8268294: Reusing HttpClient in a WebSocket.Listener hangs.

Reviewed-by: dfuchs
This commit is contained in:
Michael McMahon 2021-06-17 07:13:59 +00:00
parent e84461072a
commit 2d088fa91d
12 changed files with 1066 additions and 12 deletions

View File

@ -49,7 +49,7 @@ import jdk.internal.net.http.common.OperationTrackers.Tracker;
* An HttpClientFacade is a simple class that wraps an HttpClient implementation
* and delegates everything to its implementation delegate.
*/
final class HttpClientFacade extends HttpClient implements Trackable {
public final class HttpClientFacade extends HttpClient implements Trackable {
final HttpClientImpl impl;
@ -110,6 +110,10 @@ final class HttpClientFacade extends HttpClient implements Trackable {
return impl.executor();
}
public Executor theExecutor() {
return impl.theExecutor();
}
@Override
public <T> HttpResponse<T>
send(HttpRequest req, HttpResponse.BodyHandler<T> responseBodyHandler)

View File

@ -60,9 +60,16 @@ class MessageDecoder implements Frame.Consumer {
private long payloadLen;
private long unconsumedPayloadLen;
private ByteBuffer binaryData;
private final boolean server;
private int maskingKey;
MessageDecoder(MessageStreamConsumer output) {
this(output, false);
}
MessageDecoder(MessageStreamConsumer output, boolean server) {
this.output = requireNonNull(output);
this.server = server;
}
/* Exposed for testing purposes */
@ -143,9 +150,12 @@ class MessageDecoder implements Frame.Consumer {
if (debug.on()) {
debug.log("mask %s", value);
}
if (value) {
if (value && !server) {
throw new FailWebSocketException("Masked frame received");
}
if (!value && server) {
throw new FailWebSocketException("Masked frame expected");
}
}
@Override
@ -175,7 +185,9 @@ class MessageDecoder implements Frame.Consumer {
// So this method (`maskingKey`) is not supposed to be invoked while
// reading a frame that has came from the server. If this method is
// invoked, then it's an error in implementation, thus InternalError
throw new InternalError();
if (!server)
throw new InternalError();
maskingKey = value;
}
@Override
@ -204,10 +216,17 @@ class MessageDecoder implements Frame.Consumer {
boolean last = fin && lastPayloadChunk;
boolean text = opcode == Opcode.TEXT || originatingOpcode == Opcode.TEXT;
if (!text) {
output.onBinary(data.slice(), last);
ByteBuffer slice = data.slice();
if (server) {
unMask(slice);
}
output.onBinary(slice, last);
data.position(data.limit()); // Consume
} else {
boolean binaryNonEmpty = data.hasRemaining();
if (server) {
unMask(data);
}
CharBuffer textData;
try {
textData = decoder.decode(data, last);
@ -225,6 +244,17 @@ class MessageDecoder implements Frame.Consumer {
}
}
private void unMask(ByteBuffer src) {
int pos = src.position();
int size = src.remaining();
ByteBuffer temp = ByteBuffer.allocate(size);
Frame.Masker.transferMasking(src, temp, maskingKey);
temp.flip();
src.position(pos);
src.put(temp);
src.position(pos).limit(pos+size);
}
@Override
public void endFrame() {
if (debug.on()) {

View File

@ -81,6 +81,15 @@ public class MessageEncoder {
/* Was the previous frame TEXT or a CONTINUATION thereof? */
private boolean previousText;
private boolean closed;
private final boolean server;
MessageEncoder() {
this(false);
}
MessageEncoder(boolean isServer) {
this.server = isServer;
}
/*
* How many bytes of the current message have been already encoded.
@ -369,12 +378,20 @@ public class MessageEncoder {
opcode, fin, payloadLen);
}
headerBuffer.clear();
int mask = maskingKeySource.nextInt();
headerWriter.fin(fin)
// for server setting mask to 0 disables masking (xor)
int mask = this.server ? 0 : maskingKeySource.nextInt();
if (mask == 0) {
headerWriter.fin(fin)
.opcode(opcode)
.payloadLen(payloadLen)
.write(headerBuffer);
} else {
headerWriter.fin(fin)
.opcode(opcode)
.payloadLen(payloadLen)
.mask(mask)
.write(headerBuffer);
}
headerBuffer.flip();
payloadMasker.mask(mask);
}

View File

@ -25,6 +25,7 @@
package jdk.internal.net.http.websocket;
import jdk.internal.net.http.HttpClientFacade;
import jdk.internal.net.http.common.Demand;
import jdk.internal.net.http.common.Log;
import jdk.internal.net.http.common.Logger;
@ -37,6 +38,7 @@ import java.io.IOException;
import java.lang.ref.Reference;
import java.net.ProtocolException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.WebSocket;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
@ -44,6 +46,7 @@ import java.nio.charset.CharacterCodingException;
import java.nio.charset.CharsetEncoder;
import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.Executor;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
@ -115,10 +118,12 @@ public final class WebSocketImpl implements WebSocket {
private final SequentialScheduler receiveScheduler
= new SequentialScheduler(new ReceiveTask());
private final Demand demand = new Demand();
private final Executor clientExecutor;
public static CompletableFuture<WebSocket> newInstanceAsync(BuilderImpl b) {
Function<Result, WebSocket> newWebSocket = r -> {
WebSocket ws = newInstance(b.getUri(),
b.getClient(),
r.subprotocol,
b.getListener(),
r.transport);
@ -140,10 +145,11 @@ public final class WebSocketImpl implements WebSocket {
/* Exposed for testing purposes */
static WebSocketImpl newInstance(URI uri,
HttpClient client,
String subprotocol,
Listener listener,
TransportFactory transport) {
WebSocketImpl ws = new WebSocketImpl(uri, subprotocol, listener, transport);
WebSocketImpl ws = new WebSocketImpl(uri, client, subprotocol, listener, transport);
// This initialisation is outside of the constructor for the sake of
// safe publication of WebSocketImpl.this
ws.signalOpen();
@ -151,10 +157,12 @@ public final class WebSocketImpl implements WebSocket {
}
private WebSocketImpl(URI uri,
HttpClient client,
String subprotocol,
Listener listener,
TransportFactory transportFactory) {
this.uri = requireNonNull(uri);
this.clientExecutor = ((HttpClientFacade)client).theExecutor();
this.subprotocol = requireNonNull(subprotocol);
this.listener = requireNonNull(listener);
// Why 6? 1 sendPing/sendPong + 1 sendText/sendBinary + 1 Close +
@ -356,7 +364,7 @@ public final class WebSocketImpl implements WebSocket {
debug.log("request %s", n);
}
if (demand.increase(n)) {
receiveScheduler.runOrSchedule();
receiveScheduler.runOrSchedule(clientExecutor);
}
}
@ -398,7 +406,7 @@ public final class WebSocketImpl implements WebSocket {
* The assumptions about order is as follows:
*
* - state is never changed more than twice inside the `run` method:
* x --(1)--> IDLE --(2)--> y (otherwise we're loosing events, or
* x --(1)--> IDLE --(2)--> y (otherwise we're losing events, or
* overwriting parts of messages creating a mess since there's no
* queueing)
* - OPEN is always the first state
@ -702,7 +710,7 @@ public final class WebSocketImpl implements WebSocket {
private void signalOpen() {
debug.log("signalOpen");
receiveScheduler.runOrSchedule();
receiveScheduler.runOrSchedule(clientExecutor);
}
private void signalError(Throwable error) {
@ -834,7 +842,7 @@ public final class WebSocketImpl implements WebSocket {
if (currentState == ERROR || currentState == CLOSE) {
break;
} else if (state.compareAndSet(currentState, newState)) {
receiveScheduler.runOrSchedule();
receiveScheduler.runOrSchedule(clientExecutor);
success = true;
break;
}
@ -850,7 +858,7 @@ public final class WebSocketImpl implements WebSocket {
State witness = state.compareAndExchange(expectedState, newState);
boolean success = false;
if (witness == expectedState) {
receiveScheduler.runOrSchedule();
receiveScheduler.runOrSchedule(clientExecutor);
success = true;
} else if (witness != ERROR && witness != CLOSE) {
// This should be the only reason for inability to change the state

View File

@ -0,0 +1,33 @@
/*
* Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @bug 8268294
* @modules java.net.http/jdk.internal.net.http.websocket:open jdk.httpserver
* @run main/othervm
* --add-reads java.net.http=ALL-UNNAMED
* --add-reads java.net.http=jdk.httpserver
* java.net.http/jdk.internal.net.http.websocket.WebSocketAndHttpTest
*/
public final class WebSocketServerDriver { }

View File

@ -0,0 +1,48 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package jdk.internal.net.http.websocket;
import java.nio.ByteBuffer;
/**
* No implementation provided for onInit() because that must always be
* implemented by user
*/
abstract class DefaultMessageStreamHandler implements MessageStreamHandler {
public void onText(CharSequence data, boolean last) {}
public void onBinary(ByteBuffer data, boolean last) {}
public void onPing(ByteBuffer data) {}
public void onPong(ByteBuffer data) {}
public void onClose(int statusCode, CharSequence reason) {}
public void onComplete() {}
public void onError(Throwable e) {}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package jdk.internal.net.http.websocket;
/**
* WebSocket server listener interface, which is the same as the client API
* in java.net.http. See MessageStreamResponder for how listener methods
* can send response messages back to the client
*
* All MessageStreamConsumer methods must be implemented (plus the handler method
* declared here). DefaultMessageStreamHandler provides empty implementations of all
* that can be extended, except for onInit() which must always be implemented.
*
* void onText(CharSequence data, boolean last);
*
* void onBinary(ByteBuffer data, boolean last);
*
* void onPing(ByteBuffer data);
*
* void onPong(ByteBuffer data);
*
* void onClose(int statusCode, CharSequence reason);
*
* void onComplete();
*
* void onError(Throwable e);
*/
interface MessageStreamHandler extends MessageStreamConsumer {
/**
* called before any of the methods above to supply a
* MessageStreamResponder for any new connection, which can be used to send replies
* sendText(), sendBinary(), sendClose() etc
*/
void onInit(MessageStreamResponder responder);
}

View File

@ -0,0 +1,47 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package jdk.internal.net.http.websocket;
import java.io.*;
import java.nio.*;
import java.util.List;
/**
* One of these supplied for each incoming client connection for use
* by user written MessageStreamConsumer.
*/
interface MessageStreamResponder {
public void sendText(CharBuffer src, boolean last) throws IOException;
public void sendBinary(ByteBuffer src, boolean last) throws IOException;
public void sendPing(ByteBuffer src) throws IOException;
public void sendPong(ByteBuffer src) throws IOException;
public void sendClose(int statusCode, CharBuffer reason) throws IOException;
public void close();
}

View File

@ -0,0 +1,113 @@
package jdk.internal.net.http.websocket;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.WebSocket;
import java.util.Optional;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
/**
* This is the client side of the test invoked from WebSocketAndHttpTest:
*
* The two args are the addresses of a (local) Websocket and Http server
*
* The test first sends a request to the WS server and in the listener
* which handles the response, it tries to send a request to the http
* server. This hangs if the listener was invoked from the selector
* manager thread. If invoked from a different thread then the http
* response is received and the response string is mapped to string
* "succeeded"
*/
public class WebSocketAndHttpClient {
public static void main(String[] args) throws InterruptedException {
ExecutorService executorService = Executors.newCachedThreadPool();
HttpClient httpClient = HttpClient.newBuilder().executor(executorService).build();
WebSocketTest wsTest = new WebSocketTest(httpClient, args[0]);
HttpTest httpTest = new HttpTest(httpClient, args[1]);
final CompletableFuture<String> result = new CompletableFuture<>();
wsTest.listen(message -> {
try {
String r = httpTest.getData(message);
result.complete(r);
} catch (Exception e) {
result.completeExceptionally(e);
}
});
wsTest.sendData("TEST_DATA");
System.out.println("Wait for result");
try {
result.join();
System.out.println("Result: success");
} finally {
executorService.shutdownNow();
}
}
static class WebSocketTest {
final HttpClient httpClient;
final String server;
volatile WebSocket webSocket;
WebSocketTest(HttpClient httpClient, String server) {
this.httpClient = httpClient;
this.server = server;
}
public void listen(Consumer<String> consumer) {
URI uri = URI.create(server);
System.out.println("WS API client - Connecting to " + uri.toString());
CompletableFuture<WebSocket> cf = httpClient.newWebSocketBuilder()
.buildAsync(uri, new WebSocket.Listener() {
@Override
public CompletionStage<?> onText(WebSocket webSocket, CharSequence data, boolean last) {
System.out.println("WS API client - received data: " + data);
consumer.accept(data.toString());
return null;
}
public void onError(WebSocket webSocket, Throwable error) {
System.out.println("WS API client - error");
error.printStackTrace();
}
});
System.out.println("CF created");
webSocket = cf.join();
System.out.println("Websocket created");
}
void sendData(String data) {
System.out.println("WS API client - sending data via WebSocket: {}" + data);
webSocket.sendText(data, true).join();
}
}
static class HttpTest {
final HttpClient httpClient;
final String baseUrl;
HttpTest(HttpClient httpClient, String baseUrl) {
this.httpClient = httpClient;
this.baseUrl = baseUrl;
}
private String getData(String data) throws Exception {
URI uri = URI.create(baseUrl + "?param=" + data);
HttpRequest request = HttpRequest.newBuilder().GET().uri(uri).build();
System.out.println("Http API Client - send HTTP GET request with parameter {}" + data);
HttpResponse<String> send = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
return send.body();
}
}
}

View File

@ -0,0 +1,75 @@
/*
* Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package jdk.internal.net.http.websocket;
import java.net.*;
import java.nio.CharBuffer;
import java.io.*;
import com.sun.net.httpserver.*;
public class WebSocketAndHttpTest {
static class WHandler extends DefaultMessageStreamHandler {
volatile MessageStreamResponder responder;
public void onText(CharSequence data, boolean last) {
System.out.println("onText: " + data);
System.out.println("onText: " + Thread.currentThread());
try {
responder.sendText(CharBuffer.wrap(data), true);
System.out.println("onText: send ok");
} catch (IOException e) {
System.out.println("onText: " + e);
throw new UncheckedIOException(e);
}
}
public void onInit(MessageStreamResponder responder) {
System.out.println("onInit");
this.responder = responder;
}
}
static HttpHandler httpHandler = (ex) -> ex.sendResponseHeaders(200, -1);
public static void main(String[] args) throws Exception {
HttpServer hserver = null;
try {
WebSocketServer server = new WebSocketServer(new WHandler());
server.open();
URI uri = server.getURI();
hserver = HttpServer.create(new InetSocketAddress(0), 4);
hserver.createContext("/", httpHandler);
hserver.start();
int port = hserver.getAddress().getPort();
URI huri = new URI("http://127.0.0.1:" + port + "/foo");
WebSocketAndHttpClient.main(new String[]{uri.toString(), huri.toString()});
} finally {
hserver.stop(0);
}
}
}

View File

@ -0,0 +1,186 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package jdk.internal.net.http.websocket;
import java.util.LinkedList;
import java.util.List;
import java.io.*;
import java.nio.*;
import java.nio.channels.*;
public class WebSocketResponder implements MessageStreamResponder {
final MessageStreamConsumer consumer;
final LinkedList<ByteBuffer> queue;
volatile boolean closed = false;
final MessageEncoder encoder;
final MessageDecoder decoder;
static final int BUF_SIZE = 1024;
public WebSocketResponder(MessageStreamConsumer consumer) {
this.consumer = consumer;
this.queue = new LinkedList<>();
this.decoder = new MessageDecoder(consumer, true);
this.encoder = new MessageEncoder(true);
}
// own thread
public void readLoop(SocketChannel chan) throws IOException {
chan.configureBlocking(true);
boolean eof = false;
ByteBuffer buf = ByteBuffer.allocate(8 * 1024);
Frame.Reader reader = new Frame.Reader();
try {
while (!eof) {
int count;
buf.clear();
eof = ((count=chan.read(buf)) == -1);
if (!eof) {
buf.flip();
reader.readFrame(buf, decoder);
}
}
} catch (IOException e) {
if (!closed)
throw e;
}
}
// own thread
public void writeLoop(SocketChannel chan) throws IOException {
// read queue and send data
while (true) {
ByteBuffer buf;
synchronized(queue) {
while (queue.isEmpty()) {
try {
queue.wait();
} catch (InterruptedException e) {
throw new IOException(e);
}
if (queue.isEmpty() && closed) {
chan.close();
return;
}
}
buf = queue.remove(0);
}
chan.write(buf);
}
}
/**
* Public methods below used y MessageStreamHandler to send replies
* to client.
*/
@Override
public void sendText(CharBuffer src, boolean last) throws IOException {
ByteBuffer buf = ByteBuffer.allocate(BUF_SIZE);
LinkedList<ByteBuffer> bufs = new LinkedList<>();
boolean done = false;
do {
buf.clear();
done = encoder.encodeText(src, last, buf);
buf.flip();
bufs.add(buf);
} while (!done);
sendMessage(bufs);
}
@Override
public void sendBinary(ByteBuffer src, boolean last) throws IOException {
ByteBuffer buf = ByteBuffer.allocate(BUF_SIZE);
LinkedList<ByteBuffer> bufs = new LinkedList<>();
boolean done = false;
do {
buf.clear();
done = encoder.encodeBinary(src, last, buf);
buf.flip();
bufs.add(buf);
} while (!done);
sendMessage(bufs);
}
@Override
public void sendPing(ByteBuffer src) throws IOException {
ByteBuffer buf = ByteBuffer.allocate(BUF_SIZE);
LinkedList<ByteBuffer> bufs = new LinkedList<>();
boolean done = false;
do {
buf.clear();
done = encoder.encodePing(src, buf);
buf.flip();
bufs.add(buf);
} while (!done);
sendMessage(bufs);
}
@Override
public void sendPong(ByteBuffer src) throws IOException {
ByteBuffer buf = ByteBuffer.allocate(BUF_SIZE);
LinkedList<ByteBuffer> bufs = new LinkedList<>();
boolean done = false;
do {
buf.clear();
done = encoder.encodePong(src, buf);
buf.flip();
bufs.add(buf);
} while (!done);
sendMessage(bufs);
}
@Override
public void sendClose(int statusCode, CharBuffer reason) throws IOException {
ByteBuffer buf = ByteBuffer.allocate(BUF_SIZE);
LinkedList<ByteBuffer> bufs = new LinkedList<>();
boolean done = false;
do {
buf.clear();
done = encoder.encodeClose(statusCode, reason, buf);
buf.flip();
bufs.add(buf);
} while (!done);
sendMessage(bufs);
close();
}
private void sendMessage(List<ByteBuffer> bufs) throws IOException {
if (closed)
throw new IOException("closed");
synchronized(queue) {
queue.addAll(bufs);
queue.notify();
}
}
@Override
public void close() {
synchronized(queue) {
closed = true;
queue.notify();
}
}
}

View File

@ -0,0 +1,435 @@
/*
* Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package jdk.internal.net.http.websocket;
import java.io.Closeable;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.StandardSocketOptions;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.channels.ClosedByInterruptException;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.charset.CharacterCodingException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiFunction;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static java.lang.String.format;
import static java.lang.System.err;
import static java.nio.charset.StandardCharsets.ISO_8859_1;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Arrays.asList;
import static java.util.Objects.requireNonNull;
/**
* WebSocket Server. This is a copy of the DummyWebSocketServer test class
* but which also supports sending and receiving of websocket messages
* using a simple API once the connection has been established
*
* MessageStreamHandler is the "listener" API to be implemented for handling
* incoming messages. MessageStreamResponder is used by that handler to send
* responses back to the client.
*
* Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e.
* no proxying, cookies, etc.) Supports sequential connections, one at a time,
* i.e. in order for a client to connect to the server the previous client must
* disconnect first.
*
* Expected client request:
*
* GET /chat HTTP/1.1
* Host: server.example.com
* Upgrade: websocket
* Connection: Upgrade
* Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
* Origin: http://example.com
* Sec-WebSocket-Protocol: chat, superchat
* Sec-WebSocket-Version: 13
*
* This server response:
*
* HTTP/1.1 101 Switching Protocols
* Upgrade: websocket
* Connection: Upgrade
* Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
* Sec-WebSocket-Protocol: chat
*/
public class WebSocketServer implements Closeable {
private final AtomicBoolean started = new AtomicBoolean();
private final Thread thread;
private volatile ServerSocketChannel ssc;
private volatile InetSocketAddress address;
private ByteBuffer read = ByteBuffer.allocate(16384);
private final CountDownLatch readReady = new CountDownLatch(1);
private final MessageStreamHandler handler;
private final WebSocketResponder responder;
private volatile int receiveBufferSize;
private static class Credentials {
private final String name;
private final String password;
private Credentials(String name, String password) {
this.name = name;
this.password = password;
}
public String name() { return name; }
public String password() { return password; }
}
public WebSocketServer(MessageStreamHandler handler) {
this(handler, defaultMapping(), null, null);
}
public WebSocketServer() {
this(null, defaultMapping(), null, null);
}
public WebSocketServer(String username, String password) {
this(null, defaultMapping(), username, password);
}
public WebSocketServer(MessageStreamHandler handler,
BiFunction<List<String>,Credentials,List<String>> mapping,
String username, String password) {
requireNonNull(mapping);
this.handler = handler;
if (handler == null) {
this.responder = null;
} else {
this.responder = new WebSocketResponder(handler);
handler.onInit(this.responder);
}
Credentials credentials = username != null ?
new Credentials(username, password) : null;
thread = new Thread(() -> {
try {
while (!Thread.currentThread().isInterrupted()) {
err.println("Accepting next connection at: " + ssc);
SocketChannel channel = ssc.accept();
err.println("Accepted: " + channel);
try {
channel.setOption(StandardSocketOptions.TCP_NODELAY, true);
channel.configureBlocking(true);
while (true) {
StringBuilder request = new StringBuilder();
if (!readRequest(channel, request)) {
throw new IOException("Bad request:[" + request + "]");
}
List<String> strings = asList(request.toString().split("\r\n"));
List<String> response = mapping.apply(strings, credentials);
writeResponse(channel, response);
if (response.get(0).startsWith("HTTP/1.1 401")) {
err.println("Sent 401 Authentication response " + channel);
continue;
} else {
serve(channel);
break;
}
}
} catch (IOException e) {
err.println("Error in connection: " + channel + ", " + e);
} finally {
err.println("Closed: " + channel);
close(channel);
readReady.countDown();
}
}
} catch (ClosedByInterruptException ignored) {
} catch (Exception e) {
e.printStackTrace(err);
} finally {
close(ssc);
err.println("Stopped at: " + getURI());
}
});
thread.setName("WebSocketServer");
thread.setDaemon(false);
}
// runs in own thread. Override to implement different behavior
protected void read(SocketChannel ch) throws IOException {
responder.readLoop(ch);
}
// runs in own thread. Override to implement different behavior
protected void write(SocketChannel ch) throws IOException {
responder.writeLoop(ch);
}
protected final void serve(SocketChannel channel)
throws InterruptedException
{
Thread reader = new Thread(() -> {
try {
read(channel);
} catch (IOException ignored) { }
});
Thread writer = new Thread(() -> {
try {
write(channel);
} catch (IOException ignored) { }
});
reader.start();
writer.start();
try {
reader.join();
} finally {
reader.interrupt();
try {
writer.join();
} finally {
writer.interrupt();
}
}
}
public ByteBuffer read() throws InterruptedException {
readReady.await();
return read.duplicate().asReadOnlyBuffer().flip();
}
public void setReceiveBufferSize(int bufsize) {
assert ssc == null : "Must configure before calling open()";
this.receiveBufferSize = bufsize;
}
public void open() throws IOException {
err.println("Starting");
if (!started.compareAndSet(false, true)) {
throw new IllegalStateException("Already started");
}
ssc = ServerSocketChannel.open();
try {
ssc.configureBlocking(true);
var bufsize = receiveBufferSize;
if (bufsize > 0) {
err.printf("Configuring receive buffer size to %d%n", bufsize);
try {
ssc.setOption(StandardSocketOptions.SO_RCVBUF, bufsize);
} catch (IOException x) {
err.printf("Failed to configure receive buffer size to %d%n", bufsize);
}
}
ssc.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
address = (InetSocketAddress) ssc.getLocalAddress();
thread.start();
} catch (IOException e) {
close(ssc);
throw e;
}
err.println("Started at: " + getURI());
}
@Override
public void close() {
err.println("Stopping: " + getURI());
thread.interrupt();
close(ssc);
}
URI getURI() {
if (!started.get()) {
throw new IllegalStateException("Not yet started");
}
return URI.create("ws://localhost:" + address.getPort());
}
private boolean readRequest(SocketChannel channel, StringBuilder request)
throws IOException
{
ByteBuffer buffer = ByteBuffer.allocate(512);
while (channel.read(buffer) != -1) {
// read the complete HTTP request headers, there should be no body
CharBuffer decoded;
buffer.flip();
try {
decoded = ISO_8859_1.newDecoder().decode(buffer);
} catch (CharacterCodingException e) {
throw new UncheckedIOException(e);
}
request.append(decoded);
if (Pattern.compile("\r\n\r\n").matcher(request).find())
return true;
buffer.clear();
}
return false;
}
private void writeResponse(SocketChannel channel, List<String> response)
throws IOException
{
String s = response.stream().collect(Collectors.joining("\r\n"))
+ "\r\n\r\n";
ByteBuffer encoded;
try {
encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s));
} catch (CharacterCodingException e) {
throw new UncheckedIOException(e);
}
while (encoded.hasRemaining()) {
channel.write(encoded);
}
}
private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() {
return (request, credentials) -> {
List<String> response = new LinkedList<>();
Iterator<String> iterator = request.iterator();
if (!iterator.hasNext()) {
throw new IllegalStateException("The request is empty");
}
String statusLine = iterator.next();
if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) {
throw new IllegalStateException
("Unexpected status line: " + request.get(0));
}
response.add("HTTP/1.1 101 Switching Protocols");
Map<String, List<String>> requestHeaders = new HashMap<>();
while (iterator.hasNext()) {
String header = iterator.next();
String[] split = header.split(": ");
if (split.length != 2) {
throw new IllegalStateException
("Unexpected header: " + header
+ ", split=" + Arrays.toString(split));
}
requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]);
}
if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) {
throw new IllegalStateException("Subprotocols are not expected");
}
if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) {
throw new IllegalStateException("Extensions are not expected");
}
expectHeader(requestHeaders, "Connection", "Upgrade");
response.add("Connection: Upgrade");
expectHeader(requestHeaders, "Upgrade", "websocket");
response.add("Upgrade: websocket");
expectHeader(requestHeaders, "Sec-WebSocket-Version", "13");
List<String> key = requestHeaders.get("Sec-WebSocket-Key");
if (key == null || key.isEmpty()) {
throw new IllegalStateException("Sec-WebSocket-Key is missing");
}
if (key.size() != 1) {
throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key);
}
MessageDigest sha1 = null;
try {
sha1 = MessageDigest.getInstance("SHA-1");
} catch (NoSuchAlgorithmException e) {
throw new InternalError(e);
}
String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
sha1.update(x.getBytes(ISO_8859_1));
String v = Base64.getEncoder().encodeToString(sha1.digest());
response.add("Sec-WebSocket-Accept: " + v);
// check authorization credentials, if required by the server
if (credentials != null && !authorized(credentials, requestHeaders)) {
response.clear();
response.add("HTTP/1.1 401 Unauthorized");
response.add("Content-Length: 0");
response.add("WWW-Authenticate: Basic realm=\"dummy server realm\"");
}
return response;
};
}
// Checks credentials in the request against those allowable by the server.
private static boolean authorized(Credentials credentials,
Map<String,List<String>> requestHeaders) {
List<String> authorization = requestHeaders.get("Authorization");
if (authorization == null)
return false;
if (authorization.size() != 1) {
throw new IllegalStateException("Authorization unexpected count:" + authorization);
}
String header = authorization.get(0);
if (!header.startsWith("Basic "))
throw new IllegalStateException("Authorization not Basic: " + header);
header = header.substring("Basic ".length());
String values = new String(Base64.getDecoder().decode(header), UTF_8);
int sep = values.indexOf(':');
if (sep < 1) {
throw new IllegalStateException("Authorization not colon: " + values);
}
String name = values.substring(0, sep);
String password = values.substring(sep + 1);
if (name.equals(credentials.name()) && password.equals(credentials.password()))
return true;
return false;
}
protected static String expectHeader(Map<String, List<String>> headers,
String name,
String value) {
List<String> v = headers.get(name);
if (v == null) {
throw new IllegalStateException(
format("Expected '%s' header, not present in %s",
name, headers));
}
if (!v.contains(value)) {
throw new IllegalStateException(
format("Expected '%s: %s', actual: '%s: %s'",
name, value, name, v)
);
}
return value;
}
private static void close(AutoCloseable... acs) {
for (AutoCloseable ac : acs) {
try {
ac.close();
} catch (Exception ignored) { }
}
}
}