8328919: Add BodyHandlers / BodySubscribers methods to handle excessive server input

Reviewed-by: jpai
This commit is contained in:
Volkan Yazici 2025-01-28 10:22:55 +00:00 committed by Jaikiran Pai
parent 1f74caa7da
commit 558255ae70
3 changed files with 620 additions and 1 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -51,6 +51,7 @@ import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.net.ssl.SSLSession;
import jdk.internal.net.http.BufferingSubscriber;
import jdk.internal.net.http.LimitingSubscriber;
import jdk.internal.net.http.LineSubscriberAdapter;
import jdk.internal.net.http.ResponseBodyHandlers.FileDownloadBodyHandler;
import jdk.internal.net.http.ResponseBodyHandlers.PathBodyHandler;
@ -748,6 +749,33 @@ public interface HttpResponse<T> {
.buffering(downstreamHandler.apply(responseInfo),
bufferSize);
}
/**
* {@return a {@code BodyHandler} that limits the number of body bytes
* that are delivered to the given {@code downstreamHandler}}
* <p>
* If the number of body bytes received exceeds the given
* {@code capacity}, {@link BodySubscriber#onError(Throwable) onError}
* is called on the downstream {@code BodySubscriber} with an
* {@link IOException} indicating that the capacity is exceeded, and
* the upstream subscription is cancelled.
*
* @param downstreamHandler the downstream handler to pass received data to
* @param capacity the maximum number of bytes that are allowed
* @throws IllegalArgumentException if {@code capacity} is negative
* @since 25
*/
public static <T> BodyHandler<T> limiting(BodyHandler<T> downstreamHandler, long capacity) {
Objects.requireNonNull(downstreamHandler, "downstreamHandler");
if (capacity < 0) {
throw new IllegalArgumentException("capacity must not be negative: " + capacity);
}
return responseInfo -> {
BodySubscriber<T> downstreamSubscriber = downstreamHandler.apply(responseInfo);
return BodySubscribers.limiting(downstreamSubscriber, capacity);
};
}
}
/**
@ -1350,5 +1378,30 @@ public interface HttpResponse<T> {
{
return new ResponseSubscribers.MappingSubscriber<>(upstream, mapper);
}
/**
* {@return a {@code BodySubscriber} that limits the number of body
* bytes that are delivered to the given {@code downstreamSubscriber}}
* <p>
* If the number of body bytes received exceeds the given
* {@code capacity}, {@link BodySubscriber#onError(Throwable) onError}
* is called on the downstream {@code BodySubscriber} with an
* {@link IOException} indicating that the capacity is exceeded, and
* the upstream subscription is cancelled.
*
* @param downstreamSubscriber the downstream subscriber to pass received data to
* @param capacity the maximum number of bytes that are allowed
* @throws IllegalArgumentException if {@code capacity} is negative
* @since 25
*/
public static <T> BodySubscriber<T> limiting(BodySubscriber<T> downstreamSubscriber, long capacity) {
Objects.requireNonNull(downstreamSubscriber, "downstreamSubscriber");
if (capacity < 0) {
throw new IllegalArgumentException("capacity must not be negative: " + capacity);
}
return new LimitingSubscriber<>(downstreamSubscriber, capacity);
}
}
}

View File

@ -0,0 +1,148 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package jdk.internal.net.http;
import jdk.internal.net.http.ResponseSubscribers.TrustedSubscriber;
import jdk.internal.net.http.common.Utils;
import java.io.IOException;
import java.net.http.HttpResponse.BodySubscriber;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Flow.Subscription;
import static java.util.Objects.requireNonNull;
/**
* A subscriber limiting the maximum number of bytes that are allowed to be consumed by a downstream subscriber.
*
* @param <T> the response type
*/
public final class LimitingSubscriber<T> implements TrustedSubscriber<T> {
private final BodySubscriber<T> downstreamSubscriber;
private final long capacity;
private State state;
private long length;
private interface State {
State TERMINATED = new State() {};
record Subscribed(Subscription subscription) implements State {}
}
/**
* @param downstreamSubscriber the downstream subscriber to pass received data to
* @param capacity the maximum number of bytes that are allowed
* @throws IllegalArgumentException if {@code capacity} is negative
*/
public LimitingSubscriber(BodySubscriber<T> downstreamSubscriber, long capacity) {
if (capacity < 0) {
throw new IllegalArgumentException("capacity must not be negative: " + capacity);
}
this.downstreamSubscriber = requireNonNull(downstreamSubscriber, "downstreamSubscriber");
this.capacity = capacity;
}
@Override
public void onSubscribe(Subscription subscription) {
requireNonNull(subscription, "subscription");
if (state != null) {
subscription.cancel();
} else {
state = new State.Subscribed(subscription);
downstreamSubscriber.onSubscribe(subscription);
}
}
@Override
public void onNext(List<ByteBuffer> buffers) {
// Check arguments
requireNonNull(buffers, "buffers");
assert Utils.hasRemaining(buffers);
// Short-circuit if not subscribed
if (!(state instanceof State.Subscribed subscribed)) {
return;
}
// See if we may consume the input
boolean lengthAllocated = allocateLength(buffers);
if (lengthAllocated) {
downstreamSubscriber.onNext(buffers);
} else { // Otherwise, trigger failure
state = State.TERMINATED;
downstreamSubscriber.onError(new IOException("body exceeds capacity: " + capacity));
subscribed.subscription.cancel();
}
}
private boolean allocateLength(List<ByteBuffer> buffers) {
long bufferLength = Utils.remaining(buffers);
long nextLength;
try {
nextLength = Math.addExact(length, bufferLength);
} catch (ArithmeticException _) {
return false;
}
if (nextLength > capacity) {
return false;
}
length = nextLength;
return true;
}
@Override
public void onError(Throwable throwable) {
requireNonNull(throwable, "throwable");
if (state instanceof State.Subscribed) {
state = State.TERMINATED;
downstreamSubscriber.onError(throwable);
}
}
@Override
public void onComplete() {
if (state instanceof State.Subscribed) {
state = State.TERMINATED;
downstreamSubscriber.onComplete();
}
}
@Override
public CompletionStage<T> getBody() {
return downstreamSubscriber.getBody();
}
}

View File

@ -0,0 +1,418 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @bug 8328919
* @summary tests `limiting()` in `HttpResponse.Body{Handlers,Subscribers}`
* @key randomness
* @library /test/lib
* /test/jdk/java/net/httpclient/lib
* @build jdk.httpclient.test.lib.common.HttpServerAdapters
* jdk.test.lib.RandomFactory
* jdk.test.lib.net.SimpleSSLContext
* @run junit HttpResponseLimitingTest
*/
import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer;
import jdk.test.lib.RandomFactory;
import jdk.test.lib.net.SimpleSSLContext;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpHeaders;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.HttpResponse.BodyHandler;
import java.net.http.HttpResponse.BodyHandlers;
import java.net.http.HttpResponse.BodySubscriber;
import java.net.http.HttpResponse.BodySubscribers;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Flow.Subscription;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static java.net.http.HttpClient.Builder.NO_PROXY;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Arrays.copyOfRange;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
class HttpResponseLimitingTest {
private static final Random RANDOM = RandomFactory.getRandom();
private static final byte[] RESPONSE_BODY = "random non-empty body".getBytes(UTF_8);
private static final String RESPONSE_HEADER_NAME = "X-Excessive-Data";
/**
* A header value larger than {@link #RESPONSE_BODY} to verify that {@code limiting()} doesn't affect header parsing.
*/
private static final String RESPONSE_HEADER_VALUE = "!".repeat(RESPONSE_BODY.length + 1);
private static final ServerClientPair HTTP1 = ServerClientPair.of(HttpClient.Version.HTTP_1_1, false);
private static final ServerClientPair HTTPS1 = ServerClientPair.of(HttpClient.Version.HTTP_1_1, true);
private static final ServerClientPair HTTP2 = ServerClientPair.of(HttpClient.Version.HTTP_2, false);
private static final ServerClientPair HTTPS2 = ServerClientPair.of(HttpClient.Version.HTTP_2, true);
private record ServerClientPair(HttpTestServer server, HttpClient client, HttpRequest request) {
private static final SSLContext SSL_CONTEXT = createSslContext();
private static SSLContext createSslContext() {
try {
return new SimpleSSLContext().get();
} catch (IOException exception) {
throw new UncheckedIOException(exception);
}
}
private ServerClientPair {
try {
server.start();
} catch (Exception serverException) {
try {
client.close();
} catch (Exception clientException) {
Exception localClientException = new RuntimeException("failed closing client", clientException);
serverException.addSuppressed(localClientException);
}
throw new RuntimeException("failed closing server", serverException);
}
}
private static ServerClientPair of(HttpClient.Version version, boolean secure) {
// Create the server and the request URI
SSLContext sslContext = secure ? SSL_CONTEXT : null;
HttpTestServer server = createServer(version, sslContext);
String handlerPath = "/";
String requestUriScheme = secure ? "https" : "http";
URI requestUri = URI.create(requestUriScheme + "://" + server.serverAuthority() + handlerPath);
// Register the request handler
server.addHandler(
(exchange) -> {
exchange.getResponseHeaders().addHeader(RESPONSE_HEADER_NAME, RESPONSE_HEADER_VALUE);
exchange.sendResponseHeaders(200, RESPONSE_BODY.length);
try (var outputStream = exchange.getResponseBody()) {
outputStream.write(RESPONSE_BODY);
}
exchange.close();
},
handlerPath);
// Create the client and the request
HttpClient client = createClient(version, sslContext);
HttpRequest request = HttpRequest.newBuilder(requestUri).version(version).build();
// Create the pair
return new ServerClientPair(server, client, request);
}
private static HttpTestServer createServer(HttpClient.Version version, SSLContext sslContext) {
try {
return HttpTestServer.create(version, sslContext);
} catch (IOException exception) {
throw new UncheckedIOException(exception);
}
}
private static HttpClient createClient(HttpClient.Version version, SSLContext sslContext) {
HttpClient.Builder builder = HttpClient.newBuilder().version(version).proxy(NO_PROXY);
if (sslContext != null) {
builder.sslContext(sslContext);
}
return builder.build();
}
private <T> HttpResponse<T> request(BodyHandler<T> downstreamHandler, long capacity) throws Exception {
var handler = BodyHandlers.limiting(downstreamHandler, capacity);
return client.send(request, handler);
}
@Override
public String toString() {
String version = client.version().toString();
return client.sslContext() != null ? version.replaceFirst("_", "S_") : version;
}
}
@AfterAll
static void closeServerClientPairs() {
Exception[] exceptionRef = {null};
Stream
.of(HTTP1, HTTPS1, HTTP2, HTTPS2)
.flatMap(pair -> Stream.<Runnable>of(
pair.client::close,
pair.server::stop))
.forEach(closer -> {
try {
closer.run();
} catch (Exception exception) {
if (exceptionRef[0] == null) {
exceptionRef[0] = exception;
} else {
exceptionRef[0].addSuppressed(exception);
}
}
});
if (exceptionRef[0] != null) {
throw new RuntimeException("failed closing one or more server-client pairs", exceptionRef[0]);
}
}
@ParameterizedTest
@MethodSource("sufficientCapacities")
void testSuccessOnSufficientCapacityForByteArray(ServerClientPair pair, long sufficientCapacity) throws Exception {
HttpResponse<byte[]> response = pair.request(BodyHandlers.ofByteArray(), sufficientCapacity);
verifyHeaders(response.headers());
assertArrayEquals(RESPONSE_BODY, response.body());
}
@ParameterizedTest
@MethodSource("sufficientCapacities")
void testSuccessOnSufficientCapacityForInputStream(ServerClientPair pair, long sufficientCapacity) throws Exception {
HttpResponse<InputStream> response = pair.request(BodyHandlers.ofInputStream(), sufficientCapacity);
verifyHeaders(response.headers());
try (InputStream responseBodyStream = response.body()) {
byte[] responseBodyBuffer = responseBodyStream.readAllBytes();
assertArrayEquals(RESPONSE_BODY, responseBodyBuffer);
}
}
static Arguments[] sufficientCapacities() {
long minExtremeCapacity = RESPONSE_BODY.length;
long maxExtremeCapacity = Long.MAX_VALUE;
long nonExtremeCapacity = RANDOM.nextLong(minExtremeCapacity + 1, maxExtremeCapacity);
return capacityArgs(minExtremeCapacity, nonExtremeCapacity, maxExtremeCapacity);
}
@ParameterizedTest
@MethodSource("insufficientCapacities")
void testFailureOnInsufficientCapacityForByteArray(ServerClientPair pair, long insufficientCapacity) {
BodyHandler<byte[]> handler = responseInfo -> {
verifyHeaders(responseInfo.headers());
return BodySubscribers.limiting(BodySubscribers.ofByteArray(), insufficientCapacity);
};
var exception = assertThrows(IOException.class, () -> pair.request(handler, insufficientCapacity));
assertEquals(exception.getMessage(), "body exceeds capacity: " + insufficientCapacity);
}
@ParameterizedTest
@MethodSource("insufficientCapacities")
void testFailureOnInsufficientCapacityForInputStream(ServerClientPair pair, long insufficientCapacity) throws Exception {
HttpResponse<InputStream> response = pair.request(BodyHandlers.ofInputStream(), insufficientCapacity);
verifyHeaders(response.headers());
try (InputStream responseBodyStream = response.body()) {
var exception = assertThrows(IOException.class, responseBodyStream::readAllBytes);
assertNotNull(exception.getCause());
assertEquals(exception.getCause().getMessage(), "body exceeds capacity: " + insufficientCapacity);
}
}
static Arguments[] insufficientCapacities() {
long minExtremeCapacity = 0;
long maxExtremeCapacity = RESPONSE_BODY.length - 1;
long nonExtremeCapacity = RANDOM.nextLong(minExtremeCapacity + 1, maxExtremeCapacity);
return capacityArgs(minExtremeCapacity, nonExtremeCapacity, maxExtremeCapacity);
}
private static void verifyHeaders(HttpHeaders responseHeaders) {
List<String> responseHeaderValues = responseHeaders.allValues(RESPONSE_HEADER_NAME);
assertEquals(List.of(RESPONSE_HEADER_VALUE), responseHeaderValues);
}
@ParameterizedTest
@MethodSource("invalidCapacities")
void testFailureOnInvalidCapacityForHandler(long invalidCapacity) {
var exception = assertThrows(
IllegalArgumentException.class,
() -> BodyHandlers.limiting(BodyHandlers.ofByteArray(), invalidCapacity));
assertEquals(exception.getMessage(), "capacity must not be negative: " + invalidCapacity);
}
@ParameterizedTest
@MethodSource("invalidCapacities")
void testFailureOnInvalidCapacityForSubscriber(long invalidCapacity) {
var exception = assertThrows(
IllegalArgumentException.class,
() -> BodySubscribers.limiting(BodySubscribers.ofByteArray(), invalidCapacity));
assertEquals(exception.getMessage(), "capacity must not be negative: " + invalidCapacity);
}
static long[] invalidCapacities() {
long minExtremeCapacity = Long.MIN_VALUE;
long maxExtremeCapacity = -1;
long nonExtremeCapacity = RANDOM.nextLong(minExtremeCapacity + 1, maxExtremeCapacity);
return new long[]{minExtremeCapacity, nonExtremeCapacity, maxExtremeCapacity};
}
@Test
void testFailureOnNullDownstreamHandler() {
var exception = assertThrows(NullPointerException.class, () -> BodyHandlers.limiting(null, 0));
assertEquals(exception.getMessage(), "downstreamHandler");
}
@Test
void testFailureOnNullDownstreamSubscriber() {
var exception = assertThrows(NullPointerException.class, () -> BodySubscribers.limiting(null, 0));
assertEquals(exception.getMessage(), "downstreamSubscriber");
}
private static Arguments[] capacityArgs(long... capacities) {
return Stream
.of(HTTP1, HTTPS1, HTTP2, HTTPS2)
.flatMap(pair -> Arrays
.stream(capacities)
.mapToObj(capacity -> Arguments.of(pair, capacity)))
.toArray(Arguments[]::new);
}
@Test
void testSubscriberForCompleteConsumption() {
// Create the subscriber (with sufficient capacity)
ObserverSubscriber downstreamSubscriber = new ObserverSubscriber();
int sufficientCapacity = RESPONSE_BODY.length;
BodySubscriber<String> subscriber = BodySubscribers.limiting(downstreamSubscriber, sufficientCapacity);
// Emit values
subscriber.onSubscribe(DummySubscription.INSTANCE);
byte[] responseBodyPart1 = {RESPONSE_BODY[0]};
byte[] responseBodyPart2 = copyOfRange(RESPONSE_BODY, 1, RESPONSE_BODY.length);
List<ByteBuffer> buffers = toByteBuffers(responseBodyPart1, responseBodyPart2);
subscriber.onNext(buffers);
// Verify the downstream propagation
assertSame(buffers, downstreamSubscriber.lastBuffers);
assertNull(downstreamSubscriber.lastThrowable);
assertFalse(downstreamSubscriber.completed);
}
@Test
void testSubscriberForFailureOnExcess() {
// Create the subscriber (with insufficient capacity)
ObserverSubscriber downstreamSubscriber = new ObserverSubscriber();
int insufficientCapacity = 2;
BodySubscriber<String> subscriber = BodySubscribers.limiting(downstreamSubscriber, insufficientCapacity);
// Emit values
subscriber.onSubscribe(DummySubscription.INSTANCE);
byte[] responseBodyPart1 = {RESPONSE_BODY[0]};
byte[] responseBodyPart2 = copyOfRange(RESPONSE_BODY, 1, RESPONSE_BODY.length);
List<ByteBuffer> buffers = toByteBuffers(responseBodyPart1, responseBodyPart2);
subscriber.onNext(buffers);
// Verify the downstream propagation
assertNull(downstreamSubscriber.lastBuffers);
assertNotNull(downstreamSubscriber.lastThrowable);
assertEquals(
"body exceeds capacity: " + insufficientCapacity,
downstreamSubscriber.lastThrowable.getMessage());
assertFalse(downstreamSubscriber.completed);
}
private static List<ByteBuffer> toByteBuffers(byte[]... buffers) {
return Arrays.stream(buffers).map(ByteBuffer::wrap).collect(Collectors.toList());
}
private static final class ObserverSubscriber implements BodySubscriber<String> {
private List<ByteBuffer> lastBuffers;
private Throwable lastThrowable;
private boolean completed;
@Override
public CompletionStage<String> getBody() {
throw new UnsupportedOperationException();
}
@Override
public void onSubscribe(Subscription subscription) {
subscription.request(Long.MAX_VALUE);
}
@Override
public void onNext(List<ByteBuffer> buffers) {
lastBuffers = buffers;
}
@Override
public void onError(Throwable throwable) {
lastThrowable = throwable;
}
@Override
public void onComplete() {
completed = true;
}
}
private enum DummySubscription implements Subscription {
INSTANCE;
@Override
public void request(long n) {
// Do nothing
}
@Override
public void cancel() {
// Do nothing
}
}
}