Skip to content

HTTP upgrade handler for Websockets #5569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 10, 2023
Merged
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,14 +16,31 @@

package io.helidon.nima.tests.integration.websocket.webserver;

import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;

import io.helidon.common.http.Headers;
import io.helidon.common.http.HttpPrologue;
import io.helidon.common.http.WritableHeaders;
import io.helidon.nima.websocket.WsListener;
import io.helidon.nima.websocket.WsSession;
import io.helidon.nima.websocket.WsUpgradeException;
import io.helidon.nima.websocket.webserver.WsUpgradeProvider;

class EchoService implements WsListener {
private final AtomicReference<CloseInfo> closed = new AtomicReference<>();

private volatile String subProtocol;

@Override
public void onOpen(WsSession session) {
String p = session.subProtocol().orElse(null);
if (subProtocol != null && !subProtocol.equals(p)) {
throw new InternalError("Invalid sub-protocol in session");
}
}

@Override
public void receive(WsSession session, String text, boolean last) {
session.send(text, last);
Expand All @@ -34,6 +51,31 @@ public void onClose(WsSession session, int status, String reason) {
closed.set(new CloseInfo(status, reason));
}

@Override
public Optional<Headers> onHttpUpgrade(HttpPrologue prologue, Headers headers) throws WsUpgradeException {
WritableHeaders<?> upgradeHeaders = WritableHeaders.create();
if (headers.contains(WsUpgradeProvider.PROTOCOL)) {
List<String> subProtocols = headers.get(WsUpgradeProvider.PROTOCOL).allValues(true);
if (subProtocols.contains("chat")) {
upgradeHeaders.set(WsUpgradeProvider.PROTOCOL, "chat");
subProtocol = "chat";
} else {
throw new WsUpgradeException("Unable to negotiate WS sub-protocol");
}
} else {
subProtocol = null;
}
if (headers.contains(WsUpgradeProvider.EXTENSIONS)) {
List<String> extensions = headers.get(WsUpgradeProvider.EXTENSIONS).allValues(true);
if (extensions.contains("nima")) {
upgradeHeaders.set(WsUpgradeProvider.EXTENSIONS, "nima");
} else {
throw new WsUpgradeException("Unable to negotiate WS extensions");
}
}
return upgradeHeaders.size() > 0 ? Optional.of(upgradeHeaders) : Optional.empty();
}

void resetClosed() {
closed.set(null);
}
Expand All @@ -43,4 +85,5 @@ CloseInfo closeInfo() {
}

record CloseInfo(int status, String reason) { }

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -81,8 +81,11 @@ void testOnce() throws Exception {
TestListener listener = new TestListener();

java.net.http.WebSocket ws = client.newWebSocketBuilder()
.subprotocols("chat", "mute")
// .header(EXTENSIONS.defaultCase(), "nima") rejected by client
.buildAsync(URI.create("ws://localhost:" + port + "/echo"), listener)
.get(5, TimeUnit.SECONDS);
assertThat(ws.getSubprotocol(), is("chat")); // negotiated
ws.request(10);

ws.sendText("Hello", true).get(5, TimeUnit.SECONDS);
Expand Down
2 changes: 1 addition & 1 deletion nima/websocket/webserver/pom.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Copyright (c) 2022 Oracle and/or its affiliates.
Copyright (c) 2022, 2023 Oracle and/or its affiliates.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,9 +18,11 @@

import java.lang.System.Logger.Level;
import java.nio.charset.StandardCharsets;
import java.util.Optional;

import io.helidon.common.buffers.BufferData;
import io.helidon.common.buffers.DataReader;
import io.helidon.common.http.Headers;
import io.helidon.common.http.HttpPrologue;
import io.helidon.common.http.WritableHeaders;
import io.helidon.nima.webserver.CloseConnectionException;
Expand All @@ -30,12 +32,15 @@
import io.helidon.nima.websocket.WsListener;
import io.helidon.nima.websocket.WsSession;

import static io.helidon.nima.websocket.webserver.WsUpgradeProvider.PROTOCOL;

class WsConnection implements ServerConnection, WsSession {
private static final System.Logger LOGGER = System.getLogger(WsConnection.class.getName());

private final ConnectionContext ctx;
private final HttpPrologue prologue;
private final WritableHeaders<?> headers;
private final Headers upgradeHeaders;
private final String wsKey;
private final WsListener listener;

Expand All @@ -49,11 +54,13 @@ class WsConnection implements ServerConnection, WsSession {
WsConnection(ConnectionContext ctx,
HttpPrologue prologue,
WritableHeaders<?> headers,
Headers upgradeHeaders,
String wsKey,
WebSocket wsRoute) {
this.ctx = ctx;
this.prologue = prologue;
this.headers = headers;
this.upgradeHeaders = upgradeHeaders;
this.wsKey = wsKey;
this.listener = wsRoute.listener();
this.dataReader = ctx.dataReader();
Expand Down Expand Up @@ -113,6 +120,14 @@ public WsSession abort() {
throw new CloseConnectionException("Aborting from WebSocket");
}

@Override
public Optional<String> subProtocol() {
if (upgradeHeaders != null) {
return upgradeHeaders.first(PROTOCOL);
}
return Optional.empty();
}

private boolean processFrame(ClientFrame frame) {
// TODO listener.onError should be called for errors
BufferData payload = frame.unmasked();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,16 +17,17 @@
package io.helidon.nima.websocket.webserver;

import java.lang.System.Logger.Level;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;

import io.helidon.common.buffers.BufferData;
import io.helidon.common.buffers.DataWriter;
import io.helidon.common.http.DirectHandler;
import io.helidon.common.http.Headers;
import io.helidon.common.http.Http;
import io.helidon.common.http.Http.Header;
import io.helidon.common.http.Http.HeaderName;
Expand All @@ -36,6 +37,9 @@
import io.helidon.nima.webserver.ConnectionContext;
import io.helidon.nima.webserver.http1.spi.Http1UpgradeProvider;
import io.helidon.nima.webserver.spi.ServerConnection;
import io.helidon.nima.websocket.WsUpgradeException;

import static java.nio.charset.StandardCharsets.US_ASCII;

/**
* {@link java.util.ServiceLoader} provider implementation for upgrade from HTTP/1.1 to WebSocket.
Expand All @@ -45,17 +49,22 @@ public class WsUpgradeProvider implements Http1UpgradeProvider {
/**
* Websocket key header name.
*/
protected static final HeaderName WS_KEY = Header.create("Sec-WebSocket-Key");
public static final HeaderName WS_KEY = Header.create("Sec-WebSocket-Key");

/**
* Websocket version header name.
*/
protected static final HeaderName WS_VERSION = Header.create("Sec-WebSocket-Version");
public static final HeaderName WS_VERSION = Header.create("Sec-WebSocket-Version");

/**
* Websocket protocol header name.
*/
public static final HeaderName PROTOCOL = Header.create("Sec-WebSocket-Protocol");

/**
* Websocket protocol header name.
*/
protected static final HeaderName PROTOCOL = Header.create("Sec-WebSocket-Protocol");
public static final HeaderName EXTENSIONS = Header.create("Sec-WebSocket-Extensions");

/**
* Switching response prefix.
Expand All @@ -81,10 +90,11 @@ public class WsUpgradeProvider implements Http1UpgradeProvider {
protected static final Http.HeaderValue SUPPORTED_VERSION_HEADER = Header.create(WS_VERSION, SUPPORTED_VERSION);

private static final System.Logger LOGGER = System.getLogger(WsUpgradeProvider.class.getName());
private static final byte[] KEY_SUFFIX = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(StandardCharsets.US_ASCII);
private static final byte[] KEY_SUFFIX = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(US_ASCII);
private static final int KEY_SUFFIX_LENGTH = KEY_SUFFIX.length;
private static final Base64.Decoder B64_DECODER = Base64.getDecoder();
private static final Base64.Encoder B64_ENCODER = Base64.getEncoder();
private static final byte[] HEADERS_SEPARATOR = "\r\n".getBytes(US_ASCII);

private final Set<String> origins;
private final boolean anyOrigin;
Expand Down Expand Up @@ -160,17 +170,33 @@ public ServerConnection upgrade(ConnectionContext ctx, HttpPrologue prologue, Wr
}
}

// todo support subprotocols (must be provided by route)
// Sec-WebSocket-Protocol: sub-protocol (list provided in PROTOCOL header, separated by comma space
// invoke user-provided HTTP upgrade handler
Optional<Headers> upgradeHeaders;
try {
upgradeHeaders = route.listener().onHttpUpgrade(prologue, headers);
} catch (WsUpgradeException e) {
LOGGER.log(Level.TRACE, "Websocket upgrade rejected", e);
return null;
}

// write switch protocol response including headers from listener
DataWriter dataWriter = ctx.dataWriter();
String switchingProtocols = SWITCHING_PROTOCOL_PREFIX + hash(ctx, wsKey) + SWITCHING_PROTOCOLS_SUFFIX;
dataWriter.write(BufferData.create(switchingProtocols.getBytes(StandardCharsets.US_ASCII)));
String switchingProtocols = SWITCHING_PROTOCOL_PREFIX + hash(ctx, wsKey);
dataWriter.write(BufferData.create(switchingProtocols.getBytes(US_ASCII)));
BufferData separator = BufferData.create(HEADERS_SEPARATOR);
dataWriter.write(separator);
upgradeHeaders.ifPresent(hs -> {
BufferData headerData = BufferData.growing(128);
hs.forEach(h -> h.writeHttp1Header(headerData));
dataWriter.write(headerData);
});
dataWriter.write(separator.rewind());

if (LOGGER.isLoggable(Level.TRACE)) {
LOGGER.log(Level.TRACE, "Upgraded to websocket version " + version);
}

return new WsConnection(ctx, prologue, headers, wsKey, route);
return new WsConnection(ctx, prologue, headers, upgradeHeaders.orElse(null), wsKey, route);
}

protected boolean anyOrigin() {
Expand All @@ -197,7 +223,7 @@ protected String hash(ConnectionContext ctx, String wsKey) {
.message("Invalid Sec-WebSocket-Key header")
.build();
}
byte[] wsKeyBytes = wsKey.getBytes(StandardCharsets.US_ASCII);
byte[] wsKeyBytes = wsKey.getBytes(US_ASCII);
int wsKeyBytesLength = wsKeyBytes.length;
byte[] toHash = new byte[wsKeyBytesLength + KEY_SUFFIX_LENGTH];
System.arraycopy(wsKeyBytes, 0, toHash, 0, wsKeyBytesLength);
Expand Down
6 changes: 5 additions & 1 deletion nima/websocket/websocket/pom.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Copyright (c) 2022 Oracle and/or its affiliates.
Copyright (c) 2022, 2023 Oracle and/or its affiliates.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,10 @@
<artifactId>helidon-common-buffers</artifactId>
</dependency>
<dependency>
<groupId>io.helidon.common</groupId>
<artifactId>helidon-common-http</artifactId>
</dependency>
<dependency>
<groupId>io.helidon.common.features</groupId>
<artifactId>helidon-common-features-api</artifactId>
<scope>provided</scope>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,11 @@

package io.helidon.nima.websocket;

import java.util.Optional;

import io.helidon.common.buffers.BufferData;
import io.helidon.common.http.Headers;
import io.helidon.common.http.HttpPrologue;

/**
* WebSocket listener.
Expand Down Expand Up @@ -49,7 +53,6 @@ default void receive(WsSession session, BufferData buffer, boolean last) {
* @param buffer buffer with data
*/
default void onPing(WsSession session, BufferData buffer) {

}

/**
Expand All @@ -59,7 +62,6 @@ default void onPing(WsSession session, BufferData buffer) {
* @param buffer buffer with data
*/
default void onPong(WsSession session, BufferData buffer) {

}

/**
Expand All @@ -70,7 +72,6 @@ default void onPong(WsSession session, BufferData buffer) {
* @param reason reason of close
*/
default void onClose(WsSession session, int status, String reason) {

}

/**
Expand All @@ -80,7 +81,6 @@ default void onClose(WsSession session, int status, String reason) {
* @param t throwable caught
*/
default void onError(WsSession session, Throwable t) {

}

/**
Expand All @@ -89,6 +89,18 @@ default void onError(WsSession session, Throwable t) {
* @param session WebSocket session
*/
default void onOpen(WsSession session) {
}

/**
* Invoked during handshake process. Can be used to negotiate sub-protocols and/or
* reject an upgrade by throwing {@link WsUpgradeException}.
*
* @param prologue the http handshake request
* @param headers headers in request
* @return headers to be included in handshake response
* @throws WsUpgradeException if handshake is rejected
*/
default Optional<Headers> onHttpUpgrade(HttpPrologue prologue, Headers headers) throws WsUpgradeException {
return Optional.empty();
}
}
Loading