8349662: SSLTube SSLSubscriptionWrapper has potential races when switching subscriptions

Reviewed-by: jpai
This commit is contained in:
Daniel Fuchs 2025-02-11 12:10:05 +00:00
parent 964dd18fd2
commit 5ee44c1688
3 changed files with 130 additions and 51 deletions

View File

@ -119,7 +119,7 @@ public class SSLTube implements FlowTube {
// Connect the read sink first. That's the left-hand side
// downstream subscriber from the HttpConnection (or more
// accurately, the SSLSubscriberWrapper that will wrap it
// when SSLTube::connectFlows is called.
// when SSLTube::connectFlows is called).
reader.subscribe(downReader);
// Connect the right hand side tube (the socket tube).
@ -191,7 +191,7 @@ public class SSLTube implements FlowTube {
private volatile Flow.Subscription readSubscription;
// The DelegateWrapper wraps a subscribed {@code Flow.Subscriber} and
// tracks the subscriber's state. In particular it makes sure that
// tracks the subscriber's state. In particular, it makes sure that
// onComplete/onError are not called before onSubscribed.
static final class DelegateWrapper implements FlowTube.TubeSubscriber {
private final FlowTube.TubeSubscriber delegate;
@ -302,7 +302,7 @@ public class SSLTube implements FlowTube {
// Used to read data from the SSLTube.
final class SSLSubscriberWrapper implements FlowTube.TubeSubscriber {
private AtomicReference<DelegateWrapper> pendingDelegate =
private final AtomicReference<DelegateWrapper> pendingDelegate =
new AtomicReference<>();
private volatile DelegateWrapper subscribed;
private volatile boolean onCompleteReceived;
@ -353,15 +353,15 @@ public class SSLTube implements FlowTube {
return;
}
// sslDelegate field should have been initialized by the
// the time we reach here, as there can be no subscriber
// time we reach here, as there can be no subscriber
// until SSLTube is fully constructed.
if (handleNow || !sslDelegate.resumeReader()) {
processPendingSubscriber();
}
}
// Can be called outside of the flow if an error has already been
// raise. Otherwise, must be called within the SSLFlowDelegate
// Can be called outside the flow if an error has already been
// raised. Otherwise, must be called within the SSLFlowDelegate
// downstream reader flow.
// If there is a subscription, and if there is a pending delegate,
// calls dropSubscription() on the previous delegate (if any),
@ -619,32 +619,56 @@ public class SSLTube implements FlowTube {
private volatile boolean cancelled;
void setSubscription(Flow.Subscription sub) {
long demand = writeDemand.get(); // FIXME: isn't it a racy way of passing the demand?
delegate = sub;
if (debug.on())
debug.log("setSubscription: demand=%d, cancelled:%s", demand, cancelled);
long demand;
// Avoid race condition and requesting demand twice if
// request() runs concurrently with setSubscription()
boolean cancelled;
synchronized (this) {
demand = writeDemand.get();
delegate = sub;
cancelled = this.cancelled;
}
if (debug.on()) {
debug.log("setSubscription: demand=%d, cancelled:%s, new subscription %s",
demand, cancelled, sub);
}
if (cancelled)
delegate.cancel();
sub.cancel();
else if (demand > 0)
sub.request(demand);
}
@Override
public void request(long n) {
writeDemand.increase(n);
if (debug.on()) debug.log("request: n=%d", n);
Flow.Subscription sub = delegate;
// Avoid race condition and requesting demand twice if
// request() runs concurrently with setSubscription()
Flow.Subscription sub;
long demanded;
synchronized (this) {
sub = delegate;
demanded = writeDemand.get();
writeDemand.increase(n);
}
if (debug.on()) {
debug.log("request: n=%s to %s (%s already demanded)",
n, sub, demanded);
}
if (sub != null && n > 0) {
if (debug.on()) debug.log("requesting %s from %s", n, sub);
sub.request(n);
}
}
@Override
public void cancel() {
cancelled = true;
if (delegate != null)
delegate.cancel();
Flow.Subscription sub;
synchronized (this) {
cancelled = true;
sub = delegate;
}
if (debug.on()) debug.log("cancel: cancelling subscription: " + sub);
if (sub != null) sub.cancel();
}
}
@ -652,10 +676,16 @@ public class SSLTube implements FlowTube {
@Override
public void onSubscribe(Flow.Subscription subscription) {
Objects.requireNonNull(subscription);
Flow.Subscription x = writeSubscription.delegate;
if (x != null)
x.cancel();
Flow.Subscription old;
synchronized (this) {
old = writeSubscription.delegate;
}
if (old != null && old != subscription) {
if (debug.on()) debug.log("onSubscribe: cancelling old subscription: " + old);
old.cancel();
}
if (debug.on()) debug.log("onSubscribe: new subscription: " + subscription);
writeSubscription.setSubscription(subscription);
}
@ -664,8 +694,10 @@ public class SSLTube implements FlowTube {
Objects.requireNonNull(item);
boolean decremented = writeDemand.tryDecrement();
assert decremented : "Unexpected writeDemand: ";
if (debug.on())
debug.log("sending %d buffers to SSL flow delegate", item.size());
if (debug.on()) {
debug.log("sending %s buffers to SSL flow delegate (%s bytes)",
item.size(), Utils.remaining(item));
}
sslDelegate.upstreamWriter().onNext(item);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -33,9 +33,6 @@
* CookieHeaderTest
*/
import com.sun.net.httpserver.HttpServer;
import com.sun.net.httpserver.HttpsConfigurator;
import com.sun.net.httpserver.HttpsServer;
import jdk.test.lib.net.SimpleSSLContext;
import org.testng.annotations.AfterTest;
import org.testng.annotations.BeforeTest;
@ -51,7 +48,6 @@ import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.net.CookieHandler;
import java.net.CookieManager;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
@ -65,7 +61,6 @@ import java.net.http.HttpResponse.BodyHandlers;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@ -76,7 +71,6 @@ import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import jdk.httpclient.test.lib.common.HttpServerAdapters;
import jdk.httpclient.test.lib.http2.Http2TestServer;
import static java.lang.System.out;
import static java.net.http.HttpClient.Version.HTTP_1_1;

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -58,6 +58,7 @@ import sun.net.NetProperties;
import sun.net.www.HeaderParser;
import static java.lang.System.out;
import static java.lang.System.err;
import static java.lang.String.format;
/**
@ -304,6 +305,9 @@ public class DigestEchoClient {
} catch(Throwable t) {
out.println(DigestEchoServer.now()
+ ": Unexpected exception: " + t);
t.printStackTrace(System.out);
err.println(DigestEchoServer.now()
+ ": Unexpected exception: " + t);
t.printStackTrace();
failed = t;
throw t;
@ -393,15 +397,21 @@ public class DigestEchoClient {
HttpResponse<String> r;
CompletableFuture<HttpResponse<String>> cf1;
String auth = null;
Throwable failed = null;
URI reqURI = null;
try {
for (int i=0; i<data.length; i++) {
for (int i = 0; i < data.length; i++) {
out.println(DigestEchoServer.now() + " ----- iteration " + i + " -----");
List<String> lines = List.of(Arrays.copyOfRange(data, 0, i+1));
List<String> lines = List.of(Arrays.copyOfRange(data, 0, i + 1));
assert lines.size() == i + 1;
String body = lines.stream().collect(Collectors.joining("\r\n"));
BodyPublisher reqBody = BodyPublishers.ofString(body);
HttpRequest.Builder builder = HttpRequest.newBuilder(uri).version(clientVersion)
URI baseReq = URI.create(uri + "?iteration=" + i + ",async=" + async
+ ",addHeaders=" + addHeaders + ",preemptive=" + preemptive
+ ",expectContinue=" + expectContinue + ",version=" + clientVersion);
reqURI = URI.create(baseReq + ",basicCount=" + basicCount.get());
HttpRequest.Builder builder = HttpRequest.newBuilder(reqURI).version(clientVersion)
.POST(reqBody).expectContinue(expectContinue);
boolean isTunnel = isProxy(authType) && useSSL;
if (addHeaders) {
@ -433,8 +443,10 @@ public class DigestEchoClient {
HttpResponse<Stream<String>> resp;
try {
if (async) {
out.printf("%s client.sendAsync(%s)%n", DigestEchoServer.now(), request);
resp = client.sendAsync(request, BodyHandlers.ofLines()).join();
} else {
out.printf("%s client.send(%s)%n", DigestEchoServer.now(), request);
resp = client.send(request, BodyHandlers.ofLines());
}
} catch (Throwable t) {
@ -443,17 +455,10 @@ public class DigestEchoClient {
long n = basicCount.getAndIncrement();
basics.set((basics.get() * n + (stop - start)) / (n + 1));
}
// unwrap CompletionException
if (t instanceof CompletionException) {
assert t.getCause() != null;
t = t.getCause();
}
out.println(DigestEchoServer.now()
+ ": Unexpected exception: " + t);
throw new RuntimeException("Unexpected exception: " + t, t);
throw t;
}
if (addHeaders && !preemptive && (i==0 || isSchemeDisabled())) {
if (addHeaders && !preemptive && (i == 0 || isSchemeDisabled())) {
assert resp.statusCode() == 401 || resp.statusCode() == 407;
Stream<String> respBody = resp.body();
if (respBody != null) {
@ -462,11 +467,15 @@ public class DigestEchoClient {
}
System.out.println(String.format("%s received: adding header %s: %s",
resp.statusCode(), authorizationKey(authType), auth));
request = HttpRequest.newBuilder(uri).version(clientVersion)
reqURI = URI.create(baseReq + ",withAuthorization="
+ authType + ",basicCount=" + basicCount.get());
request = HttpRequest.newBuilder(reqURI).version(clientVersion)
.POST(reqBody).header(authorizationKey(authType), auth).build();
if (async) {
out.printf("%s client.sendAsync(%s)%n", DigestEchoServer.now(), request);
resp = client.sendAsync(request, BodyHandlers.ofLines()).join();
} else {
out.printf("%s client.send(%s)%n", DigestEchoServer.now(), request);
resp = client.send(request, BodyHandlers.ofLines());
}
}
@ -500,6 +509,15 @@ public class DigestEchoClient {
throw new RuntimeException("Unexpected response: " + respLines);
}
}
} catch (Throwable t) {
if (reqURI == null) {
failed = t;
throw t;
}
String decoration = "%s Unexpected exception %s for %s".formatted(DigestEchoServer.now(), t, reqURI);
RuntimeException decorated = new RuntimeException(decoration, t);
failed = decorated;
throw decorated;
} finally {
client = null;
System.gc();
@ -508,7 +526,10 @@ public class DigestEchoClient {
if (queue.remove(100) == ref) break;
}
var error = TRACKER.checkShutdown(900);
if (error != null) throw error;
if (error != null) {
if (failed != null) error.addSuppressed(failed);
throw error;
}
}
System.out.println("OK");
}
@ -546,16 +567,22 @@ public class DigestEchoClient {
byte[] cnonce = new byte[16];
String cnonceStr = null;
DigestEchoServer.DigestResponse challenge = null;
ReferenceQueue<HttpClient> queue = new ReferenceQueue<>();
WeakReference<HttpClient> ref = new WeakReference<>(client, queue);
URI reqURI = null;
Throwable failed = null;
try {
for (int i=0; i<data.length; i++) {
for (int i = 0; i < data.length; i++) {
out.println(DigestEchoServer.now() + "----- iteration " + i + " -----");
List<String> lines = List.of(Arrays.copyOfRange(data, 0, i+1));
List<String> lines = List.of(Arrays.copyOfRange(data, 0, i + 1));
assert lines.size() == i + 1;
String body = lines.stream().collect(Collectors.joining("\r\n"));
HttpRequest.BodyPublisher reqBody = HttpRequest.BodyPublishers.ofString(body);
URI baseReq = URI.create(uri + "?iteration=" + i + ",async=" + async
+ ",expectContinue=" + expectContinue + ",version=" + clientVersion);
reqURI = URI.create(baseReq + ",digestCount=" + digestCount.get());
HttpRequest.Builder reqBuilder = HttpRequest
.newBuilder(uri).version(clientVersion).POST(reqBody)
.newBuilder(reqURI).version(clientVersion).POST(reqBody)
.expectContinue(expectContinue);
boolean isTunnel = isProxy(authType) && useSSL;
@ -578,8 +605,10 @@ public class DigestEchoClient {
HttpRequest request = reqBuilder.build();
HttpResponse<Stream<String>> resp;
if (async) {
out.printf("%s client.sendAsync(%s)%n", DigestEchoServer.now(), request);
resp = client.sendAsync(request, BodyHandlers.ofLines()).join();
} else {
out.printf("%s client.send(%s)%n", DigestEchoServer.now(), request);
resp = client.send(request, BodyHandlers.ofLines());
}
System.out.println(resp);
@ -609,16 +638,18 @@ public class DigestEchoClient {
challenge = DigestEchoServer.DigestResponse
.create(authenticate.substring("Digest ".length()));
String auth = digestResponse(uri, digestMethod, challenge, cnonceStr);
reqURI = URI.create(baseReq + ",withAuth=" + authType + ",digestCount=" + digestCount.get());
try {
request = HttpRequest.newBuilder(uri).version(clientVersion)
.POST(reqBody).header(authorizationKey(authType), auth).build();
request = HttpRequest.newBuilder(reqURI).version(clientVersion)
.POST(reqBody).header(authorizationKey(authType), auth).build();
} catch (IllegalArgumentException x) {
throw x;
}
if (async) {
out.printf("%s client.sendAsync(%s)%n", DigestEchoServer.now(), request);
resp = client.sendAsync(request, BodyHandlers.ofLines()).join();
} else {
out.printf("%s client.send(%s)%n", DigestEchoServer.now(), request);
resp = client.send(request, BodyHandlers.ofLines());
}
System.out.println(resp);
@ -649,7 +680,29 @@ public class DigestEchoClient {
throw new RuntimeException("Unexpected response: " + respLines);
}
}
} catch (Throwable t) {
if (reqURI == null) {
failed = t;
throw t;
}
String decoration = "%s Unexpected exception %s for %s".formatted(DigestEchoServer.now(), t, reqURI);
RuntimeException decorated = new RuntimeException(decoration, t);
failed = decorated;
throw decorated;
} finally {
client = null;
System.gc();
while (!ref.refersTo(null)) {
System.gc();
if (queue.remove(100) == ref) break;
}
var error = TRACKER.checkShutdown(900);
if (error != null) {
if (failed != null) {
error.addSuppressed(failed);
}
throw error;
}
}
System.out.println("OK");
}