From e8db14f584fa92db170e056bc68074ccabae82c9 Mon Sep 17 00:00:00 2001 From: Daniel Fuchs Date: Mon, 22 Sep 2025 10:12:12 +0000 Subject: [PATCH] 8349910: Implement JEP 517: HTTP/3 for the HTTP Client API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Aleksei Efimov Co-authored-by: Bradford Wetmore Co-authored-by: Daniel JeliƄski Co-authored-by: Darragh Clarke Co-authored-by: Jaikiran Pai Co-authored-by: Michael McMahon Co-authored-by: Volkan Yazici Co-authored-by: Conor Cleary Co-authored-by: Patrick Concannon Co-authored-by: Rahul Yadav Co-authored-by: Daniel Fuchs Reviewed-by: djelinski, jpai, aefimov, abarashev, michaelm --- .../net/quic/QuicKeyUnavailableException.java | 45 + .../internal/net/quic/QuicOneRttContext.java | 38 + .../jdk/internal/net/quic/QuicTLSContext.java | 152 + .../jdk/internal/net/quic/QuicTLSEngine.java | 508 ++ .../net/quic/QuicTransportErrors.java | 349 ++ .../net/quic/QuicTransportException.java | 111 + .../quic/QuicTransportParametersConsumer.java | 39 + .../jdk/internal/net/quic/QuicVersion.java | 100 + src/java.base/share/classes/module-info.java | 3 + .../share/classes/sun/security/ssl/Alert.java | 8 +- .../sun/security/ssl/AlpnExtension.java | 20 +- .../sun/security/ssl/CertificateMessage.java | 29 +- .../classes/sun/security/ssl/ClientHello.java | 2 +- .../classes/sun/security/ssl/Finished.java | 20 + .../classes/sun/security/ssl/KeyUpdate.java | 6 + .../sun/security/ssl/OutputRecord.java | 14 +- .../security/ssl/PostHandshakeContext.java | 17 +- .../classes/sun/security/ssl/QuicCipher.java | 699 +++ .../security/ssl/QuicEngineOutputRecord.java | 245 + .../sun/security/ssl/QuicKeyManager.java | 1216 +++++ .../sun/security/ssl/QuicTLSEngineImpl.java | 893 ++++ .../ssl/QuicTransportParametersExtension.java | 189 + .../security/ssl/SSLAlgorithmConstraints.java | 40 + .../sun/security/ssl/SSLConfiguration.java | 21 +- .../sun/security/ssl/SSLContextImpl.java | 4 + .../sun/security/ssl/SSLExtension.java | 27 +- .../classes/sun/security/ssl/ServerHello.java | 38 +- .../security/ssl/SunX509KeyManagerImpl.java | 17 + .../sun/security/ssl/TransportContext.java | 6 +- .../sun/security/ssl/X509Authentication.java | 46 +- .../ssl/X509KeyManagerCertChecking.java | 20 + .../sun/security/ssl/X509KeyManagerImpl.java | 17 + .../security/ssl/X509TrustManagerImpl.java | 65 + .../share/conf/security/java.security | 27 + .../classes/java/net/http/HttpClient.java | 100 +- .../classes/java/net/http/HttpOption.java | 176 + .../classes/java/net/http/HttpRequest.java | 64 +- .../java/net/http/HttpRequestOptionImpl.java | 34 + .../classes/java/net/http/HttpResponse.java | 119 +- .../java/net/http/StreamLimitException.java | 104 + .../UnsupportedProtocolVersionException.java | 55 + .../classes/java/net/http/package-info.java | 21 +- .../net/http/AltServicesRegistry.java | 569 +++ .../internal/net/http/AltSvcProcessor.java | 495 ++ .../jdk/internal/net/http/Exchange.java | 102 +- .../jdk/internal/net/http/ExchangeImpl.java | 434 +- .../net/http/H3FrameOrderVerifier.java | 200 + .../jdk/internal/net/http/Http1Exchange.java | 2 +- .../internal/net/http/Http2ClientImpl.java | 3 +- .../internal/net/http/Http2Connection.java | 46 +- .../internal/net/http/Http3ClientImpl.java | 844 ++++ .../net/http/Http3ClientProperties.java | 171 + .../internal/net/http/Http3Connection.java | 1657 +++++++ .../net/http/Http3ConnectionPool.java | 207 + .../internal/net/http/Http3ExchangeImpl.java | 1795 +++++++ .../net/http/Http3PendingConnections.java | 224 + .../internal/net/http/Http3PushManager.java | 811 +++ .../net/http/Http3PushPromiseStream.java | 746 +++ .../jdk/internal/net/http/Http3Stream.java | 693 +++ .../jdk/internal/net/http/HttpClientImpl.java | 155 +- .../jdk/internal/net/http/HttpConnection.java | 75 +- .../internal/net/http/HttpQuicConnection.java | 690 +++ .../net/http/HttpRequestBuilderImpl.java | 52 + .../internal/net/http/HttpRequestImpl.java | 48 +- .../internal/net/http/HttpResponseImpl.java | 14 +- .../net/http/ImmutableHttpRequest.java | 16 +- .../jdk/internal/net/http/MultiExchange.java | 433 +- .../classes/jdk/internal/net/http/Origin.java | 41 + .../net/http/PlainHttpConnection.java | 68 +- .../jdk/internal/net/http/PushGroup.java | 19 +- .../jdk/internal/net/http/Response.java | 28 +- .../net/http/ResponseSubscribers.java | 12 +- .../classes/jdk/internal/net/http/Stream.java | 41 +- .../jdk/internal/net/http/common/Alpns.java | 5 + .../common/ConnectionExpiredException.java | 4 +- .../internal/net/http/common/Deadline.java | 54 +- .../common/HttpBodySubscriberWrapper.java | 5 + .../net/http/common/HttpHeadersBuilder.java | 11 +- .../jdk/internal/net/http/common/Log.java | 240 +- .../net/http/common/OperationTrackers.java | 2 + .../internal/net/http/common/TimeSource.java | 3 +- .../jdk/internal/net/http/common/Utils.java | 572 ++- .../internal/net/http/frame/AltSvcFrame.java | 77 + .../net/http/frame/FramesDecoder.java | 36 +- .../net/http/frame/FramesEncoder.java | 16 + .../internal/net/http/frame/Http2Frame.java | 3 +- .../jdk/internal/net/http/hpack/Decoder.java | 2 +- .../internal/net/http/hpack/ISO_8859_1.java | 7 +- .../internal/net/http/hpack/QuickHuffman.java | 30 +- .../net/http/http3/ConnectionSettings.java | 63 + .../internal/net/http/http3/Http3Error.java | 308 ++ .../http/http3/frames/AbstractHttp3Frame.java | 118 + .../http/http3/frames/CancelPushFrame.java | 115 + .../net/http/http3/frames/DataFrame.java | 60 + .../net/http/http3/frames/FramesDecoder.java | 331 ++ .../net/http/http3/frames/GoAwayFrame.java | 123 + .../net/http/http3/frames/HeadersFrame.java | 60 + .../net/http/http3/frames/Http3Frame.java | 214 + .../net/http/http3/frames/Http3FrameType.java | 201 + .../net/http/http3/frames/MalformedFrame.java | 124 + .../net/http/http3/frames/MaxPushIdFrame.java | 117 + .../net/http/http3/frames/PartialFrame.java | 154 + .../http/http3/frames/PushPromiseFrame.java | 84 + .../net/http/http3/frames/SettingsFrame.java | 364 ++ .../net/http/http3/frames/UnknownFrame.java | 67 + .../net/http/http3/streams/Http3Streams.java | 117 + .../streams/PeerUniStreamDispatcher.java | 328 ++ .../http/http3/streams/QueuingStreamPair.java | 183 + .../http3/streams/QuicStreamIntReader.java | 192 + .../net/http/http3/streams/UniStreamPair.java | 505 ++ .../jdk/internal/net/http/qpack/Decoder.java | 400 ++ .../net/http/qpack/DecodingCallback.java | 202 + .../internal/net/http/qpack/DynamicTable.java | 1069 ++++ .../jdk/internal/net/http/qpack/Encoder.java | 672 +++ .../net/http/qpack/FieldSectionPrefix.java | 75 + .../internal/net/http/qpack/HeaderField.java | 37 + .../internal/net/http/qpack/HeadersTable.java | 64 + .../net/http/qpack/InsertionPolicy.java | 29 + .../jdk/internal/net/http/qpack/QPACK.java | 229 + .../net/http/qpack/QPackException.java | 68 + .../internal/net/http/qpack/StaticTable.java | 192 + .../internal/net/http/qpack/TableEntry.java | 76 + .../net/http/qpack/TablesIndexer.java | 112 + .../internal/net/http/qpack/package-info.java | 34 + .../readers/DecoderInstructionsReader.java | 159 + .../readers/EncoderInstructionsReader.java | 245 + .../FieldLineIndexedPostBaseReader.java | 83 + .../qpack/readers/FieldLineIndexedReader.java | 86 + .../readers/FieldLineLiteralsReader.java | 119 + .../FieldLineNameRefPostBaseReader.java | 125 + .../readers/FieldLineNameReferenceReader.java | 126 + .../http/qpack/readers/FieldLineReader.java | 128 + .../http/qpack/readers/HeaderFrameReader.java | 414 ++ .../net/http/qpack/readers/IntegerReader.java | 177 + .../net/http/qpack/readers/ReaderError.java | 49 + .../net/http/qpack/readers/StringReader.java | 152 + .../writers/BinaryRepresentationWriter.java | 33 + .../writers/DecoderInstructionsWriter.java | 121 + .../writers/EncoderDuplicateEntryWriter.java | 56 + .../EncoderDynamicTableCapacityWriter.java | 54 + .../EncoderInsertIndexedNameWriter.java | 102 + .../EncoderInsertLiteralNameWriter.java | 101 + .../writers/EncoderInstructionsWriter.java | 190 + .../writers/FieldLineIndexedNameWriter.java | 144 + .../qpack/writers/FieldLineIndexedWriter.java | 105 + .../writers/FieldLineLiteralsWriter.java | 104 + .../writers/FieldLineSectionPrefixWriter.java | 96 + .../http/qpack/writers/HeaderFrameWriter.java | 108 + .../net/http/qpack/writers/IntegerWriter.java | 133 + .../net/http/qpack/writers/StringWriter.java | 139 + .../internal/net/http/quic/BuffersReader.java | 707 +++ .../internal/net/http/quic/CodingContext.java | 169 + .../net/http/quic/ConnectionTerminator.java | 38 + .../http/quic/ConnectionTerminatorImpl.java | 475 ++ .../net/http/quic/IdleTimeoutManager.java | 528 ++ .../net/http/quic/LocalConnIdManager.java | 175 + .../internal/net/http/quic/OrderedFlow.java | 389 ++ .../internal/net/http/quic/PacketEmitter.java | 134 + .../net/http/quic/PacketSpaceManager.java | 2370 +++++++++ .../net/http/quic/PeerConnIdManager.java | 520 ++ .../net/http/quic/PeerConnectionId.java | 92 + .../internal/net/http/quic/QuicClient.java | 585 +++ .../http/quic/QuicCongestionController.java | 75 + .../net/http/quic/QuicConnection.java | 229 + .../net/http/quic/QuicConnectionId.java | 151 + .../http/quic/QuicConnectionIdFactory.java | 354 ++ .../net/http/quic/QuicConnectionImpl.java | 4353 +++++++++++++++++ .../internal/net/http/quic/QuicEndpoint.java | 2062 ++++++++ .../internal/net/http/quic/QuicInstance.java | 150 + .../net/http/quic/QuicPacketReceiver.java | 144 + .../quic/QuicRenoCongestionController.java | 220 + .../net/http/quic/QuicRttEstimator.java | 170 + .../internal/net/http/quic/QuicSelector.java | 536 ++ .../http/quic/QuicStreamLimitException.java | 38 + .../net/http/quic/QuicTimedEvent.java | 160 + .../net/http/quic/QuicTimerQueue.java | 522 ++ .../http/quic/QuicTransportParameters.java | 1319 +++++ .../net/http/quic/TerminationCause.java | 211 + .../net/http/quic/VariableLengthEncoder.java | 341 ++ .../net/http/quic/frames/AckFrame.java | 931 ++++ .../quic/frames/ConnectionCloseFrame.java | 238 + .../net/http/quic/frames/CryptoFrame.java | 153 + .../http/quic/frames/DataBlockedFrame.java | 91 + .../http/quic/frames/HandshakeDoneFrame.java | 71 + .../net/http/quic/frames/MaxDataFrame.java | 91 + .../http/quic/frames/MaxStreamDataFrame.java | 101 + .../net/http/quic/frames/MaxStreamsFrame.java | 108 + .../quic/frames/NewConnectionIDFrame.java | 141 + .../net/http/quic/frames/NewTokenFrame.java | 103 + .../net/http/quic/frames/PaddingFrame.java | 102 + .../http/quic/frames/PathChallengeFrame.java | 89 + .../http/quic/frames/PathResponseFrame.java | 89 + .../net/http/quic/frames/PingFrame.java | 64 + .../net/http/quic/frames/QuicFrame.java | 387 ++ .../http/quic/frames/ResetStreamFrame.java | 109 + .../quic/frames/RetireConnectionIDFrame.java | 91 + .../http/quic/frames/StopSendingFrame.java | 93 + .../quic/frames/StreamDataBlockedFrame.java | 118 + .../net/http/quic/frames/StreamFrame.java | 264 + .../http/quic/frames/StreamsBlockedFrame.java | 107 + .../internal/net/http/quic/package-info.java | 40 + .../http/quic/packets/HandshakePacket.java | 97 + .../net/http/quic/packets/InitialPacket.java | 125 + .../net/http/quic/packets/LongHeader.java | 57 + .../http/quic/packets/LongHeaderPacket.java | 71 + .../net/http/quic/packets/OneRttPacket.java | 125 + .../net/http/quic/packets/PacketSpace.java | 245 + .../net/http/quic/packets/QuicPacket.java | 249 + .../http/quic/packets/QuicPacketDecoder.java | 1748 +++++++ .../http/quic/packets/QuicPacketEncoder.java | 1746 +++++++ .../http/quic/packets/QuicPacketNumbers.java | 197 + .../net/http/quic/packets/RetryPacket.java | 100 + .../http/quic/packets/ShortHeaderPacket.java | 54 + .../packets/VersionNegotiationPacket.java | 76 + .../net/http/quic/packets/ZeroRttPacket.java | 103 + .../http/quic/streams/AbstractQuicStream.java | 118 + .../http/quic/streams/CryptoWriterQueue.java | 213 + .../net/http/quic/streams/QuicBidiStream.java | 135 + .../http/quic/streams/QuicBidiStreamImpl.java | 151 + .../quic/streams/QuicConnectionStreams.java | 1590 ++++++ .../http/quic/streams/QuicReceiverStream.java | 190 + .../quic/streams/QuicReceiverStreamImpl.java | 942 ++++ .../http/quic/streams/QuicSenderStream.java | 197 + .../quic/streams/QuicSenderStreamImpl.java | 662 +++ .../net/http/quic/streams/QuicStream.java | 149 + .../http/quic/streams/QuicStreamReader.java | 138 + .../http/quic/streams/QuicStreamWriter.java | 169 + .../net/http/quic/streams/QuicStreams.java | 90 + .../quic/streams/StreamCreationPermit.java | 317 ++ .../http/quic/streams/StreamWriterQueue.java | 550 +++ .../share/classes/module-info.java | 107 +- .../security/pkcs11/P11SecretKeyFactory.java | 4 + test/jdk/com/sun/net/httpserver/SANTest.java | 12 + .../java/net/httpclient/AbstractNoBody.java | 83 +- .../AbstractThrowingPublishers.java | 134 +- .../AbstractThrowingPushPromises.java | 150 +- .../AbstractThrowingSubscribers.java | 131 +- .../httpclient/AggregateRequestBodyTest.java | 70 +- .../net/httpclient/AltServiceUsageTest.java | 454 ++ .../net/httpclient/AsFileDownloadTest.java | 160 +- .../net/httpclient/AsyncExecutorShutdown.java | 197 +- .../java/net/httpclient/AsyncShutdownNow.java | 122 +- .../net/httpclient/AuthFilterCacheTest.java | 111 +- .../java/net/httpclient/BasicAuthTest.java | 14 +- .../java/net/httpclient/BasicHTTP2Test.java | 320 ++ .../java/net/httpclient/BasicHTTP3Test.java | 482 ++ .../net/httpclient/BasicRedirectTest.java | 173 +- .../net/httpclient/CancelRequestTest.java | 101 +- .../httpclient/CancelStreamedBodyTest.java | 81 +- ...java => CancelledPartialResponseTest.java} | 170 +- .../net/httpclient/CancelledResponse.java | 9 +- .../net/httpclient/CancelledResponse2.java | 88 +- .../net/httpclient/ConcurrentResponses.java | 69 +- .../httpclient/ContentLengthHeaderTest.java | 48 +- .../java/net/httpclient/CookieHeaderTest.java | 23 +- .../httpclient/CustomRequestPublisher.java | 63 +- .../httpclient/CustomResponseSubscriber.java | 2 +- .../net/httpclient/DependentActionsTest.java | 147 +- .../DependentPromiseActionsTest.java | 263 +- .../java/net/httpclient/DigestEchoClient.java | 185 +- .../net/httpclient/DigestEchoClientSSL.java | 10 +- .../java/net/httpclient/DigestEchoServer.java | 100 +- .../net/httpclient/EmptyAuthenticate.java | 14 +- .../net/httpclient/EncodedCharsInURI.java | 65 +- .../net/httpclient/EscapedOctetsInURI.java | 64 +- .../java/net/httpclient/ExecutorShutdown.java | 93 +- .../net/httpclient/ExpectContinueTest.java | 18 +- .../httpclient/FlowAdapterPublisherTest.java | 31 +- .../httpclient/FlowAdapterSubscriberTest.java | 31 +- .../net/httpclient/ForbiddenHeadTest.java | 29 +- .../net/httpclient/GZIPInputStreamTest.java | 156 +- .../net/httpclient/HandshakeFailureTest.java | 1 + test/jdk/java/net/httpclient/HeadTest.java | 87 +- .../net/httpclient/HeadersLowerCaseTest.java | 231 + .../net/httpclient/HttpClientBuilderTest.java | 11 + .../java/net/httpclient/HttpClientClose.java | 130 +- .../net/httpclient/HttpClientShutdown.java | 102 +- .../httpclient/HttpGetInCancelledFuture.java | 105 +- .../java/net/httpclient/HttpRedirectTest.java | 35 +- .../httpclient/HttpRequestBuilderTest.java | 78 +- .../httpclient/HttpRequestNewBuilderTest.java | 30 +- .../HttpResponseConnectionLabelTest.java | 64 +- .../httpclient/HttpResponseLimitingTest.java | 48 +- .../net/httpclient/HttpSlowServerTest.java | 27 +- .../java/net/httpclient/ISO_8859_1_Test.java | 82 +- .../httpclient/IdleConnectionTimeoutTest.java | 360 ++ .../net/httpclient/ImmutableFlowItems.java | 1 - .../httpclient/ImmutableSSLSessionTest.java | 381 ++ ...InvalidInputStreamSubscriptionRequest.java | 206 +- .../InvalidSubscriptionRequest.java | 137 +- .../net/httpclient/LargeHandshakeTest.java | 33 +- .../net/httpclient/LargeResponseTest.java | 34 +- .../net/httpclient/LineBodyHandlerTest.java | 65 +- .../jdk/java/net/httpclient/ManyRequests.java | 19 +- .../java/net/httpclient/ManyRequests2.java | 6 +- .../net/httpclient/ManyRequestsLegacy.java | 20 +- .../httpclient/MappingResponseSubscriber.java | 7 +- .../java/net/httpclient/NoBodyPartOne.java | 10 + .../java/net/httpclient/NoBodyPartThree.java | 17 +- .../java/net/httpclient/NoBodyPartTwo.java | 16 +- .../net/httpclient/NonAsciiCharsInURI.java | 67 +- .../BodyHandlerOfFileDownloadTest.java | 74 +- .../PathSubscriber/BodyHandlerOfFileTest.java | 69 +- .../BodySubscriberOfFileTest.java | 70 +- .../ProxyAuthDisabledSchemesSSL.java | 15 +- test/jdk/java/net/httpclient/ProxyTest.java | 22 +- .../net/httpclient/RedirectMethodChange.java | 50 +- .../net/httpclient/RedirectTimeoutTest.java | 65 +- .../net/httpclient/RedirectWithCookie.java | 38 +- .../java/net/httpclient/ReferenceTracker.java | 29 + .../net/httpclient/RequestBuilderTest.java | 19 +- .../java/net/httpclient/Response1xxTest.java | 219 +- .../net/httpclient/Response204V2Test.java | 54 +- .../httpclient/ResponseBodyBeforeError.java | 2 +- .../net/httpclient/ResponsePublisher.java | 73 +- .../net/httpclient/RestrictedHeadersTest.java | 4 +- .../java/net/httpclient/RetryWithCookie.java | 40 +- test/jdk/java/net/httpclient/ShutdownNow.java | 68 +- test/jdk/java/net/httpclient/SmokeTest.java | 22 +- .../net/httpclient/SpecialHeadersTest.java | 9 +- .../java/net/httpclient/SplitResponse.java | 10 +- .../java/net/httpclient/StreamCloseTest.java | 6 +- .../java/net/httpclient/StreamingBody.java | 37 +- test/jdk/java/net/httpclient/TEST.properties | 17 +- .../jdk/java/net/httpclient/TimeoutBasic.java | 82 +- .../java/net/httpclient/TlsContextTest.java | 49 +- .../java/net/httpclient/UnauthorizedTest.java | 37 +- .../httpclient/UserAuthWithAuthenticator.java | 465 +- .../java/net/httpclient/UserCookieTest.java | 32 +- test/jdk/java/net/httpclient/VersionTest.java | 5 +- .../net/http/Http3ConnectionAccess.java | 64 + .../common/ImmutableSSLSessionAccess.java | 42 + .../altsvc/AltServiceReasonableAssurance.java | 688 +++ .../httpclient/altsvc/altsvc-dns-hosts.txt | 23 + .../net/http/common/TestLoggerUtil.java | 46 + .../httpclient/http2/BadPushPromiseTest.java | 2 +- .../http2/ContinuationFrameTest.java | 24 +- .../java/net/httpclient/http2/ErrorTest.java | 19 +- .../http2/HpackBinaryTestDriver.java | 2 +- .../httpclient/http2/HpackHuffmanDriver.java | 2 +- .../http2/IdleConnectionTimeoutTest.java | 227 - .../http2/IdlePooledConnectionTest.java | 5 +- .../java/net/httpclient/http2/ProxyTest2.java | 25 +- .../http2/PushPromiseContinuation.java | 14 +- .../net/httpclient/http2/RedirectTest.java | 8 +- .../java/net/httpclient/http2/SimpleGet.java | 225 + .../http2/StreamFlowControlTest.java | 2 +- .../httpclient/http2/TrailingHeadersTest.java | 12 +- .../net/httpclient/http2/UserInfoTest.java | 13 +- .../http3/BadCipherSuiteErrorTest.java | 119 + .../httpclient/http3/FramesDecoderTest.java | 227 + .../net/httpclient/http3/GetHTTP3Test.java | 476 ++ .../httpclient/http3/H3BadHeadersTest.java | 330 ++ .../net/httpclient/http3/H3BasicTest.java | 405 ++ .../httpclient/http3/H3ConcurrentPush.java | 471 ++ .../http3/H3ConnectionPoolTest.java | 581 +++ .../httpclient/http3/H3DataLimitsTest.java | 270 + .../httpclient/http3/H3ErrorHandlingTest.java | 1063 ++++ .../http3/H3FixedThreadPoolTest.java | 300 ++ .../net/httpclient/http3/H3GoAwayTest.java | 180 + .../http3/H3HeaderSizeLimitTest.java | 162 + .../httpclient/http3/H3HeadersEncoding.java | 315 ++ .../http3/H3ImplicitPushCancel.java | 258 + .../http3/H3InsertionsLimitTest.java | 179 + .../http3/H3MalformedResponseTest.java | 437 ++ .../http3/H3MaxInitialTimeoutTest.java | 250 + .../http3/H3MemoryHandlingTest.java | 232 + .../H3MultipleConnectionsToSameHost.java | 338 ++ .../net/httpclient/http3/H3ProxyTest.java | 396 ++ .../net/httpclient/http3/H3PushCancel.java | 508 ++ .../httpclient/http3/H3QuicTLSConnection.java | 363 ++ .../net/httpclient/http3/H3RedirectTest.java | 260 + .../net/httpclient/http3/H3ServerPush.java | 396 ++ .../httpclient/http3/H3ServerPushCancel.java | 607 +++ .../httpclient/http3/H3ServerPushTest.java | 1224 +++++ .../http3/H3ServerPushWithDiffTypes.java | 292 ++ .../net/httpclient/http3/H3SimpleGet.java | 315 ++ .../net/httpclient/http3/H3SimplePost.java | 207 + .../net/httpclient/http3/H3SimpleTest.java | 129 + .../httpclient/http3/H3StopSendingTest.java | 259 + .../http3/H3StreamLimitReachedTest.java | 971 ++++ .../java/net/httpclient/http3/H3Timeout.java | 187 + .../http3/H3UnsupportedSSLParametersTest.java | 78 + .../net/httpclient/http3/H3UserInfoTest.java | 193 + .../net/httpclient/http3/HTTP3NoBodyTest.java | 324 ++ .../http3/Http3ExpectContinueTest.java | 250 + .../http3/PeerUniStreamDispatcherTest.java | 436 ++ .../net/httpclient/http3/PostHTTP3Test.java | 522 ++ .../net/httpclient/http3/StopSendingTest.java | 215 + .../net/httpclient/http3/StreamLimitTest.java | 266 + .../test/lib/common/DynamicKeyStoreUtil.java | 266 + .../test/lib/common/HttpServerAdapters.java | 880 +++- .../lib/common/RequestPathMatcherUtil.java | 68 + .../lib/common/TestServerConfigurator.java | 19 +- .../httpclient/test/lib/common/TestUtil.java | 58 + .../test/lib/http2/BodyOutputStream.java | 48 +- .../test/lib/http2/EchoHandler.java | 7 +- .../test/lib/http2/Http2EchoHandler.java | 34 +- .../test/lib/http2/Http2Handler.java | 3 +- .../test/lib/http2/Http2RedirectHandler.java | 9 +- .../test/lib/http2/Http2TestExchange.java | 170 +- .../test/lib/http2/Http2TestExchangeImpl.java | 25 +- .../test/lib/http2/Http2TestServer.java | 259 +- .../lib/http2/Http2TestServerConnection.java | 261 +- .../test/lib/http2/OutgoingPushPromise.java | 16 +- .../test/lib/http3/Http3ServerConnection.java | 801 +++ .../test/lib/http3/Http3ServerExchange.java | 801 +++ .../test/lib/http3/Http3ServerStreamImpl.java | 489 ++ .../test/lib/http3/Http3TestServer.java | 371 ++ .../lib/http3/UnknownOrReservedFrame.java | 94 + .../test/lib/quic/ClientConnection.java | 134 + .../test/lib/quic/ConnectedBidiStream.java | 129 + .../test/lib/quic/DatagramDeliveryPolicy.java | 315 ++ .../httpclient/test/lib/quic/OutStream.java | 75 + .../test/lib/quic/QueueInputStream.java | 164 + .../httpclient/test/lib/quic/QuicServer.java | 913 ++++ .../test/lib/quic/QuicServerConnection.java | 570 +++ .../test/lib/quic/QuicServerHandler.java | 87 + .../test/lib/quic/QuicStandaloneServer.java | 185 + .../test/lib/quic/RetryCodingContext.java | 82 + .../qpack/BlockingDecodingTest.java | 374 ++ .../qpack/DecoderInstructionsReaderTest.java | 165 + .../qpack/DecoderInstructionsWriterTest.java | 177 + .../qpack/DecoderSectionSizeLimitTest.java | 265 + .../net/httpclient/qpack/DecoderTest.java | 392 ++ ...namicTableFieldLineRepresentationTest.java | 404 ++ .../httpclient/qpack/DynamicTableTest.java | 366 ++ .../qpack/EncoderDecoderConnectionTest.java | 241 + .../qpack/EncoderDecoderConnector.java | 509 ++ .../httpclient/qpack/EncoderDecoderTest.java | 442 ++ .../qpack/EncoderInstructionsReaderTest.java | 373 ++ .../qpack/EncoderInstructionsWriterTest.java | 343 ++ .../net/httpclient/qpack/EncoderTest.java | 670 +++ .../httpclient/qpack/EntriesEvictionTest.java | 190 + .../qpack/FieldSectionPrefixTest.java | 133 + .../qpack/IntegerReaderMaxValuesTest.java | 84 + .../qpack/StaticTableFieldsTest.java | 182 + .../qpack/StringLengthLimitsTest.java | 482 ++ .../httpclient/qpack/TablesIndexerTest.java | 196 + .../qpack/UnacknowledgedInsertionTest.java | 243 + .../net/httpclient/quic/AckElicitingTest.java | 729 +++ .../net/httpclient/quic/AckFrameTest.java | 387 ++ .../httpclient/quic/BuffersReaderTest.java | 520 ++ .../httpclient/quic/BuffersReaderVLTest.java | 325 ++ .../httpclient/quic/ConnectionIDSTest.java | 170 + .../quic/CryptoWriterQueueTest.java | 73 + .../net/httpclient/quic/KeyUpdateTest.java | 272 + .../net/httpclient/quic/OrderedFlowTest.java | 401 ++ .../httpclient/quic/PacketEncodingTest.java | 1440 ++++++ .../net/httpclient/quic/PacketLossTest.java | 253 + .../httpclient/quic/PacketNumbersTest.java | 159 + .../quic/PacketSpaceManagerTest.java | 1094 +++++ .../quic/QuicFramesDecoderTest.java | 298 ++ .../quic/QuicRequestResponseTest.java | 172 + .../quic/StatelessResetReceiptTest.java | 303 ++ .../httpclient/quic/VariableLengthTest.java | 348 ++ .../quic/VersionNegotiationTest.java | 146 + .../quic/quic-tls-keylimits-java.security | 4 + .../quic/tls/PacketEncryptionTest.java | 457 ++ .../tls/QuicTLSEngineBadParametersTest.java | 122 + .../quic/tls/QuicTLSEngineFailedALPNTest.java | 110 + .../QuicTLSEngineMissingParametersTest.java | 112 + .../quic/tls/Quicv2PacketEncryptionTest.java | 451 ++ .../ssl/QuicTLSEngineImplAccessor.java | 59 + .../httpclient/ssltest/CertificateTest.java | 7 +- .../java/net/httpclient/ssltest/Server.java | 10 +- .../httpclient/ssltest/TlsVersionTest.java | 14 +- .../websocket/HandshakeUrlEncodingTest.java | 1 - .../httpclient/websocket/ReaderDriver.java | 2 +- .../httpclient/whitebox/AltSvcFrameTest.java | 239 + .../whitebox/AltSvcRegistryTest.java | 161 + .../internal/net/http/HttpClientAccess.java | 10 + .../quic/packets/QuicPacketNumbersTest.java | 137 + 473 files changed, 104314 insertions(+), 2646 deletions(-) create mode 100644 src/java.base/share/classes/jdk/internal/net/quic/QuicKeyUnavailableException.java create mode 100644 src/java.base/share/classes/jdk/internal/net/quic/QuicOneRttContext.java create mode 100644 src/java.base/share/classes/jdk/internal/net/quic/QuicTLSContext.java create mode 100644 src/java.base/share/classes/jdk/internal/net/quic/QuicTLSEngine.java create mode 100644 src/java.base/share/classes/jdk/internal/net/quic/QuicTransportErrors.java create mode 100644 src/java.base/share/classes/jdk/internal/net/quic/QuicTransportException.java create mode 100644 src/java.base/share/classes/jdk/internal/net/quic/QuicTransportParametersConsumer.java create mode 100644 src/java.base/share/classes/jdk/internal/net/quic/QuicVersion.java create mode 100644 src/java.base/share/classes/sun/security/ssl/QuicCipher.java create mode 100644 src/java.base/share/classes/sun/security/ssl/QuicEngineOutputRecord.java create mode 100644 src/java.base/share/classes/sun/security/ssl/QuicKeyManager.java create mode 100644 src/java.base/share/classes/sun/security/ssl/QuicTLSEngineImpl.java create mode 100644 src/java.base/share/classes/sun/security/ssl/QuicTransportParametersExtension.java create mode 100644 src/java.net.http/share/classes/java/net/http/HttpOption.java create mode 100644 src/java.net.http/share/classes/java/net/http/HttpRequestOptionImpl.java create mode 100644 src/java.net.http/share/classes/java/net/http/StreamLimitException.java create mode 100644 src/java.net.http/share/classes/java/net/http/UnsupportedProtocolVersionException.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/AltServicesRegistry.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/AltSvcProcessor.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/H3FrameOrderVerifier.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/Http3ClientImpl.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/Http3ClientProperties.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/Http3Connection.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/Http3ConnectionPool.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/Http3ExchangeImpl.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/Http3PendingConnections.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/Http3PushManager.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/Http3PushPromiseStream.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/Http3Stream.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/HttpQuicConnection.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/frame/AltSvcFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/ConnectionSettings.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/Http3Error.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/AbstractHttp3Frame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/CancelPushFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/DataFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/FramesDecoder.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/GoAwayFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/HeadersFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/Http3Frame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/Http3FrameType.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/MalformedFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/MaxPushIdFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/PartialFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/PushPromiseFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/SettingsFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/UnknownFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/Http3Streams.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/PeerUniStreamDispatcher.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/QueuingStreamPair.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/QuicStreamIntReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/UniStreamPair.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/Decoder.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/DecodingCallback.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/DynamicTable.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/Encoder.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/FieldSectionPrefix.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/HeaderField.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/HeadersTable.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/InsertionPolicy.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/QPACK.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/QPackException.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/StaticTable.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/TableEntry.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/TablesIndexer.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/package-info.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/DecoderInstructionsReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/EncoderInstructionsReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineIndexedPostBaseReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineIndexedReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineLiteralsReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineNameRefPostBaseReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineNameReferenceReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/HeaderFrameReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/IntegerReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/ReaderError.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/StringReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/BinaryRepresentationWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/DecoderInstructionsWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderDuplicateEntryWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderDynamicTableCapacityWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInsertIndexedNameWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInsertLiteralNameWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInstructionsWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineIndexedNameWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineIndexedWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineLiteralsWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineSectionPrefixWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/HeaderFrameWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/IntegerWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/StringWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/BuffersReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/CodingContext.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/ConnectionTerminator.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/ConnectionTerminatorImpl.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/IdleTimeoutManager.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/LocalConnIdManager.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/OrderedFlow.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/PacketEmitter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/PacketSpaceManager.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/PeerConnIdManager.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/PeerConnectionId.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicClient.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicCongestionController.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnection.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionId.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionIdFactory.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionImpl.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicEndpoint.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicInstance.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicPacketReceiver.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicRenoCongestionController.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicRttEstimator.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicSelector.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicStreamLimitException.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTimedEvent.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTimerQueue.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTransportParameters.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/TerminationCause.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/VariableLengthEncoder.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/AckFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/ConnectionCloseFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/CryptoFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/DataBlockedFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/HandshakeDoneFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxDataFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxStreamDataFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxStreamsFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/NewConnectionIDFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/NewTokenFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PaddingFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PathChallengeFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PathResponseFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PingFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/QuicFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/ResetStreamFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/RetireConnectionIDFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StopSendingFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamDataBlockedFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamsBlockedFrame.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/package-info.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/HandshakePacket.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/InitialPacket.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/LongHeader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/LongHeaderPacket.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/OneRttPacket.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/PacketSpace.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacket.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketDecoder.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketEncoder.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketNumbers.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/RetryPacket.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/ShortHeaderPacket.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/VersionNegotiationPacket.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/ZeroRttPacket.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/AbstractQuicStream.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/CryptoWriterQueue.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicBidiStream.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicBidiStreamImpl.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicConnectionStreams.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicReceiverStream.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicReceiverStreamImpl.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicSenderStream.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicSenderStreamImpl.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStream.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreamReader.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreamWriter.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreams.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/StreamCreationPermit.java create mode 100644 src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/StreamWriterQueue.java create mode 100644 test/jdk/java/net/httpclient/AltServiceUsageTest.java create mode 100644 test/jdk/java/net/httpclient/BasicHTTP2Test.java create mode 100644 test/jdk/java/net/httpclient/BasicHTTP3Test.java rename test/jdk/java/net/httpclient/{http2/ExpectContinueResetTest.java => CancelledPartialResponseTest.java} (52%) create mode 100644 test/jdk/java/net/httpclient/HeadersLowerCaseTest.java create mode 100644 test/jdk/java/net/httpclient/IdleConnectionTimeoutTest.java create mode 100644 test/jdk/java/net/httpclient/ImmutableSSLSessionTest.java create mode 100644 test/jdk/java/net/httpclient/access/java.net.http/jdk/internal/net/http/Http3ConnectionAccess.java create mode 100644 test/jdk/java/net/httpclient/access/java.net.http/jdk/internal/net/http/common/ImmutableSSLSessionAccess.java create mode 100644 test/jdk/java/net/httpclient/altsvc/AltServiceReasonableAssurance.java create mode 100644 test/jdk/java/net/httpclient/altsvc/altsvc-dns-hosts.txt create mode 100644 test/jdk/java/net/httpclient/debug/java.net.http/jdk/internal/net/http/common/TestLoggerUtil.java delete mode 100644 test/jdk/java/net/httpclient/http2/IdleConnectionTimeoutTest.java create mode 100644 test/jdk/java/net/httpclient/http2/SimpleGet.java create mode 100644 test/jdk/java/net/httpclient/http3/BadCipherSuiteErrorTest.java create mode 100644 test/jdk/java/net/httpclient/http3/FramesDecoderTest.java create mode 100644 test/jdk/java/net/httpclient/http3/GetHTTP3Test.java create mode 100644 test/jdk/java/net/httpclient/http3/H3BadHeadersTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3BasicTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3ConcurrentPush.java create mode 100644 test/jdk/java/net/httpclient/http3/H3ConnectionPoolTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3DataLimitsTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3ErrorHandlingTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3FixedThreadPoolTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3GoAwayTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3HeaderSizeLimitTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3HeadersEncoding.java create mode 100644 test/jdk/java/net/httpclient/http3/H3ImplicitPushCancel.java create mode 100644 test/jdk/java/net/httpclient/http3/H3InsertionsLimitTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3MalformedResponseTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3MaxInitialTimeoutTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3MemoryHandlingTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3MultipleConnectionsToSameHost.java create mode 100644 test/jdk/java/net/httpclient/http3/H3ProxyTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3PushCancel.java create mode 100644 test/jdk/java/net/httpclient/http3/H3QuicTLSConnection.java create mode 100644 test/jdk/java/net/httpclient/http3/H3RedirectTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3ServerPush.java create mode 100644 test/jdk/java/net/httpclient/http3/H3ServerPushCancel.java create mode 100644 test/jdk/java/net/httpclient/http3/H3ServerPushTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3ServerPushWithDiffTypes.java create mode 100644 test/jdk/java/net/httpclient/http3/H3SimpleGet.java create mode 100644 test/jdk/java/net/httpclient/http3/H3SimplePost.java create mode 100644 test/jdk/java/net/httpclient/http3/H3SimpleTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3StopSendingTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3StreamLimitReachedTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3Timeout.java create mode 100644 test/jdk/java/net/httpclient/http3/H3UnsupportedSSLParametersTest.java create mode 100644 test/jdk/java/net/httpclient/http3/H3UserInfoTest.java create mode 100644 test/jdk/java/net/httpclient/http3/HTTP3NoBodyTest.java create mode 100644 test/jdk/java/net/httpclient/http3/Http3ExpectContinueTest.java create mode 100644 test/jdk/java/net/httpclient/http3/PeerUniStreamDispatcherTest.java create mode 100644 test/jdk/java/net/httpclient/http3/PostHTTP3Test.java create mode 100644 test/jdk/java/net/httpclient/http3/StopSendingTest.java create mode 100644 test/jdk/java/net/httpclient/http3/StreamLimitTest.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/DynamicKeyStoreUtil.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/RequestPathMatcherUtil.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/TestUtil.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerConnection.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerExchange.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerStreamImpl.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3TestServer.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/UnknownOrReservedFrame.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/ClientConnection.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/ConnectedBidiStream.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/DatagramDeliveryPolicy.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/OutStream.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QueueInputStream.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServer.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServerConnection.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServerHandler.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicStandaloneServer.java create mode 100644 test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/RetryCodingContext.java create mode 100644 test/jdk/java/net/httpclient/qpack/BlockingDecodingTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/DecoderInstructionsReaderTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/DecoderInstructionsWriterTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/DecoderSectionSizeLimitTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/DecoderTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/DynamicTableFieldLineRepresentationTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/DynamicTableTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/EncoderDecoderConnectionTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/EncoderDecoderConnector.java create mode 100644 test/jdk/java/net/httpclient/qpack/EncoderDecoderTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/EncoderInstructionsReaderTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/EncoderInstructionsWriterTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/EncoderTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/EntriesEvictionTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/FieldSectionPrefixTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/IntegerReaderMaxValuesTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/StaticTableFieldsTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/StringLengthLimitsTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/TablesIndexerTest.java create mode 100644 test/jdk/java/net/httpclient/qpack/UnacknowledgedInsertionTest.java create mode 100644 test/jdk/java/net/httpclient/quic/AckElicitingTest.java create mode 100644 test/jdk/java/net/httpclient/quic/AckFrameTest.java create mode 100644 test/jdk/java/net/httpclient/quic/BuffersReaderTest.java create mode 100644 test/jdk/java/net/httpclient/quic/BuffersReaderVLTest.java create mode 100644 test/jdk/java/net/httpclient/quic/ConnectionIDSTest.java create mode 100644 test/jdk/java/net/httpclient/quic/CryptoWriterQueueTest.java create mode 100644 test/jdk/java/net/httpclient/quic/KeyUpdateTest.java create mode 100644 test/jdk/java/net/httpclient/quic/OrderedFlowTest.java create mode 100644 test/jdk/java/net/httpclient/quic/PacketEncodingTest.java create mode 100644 test/jdk/java/net/httpclient/quic/PacketLossTest.java create mode 100644 test/jdk/java/net/httpclient/quic/PacketNumbersTest.java create mode 100644 test/jdk/java/net/httpclient/quic/PacketSpaceManagerTest.java create mode 100644 test/jdk/java/net/httpclient/quic/QuicFramesDecoderTest.java create mode 100644 test/jdk/java/net/httpclient/quic/QuicRequestResponseTest.java create mode 100644 test/jdk/java/net/httpclient/quic/StatelessResetReceiptTest.java create mode 100644 test/jdk/java/net/httpclient/quic/VariableLengthTest.java create mode 100644 test/jdk/java/net/httpclient/quic/VersionNegotiationTest.java create mode 100644 test/jdk/java/net/httpclient/quic/quic-tls-keylimits-java.security create mode 100644 test/jdk/java/net/httpclient/quic/tls/PacketEncryptionTest.java create mode 100644 test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineBadParametersTest.java create mode 100644 test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineFailedALPNTest.java create mode 100644 test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineMissingParametersTest.java create mode 100644 test/jdk/java/net/httpclient/quic/tls/Quicv2PacketEncryptionTest.java create mode 100644 test/jdk/java/net/httpclient/quic/tls/java.base/sun/security/ssl/QuicTLSEngineImplAccessor.java create mode 100644 test/jdk/java/net/httpclient/whitebox/AltSvcFrameTest.java create mode 100644 test/jdk/java/net/httpclient/whitebox/AltSvcRegistryTest.java create mode 100644 test/jdk/java/net/httpclient/whitebox/java.net.http/jdk/internal/net/http/HttpClientAccess.java create mode 100644 test/jdk/jdk/internal/net/http/quic/packets/QuicPacketNumbersTest.java diff --git a/src/java.base/share/classes/jdk/internal/net/quic/QuicKeyUnavailableException.java b/src/java.base/share/classes/jdk/internal/net/quic/QuicKeyUnavailableException.java new file mode 100644 index 00000000000..89d15eb3439 --- /dev/null +++ b/src/java.base/share/classes/jdk/internal/net/quic/QuicKeyUnavailableException.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.quic; + + +import java.util.Objects; + +import jdk.internal.net.quic.QuicTLSEngine.KeySpace; + +/** + * Thrown when an operation on {@link QuicTLSEngine} doesn't have the necessary + * QUIC keys for encrypting or decrypting packets. This can either be because + * the keys aren't available for a particular {@linkplain KeySpace keyspace} or + * the keys for the {@code keyspace} have been discarded. + */ +public final class QuicKeyUnavailableException extends Exception { + @java.io.Serial + private static final long serialVersionUID = 8553365136999153478L; + + public QuicKeyUnavailableException(final String message, final KeySpace keySpace) { + super(Objects.requireNonNull(keySpace) + " keyspace: " + message); + } +} diff --git a/src/java.base/share/classes/jdk/internal/net/quic/QuicOneRttContext.java b/src/java.base/share/classes/jdk/internal/net/quic/QuicOneRttContext.java new file mode 100644 index 00000000000..fd0b405069c --- /dev/null +++ b/src/java.base/share/classes/jdk/internal/net/quic/QuicOneRttContext.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.quic; + +/** + * Supplies contextual 1-RTT information that's available in the QUIC implementation of the + * {@code java.net.http} module, to the QUIC TLS layer in the {@code java.base} module. + */ +public interface QuicOneRttContext { + + /** + * {@return the largest packet number that was acknowledged by + * the peer in the 1-RTT packet space} + */ + long getLargestPeerAckedPN(); +} diff --git a/src/java.base/share/classes/jdk/internal/net/quic/QuicTLSContext.java b/src/java.base/share/classes/jdk/internal/net/quic/QuicTLSContext.java new file mode 100644 index 00000000000..5cf0c999fb9 --- /dev/null +++ b/src/java.base/share/classes/jdk/internal/net/quic/QuicTLSContext.java @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.quic; + +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.util.Arrays; +import java.util.Objects; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLContextSpi; +import javax.net.ssl.SSLParameters; + +import sun.security.ssl.QuicTLSEngineImpl; +import sun.security.ssl.SSLContextImpl; + +/** + * Instances of this class act as a factory for creation + * of {@link QuicTLSEngine QUIC TLS engine}. + */ +public final class QuicTLSContext { + + // In this implementation, we have a dependency on + // sun.security.ssl.SSLContextImpl. We can only support + // Quic on SSLContext instances created by the default + // SunJSSE Provider + private final SSLContextImpl sslCtxImpl; + + /** + * {@return {@code true} if the given {@code sslContext} supports QUIC TLS, {@code false} otherwise} + * @param sslContext an {@link SSLContext} + */ + public static boolean isQuicCompatible(final SSLContext sslContext) { + boolean parametersSupported = isQuicCompatible(sslContext.getSupportedSSLParameters()); + if (!parametersSupported) { + return false; + } + // horrible hack - what we do here is try and get hold of a SSLContext + // that has already been initialised and configured with the HttpClient. + // We see if that SSLContext is created using an implementation of + // sun.security.ssl.SSLContextImpl. Since there's no API + // available to get hold of that underlying implementation, we use + // MethodHandle lookup to get access to the field which holds that + // detail. + final Object underlyingImpl = CONTEXT_SPI.get(sslContext); + if (!(underlyingImpl instanceof SSLContextImpl ssci)) { + return false; + } + return ssci.isUsableWithQuic(); + } + + /** + * {@return {@code true} if protocols of the given {@code parameters} support QUIC TLS, {@code false} otherwise} + */ + public static boolean isQuicCompatible(SSLParameters parameters) { + String[] protocols = parameters.getProtocols(); + return protocols != null && Arrays.asList(protocols).contains("TLSv1.3"); + } + + private static SSLContextImpl getSSLContextImpl( + final SSLContext sslContext) { + final Object underlyingImpl = CONTEXT_SPI.get(sslContext); + assert underlyingImpl instanceof SSLContextImpl; + return (SSLContextImpl) underlyingImpl; + } + + /** + * Constructs a QuicTLSContext for the given {@code sslContext} + * + * @param sslContext The SSLContext + * @throws IllegalArgumentException If the passed {@code sslContext} isn't + * supported by the QuicTLSContext + * @see #isQuicCompatible(SSLContext) + */ + public QuicTLSContext(final SSLContext sslContext) { + Objects.requireNonNull(sslContext); + if (!isQuicCompatible(sslContext)) { + throw new IllegalArgumentException( + "Cannot construct a QUIC TLS context with the given SSLContext"); + } + this.sslCtxImpl = getSSLContextImpl(sslContext); + } + + /** + * Creates a {@link QuicTLSEngine} using this context + *

+ * This method does not provide hints for session caching. + * + * @return the newly created QuicTLSEngine + */ + public QuicTLSEngine createEngine() { + return createEngine(null, -1); + } + + /** + * Creates a {@link QuicTLSEngine} using this context using + * advisory peer information. + *

+ * The provided parameters will be used as hints for session caching. + * The {@code peerHost} parameter will be used in the server_name extension, + * unless overridden later. + * + * @param peerHost The peer hostname or IP address. Can be null. + * @param peerPort The peer port, can be -1 if the port is unknown + * @return the newly created QuicTLSEngine + */ + public QuicTLSEngine createEngine(final String peerHost, final int peerPort) { + return new QuicTLSEngineImpl(this.sslCtxImpl, peerHost, peerPort); + } + + // This VarHandle is used to access the SSLContext::contextSpi + // field which is not publicly accessible. + // In this implementation, Quic is only supported for SSLContext + // instances whose underlying implementation is provided by a + // sun.security.ssl.SSLContextImpl + private static final VarHandle CONTEXT_SPI; + static { + try { + final MethodHandles.Lookup lookup = + MethodHandles.privateLookupIn(SSLContext.class, + MethodHandles.lookup()); + final VarHandle vh = lookup.findVarHandle(SSLContext.class, + "contextSpi", SSLContextSpi.class); + CONTEXT_SPI = vh; + } catch (Exception x) { + throw new ExceptionInInitializerError(x); + } + } +} + diff --git a/src/java.base/share/classes/jdk/internal/net/quic/QuicTLSEngine.java b/src/java.base/share/classes/jdk/internal/net/quic/QuicTLSEngine.java new file mode 100644 index 00000000000..70ed86bbf01 --- /dev/null +++ b/src/java.base/share/classes/jdk/internal/net/quic/QuicTLSEngine.java @@ -0,0 +1,508 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.quic; + +import javax.crypto.AEADBadTagException; +import javax.crypto.ShortBufferException; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Set; +import java.util.function.IntFunction; + +/** + * One instance of these per QUIC connection. Configuration methods not shown + * but would be similar to SSLEngine. + */ +public interface QuicTLSEngine { + + /** + * Represents the encryption level associated with a packet encryption or + * decryption. A QUIC connection has a current keyspace for sending and + * receiving which can be queried. + */ + enum KeySpace { + INITIAL, + HANDSHAKE, + RETRY, // Special algorithm used for this packet + ZERO_RTT, + ONE_RTT + } + + enum HandshakeState { + /** + * Need to receive a CRYPTO frame + */ + NEED_RECV_CRYPTO, + /** + * Need to receive a HANDSHAKE_DONE frame from server to complete the + * handshake, but application data can be sent in this state (client + * only state). + */ + NEED_RECV_HANDSHAKE_DONE, + /** + * Need to send a CRYPTO frame + */ + NEED_SEND_CRYPTO, + /** + * Need to send a HANDSHAKE_DONE frame to complete the handshake, but + * application data can be sent in this state (server only state) + */ + NEED_SEND_HANDSHAKE_DONE, + /** + * Need to execute a task + */ + NEED_TASK, + /** + * Handshake is confirmed, as specified in section 4.1.2 of RFC-9001 + */ + // On client side this happens when client receives HANDSHAKE_DONE + // frame. On server side this happens when the TLS stack has both + // sent a Finished message and verified the peer's Finished message. + HANDSHAKE_CONFIRMED, + } + + /** + * {@return the QUIC versions supported by this engine} + */ + Set getSupportedQuicVersions(); + + /** + * If {@code mode} is {@code true} then configures this QuicTLSEngine to + * operate in client mode. If {@code false}, then this QuicTLSEngine + * operates in server mode. + * + * @param mode true to make this QuicTLSEngine operate in client + * mode, false otherwise + */ + void setUseClientMode(boolean mode); + + /** + * {@return true if this QuicTLSEngine is operating in client mode, false + * otherwise} + */ + boolean getUseClientMode(); + + /** + * {@return the SSLParameters in effect for this engine.} + */ + SSLParameters getSSLParameters(); + + /** + * Sets the {@code SSLParameters} to be used by this engine + * + * @param sslParameters the SSLParameters + * @throws IllegalArgumentException if + * {@linkplain SSLParameters#getProtocols() TLS protocol versions} on the + * {@code sslParameters} is either empty or contains anything other + * than {@code TLSv1.3} + * @throws NullPointerException if {@code sslParameters} is null + */ + void setSSLParameters(SSLParameters sslParameters); + + /** + * {@return the most recent application protocol value negotiated by the + * engine. Returns null if no application protocol has yet been negotiated + * by the engine} + */ + String getApplicationProtocol(); + + /** + * {@return the SSLSession} + * + * @see SSLEngine#getSession() + */ + SSLSession getSession(); + + /** + * Returns the SSLSession being constructed during a QUIC handshake. + * + * @return null if this instance is not currently handshaking, or if the + * current handshake has not progressed far enough to create + * a basic SSLSession. Otherwise, this method returns the + * {@code SSLSession} currently being negotiated. + * + * @see SSLEngine#getHandshakeSession() + */ + SSLSession getHandshakeSession(); + + /** + * Returns the current handshake state of the connection. Sometimes packets + * that could be decrypted can be received before the handshake has + * completed, but should not be decrypted until it is complete + * + * @return the HandshakeState + */ + HandshakeState getHandshakeState(); + + /** + * Returns true if the TLS handshake is considered complete. + *

+ * The TLS handshake is considered complete when the TLS stack + * has reported that the handshake is complete. This happens when + * the TLS stack has both sent a {@code Finished} message and verified + * the peer's {@code Finished} message. + * + * @return true if TLS handshake is complete, false otherwise. + */ + boolean isTLSHandshakeComplete(); + + /** + * {@return the current sending key space (encryption level)} + */ + KeySpace getCurrentSendKeySpace(); + + /** + * Checks whether the keys for the given key space are available. + *

+ * Keys are available when they are already computed and not discarded yet. + * + * @param keySpace key space to check + * @return true if the given keys are available + */ + boolean keysAvailable(KeySpace keySpace); + + /** + * Discard the keys used by the {@code keySpace}. + *

+ * Once the keys for a particular {@code keySpace} have been discarded, the + * keySpace will no longer be able to + * {@linkplain #encryptPacket(KeySpace, long, IntFunction, + * ByteBuffer, ByteBuffer) encrypt} or + * {@linkplain #decryptPacket(KeySpace, long, int, ByteBuffer, int, ByteBuffer) + * decrypt} packets. + * + * @param keySpace The keyspace whose current keys should be discarded + */ + void discardKeys(KeySpace keySpace); + + /** + * Provide quic_transport_parameters for inclusion in handshake message. + * + * @param params encoded quic_transport_parameters + */ + void setLocalQuicTransportParameters(ByteBuffer params); + + /** + * Reset the handshake state and produce a new ClientHello message. + * + * When a Quic client receives a Version Negotiation packet, + * it restarts the handshake by calling this method after updating the + * {@linkplain #setLocalQuicTransportParameters(ByteBuffer) transport parameters} + * with the new version information. + */ + void restartHandshake() throws IOException; + + /** + * Set consumer for quic_transport_parameters sent by the remote side. + * Consumer will receive a byte buffer containing the value of + * quic_transport_parameters extension sent by the remote endpoint. + * + * @param consumer consumer for remote quic transport parameters + */ + void setRemoteQuicTransportParametersConsumer( + QuicTransportParametersConsumer consumer); + + /** + * Derive initial keys for the given QUIC version and connection ID + * @param quicVersion QUIC protocol version + * @param connectionId initial destination connection ID + * @throws IllegalArgumentException if the {@code quicVersion} isn't + * {@linkplain #getSupportedQuicVersions() supported} on this + * {@code QuicTLSEngine} + */ + void deriveInitialKeys(QuicVersion quicVersion, ByteBuffer connectionId) throws IOException; + + /** + * Get the sample size for header protection algorithm + * + * @param keySpace Packet key space + * @return required sample size for header protection + * @throws IllegalArgumentException when keySpace does not require + * header protection + */ + int getHeaderProtectionSampleSize(KeySpace keySpace); + + /** + * Compute the header protection mask for the given sample, + * packet key space and direction (incoming/outgoing). + * + * @param keySpace Packet key space + * @param incoming true for incoming packets, false for outgoing + * @param sample sampled data + * @return mask bytes, at least 5. + * @throws IllegalArgumentException when keySpace does not require + * header protection or sample length is different from required + * @see #getHeaderProtectionSampleSize(KeySpace) + * @spec https://www.rfc-editor.org/rfc/rfc9001.html#name-header-protection-applicati + * RFC 9001, Section 5.4.1 Header Protection Application + */ + ByteBuffer computeHeaderProtectionMask(KeySpace keySpace, + boolean incoming, ByteBuffer sample) + throws QuicKeyUnavailableException, QuicTransportException; + + /** + * Get the authentication tag size. Encryption adds this number of bytes. + * + * @return authentication tag size + */ + int getAuthTagSize(); + + /** + * Encrypt into {@code output}, the given {@code packetPayload} bytes using the + * keys for the given {@code keySpace}. + *

+ * Before encrypting the {@code packetPayload}, this method invokes the {@code headerGenerator} + * passing it the key phase corresponding to the encryption key that's in use. + * For {@code KeySpace}s where key phase isn't applicable, the {@code headerGenerator} will + * be invoked with a value of {@code 0} for the key phase. + *

+ * The {@code headerGenerator} is expected to return a {@code ByteBuffer} representing the + * packet header and where applicable, the returned header must contain the key phase + * that was passed to the {@code headerGenerator}. The packet header will be used as + * the Additional Authentication Data (AAD) for encrypting the {@code packetPayload}. + *

+ * Upon return, the {@code output} will contain the encrypted packet payload bytes + * and the authentication tag. The {@code packetPayload} and the packet header, returned + * by the {@code headerGenerator}, will have their {@code position} equal to their + * {@code limit}. The limit of either of those buffers will not have changed. + *

+ * It is recommended to do the encryption in place by using slices of a bigger + * buffer as the input and output buffer: + *

+     *          +--------+-------------------+
+     * input:   | header | plaintext payload |
+     *          +--------+-------------------+----------+
+     * output:           | encrypted payload | AEAD tag |
+     *                   +-------------------+----------+
+     * 
+ * + * @param keySpace Packet key space + * @param packetNumber full packet number + * @param headerGenerator an {@link IntFunction} which takes a key phase and returns + * the packet header + * @param packetPayload buffer containing unencrypted packet payload + * @param output buffer into which the encrypted packet payload will be written + * @throws QuicKeyUnavailableException if keys are not available + * @throws QuicTransportException if encrypting the packet would result + * in exceeding the AEAD cipher confidentiality limit + */ + void encryptPacket(KeySpace keySpace, long packetNumber, + IntFunction headerGenerator, + ByteBuffer packetPayload, + ByteBuffer output) + throws QuicKeyUnavailableException, QuicTransportException, ShortBufferException; + + /** + * Decrypt the given packet bytes using keys for the given packet key space. + * Header protection must be removed before calling this method. + *

+ * The input buffer contains the packet header and the encrypted packet payload. + * The packet header (first {@code headerLength} bytes of the input buffer) + * is consumed by this method, but is not decrypted. + * The packet payload (bytes following the packet header) is decrypted + * by this method. This method consumes the entire input buffer. + *

+ * The decrypted payload bytes are written + * to the output buffer. + *

+ * It is recommended to do the decryption in place by using slices of a bigger + * buffer as the input and output buffer: + *

+     *          +--------+-------------------+----------+
+     * input:   | header | encrypted payload | AEAD tag |
+     *          +--------+-------------------+----------+
+     * output:           | decrypted payload |
+     *                   +-------------------+
+     * 
+ * + * @param keySpace Packet key space + * @param packetNumber full packet number + * @param keyPhase key phase bit (0 or 1) found on the packet, or -1 + * if the packet does not have a key phase bit + * @param packet buffer containing encrypted packet bytes + * @param headerLength length of the packet header + * @param output buffer where decrypted packet bytes will be stored + * @throws IllegalArgumentException if keyPhase bit is invalid + * @throws QuicKeyUnavailableException if keys are not available + * @throws AEADBadTagException if the provided packet's authentication tag + * is incorrect + * @throws QuicTransportException if decrypting the invalid packet resulted + * in exceeding the AEAD cipher integrity limit + */ + void decryptPacket(KeySpace keySpace, long packetNumber, int keyPhase, + ByteBuffer packet, int headerLength, ByteBuffer output) + throws IllegalArgumentException, QuicKeyUnavailableException, + AEADBadTagException, QuicTransportException, ShortBufferException; + + /** + * Sign the provided retry packet. Input buffer contains the retry packet + * payload. Integrity tag is stored in the output buffer. + * + * @param version Quic version + * @param originalConnectionId original destination connection ID, + * without length + * @param packet retry packet bytes without tag + * @param output buffer where integrity tag will be stored + * @throws ShortBufferException if output buffer is too short to + * hold the tag + * @throws IllegalArgumentException if originalConnectionId is + * longer than 255 bytes + * @throws IllegalArgumentException if {@code version} isn't + * {@linkplain #getSupportedQuicVersions() supported} + */ + void signRetryPacket(QuicVersion version, ByteBuffer originalConnectionId, + ByteBuffer packet, ByteBuffer output) throws ShortBufferException, QuicTransportException; + + /** + * Verify the provided retry packet. + * + * @param version Quic version + * @param originalConnectionId original destination connection ID, + * without length + * @param packet retry packet bytes with tag + * @throws AEADBadTagException if integrity tag is invalid + * @throws IllegalArgumentException if originalConnectionId is + * longer than 255 bytes + * @throws IllegalArgumentException if {@code version} isn't + * {@linkplain #getSupportedQuicVersions() supported} + */ + void verifyRetryPacket(QuicVersion version, ByteBuffer originalConnectionId, + ByteBuffer packet) throws AEADBadTagException, QuicTransportException; + + /** + * If the current handshake state is {@link HandshakeState#NEED_SEND_CRYPTO} + * meaning that a CRYPTO frame needs to be sent then this method is called + * to obtain the contents of the frame. Current handshake state + * can be obtained from {@link #getHandshakeState()}, and the current + * key space can be obtained with {@link #getCurrentSendKeySpace()} + * The bytes returned by this call are used to build a CRYPTO frame. + * + * @param keySpace the key space of the packet in which the + * requested data will be placed + * @return buffer containing data that will be put by caller in a CRYPTO + * frame, or null if there are no more handshake bytes to send in + * this key space at this time. + */ + ByteBuffer getHandshakeBytes(KeySpace keySpace) throws IOException; + + /** + * This method consumes crypto stream. + * + * @param keySpace the key space of the packet in which the provided + * crypto data was encountered. + * @param payload contents of the next CRYPTO frame + * @throws IllegalArgumentException if keySpace is ZERORTT or + * payload is empty + * @throws QuicTransportException if the handshake failed + */ + void consumeHandshakeBytes(KeySpace keySpace, ByteBuffer payload) + throws QuicTransportException; + + /** + * Returns a delegated {@code Runnable} task for + * this {@code QuicTLSEngine}. + *

+ * {@code QuicTLSEngine} operations may require the results of + * operations that block, or may take an extended period of time to + * complete. This method is used to obtain an outstanding {@link + * java.lang.Runnable} operation (task). Each task must be assigned + * a thread (possibly the current) to perform the {@link + * java.lang.Runnable#run() run} operation. Once the + * {@code run} method returns, the {@code Runnable} object + * is no longer needed and may be discarded. + *

+ * A call to this method will return each outstanding task + * exactly once. + *

+ * Multiple delegated tasks can be run in parallel. + * + * @return a delegated {@code Runnable} task, or null + * if none are available. + */ + Runnable getDelegatedTask(); + + /** + * Called to check if a {@code HANDSHAKE_DONE} frame needs to be sent by the + * server. This method will only be called for a {@code QuicTLSEngine} which + * is in {@linkplain #getUseClientMode() server mode}. If the current TLS handshake + * state is + * {@link HandshakeState#NEED_SEND_HANDSHAKE_DONE + * NEED_SEND_HANDSHAKE_DONE} then this method returns {@code true} and + * advances the TLS handshake state to + * {@link HandshakeState#HANDSHAKE_CONFIRMED HANDSHAKE_CONFIRMED}. Else + * returns {@code false}. + * + * @return true if handshake state was {@code NEED_SEND_HANDSHAKE_DONE}, + * false otherwise + * @throws IllegalStateException If this {@code QuicTLSEngine} is + * not in server mode + */ + boolean tryMarkHandshakeDone() throws IllegalStateException; + + /** + * Called when HANDSHAKE_DONE message is received from the server. This + * method will only be called for a {@code QuicTLSEngine} which is in + * {@linkplain #getUseClientMode() client mode}. If the current TLS handshake state + * is + * {@link HandshakeState#NEED_RECV_HANDSHAKE_DONE + * NEED_RECV_HANDSHAKE_DONE} then this method returns {@code true} and + * advances the TLS handshake state to + * {@link HandshakeState#HANDSHAKE_CONFIRMED HANDSHAKE_CONFIRMED}. Else + * returns {@code false}. + * + * @return true if handshake state was {@code NEED_RECV_HANDSHAKE_DONE}, + * false otherwise + * @throws IllegalStateException if this {@code QuicTLSEngine} is + * not in client mode + */ + boolean tryReceiveHandshakeDone() throws IllegalStateException; + + /** + * Called when the client and the server, during the connection creation + * handshake, have settled on a Quic version to use for the connection. This + * can happen either due to an explicit version negotiation (as outlined in + * Quic RFC) or the server accepting the Quic version that the client chose + * in its first INITIAL packet. In either of those cases, this method will + * be called. + * + * @param quicVersion the negotiated {@code QuicVersion} + * @throws IllegalArgumentException if the {@code quicVersion} isn't + * {@linkplain #getSupportedQuicVersions() supported} on this engine + */ + void versionNegotiated(QuicVersion quicVersion); + + /** + * Sets the {@link QuicOneRttContext} on the {@code QuicTLSEngine}. + *

The {@code ctx} will be used by the {@code QuicTLSEngine} to access contextual 1-RTT + * data that might be required for the TLS operations. + * + * @param ctx the 1-RTT context to set + * @throws NullPointerException if {@code ctx} is null + */ + void setOneRttContext(QuicOneRttContext ctx); +} diff --git a/src/java.base/share/classes/jdk/internal/net/quic/QuicTransportErrors.java b/src/java.base/share/classes/jdk/internal/net/quic/QuicTransportErrors.java new file mode 100644 index 00000000000..d081458d40e --- /dev/null +++ b/src/java.base/share/classes/jdk/internal/net/quic/QuicTransportErrors.java @@ -0,0 +1,349 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.quic; + +import sun.security.ssl.Alert; + +import java.util.Optional; +import java.util.stream.Stream; + +/** + * An enum to model Quic transport errors. + * Some errors have a single possible code value, some, like + * {@link #CRYPTO_ERROR} have a range of possible values. + * Usually, the value (a long) would be used instead of the + * enum, but the enum itself can be useful - for instance in + * switch statements. + * This enum models QUIC transport error codes as defined in + * RFC 9000, section 20.1. + */ +public enum QuicTransportErrors { + /** + * No error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * An endpoint uses this with CONNECTION_CLOSE to signal that
+     * the connection is being closed abruptly in the absence
+     * of any error.
+     * }
+ */ + NO_ERROR(0x00), + + /** + * Internal Error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * The endpoint encountered an internal error and cannot
+     * continue with the connection.
+     * }
+ */ + INTERNAL_ERROR(0x01), + + /** + * Connection refused error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * The server refused to accept a new connection.
+     * }
+ */ + CONNECTION_REFUSED(0x02), + + /** + * Flow control error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * An endpoint received more data than it permitted in its advertised data limits;
+     * see Section 4.
+     * }
+ * @see + * RFC 9000, Section 20.1: + *
{@code
+     * An endpoint received a frame for a stream identifier that exceeded its advertised
+     * stream limit for the corresponding stream type.
+     * }
+ */ + STREAM_LIMIT_ERROR(0x04), + + /** + * Stream state error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * An endpoint received a frame for a stream that was not in a state that permitted
+     * that frame; see Section 3.
+     * }
+ * @see + * RFC 9000, Section 20.1: + *
{@code
+     * (1) An endpoint received a STREAM frame containing data that exceeded the previously
+     *     established final size,
+     * (2) an endpoint received a STREAM frame or a RESET_STREAM frame containing a final
+     *     size that was lower than the size of stream data that was already received, or
+     * (3) an endpoint received a STREAM frame or a RESET_STREAM frame containing a
+     *     different final size to the one already established.
+     * }
+ */ + FINAL_SIZE_ERROR(0x06), + + /** + * Frame encoding error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * An endpoint received a frame that was badly formatted -- for instance,
+     * a frame of an unknown type or an ACK frame that has more
+     * acknowledgment ranges than the remainder of the packet could carry.
+     * }
+ */ + FRAME_ENCODING_ERROR(0x07), + + /** + * Transport parameter error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * An endpoint received transport parameters that were badly
+     * formatted, included an invalid value, omitted a mandatory
+     * transport parameter, included a forbidden transport
+     * parameter, or were otherwise in error.
+     * }
+ */ + TRANSPORT_PARAMETER_ERROR(0x08), + + /** + * Connection id limit error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * The number of connection IDs provided by the peer exceeds
+     * the advertised active_connection_id_limit.
+     * }
+ */ + CONNECTION_ID_LIMIT_ERROR(0x09), + + /** + * Protocol violiation error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * An endpoint detected an error with protocol compliance that
+     * was not covered by more specific error codes.
+     * }
+ */ + PROTOCOL_VIOLATION(0x0a), + + /** + * Invalid token error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * A server received a client Initial that contained an invalid Token field.
+     * }
+ */ + INVALID_TOKEN(0x0b), + + /** + * Application error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * The application or application protocol caused the connection to be closed.
+     * }
+ */ + APPLICATION_ERROR(0x0c), + + /** + * Crypto buffer exceeded error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * An endpoint has received more data in CRYPTO frames than it can buffer.
+     * }
+ */ + CRYPTO_BUFFER_EXCEEDED(0x0d), + + /** + * Key update error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * An endpoint detected errors in performing key updates; see Section 6 of [QUIC-TLS].
+     * }
+ * @see Section 6 of RFC 9001 [QUIC-TLS] + */ + KEY_UPDATE_ERROR(0x0e), + + /** + * AEAD limit reached error + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * An endpoint has reached the confidentiality or integrity limit
+     * for the AEAD algorithm used by the given connection.
+     * }
+ */ + AEAD_LIMIT_REACHED(0x0f), + + /** + * No viable path error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * An endpoint has determined that the network path is incapable of
+     * supporting QUIC. An endpoint is unlikely to receive a
+     * CONNECTION_CLOSE frame carrying this code except when the
+     * path does not support a large enough MTU.
+     * }
+ */ + NO_VIABLE_PATH(0x10), + + /** + * Error negotiating version. + * @spec https://www.rfc-editor.org/rfc/rfc9368#name-version-downgrade-preventio + * RFC 9368, Section 4 + */ + VERSION_NEGOTIATION_ERROR(0x11), + + /** + * Crypto error. + *

+ * From + * RFC 9000, Section 20.1: + *

{@code
+     * The cryptographic handshake failed. A range of 256 values is
+     * reserved for carrying error codes specific to the cryptographic
+     * handshake that is used. Codes for errors occurring when
+     * TLS is used for the cryptographic handshake are described
+     * in Section 4.8 of [QUIC-TLS].
+     * }
+ * @see Section 4.8 of RFC 9001 [QUIC-TLS] + */ + CRYPTO_ERROR(0x0100, 0x01ff); + + private final long from; + private final long to; + + QuicTransportErrors(long code) { + this(code, code); + } + + QuicTransportErrors(long from, long to) { + assert from <= to; + this.from = from; + this.to = to; + } + + /** + * {@return the code for this transport error, if this error + * {@linkplain #hasCode() has a single possible code value}, + * {@code -1} otherwise} + */ + public long code() { return hasCode() ? from : -1;} + + /** + * {@return true if this error has a single possible code value} + */ + public boolean hasCode() { return from == to; } + + /** + * {@return true if this error has a range of possible code values} + */ + public boolean hasRange() { return from < to;} + + /** + * {@return the first possible code value in the range, or the + * code value if this error has a single possible code value} + */ + public long from() {return from;} + + /** + * {@return the last possible code value in the range, or the + * code value if this error has a single possible code value} + */ + public long to() { return to; } + + /** + * Tells whether the given {@code code} value corresponds to + * this error. + * @param code an error code value + * @return true if the given {@code code} value corresponds to + * this error. + */ + boolean isFor(long code) { + return code >= from && code <= to; + } + + /** + * {@return the {@link QuicTransportErrors} instance corresponding + * to the given {@code code} value, if any} + * @param code a {@code code} value + */ + public static Optional ofCode(long code) { + return Stream.of(values()).filter(e -> e.isFor(code)).findAny(); + } + + public static String toString(long code) { + Optional c = Stream.of(values()).filter(e -> e.isFor(code)).findAny(); + if (c.isEmpty()) return "Unknown [0x"+Long.toHexString(code) + "]"; + if (c.get().hasCode()) return c.get().toString(); + if (c.get() == CRYPTO_ERROR) + return c.get() + "|" + Alert.nameOf((byte)code); + return c.get() + " [0x" + Long.toHexString(code) + "]"; + + } +} diff --git a/src/java.base/share/classes/jdk/internal/net/quic/QuicTransportException.java b/src/java.base/share/classes/jdk/internal/net/quic/QuicTransportException.java new file mode 100644 index 00000000000..3341cc527f2 --- /dev/null +++ b/src/java.base/share/classes/jdk/internal/net/quic/QuicTransportException.java @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.quic; + +/** + * Exception that wraps QUIC transport error codes. + * Thrown in response to packets or frames that violate QUIC protocol. + * This is a fatal exception; connection is always closed when this exception is caught. + * + *

For a list of errors see: + * https://www.rfc-editor.org/rfc/rfc9000.html#name-transport-error-codes + */ +public final class QuicTransportException extends Exception { + @java.io.Serial + private static final long serialVersionUID = 5259674758792412464L; + + private final QuicTLSEngine.KeySpace keySpace; + private final long frameType; + private final long errorCode; + + /** + * Constructs a new {@code QuicTransportException}. + * + * @param reason the reason why the exception occurred + * @param keySpace the key space in which the frame appeared. + * May be {@code null}, for instance, in + * case of {@link QuicTransportErrors#INTERNAL_ERROR}. + * @param frameType the frame type of the frame whose parsing / handling + * caused the error. + * May be 0 if not related to any specific frame. + * @param errorCode a quic transport error + */ + public QuicTransportException(String reason, QuicTLSEngine.KeySpace keySpace, + long frameType, QuicTransportErrors errorCode) { + super(reason); + this.keySpace = keySpace; + this.frameType = frameType; + this.errorCode = errorCode.code(); + } + + /** + * Constructs a new {@code QuicTransportException}. For use with TLS alerts. + * + * @param reason the reason why the exception occurred + * @param keySpace the key space in which the frame appeared. + * May be {@code null}, for instance, in + * case of {@link QuicTransportErrors#INTERNAL_ERROR}. + * @param frameType the frame type of the frame whose parsing / handling + * caused the error. + * May be 0 if not related to any specific frame. + * @param errorCode a quic transport error code + * @param cause the cause + */ + public QuicTransportException(String reason, QuicTLSEngine.KeySpace keySpace, + long frameType, long errorCode, Throwable cause) { + super(reason, cause); + this.keySpace = keySpace; + this.frameType = frameType; + this.errorCode = errorCode; + } + + /** + * {@return the reason to include in the {@code ConnectionCloseFrame}} + */ + public String getReason() { + return getMessage(); + } + + /** + * {@return the key space for which the error occurred, or {@code null}} + */ + public QuicTLSEngine.KeySpace getKeySpace() { + return keySpace; + } + + /** + * {@return the frame type for which the error occurred, or 0} + */ + public long getFrameType() { + return frameType; + } + + /** + * {@return the transport error that occurred} + */ + public long getErrorCode() { + return errorCode; + } +} diff --git a/src/java.base/share/classes/jdk/internal/net/quic/QuicTransportParametersConsumer.java b/src/java.base/share/classes/jdk/internal/net/quic/QuicTransportParametersConsumer.java new file mode 100644 index 00000000000..9429f6bf26f --- /dev/null +++ b/src/java.base/share/classes/jdk/internal/net/quic/QuicTransportParametersConsumer.java @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.quic; + +import java.nio.ByteBuffer; + +/** + * Interface for consumer of QUIC transport parameters, in wire-encoded format + */ +public interface QuicTransportParametersConsumer { + /** + * Consumes the provided QUIC transport parameters + * @param buffer byte buffer containing encoded quic transport parameters + * @throws QuicTransportException if buffer does not represent valid parameters + */ + void accept(ByteBuffer buffer) throws QuicTransportException; +} diff --git a/src/java.base/share/classes/jdk/internal/net/quic/QuicVersion.java b/src/java.base/share/classes/jdk/internal/net/quic/QuicVersion.java new file mode 100644 index 00000000000..14bbaba816d --- /dev/null +++ b/src/java.base/share/classes/jdk/internal/net/quic/QuicVersion.java @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.quic; + +import java.util.Collection; +import java.util.Comparator; +import java.util.Objects; +import java.util.Optional; + +/** + * Represents the Quic versions defined in their corresponding RFCs + */ +public enum QuicVersion { + // the version numbers are defined in their respective RFCs + QUIC_V1(1), // RFC-9000 + QUIC_V2(0x6b3343cf); // RFC 9369 + + // 32 bits unsigned integer representing the version as + // defined in RFC. This is the version number as sent + // in long headers packets (see RFC 9000). + private final int versionNumber; + + private QuicVersion(final int versionNumber) { + this.versionNumber = versionNumber; + } + + /** + * {@return the version number} + */ + public int versionNumber() { + return this.versionNumber; + } + + /** + * {@return the QuicVersion corresponding to the {@code versionNumber} or + * {@link Optional#empty() an empty Optional} if the {@code versionNumber} + * doesn't correspond to a Quic version} + * + * @param versionNumber The version number + */ + public static Optional of(int versionNumber) { + for (QuicVersion qv : QuicVersion.values()) { + if (qv.versionNumber == versionNumber) { + return Optional.of(qv); + } + } + return Optional.empty(); + } + + /** + * From among the {@code quicVersions}, selects a {@code QuicVersion} to be used in the + * first packet during connection initiation. + * + * @param quicVersions the available QUIC versions + * @return the QUIC version to use in the first packet + * @throws NullPointerException if {@code quicVersions} is null or any element + * in it is null + * @throws IllegalArgumentException if {@code quicVersions} is empty + */ + public static QuicVersion firstFlightVersion(final Collection quicVersions) { + Objects.requireNonNull(quicVersions); + if (quicVersions.isEmpty()) { + throw new IllegalArgumentException("Empty quic versions"); + } + if (quicVersions.size() == 1) { + return quicVersions.iterator().next(); + } + for (final QuicVersion version : quicVersions) { + if (version == QUIC_V1) { + // we always prefer QUIC v1 for first flight version + return QUIC_V1; + } + } + // the given versions did not have QUIC v1, which implies the + // only available first flight version is QUIC v2 + return QUIC_V2; + } +} diff --git a/src/java.base/share/classes/module-info.java b/src/java.base/share/classes/module-info.java index 2a51a0af38d..3ae84fdf198 100644 --- a/src/java.base/share/classes/module-info.java +++ b/src/java.base/share/classes/module-info.java @@ -190,6 +190,8 @@ module java.base { jdk.jlink; exports jdk.internal.logger to java.logging; + exports jdk.internal.net.quic to + java.net.http; exports jdk.internal.org.xml.sax to jdk.jfr; exports jdk.internal.org.xml.sax.helpers to @@ -260,6 +262,7 @@ module java.base { jdk.jfr; exports jdk.internal.util to java.desktop, + java.net.http, java.prefs, java.security.jgss, java.smartcardio, diff --git a/src/java.base/share/classes/sun/security/ssl/Alert.java b/src/java.base/share/classes/sun/security/ssl/Alert.java index 4e1ccf385c7..960b3f3b37d 100644 --- a/src/java.base/share/classes/sun/security/ssl/Alert.java +++ b/src/java.base/share/classes/sun/security/ssl/Alert.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2003, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2003, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -34,9 +34,9 @@ import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLProtocolException; /** - * SSL/(D)TLS Alter description + * SSL/(D)TLS Alert description */ -enum Alert { +public enum Alert { // Please refer to TLS Alert Registry for the latest (D)TLS Alert values: // https://www.iana.org/assignments/tls-parameters/ CLOSE_NOTIFY ((byte)0, "close_notify", false), @@ -103,7 +103,7 @@ enum Alert { return null; } - static String nameOf(byte id) { + public static String nameOf(byte id) { for (Alert al : Alert.values()) { if (al.id == id) { return al.description; diff --git a/src/java.base/share/classes/sun/security/ssl/AlpnExtension.java b/src/java.base/share/classes/sun/security/ssl/AlpnExtension.java index d44ec034411..aa5933ddab0 100644 --- a/src/java.base/share/classes/sun/security/ssl/AlpnExtension.java +++ b/src/java.base/share/classes/sun/security/ssl/AlpnExtension.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -71,7 +71,7 @@ final class AlpnExtension { /** * The "application_layer_protocol_negotiation" extension. - * + *

* See RFC 7301 for the specification of this extension. */ static final class AlpnSpec implements SSLExtensionSpec { @@ -344,6 +344,13 @@ final class AlpnExtension { // The producing happens in server side only. ServerHandshakeContext shc = (ServerHandshakeContext)context; + if (shc.sslConfig.isQuic) { + // RFC 9001: endpoints MUST use ALPN + throw shc.conContext.fatal( + Alert.NO_APPLICATION_PROTOCOL, + "Client did not offer application layer protocol"); + } + // Please don't use the previous negotiated application protocol. shc.applicationProtocol = ""; shc.conContext.applicationProtocol = ""; @@ -513,6 +520,15 @@ final class AlpnExtension { // The producing happens in client side only. ClientHandshakeContext chc = (ClientHandshakeContext)context; + if (chc.sslConfig.isQuic) { + // RFC 9001: QUIC clients MUST use error 0x0178 + // [no_application_protocol] to terminate a connection when + // ALPN negotiation fails + throw chc.conContext.fatal( + Alert.NO_APPLICATION_PROTOCOL, + "Server did not offer application layer protocol"); + } + // Please don't use the previous negotiated application protocol. chc.applicationProtocol = ""; chc.conContext.applicationProtocol = ""; diff --git a/src/java.base/share/classes/sun/security/ssl/CertificateMessage.java b/src/java.base/share/classes/sun/security/ssl/CertificateMessage.java index 609a81571ed..d4587d35ae9 100644 --- a/src/java.base/share/classes/sun/security/ssl/CertificateMessage.java +++ b/src/java.base/share/classes/sun/security/ssl/CertificateMessage.java @@ -1219,12 +1219,19 @@ final class CertificateMessage { certs.clone(), authType, engine); - } else { - SSLSocket socket = (SSLSocket)shc.conContext.transport; + } else if (shc.conContext.transport instanceof SSLSocket socket){ ((X509ExtendedTrustManager)tm).checkClientTrusted( certs.clone(), authType, socket); + } else if (shc.conContext.transport + instanceof QuicTLSEngineImpl qtlse) { + if (tm instanceof X509TrustManagerImpl tmImpl) { + tmImpl.checkClientTrusted(certs.clone(), authType, qtlse); + } else { + throw new CertificateException( + "QUIC only supports SunJSSE trust managers"); + } } } else { // Unlikely to happen, because we have wrapped the old @@ -1268,18 +1275,26 @@ final class CertificateMessage { try { X509TrustManager tm = chc.sslContext.getX509TrustManager(); - if (tm instanceof X509ExtendedTrustManager) { + if (tm instanceof X509ExtendedTrustManager x509ExtTm) { if (chc.conContext.transport instanceof SSLEngine engine) { - ((X509ExtendedTrustManager)tm).checkServerTrusted( + x509ExtTm.checkServerTrusted( certs.clone(), authType, engine); - } else { - SSLSocket socket = (SSLSocket)chc.conContext.transport; - ((X509ExtendedTrustManager)tm).checkServerTrusted( + } else if (chc.conContext.transport instanceof SSLSocket socket) { + x509ExtTm.checkServerTrusted( certs.clone(), authType, socket); + } else if (chc.conContext.transport instanceof QuicTLSEngineImpl qtlse) { + if (x509ExtTm instanceof X509TrustManagerImpl tmImpl) { + tmImpl.checkServerTrusted(certs.clone(), authType, qtlse); + } else { + throw new CertificateException( + "QUIC only supports SunJSSE trust managers"); + } + } else { + throw new AssertionError("Unexpected transport type"); } } else { // Unlikely to happen, because we have wrapped the old diff --git a/src/java.base/share/classes/sun/security/ssl/ClientHello.java b/src/java.base/share/classes/sun/security/ssl/ClientHello.java index 3e43921520d..c9432ea3979 100644 --- a/src/java.base/share/classes/sun/security/ssl/ClientHello.java +++ b/src/java.base/share/classes/sun/security/ssl/ClientHello.java @@ -568,7 +568,7 @@ final class ClientHello { } if (sessionId.length() == 0 && chc.maximumActiveProtocol.useTLS13PlusSpec() && - SSLConfiguration.useCompatibilityMode) { + chc.sslConfig.isUseCompatibilityMode()) { // In compatibility mode, the TLS 1.3 legacy_session_id // field MUST be non-empty, so a client not offering a // pre-TLS 1.3 session MUST generate a new 32-byte value. diff --git a/src/java.base/share/classes/sun/security/ssl/Finished.java b/src/java.base/share/classes/sun/security/ssl/Finished.java index 9421d12ec15..04fe61760d0 100644 --- a/src/java.base/share/classes/sun/security/ssl/Finished.java +++ b/src/java.base/share/classes/sun/security/ssl/Finished.java @@ -846,6 +846,16 @@ final class Finished { // update the context for the following key derivation shc.handshakeKeyDerivation = secretKD; + if (shc.sslConfig.isQuic) { + QuicTLSEngineImpl engine = + (QuicTLSEngineImpl) shc.conContext.transport; + try { + engine.deriveOneRTTKeys(); + } catch (IOException e) { + throw shc.conContext.fatal(Alert.INTERNAL_ERROR, + "Failure to derive application secrets", e); + } + } } catch (GeneralSecurityException gse) { throw shc.conContext.fatal(Alert.INTERNAL_ERROR, "Failure to derive application secrets", gse); @@ -1010,6 +1020,16 @@ final class Finished { // update the context for the following key derivation chc.handshakeKeyDerivation = secretKD; + if (chc.sslConfig.isQuic) { + QuicTLSEngineImpl engine = + (QuicTLSEngineImpl) chc.conContext.transport; + try { + engine.deriveOneRTTKeys(); + } catch (IOException e) { + throw chc.conContext.fatal(Alert.INTERNAL_ERROR, + "Failure to derive application secrets", e); + } + } } catch (GeneralSecurityException gse) { throw chc.conContext.fatal(Alert.INTERNAL_ERROR, "Failure to derive application secrets", gse); diff --git a/src/java.base/share/classes/sun/security/ssl/KeyUpdate.java b/src/java.base/share/classes/sun/security/ssl/KeyUpdate.java index c4549070f02..2b17c7406a3 100644 --- a/src/java.base/share/classes/sun/security/ssl/KeyUpdate.java +++ b/src/java.base/share/classes/sun/security/ssl/KeyUpdate.java @@ -269,6 +269,12 @@ final class KeyUpdate { HandshakeMessage message) throws IOException { // The producing happens in server side only. PostHandshakeContext hc = (PostHandshakeContext)context; + if (hc.sslConfig.isQuic) { + // Quic doesn't allow KEY_UPDATE TLS message. It has its own Quic specific + // key update mechanism, RFC-9001, section 6: + // Endpoints MUST NOT send a TLS KeyUpdate message. + return null; + } KeyUpdateMessage km = (KeyUpdateMessage)message; if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) { SSLLogger.fine( diff --git a/src/java.base/share/classes/sun/security/ssl/OutputRecord.java b/src/java.base/share/classes/sun/security/ssl/OutputRecord.java index 0fa831f6351..f2c30b3ff72 100644 --- a/src/java.base/share/classes/sun/security/ssl/OutputRecord.java +++ b/src/java.base/share/classes/sun/security/ssl/OutputRecord.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 1996, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1996, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -31,6 +31,8 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.concurrent.locks.ReentrantLock; + +import jdk.internal.net.quic.QuicTLSEngine; import sun.security.ssl.SSLCipher.SSLWriteCipher; /** @@ -154,6 +156,16 @@ abstract class OutputRecord throw new UnsupportedOperationException(); } + // apply to QuicEngine only + byte[] getHandshakeMessage() { + throw new UnsupportedOperationException(); + } + + // apply to QuicEngine only + QuicTLSEngine.KeySpace getHandshakeMessageKeySpace() { + throw new UnsupportedOperationException(); + } + // apply to SSLEngine only void encodeV2NoCipher() throws IOException { throw new UnsupportedOperationException(); diff --git a/src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java b/src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java index b06549b40e3..a4f87616245 100644 --- a/src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java +++ b/src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java @@ -47,17 +47,15 @@ final class PostHandshakeContext extends HandshakeContext { context.conSession.getLocalSupportedSignatureSchemes()); // Add the potential post-handshake consumers. - if (context.sslConfig.isClientMode) { + if (!context.sslConfig.isQuic) { handshakeConsumers.putIfAbsent( SSLHandshake.KEY_UPDATE.id, SSLHandshake.KEY_UPDATE); + } + if (context.sslConfig.isClientMode) { handshakeConsumers.putIfAbsent( SSLHandshake.NEW_SESSION_TICKET.id, SSLHandshake.NEW_SESSION_TICKET); - } else { - handshakeConsumers.putIfAbsent( - SSLHandshake.KEY_UPDATE.id, - SSLHandshake.KEY_UPDATE); } handshakeFinished = true; @@ -93,6 +91,15 @@ final class PostHandshakeContext extends HandshakeContext { static boolean isConsumable(TransportContext context, byte handshakeType) { if (handshakeType == SSLHandshake.KEY_UPDATE.id) { + // Quic doesn't allow KEY_UPDATE TLS message. It has its own + // Quic-specific key update mechanism, RFC-9001, section 6: + // Endpoints MUST NOT send a TLS KeyUpdate message. Endpoints + // MUST treat the receipt of a TLS KeyUpdate message as a + // connection error of type 0x010a, equivalent to a fatal + // TLS alert of unexpected_message; + if (context.sslConfig.isQuic) { + return false; + } // The KeyUpdate handshake message does not apply to TLS 1.2 and // previous protocols. return context.protocolVersion.useTLS13PlusSpec(); diff --git a/src/java.base/share/classes/sun/security/ssl/QuicCipher.java b/src/java.base/share/classes/sun/security/ssl/QuicCipher.java new file mode 100644 index 00000000000..c1d812e4c40 --- /dev/null +++ b/src/java.base/share/classes/sun/security/ssl/QuicCipher.java @@ -0,0 +1,699 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package sun.security.ssl; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.security.GeneralSecurityException; +import java.security.Security; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; + +import javax.crypto.AEADBadTagException; +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.ChaCha20ParameterSpec; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.IvParameterSpec; + +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import sun.security.util.KeyUtil; + +import static jdk.internal.net.quic.QuicTLSEngine.KeySpace.ONE_RTT; +import static sun.security.ssl.QuicTLSEngineImpl.BASE_CRYPTO_ERROR; + +abstract class QuicCipher { + private static final String + SEC_PROP_QUIC_TLS_KEY_LIMITS = "jdk.quic.tls.keyLimits"; + + private static final Map KEY_LIMITS; + + static { + final String propVal = Security.getProperty( + SEC_PROP_QUIC_TLS_KEY_LIMITS); + if (propVal == null) { + KEY_LIMITS = Map.of(); // no specific limits + } else { + final Map limits = new HashMap<>(); + for (final String entry : propVal.split(",")) { + // each entry is of the form + // example: + // AES/GCM/NoPadding 2^23 + // ChaCha20-Poly1305 -1 + final String[] parts = entry.trim().split(" "); + if (parts.length != 2) { + // TODO: exception type + throw new RuntimeException("invalid value for " + + SEC_PROP_QUIC_TLS_KEY_LIMITS + + " security property"); + } + final String cipher = parts[0]; + if (limits.containsKey(cipher)) { + throw new RuntimeException( + "key limit defined more than once for cipher " + + cipher); + } + final String limitVal = parts[1]; + final long limit; + final int index = limitVal.indexOf("^"); + if (index >= 0) { + // of the form x^y (example: 2^23) + limit = (long) Math.pow( + Integer.parseInt(limitVal.substring(0, index)), + Integer.parseInt(limitVal.substring(index + 1))); + } else { + limit = Long.parseLong(limitVal); + } + if (limit == 0 || limit < -1) { + // we allow -1 to imply no limits, but any other zero + // or negative value is invalid + // TODO: exception type + throw new RuntimeException("invalid value for " + + SEC_PROP_QUIC_TLS_KEY_LIMITS + + " security property"); + } + limits.put(cipher, limit); + } + KEY_LIMITS = Collections.unmodifiableMap(limits); + } + } + + private final CipherSuite cipherSuite; + private final QuicHeaderProtectionCipher hpCipher; + private final SecretKey baseSecret; + private final int keyPhase; + + protected QuicCipher(final CipherSuite cipherSuite, final SecretKey baseSecret, + final QuicHeaderProtectionCipher hpCipher, final int keyPhase) { + assert keyPhase == 0 || keyPhase == 1 : + "invalid key phase: " + keyPhase; + this.cipherSuite = cipherSuite; + this.baseSecret = baseSecret; + this.hpCipher = hpCipher; + this.keyPhase = keyPhase; + } + + final SecretKey getBaseSecret() { + return this.baseSecret; + } + + final CipherSuite getCipherSuite() { + return this.cipherSuite; + } + + final SecretKey getHeaderProtectionKey() { + return this.hpCipher.headerProtectionKey; + } + + final ByteBuffer computeHeaderProtectionMask(ByteBuffer sample) + throws QuicTransportException { + return hpCipher.computeHeaderProtectionMask(sample); + } + + final int getKeyPhase() { + return this.keyPhase; + } + + final void discard(boolean destroyHP) { + safeDiscard(this.baseSecret); + if (destroyHP) { + this.hpCipher.discard(); + } + this.doDiscard(); + } + + protected abstract void doDiscard(); + + static QuicReadCipher createReadCipher(final CipherSuite cipherSuite, + final SecretKey baseSecret, final SecretKey key, + final byte[] iv, final SecretKey hp, + final int keyPhase) throws GeneralSecurityException { + return switch (cipherSuite) { + case TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384 -> + new T13GCMReadCipher( + cipherSuite, baseSecret, key, iv, hp, keyPhase); + case TLS_CHACHA20_POLY1305_SHA256 -> + new T13CC20P1305ReadCipher( + cipherSuite, baseSecret, key, iv, hp, keyPhase); + default -> throw new IllegalArgumentException("Cipher suite " + + cipherSuite + " not supported"); + }; + } + + static QuicWriteCipher createWriteCipher(final CipherSuite cipherSuite, + final SecretKey baseSecret, final SecretKey key, + final byte[] iv, final SecretKey hp, + final int keyPhase) throws GeneralSecurityException { + return switch (cipherSuite) { + case TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384 -> + new T13GCMWriteCipher(cipherSuite, baseSecret, key, iv, hp, + keyPhase); + case TLS_CHACHA20_POLY1305_SHA256 -> + new T13CC20P1305WriteCipher(cipherSuite, baseSecret, key, iv, + hp, keyPhase); + default -> throw new IllegalArgumentException("Cipher suite " + + cipherSuite + " not supported"); + }; + } + + static void safeDiscard(final SecretKey secretKey) { + KeyUtil.destroySecretKeys(secretKey); + } + + abstract static class QuicReadCipher extends QuicCipher { + private final AtomicLong lowestDecryptedPktNum = new AtomicLong(-1); + + QuicReadCipher(CipherSuite cipherSuite, SecretKey baseSecret, + QuicHeaderProtectionCipher hpCipher, int keyPhase) { + super(cipherSuite, baseSecret, hpCipher, keyPhase); + } + + final void decryptPacket(long packetNumber, ByteBuffer packet, + int headerLength, ByteBuffer output) + throws AEADBadTagException, ShortBufferException, QuicTransportException { + doDecrypt(packetNumber, packet, headerLength, output); + boolean updated; + do { + final long current = lowestDecryptedPktNum.get(); + assert packetNumber >= 0 : + "unexpected packet number: " + packetNumber; + final long newLowest = current == -1 ? packetNumber : + Math.min(current, packetNumber); + updated = lowestDecryptedPktNum.compareAndSet(current, + newLowest); + } while (!updated); + } + + protected abstract void doDecrypt(long packetNumber, + ByteBuffer packet, int headerLength, ByteBuffer output) + throws AEADBadTagException, ShortBufferException, QuicTransportException; + + /** + * Returns the maximum limit on the number of packets that fail + * decryption, across all key (updates), using this + * {@code QuicReadCipher}. This method must not return a value less + * than 0. + * + * @return the limit + */ + // RFC-9001, section 6.6 + abstract long integrityLimit(); + + /** + * {@return the lowest packet number that this {@code QuicReadCipher} + * has decrypted. If no packets have yet been decrypted by this + * instance, then this method returns -1} + */ + final long lowestDecryptedPktNum() { + return this.lowestDecryptedPktNum.get(); + } + + /** + * {@return true if this {@code QuicReadCipher} has successfully + * decrypted any packet sent by the peer, else returns false} + */ + final boolean hasDecryptedAny() { + return this.lowestDecryptedPktNum.get() != -1; + } + } + + abstract static class QuicWriteCipher extends QuicCipher { + private final AtomicLong numPacketsEncrypted = new AtomicLong(); + private final AtomicLong lowestEncryptedPktNum = new AtomicLong(-1); + + QuicWriteCipher(CipherSuite cipherSuite, SecretKey baseSecret, + QuicHeaderProtectionCipher hpCipher, int keyPhase) { + super(cipherSuite, baseSecret, hpCipher, keyPhase); + } + + final void encryptPacket(final long packetNumber, + final ByteBuffer packetHeader, + final ByteBuffer packetPayload, + final ByteBuffer output) + throws QuicTransportException, ShortBufferException { + final long confidentialityLimit = confidentialityLimit(); + final long numEncrypted = this.numPacketsEncrypted.get(); + if (confidentialityLimit > 0 && + numEncrypted > confidentialityLimit) { + // the OneRttKeyManager is responsible for detecting and + // initiating a key update before this limit is hit. The fact + // that we hit this limit indicates that either the key + // update wasn't initiated or the key update failed. In + // either case we just throw an exception which + // should lead to the connection being closed as required by + // RFC-9001, section 6.6: + // If a key update is not possible or integrity limits are + // reached, the endpoint MUST stop using the connection and + // only send stateless resets in response to receiving + // packets. It is RECOMMENDED that endpoints immediately + // close the connection with a connection error of type + // AEAD_LIMIT_REACHED before reaching a state where key + // updates are not possible. + throw new QuicTransportException("confidentiality limit " + + "reached", ONE_RTT, 0, + QuicTransportErrors.AEAD_LIMIT_REACHED); + } + this.numPacketsEncrypted.incrementAndGet(); + doEncryptPacket(packetNumber, packetHeader, packetPayload, output); + boolean updated; + do { + final long current = lowestEncryptedPktNum.get(); + assert packetNumber >= 0 : + "unexpected packet number: " + packetNumber; + final long newLowest = current == -1 ? packetNumber : + Math.min(current, packetNumber); + updated = lowestEncryptedPktNum.compareAndSet(current, + newLowest); + } while (!updated); + } + + /** + * {@return the lowest packet number that this {@code QuicWriteCipher} + * has encrypted. If no packets have yet been encrypted by this + * instance, then this method returns -1} + */ + final long lowestEncryptedPktNum() { + return this.lowestEncryptedPktNum.get(); + } + + /** + * {@return true if this {@code QuicWriteCipher} has successfully + * encrypted any packet to send to the peer, else returns false} + */ + final boolean hasEncryptedAny() { + // rely on the lowestEncryptedPktNum field instead of the + // numPacketsEncrypted field. this avoids a race where the + // lowestEncryptedPktNum() might return a value contradicting + // the return value of this method. + return this.lowestEncryptedPktNum.get() != -1; + } + + /** + * {@return the number of packets encrypted by this {@code + * QuicWriteCipher}} + */ + final long getNumEncrypted() { + return this.numPacketsEncrypted.get(); + } + + abstract void doEncryptPacket(long packetNumber, ByteBuffer packetHeader, + ByteBuffer packetPayload, ByteBuffer output) + throws ShortBufferException, QuicTransportException; + + /** + * Returns the maximum limit on the number of packets that are allowed + * to be encrypted with this instance of {@code QuicWriteCipher}. A + * value less than 0 implies that there's no limit. + * + * @return the limit or -1 + */ + // RFC-9001, section 6.6: The confidentiality limit applies to the + // number of + // packets encrypted with a given key. + abstract long confidentialityLimit(); + } + + abstract static class QuicHeaderProtectionCipher { + protected final SecretKey headerProtectionKey; + + protected QuicHeaderProtectionCipher( + final SecretKey headerProtectionKey) { + this.headerProtectionKey = headerProtectionKey; + } + + int getHeaderProtectionSampleSize() { + return 16; + } + + abstract ByteBuffer computeHeaderProtectionMask(ByteBuffer sample) + throws QuicTransportException; + + final void discard() { + safeDiscard(this.headerProtectionKey); + } + } + + static final class T13GCMReadCipher extends QuicReadCipher { + // RFC-9001, section 6.6: For AEAD_AES_128_GCM and AEAD_AES_256_GCM, + // the integrity limit is 2^52 invalid packets + private static final long INTEGRITY_LIMIT = 1L << 52; + + private final Cipher cipher; + private final SecretKey key; + private final byte[] iv; + + T13GCMReadCipher(final CipherSuite cipherSuite, final SecretKey baseSecret, + final SecretKey key, final byte[] iv, final SecretKey hp, + final int keyPhase) + throws GeneralSecurityException { + super(cipherSuite, baseSecret, new T13AESHPCipher(hp), keyPhase); + this.key = key; + this.iv = iv; + this.cipher = Cipher.getInstance("AES/GCM/NoPadding"); + } + + @Override + protected void doDecrypt(long packetNumber, ByteBuffer packet, + int headerLength, ByteBuffer output) + throws AEADBadTagException, ShortBufferException, QuicTransportException { + byte[] iv = this.iv.clone(); + + // apply packet number to IV + int i = 11; + while (packetNumber > 0) { + iv[i] ^= (byte) (packetNumber & 0xFF); + packetNumber = packetNumber >>> 8; + i--; + } + final GCMParameterSpec ivSpec = new GCMParameterSpec(128, iv); + synchronized (cipher) { + try { + cipher.init(Cipher.DECRYPT_MODE, key, ivSpec); + int limit = packet.limit(); + packet.limit(packet.position() + headerLength); + cipher.updateAAD(packet); + packet.limit(limit); + cipher.doFinal(packet, output); + } catch (AEADBadTagException | ShortBufferException e) { + throw e; + } catch (Exception e) { + throw new QuicTransportException("Decryption failed", + null, 0, BASE_CRYPTO_ERROR + Alert.INTERNAL_ERROR.id, e); + } + } + } + + @Override + long integrityLimit() { + return INTEGRITY_LIMIT; + } + + @Override + protected final void doDiscard() { + safeDiscard(this.key); + } + } + + static final class T13GCMWriteCipher extends QuicWriteCipher { + private static final String CIPHER_ALGORITHM_NAME = "AES/GCM/NoPadding"; + private static final long CONFIDENTIALITY_LIMIT; + + static { + // RFC-9001, section 6.6: For AEAD_AES_128_GCM and AEAD_AES_256_GCM, + // the confidentiality limit is 2^23 encrypted packets + final long defaultVal = 1 << 23; + long limit = + KEY_LIMITS.getOrDefault(CIPHER_ALGORITHM_NAME, defaultVal); + // don't allow the configuration to increase the confidentiality + // limit, but only let it lower the limit + limit = limit > defaultVal ? defaultVal : limit; + CONFIDENTIALITY_LIMIT = limit; + } + + private final SecretKey key; + private final Cipher cipher; + private final byte[] iv; + + T13GCMWriteCipher(final CipherSuite cipherSuite, final SecretKey baseSecret, + final SecretKey key, final byte[] iv, final SecretKey hp, + final int keyPhase) throws GeneralSecurityException { + super(cipherSuite, baseSecret, new T13AESHPCipher(hp), keyPhase); + this.key = key; + this.iv = iv; + this.cipher = Cipher.getInstance(CIPHER_ALGORITHM_NAME); + } + + @Override + void doEncryptPacket(long packetNumber, ByteBuffer packetHeader, + ByteBuffer packetPayload, ByteBuffer output) + throws ShortBufferException, QuicTransportException { + byte[] iv = this.iv.clone(); + + // apply packet number to IV + int i = 11; + while (packetNumber > 0) { + iv[i] ^= (byte) (packetNumber & 0xFF); + packetNumber = packetNumber >>> 8; + i--; + } + final GCMParameterSpec ivSpec = new GCMParameterSpec(128, iv); + synchronized (cipher) { + try { + cipher.init(Cipher.ENCRYPT_MODE, key, ivSpec); + cipher.updateAAD(packetHeader); + cipher.doFinal(packetPayload, output); + } catch (ShortBufferException e) { + throw e; + } catch (Exception e) { + throw new QuicTransportException("Encryption failed", + null, 0, BASE_CRYPTO_ERROR + Alert.INTERNAL_ERROR.id, e); + } + } + } + + @Override + long confidentialityLimit() { + return CONFIDENTIALITY_LIMIT; + } + + @Override + protected final void doDiscard() { + safeDiscard(this.key); + } + } + + static final class T13AESHPCipher extends QuicHeaderProtectionCipher { + private final Cipher cipher; + + T13AESHPCipher(SecretKey hp) throws GeneralSecurityException { + super(hp); + cipher = Cipher.getInstance("AES/ECB/NoPadding"); + } + + @Override + public ByteBuffer computeHeaderProtectionMask(ByteBuffer sample) + throws QuicTransportException { + if (sample.remaining() != getHeaderProtectionSampleSize()) { + throw new IllegalArgumentException("Invalid sample size"); + } + ByteBuffer output = ByteBuffer.allocate(sample.remaining()); + try { + synchronized (cipher) { + // Some providers (Jipher) don't re-initialize the cipher + // after doFinal, and need init every time. + cipher.init(Cipher.ENCRYPT_MODE, headerProtectionKey); + cipher.doFinal(sample, output); + } + output.flip(); + assert output.remaining() >= 5; + return output; + } catch (Exception e) { + throw new QuicTransportException("Encryption failed", + null, 0, BASE_CRYPTO_ERROR + Alert.INTERNAL_ERROR.id, e); + } + } + } + + static final class T13CC20P1305ReadCipher extends QuicReadCipher { + // RFC-9001, section 6.6: For AEAD_CHACHA20_POLY1305, + // the integrity limit is 2^36 invalid packets + private static final long INTEGRITY_LIMIT = 1L << 36; + + private final SecretKey key; + private final Cipher cipher; + private final byte[] iv; + + T13CC20P1305ReadCipher(final CipherSuite cipherSuite, + final SecretKey baseSecret, final SecretKey key, + final byte[] iv, final SecretKey hp, final int keyPhase) + throws GeneralSecurityException { + super(cipherSuite, baseSecret, new T13CC20HPCipher(hp), keyPhase); + this.key = key; + this.iv = iv; + this.cipher = Cipher.getInstance("ChaCha20-Poly1305"); + } + + @Override + protected void doDecrypt(long packetNumber, ByteBuffer packet, + int headerLength, ByteBuffer output) + throws AEADBadTagException, ShortBufferException, QuicTransportException { + byte[] iv = this.iv.clone(); + + // apply packet number to IV + int i = 11; + while (packetNumber > 0) { + iv[i] ^= (byte) (packetNumber & 0xFF); + packetNumber = packetNumber >>> 8; + i--; + } + final IvParameterSpec ivSpec = new IvParameterSpec(iv); + synchronized (cipher) { + try { + cipher.init(Cipher.DECRYPT_MODE, key, ivSpec); + int limit = packet.limit(); + packet.limit(packet.position() + headerLength); + cipher.updateAAD(packet); + packet.limit(limit); + cipher.doFinal(packet, output); + } catch (AEADBadTagException | ShortBufferException e) { + throw e; + } catch (Exception e) { + throw new QuicTransportException("Decryption failed", + null, 0, BASE_CRYPTO_ERROR + Alert.INTERNAL_ERROR.id, e); + } + } + } + + @Override + long integrityLimit() { + return INTEGRITY_LIMIT; + } + + @Override + protected final void doDiscard() { + safeDiscard(this.key); + } + } + + static final class T13CC20P1305WriteCipher extends QuicWriteCipher { + private static final String CIPHER_ALGORITHM_NAME = "ChaCha20-Poly1305"; + private static final long CONFIDENTIALITY_LIMIT; + + static { + // RFC-9001, section 6.6: For AEAD_CHACHA20_POLY1305, the + // confidentiality limit is greater than the number of possible + // packets (2^62) and so can be disregarded. + final long defaultVal = -1; // no limit + long limit = + KEY_LIMITS.getOrDefault(CIPHER_ALGORITHM_NAME, defaultVal); + limit = limit < 0 ? -1 /* no limit */ : limit; + CONFIDENTIALITY_LIMIT = limit; + } + + private final SecretKey key; + private final Cipher cipher; + private final byte[] iv; + + T13CC20P1305WriteCipher(final CipherSuite cipherSuite, + final SecretKey baseSecret, final SecretKey key, + final byte[] iv, final SecretKey hp, + final int keyPhase) + throws GeneralSecurityException { + super(cipherSuite, baseSecret, new T13CC20HPCipher(hp), keyPhase); + this.key = key; + this.iv = iv; + this.cipher = Cipher.getInstance(CIPHER_ALGORITHM_NAME); + } + + @Override + void doEncryptPacket(final long packetNumber, final ByteBuffer packetHeader, + final ByteBuffer packetPayload, final ByteBuffer output) + throws ShortBufferException, QuicTransportException { + byte[] iv = this.iv.clone(); + + // apply packet number to IV + int i = 11; + long pn = packetNumber; + while (pn > 0) { + iv[i] ^= (byte) (pn & 0xFF); + pn = pn >>> 8; + i--; + } + final IvParameterSpec ivSpec = new IvParameterSpec(iv); + synchronized (cipher) { + try { + cipher.init(Cipher.ENCRYPT_MODE, key, ivSpec); + cipher.updateAAD(packetHeader); + cipher.doFinal(packetPayload, output); + } catch (ShortBufferException e) { + throw e; + } catch (Exception e) { + throw new QuicTransportException("Encryption failed", + null, 0, BASE_CRYPTO_ERROR + Alert.INTERNAL_ERROR.id, e); + } + } + } + + @Override + long confidentialityLimit() { + return CONFIDENTIALITY_LIMIT; + } + + @Override + protected final void doDiscard() { + safeDiscard(this.key); + } + } + + static final class T13CC20HPCipher extends QuicHeaderProtectionCipher { + private final Cipher cipher; + + T13CC20HPCipher(final SecretKey hp) throws GeneralSecurityException { + super(hp); + cipher = Cipher.getInstance("ChaCha20"); + } + + @Override + public ByteBuffer computeHeaderProtectionMask(ByteBuffer sample) + throws QuicTransportException { + if (sample.remaining() != getHeaderProtectionSampleSize()) { + throw new IllegalArgumentException("Invalid sample size"); + } + try { + // RFC 7539: [counter is a] 32-bit block count parameter, + // treated as a 32-bit little-endian integer + // RFC 9001: + // counter = sample[0..3] + // nonce = sample[4..15] + // mask = ChaCha20(hp_key, counter, nonce, {0,0,0,0,0}) + + sample.order(ByteOrder.LITTLE_ENDIAN); + byte[] nonce = new byte[12]; + int counter = sample.getInt(); + sample.get(nonce); + ChaCha20ParameterSpec ivSpec = + new ChaCha20ParameterSpec(nonce, counter); + byte[] output = new byte[5]; + + synchronized (cipher) { + // DECRYPT produces the same output as ENCRYPT, but does + // not throw when the same IV is used repeatedly + cipher.init(Cipher.DECRYPT_MODE, headerProtectionKey, + ivSpec); + int numBytes = cipher.doFinal(output, 0, 5, output); + assert numBytes == 5; + } + return ByteBuffer.wrap(output); + } catch (Exception e) { + throw new QuicTransportException("Encryption failed", + null, 0, BASE_CRYPTO_ERROR + Alert.INTERNAL_ERROR.id, e); + } + } + } +} diff --git a/src/java.base/share/classes/sun/security/ssl/QuicEngineOutputRecord.java b/src/java.base/share/classes/sun/security/ssl/QuicEngineOutputRecord.java new file mode 100644 index 00000000000..893eb282116 --- /dev/null +++ b/src/java.base/share/classes/sun/security/ssl/QuicEngineOutputRecord.java @@ -0,0 +1,245 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package sun.security.ssl; + +import jdk.internal.net.quic.QuicTLSEngine; +import sun.security.ssl.SSLCipher.SSLWriteCipher; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.LinkedList; + +/** + * {@code OutputRecord} implementation for {@code QuicTLSEngineImpl}. + */ +final class QuicEngineOutputRecord extends OutputRecord implements SSLRecord { + + private final HandshakeFragment fragmenter = new HandshakeFragment(); + + private volatile boolean isCloseWaiting; + + private Alert alert; + + QuicEngineOutputRecord(HandshakeHash handshakeHash) { + super(handshakeHash, SSLWriteCipher.nullTlsWriteCipher()); + + this.packetSize = SSLRecord.maxRecordSize; + this.protocolVersion = ProtocolVersion.NONE; + } + + @Override + public void close() throws IOException { + recordLock.lock(); + try { + if (!isClosed) { + if (!fragmenter.isEmpty()) { + isCloseWaiting = true; + } else { + super.close(); + } + } + } finally { + recordLock.unlock(); + } + } + + boolean isClosed() { + return isClosed || isCloseWaiting; + } + + @Override + void encodeAlert(byte level, byte description) throws IOException { + recordLock.lock(); + try { + if (isClosed()) { + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.warning("outbound has closed, ignore outbound " + + "alert message: " + Alert.nameOf(description)); + } + return; + } + if (level == Alert.Level.WARNING.level) { + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.warning("Suppressing warning-level " + + "alert message: " + Alert.nameOf(description)); + } + return; + } + + if (alert != null) { + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.warning("Suppressing subsequent alert: " + + description + ", original: " + alert.id); + } + return; + } + + alert = Alert.valueOf(description); + } finally { + recordLock.unlock(); + } + } + + @Override + void encodeHandshake(byte[] source, + int offset, int length) throws IOException { + recordLock.lock(); + try { + if (isClosed()) { + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.warning("outbound has closed, ignore outbound " + + "handshake message", + ByteBuffer.wrap(source, offset, length)); + } + return; + } + + firstMessage = false; + + byte handshakeType = source[offset]; + if (handshakeHash.isHashable(handshakeType)) { + handshakeHash.deliver(source, offset, length); + } + + fragmenter.queueUpFragment(source, offset, length); + } finally { + recordLock.unlock(); + } + } + + @Override + void encodeChangeCipherSpec() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + void changeWriteCiphers(SSLWriteCipher writeCipher, boolean useChangeCipherSpec) throws IOException { + recordLock.lock(); + try { + fragmenter.changePacketSpace(); + } finally { + recordLock.unlock(); + } + } + + @Override + void changeWriteCiphers(SSLWriteCipher writeCipher, byte keyUpdateRequest) throws IOException { + throw new UnsupportedOperationException("Should not call this"); + } + + @Override + byte[] getHandshakeMessage() { + recordLock.lock(); + try { + return fragmenter.acquireCiphertext(); + } finally { + recordLock.unlock(); + } + } + + @Override + QuicTLSEngine.KeySpace getHandshakeMessageKeySpace() { + recordLock.lock(); + try { + return switch (fragmenter.currentPacketSpace) { + case 0-> QuicTLSEngine.KeySpace.INITIAL; + case 1-> QuicTLSEngine.KeySpace.HANDSHAKE; + case 2-> QuicTLSEngine.KeySpace.ONE_RTT; + default -> throw new IllegalStateException("Unexpected state"); + }; + } finally { + recordLock.unlock(); + } + } + + @Override + boolean isEmpty() { + recordLock.lock(); + try { + return fragmenter.isEmpty(); + } finally { + recordLock.unlock(); + } + } + + Alert getAlert() { + recordLock.lock(); + try { + return alert; + } finally { + recordLock.unlock(); + } + } + + // buffered record fragment + private static class HandshakeMemo { + boolean changeSpace; + byte[] fragment; + } + + static final class HandshakeFragment { + private final LinkedList handshakeMemos = + new LinkedList<>(); + + private int currentPacketSpace; + + void queueUpFragment(byte[] source, + int offset, int length) throws IOException { + HandshakeMemo memo = new HandshakeMemo(); + + memo.fragment = new byte[length]; + assert Record.getInt24(ByteBuffer.wrap(source, offset + 1, 3)) + == length - 4 : "Invalid handshake message length"; + System.arraycopy(source, offset, memo.fragment, 0, length); + + handshakeMemos.add(memo); + } + + void changePacketSpace() { + HandshakeMemo lastMemo = handshakeMemos.peekLast(); + if (lastMemo != null) { + lastMemo.changeSpace = true; + } else { + currentPacketSpace++; + } + } + + byte[] acquireCiphertext() { + HandshakeMemo hsMemo = handshakeMemos.pollFirst(); + if (hsMemo == null) { + return null; + } + if (hsMemo.changeSpace) { + currentPacketSpace++; + } + return hsMemo.fragment; + } + + boolean isEmpty() { + return handshakeMemos.isEmpty(); + } + } +} diff --git a/src/java.base/share/classes/sun/security/ssl/QuicKeyManager.java b/src/java.base/share/classes/sun/security/ssl/QuicKeyManager.java new file mode 100644 index 00000000000..fb9077af022 --- /dev/null +++ b/src/java.base/share/classes/sun/security/ssl/QuicKeyManager.java @@ -0,0 +1,1216 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package sun.security.ssl; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import java.security.NoSuchAlgorithmException; +import java.util.HexFormat; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.IntFunction; + +import javax.crypto.AEADBadTagException; +import javax.crypto.Cipher; +import javax.crypto.KDF; +import javax.crypto.SecretKey; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.HKDFParameterSpec; +import javax.crypto.spec.SecretKeySpec; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; + +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicOneRttContext; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.quic.QuicVersion; +import jdk.internal.vm.annotation.Stable; +import sun.security.ssl.QuicCipher.QuicReadCipher; +import sun.security.ssl.QuicCipher.QuicWriteCipher; +import sun.security.util.KeyUtil; + +import static jdk.internal.net.quic.QuicTLSEngine.KeySpace.HANDSHAKE; +import static jdk.internal.net.quic.QuicTLSEngine.KeySpace.INITIAL; +import static jdk.internal.net.quic.QuicTLSEngine.KeySpace.ONE_RTT; +import static jdk.internal.net.quic.QuicTransportErrors.AEAD_LIMIT_REACHED; +import static jdk.internal.net.quic.QuicTransportErrors.KEY_UPDATE_ERROR; +import static sun.security.ssl.QuicTLSEngineImpl.BASE_CRYPTO_ERROR; + +sealed abstract class QuicKeyManager + permits QuicKeyManager.HandshakeKeyManager, + QuicKeyManager.InitialKeyManager, QuicKeyManager.OneRttKeyManager { + + private record QuicKeys(SecretKey key, byte[] iv, SecretKey hp) { + } + + private record CipherPair(QuicReadCipher readCipher, + QuicWriteCipher writeCipher) { + void discard(boolean destroyHP) { + writeCipher.discard(destroyHP); + readCipher.discard(destroyHP); + } + + /** + * {@return true if the keys represented by this {@code CipherPair} + * were used by both this endpoint and the peer, thus implying these + * keys are available to both of them} + */ + boolean usedByBothEndpoints() { + return this.readCipher.hasDecryptedAny() && + this.writeCipher.hasEncryptedAny(); + } + } + + final QuicTLSEngine.KeySpace keySpace; + // counter towards the integrity limit + final AtomicLong invalidPackets = new AtomicLong(); + volatile boolean keysDiscarded; + + private QuicKeyManager(final QuicTLSEngine.KeySpace keySpace) { + this.keySpace = keySpace; + } + + protected abstract boolean keysAvailable(); + + protected abstract QuicReadCipher getReadCipher() + throws QuicKeyUnavailableException; + + protected abstract QuicWriteCipher getWriteCipher() + throws QuicKeyUnavailableException; + + abstract void discardKeys(); + + void decryptPacket(final long packetNumber, final int keyPhase, + final ByteBuffer packet,final int headerLength, + final ByteBuffer output) throws QuicKeyUnavailableException, + IllegalArgumentException, AEADBadTagException, + QuicTransportException, ShortBufferException { + // keyPhase is only applicable for 1-RTT packets; the decryptPacket + // method is overridden by OneRttKeyManager, so this check is for + // other packet types + if (keyPhase != -1) { + throw new IllegalArgumentException( + "Unexpected key phase value: " + keyPhase); + } + // use current keys to decrypt + QuicReadCipher readCipher = getReadCipher(); + try { + readCipher.decryptPacket(packetNumber, packet, headerLength, output); + } catch (AEADBadTagException e) { + if (invalidPackets.incrementAndGet() >= + readCipher.integrityLimit()) { + throw new QuicTransportException("Integrity limit reached", + keySpace, 0, AEAD_LIMIT_REACHED); + } + throw e; + } + } + + void encryptPacket(final long packetNumber, + final IntFunction headerGenerator, + final ByteBuffer packetPayload, + final ByteBuffer output) + throws QuicKeyUnavailableException, QuicTransportException, ShortBufferException { + // generate the packet header passing the generator the key phase + final ByteBuffer header = headerGenerator.apply(0); // key phase is always 0 for non-ONE_RTT + getWriteCipher().encryptPacket(packetNumber, header, packetPayload, output); + } + + private static QuicKeys deriveQuicKeys(final QuicVersion quicVersion, + final CipherSuite cs, final SecretKey traffic_secret) + throws IOException { + final SSLKeyDerivation kd = new QuicTLSKeyDerivation(cs, + traffic_secret); + final QuicTLSData tlsData = getQuicData(quicVersion); + final SecretKey quic_key = kd.deriveKey(tlsData.getTlsKeyLabel()); + final byte[] quic_iv = kd.deriveData(tlsData.getTlsIvLabel()); + final SecretKey quic_hp = kd.deriveKey(tlsData.getTlsHpLabel()); + return new QuicKeys(quic_key, quic_iv, quic_hp); + } + + // Used in 1RTT when advancing the keyphase. quic_hp is not advanced. + private static QuicKeys deriveQuicKeys(final QuicVersion quicVersion, + final CipherSuite cs, final SecretKey traffic_secret, + final SecretKey quic_hp) throws IOException { + final SSLKeyDerivation kd = new QuicTLSKeyDerivation(cs, + traffic_secret); + final QuicTLSData tlsData = getQuicData(quicVersion); + final SecretKey quic_key = kd.deriveKey(tlsData.getTlsKeyLabel()); + final byte[] quic_iv = kd.deriveData(tlsData.getTlsIvLabel()); + return new QuicKeys(quic_key, quic_iv, quic_hp); + } + + private static QuicTLSData getQuicData(final QuicVersion quicVersion) { + return switch (quicVersion) { + case QUIC_V1 -> QuicTLSData.V1; + case QUIC_V2 -> QuicTLSData.V2; + }; + } + + private static byte[] createHkdfInfo(final String label, final int length) { + final byte[] tls13Label = + ("tls13 " + label).getBytes(StandardCharsets.UTF_8); + return createHkdfInfo(tls13Label, length); + } + + private static byte[] createHkdfInfo(final byte[] tls13Label, + final int length) { + final byte[] info = new byte[4 + tls13Label.length]; + final ByteBuffer m = ByteBuffer.wrap(info); + try { + Record.putInt16(m, length); + Record.putBytes8(m, tls13Label); + Record.putInt8(m, 0x00); // zero-length context + } catch (IOException ioe) { + // unlikely + throw new UncheckedIOException("Unexpected exception", ioe); + } + return info; + } + + static final class InitialKeyManager extends QuicKeyManager { + + private volatile CipherPair cipherPair; + + InitialKeyManager() { + super(INITIAL); + } + + @Override + protected boolean keysAvailable() { + return this.cipherPair != null && !this.keysDiscarded; + } + + @Override + protected QuicReadCipher getReadCipher() + throws QuicKeyUnavailableException { + final CipherPair pair = this.cipherPair; + if (pair == null) { + final String msg = this.keysDiscarded + ? "Keys have been discarded" + : "Keys not available"; + throw new QuicKeyUnavailableException(msg, this.keySpace); + } + return pair.readCipher; + } + + @Override + protected QuicWriteCipher getWriteCipher() + throws QuicKeyUnavailableException { + final CipherPair pair = this.cipherPair; + if (pair == null) { + final String msg = this.keysDiscarded + ? "Keys have been discarded" + : "Keys not available"; + throw new QuicKeyUnavailableException(msg, this.keySpace); + } + return pair.writeCipher; + } + + @Override + void discardKeys() { + final CipherPair toDiscard = this.cipherPair; + this.keysDiscarded = true; + this.cipherPair = null; // no longer needed + if (toDiscard == null) { + return; + } + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest("discarding keys (keyphase=" + + toDiscard.writeCipher.getKeyPhase() + + ") of " + this.keySpace + " key space"); + } + toDiscard.discard(true); + } + + void deriveKeys(final QuicVersion quicVersion, + final byte[] connectionId, + final boolean clientMode) throws IOException{ + Objects.requireNonNull(quicVersion); + final CipherSuite cs = CipherSuite.TLS_AES_128_GCM_SHA256; + final CipherSuite.HashAlg hashAlg = cs.hashAlg; + + KDF hkdf; + try { + hkdf = KDF.getInstance(hashAlg.hkdfAlgorithm); + } catch (NoSuchAlgorithmException e) { + throw new SSLHandshakeException("Could not generate secret", e); + } + final QuicTLSData tlsData = QuicKeyManager.getQuicData(quicVersion); + SecretKey initial_secret = null; + SecretKey server_initial_secret = null; + SecretKey client_initial_secret = null; + try { + initial_secret = hkdf.deriveKey("TlsInitialSecret", + HKDFParameterSpec.ofExtract() + .addSalt(tlsData.getInitialSalt()) + .addIKM(connectionId).extractOnly()); + + byte[] clientInfo = createHkdfInfo("client in", + hashAlg.hashLength); + client_initial_secret = + hkdf.deriveKey("TlsClientInitialTrafficSecret", + HKDFParameterSpec.expandOnly( + initial_secret, + clientInfo, + hashAlg.hashLength)); + QuicKeys clientKeys = deriveQuicKeys(quicVersion, cs, + client_initial_secret); + + byte[] serverInfo = createHkdfInfo("server in", + hashAlg.hashLength); + server_initial_secret = + hkdf.deriveKey("TlsServerInitialTrafficSecret", + HKDFParameterSpec.expandOnly( + initial_secret, + serverInfo, + hashAlg.hashLength)); + QuicKeys serverKeys = deriveQuicKeys(quicVersion, cs, + server_initial_secret); + + final QuicReadCipher readCipher; + final QuicWriteCipher writeCipher; + final int keyPhase = 0; + if (clientMode) { + readCipher = QuicCipher.createReadCipher(cs, + server_initial_secret, + serverKeys.key, serverKeys.iv, serverKeys.hp, + keyPhase); + writeCipher = QuicCipher.createWriteCipher(cs, + client_initial_secret, + clientKeys.key, clientKeys.iv, clientKeys.hp, + keyPhase); + } else { + readCipher = QuicCipher.createReadCipher(cs, + client_initial_secret, + clientKeys.key, clientKeys.iv, clientKeys.hp, + keyPhase); + writeCipher = QuicCipher.createWriteCipher(cs, + server_initial_secret, + serverKeys.key, serverKeys.iv, serverKeys.hp, + keyPhase); + } + final CipherPair old = this.cipherPair; + // we don't check if keys are already available, since it's a + // valid case where the INITIAL keys are regenerated due to a + // RETRY packet from the peer or even for the case where a + // different quic version was negotiated by the server + this.cipherPair = new CipherPair(readCipher, writeCipher); + if (old != null) { + old.discard(true); + } + } catch (GeneralSecurityException e) { + throw new SSLException("Missing cipher algorithm", e); + } finally { + KeyUtil.destroySecretKeys(initial_secret, client_initial_secret, + server_initial_secret); + } + } + + static Cipher getRetryCipher(final QuicVersion quicVersion, + final boolean incoming) throws QuicTransportException { + final QuicTLSData tlsData = QuicKeyManager.getQuicData(quicVersion); + return tlsData.getRetryCipher(incoming); + } + } + + static final class HandshakeKeyManager extends QuicKeyManager { + private volatile CipherPair cipherPair; + + HandshakeKeyManager() { + super(HANDSHAKE); + } + + @Override + protected boolean keysAvailable() { + return this.cipherPair != null && !this.keysDiscarded; + } + + @Override + protected QuicReadCipher getReadCipher() + throws QuicKeyUnavailableException { + final CipherPair pair = this.cipherPair; + if (pair == null) { + final String msg = this.keysDiscarded + ? "Keys have been discarded" + : "Keys not available"; + throw new QuicKeyUnavailableException(msg, this.keySpace); + } + return pair.readCipher; + } + + @Override + protected QuicWriteCipher getWriteCipher() + throws QuicKeyUnavailableException { + final CipherPair pair = this.cipherPair; + if (pair == null) { + final String msg = this.keysDiscarded + ? "Keys have been discarded" + : "Keys not available"; + throw new QuicKeyUnavailableException(msg, this.keySpace); + } + return pair.writeCipher; + } + + @Override + void discardKeys() { + final CipherPair toDiscard = this.cipherPair; + this.cipherPair = null; // no longer needed + this.keysDiscarded = true; + if (toDiscard == null) { + return; + } + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest("discarding keys (keyphase=" + + toDiscard.writeCipher.getKeyPhase() + + ") of " + this.keySpace + " key space"); + } + toDiscard.discard(true); + } + + void deriveKeys(final QuicVersion quicVersion, + final HandshakeContext handshakeContext, + final boolean clientMode) throws IOException { + Objects.requireNonNull(quicVersion); + if (keysAvailable()) { + throw new IllegalStateException( + "Keys already derived for " + this.keySpace + + " key space"); + } + SecretKey client_handshake_traffic_secret = null; + SecretKey server_handshake_traffic_secret = null; + try { + final SSLKeyDerivation kd = + handshakeContext.handshakeKeyDerivation; + client_handshake_traffic_secret = kd.deriveKey( + "TlsClientHandshakeTrafficSecret"); + final QuicKeys clientKeys = deriveQuicKeys(quicVersion, + handshakeContext.negotiatedCipherSuite, + client_handshake_traffic_secret); + server_handshake_traffic_secret = kd.deriveKey( + "TlsServerHandshakeTrafficSecret"); + final QuicKeys serverKeys = deriveQuicKeys(quicVersion, + handshakeContext.negotiatedCipherSuite, + server_handshake_traffic_secret); + + final CipherSuite negotiatedCipherSuite = + handshakeContext.negotiatedCipherSuite; + final QuicReadCipher readCipher; + final QuicWriteCipher writeCipher; + final int keyPhase = 0; + if (clientMode) { + readCipher = + QuicCipher.createReadCipher(negotiatedCipherSuite, + server_handshake_traffic_secret, + serverKeys.key, serverKeys.iv, + serverKeys.hp, keyPhase); + writeCipher = + QuicCipher.createWriteCipher(negotiatedCipherSuite, + client_handshake_traffic_secret, + clientKeys.key, clientKeys.iv, + clientKeys.hp, keyPhase); + } else { + readCipher = + QuicCipher.createReadCipher(negotiatedCipherSuite, + client_handshake_traffic_secret, + clientKeys.key, clientKeys.iv, + clientKeys.hp, keyPhase); + writeCipher = + QuicCipher.createWriteCipher(negotiatedCipherSuite, + server_handshake_traffic_secret, + serverKeys.key, serverKeys.iv, + serverKeys.hp, keyPhase); + } + synchronized (this) { + if (this.cipherPair != null) { + // don't allow setting more than once + throw new IllegalStateException("Keys already " + + "available for keyspace: " + + this.keySpace); + } + this.cipherPair = new CipherPair(readCipher, writeCipher); + } + } catch (GeneralSecurityException e) { + throw new SSLException("Missing cipher algorithm", e); + } finally { + KeyUtil.destroySecretKeys(client_handshake_traffic_secret, + server_handshake_traffic_secret); + } + } + } + + static final class OneRttKeyManager extends QuicKeyManager { + // a series of keys that the 1-RTT key manager uses + private record KeySeries(QuicReadCipher old, CipherPair current, + CipherPair next) { + private KeySeries { + Objects.requireNonNull(current); + if (old != null) { + if (old.getKeyPhase() == + current.writeCipher.getKeyPhase()) { + throw new IllegalArgumentException("Both old keys and" + + " current keys have the same key phase: " + + current.writeCipher.getKeyPhase()); + } + } + if (next != null) { + if (next.writeCipher.getKeyPhase() == + current.writeCipher.getKeyPhase()) { + throw new IllegalArgumentException("Both next keys " + + "and current keys have the same key phase: " + + current.writeCipher.getKeyPhase()); + } + } + } + + /** + * {@return true if this {@code KeySeries} has an old decryption key + * and the {@code pktNum} is lower than the least packet number the + * current decryption key has decrypted so far} + * + * @param pktNum the packet number for which the old key + * might be needed + */ + boolean canUseOldDecryptKey(final long pktNum) { + assert pktNum >= 0 : "unexpected packet number: " + pktNum; + if (this.old == null) { + return false; + } + final QuicReadCipher currentKey = this.current.readCipher; + final long lowestDecrypted = currentKey.lowestDecryptedPktNum(); + // if the incoming packet number is lesser than the lowest + // decrypted packet number by the current key, then it + // implies that this might be a delayed packet and thus is + // allowed to use the old key (if available) from + // the previous key phase. + // see RFC-9001, section 6.5 + if (lowestDecrypted == -1) { + return true; + } + return pktNum < lowestDecrypted; + } + } + + // will be set when the keys are derived + private volatile QuicVersion negotiatedVersion; + + private final Lock keySeriesLock = new ReentrantLock(); + // will be set when keys are derived and will + // be updated whenever keys are updated. + // Must be updated/written only + // when holding the keySeriesLock lock + private volatile KeySeries keySeries; + + @Stable + private volatile QuicOneRttContext oneRttContext; + + OneRttKeyManager() { + super(ONE_RTT); + } + + @Override + protected boolean keysAvailable() { + return this.keySeries != null && !this.keysDiscarded; + } + + @Override + protected QuicReadCipher getReadCipher() + throws QuicKeyUnavailableException { + final KeySeries series = requireKeySeries(); + return series.current.readCipher; + } + + @Override + protected QuicWriteCipher getWriteCipher() + throws QuicKeyUnavailableException { + final KeySeries series = requireKeySeries(); + return series.current.writeCipher; + } + + @Override + void discardKeys() { + this.keysDiscarded = true; + final KeySeries series; + this.keySeriesLock.lock(); + try { + series = this.keySeries; + this.keySeries = null; // no longer available + } finally { + this.keySeriesLock.unlock(); + } + if (series == null) { + return; + } + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest("discarding key (series) of " + + this.keySpace + " key space"); + } + if (series.old != null) { + series.old.discard(false); + } + discardKeys(series.current); + discardKeys(series.next); + } + + @Override + void decryptPacket(final long packetNumber, final int keyPhase, + final ByteBuffer packet, final int headerLength, + final ByteBuffer output) throws QuicKeyUnavailableException, + QuicTransportException, AEADBadTagException, ShortBufferException { + if (keyPhase != 0 && keyPhase != 1) { + throw new IllegalArgumentException("Unexpected key phase " + + "value: " + keyPhase); + } + final KeySeries series = requireKeySeries(); + final CipherPair current = series.current; + // Use the write cipher's key phase to detect a key update as noted + // in RFC-9001, section 6.2: + // An endpoint detects a key update when processing a packet with + // a key phase that differs from the value used to protect the + // last packet it sent. + final int currentKeyPhase = current.writeCipher.getKeyPhase(); + if (keyPhase == currentKeyPhase) { + current.readCipher.decryptPacket(packetNumber, packet, + headerLength, output); + return; + } + // incoming packet is using a key phase which doesn't match the + // current key phase. this implies that either a key update + // is being initiated or a key update initiated by the current + // endpoint is in progress and some older packet with the + // previous key phase has arrived. + if (series.canUseOldDecryptKey(packetNumber)) { + final QuicReadCipher oldReadCipher = series.old; + assert oldReadCipher != null : "old key is unexpectedly null"; + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest("using old read key to decrypt packet: " + + packetNumber + ", with incoming key phase: " + + keyPhase + ", current key phase: " + + currentKeyPhase); + } + oldReadCipher.decryptPacket( + packetNumber, packet, headerLength, output); + // we were able to decrypt using an old key. now verify + // that it was OK to use this old key for this packet. + if (!series.current.usedByBothEndpoints() + && series.current.writeCipher.hasEncryptedAny() + && oneRttContext.getLargestPeerAckedPN() + >= series.current.writeCipher.lowestEncryptedPktNum()) { + // RFC-9001, section 6.2: + // An endpoint that receives an acknowledgment that is + // carried in a packet protected with old keys where any + // acknowledged packet was protected with newer keys MAY + // treat that as a connection error of type + // KEY_UPDATE_ERROR. This indicates that a peer has + // received and acknowledged a packet that initiates a key + // update, but has not updated keys in response. + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest("peer used incorrect key, was" + + " expected to use updated key of" + + " key phase: " + currentKeyPhase + + ", incoming key phase: " + keyPhase + + ", packet number: " + packetNumber); + } + throw new QuicTransportException("peer used incorrect" + + " key, was expected to use updated key", + this.keySpace, 0, KEY_UPDATE_ERROR); + } + return; + } + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest("detected ONE_RTT key update, current key " + + "phase: " + currentKeyPhase + + ", incoming key phase: " + keyPhase + + ", packet number: " + packetNumber); + } + decryptUsingNextKeys( + series, packetNumber, packet, headerLength, output); + } + + @Override + void encryptPacket(final long packetNumber, + final IntFunction headerGenerator, + final ByteBuffer packetPayload, + final ByteBuffer output) + throws QuicKeyUnavailableException, QuicTransportException, ShortBufferException { + KeySeries currentSeries = requireKeySeries(); + if (currentSeries.next == null) { + // next keys haven't yet been generated, + // generate them now + try { + currentSeries = generateNextKeys( + this.negotiatedVersion, currentSeries); + } catch (GeneralSecurityException | IOException e) { + throw new QuicTransportException("Failed to update keys", + ONE_RTT, 0, BASE_CRYPTO_ERROR + Alert.INTERNAL_ERROR.id, e); + } + } + maybeInitiateKeyUpdate(currentSeries, packetNumber); + // call getWriteCipher() afresh so that it can use + // the new keyseries if at all the key update was + // initiated + final QuicWriteCipher writeCipher = getWriteCipher(); + final int keyPhase = writeCipher.getKeyPhase(); + // generate the packet header passing the generator the key phase + final ByteBuffer header = headerGenerator.apply(keyPhase); + writeCipher.encryptPacket(packetNumber, header, packetPayload, output); + } + + void setOneRttContext(final QuicOneRttContext ctx) { + Objects.requireNonNull(ctx); + this.oneRttContext = ctx; + } + + private KeySeries requireKeySeries() + throws QuicKeyUnavailableException { + final KeySeries series = this.keySeries; + if (series != null) { + return series; + } + final String msg = this.keysDiscarded + ? "Keys have been discarded" + : "Keys not available"; + throw new QuicKeyUnavailableException(msg, this.keySpace); + } + + // based on certain internal criteria, this method may trigger a key + // update. + // returns true if it does trigger the key update. false otherwise. + private boolean maybeInitiateKeyUpdate(final KeySeries currentSeries, + final long packetNumber) { + final QuicWriteCipher cipher = currentSeries.current.writeCipher; + // when we notice that we have reached 80% (which is arbitrary) + // of the confidentiality limit, we trigger a key update instead + // of waiting to hit the limit + final long confidentialityLimit = cipher.confidentialityLimit(); + if (confidentialityLimit < 0) { + return false; + } + final long numEncrypted = cipher.getNumEncrypted(); + if (numEncrypted >= 0.8 * confidentialityLimit) { + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest("about to reach confidentiality limit, " + + "attempting to initiate a 1-RTT key update," + + " packet number: " + + packetNumber + ", current key phase: " + + cipher.getKeyPhase()); + } + final boolean initiated = initiateKeyUpdate(currentSeries); + if (initiated) { + final int newKeyPhase = + this.keySeries.current.writeCipher.getKeyPhase(); + assert cipher.getKeyPhase() != newKeyPhase + : "key phase of updated key unexpectedly matches " + + "the key phase " + + cipher.getKeyPhase() + " of current keys"; + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest( + "1-RTT key update initiated, new key phase: " + + newKeyPhase); + } + } + return initiated; + } + return false; + } + + private boolean initiateKeyUpdate(final KeySeries series) { + // we only initiate a key update if this current endpoint and the + // peer have both been using this current key + if (!series.current.usedByBothEndpoints()) { + // RFC-9001, section 6.1: + // An endpoint MUST NOT initiate a subsequent key update + // unless it has received + // an acknowledgment for a packet that was sent protected + // with keys from the + // current key phase. This ensures that keys are + // available to both peers before + // another key update can be initiated. + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest( + "skipping key update initiation because peer " + + "hasn't yet sent us a packet encrypted with " + + "current key of key phase: " + + series.current.readCipher.getKeyPhase()); + } + return false; + } + // OK to initiate a key update. + // An endpoint initiates a key update by updating its packet + // protection write secret + // and using that to protect new packets. + rolloverKeys(this.negotiatedVersion, series); + return true; + } + + private static void discardKeys(final CipherPair cipherPair) { + if (cipherPair == null) { + return; + } + cipherPair.discard(true); + } + + /** + * uses "next" keys to try and decrypt the incoming packet. if that + * succeeded then it implies that the key update was indeed initiated by + * the peer and this method then rolls over the keys to start using + * these "next" keys. this method then returns true in such cases. if + * the packet decryption using the "next" key fails, then this method + * just returns back false (and doesn't roll over the keys) + */ + private void decryptUsingNextKeys( + final KeySeries currentKeySeries, + final long packetNumber, + final ByteBuffer packet, + final int headerLength, + final ByteBuffer output) + throws QuicKeyUnavailableException, AEADBadTagException, + ShortBufferException, QuicTransportException { + if (currentKeySeries.next == null) { + // this can happen if the peer initiated another + // key update before we could generate the next + // keys during our encryption flow. in such + // cases we reject the key update for the packet + // (we avoid timing attacks by not generating + // keys during decryption, our key generation + // only happens during encryption) + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest("next keys unavailable," + + " won't decrypt a packet which appears to be" + + " a key update"); + } + throw new QuicKeyUnavailableException( + "next keys unavailable to handle key update", + this.keySpace); + } + // use the next keys to attempt decrypting + currentKeySeries.next.readCipher.decryptPacket(packetNumber, packet, + headerLength, output); + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest( + "decrypted using next keys for peer-initiated" + + " key update; will now switch to new key phase: " + + currentKeySeries.next.readCipher.getKeyPhase()); + } + // we have successfully decrypted the packet using the new/next + // read key. So we now update even the write key as noted in + // RFC-9001, section 6.2: + // If a packet is successfully processed using the next key and + // IV, then the peer has initiated a key update. The endpoint + // MUST update its send keys to the corresponding + // key phase in response, as described in Section 6.1. Sending + // keys MUST be updated before sending an acknowledgment for the + // packet that was received with updated keys. rollover the + // keys == old gets discarded and is replaced by + // current, current is replaced by next and next is set to null + // (a new set of next keys will be generated separately on + // a schedule) + rolloverKeys(this.negotiatedVersion, currentKeySeries); + } + + void deriveKeys(final QuicVersion negotiatedVersion, + final HandshakeContext handshakeContext, + final boolean clientMode) throws IOException { + Objects.requireNonNull(negotiatedVersion); + if (keysAvailable()) { + throw new IllegalStateException("Keys already derived for " + + this.keySpace + " key space"); + } + this.negotiatedVersion = negotiatedVersion; + + try { + SSLKeyDerivation kd = handshakeContext.handshakeKeyDerivation; + SecretKey client_application_traffic_secret_0 = kd.deriveKey( + "TlsClientAppTrafficSecret"); + SecretKey server_application_traffic_secret_0 = kd.deriveKey( + "TlsServerAppTrafficSecret"); + + deriveOneRttKeys(this.negotiatedVersion, + client_application_traffic_secret_0, + server_application_traffic_secret_0, + handshakeContext.negotiatedCipherSuite, + clientMode); + } catch (GeneralSecurityException e) { + throw new SSLException("Missing cipher algorithm", e); + } + } + + void deriveOneRttKeys(final QuicVersion version, + final SecretKey client_application_traffic_secret_0, + final SecretKey server_application_traffic_secret_0, + final CipherSuite negotiatedCipherSuite, + final boolean clientMode) throws IOException, + GeneralSecurityException { + final QuicKeys clientKeys = deriveQuicKeys(version, + negotiatedCipherSuite, + client_application_traffic_secret_0); + final QuicKeys serverKeys = deriveQuicKeys(version, + negotiatedCipherSuite, + server_application_traffic_secret_0); + final QuicReadCipher readCipher; + final QuicWriteCipher writeCipher; + // this method always derives the first key for the 1-RTT, so key + // phase is always 0 + final int keyPhase = 0; + if (clientMode) { + readCipher = QuicCipher.createReadCipher(negotiatedCipherSuite, + server_application_traffic_secret_0, serverKeys.key, + serverKeys.iv, serverKeys.hp, keyPhase); + writeCipher = + QuicCipher.createWriteCipher(negotiatedCipherSuite, + client_application_traffic_secret_0, clientKeys.key, + clientKeys.iv, clientKeys.hp, keyPhase); + } else { + readCipher = QuicCipher.createReadCipher(negotiatedCipherSuite, + client_application_traffic_secret_0, clientKeys.key, + clientKeys.iv, clientKeys.hp, keyPhase); + writeCipher = + QuicCipher.createWriteCipher(negotiatedCipherSuite, + server_application_traffic_secret_0, serverKeys.key, + serverKeys.iv, serverKeys.hp, keyPhase); + } + // generate the next set of keys beforehand to prevent any timing + // attacks + // during key update + final QuicReadCipher nPlus1ReadCipher = + generateNextReadCipher(version, readCipher); + final QuicWriteCipher nPlus1WriteCipher = + generateNextWriteCipher(version, writeCipher); + this.keySeriesLock.lock(); + try { + if (this.keySeries != null) { + // don't allow deriving the first set of 1-RTT keys more + // than once + throw new IllegalStateException("Keys already available " + + "for keyspace: " + + this.keySpace); + } + this.keySeries = new KeySeries(null, + new CipherPair(readCipher, writeCipher), + new CipherPair(nPlus1ReadCipher, nPlus1WriteCipher)); + } finally { + this.keySeriesLock.unlock(); + } + } + + private static QuicWriteCipher generateNextWriteCipher( + final QuicVersion quicVersion, final QuicWriteCipher current) + throws IOException, GeneralSecurityException { + final SSLKeyDerivation kd = + new QuicTLSKeyDerivation(current.getCipherSuite(), + current.getBaseSecret()); + final QuicTLSData tlsData = QuicKeyManager.getQuicData(quicVersion); + final SecretKey nplus1Secret = + kd.deriveKey(tlsData.getTlsKeyUpdateLabel()); + final QuicKeys quicKeys = + QuicKeyManager.deriveQuicKeys(quicVersion, + current.getCipherSuite(), + nplus1Secret, current.getHeaderProtectionKey()); + final int nextKeyPhase = current.getKeyPhase() == 0 ? 1 : 0; + // toggle the 1 bit keyphase + final QuicWriteCipher next = + QuicCipher.createWriteCipher(current.getCipherSuite(), + nplus1Secret, quicKeys.key, quicKeys.iv, + quicKeys.hp, nextKeyPhase); + return next; + } + + private static QuicReadCipher generateNextReadCipher( + final QuicVersion quicVersion, final QuicReadCipher current) + throws IOException, GeneralSecurityException { + final SSLKeyDerivation kd = + new QuicTLSKeyDerivation(current.getCipherSuite(), + current.getBaseSecret()); + final QuicTLSData tlsData = QuicKeyManager.getQuicData(quicVersion); + final SecretKey nPlus1Secret = + kd.deriveKey(tlsData.getTlsKeyUpdateLabel()); + final QuicKeys quicKeys = + QuicKeyManager.deriveQuicKeys(quicVersion, + current.getCipherSuite(), + nPlus1Secret, current.getHeaderProtectionKey()); + final int nextKeyPhase = current.getKeyPhase() == 0 ? 1 : 0; + // toggle the 1 bit keyphase + final QuicReadCipher next = + QuicCipher.createReadCipher(current.getCipherSuite(), + nPlus1Secret, quicKeys.key, + quicKeys.iv, quicKeys.hp, nextKeyPhase); + return next; + } + + private KeySeries generateNextKeys(final QuicVersion version, + final KeySeries currentSeries) + throws GeneralSecurityException, IOException { + this.keySeriesLock.lock(); + try { + // nothing to do if some other thread + // already changed the keySeries + if (this.keySeries != currentSeries) { + return this.keySeries; + } + final QuicReadCipher nPlus1ReadCipher = + generateNextReadCipher(version, + currentSeries.current.readCipher); + final QuicWriteCipher nPlus1WriteCipher = + generateNextWriteCipher(version, + currentSeries.current.writeCipher); + // only the next keys will differ in the new series + // as compared to the current series + final KeySeries newSeries = new KeySeries(currentSeries.old, + currentSeries.current, + new CipherPair(nPlus1ReadCipher, nPlus1WriteCipher)); + this.keySeries = newSeries; + return newSeries; + } finally { + this.keySeriesLock.unlock(); + } + } + + /** + * Updates the key series by "left shifting" the series of keys. + * i.e. old keys (if any) are discarded, current keys + * are moved to old keys and next keys are moved to current keys. + * Note that no new keys will be generated by this method. + * @return the key series that will be in use going forward + */ + private KeySeries rolloverKeys(final QuicVersion version, + final KeySeries currentSeries) { + this.keySeriesLock.lock(); + try { + // nothing to do if some other thread + // already changed the keySeries + if (this.keySeries != currentSeries) { + return this.keySeries; + } + assert currentSeries.next != null : "Key series missing next" + + " keys"; + // discard the old read cipher which will no longer be used + final QuicReadCipher oldReadCipher = currentSeries.old; + // once we move current key to old, we won't be using the + // write cipher of that + // moved pair + final QuicWriteCipher writeCipherToDiscard = + currentSeries.current.writeCipher; + final KeySeries newSeries = new KeySeries( + currentSeries.current.readCipher, currentSeries.next, + null); + // update the key series + this.keySeries = newSeries; + if (oldReadCipher != null) { + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest( + "discarding old read key of key phase: " + + oldReadCipher.getKeyPhase()); + } + oldReadCipher.discard(false); + } + if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { + SSLLogger.finest("discarding write key of key phase: " + + writeCipherToDiscard.getKeyPhase()); + } + writeCipherToDiscard.discard(false); + return newSeries; + } finally { + this.keySeriesLock.unlock(); + } + } + } + + private static final class QuicTLSKeyDerivation + implements SSLKeyDerivation { + + private enum HkdfLabel { + // RFC 9001: quic version 1 + quickey("quic key"), + quiciv("quic iv"), + quichp("quic hp"), + quicku("quic ku"), + + // RFC 9369: quic version 2 + quicv2key("quicv2 key"), + quicv2iv("quicv2 iv"), + quicv2hp("quicv2 hp"), + quicv2ku("quicv2 ku"); + + private final String label; + private final byte[] tls13LabelBytes; + + HkdfLabel(final String label) { + Objects.requireNonNull(label); + this.label = label; + this.tls13LabelBytes = + ("tls13 " + label).getBytes(StandardCharsets.UTF_8); + } + + private static HkdfLabel fromLabel(final String label) { + Objects.requireNonNull(label); + for (final HkdfLabel hkdfLabel : HkdfLabel.values()) { + if (hkdfLabel.label.equals(label)) { + return hkdfLabel; + } + } + throw new IllegalArgumentException( + "unrecognized label: " + label); + } + } + + private final CipherSuite cs; + private final SecretKey secret; + + private QuicTLSKeyDerivation(final CipherSuite cs, + final SecretKey secret) { + this.cs = Objects.requireNonNull(cs); + this.secret = Objects.requireNonNull(secret); + } + + @Override + public SecretKey deriveKey(final String algorithm) throws IOException { + final HkdfLabel hkdfLabel = HkdfLabel.fromLabel(algorithm); + try { + final KDF hkdf = KDF.getInstance(this.cs.hashAlg.hkdfAlgorithm); + final int keyLength = getKeyLength(hkdfLabel); + final byte[] hkdfInfo = + createHkdfInfo(hkdfLabel.tls13LabelBytes, keyLength); + final String keyAlgo = getKeyAlgorithm(hkdfLabel); + return hkdf.deriveKey(keyAlgo, + HKDFParameterSpec.expandOnly( + secret, hkdfInfo, keyLength)); + } catch (GeneralSecurityException gse) { + throw new SSLHandshakeException("Could not derive key", gse); + } + } + + @Override + public byte[] deriveData(final String algorithm) throws IOException { + final HkdfLabel hkdfLabel = HkdfLabel.fromLabel(algorithm); + try { + final KDF hkdf = KDF.getInstance(this.cs.hashAlg.hkdfAlgorithm); + final int keyLength = getKeyLength(hkdfLabel); + final byte[] hkdfInfo = + createHkdfInfo(hkdfLabel.tls13LabelBytes, keyLength); + return hkdf.deriveData(HKDFParameterSpec.expandOnly( + secret, hkdfInfo, keyLength)); + } catch (GeneralSecurityException gse) { + throw new SSLHandshakeException("Could not derive key", gse); + } + } + + private int getKeyLength(final HkdfLabel hkdfLabel) { + return switch (hkdfLabel) { + case quicku, quicv2ku -> { + // RFC-9001, section 6.1: + // secret_ = HKDF-Expand-Label(secret_, "quic + // ku", "", Hash.length) + yield this.cs.hashAlg.hashLength; + } + case quiciv, quicv2iv -> this.cs.bulkCipher.ivSize; + default -> this.cs.bulkCipher.keySize; + }; + } + + private String getKeyAlgorithm(final HkdfLabel hkdfLabel) { + return switch (hkdfLabel) { + case quicku, quicv2ku -> "TlsUpdateNplus1"; + case quiciv, quicv2iv -> + throw new IllegalArgumentException("IV not expected"); + default -> this.cs.bulkCipher.algorithm; + }; + } + } + + private enum QuicTLSData { + V1("38762cf7f55934b34d179ae6a4c80cadccbb7f0a", + "be0c690b9f66575a1d766b54e368c84e", + "461599d35d632bf2239825bb", + "quic key", "quic iv", "quic hp", "quic ku"), + V2("0dede3def700a6db819381be6e269dcbf9bd2ed9", + "8fb4b01b56ac48e260fbcbcead7ccc92", + "d86969bc2d7c6d9990efb04a", + "quicv2 key", "quicv2 iv", "quicv2 hp", "quicv2 ku"); + + private final byte[] initialSalt; + private final SecretKey retryKey; + private final GCMParameterSpec retryIvSpec; + private final String keyLabel; + private final String ivLabel; + private final String hpLabel; + private final String kuLabel; + + QuicTLSData(String initialSalt, String retryKey, String retryIv, + String keyLabel, String ivLabel, String hpLabel, + String kuLabel) { + this.initialSalt = HexFormat.of() + .parseHex(initialSalt); + this.retryKey = new SecretKeySpec(HexFormat.of() + .parseHex(retryKey), "AES"); + retryIvSpec = new GCMParameterSpec(128, + HexFormat.of().parseHex(retryIv)); + this.keyLabel = keyLabel; + this.ivLabel = ivLabel; + this.hpLabel = hpLabel; + this.kuLabel = kuLabel; + } + + public byte[] getInitialSalt() { + return initialSalt; + } + + public Cipher getRetryCipher(boolean incoming) throws QuicTransportException { + Cipher retryCipher = null; + try { + retryCipher = Cipher.getInstance("AES/GCM/NoPadding"); + retryCipher.init(incoming ? Cipher.DECRYPT_MODE : + Cipher.ENCRYPT_MODE, + retryKey, retryIvSpec); + } catch (Exception e) { + throw new QuicTransportException("Cipher not available", + null, 0, BASE_CRYPTO_ERROR + Alert.INTERNAL_ERROR.id, e); + } + return retryCipher; + } + + public String getTlsKeyLabel() { + return keyLabel; + } + + public String getTlsIvLabel() { + return ivLabel; + } + + public String getTlsHpLabel() { + return hpLabel; + } + + public String getTlsKeyUpdateLabel() { + return kuLabel; + } + } +} diff --git a/src/java.base/share/classes/sun/security/ssl/QuicTLSEngineImpl.java b/src/java.base/share/classes/sun/security/ssl/QuicTLSEngineImpl.java new file mode 100644 index 00000000000..6765f554fcc --- /dev/null +++ b/src/java.base/share/classes/sun/security/ssl/QuicTLSEngineImpl.java @@ -0,0 +1,893 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package sun.security.ssl; + +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicOneRttContext; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.quic.QuicTransportParametersConsumer; +import jdk.internal.net.quic.QuicVersion; +import sun.security.ssl.QuicKeyManager.HandshakeKeyManager; +import sun.security.ssl.QuicKeyManager.InitialKeyManager; +import sun.security.ssl.QuicKeyManager.OneRttKeyManager; + +import javax.crypto.AEADBadTagException; +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.ShortBufferException; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.ByteBuffer; +import java.security.AlgorithmConstraints; +import java.security.GeneralSecurityException; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.IntFunction; + +import static jdk.internal.net.quic.QuicTLSEngine.HandshakeState.*; +import static jdk.internal.net.quic.QuicTLSEngine.KeySpace.*; + +/** + * One instance per QUIC connection. Configuration methods similar to + * SSLEngine. + *

+ * The implementation of this class uses the {@link QuicKeyManager} to maintain + * all state relating to keys for each encryption levels. + */ +public final class QuicTLSEngineImpl implements QuicTLSEngine, SSLTransport { + + private static final Map messageTypeMap = + Map.of(SSLHandshake.CLIENT_HELLO.id, INITIAL, + SSLHandshake.SERVER_HELLO.id, INITIAL, + SSLHandshake.ENCRYPTED_EXTENSIONS.id, HANDSHAKE, + SSLHandshake.CERTIFICATE_REQUEST.id, HANDSHAKE, + SSLHandshake.CERTIFICATE.id, HANDSHAKE, + SSLHandshake.CERTIFICATE_VERIFY.id, HANDSHAKE, + SSLHandshake.FINISHED.id, HANDSHAKE, + SSLHandshake.NEW_SESSION_TICKET.id, ONE_RTT); + static final long BASE_CRYPTO_ERROR = 256; + + private static final Set SUPPORTED_QUIC_VERSIONS = + Set.of(QuicVersion.QUIC_V1, QuicVersion.QUIC_V2); + + // VarHandles are used to access compareAndSet semantics. + private static final VarHandle HANDSHAKE_STATE_HANDLE; + static { + final MethodHandles.Lookup lookup = MethodHandles.lookup(); + try { + final Class quicTlsEngineImpl = QuicTLSEngineImpl.class; + HANDSHAKE_STATE_HANDLE = lookup.findVarHandle( + quicTlsEngineImpl, + "handshakeState", + HandshakeState.class); + } catch (Exception e) { + throw new ExceptionInInitializerError(e); + } + } + + private final TransportContext conContext; + private final String peerHost; + private final int peerPort; + private volatile HandshakeState handshakeState; + private volatile KeySpace sendKeySpace; + // next message to send or receive + private volatile ByteBuffer localQuicTransportParameters; + private volatile QuicTransportParametersConsumer + remoteQuicTransportParametersConsumer; + + // keymanagers for individual keyspaces + private final InitialKeyManager initialKeyManager = new InitialKeyManager(); + private final HandshakeKeyManager handshakeKeyManager = + new HandshakeKeyManager(); + private final OneRttKeyManager oneRttKeyManager = new OneRttKeyManager(); + + // buffer for crypto data that was received but not yet processed (i.e. + // incomplete messages) + private volatile ByteBuffer incomingCryptoBuffer; + // key space for incomingCryptoBuffer + private volatile KeySpace incomingCryptoSpace; + + private volatile QuicVersion negotiatedVersion; + + public QuicTLSEngineImpl(SSLContextImpl sslContextImpl) { + this(sslContextImpl, null, -1); + } + + public QuicTLSEngineImpl(SSLContextImpl sslContextImpl, final String peerHost, final int peerPort) { + this.peerHost = peerHost; + this.peerPort = peerPort; + this.sendKeySpace = INITIAL; + HandshakeHash handshakeHash = new HandshakeHash(); + this.conContext = new TransportContext(sslContextImpl, this, + new SSLEngineInputRecord(handshakeHash), + new QuicEngineOutputRecord(handshakeHash)); + conContext.sslConfig.enabledProtocols = List.of(ProtocolVersion.TLS13); + if (peerHost != null) { + conContext.sslConfig.serverNames = + Utilities.addToSNIServerNameList( + conContext.sslConfig.serverNames, peerHost); + } + conContext.setQuic(true); + } + + @Override + public void setUseClientMode(boolean mode) { + conContext.setUseClientMode(mode); + this.handshakeState = mode + ? HandshakeState.NEED_SEND_CRYPTO + : HandshakeState.NEED_RECV_CRYPTO; + } + + @Override + public boolean getUseClientMode() { + return conContext.sslConfig.isClientMode; + } + + @Override + public void setSSLParameters(final SSLParameters params) { + Objects.requireNonNull(params); + // section, 4.2 of RFC-9001 + // Clients MUST NOT offer TLS versions older than 1.3 + final String[] protos = params.getProtocols(); + if (protos == null || protos.length == 0) { + throw new IllegalArgumentException("No TLS protocols set"); + } + boolean tlsv13Present = false; + Set unsupported = new HashSet<>(); + for (String p : protos) { + if ("TLSv1.3".equals(p)) { + tlsv13Present = true; + } else { + unsupported.add(p); + } + } + if (!tlsv13Present) { + throw new IllegalArgumentException( + "required TLSv1.3 protocol version hasn't been set"); + } + if (!unsupported.isEmpty()) { + throw new IllegalArgumentException( + "Unsupported TLS protocol versions " + unsupported); + } + conContext.sslConfig.setSSLParameters(params); + } + + @Override + public SSLSession getSession() { + return conContext.conSession; + } + + @Override + public SSLSession getHandshakeSession() { + final HandshakeContext handshakeContext = conContext.handshakeContext; + return handshakeContext == null + ? null + : handshakeContext.handshakeSession; + } + + /** + * {@return the {@link AlgorithmConstraints} that are applicable for this engine, + * or null if none are applicable} + */ + AlgorithmConstraints getAlgorithmConstraints() { + final HandshakeContext handshakeContext = conContext.handshakeContext; + // if we are handshaking then use the handshake context + // to determine the constraints, else use the configured + // SSLParameters + return handshakeContext == null + ? getSSLParameters().getAlgorithmConstraints() + : handshakeContext.sslConfig.userSpecifiedAlgorithmConstraints; + } + + @Override + public SSLParameters getSSLParameters() { + return conContext.sslConfig.getSSLParameters(); + } + + @Override + public String getApplicationProtocol() { + // TODO: review thread safety when dealing with conContext + return conContext.applicationProtocol; + } + + @Override + public Set getSupportedQuicVersions() { + return SUPPORTED_QUIC_VERSIONS; + } + + @Override + public void setOneRttContext(final QuicOneRttContext ctx) { + this.oneRttKeyManager.setOneRttContext(ctx); + } + + private QuicVersion getNegotiatedVersion() { + final QuicVersion negotiated = this.negotiatedVersion; + if (negotiated == null) { + throw new IllegalStateException( + "Quic version hasn't been negotiated yet"); + } + return negotiated; + } + + private boolean isEnabled(final QuicVersion quicVersion) { + final Set enabled = getSupportedQuicVersions(); + if (enabled == null) { + return false; + } + return enabled.contains(quicVersion); + } + + /** + * Returns the current handshake state of the connection. Sometimes packets + * that could be decrypted can be received before the handshake has + * completed, but should not be decrypted until it is complete + * + * @return the HandshakeState + */ + @Override + public HandshakeState getHandshakeState() { + return handshakeState; + } + + /** + * Returns the current sending key space (encryption level) + * + * @return the current sending key space + */ + @Override + public KeySpace getCurrentSendKeySpace() { + return sendKeySpace; + } + + @Override + public boolean keysAvailable(KeySpace keySpace) { + return switch (keySpace) { + case INITIAL -> this.initialKeyManager.keysAvailable(); + case HANDSHAKE -> this.handshakeKeyManager.keysAvailable(); + case ONE_RTT -> this.oneRttKeyManager.keysAvailable(); + case ZERO_RTT -> false; + case RETRY -> true; + default -> throw new IllegalArgumentException( + keySpace + " not expected here"); + }; + } + + @Override + public void discardKeys(KeySpace keySpace) { + switch (keySpace) { + case INITIAL -> this.initialKeyManager.discardKeys(); + case HANDSHAKE -> this.handshakeKeyManager.discardKeys(); + case ONE_RTT -> this.oneRttKeyManager.discardKeys(); + default -> throw new IllegalArgumentException( + "key discarding not implemented for " + keySpace); + } + } + + @Override + public int getHeaderProtectionSampleSize(KeySpace keySpace) { + return switch (keySpace) { + case INITIAL, HANDSHAKE, ZERO_RTT, ONE_RTT -> 16; + default -> throw new IllegalArgumentException( + "Type '" + keySpace + "' not expected here"); + }; + } + + @Override + public ByteBuffer computeHeaderProtectionMask(KeySpace keySpace, + boolean incoming, ByteBuffer sample) + throws QuicKeyUnavailableException, QuicTransportException { + final QuicKeyManager keyManager = keyManager(keySpace); + if (incoming) { + final QuicCipher.QuicReadCipher quicCipher = + keyManager.getReadCipher(); + return quicCipher.computeHeaderProtectionMask(sample); + } else { + final QuicCipher.QuicWriteCipher quicCipher = + keyManager.getWriteCipher(); + return quicCipher.computeHeaderProtectionMask(sample); + } + } + + @Override + public int getAuthTagSize() { + // RFC-9001, section 5.3 + // QUIC can use any of the cipher suites defined in [TLS13] with the + // exception of TLS_AES_128_CCM_8_SHA256. ... + // These cipher suites have a 16-byte authentication tag and produce + // an output 16 bytes larger than their input. + return 16; + } + + @Override + public void encryptPacket(final KeySpace keySpace, final long packetNumber, + final IntFunction headerGenerator, + final ByteBuffer packetPayload, final ByteBuffer output) + throws QuicKeyUnavailableException, QuicTransportException, ShortBufferException { + final QuicKeyManager keyManager = keyManager(keySpace); + keyManager.encryptPacket(packetNumber, headerGenerator, packetPayload, output); + } + + @Override + public void decryptPacket(final KeySpace keySpace, + final long packetNumber, final int keyPhase, + final ByteBuffer packet, final int headerLength, + final ByteBuffer output) + throws QuicKeyUnavailableException, AEADBadTagException, + QuicTransportException, ShortBufferException { + if (keySpace == ONE_RTT && !isTLSHandshakeComplete()) { + // RFC-9001, section 5.7 specifies that the server or the client MUST NOT + // decrypt 1-RTT packets, even if 1-RTT keys are available, before the + // TLS handshake is complete. + throw new QuicKeyUnavailableException("QUIC TLS handshake not yet complete", ONE_RTT); + } + final QuicKeyManager keyManager = keyManager(keySpace); + keyManager.decryptPacket(packetNumber, keyPhase, packet, headerLength, + output); + } + + @Override + public void signRetryPacket(final QuicVersion quicVersion, + final ByteBuffer originalConnectionId, final ByteBuffer packet, + final ByteBuffer output) + throws ShortBufferException, QuicTransportException { + if (!isEnabled(quicVersion)) { + throw new IllegalArgumentException( + "Quic version " + quicVersion + " isn't enabled"); + } + int connIdLength = originalConnectionId.remaining(); + if (connIdLength >= 256 || connIdLength < 0) { + throw new IllegalArgumentException("connection ID length"); + } + final Cipher cipher = InitialKeyManager.getRetryCipher( + quicVersion, false); + cipher.updateAAD(new byte[]{(byte) connIdLength}); + cipher.updateAAD(originalConnectionId); + cipher.updateAAD(packet); + try { + // No data to encrypt, just outputting the tag which will be + // verified later. + cipher.doFinal(ByteBuffer.allocate(0), output); + } catch (ShortBufferException e) { + throw e; + } catch (Exception e) { + throw new QuicTransportException("Failed to sign packet", + null, 0, BASE_CRYPTO_ERROR + Alert.INTERNAL_ERROR.id, e); + } + } + + @Override + public void verifyRetryPacket(final QuicVersion quicVersion, + final ByteBuffer originalConnectionId, + final ByteBuffer packet) + throws AEADBadTagException, QuicTransportException { + if (!isEnabled(quicVersion)) { + throw new IllegalArgumentException( + "Quic version " + quicVersion + " isn't enabled"); + } + int connIdLength = originalConnectionId.remaining(); + if (connIdLength >= 256 || connIdLength < 0) { + throw new IllegalArgumentException("connection ID length"); + } + int originalLimit = packet.limit(); + packet.limit(originalLimit - 16); + final Cipher cipher = + InitialKeyManager.getRetryCipher(quicVersion, true); + cipher.updateAAD(new byte[]{(byte) connIdLength}); + cipher.updateAAD(originalConnectionId); + cipher.updateAAD(packet); + packet.limit(originalLimit); + try { + assert packet.remaining() == 16; + int outBufLength = cipher.getOutputSize(packet.remaining()); + // No data to decrypt, just checking the tag. + ByteBuffer outBuffer = ByteBuffer.allocate(outBufLength); + cipher.doFinal(packet, outBuffer); + assert outBuffer.position() == 0; + } catch (AEADBadTagException e) { + throw e; + } catch (Exception e) { + throw new QuicTransportException("Failed to verify packet", + null, 0, BASE_CRYPTO_ERROR + Alert.INTERNAL_ERROR.id, e); + } + } + + private QuicKeyManager keyManager(final KeySpace keySpace) { + return switch (keySpace) { + case INITIAL -> this.initialKeyManager; + case HANDSHAKE -> this.handshakeKeyManager; + case ONE_RTT -> this.oneRttKeyManager; + default -> throw new IllegalArgumentException( + "No key manager available for key space: " + keySpace); + }; + } + + @Override + public ByteBuffer getHandshakeBytes(KeySpace keySpace) throws IOException { + if (keySpace != sendKeySpace) { + throw new IllegalStateException("Unexpected key space: " + + keySpace + " (expected " + sendKeySpace + ")"); + } + if (handshakeState == HandshakeState.NEED_SEND_CRYPTO || + !conContext.outputRecord.isEmpty()) { // session ticket + byte[] bytes = produceNextHandshakeMessage(); + return ByteBuffer.wrap(bytes); + } else { + return null; + } + } + + private byte[] produceNextHandshakeMessage() throws IOException { + if (!conContext.isNegotiated && !conContext.isBroken && + !conContext.isInboundClosed() && + !conContext.isOutboundClosed()) { + conContext.kickstart(); + } + byte[] message = conContext.outputRecord.getHandshakeMessage(); + if (handshakeState == NEED_SEND_CRYPTO) { + if (conContext.outputRecord.isEmpty()) { + if (conContext.isNegotiated) { + // client, done + handshakeState = NEED_RECV_HANDSHAKE_DONE; + sendKeySpace = ONE_RTT; + } else { + handshakeState = NEED_RECV_CRYPTO; + } + } else if (sendKeySpace == INITIAL && !getUseClientMode()) { + // Server sends handshake messages immediately after + // the initial server hello. Need to check the next key space. + sendKeySpace = conContext.outputRecord.getHandshakeMessageKeySpace(); + } + } else { + assert conContext.isNegotiated; + } + return message; + } + + @Override + public void consumeHandshakeBytes(KeySpace keySpace, ByteBuffer payload) + throws QuicTransportException { + if (!payload.hasRemaining()) { + throw new IllegalArgumentException("Empty crypto buffer"); + } + if (keySpace == KeySpace.ZERO_RTT) { + throw new IllegalArgumentException("Crypto in zero-rtt"); + } + if (incomingCryptoSpace != null && incomingCryptoSpace != keySpace) { + throw new QuicTransportException("Unexpected message", null, 0, + BASE_CRYPTO_ERROR + Alert.UNEXPECTED_MESSAGE.id, + new SSLHandshakeException( + "Unfinished message in " + incomingCryptoSpace)); + } + try { + if (!conContext.isNegotiated && !conContext.isBroken && + !conContext.isInboundClosed() && + !conContext.isOutboundClosed()) { + conContext.kickstart(); + } + } catch (IOException e) { + throw new QuicTransportException(e.toString(), null, 0, + BASE_CRYPTO_ERROR + Alert.INTERNAL_ERROR.id, e); + } + // previously unconsumed bytes in incomingCryptoBuffer, new bytes in + // payload. if incomingCryptoBuffer is not null, it's either 4 bytes + // or large enough to hold the entire message. + while (payload.hasRemaining()) { + if (keySpace != KeySpace.ONE_RTT && + handshakeState != HandshakeState.NEED_RECV_CRYPTO) { + // in one-rtt we may receive session tickets at any time; + // during handshake we're either sending or receiving + throw new QuicTransportException("Unexpected message", null, 0, + BASE_CRYPTO_ERROR + Alert.UNEXPECTED_MESSAGE.id, + new SSLHandshakeException( + "Not expecting a handshake message, state: " + + handshakeState)); + } + if (incomingCryptoBuffer != null) { + // message type validated already; pump more bytes + if (payload.remaining() <= incomingCryptoBuffer.remaining()) { + incomingCryptoBuffer.put(payload); + } else { + // more than one message in buffer, or we don't have a + // header yet + int remaining = incomingCryptoBuffer.remaining(); + incomingCryptoBuffer.put(incomingCryptoBuffer.position(), + payload, payload.position(), remaining); + incomingCryptoBuffer.position(incomingCryptoBuffer.limit()); + payload.position(payload.position() + remaining); + if (incomingCryptoBuffer.capacity() == 4) { + // small buffer for header only; retrieve size and + // expand if necessary + int messageSize = + ((incomingCryptoBuffer.get(1) & 0xFF) << 16) | + ((incomingCryptoBuffer.get(2) & 0xFF) << 8) | + (incomingCryptoBuffer.get(3) & 0xFF); + if (messageSize != 0) { + if (messageSize > SSLConfiguration.maxHandshakeMessageSize) { + throw new QuicTransportException( + "The size of the handshake message (" + + messageSize + + ") exceeds the maximum allowed size (" + + SSLConfiguration.maxHandshakeMessageSize + + ")", + null, 0, + QuicTransportErrors.CRYPTO_BUFFER_EXCEEDED); + } + ByteBuffer newBuffer = + ByteBuffer.allocate(messageSize + 4); + incomingCryptoBuffer.flip(); + newBuffer.put(incomingCryptoBuffer); + incomingCryptoBuffer = newBuffer; + assert incomingCryptoBuffer.position() == 4 : + incomingCryptoBuffer.position(); + // start over with larger buffer + continue; + } + // message size was zero... can it really happen? + } + } + } else { + // incoming crypto buffer is null. Validate message type, + // check if size is available + byte messageType = payload.get(payload.position()); + if (SSLLogger.isOn) { + SSLLogger.fine("Received message of type 0x" + + Integer.toHexString(messageType & 0xFF)); + } + KeySpace expected = messageTypeMap.get(messageType); + if (expected != keySpace) { + throw new QuicTransportException("Unexpected message", + null, 0, + BASE_CRYPTO_ERROR + Alert.UNEXPECTED_MESSAGE.id, + new SSLHandshakeException("Message " + messageType + + " received in " + keySpace + + " but should be " + expected)); + } + if (payload.remaining() < 4) { + // partial message, length missing. Store in + // incomingCryptoBuffer + incomingCryptoBuffer = ByteBuffer.allocate(4); + incomingCryptoBuffer.put(payload); + incomingCryptoSpace = keySpace; + return; + } + int payloadPos = payload.position(); + int messageSize = ((payload.get(payloadPos + 1) & 0xFF) << 16) + | ((payload.get(payloadPos + 2) & 0xFF) << 8) + | (payload.get(payloadPos + 3) & 0xFF); + if (payload.remaining() < messageSize + 4) { + // partial message, length known. Store in + // incomingCryptoBuffer + if (messageSize > SSLConfiguration.maxHandshakeMessageSize) { + throw new QuicTransportException( + "The size of the handshake message (" + + messageSize + + ") exceeds the maximum allowed size (" + + SSLConfiguration.maxHandshakeMessageSize + + ")", + null, 0, + QuicTransportErrors.CRYPTO_BUFFER_EXCEEDED); + } + incomingCryptoBuffer = ByteBuffer.allocate(messageSize + 4); + incomingCryptoBuffer.put(payload); + incomingCryptoSpace = keySpace; + return; + } + incomingCryptoSpace = keySpace; + incomingCryptoBuffer = payload.slice(payloadPos, + messageSize + 4); + // set position at end to indicate that the buffer is ready + // for processing + incomingCryptoBuffer.position(messageSize + 4); + assert !incomingCryptoBuffer.hasRemaining() : + incomingCryptoBuffer.remaining(); + payload.position(payloadPos + messageSize + 4); + } + if (!incomingCryptoBuffer.hasRemaining()) { + incomingCryptoBuffer.flip(); + handleHandshakeMessage(keySpace, incomingCryptoBuffer); + incomingCryptoBuffer = null; + incomingCryptoSpace = null; + } else { + assert !payload.hasRemaining() : payload.remaining(); + return; + } + } + } + + private void handleHandshakeMessage(KeySpace keySpace, ByteBuffer message) + throws QuicTransportException { + // message param contains one whole TLS message + boolean useClientMode = getUseClientMode(); + byte messageType = message.get(); + int messageSize = ((message.get() & 0xFF) << 16) + | ((message.get() & 0xFF) << 8) + | (message.get() & 0xFF); + + assert message.remaining() == messageSize : + message.remaining() - messageSize; + try { + if (conContext.inputRecord.handshakeHash.isHashable(messageType)) { + ByteBuffer temp = message.duplicate(); + temp.position(0); + conContext.inputRecord.handshakeHash.receive(temp); + } + if (conContext.handshakeContext == null) { + if (!conContext.isNegotiated) { + throw new QuicTransportException( + "Cannot process crypto message, broken: " + + conContext.isBroken, + null, 0, QuicTransportErrors.INTERNAL_ERROR); + } + conContext.handshakeContext = + new PostHandshakeContext(conContext); + } + conContext.handshakeContext.dispatch(messageType, message.slice()); + } catch (SSLHandshakeException e) { + if (e.getCause() instanceof QuicTransportException qte) { + // rethrow quic transport parameters validation exception + throw qte; + } + Alert alert = ((QuicEngineOutputRecord) + conContext.outputRecord).getAlert(); + throw new QuicTransportException(alert.description, keySpace, 0, + BASE_CRYPTO_ERROR + alert.id, e); + } catch (IOException e) { + throw new RuntimeException(e); + } + if (handshakeState == NEED_RECV_CRYPTO) { + if (conContext.outputRecord.isEmpty()) { + if (conContext.isNegotiated) { + // dead code? done, server side, no session ticket + handshakeState = NEED_SEND_HANDSHAKE_DONE; + sendKeySpace = ONE_RTT; + } else { + // expect more messages + // client side: if we're still in INITIAL, switch + // to HANDSHAKE + if (sendKeySpace == INITIAL) { + sendKeySpace = HANDSHAKE; + } + } + } else { + // our turn to send + if (conContext.isNegotiated && !useClientMode) { + // done, server side, wants to send session ticket + handshakeState = NEED_SEND_HANDSHAKE_DONE; + sendKeySpace = ONE_RTT; + } else { + // more messages needed to finish handshake + handshakeState = HandshakeState.NEED_SEND_CRYPTO; + } + } + } else { + assert conContext.isNegotiated; + } + } + + @Override + public void deriveInitialKeys(final QuicVersion quicVersion, + final ByteBuffer connectionId) throws IOException { + if (!isEnabled(quicVersion)) { + throw new IllegalArgumentException("Quic version " + quicVersion + + " isn't enabled"); + } + final byte[] connectionIdBytes = new byte[connectionId.remaining()]; + connectionId.get(connectionIdBytes); + this.initialKeyManager.deriveKeys(quicVersion, connectionIdBytes, + getUseClientMode()); + } + + @Override + public void versionNegotiated(final QuicVersion quicVersion) { + Objects.requireNonNull(quicVersion); + if (!isEnabled(quicVersion)) { + throw new IllegalArgumentException("Quic version " + quicVersion + + " is not enabled"); + } + synchronized (this) { + final QuicVersion prevNegotiated = this.negotiatedVersion; + if (prevNegotiated != null) { + throw new IllegalStateException("A Quic version has already " + + "been negotiated previously"); + } + this.negotiatedVersion = quicVersion; + } + } + + public void deriveHandshakeKeys() throws IOException { + final QuicVersion quicVersion = getNegotiatedVersion(); + this.handshakeKeyManager.deriveKeys(quicVersion, + this.conContext.handshakeContext, + getUseClientMode()); + } + + public void deriveOneRTTKeys() throws IOException { + final QuicVersion quicVersion = getNegotiatedVersion(); + this.oneRttKeyManager.deriveKeys(quicVersion, + this.conContext.handshakeContext, + getUseClientMode()); + } + + // for testing (PacketEncryptionTest) + void deriveOneRTTKeys(final QuicVersion version, + final SecretKey client_application_traffic_secret_0, + final SecretKey server_application_traffic_secret_0, + final CipherSuite negotiatedCipherSuite, + final boolean clientMode) throws IOException, + GeneralSecurityException { + this.oneRttKeyManager.deriveOneRttKeys(version, + client_application_traffic_secret_0, + server_application_traffic_secret_0, + negotiatedCipherSuite, clientMode); + } + + @Override + public Runnable getDelegatedTask() { + // TODO: actually delegate tasks + return null; + } + + @Override + public String getPeerHost() { + return peerHost; + } + + @Override + public int getPeerPort() { + return peerPort; + } + + @Override + public boolean useDelegatedTask() { + return true; + } + + public byte[] getLocalQuicTransportParameters() { + ByteBuffer ltp = localQuicTransportParameters; + if (ltp == null) { + return null; + } + byte[] result = new byte[ltp.remaining()]; + ltp.get(0, result); + return result; + } + + @Override + public void setLocalQuicTransportParameters(ByteBuffer params) { + this.localQuicTransportParameters = params; + } + + @Override + public void restartHandshake() throws IOException { + if (negotiatedVersion != null) { + throw new IllegalStateException("Version already negotiated"); + } + if (sendKeySpace != INITIAL || handshakeState != NEED_RECV_CRYPTO) { + throw new IllegalStateException("Unexpected handshake state"); + } + HandshakeContext context = conContext.handshakeContext; + ClientHandshakeContext chc = (ClientHandshakeContext)context; + + // Refresh handshake hash + chc.handshakeHash.finish(); // reset the handshake hash + + // Update the initial ClientHello handshake message. + chc.initialClientHelloMsg.extensions.reproduce(chc, + new SSLExtension[] { + SSLExtension.CH_QUIC_TRANSPORT_PARAMETERS, + SSLExtension.CH_PRE_SHARED_KEY + }); + + // produce handshake message + chc.initialClientHelloMsg.write(chc.handshakeOutput); + handshakeState = NEED_SEND_CRYPTO; + } + + @Override + public void setRemoteQuicTransportParametersConsumer( + QuicTransportParametersConsumer consumer) { + this.remoteQuicTransportParametersConsumer = consumer; + } + + void processRemoteQuicTransportParameters(ByteBuffer buffer) + throws QuicTransportException{ + remoteQuicTransportParametersConsumer.accept(buffer); + } + + @Override + public boolean tryMarkHandshakeDone() { + if (getUseClientMode()) { + // not expected to be called on client + throw new IllegalStateException( + "Not expected to be called in client mode"); + } + final boolean confirmed = HANDSHAKE_STATE_HANDLE.compareAndSet(this, + NEED_SEND_HANDSHAKE_DONE, HANDSHAKE_CONFIRMED); + if (confirmed) { + if (SSLLogger.isOn) { + SSLLogger.fine("QuicTLSEngine (server) marked handshake " + + "state as HANDSHAKE_CONFIRMED"); + } + } + return confirmed; + } + + @Override + public boolean tryReceiveHandshakeDone() { + final boolean isClient = getUseClientMode(); + if (!isClient) { + throw new IllegalStateException( + "Not expected to receive HANDSHAKE_DONE in server mode"); + } + final boolean confirmed = HANDSHAKE_STATE_HANDLE.compareAndSet(this, + NEED_RECV_HANDSHAKE_DONE, HANDSHAKE_CONFIRMED); + if (confirmed) { + if (SSLLogger.isOn) { + SSLLogger.fine( + "QuicTLSEngine (client) received HANDSHAKE_DONE," + + " marking state as HANDSHAKE_DONE"); + } + } + return confirmed; + } + + @Override + public boolean isTLSHandshakeComplete() { + final boolean isClient = getUseClientMode(); + final HandshakeState hsState = this.handshakeState; + if (isClient) { + // the client has received TLS Finished message from server and + // has sent its own TLS Finished message and is waiting for the server + // to send QUIC HANDSHAKE_DONE frame. + // OR + // the client has received TLS Finished message from server and + // has sent its own TLS Finished message and has even received the + // QUIC HANDSHAKE_DONE frame. + // Either of these implies the TLS handshake is complete for the client + return hsState == NEED_RECV_HANDSHAKE_DONE || hsState == HANDSHAKE_CONFIRMED; + } + // on the server side the TLS handshake is complete only when the server has + // sent a TLS Finished message and received the client's Finished message. + return hsState == HANDSHAKE_CONFIRMED; + } + + /** + * {@return the key phase being used when decrypting incoming 1-RTT + * packets} + */ + // this is only used in tests + public int getOneRttKeyPhase() throws QuicKeyUnavailableException { + return this.oneRttKeyManager.getReadCipher().getKeyPhase(); + } +} diff --git a/src/java.base/share/classes/sun/security/ssl/QuicTransportParametersExtension.java b/src/java.base/share/classes/sun/security/ssl/QuicTransportParametersExtension.java new file mode 100644 index 00000000000..83e977ee446 --- /dev/null +++ b/src/java.base/share/classes/sun/security/ssl/QuicTransportParametersExtension.java @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package sun.security.ssl; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import jdk.internal.net.quic.QuicTransportException; +import sun.security.ssl.SSLExtension.ExtensionConsumer; +import sun.security.ssl.SSLHandshake.HandshakeMessage; + +/** + * Pack of the "quic_transport_parameters" extensions [RFC 9001]. + */ +final class QuicTransportParametersExtension { + + static final HandshakeProducer chNetworkProducer = + new T13CHQuicParametersProducer(); + static final ExtensionConsumer chOnLoadConsumer = + new T13CHQuicParametersConsumer(); + static final HandshakeAbsence chOnLoadAbsence = + new T13CHQuicParametersAbsence(); + static final HandshakeProducer eeNetworkProducer = + new T13EEQuicParametersProducer(); + static final ExtensionConsumer eeOnLoadConsumer = + new T13EEQuicParametersConsumer(); + static final HandshakeAbsence eeOnLoadAbsence = + new T13EEQuicParametersAbsence(); + + private static final class T13CHQuicParametersProducer + implements HandshakeProducer { + // Prevent instantiation of this class. + private T13CHQuicParametersProducer() { + } + + @Override + public byte[] produce(ConnectionContext context, + HandshakeMessage message) throws IOException { + + ClientHandshakeContext chc = (ClientHandshakeContext) context; + if (!chc.sslConfig.isQuic) { + return null; + } + QuicTLSEngineImpl quicTLSEngine = + (QuicTLSEngineImpl) chc.conContext.transport; + + return quicTLSEngine.getLocalQuicTransportParameters(); + } + + } + + private static final class T13CHQuicParametersConsumer + implements ExtensionConsumer { + // Prevent instantiation of this class. + private T13CHQuicParametersConsumer() { + } + + @Override + public void consume(ConnectionContext context, + HandshakeMessage message, ByteBuffer buffer) + throws IOException { + ServerHandshakeContext shc = (ServerHandshakeContext) context; + if (!shc.sslConfig.isQuic) { + throw shc.conContext.fatal(Alert.UNSUPPORTED_EXTENSION, + "Client sent the quic_transport_parameters " + + "extension in a non-QUIC context"); + } + QuicTLSEngineImpl quicTLSEngine = + (QuicTLSEngineImpl) shc.conContext.transport; + try { + quicTLSEngine.processRemoteQuicTransportParameters(buffer); + } catch (QuicTransportException e) { + throw shc.conContext.fatal(Alert.DECODE_ERROR, e); + } + + } + } + + private static final class T13CHQuicParametersAbsence + implements HandshakeAbsence { + // Prevent instantiation of this class. + private T13CHQuicParametersAbsence() { + } + + @Override + public void absent(ConnectionContext context, + HandshakeMessage message) throws IOException { + // The producing happens in server side only. + ServerHandshakeContext shc = (ServerHandshakeContext)context; + + if (shc.sslConfig.isQuic) { + // RFC 9001: endpoints MUST send quic_transport_parameters + throw shc.conContext.fatal( + Alert.MISSING_EXTENSION, + "Client did not send QUIC transport parameters"); + } + } + } + + private static final class T13EEQuicParametersProducer + implements HandshakeProducer { + // Prevent instantiation of this class. + private T13EEQuicParametersProducer() { + } + + @Override + public byte[] produce(ConnectionContext context, + HandshakeMessage message) { + + ServerHandshakeContext shc = (ServerHandshakeContext) context; + if (!shc.sslConfig.isQuic) { + return null; + } + QuicTLSEngineImpl quicTLSEngine = + (QuicTLSEngineImpl) shc.conContext.transport; + + return quicTLSEngine.getLocalQuicTransportParameters(); + } + } + + private static final class T13EEQuicParametersConsumer + implements ExtensionConsumer { + // Prevent instantiation of this class. + private T13EEQuicParametersConsumer() { + } + + @Override + public void consume(ConnectionContext context, + HandshakeMessage message, ByteBuffer buffer) + throws IOException { + ClientHandshakeContext chc = (ClientHandshakeContext) context; + if (!chc.sslConfig.isQuic) { + throw chc.conContext.fatal(Alert.UNSUPPORTED_EXTENSION, + "Server sent the quic_transport_parameters " + + "extension in a non-QUIC context"); + } + QuicTLSEngineImpl quicTLSEngine = + (QuicTLSEngineImpl) chc.conContext.transport; + try { + quicTLSEngine.processRemoteQuicTransportParameters(buffer); + } catch (QuicTransportException e) { + throw chc.conContext.fatal(Alert.DECODE_ERROR, e); + } + } + } + + private static final class T13EEQuicParametersAbsence + implements HandshakeAbsence { + // Prevent instantiation of this class. + private T13EEQuicParametersAbsence() { + } + + @Override + public void absent(ConnectionContext context, + HandshakeMessage message) throws IOException { + ClientHandshakeContext chc = (ClientHandshakeContext) context; + + if (chc.sslConfig.isQuic) { + // RFC 9001: endpoints MUST send quic_transport_parameters + throw chc.conContext.fatal( + Alert.MISSING_EXTENSION, + "Server did not send QUIC transport parameters"); + } + } + } +} diff --git a/src/java.base/share/classes/sun/security/ssl/SSLAlgorithmConstraints.java b/src/java.base/share/classes/sun/security/ssl/SSLAlgorithmConstraints.java index 95cfc6082be..1d5a4c4e73d 100644 --- a/src/java.base/share/classes/sun/security/ssl/SSLAlgorithmConstraints.java +++ b/src/java.base/share/classes/sun/security/ssl/SSLAlgorithmConstraints.java @@ -37,6 +37,8 @@ import java.util.List; import java.util.Set; import java.util.TreeSet; import javax.net.ssl.*; + +import jdk.internal.net.quic.QuicTLSEngine; import sun.security.util.DisabledAlgorithmConstraints; import static sun.security.util.DisabledAlgorithmConstraints.*; @@ -162,6 +164,33 @@ final class SSLAlgorithmConstraints implements AlgorithmConstraints { withDefaultCertPathConstraints); } + /** + * Returns an {@link AlgorithmConstraints} instance that uses the + * constraints configured for the given {@code engine} in addition + * to the platform configured constraints. + *

+ * If the given {@code allowedAlgorithms} is non-null then the returned + * {@code AlgorithmConstraints} will only permit those allowed algorithms. + * + * @param engine QuicTLSEngine used to determine the constraints + * @param mode SIGNATURE_CONSTRAINTS_MODE + * @param withDefaultCertPathConstraints whether or not to apply the default certpath + * algorithm constraints too + * @return a AlgorithmConstraints instance + */ + static AlgorithmConstraints forQUIC(QuicTLSEngine engine, + SIGNATURE_CONSTRAINTS_MODE mode, + boolean withDefaultCertPathConstraints) { + if (engine == null) { + return wrap(null, withDefaultCertPathConstraints); + } + + return new SSLAlgorithmConstraints( + nullIfDefault(getUserSpecifiedConstraints(engine)), + new SupportedSignatureAlgorithmConstraints(engine.getHandshakeSession(), mode), + withDefaultCertPathConstraints); + } + private static AlgorithmConstraints nullIfDefault( AlgorithmConstraints constraints) { return constraints == DEFAULT ? null : constraints; @@ -207,6 +236,17 @@ final class SSLAlgorithmConstraints implements AlgorithmConstraints { return null; } + private static AlgorithmConstraints getUserSpecifiedConstraints( + QuicTLSEngine quicEngine) { + if (quicEngine != null) { + if (quicEngine instanceof QuicTLSEngineImpl engineImpl) { + return engineImpl.getAlgorithmConstraints(); + } + return quicEngine.getSSLParameters().getAlgorithmConstraints(); + } + return null; + } + @Override public boolean permits(Set primitives, String algorithm, AlgorithmParameters parameters) { diff --git a/src/java.base/share/classes/sun/security/ssl/SSLConfiguration.java b/src/java.base/share/classes/sun/security/ssl/SSLConfiguration.java index bb032e019d3..aacac465027 100644 --- a/src/java.base/share/classes/sun/security/ssl/SSLConfiguration.java +++ b/src/java.base/share/classes/sun/security/ssl/SSLConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -78,6 +78,7 @@ final class SSLConfiguration implements Cloneable { boolean noSniExtension; boolean noSniMatcher; + boolean isQuic; // To switch off the extended_master_secret extension. static final boolean useExtendedMasterSecret; @@ -91,7 +92,7 @@ final class SSLConfiguration implements Cloneable { Utilities.getBooleanProperty("jdk.tls.allowLegacyMasterSecret", true); // Use TLS1.3 middlebox compatibility mode. - static final boolean useCompatibilityMode = Utilities.getBooleanProperty( + private static final boolean useCompatibilityMode = Utilities.getBooleanProperty( "jdk.tls.client.useCompatibilityMode", true); // Respond a close_notify alert if receiving close_notify alert. @@ -524,6 +525,14 @@ final class SSLConfiguration implements Cloneable { } } + public boolean isUseCompatibilityMode() { + return useCompatibilityMode && !isQuic; + } + + public void setQuic(boolean quic) { + isQuic = quic; + } + @Override @SuppressWarnings({"unchecked", "CloneDeclaresCloneNotSupported"}) public Object clone() { @@ -567,7 +576,10 @@ final class SSLConfiguration implements Cloneable { */ private static String[] getCustomizedSignatureScheme(String propertyName) { String property = System.getProperty(propertyName); - if (SSLLogger.isOn && SSLLogger.isOn("ssl,sslctx")) { + // this method is called from class initializer; logging here + // will occasionally pin threads and deadlock if called from a virtual thread + if (SSLLogger.isOn && SSLLogger.isOn("ssl,sslctx") + && !Thread.currentThread().isVirtual()) { SSLLogger.fine( "System property " + propertyName + " is set to '" + property + "'"); @@ -595,7 +607,8 @@ final class SSLConfiguration implements Cloneable { if (scheme != null && scheme.isAvailable) { signatureSchemes.add(schemeName); } else { - if (SSLLogger.isOn && SSLLogger.isOn("ssl,sslctx")) { + if (SSLLogger.isOn && SSLLogger.isOn("ssl,sslctx") + && !Thread.currentThread().isVirtual()) { SSLLogger.fine( "The current installed providers do not " + "support signature scheme: " + schemeName); diff --git a/src/java.base/share/classes/sun/security/ssl/SSLContextImpl.java b/src/java.base/share/classes/sun/security/ssl/SSLContextImpl.java index a0cb28201e9..85dde5b0dbb 100644 --- a/src/java.base/share/classes/sun/security/ssl/SSLContextImpl.java +++ b/src/java.base/share/classes/sun/security/ssl/SSLContextImpl.java @@ -481,6 +481,10 @@ public abstract class SSLContextImpl extends SSLContextSpi { return availableProtocols; } + public boolean isUsableWithQuic() { + return trustManager instanceof X509TrustManagerImpl; + } + /* * The SSLContext implementation for SSL/(D)TLS algorithm * diff --git a/src/java.base/share/classes/sun/security/ssl/SSLExtension.java b/src/java.base/share/classes/sun/security/ssl/SSLExtension.java index c7175ea7fdc..082914b4b4b 100644 --- a/src/java.base/share/classes/sun/security/ssl/SSLExtension.java +++ b/src/java.base/share/classes/sun/security/ssl/SSLExtension.java @@ -458,6 +458,28 @@ enum SSLExtension implements SSLStringizer { null, null, null, null, KeyShareExtension.hrrStringizer), + // Extension defined in RFC 9001 + CH_QUIC_TRANSPORT_PARAMETERS (0x0039, "quic_transport_parameters", + SSLHandshake.CLIENT_HELLO, + ProtocolVersion.PROTOCOLS_OF_13, + QuicTransportParametersExtension.chNetworkProducer, + QuicTransportParametersExtension.chOnLoadConsumer, + QuicTransportParametersExtension.chOnLoadAbsence, + null, + null, + // TODO properly stringize, rather than hex output. + null), + EE_QUIC_TRANSPORT_PARAMETERS (0x0039, "quic_transport_parameters", + SSLHandshake.ENCRYPTED_EXTENSIONS, + ProtocolVersion.PROTOCOLS_OF_13, + QuicTransportParametersExtension.eeNetworkProducer, + QuicTransportParametersExtension.eeOnLoadConsumer, + QuicTransportParametersExtension.eeOnLoadAbsence, + null, + null, + // TODO properly stringize, rather than hex output + null), + // Extensions defined in RFC 5746 (TLS Renegotiation Indication Extension) CH_RENEGOTIATION_INFO (0xff01, "renegotiation_info", SSLHandshake.CLIENT_HELLO, @@ -820,7 +842,10 @@ enum SSLExtension implements SSLStringizer { private static Collection getDisabledExtensions( String propertyName) { String property = System.getProperty(propertyName); - if (SSLLogger.isOn && SSLLogger.isOn("ssl,sslctx")) { + // this method is called from class initializer; logging here + // will occasionally pin threads and deadlock if called from a virtual thread + if (SSLLogger.isOn && SSLLogger.isOn("ssl,sslctx") + && !Thread.currentThread().isVirtual()) { SSLLogger.fine( "System property " + propertyName + " is set to '" + property + "'"); diff --git a/src/java.base/share/classes/sun/security/ssl/ServerHello.java b/src/java.base/share/classes/sun/security/ssl/ServerHello.java index d092d6c07de..1d2faa5351f 100644 --- a/src/java.base/share/classes/sun/security/ssl/ServerHello.java +++ b/src/java.base/share/classes/sun/security/ssl/ServerHello.java @@ -235,7 +235,8 @@ final class ServerHello { serverVersion.name, Utilities.toHexString(serverRandom.randomBytes), sessionId.toString(), - cipherSuite.name + "(" + Utilities.byte16HexString(cipherSuite.id) + ")", + cipherSuite.name + + "(" + Utilities.byte16HexString(cipherSuite.id) + ")", HexFormat.of().toHexDigits(compressionMethod), Utilities.indent(extensions.toString(), " ") }; @@ -534,8 +535,9 @@ final class ServerHello { // consider the handshake extension impact SSLExtension[] enabledExtensions = - shc.sslConfig.getEnabledExtensions( - SSLHandshake.CLIENT_HELLO, shc.negotiatedProtocol); + shc.sslConfig.getEnabledExtensions( + SSLHandshake.CLIENT_HELLO, + shc.negotiatedProtocol); clientHello.extensions.consumeOnTrade(shc, enabledExtensions); shc.negotiatedProtocol = @@ -670,6 +672,17 @@ final class ServerHello { // Update the context for master key derivation. shc.handshakeKeyDerivation = kd; + if (shc.sslConfig.isQuic) { + QuicTLSEngineImpl engine = + (QuicTLSEngineImpl) shc.conContext.transport; + try { + engine.deriveHandshakeKeys(); + } catch (IOException e) { + // unlikely + throw shc.conContext.fatal(Alert.HANDSHAKE_FAILURE, + "Failed to derive keys", e); + } + } // Check if the server supports stateless resumption if (sessionCache.statelessEnabled()) { shc.statelessResumption = true; @@ -784,9 +797,9 @@ final class ServerHello { // first handshake message. This may either be after // a ServerHello or a HelloRetryRequest. // (RFC 8446, Appendix D.4) - shc.conContext.outputRecord.changeWriteCiphers( - SSLWriteCipher.nullTlsWriteCipher(), - (clientHello.sessionId.length() != 0)); + if (clientHello.sessionId.length() != 0) { + shc.conContext.outputRecord.encodeChangeCipherSpec(); + } // Stateless, shall we clean up the handshake context as well? shc.handshakeHash.finish(); // forgot about the handshake hash @@ -1366,10 +1379,21 @@ final class ServerHello { // Should use resumption_master_secret for TLS 1.3. // chc.handshakeSession.setMasterSecret(masterSecret); - // Update the context for master key derivation. chc.handshakeKeyDerivation = secretKD; + if (chc.sslConfig.isQuic) { + QuicTLSEngineImpl engine = + (QuicTLSEngineImpl) chc.conContext.transport; + try { + engine.deriveHandshakeKeys(); + } catch (IOException e) { + // unlikely + throw chc.conContext.fatal(Alert.HANDSHAKE_FAILURE, + "Failed to derive keys", e); + } + } + // update the consumers and producers // // The server sends a dummy change_cipher_spec record immediately diff --git a/src/java.base/share/classes/sun/security/ssl/SunX509KeyManagerImpl.java b/src/java.base/share/classes/sun/security/ssl/SunX509KeyManagerImpl.java index 2441ad91fde..6bf138f4e45 100644 --- a/src/java.base/share/classes/sun/security/ssl/SunX509KeyManagerImpl.java +++ b/src/java.base/share/classes/sun/security/ssl/SunX509KeyManagerImpl.java @@ -195,6 +195,13 @@ final class SunX509KeyManagerImpl extends X509KeyManagerCertChecking { getAlgorithmConstraints(engine), null, null); } + @Override + String chooseQuicClientAlias(String[] keyTypes, Principal[] issuers, + QuicTLSEngineImpl quicTLSEngine) { + return chooseAlias(getKeyTypes(keyTypes), issuers, CheckType.CLIENT, + getAlgorithmConstraints(quicTLSEngine), null, null); + } + /* * Choose an alias to authenticate the server side of a secure * socket given the public key type and the list of @@ -222,6 +229,16 @@ final class SunX509KeyManagerImpl extends X509KeyManagerCertChecking { X509TrustManagerImpl.getRequestedServerNames(engine), "HTTPS"); } + @Override + String chooseQuicServerAlias(String keyType, + X500Principal[] issuers, + QuicTLSEngineImpl quicTLSEngine) { + return chooseAlias(getKeyTypes(keyType), issuers, CheckType.SERVER, + getAlgorithmConstraints(quicTLSEngine), + X509TrustManagerImpl.getRequestedServerNames(quicTLSEngine), + "HTTPS"); + } + /* * Get the matching aliases for authenticating the client side of a secure * socket given the public key type and the list of diff --git a/src/java.base/share/classes/sun/security/ssl/TransportContext.java b/src/java.base/share/classes/sun/security/ssl/TransportContext.java index 717c81723ff..49fd664e9ed 100644 --- a/src/java.base/share/classes/sun/security/ssl/TransportContext.java +++ b/src/java.base/share/classes/sun/security/ssl/TransportContext.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -489,6 +489,10 @@ final class TransportContext implements ConnectionContext { isUnsureMode = false; } + public void setQuic(boolean quic) { + sslConfig.setQuic(quic); + } + // The OutputRecord is closed and not buffered output record. boolean isOutboundDone() { return outputRecord.isClosed() && outputRecord.isEmpty(); diff --git a/src/java.base/share/classes/sun/security/ssl/X509Authentication.java b/src/java.base/share/classes/sun/security/ssl/X509Authentication.java index 4e91df2806e..5abc2cb1bf4 100644 --- a/src/java.base/share/classes/sun/security/ssl/X509Authentication.java +++ b/src/java.base/share/classes/sun/security/ssl/X509Authentication.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -218,6 +218,28 @@ enum X509Authentication implements SSLAuthentication { chc.peerSupportedAuthorities == null ? null : chc.peerSupportedAuthorities.clone(), engine); + } else if (chc.conContext.transport instanceof QuicTLSEngineImpl quicEngineImpl) { + // TODO add a method on javax.net.ssl.X509ExtendedKeyManager that + // takes QuicTLSEngine. + // For now, in context of QUIC, for KeyManager implementations other than + // subclasses of sun.security.ssl.X509KeyManagerCertChecking + // we don't take into account + // any algorithm constraints when choosing the client alias and + // just call the functionally limited + // javax.net.ssl.X509KeyManager.chooseClientAlias(...) + if (km instanceof X509KeyManagerCertChecking xkm) { + clientAlias = xkm.chooseQuicClientAlias(keyTypes, + chc.peerSupportedAuthorities == null + ? null + : chc.peerSupportedAuthorities.clone(), + quicEngineImpl); + } else { + clientAlias = km.chooseClientAlias(keyTypes, + chc.peerSupportedAuthorities == null + ? null + : chc.peerSupportedAuthorities.clone(), + null); + } } if (clientAlias == null) { @@ -290,6 +312,28 @@ enum X509Authentication implements SSLAuthentication { shc.peerSupportedAuthorities == null ? null : shc.peerSupportedAuthorities.clone(), engine); + } else if (shc.conContext.transport instanceof QuicTLSEngineImpl quicEngineImpl) { + // TODO add a method on javax.net.ssl.X509ExtendedKeyManager that + // takes QuicTLSEngine. + // For now, in context of QUIC, for KeyManager implementations other than + // subclasses of sun.security.ssl.X509KeyManagerCertChecking + // we don't take into account + // any algorithm constraints when choosing the server alias + // and just call the functionally limited + // javax.net.ssl.X509KeyManager.chooseServerAlias(...) + if (km instanceof X509KeyManagerCertChecking xkm) { + serverAlias = xkm.chooseQuicServerAlias(keyType, + shc.peerSupportedAuthorities == null + ? null + : shc.peerSupportedAuthorities.clone(), + quicEngineImpl); + } else { + serverAlias = km.chooseServerAlias(keyType, + shc.peerSupportedAuthorities == null + ? null + : shc.peerSupportedAuthorities.clone(), + null); + } } if (serverAlias == null) { diff --git a/src/java.base/share/classes/sun/security/ssl/X509KeyManagerCertChecking.java b/src/java.base/share/classes/sun/security/ssl/X509KeyManagerCertChecking.java index 162a938cddb..9484ab4f830 100644 --- a/src/java.base/share/classes/sun/security/ssl/X509KeyManagerCertChecking.java +++ b/src/java.base/share/classes/sun/security/ssl/X509KeyManagerCertChecking.java @@ -74,6 +74,15 @@ abstract class X509KeyManagerCertChecking extends X509ExtendedKeyManager { abstract boolean isCheckingDisabled(); + // TODO move this method to a public interface / class + abstract String chooseQuicClientAlias(String[] keyTypes, Principal[] issuers, + QuicTLSEngineImpl quicTLSEngine); + + // TODO move this method to a public interface / class + abstract String chooseQuicServerAlias(String keyType, + X500Principal[] issuers, + QuicTLSEngineImpl quicTLSEngine); + // Entry point to do all certificate checks. protected EntryStatus checkAlias(int keyStoreIndex, String alias, Certificate[] chain, Date verificationDate, List keyTypes, @@ -185,6 +194,17 @@ abstract class X509KeyManagerCertChecking extends X509ExtendedKeyManager { engine, SIGNATURE_CONSTRAINTS_MODE.PEER, true); } + // Gets algorithm constraints of QUIC TLS engine. + protected AlgorithmConstraints getAlgorithmConstraints(QuicTLSEngineImpl engine) { + + if (checksDisabled) { + return null; + } + + return SSLAlgorithmConstraints.forQUIC( + engine, SIGNATURE_CONSTRAINTS_MODE.PEER, true); + } + // Algorithm constraints check. private boolean conformsToAlgorithmConstraints( AlgorithmConstraints constraints, Certificate[] chain, diff --git a/src/java.base/share/classes/sun/security/ssl/X509KeyManagerImpl.java b/src/java.base/share/classes/sun/security/ssl/X509KeyManagerImpl.java index df6ecaf7a42..e48096cc363 100644 --- a/src/java.base/share/classes/sun/security/ssl/X509KeyManagerImpl.java +++ b/src/java.base/share/classes/sun/security/ssl/X509KeyManagerImpl.java @@ -129,6 +129,13 @@ final class X509KeyManagerImpl extends X509KeyManagerCertChecking { getAlgorithmConstraints(engine), null, null); } + @Override + String chooseQuicClientAlias(String[] keyTypes, Principal[] issuers, + QuicTLSEngineImpl quicTLSEngine) { + return chooseAlias(getKeyTypes(keyTypes), issuers, CheckType.CLIENT, + getAlgorithmConstraints(quicTLSEngine), null, null); + } + @Override public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { @@ -165,6 +172,16 @@ final class X509KeyManagerImpl extends X509KeyManagerCertChecking { // It is not a really HTTPS endpoint identification. } + @Override + String chooseQuicServerAlias(String keyType, + X500Principal[] issuers, + QuicTLSEngineImpl quicTLSEngine) { + return chooseAlias(getKeyTypes(keyType), issuers, CheckType.SERVER, + getAlgorithmConstraints(quicTLSEngine), + X509TrustManagerImpl.getRequestedServerNames(quicTLSEngine), + "HTTPS"); + } + @Override public String[] getClientAliases(String keyType, Principal[] issuers) { return getAliases(keyType, issuers, CheckType.CLIENT); diff --git a/src/java.base/share/classes/sun/security/ssl/X509TrustManagerImpl.java b/src/java.base/share/classes/sun/security/ssl/X509TrustManagerImpl.java index 5001181fecf..d82b94a1d7d 100644 --- a/src/java.base/share/classes/sun/security/ssl/X509TrustManagerImpl.java +++ b/src/java.base/share/classes/sun/security/ssl/X509TrustManagerImpl.java @@ -30,7 +30,9 @@ import java.security.*; import java.security.cert.*; import java.util.*; import java.util.concurrent.locks.ReentrantLock; + import javax.net.ssl.*; + import sun.security.ssl.SSLAlgorithmConstraints.SIGNATURE_CONSTRAINTS_MODE; import sun.security.util.AnchorCertificates; import sun.security.util.HostnameChecker; @@ -145,6 +147,16 @@ final class X509TrustManagerImpl extends X509ExtendedTrustManager checkTrusted(chain, authType, engine, false); } + public void checkClientTrusted(X509Certificate[] chain, String authType, + QuicTLSEngineImpl quicTLSEngine) throws CertificateException { + checkTrusted(chain, authType, quicTLSEngine, true); + } + + void checkServerTrusted(X509Certificate[] chain, String authType, + QuicTLSEngineImpl quicTLSEngine) throws CertificateException { + checkTrusted(chain, authType, quicTLSEngine, false); + } + private Validator checkTrustedInit(X509Certificate[] chain, String authType, boolean checkClientTrusted) { if (chain == null || chain.length == 0) { @@ -236,6 +248,52 @@ final class X509TrustManagerImpl extends X509ExtendedTrustManager } } + private void checkTrusted(X509Certificate[] chain, + String authType, QuicTLSEngineImpl quicTLSEngine, + boolean checkClientTrusted) throws CertificateException { + Validator v = checkTrustedInit(chain, authType, checkClientTrusted); + + final X509Certificate[] trustedChain; + if (quicTLSEngine != null) { + + final SSLSession session = quicTLSEngine.getHandshakeSession(); + if (session == null) { + throw new CertificateException("No handshake session"); + } + + // create the algorithm constraints + final AlgorithmConstraints constraints = SSLAlgorithmConstraints.forQUIC( + quicTLSEngine, SIGNATURE_CONSTRAINTS_MODE.LOCAL, false); + final List responseList; + // grab any stapled OCSP responses for use in validation + if (!checkClientTrusted && + session instanceof ExtendedSSLSession extSession) { + responseList = extSession.getStatusResponses(); + } else { + responseList = Collections.emptyList(); + } + // do the certificate chain validation + trustedChain = v.validate(chain, null, responseList, + constraints, checkClientTrusted ? null : authType); + + // check endpoint identity + String identityAlg = quicTLSEngine.getSSLParameters(). + getEndpointIdentificationAlgorithm(); + if (identityAlg != null && !identityAlg.isEmpty()) { + checkIdentity(session, trustedChain, + identityAlg, checkClientTrusted); + } + } else { + trustedChain = v.validate(chain, null, Collections.emptyList(), + null, checkClientTrusted ? null : authType); + } + + if (SSLLogger.isOn && SSLLogger.isOn("ssl,trustmanager")) { + SSLLogger.fine("Found trusted certificate", + trustedChain[trustedChain.length - 1]); + } + } + private void checkTrusted(X509Certificate[] chain, String authType, SSLEngine engine, boolean checkClientTrusted) throws CertificateException { @@ -344,6 +402,13 @@ final class X509TrustManagerImpl extends X509ExtendedTrustManager return Collections.emptyList(); } + static List getRequestedServerNames(QuicTLSEngineImpl engine) { + if (engine != null) { + return getRequestedServerNames(engine.getHandshakeSession()); + } + return Collections.emptyList(); + } + private static List getRequestedServerNames( SSLSession session) { if (session instanceof ExtendedSSLSession) { diff --git a/src/java.base/share/conf/security/java.security b/src/java.base/share/conf/security/java.security index 32d1ddaf0f7..2464361b9ef 100644 --- a/src/java.base/share/conf/security/java.security +++ b/src/java.base/share/conf/security/java.security @@ -971,6 +971,33 @@ jdk.tls.legacyAlgorithms=NULL, anon, RC4, DES, 3DES_EDE_CBC jdk.tls.keyLimits=AES/GCM/NoPadding KeyUpdate 2^37, \ ChaCha20-Poly1305 KeyUpdate 2^37 +# +# QUIC TLS key limits on symmetric cryptographic algorithms +# +# This security property sets limits on algorithms key usage in QUIC. +# When the number of encrypted datagrams reaches the algorithm value +# listed below, key update operation will be initiated. +# +# The syntax for the property is described below: +# KeyLimits: +# " KeyLimit { , KeyLimit } " +# +# KeyLimit: +# AlgorithmName Length +# +# AlgorithmName: +# A full algorithm transformation. +# +# Length: +# The amount of encrypted data in a session before the Action occurs +# This value may be an integer value in bytes, or as a power of two, 2^23. +# +# Note: This property is currently used by OpenJDK's JSSE implementation. It +# is not guaranteed to be examined and used by other implementations. +# +jdk.quic.tls.keyLimits=AES/GCM/NoPadding 2^23, \ + ChaCha20-Poly1305 2^23 + # # Cryptographic Jurisdiction Policy defaults # diff --git a/src/java.net.http/share/classes/java/net/http/HttpClient.java b/src/java.net.http/share/classes/java/net/http/HttpClient.java index 59afff013c7..4ce77486e70 100644 --- a/src/java.net.http/share/classes/java/net/http/HttpClient.java +++ b/src/java.net.http/share/classes/java/net/http/HttpClient.java @@ -28,9 +28,11 @@ package java.net.http; import java.io.IOException; import java.io.UncheckedIOException; import java.net.InetAddress; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse.BodyHandlers; import java.net.http.HttpResponse.BodySubscriber; import java.net.http.HttpResponse.BodySubscribers; +import java.net.URI; import java.nio.channels.Selector; import java.net.Authenticator; import java.net.CookieHandler; @@ -59,7 +61,7 @@ import jdk.internal.net.http.HttpClientBuilderImpl; * The {@link #newBuilder() newBuilder} method returns a builder that creates * instances of the default {@code HttpClient} implementation. * The builder can be used to configure per-client state, like: the preferred - * protocol version ( HTTP/1.1 or HTTP/2 ), whether to follow redirects, a + * protocol version ( HTTP/1.1, HTTP/2 or HTTP/3 ), whether to follow redirects, a * proxy, an authenticator, etc. Once built, an {@code HttpClient} is immutable, * and can be used to send multiple requests. * @@ -162,6 +164,59 @@ import jdk.internal.net.http.HttpClientBuilderImpl; * prevent the resources allocated by the associated client from * being reclaimed by the garbage collector. * + *

+ * The default implementation of the {@code HttpClient} supports HTTP/1.1, + * HTTP/2, and HTTP/3. Which version of the protocol is actually used when sending + * a request can depend on multiple factors. In the case of HTTP/2, it may depend + * on an initial upgrade to succeed (when using a plain connection), or on HTTP/2 + * being successfully negotiated during the Transport Layer Security (TLS) handshake. + * + *

If {@linkplain Version#HTTP_2 HTTP/2} is selected over a clear + * connection, and no HTTP/2 connection to the + * origin server + * already exists, the client will create a new connection and attempt an upgrade + * from HTTP/1.1 to HTTP/2. + * If the upgrade succeeds, then the response to this request will use HTTP/2. + * If the upgrade fails, then the response will be handled using HTTP/1.1. + * + *

Other constraints may also affect the selection of protocol version. + * For example, if HTTP/2 is requested through a proxy, and if the implementation + * does not support this mode, then HTTP/1.1 may be used. + *

+ * The HTTP/3 protocol is not selected by default, but can be enabled by setting + * the {@linkplain Builder#version(Version) HttpClient preferred version} or the + * {@linkplain HttpRequest.Builder#version(Version) HttpRequest preferred version} to + * {@linkplain Version#HTTP_3 HTTP/3}. Like for HTTP/2, which protocol version is + * actually used when HTTP/3 is enabled may depend on several factors. + * {@linkplain HttpOption#H3_DISCOVERY Configuration hints} can + * be {@linkplain HttpRequest.Builder#setOption(HttpOption, Object) provided} + * to help the {@code HttpClient} implementation decide how to establish + * and carry out the HTTP exchange when the HTTP/3 protocol is enabled. + * If no configuration hints are provided, the {@code HttpClient} will select + * one as explained in the {@link HttpOption#H3_DISCOVERY H3_DISCOVERY} + * option API documentation. + *
Note that a request whose {@linkplain URI#getScheme() URI scheme} is not + * {@code "https"} will never be sent over HTTP/3. In this implementation, + * HTTP/3 is not used if a proxy is selected. + * + *

+ * If a concrete instance of {@link HttpClient} doesn't support sending a + * request through HTTP/3, an {@link UnsupportedProtocolVersionException} may be + * thrown, either when {@linkplain Builder#build() building} the client with + * a {@linkplain Builder#version(Version) preferred version} set to HTTP/3, + * or when attempting to send a request with {@linkplain HttpRequest.Builder#version(Version) + * HTTP/3 enabled} when {@link Http3DiscoveryMode#HTTP_3_URI_ONLY HTTP_3_URI_ONLY} + * was {@linkplain HttpRequest.Builder#setOption(HttpOption, Object) specified}. + * This may typically happen if the {@link #sslContext() SSLContext} + * or {@link #sslParameters() SSLParameters} configured on the client instance cannot + * be used with HTTP/3. + * + * @see UnsupportedProtocolVersionException + * @see Builder#version(Version) + * @see HttpRequest.Builder#version(Version) + * @see HttpRequest.Builder#setOption(HttpOption, Object) + * @see HttpOption#H3_DISCOVERY + * * @since 11 */ public abstract class HttpClient implements AutoCloseable { @@ -320,23 +375,19 @@ public abstract class HttpClient implements AutoCloseable { public Builder followRedirects(Redirect policy); /** - * Requests a specific HTTP protocol version where possible. + * Sets the default preferred HTTP protocol version for requests + * issued by this client. * *

If this method is not invoked prior to {@linkplain #build() * building}, then newly built clients will prefer {@linkplain * Version#HTTP_2 HTTP/2}. * - *

If set to {@linkplain Version#HTTP_2 HTTP/2}, then each request - * will attempt to upgrade to HTTP/2. If the upgrade succeeds, then the - * response to this request will use HTTP/2 and all subsequent requests - * and responses to the same - * origin server - * will use HTTP/2. If the upgrade fails, then the response will be - * handled using HTTP/1.1 + *

If a request doesn't have a preferred version, then + * the effective preferred version for the request is the + * client's preferred version.

* - * @implNote Constraints may also affect the selection of protocol version. - * For example, if HTTP/2 is requested through a proxy, and if the implementation - * does not support this mode, then HTTP/1.1 may be used + * @implNote Some constraints may also affect the {@linkplain + * HttpClient##ProtocolVersionSelection selection of the actual protocol version}. * * @param version the requested HTTP protocol version * @return this builder @@ -439,9 +490,14 @@ public abstract class HttpClient implements AutoCloseable { * @return a new {@code HttpClient} * * @throws UncheckedIOException may be thrown if underlying IO resources required - * by the implementation cannot be allocated. For instance, + * by the implementation cannot be allocated, or if the resulting configuration + * does not satisfy the implementation requirements. For instance, * if the implementation requires a {@link Selector}, and opening - * one fails due to {@linkplain Selector#open() lack of necessary resources}. + * one fails due to {@linkplain Selector#open() lack of necessary resources}, + * or if the {@linkplain #version(Version) preferred protocol version} is not + * {@linkplain HttpClient##UnsupportedProtocolVersion supported by + * the implementation or cannot be used in this configuration}. + * */ public HttpClient build(); } @@ -525,9 +581,11 @@ public abstract class HttpClient implements AutoCloseable { * Returns the preferred HTTP protocol version for this client. The default * value is {@link HttpClient.Version#HTTP_2} * - * @implNote Constraints may also affect the selection of protocol version. - * For example, if HTTP/2 is requested through a proxy, and if the - * implementation does not support this mode, then HTTP/1.1 may be used + * @implNote + * The protocol version that the {@code HttpClient} eventually + * decides to use for a request is affected by various factors noted + * in {@linkplain ##ProtocolVersionSelection protocol version selection} + * section. * * @return the HTTP protocol version requested */ @@ -562,7 +620,13 @@ public abstract class HttpClient implements AutoCloseable { /** * HTTP version 2 */ - HTTP_2 + HTTP_2, + + /** + * HTTP version 3 + * @since 26 + */ + HTTP_3 } /** diff --git a/src/java.net.http/share/classes/java/net/http/HttpOption.java b/src/java.net.http/share/classes/java/net/http/HttpOption.java new file mode 100644 index 00000000000..cbff11f71ee --- /dev/null +++ b/src/java.net.http/share/classes/java/net/http/HttpOption.java @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package java.net.http; + +import java.net.ProxySelector; +import java.net.URI; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest.Builder; + +/** + * This interface is used to provide additional request configuration + * option hints on how an HTTP request/response exchange should + * be carried out by the {@link HttpClient} implementation. + * Request configuration option hints can be provided to an + * {@link HttpRequest} with the {@link + * Builder#setOption(HttpOption, Object) HttpRequest.Builder + * setOption} method. + * + *

Concrete instances of this class and its subclasses are immutable. + * + * @apiNote + * In this version, the {@code HttpOption} interface is sealed and + * only allows the {@link #H3_DISCOVERY} option. However, it could be + * extended in the future to support additional options. + *

+ * The {@link #H3_DISCOVERY} option can be used to help the + * {@link HttpClient} decide how to select or establish an + * HTTP/3 connection through which to carry out an HTTP/3 + * request/response exchange. + * + * @param The {@linkplain #type() type of the option value} + * + * @since 26 + */ +public sealed interface HttpOption permits HttpRequestOptionImpl { + /** + * {@return the option name} + * + * @implSpec Different options must have different names. + */ + String name(); + + /** + * {@return the type of the value associated with the option} + * + * @apiNote Different options may have the same type. + */ + Class type(); + + /** + * An option that can be used to configure how the {@link HttpClient} will + * select or establish an HTTP/3 connection through which to carry out + * the request. If {@link Version#HTTP_3} is not selected either as + * the {@linkplain Builder#version(Version) request preferred version} + * or the {@linkplain HttpClient.Builder#version(Version) HttpClient + * preferred version} setting this option on the request has no effect. + *

+ * The {@linkplain #name() name of this option} is {@code "H3_DISCOVERY"}. + * + * @implNote + * The JDK built-in implementation of the {@link HttpClient} understands the + * request option {@link #H3_DISCOVERY} hint. + *
+ * If no {@code H3_DISCOVERY} hint is provided, and the {@linkplain Version#HTTP_3 + * HTTP/3 version} is selected, either as {@linkplain Builder#version(Version) + * request preferred version} or {@linkplain HttpClient.Builder#version(Version) + * client preferred version}, the JDK built-in implementation will establish + * the exchange as per {@link Http3DiscoveryMode#ANY}. + *

+ * In case of {@linkplain HttpClient.Redirect redirect}, the + * {@link #H3_DISCOVERY} option, if present, is always transferred to + * the new request. + *

+ * In this implementation, HTTP/3 through proxies is not supported. + * Unless {@link Http3DiscoveryMode#HTTP_3_URI_ONLY} is specified, if + * a {@linkplain HttpClient.Builder#proxy(ProxySelector) proxy} is {@linkplain + * ProxySelector#select(URI) selected} for the {@linkplain HttpRequest#uri() + * request URI}, the protocol version is downgraded to HTTP/2 or + * HTTP/1.1 and the {@link #H3_DISCOVERY} option is ignored. If, on the + * other hand, {@link Http3DiscoveryMode#HTTP_3_URI_ONLY} is specified, + * the request will fail. + * + * @see Http3DiscoveryMode + * @see Builder#setOption(HttpOption, Object) + */ + HttpOption H3_DISCOVERY = + new HttpRequestOptionImpl<>(Http3DiscoveryMode.class, "H3_DISCOVERY"); + + /** + * This enumeration can be used to help the {@link HttpClient} decide + * how an HTTP/3 exchange should be established, and can be provided + * as the value of the {@link HttpOption#H3_DISCOVERY} option + * to {@link Builder#setOption(HttpOption, Object) Builder.setOption}. + *

+ * Note that if neither the {@linkplain Builder#version(Version) request preferred + * version} nor the {@linkplain HttpClient.Builder#version(Version) client preferred + * version} is {@linkplain Version#HTTP_3 HTTP/3}, no HTTP/3 exchange will + * be established and the {@code Http3DiscoveryMode} is ignored. + * + * @since 26 + */ + enum Http3DiscoveryMode { + /** + * This instructs the {@link HttpClient} to use its own implementation + * specific algorithm to find or establish a connection for the exchange. + * Typically, if no connection was previously established with the origin + * server defined by the request URI, the {@link HttpClient} implementation + * may attempt to establish both an HTTP/3 connection over QUIC and an HTTP + * connection over TLS/TCP at the authority present in the request URI, + * and use the first that succeeds. The exchange may then be carried out with + * any of the {@linkplain Version + * three HTTP protocol versions}, depending on which method succeeded first. + * + * @implNote + * If the {@linkplain Builder#version(Version) request preferred version} is {@linkplain + * Version#HTTP_3 HTTP/3}, the {@code HttpClient} may give priority to HTTP/3 by + * attempting to establish an HTTP/3 connection, before attempting a TLS + * connection over TCP. + *

+ * When attempting an HTTP/3 connection in this mode, the {@code HttpClient} may + * use any HTTP Alternative Services + * information it may have previously obtained from the origin server. If no + * such information is available, a direct HTTP/3 connection at the authority (host, port) + * present in the {@linkplain HttpRequest#uri() request URI} will be attempted. + */ + ANY, + /** + * This instructs the {@link HttpClient} to only use the + * HTTP Alternative Services + * to find or establish an HTTP/3 connection with the origin server. + * The exchange may then be carried out with any of the {@linkplain + * Version three HTTP protocol versions}, depending on + * whether an Alternate Service record for HTTP/3 could be found, and which HTTP version + * was negotiated with the origin server, if no such record could be found. + *

+ * In this mode, requests sent to the origin server will be sent through HTTP/1.1 or HTTP/2 + * until a {@code h3} HTTP Alternative Services + * endpoint for that server is advertised to the client. Usually, an alternate service is + * advertised by a server when responding to a request, so that subsequent requests can make + * use of that alternative service. + */ + ALT_SVC, + /** + * This instructs the {@link HttpClient} to only attempt an HTTP/3 connection + * with the origin server. The connection will only succeed if the origin server + * is listening for incoming HTTP/3 connections over QUIC at the same authority (host, port) + * as defined in the {@linkplain HttpRequest#uri() request URI}. In this mode, + * HTTP Alternative Services + * are not used. + */ + HTTP_3_URI_ONLY + } + +} diff --git a/src/java.net.http/share/classes/java/net/http/HttpRequest.java b/src/java.net.http/share/classes/java/net/http/HttpRequest.java index 84a521336b6..c56328ba4b4 100644 --- a/src/java.net.http/share/classes/java/net/http/HttpRequest.java +++ b/src/java.net.http/share/classes/java/net/http/HttpRequest.java @@ -29,6 +29,7 @@ import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.net.URI; +import java.net.http.HttpClient.Version; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.nio.charset.Charset; @@ -91,6 +92,24 @@ public abstract class HttpRequest { */ protected HttpRequest() {} + /** + * {@return the value configured on this request for the given option, if any} + * @param option a request configuration option + * @param the type of the option + * + * @see Builder#setOption(HttpOption, Object) + * + * @implSpec + * The default implementation of this method returns {@link Optional#empty()} + * if {@code option} is non-null, otherwise throws {@link NullPointerException}. + * + * @since 26 + */ + public Optional getOption(HttpOption option) { + Objects.requireNonNull(option); + return Optional.empty(); + } + /** * A builder of {@linkplain HttpRequest HTTP requests}. * @@ -144,14 +163,53 @@ public abstract class HttpRequest { * *

The corresponding {@link HttpResponse} should be checked for the * version that was actually used. If the version is not set in a - * request, then the version requested will be that of the sending - * {@link HttpClient}. + * request, then the version requested will be {@linkplain + * HttpClient.Builder#version(Version) that of the sending + * {@code HttpClient}}. + * + * @implNote + * Constraints may also affect the {@linkplain HttpClient##ProtocolVersionSelection + * selection of the actual protocol version}. * * @param version the HTTP protocol version requested * @return this builder */ public Builder version(HttpClient.Version version); + /** + * Provides request configuration option hints modeled as key value pairs + * to help an {@link HttpClient} implementation decide how the + * request/response exchange should be established or carried out. + * + *

An {@link HttpClient} implementation may decide to ignore request + * configuration option hints, or fail the request, if provided with any + * option hints that it does not understand. + *

+ * If this method is invoked twice for the same {@linkplain HttpOption + * request option}, any value previously provided to this builder for the + * corresponding option is replaced by the new value. + * If {@code null} is supplied as a value, any value previously + * provided is discarded. + * + * @implSpec + * The default implementation of this method discards the provided option + * hint and does nothing. + * + * @implNote + * The JDK built-in implementation of the {@link HttpClient} understands the + * request option {@link HttpOption#H3_DISCOVERY} hint. + * + * @param option the request configuration option + * @param value the request configuration option value (can be null) + * + * @return this builder + * + * @see HttpRequest#getOption(HttpOption) + * + * @since 26 + */ + public default Builder setOption(HttpOption option, T value) { return this; } + /** * Adds the given name value pair to the set of headers for this request. * The given value is added to the list of values for that name. @@ -394,6 +452,8 @@ public abstract class HttpRequest { } } ); + request.getOption(HttpOption.H3_DISCOVERY) + .ifPresent(opt -> builder.setOption(HttpOption.H3_DISCOVERY, opt)); return builder; } diff --git a/src/java.net.http/share/classes/java/net/http/HttpRequestOptionImpl.java b/src/java.net.http/share/classes/java/net/http/HttpRequestOptionImpl.java new file mode 100644 index 00000000000..f5562c7068b --- /dev/null +++ b/src/java.net.http/share/classes/java/net/http/HttpRequestOptionImpl.java @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package java.net.http; + +// Package private implementation of HttpRequest options +record HttpRequestOptionImpl(Class type, String name) + implements HttpOption { + @Override + public String toString() { + return name(); + } +} diff --git a/src/java.net.http/share/classes/java/net/http/HttpResponse.java b/src/java.net.http/share/classes/java/net/http/HttpResponse.java index 52f5298452a..9843e4c7c5b 100644 --- a/src/java.net.http/share/classes/java/net/http/HttpResponse.java +++ b/src/java.net.http/share/classes/java/net/http/HttpResponse.java @@ -803,24 +803,66 @@ public interface HttpResponse { /** * A handler for push promises. * - *

A push promise is a synthetic request sent by an HTTP/2 server + *

A push promise is a synthetic request sent by an HTTP/2 or HTTP/3 server * when retrieving an initiating client-sent request. The server has * determined, possibly through inspection of the initiating request, that * the client will likely need the promised resource, and hence pushes a * synthetic push request, in the form of a push promise, to the client. The * client can choose to accept or reject the push promise request. * - *

A push promise request may be received up to the point where the + *

For HTTP/2, a push promise request may be received up to the point where the * response body of the initiating client-sent request has been fully * received. The delivery of a push promise response, however, is not * coordinated with the delivery of the response to the initiating - * client-sent request. + * client-sent request. These are delivered with the + * {@link #applyPushPromise(HttpRequest, HttpRequest, Function)} method. + *

+ * For HTTP/3, push promises are handled in a similar way, except that promises + * of the same resource (request URI, request headers and response body) can be + * promised multiple times, but are only delivered by the server (and this API) + * once though the method {@link #applyPushPromise(HttpRequest, HttpRequest, PushId, Function)}. + * Subsequent promises of the same resource, receive a notification only + * of the promise by the method {@link #notifyAdditionalPromise(HttpRequest, PushId)}. + * The same {@link PushPromiseHandler.PushId} is supplied for each of these + * notifications. Additionally, HTTP/3 push promises are not restricted to a context + * of a single initiating request. The same push promise can be delivered and then notified + * across multiple client initiated requests within the same HTTP/3 (QUIC) connection. * * @param the push promise response body type * @since 11 */ public interface PushPromiseHandler { + /** + * Represents a HTTP/3 PushID. PushIds can be shared across + * multiple client initiated requests on the same HTTP/3 connection. + * @since 26 + */ + public sealed interface PushId { + + /** + * Represents an HTTP/3 PushId. + * + * @param pushId the pushId as a long + * @param connectionLabel the {@link HttpResponse#connectionLabel()} + * of the HTTP/3 connection + * @apiNote + * The {@code connectionLabel} should be considered opaque, and ensures that + * two long pushId emitted by different connections correspond to distinct + * instances of {@code PushId}. The {@code pushId} corresponds to the + * unique push ID assigned by the server that identifies a given server + * push on that connection, as defined by + * RFC 9114, + * section 4.6 + * + * @spec https://www.rfc-editor.org/info/rfc9114 + * RFC 9114: HTTP/3 + * + * @since 26 + */ + record Http3PushId(long pushId, String connectionLabel) implements PushId { } + } + /** * Notification of an incoming push promise. * @@ -838,6 +880,12 @@ public interface HttpResponse { * then the push promise is rejected. The {@code acceptor} function will * throw an {@code IllegalStateException} if invoked more than once. * + *

This method is invoked for all HTTP/2 push promises and also + * by default for the first promise of all HTTP/3 push promises. + * If {@link #applyPushPromise(HttpRequest, HttpRequest, PushId, Function)} + * is overridden, then this method is not directly invoked for HTTP/3 + * push promises. + * * @param initiatingRequest the initiating client-send request * @param pushPromiseRequest the synthetic push request * @param acceptor the acceptor function that must be successfully @@ -849,6 +897,67 @@ public interface HttpResponse { Function,CompletableFuture>> acceptor ); + /** + * Notification of the first occurrence of an HTTP/3 incoming push promise. + * + * Subsequent promises of the same resource (with the same PushId) are notified + * using {@link #notifyAdditionalPromise(HttpRequest, PushId) + * notifyAdditionalPromise(HttpRequest, PushId)}. + * + *

This method is invoked once for each push promise received, up + * to the point where the response body of the initiating client-sent + * request has been fully received. + * + *

A push promise is accepted by invoking the given {@code acceptor} + * function. The {@code acceptor} function must be passed a non-null + * {@code BodyHandler}, that is to be used to handle the promise's + * response body. The acceptor function will return a {@code + * CompletableFuture} that completes with the promise's response. + * + *

If the {@code acceptor} function is not successfully invoked, + * then the push promise is rejected. The {@code acceptor} function will + * throw an {@code IllegalStateException} if invoked more than once. + * + * @implSpec the default implementation invokes + * {@link #applyPushPromise(HttpRequest, HttpRequest, Function)}. This allows + * {@code PushPromiseHandlers} from previous releases to handle HTTP/3 push + * promise in a reasonable way. + * + * @param initiatingRequest the client request that resulted in the promise + * @param pushPromiseRequest the promised HttpRequest from the server + * @param pushid the PushId which can be linked to subsequent notifications + * @param acceptor the acceptor function that must be successfully + * invoked to accept the push promise + * + * @since 26 + */ + public default void applyPushPromise( + HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + PushId pushid, + Function,CompletableFuture>> acceptor + ) { + applyPushPromise(initiatingRequest, pushPromiseRequest, acceptor); + } + + /** + * Invoked for each additional HTTP/3 Push Promise. The {@code pushid} links the promise to the + * original promised {@link HttpRequest} and {@link HttpResponse}. Additional promises + * generally result from different client initiated requests. + * + * @implSpec + * The default implementation of this method does nothing. + * + * @param initiatingRequest the client initiated request which resulted in the push + * @param pushid the pushid which may have been notified previously + * + * @since 26 + */ + public default void notifyAdditionalPromise( + HttpRequest initiatingRequest, + PushId pushid + ) { + } /** * Returns a push promise handler that accumulates push promises, and @@ -915,7 +1024,7 @@ public interface HttpResponse { * * @apiNote To ensure that all resources associated with the corresponding * HTTP exchange are properly released, an implementation of {@code - * BodySubscriber} should ensure to {@linkplain Flow.Subscription#request + * BodySubscriber} should ensure to {@linkplain Flow.Subscription#request(long) * request} more data until one of {@link #onComplete() onComplete} or * {@link #onError(Throwable) onError} are signalled, or {@link * Flow.Subscription#cancel cancel} its {@linkplain @@ -957,7 +1066,7 @@ public interface HttpResponse { * {@snippet : * // Streams the response body to a File * HttpResponse response = client - * .send(request, responseInfo -> BodySubscribers.ofFile(Paths.get("example.html")); } + * .send(request, responseInfo -> BodySubscribers.ofFile(Paths.get("example.html"))); } * * {@snippet : * // Accumulates the response body and returns it as a byte[] diff --git a/src/java.net.http/share/classes/java/net/http/StreamLimitException.java b/src/java.net.http/share/classes/java/net/http/StreamLimitException.java new file mode 100644 index 00000000000..583b515b01b --- /dev/null +++ b/src/java.net.http/share/classes/java/net/http/StreamLimitException.java @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package java.net.http; + +import java.io.IOException; +import java.io.InvalidObjectException; +import java.io.ObjectInputStream; +import java.net.http.HttpClient.Version; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.PushPromiseHandler; +import java.util.Objects; + +/** + * An exception raised when the limit imposed for stream creation on an + * HTTP connection is reached, and the client is unable to create a new + * stream. + *

+ * A {@code StreamLimitException} may be raised when attempting to send + * a new request on any {@linkplain #version() + * protocol version} that supports multiplexing on a single connection. Both + * {@linkplain HttpClient.Version#HTTP_2 HTTP/2} and {@linkplain + * HttpClient.Version#HTTP_3 HTTP/3} allow multiplexing concurrent requests + * to the same server on a single connection. Each request/response exchange + * is carried over a single stream, as defined by the corresponding + * protocol. + *

+ * Whether and when a {@code StreamLimitException} may be + * relayed to the code initiating a request/response exchange is + * implementation and protocol version dependent. + * + * @see HttpClient#send(HttpRequest, BodyHandler) + * @see HttpClient#sendAsync(HttpRequest, BodyHandler) + * @see HttpClient#sendAsync(HttpRequest, BodyHandler, PushPromiseHandler) + * + * @since 26 + */ +public final class StreamLimitException extends IOException { + + @java.io.Serial + private static final long serialVersionUID = 2614981180406031159L; + + /** + * The version of the HTTP protocol on which the stream limit exception occurred. + * Must not be null. + * @serial + */ + private final Version version; + + /** + * Creates a new {@code StreamLimitException} + * @param version the version of the protocol on which the stream limit exception + * occurred. Must not be null. + * @param message the detailed exception message, which can be null. + */ + public StreamLimitException(final Version version, final String message) { + super(message); + this.version = Objects.requireNonNull(version); + } + + /** + * {@return the protocol version for which the exception was raised} + */ + public final Version version() { + return version; + } + + /** + * Restores the state of a {@code StreamLimitException} from the stream + * @param in the input stream + * @throws IOException if the class of a serialized object could not be found. + * @throws ClassNotFoundException if an I/O error occurs. + * @throws InvalidObjectException if {@code version} is null. + */ + @java.io.Serial + private void readObject(ObjectInputStream in) + throws IOException, ClassNotFoundException { + in.defaultReadObject(); + if (version == null) { + throw new InvalidObjectException("version must not be null"); + } + } +} diff --git a/src/java.net.http/share/classes/java/net/http/UnsupportedProtocolVersionException.java b/src/java.net.http/share/classes/java/net/http/UnsupportedProtocolVersionException.java new file mode 100644 index 00000000000..eecc039e5d2 --- /dev/null +++ b/src/java.net.http/share/classes/java/net/http/UnsupportedProtocolVersionException.java @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package java.net.http; + +import java.io.IOException; +import java.io.Serial; +import java.net.http.HttpClient.Builder; + +/** + * Thrown when the HTTP client doesn't support a particular HTTP version. + * @apiNote + * Typically, this exception may be thrown when attempting to + * {@linkplain Builder#build() build} an {@link java.net.http.HttpClient} + * configured to use {@linkplain java.net.http.HttpClient.Version#HTTP_3 + * HTTP version 3} by default, when the underlying {@link javax.net.ssl.SSLContext + * SSLContext} implementation does not meet the requirements for supporting + * the HttpClient's implementation of the underlying QUIC transport protocol. + * @since 26 + */ +public final class UnsupportedProtocolVersionException extends IOException { + + @Serial + private static final long serialVersionUID = 981344214212332893L; + + /** + * Constructs an {@code UnsupportedProtocolVersionException} with the given detail message. + * + * @param message The detail message; can be {@code null} + */ + public UnsupportedProtocolVersionException(String message) { + super(message); + } +} diff --git a/src/java.net.http/share/classes/java/net/http/package-info.java b/src/java.net.http/share/classes/java/net/http/package-info.java index 9958fd94da0..1b8395c2706 100644 --- a/src/java.net.http/share/classes/java/net/http/package-info.java +++ b/src/java.net.http/share/classes/java/net/http/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -26,7 +26,7 @@ /** *

HTTP Client and WebSocket APIs

* - *

Provides high-level client interfaces to HTTP (versions 1.1 and 2) and + *

Provides high-level client interfaces to HTTP (versions 1.1, 2, and 3) and * low-level client interfaces to WebSocket. The main types defined are: * *

    @@ -37,10 +37,12 @@ *
* *

The protocol-specific requirements are defined in the - * Hypertext Transfer Protocol - * Version 2 (HTTP/2), the + * Hypertext Transfer Protocol + * Version 3 (HTTP/3), the + * Hypertext Transfer Protocol Version 2 (HTTP/2), the + * * Hypertext Transfer Protocol (HTTP/1.1), and - * The WebSocket Protocol. + * The WebSocket Protocol. * *

In general, asynchronous tasks execute in either the thread invoking * the operation, e.g. {@linkplain HttpClient#send(HttpRequest, BodyHandler) @@ -66,6 +68,15 @@ *

Unless otherwise stated, {@code null} parameter values will cause methods * of all classes in this package to throw {@code NullPointerException}. * + * @spec https://www.rfc-editor.org/info/rfc9114 + * RFC 9114: HTTP/3 + * @spec https://www.rfc-editor.org/info/rfc7540 + * RFC 7540: Hypertext Transfer Protocol Version 2 (HTTP/2) + * @spec https://www.rfc-editor.org/info/rfc2616 + * RFC 2616: Hypertext Transfer Protocol -- HTTP/1.1 + * @spec https://www.rfc-editor.org/info/rfc6455 + * RFC 6455: The WebSocket Protocol + * * @since 11 */ package java.net.http; diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/AltServicesRegistry.java b/src/java.net.http/share/classes/jdk/internal/net/http/AltServicesRegistry.java new file mode 100644 index 00000000000..08161bcd110 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/AltServicesRegistry.java @@ -0,0 +1,569 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http; + +import java.net.URI; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Predicate; +import java.util.stream.Stream; + +import javax.net.ssl.SNIServerName; + +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.TimeSource; +import jdk.internal.net.http.common.Utils; + +/** + * A registry for Alternate Services advertised by server endpoints. + * There is one registry per HttpClient. + */ +public final class AltServicesRegistry { + + // id and logger for debugging purposes: the id is the same for the HttpClientImpl. + private final long id; + private final Logger debug = Utils.getDebugLogger(this::dbgString); + + // The key is the origin of the alternate service + // The value is a list of AltService records declared by the origin. + private final Map> altServices = new HashMap<>(); + // alt services which were marked invalid in context of an origin. the reason for + // them being invalid can be connection issues (for example: the alt service didn't present the + // certificate of the origin) + private final InvalidAltServices invalidAltServices = new InvalidAltServices(); + + // used while dealing with both altServices Map and the invalidAltServices Set + private final ReentrantLock registryLock = new ReentrantLock(); + + public AltServicesRegistry(long id) { + this.id = id; + } + + String dbgString() { + return "AltServicesRegistry(" + id + ")"; + } + + public static final class AltService { + // As defined in RFC-7838, section 2, formally an alternate service is a combination of + // ALPN, host and port + public record Identity(String alpn, String host, int port) { + public Identity { + Objects.requireNonNull(alpn); + Objects.requireNonNull(host); + if (port <= 0) { + throw new IllegalArgumentException("Invalid port: " + port); + } + } + + public boolean matches(AltService service) { + return equals(service.identity()); + } + + @Override + public String toString() { + return alpn + "=\"" + Origin.toAuthority(host, port) +"\""; + } + } + + private record AltServiceData(Identity id, Origin origin, Deadline deadline, + boolean persist, boolean advertised, + String authority, + boolean sameAuthorityAsOrigin) { + public String pretty() { + return "AltSvc: " + id + + "; origin=\"" + origin + "\"" + + "; deadline=" + deadline + + "; persist=" + persist + + "; advertised=" + advertised + + "; sameAuthorityAsOrigin=" + sameAuthorityAsOrigin + + ';'; + } + } + private final AltServiceData svc; + + /** + * @param id the alpn, host and port of this alternate service + * @param origin the {@link Origin} for this alternate service + * @param deadline the deadline until which this endpoint is valid + * @param persist whether that information can be persisted (we don't use this) + * @param advertised Whether or not this alt service was advertised as an alt service. + * In certain cases, an alt service is created when no origin server + * has advertised it. In those cases, this param is {@code false} + */ + private AltService(final Identity id, final Origin origin, Deadline deadline, + final boolean persist, + final boolean advertised) { + Objects.requireNonNull(id); + Objects.requireNonNull(origin); + assert origin.isSecure() : "origin " + origin + " is not secure"; + deadline = deadline == null ? Deadline.MAX : deadline; + final String authority = Origin.toAuthority(id.host, id.port); + final String originAuthority = Origin.toAuthority(origin.host(), origin.port()); + // keep track of whether the authority of this alt service is same as that + // of the origin + final boolean sameAuthorityAsOrigin = authority.equals(originAuthority); + svc = new AltServiceData(id, origin, deadline, persist, advertised, + authority, sameAuthorityAsOrigin); + } + + public Identity identity() { + return svc.id; + } + + /** + * @return {@code host:port} of the alternate service + */ + public String authority() { + return svc.authority; + } + + /** + * @return {@code identity().host()} + */ + public String host() { + return svc.id.host; + } + + /** + * @return {@code identity().port()} + */ + public int port() { + return svc.id.port; + } + + public boolean isPersist() { + return svc.persist; + } + + public boolean wasAdvertised() { + return svc.advertised; + } + + public String alpn() { + return svc.id.alpn; + } + + public Origin origin() { + return svc.origin; + } + + public Deadline deadline() { + return svc.deadline; + } + + /** + * {@return true if the origin, for which this is an alternate service, has the + * same authority as this alternate service. false otherwise.} + */ + public boolean originHasSameAuthority() { + return svc.sameAuthorityAsOrigin; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof AltService service)) return false; + return svc.equals(service.svc); + } + + @Override + public int hashCode() { + return svc.hashCode(); + } + + @Override + public String toString() { + return svc.pretty(); + } + + public static Optional create(final Identity id, final Origin origin, + final Deadline deadline, final boolean persist) { + Objects.requireNonNull(id); + Objects.requireNonNull(origin); + if (!origin.isSecure()) { + return Optional.empty(); + } + return Optional.of(new AltService(id, origin, deadline, persist, true)); + } + + private static Optional createUnadvertised(final Logger debug, + final Identity id, final Origin origin, + final HttpConnection conn, + final Deadline deadline, final boolean persist) { + Objects.requireNonNull(id); + Objects.requireNonNull(origin); + if (!origin.isSecure()) { + return Optional.empty(); + } + final List sniServerNames = AltSvcProcessor.getSNIServerNames(conn); + if (sniServerNames == null || sniServerNames.isEmpty()) { + if (debug.on()) { + debug.log("Skipping unadvertised altsvc creation of %s because connection %s" + + " didn't use SNI during connection establishment", id, conn); + } + return Optional.empty(); + } + return Optional.of(new AltService(id, origin, deadline, persist, false)); + } + + } + + // A size limited collection which keeps track of unique InvalidAltSvc instances. + // Upon reaching a pre-defined size limit, after adding newer entries, the collection + // then removes the eldest (the least recently added) entry from the collection. + // The implementation of this class is not thread safe and any concurrent access + // to instances of this class should be guarded externally. + private static final class InvalidAltServices extends LinkedHashMap { + + private static final long serialVersionUID = 2772562283544644819L; + + // we track only a reasonably small number of invalid alt services + private static final int MAX_TRACKED_INVALID_ALT_SVCS = 20; + + @Override + protected boolean removeEldestEntry(final Map.Entry eldest) { + return size() > MAX_TRACKED_INVALID_ALT_SVCS; + } + + private boolean contains(final InvalidAltSvc invalidAltSvc) { + return this.containsKey(invalidAltSvc); + } + + private boolean addUnique(final InvalidAltSvc invalidAltSvc) { + if (contains(invalidAltSvc)) { + return false; + } + this.put(invalidAltSvc, null); + return true; + } + } + + // An alt-service is invalid for a particular origin + private record InvalidAltSvc(Origin origin, AltService.Identity id) { + } + + private boolean keepAltServiceFor(Origin origin, AltService svc) { + // skip invalid alt services + if (isMarkedInvalid(origin, svc.identity())) { + if (debug.on()) { + debug.log("Not registering alt-service which was previously" + + " marked invalid: " + svc); + } + if (Log.altsvc()) { + Log.logAltSvc("AltService skipped (was previously marked invalid): " + svc); + } + return false; + } + return true; + } + + /** + * Declare a new Alternate Service endpoint for the given origin. + * + * @param origin the origin + * @param services a set of alt services for the origin + */ + public void replace(final Origin origin, final List services) { + Objects.requireNonNull(origin); + Objects.requireNonNull(services); + List added; + registryLock.lock(); + try { + // the list needs to be thread safe to ensure that we won't + // get a ConcurrentModificationException when iterating + // through the elements in list::stream(); + added = altServices.compute(origin, (key, list) -> { + Stream svcs = services.stream() + .filter(AltService.class::isInstance) // filter null + .filter((s) -> keepAltServiceFor(origin, s)); + List newList = svcs.toList(); + return newList.isEmpty() ? null : newList; + }); + } finally { + registryLock.unlock(); + } + if (debug.on()) { + debug.log("parsed services: %s", services); + debug.log("resulting services: %s", added); + } + if (Log.altsvc()) { + if (added != null) { + added.forEach((svc) -> Log.logAltSvc("AltService registry updated: {0}", svc)); + } + } + } + + // should be invoked while holding registryLock + private boolean isMarkedInvalid(final Origin origin, final AltService.Identity id) { + assert registryLock.isHeldByCurrentThread() : "Thread isn't holding registry lock"; + return this.invalidAltServices.contains(new InvalidAltSvc(origin, id)); + } + + /** + * Registers an unadvertised alt service for the given origin and the alpn. + * + * @param id The alt service identity + * @param origin The origin + * @return An {@code Optional} containing the registered {@code AltService}, + * or {@link Optional#empty()} if the service was not registered. + */ + Optional registerUnadvertised(final AltService.Identity id, + final Origin origin, + final HttpConnection conn) { + Objects.requireNonNull(id); + Objects.requireNonNull(origin); + registryLock.lock(); + try { + // an unadvertised alt service is registered by an origin only after a + // successful connection has completed with that alt service. This effectively means + // that we shouldn't check our "invalid alt services" collection, since a successful + // connection against the alt service implies a valid alt service. + // Additionally, we remove it from the "invalid alt services" collection for this + // origin, if at all it was part of that collection + this.invalidAltServices.remove(new InvalidAltSvc(origin, id)); + // default max age as per AltService RFC-7838, section 3.1 is 24 hours. we use + // that same value for unadvertised alt-service(s) for an origin. + final long defaultMaxAgeInSecs = 3600 * 24; + final Deadline deadline = TimeSource.now().plusSeconds(defaultMaxAgeInSecs); + final Optional created = AltService.createUnadvertised(debug, + id, origin, conn, deadline, true); + if (created.isEmpty()) { + return Optional.empty(); + } + final AltService altSvc = created.get(); + altServices.compute(origin, (key, list) -> { + Stream old = list == null ? Stream.empty() : list.stream(); + List newList = Stream.concat(old, Stream.of(altSvc)).toList(); + return newList.isEmpty() ? null : newList; + }); + if (debug.on()) { + debug.log("Added unadvertised AltService: %s", created); + } + if (Log.altsvc()) { + Log.logAltSvc("Added unadvertised AltService: {0}", created); + } + return created; + } finally { + registryLock.unlock(); + } + } + + /** + * Clear the alternate services of the specified origin from the registry + * + * @param origin The origin whose alternate services need to be cleared + */ + public void clear(final Origin origin) { + Objects.requireNonNull(origin); + registryLock.lock(); + try { + if (Log.altsvc()) { + Log.logAltSvc("Clearing AltServices for: " + origin); + } + altServices.remove(origin); + } finally { + registryLock.unlock(); + } + } + + public void markInvalid(final AltService altService) { + Objects.requireNonNull(altService); + markInvalid(altService.origin(), altService.identity()); + } + + private void markInvalid(final Origin origin, final AltService.Identity id) { + Objects.requireNonNull(origin); + Objects.requireNonNull(id); + registryLock.lock(); + try { + // remove this alt service from the current active set of the origin + this.altServices.computeIfPresent(origin, + (key, currentActive) -> { + assert currentActive != null; // should never be null according to spec + List newList = currentActive.stream() + .filter(Predicate.not(id::matches)).toList(); + return newList.isEmpty() ? null : newList; + + }); + // additionally keep track of this as an invalid alt service, so that it cannot be + // registered again in the future. Banning is temporary. + // Banned alt services may get removed from the set at some point due to + // implementation constraints. In which case they may get registered again + // and banned again, if connecting to the endpoint fails again. + this.invalidAltServices.addUnique(new InvalidAltSvc(origin, id)); + if (debug.on()) { + debug.log("AltService marked invalid: " + id + " for origin " + origin); + } + if (Log.altsvc()) { + Log.logAltSvc("AltService marked invalid: " + id + " for origin " + origin); + } + } finally { + registryLock.unlock(); + } + + } + + public Stream lookup(final URI uri, final String alpn) { + final Origin origin; + try { + origin = Origin.from(uri); + } catch (IllegalArgumentException iae) { + return Stream.empty(); + } + return lookup(origin, alpn); + } + + /** + * A stream of {@code AlternateService} that are available for the + * given origin and the given ALPN. + * + * @param origin the URI of the origin server + * @param alpn the ALPN of the alternate service + * @return a stream of {@code AlternateService} that are available for the + * given origin and that support the given ALPN + */ + public Stream lookup(final Origin origin, final String alpn) { + return lookup(origin, Predicate.isEqual(alpn)); + } + + public Stream lookup(final URI uri, + final Predicate alpnMatcher) { + final Origin origin; + try { + origin = Origin.from(uri); + } catch (IllegalArgumentException iae) { + return Stream.empty(); + } + return lookup(origin, alpnMatcher); + } + + private boolean isExpired(AltService service, Deadline now) { + var deadline = service.deadline(); + if (now.equals(deadline) || now.isAfter(deadline)) { + // expired, remove from the list + if (debug.on()) { + debug.log("Removing expired alt-service " + service); + } + if (Log.altsvc()) { + Log.logAltSvc("AltService has expired: {0}", service); + } + return true; + } + return false; + } + + /** + * A stream of {@code AlternateService} that are available for the + * given origin and the given ALPN. + * + * @param origin the URI of the origin server + * @param alpnMatcher a predicate to select particular AltService(s) based on the alpn + * of the alternate service + * @return a stream of {@code AlternateService} that are available for the + * given origin and whose ALPN satisfies the {@code alpn} predicate. + */ + private Stream lookup(final Origin origin, + final Predicate alpnMatcher) { + if (debug.on()) debug.log("looking up alt-service for: %s", origin); + final List services; + registryLock.lock(); + try { + // we first drop any expired services + final Deadline now = TimeSource.now(); + services = altServices.compute(origin, (key, list) -> { + if (list == null) return null; + List newList = list.stream() + .filter((s) -> !isExpired(s, now)) + .toList(); + return newList.isEmpty() ? null : newList; + }); + } finally { + registryLock.unlock(); + } + // the order is important - since preferred service are at the head + return services == null + ? Stream.empty() + : services.stream().sequential().filter(s -> alpnMatcher.test(s.identity().alpn())); + } + + /** + * @param altService The alternate service + * {@return true if the {@code service} is known to this registry and the + * service isn't past its max age. false otherwise} + * @throws NullPointerException if {@code service} is null + */ + public boolean isActive(final AltService altService) { + Objects.requireNonNull(altService); + return isActive(altService.origin(), altService.identity()); + } + + private boolean isActive(final Origin origin, final AltService.Identity id) { + Objects.requireNonNull(origin); + Objects.requireNonNull(id); + registryLock.lock(); + try { + final List currentActive = this.altServices.get(origin); + if (currentActive == null) { + return false; + } + AltService svc = null; + for (AltService s : currentActive) { + if (s.identity().equals(id)) { + svc = s; + break; + } + } + if (svc == null) { + return false; + } + // verify that the service hasn't expired + final Deadline now = TimeSource.now(); + final Deadline deadline = svc.deadline(); + final boolean expired = now.equals(deadline) || now.isAfter(deadline); + if (expired) { + // remove from the registry + altServices.put(origin, currentActive.stream() + .filter(Predicate.not(svc::equals)).toList()); + if (debug.on()) { + debug.log("Removed expired alt-service " + svc + " for origin " + origin); + } + if (Log.altsvc()) { + Log.logAltSvc("Removed AltService: {0}", svc); + } + return false; + } + return true; + } finally { + registryLock.unlock(); + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/AltSvcProcessor.java b/src/java.net.http/share/classes/jdk/internal/net/http/AltSvcProcessor.java new file mode 100644 index 00000000000..b172a242346 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/AltSvcProcessor.java @@ -0,0 +1,495 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http; + +import jdk.internal.net.http.AltServicesRegistry.AltService; +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.TimeSource; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.frame.AltSvcFrame; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.StringTokenizer; + +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; + +import static jdk.internal.net.http.Http3ClientProperties.ALTSVC_ALLOW_LOCAL_HOST_ORIGIN; +import static jdk.internal.net.http.common.Alpns.isSecureALPNName; + + +/** + * Responsible for parsing the Alt-Svc values from an Alt-Svc header and/or AltSvc HTTP/2 frame. + */ +final class AltSvcProcessor { + + private static final String HEADER = "alt-svc"; + private static final Logger debug = Utils.getDebugLogger(() -> "AltSvc"); + // a special value that we return back while parsing the header values, + // indicate that all existing alternate services for a origin need to be cleared + private static final List CLEAR_ALL_ALT_SVCS = List.of(); + // whether or not alt service can be created from "localhost" origin host + private static final boolean allowLocalHostOrigin = ALTSVC_ALLOW_LOCAL_HOST_ORIGIN; + + private static final SNIHostName LOCALHOST_SNI = new SNIHostName("localhost"); + + private record ParsedHeaderValue(String rawValue, String alpnName, String host, int port, + Map parameters) { + } + + private AltSvcProcessor() { + throw new UnsupportedOperationException("Instantiation not supported"); + } + + + /** + * Parses the alt-svc header received from origin and update + * registry with the processed values. + * + * @param response response passed on by the server + * @param client client that holds alt-svc registry + * @param request request that holds the origin details + */ + static void processAltSvcHeader(Response response, HttpClientImpl client, + HttpRequestImpl request) { + + // we don't support AltSvc from unsecure origins + if (!request.secure()) { + return; + } + if (response.statusCode == 421) { + // As per AltSvc spec (RFC-7838), section 6: + // An Alt-Svc header field in a 421 (Misdirected Request) response MUST be ignored. + return; + } + final var altSvcHeaderVal = response.headers().firstValue(HEADER); + if (altSvcHeaderVal.isEmpty()) { + return; + } + if (debug.on()) { + debug.log("Processing alt-svc header"); + } + final HttpConnection conn = response.exchange.exchImpl.connection(); + final List sniServerNames = getSNIServerNames(conn); + if (sniServerNames.isEmpty()) { + // we don't trust the alt-svc advertisement if the connection over which it + // was advertised didn't use SNI during TLS handshake while establishing the connection + if (debug.on()) { + debug.log("ignoring alt-svc header because connection %s didn't use SNI during" + + " connection establishment", conn); + } + return; + } + final Origin origin; + try { + origin = Origin.from(request.uri()); + } catch (IllegalArgumentException iae) { + if (debug.on()) { + debug.log("ignoring alt-svc header due to: " + iae); + } + // ignore the alt-svc + return; + } + String altSvcValue = altSvcHeaderVal.get(); + processValueAndUpdateRegistry(client, origin, altSvcValue); + } + + static void processAltSvcFrame(final int streamId, + final AltSvcFrame frame, + final HttpConnection conn, + final HttpClientImpl client) { + final String value = frame.getAltSvcValue(); + if (value == null || value.isBlank()) { + return; + } + if (!conn.isSecure()) { + // don't support alt svc from unsecure origins + return; + } + final List sniServerNames = getSNIServerNames(conn); + if (sniServerNames.isEmpty()) { + // we don't trust the alt-svc advertisement if the connection over which it + // was advertised didn't use SNI during TLS handshake while establishing the connection + if (debug.on()) { + debug.log("ignoring altSvc frame because connection %s didn't use SNI during" + + " connection establishment", conn); + } + return; + } + debug.log("processing AltSvcFrame %s", value); + final Origin origin; + if (streamId == 0) { + // section 4, RFC-7838 - alt-svc frame on stream 0 with empty (zero length) origin + // is invalid and MUST be ignored + if (frame.getOrigin().isBlank()) { + // invalid frame, ignore it + debug.log("Ignoring invalid alt-svc frame on stream 0 of " + conn); + return; + } + // parse origin from frame.getOrigin() string which is in ASCII + // serialized form of an origin (defined in section 6.2 of RFC-6454) + final Origin parsedOrigin; + try { + parsedOrigin = Origin.fromASCIISerializedForm(frame.getOrigin()); + } catch (IllegalArgumentException iae) { + // invalid origin value, ignore the frame + debug.log("origin couldn't be parsed, ignoring invalid alt-svc frame" + + " on stream " + streamId + " of " + conn); + return; + } + // currently we do not allow an alt service to be advertised for a different origin. + // if the origin advertised in the alt-svc frame doesn't match the origin of the + // connection, then we ignore it. the RFC allows us to do that: + // RFC-7838, section 4: + // An ALTSVC frame from a server to a client on stream 0 indicates that + // the conveyed alternative service is associated with the origin + // contained in the Origin field of the frame. An association with an + // origin that the client does not consider authoritative for the + // current connection MUST be ignored. + if (!parsedOrigin.equals(conn.getOriginServer())) { + debug.log("ignoring alt-svc frame on stream 0 for origin: " + parsedOrigin + + " received on connection of origin: " + conn.getOriginServer()); + return; + } + origin = parsedOrigin; + } else { + // (section 4, RFC-7838) - for non-zero stream id, the alt-svc is for the origin of + // the stream. Additionally, an ALTSVC frame on a stream other than stream 0 containing + // non-empty "Origin" information is invalid and MUST be ignored. + if (!frame.getOrigin().isEmpty()) { + // invalid frame, ignore it + debug.log("non-empty origin in alt-svc frame on stream " + streamId + " of " + + conn + ", ignoring alt-svc frame"); + return; + } + origin = conn.getOriginServer(); + assert origin != null : "origin server is null on connection: " + conn; + } + processValueAndUpdateRegistry(client, origin, value); + } + + private static void processValueAndUpdateRegistry(HttpClientImpl client, + Origin origin, + String altSvcValue) { + final List altServices = processHeaderValue(origin, altSvcValue); + // intentional identity check + if (altServices == CLEAR_ALL_ALT_SVCS) { + // clear all existing alt services for this origin + debug.log("clearing AltServiceRegistry for " + origin); + client.registry().clear(origin); + return; + } + debug.log("AltServices: %s", altServices); + if (altServices.isEmpty()) { + return; + } + // AltService RFC-7838, section 3.1 states: + // + // When an Alt-Svc response header field is received from an origin, its + // value invalidates and replaces all cached alternative services for + // that origin. + client.registry().replace(origin, altServices); + } + + static List getSNIServerNames(final HttpConnection conn) { + final List sniServerNames = conn.getSNIServerNames(); + if (sniServerNames != null && !sniServerNames.isEmpty()) { + return sniServerNames; + } + // no SNI server name(s) were used when establishing this connection. check if + // this connection is to a loopback address and if it is then see if a configuration + // has been set to allow alt services advertised by loopback addresses to be trusted/accepted. + // if such a configuration has been set, then we return a SNIHostName for "localhost" + final InetSocketAddress addr = conn.address(); + final boolean isLoopbackAddr = addr.isUnresolved() + ? false + : conn.address.getAddress().isLoopbackAddress(); + if (!isLoopbackAddr) { + return List.of(); // no SNI server name(s) used for this connection + } + if (!allowLocalHostOrigin) { + // this is a connection to a loopback address, with no SNI server name(s) used + // during TLS handshake and the configuration doesn't allow accepting/trusting + // alt services from loopback address, so we return no SNI server name(s) for this + // connection + return List.of(); + } + // at this point, we have identified this as a loopback address and the configuration + // has been set to accept/trust alt services from loopback address, so we return a + // SNIHostname corresponding to "localhost" + return List.of(LOCALHOST_SNI); + } + + // Here are five examples of values for the Alt-Svc header: + // String svc1 = """foo=":443"; ma=2592000; persist=1""" + // String svc2 = """h3="localhost:5678""""; + // String svc3 = """bar3=":446"; ma=2592000; persist=1"""; + // String svc4 = """h3-34=":5678"; ma=2592000; persist=1"""; + // String svc5 = "%s, %s, %s, %s".formatted(svc1, svc2, svc3, svc4); + // The last one (svc5) should result in two services being registered: + // AltService[origin=https://localhost:64077/, alpn=h3, endpoint=localhost/127.0.0.1:5678, + // deadline=2021-03-13T01:41:01.369488Z, persist=false] + // AltService[origin=https://localhost:64077/, alpn=h3-34, endpoint=localhost/127.0.0.1:5678, + // deadline=2021-04-11T01:41:01.369912Z, persist=true] + private static List processHeaderValue(final Origin origin, + final String headerValue) { + final List altServices = new ArrayList<>(); + // multiple alternate services can be specified with comma as a delimiter + final var altSvcs = headerValue.split(","); + for (var altSvc : altSvcs) { + altSvc = altSvc.trim(); + + // each value is expected to be of the following form, as noted in RFC-7838, section 3 + // Alt-Svc = clear / 1#alt-value + // clear = %s"clear"; "clear", case-sensitive + // alt-value = alternative *( OWS ";" OWS parameter ) + // alternative = protocol-id "=" alt-authority + // protocol-id = token ; percent-encoded ALPN protocol name + // alt-authority = quoted-string ; containing [ uri-host ] ":" port + // parameter = token "=" ( token / quoted-string ) + + // As per the spec, the value "clear" is expected to be case-sensitive + if (altSvc.equals("clear")) { + return CLEAR_ALL_ALT_SVCS; + } + final ParsedHeaderValue parsed = parseAltValue(origin, altSvc); + if (parsed == null) { + // this implies the alt-svc header value couldn't be parsed and thus is malformed. + // we skip such header values. + debug.log("skipping %s", altSvc); + continue; + } + final var deadline = getValidTill(parsed.parameters().get("ma")); + final var persist = getPersist(parsed.parameters().get("persist")); + final AltService.Identity altSvcId = new AltService.Identity(parsed.alpnName(), + parsed.host(), parsed.port()); + AltService.create(altSvcId, origin, deadline, persist) + .ifPresent((altsvc) -> { + altServices.add(altsvc); + if (Log.altsvc()) { + final var s = altsvc; + Log.logAltSvc("Created AltService: {0}", s); + } else if (debug.on()) { + debug.log("Created AltService for id=%s, origin=%s%n", altSvcId, origin); + } + }); + } + return altServices; + } + + private static ParsedHeaderValue parseAltValue(final Origin origin, final String altValue) { + // header value is expected to be of the following form, as noted in RFC-7838, section 3 + // Alt-Svc = clear / 1#alt-value + // clear = %s"clear"; "clear", case-sensitive + // alt-value = alternative *( OWS ";" OWS parameter ) + // alternative = protocol-id "=" alt-authority + // protocol-id = token ; percent-encoded ALPN protocol name + // alt-authority = quoted-string ; containing [ uri-host ] ":" port + // parameter = token "=" ( token / quoted-string ) + + // find the = sign that separates the protocol-id and alt-authority + debug.log("parsing %s", altValue); + final int alternativeDelimIndex = altValue.indexOf("="); + if (alternativeDelimIndex == -1 || alternativeDelimIndex == altValue.length() - 1) { + // not a valid alt value + debug.log("no \"=\" character in %s", altValue); + return null; + } + // key is always the protocol-id. example, in 'h3="localhost:5678"; ma=23232; persist=1' + // "h3" acts as the key with '"localhost:5678"; ma=23232; persist=1' as the value + final String protocolId = altValue.substring(0, alternativeDelimIndex); + // the protocol-id can be percent encoded as per the spec, so we decode it to get the alpn name + final var alpnName = decodePotentialPercentEncoded(protocolId); + debug.log("alpn is %s in %s", alpnName, altValue); + if (!isSecureALPNName(alpnName)) { + // no reasonable assurance that the alternate service will be under the control + // of the origin (section 2.1, RFC-7838) + debug.log("alpn %s is not secure, skipping", alpnName); + return null; + } + String remaining = altValue.substring(alternativeDelimIndex + 1); + // now parse alt-authority + if (!remaining.startsWith("\"") || remaining.length() == 1) { + // we expect a quoted string for alt-authority + debug.log("no quoted authority in %s", altValue); + return null; + } + remaining = remaining.substring(1); // skip the starting double quote + final int nextDoubleQuoteIndex = remaining.indexOf("\""); + if (nextDoubleQuoteIndex == -1) { + // malformed value + debug.log("missing closing quote in %s", altValue); + return null; + } + final String altAuthority = remaining.substring(0, nextDoubleQuoteIndex); + final HostPort hostPort = getHostPort(origin, altAuthority); + if (hostPort == null) return null; // host port could not be parsed + if (nextDoubleQuoteIndex == remaining.length() - 1) { + // there's nothing more left to parse + return new ParsedHeaderValue(altValue, alpnName, hostPort.host(), hostPort.port(), Map.of()); + } + // parse the semicolon delimited parameters out of the rest of the remaining string + remaining = remaining.substring(nextDoubleQuoteIndex + 1); + final Map parameters = extractParameters(remaining); + return new ParsedHeaderValue(altValue, alpnName, hostPort.host(), hostPort.port(), parameters); + } + + private static String decodePotentialPercentEncoded(final String val) { + if (!val.contains("%")) { + return val; + } + // TODO: impl this + // In practice this method is only used for the ALPN. + // We only support h3 for now, so we do not need to + // decode percents: anything else but h3 will eventually be ignored. + return val; + } + + private static Map extractParameters(final String val) { + // As per the spec, parameters take the form of: + // *( OWS ";" OWS parameter ) + // ... + // parameter = token "=" ( token / quoted-string ) + // + // where * represents "any number of" and OWS means "optional whitespace" + + final var tokenizer = new StringTokenizer(val, ";"); + if (!tokenizer.hasMoreTokens()) { + return Map.of(); + } + Map parameters = null; + while (tokenizer.hasMoreTokens()) { + final var parameter = tokenizer.nextToken().trim(); + if (parameter.isEmpty()) { + continue; + } + final var equalSignIndex = parameter.indexOf('='); + if (equalSignIndex == -1 || equalSignIndex == parameter.length() - 1) { + // a parameter is expected to have a "=" delimiter which separates a key and a value. + // we skip parameters which don't conform to that rule + continue; + } + final var paramKey = parameter.substring(0, equalSignIndex); + final var paramValue = parameter.substring(equalSignIndex + 1); + if (parameters == null) { + parameters = new HashMap<>(); + } + parameters.put(paramKey, paramValue); + } + if (parameters == null) { + return Map.of(); + } + return Collections.unmodifiableMap(parameters); + } + + private record HostPort(String host, int port) {} + + private static HostPort getHostPort(Origin origin, String altAuthority) { + // The AltService spec defines an alt-authority as follows: + // + // alt-authority = quoted-string ; containing [ uri-host ] ":" port + // + // When this method is called the passed altAuthority is already stripped of the leading and trailing + // double-quotes. The value will this be of the form [uri-host]:port where uri-host is optional. + String host; int port; + try { + // Use URI to do the parsing, with a special case for optional host + URI uri = new URI("http://" + altAuthority + "/"); + host = uri.getHost(); + port = uri.getPort(); + if (host == null && port == -1) { + var auth = uri.getRawAuthority(); + if (auth.isEmpty()) return null; + if (auth.charAt(0) == ':') { + uri = new URI("http://x" + altAuthority + "/"); + if ("x".equals(uri.getHost())) { + port = uri.getPort(); + } + } + } + if (port == -1) { + debug.log("Can't parse authority: " + altAuthority); + return null; + } + String hostport; + if (host == null || host.isEmpty()) { + hostport = ":" + port; + host = origin.host(); + } else { + hostport = host + ":" + port; + } + // reject anything unexpected. altAuthority should match hostport + if (!hostport.equals(altAuthority)) { + debug.log("Authority \"%s\" doesn't match host:port \"%s\"", + altAuthority, hostport); + return null; + } + } catch (URISyntaxException x) { + debug.log("Failed to parse authority: %s - %s", + altAuthority, x); + return null; + } + return new HostPort(host, port); + } + + private static Deadline getValidTill(final String maxAge) { + // There's a detailed algorithm in RFC-7234 section 4.2.3, for calculating the age. This + // RFC section is referenced from the alternate service RFC-7838 section 3.1. + // For now though, we use "now" as the instant against which the age will be applied. + final Deadline responseGenerationInstant = TimeSource.now(); + // default max age as per AltService RFC-7838, section 3.1 is 24 hours + final long defaultMaxAgeInSecs = 3600 * 24; + if (maxAge == null) { + return responseGenerationInstant.plusSeconds(defaultMaxAgeInSecs); + } + try { + final long seconds = Long.parseLong(maxAge); + // negative values aren't allowed for max-age as per RFC-7234, section 1.2.1 + return seconds < 0 ? responseGenerationInstant.plusSeconds(defaultMaxAgeInSecs) + : responseGenerationInstant.plusSeconds(seconds); + } catch (NumberFormatException nfe) { + return responseGenerationInstant.plusSeconds(defaultMaxAgeInSecs); + } + } + + private static boolean getPersist(final String persist) { + // AltService RFC-7838, section 3.1, states: + // + // This specification only defines a single value for "persist". + // Clients MUST ignore "persist" parameters with values other than "1". + // + return "1".equals(persist); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Exchange.java b/src/java.net.http/share/classes/jdk/internal/net/http/Exchange.java index 1ee54ed2bef..c50a4922e80 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/Exchange.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Exchange.java @@ -27,6 +27,7 @@ package jdk.internal.net.http; import java.io.IOException; import java.net.ProtocolException; +import java.net.http.HttpClient.Version; import java.time.Duration; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -62,6 +63,7 @@ final class Exchange { volatile ExchangeImpl exchImpl; volatile CompletableFuture> exchangeCF; volatile CompletableFuture bodyIgnored; + volatile boolean streamLimitReached; // used to record possible cancellation raised before the exchImpl // has been established. @@ -74,11 +76,18 @@ final class Exchange { final String dbgTag; // Keeps track of the underlying connection when establishing an HTTP/2 - // exchange so that it can be aborted/timed out mid setup. + // or HTTP/3 exchange so that it can be aborted/timed out mid-setup. final ConnectionAborter connectionAborter = new ConnectionAborter(); final AtomicInteger nonFinalResponses = new AtomicInteger(); + // This will be set to true only when it is guaranteed that the server hasn't processed + // the request. Typically, this happens when the server explicitly states (through a GOAWAY frame + // or a relevant error code in reset frame) that the corresponding stream (id) wasn't processed. + // However, there can be cases where the client is certain that the request wasn't sent + // to the server (and thus not processed). In such cases, the client can set this to true. + private volatile boolean unprocessedByPeer; + Exchange(HttpRequestImpl request, MultiExchange multi) { this.request = request; this.upgrading = false; @@ -110,9 +119,13 @@ final class Exchange { } // Keeps track of the underlying connection when establishing an HTTP/2 - // exchange so that it can be aborted/timed out mid setup. - static final class ConnectionAborter { + // or HTTP/3 exchange so that it can be aborted/timed out mid setup. + final class ConnectionAborter { + // In case of HTTP/3 requests we may have + // two connections in parallel: a regular TCP connection + // and a QUIC connection. private volatile HttpConnection connection; + private volatile HttpQuicConnection quicConnection; private volatile boolean closeRequested; private volatile Throwable cause; @@ -123,10 +136,11 @@ final class Exchange { // closed closeRequested = this.closeRequested; if (!closeRequested) { - this.connection = connection; - } else { - // assert this.connection == null - this.closeRequested = false; + if (connection instanceof HttpQuicConnection quicConnection) { + this.quicConnection = quicConnection; + } else { + this.connection = connection; + } } } if (closeRequested) closeConnection(connection, cause); @@ -134,6 +148,7 @@ final class Exchange { void closeConnection(Throwable error) { HttpConnection connection; + HttpQuicConnection quicConnection; Throwable cause; synchronized (this) { cause = this.cause; @@ -141,39 +156,64 @@ final class Exchange { cause = error; } connection = this.connection; - if (connection == null) { + quicConnection = this.quicConnection; + if (connection == null || quicConnection == null) { closeRequested = true; this.cause = cause; } else { + this.quicConnection = null; this.connection = null; this.cause = null; } } closeConnection(connection, cause); + closeConnection(quicConnection, cause); } + // Called by HTTP/2 after an upgrade. + // There is no upgrade for HTTP/3 HttpConnection disable() { HttpConnection connection; synchronized (this) { connection = this.connection; this.connection = null; + this.quicConnection = null; this.closeRequested = false; this.cause = null; } return connection; } - private static void closeConnection(HttpConnection connection, Throwable cause) { - if (connection != null) { - try { - connection.close(cause); - } catch (Throwable t) { - // ignore + void clear(HttpConnection connection) { + synchronized (this) { + var c = this.connection; + if (connection == c) this.connection = null; + var qc = this.quicConnection; + if (connection == qc) this.quicConnection = null; + } + } + + private void closeConnection(HttpConnection connection, Throwable cause) { + if (connection == null) { + return; + } + try { + connection.close(cause); + } catch (Throwable t) { + // ignore + if (debug.on()) { + debug.log("ignoring exception that occurred during closing of connection: " + + connection, t); } } } } + // true if previous attempt resulted in streamLimitReached + public boolean hasReachedStreamLimit() { return streamLimitReached; } + // can be used to set or clear streamLimitReached (for instance clear it after retrying) + void streamLimitReached(boolean streamLimitReached) { this.streamLimitReached = streamLimitReached; } + // Called for 204 response - when no body is permitted // This is actually only needed for HTTP/1.1 in order // to return the connection to the pool (or close it) @@ -253,7 +293,7 @@ final class Exchange { impl.cancel(cause); } else { // abort/close the connection if setting up the exchange. This can - // be important when setting up HTTP/2 + // be important when setting up HTTP/2 or HTTP/3 closeReason = failed.get(); if (closeReason != null) { connectionAborter.closeConnection(closeReason); @@ -283,6 +323,9 @@ final class Exchange { cf = exchangeCF; } } + if (multi.requestCancelled() && impl != null && cause == null) { + cause = new IOException("Request cancelled"); + } if (cause == null) return; if (impl != null) { // The exception is raised by propagating it to the impl. @@ -314,7 +357,7 @@ final class Exchange { // if upgraded, we don't close the connection. // cancelling will be handled by the HTTP/2 exchange // in its own time. - if (!upgraded) { + if (!upgraded && !(connection instanceof HttpQuicConnection)) { t = getCancelCause(); if (t == null) t = new IOException("Request cancelled"); if (debug.on()) debug.log("exchange cancelled during connect: " + t); @@ -350,8 +393,8 @@ final class Exchange { private CompletableFuture> establishExchange(HttpConnection connection) { if (debug.on()) { - debug.log("establishing exchange for %s,%n\t proxy=%s", - request, request.proxy()); + debug.log("establishing exchange for %s #%s,%n\t proxy=%s", + request, multi.id, request.proxy()); } // check if we have been cancelled first. Throwable t = getCancelCause(); @@ -364,7 +407,17 @@ final class Exchange { } CompletableFuture> cf, res; - cf = ExchangeImpl.get(this, connection); + + cf = ExchangeImpl.get(this, connection) + // set exchImpl and call checkCancelled to make sure exchImpl + // gets cancelled even if the exchangeCf was completed exceptionally + // before the CF returned by ExchangeImpl.get completed. This deals + // with issues when the request is cancelled while the exchange impl + // is being created. + .thenApply((eimpl) -> { + synchronized (Exchange.this) {exchImpl = eimpl;} + checkCancelled(); return eimpl; + }).copy(); // We should probably use a VarHandle to get/set exchangeCF // instead - as we need CAS semantics. synchronized (this) { exchangeCF = cf; }; @@ -390,7 +443,7 @@ final class Exchange { } // Completed HttpResponse will be null if response succeeded - // will be a non null responseAsync if expect continue returns an error + // will be a non-null responseAsync if expect continue returns an error public CompletableFuture responseAsync() { return responseAsyncImpl(null); @@ -715,4 +768,13 @@ final class Exchange { String dbgString() { return dbgTag; } + + final boolean isUnprocessedByPeer() { + return this.unprocessedByPeer; + } + + // Marks the exchange as unprocessed by the peer + final void markUnprocessedByPeer() { + this.unprocessedByPeer = true; + } } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/ExchangeImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/ExchangeImpl.java index f393b021cd4..74600e78557 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/ExchangeImpl.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/ExchangeImpl.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -26,17 +26,26 @@ package jdk.internal.net.http; import java.io.IOException; +import java.net.ConnectException; +import java.net.InetSocketAddress; +import java.net.http.HttpConnectTimeoutException; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.ResponseInfo; +import java.net.http.UnsupportedProtocolVersionException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; +import java.util.function.Supplier; +import jdk.internal.net.http.Http2Connection.ALPNException; import jdk.internal.net.http.common.HttpBodySubscriberWrapper; import jdk.internal.net.http.common.Logger; import jdk.internal.net.http.common.MinimalFuture; import jdk.internal.net.http.common.Utils; import static java.net.http.HttpClient.Version.HTTP_1_1; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; /** * Splits request so that headers and body can be sent separately with optional @@ -60,10 +69,6 @@ abstract class ExchangeImpl { private volatile boolean expectTimeoutRaised; - // this will be set to true only when the peer explicitly states (through a GOAWAY frame or - // a relevant error code in reset frame) that the corresponding stream (id) wasn't processed - private volatile boolean unprocessedByPeer; - ExchangeImpl(Exchange e) { // e == null means a http/2 pushed stream this.exchange = e; @@ -98,23 +103,414 @@ abstract class ExchangeImpl { static CompletableFuture> get(Exchange exchange, HttpConnection connection) { - if (exchange.version() == HTTP_1_1) { + HttpRequestImpl request = exchange.request(); + var version = exchange.version(); + if (version == HTTP_1_1 || request.isWebSocket()) { if (debug.on()) debug.log("get: HTTP/1.1: new Http1Exchange"); return createHttp1Exchange(exchange, connection); - } else { - Http2ClientImpl c2 = exchange.client().client2(); // #### improve - HttpRequestImpl request = exchange.request(); - CompletableFuture c2f = c2.getConnectionFor(request, exchange); + } else if (!request.secure() && request.isHttp3Only(version)) { + assert version == HTTP_3; + assert !request.isWebSocket(); if (debug.on()) - debug.log("get: Trying to get HTTP/2 connection"); - // local variable required here; see JDK-8223553 - CompletableFuture>> fxi = - c2f.handle((h2c, t) -> createExchangeImpl(h2c, t, exchange, connection)); - return fxi.thenCompose(x->x); + debug.log("get: HTTP/3: HTTP/3 is not supported on plain connections"); + return MinimalFuture.failedFuture( + new UnsupportedProtocolVersionException( + "HTTP/3 is not supported on plain connections")); + } else if (version == HTTP_2 || isTCP(connection) || !request.secure()) { + assert !request.isWebSocket(); + return attemptHttp2Exchange(exchange, connection); + } else { + assert request.secure(); + assert version == HTTP_3; + assert !request.isWebSocket(); + return attemptHttp3Exchange(exchange, connection); } } + private static boolean isTCP(HttpConnection connection) { + if (connection instanceof HttpQuicConnection) return false; + if (connection == null) return false; + // if it's not an HttpQuicConnection and it's not null it's + // a TCP connection + return true; + } + + private static CompletableFuture> + attemptHttp2Exchange(Exchange exchange, HttpConnection connection) { + HttpRequestImpl request = exchange.request(); + Http2ClientImpl c2 = exchange.client().client2(); // #### improve + CompletableFuture c2f = c2.getConnectionFor(request, exchange); + if (debug.on()) + debug.log("get: Trying to get HTTP/2 connection"); + // local variable required here; see JDK-8223553 + CompletableFuture>> fxi = + c2f.handle((h2c, t) -> createExchangeImpl(h2c, t, exchange, connection)); + return fxi.thenCompose(x -> x); + } + + private static CompletableFuture> + attemptHttp3Exchange(Exchange exchange, HttpConnection connection) { + HttpRequestImpl request = exchange.request(); + var exchvers = exchange.version(); + assert request.secure() : request.uri() + " is not secure"; + assert exchvers == HTTP_3 : "expected HTTP/3, got " + exchvers; + // when we reach here, it's guaranteed that the client supports HTTP3 + assert exchange.client().client3().isPresent() : "HTTP3 isn't supported by the client"; + var client3 = exchange.client().client3().get(); + CompletableFuture c3f; + Supplier> c2fs; + var config = request.http3Discovery(); + + if (debug.on()) { + debug.log("get: Trying to get HTTP/3 connection; config is %s", config); + } + // The algorithm here depends on whether HTTP/3 is specified on + // the request itself, or on the HttpClient. + // In both cases, we may attempt a direct HTTP/3 connection if + // we don't have an H3 endpoint registered in the AltServicesRegistry. + // However, if HTTP/3 is not specified explicitly on the request, + // we will start both an HTTP/2 and an HTTP/3 connection at the + // same time, and use the one that complete first. If HTTP/3 is + // specified on the request, we will give priority to HTTP/3 ond + // only start the HTTP/2 connection if the HTTP/3 connection fails, + // or doesn't succeed in the imparted timeout. The timeout can be + // specified with the property "jdk.httpclient.http3.maxDirectConnectionTimeout". + // If unspecified it defaults to 2750ms. + // + // Because the HTTP/2 connection may start as soon as we create the + // CompletableFuture returned by the Http2Client, + // we are using a Supplier> to + // set up the call chain that would start the HTTP/2 connection. + try { + // first look to see if we already have an HTTP/3 connection in + // the pool. If we find one, we're almost done! We won't need + // to start any HTTP/2 connection. + Http3Connection pooled = client3.findPooledConnectionFor(request, exchange); + if (pooled != null) { + c3f = MinimalFuture.completedFuture(pooled); + c2fs = null; + } else { + if (debug.on()) + debug.log("get: no HTTP/3 pooled connection found"); + // possibly start an HTTP/3 connection + boolean mayAttemptDirectConnection = client3.mayAttemptDirectConnection(request); + c3f = client3.getConnectionFor(request, exchange); + if ((!c3f.isDone() || c3f.isCompletedExceptionally()) && mayAttemptDirectConnection) { + // We don't know if the server supports HTTP/3. + // happy eyeball: prepare to try both HTTP/3 and HTTP/2 and + // to use the first that succeeds + if (config != Http3DiscoveryMode.HTTP_3_URI_ONLY) { + if (debug.on()) { + debug.log("get: trying with both HTTP/3 and HTTP/2"); + } + Http2ClientImpl client2 = exchange.client().client2(); + c2fs = () -> client2.getConnectionFor(request, exchange); + } else { + if (debug.on()) { + debug.log("get: trying with HTTP/3 only"); + } + c2fs = null; + } + } else { + // We have a completed Http3Connection future. + // No need to attempt direct HTTP/3 connection. + c2fs = null; + } + } + } catch (IOException io) { + return MinimalFuture.failedFuture(io); + } + if (c2fs == null) { + // Do not attempt a happy eyeball: go the normal route to + // attempt an HTTP/3 connection + // local variable required here; see JDK-8223553 + if (debug.on()) debug.log("No HTTP/3 eyeball needed"); + CompletableFuture>> fxi = + c3f.handle((h3c, t) -> createExchangeImpl(h3c, t, exchange, connection)); + return fxi.thenCompose(x->x); + } else if (request.version().orElse(null) == HTTP_3) { + // explicit request to use HTTP/3, only use HTTP/2 if HTTP/3 fails, but + // still start both connections in parallel. HttpQuicConnection will + // attempt a direct connection. Because we register + // firstToComplete as a dependent action of c3f we will actually + // only use HTTP/2 (or HTTP/1.1) if HTTP/3 failed + CompletableFuture>> fxi = + c3f.handle((h3c, e) -> firstToComplete(exchange, connection, c2fs, c3f)); + if (debug.on()) { + debug.log("Explicit HTTP/3 request: " + + "attempt HTTP/3 first, then default to HTTP/2"); + } + return fxi.thenCompose(x->x); + } + if (debug.on()) { + debug.log("Attempt HTTP/3 and HTTP/2 in parallel, use the first that connects"); + } + // default client version is HTTP/3 - request version is not set. + // so try HTTP/3 + HTTP/2 in parallel and take the first that completes. + return firstToComplete(exchange, connection, c2fs, c3f); + } + + // Use the first connection that successfully completes. + // This is a bit hairy because HTTP/2 may be downgraded to HTTP/1 if the server + // doesn't support HTTP/2. In which case the connection attempt will succeed but + // c2f will be completed with a ALPNException. + private static CompletableFuture> firstToComplete( + Exchange exchange, + HttpConnection connection, + Supplier> c2fs, + CompletableFuture c3f) { + if (debug.on()) { + debug.log("firstToComplete(connection=%s)", connection); + debug.log("Will use the first connection that succeeds from HTTP/2 or HTTP/3"); + } + assert connection == null : "should not come here if connection is not null: " + connection; + + // Set up a completable future (cf) that will complete + // when the first HTTP/3 or HTTP/2 connection result is + // available. Error cases (when the result is exceptional) + // is handled in a dependent action of cf later below + final CompletableFuture cf; + // c3f is used for HTTP/3, c2f for HTTP/2 + final CompletableFuture c2f; + if (c3f.isDone()) { + // We already have a result for HTTP/3, consider that first; + // There's no need to start HTTP/2 yet if the result is successful. + c2f = null; + cf = c3f; + } else { + // No result for HTTP/3 yet, start HTTP/2 now and wait for the + // first that completes. + c2f = c2fs.get(); + cf = CompletableFuture.anyOf(c2f, c3f); + } + + CompletableFuture>> cfxi = cf.handle((r, t) -> { + if (debug.on()) { + debug.log("Checking which from HTTP/2 or HTTP/3 succeeded first"); + } + CompletableFuture> res; + // first check if c3f is completed successfully + if (c3f.isDone()) { + Http3Connection h3c = c3f.exceptionally((e) -> null).resultNow(); + if (h3c != null) { + // HTTP/3 success! Use HTTP/3 + if (debug.on()) { + debug.log("HTTP/3 connect completed first, using HTTP/3"); + } + res = createExchangeImpl(h3c, null, exchange, connection); + if (c2f != null) c2f.thenApply(c -> { + if (c != null) { + c.abandonStream(); + } + return c; + }); + } else { + // HTTP/3 failed! Use HTTP/2 + if (debug.on()) { + debug.log("HTTP/3 connect completed unsuccessfully," + + " either with null or with exception - waiting for HTTP/2"); + c3f.handle((r3, t3) -> { + debug.log("\tcf3: result=%s, throwable=%s", + r3, Utils.getCompletionCause(t3)); + return r3; + }).exceptionally((e) -> null).join(); + } + // c2f may be null here in the case where c3f was already completed + // when firstToComplete was called. + var h2cf = c2f == null ? c2fs.get() : c2f; + // local variable required here; see JDK-8223553 + CompletableFuture>> fxi = h2cf + .handle((h2c, e) -> createExchangeImpl(h2c, e, exchange, connection)); + res = fxi.thenCompose(x -> x); + } + } else if (c2f != null && c2f.isDone()) { + Http2Connection h2c = c2f.exceptionally((e) -> null).resultNow(); + if (h2c != null) { + // HTTP/2 succeeded first! Use it. + if (debug.on()) { + debug.log("HTTP/2 connect completed first, using HTTP/2"); + } + res = createExchangeImpl(h2c, null, exchange, connection); + } else if (exchange.multi.requestCancelled()) { + // special case for when the exchange is cancelled + if (debug.on()) { + debug.log("HTTP/2 connect completed unsuccessfully, but request cancelled"); + } + CompletableFuture>> fxi = c2f + .handle((c, e) -> createExchangeImpl(c, e, exchange, connection)); + res = fxi.thenCompose(x -> x); + } else { + if (debug.on()) { + debug.log("HTTP/2 connect completed unsuccessfully," + + " either with null or with exception"); + c2f.handle((r2, t2) -> { + debug.log("\tcf2: result=%s, throwable=%s", + r2, Utils.getCompletionCause(t2)); + return r2; + }).exceptionally((e) -> null).join(); + } + + // Now is the more complex stuff. + // HTTP/2 could have failed in the ALPN, but we still + // created a valid TLS connection to the server => default + // to HTTP/1.1 over TLS + HttpConnection http1Connection = null; + if (c2f.isCompletedExceptionally() && !c2f.isCancelled()) { + Throwable cause = Utils.getCompletionCause(c2f.exceptionNow()); + if (cause instanceof ALPNException alpn) { + debug.log("HTTP/2 downgraded to HTTP/1.1 - use HTTP/1.1"); + http1Connection = alpn.getConnection(); + } + } + if (http1Connection != null) { + if (debug.on()) { + debug.log("HTTP/1.1 connect completed first, using HTTP/1.1"); + } + // ALPN failed - but we have a valid HTTP/1.1 connection + // to the server: use that. + res = createHttp1Exchange(exchange, http1Connection); + } else { + if (c2f.isCompletedExceptionally()) { + // Wait for HTTP/3 to complete, potentially fallback to + // HTTP/1.1 + // local variable required here; see JDK-8223553 + debug.log("HTTP/2 completed with exception, wait for HTTP/3, " + + "possibly fallback to HTTP/1.1"); + CompletableFuture>> fxi = c3f + .handle((h3c, e) -> fallbackToHttp1OnTimeout(h3c, e, exchange, connection)); + res = fxi.thenCompose(x -> x); + } else { + // + // r2 == null && t2 == null - which means we know the + // server doesn't support h2, and we probably already + // have an HTTP/1.1 connection to it + // + // If an HTTP/1.1 connection is available use it. + // Otherwise, wait for the HTTP/3 to complete, potentially + // fallback to HTTP/1.1 + HttpRequestImpl request = exchange.request(); + InetSocketAddress proxy = Utils.resolveAddress(request.proxy()); + InetSocketAddress addr = request.getAddress(); + ConnectionPool pool = exchange.client().connectionPool(); + // if we have an HTTP/1.1 connection in the pool, use that. + http1Connection = pool.getConnection(true, addr, proxy); + if (http1Connection != null && http1Connection.isOpen()) { + debug.log("Server doesn't support HTTP/2, " + + "but we have an HTTP/1.1 connection in the pool"); + debug.log("Using HTTP/1.1"); + res = createHttp1Exchange(exchange, http1Connection); + } else { + // we don't have anything ready to use in the pool: + // wait for http/3 to complete, possibly falling back + // to HTTP/1.1 + debug.log("Server doesn't support HTTP/2, " + + "and we do not have an HTTP/1.1 connection"); + debug.log("Waiting for HTTP/3, possibly fallback to HTTP/1.1"); + CompletableFuture>> fxi = c3f + .handle((h3c, e) -> fallbackToHttp1OnTimeout(h3c, e, exchange, connection)); + res = fxi.thenCompose(x -> x); + } + } + } + } + } else { + assert c2f != null; + Throwable failed = t != null ? t : new InternalError("cf1 or cf2 should have completed"); + res = MinimalFuture.failedFuture(failed); + } + return res; + }); + return cfxi.thenCompose(x -> x); + } + + private static CompletableFuture> + fallbackToHttp1OnTimeout(Http3Connection c, + Throwable t, + Exchange exchange, + HttpConnection connection) { + if (t != null) { + Throwable cause = Utils.getCompletionCause(t); + if (cause instanceof HttpConnectTimeoutException) { + // when we reach here we already tried with HTTP/2, + // and we most likely have an HTTP/1.1 connection in + // the idle pool. So fallback to that. + if (debug.on()) { + debug.log("HTTP/3 connection timed out: fall back to HTTP/1.1"); + } + return createHttp1Exchange(exchange, null); + } + } + return createExchangeImpl(c, t, exchange, connection); + } + + + + // Creates an HTTP/3 exchange, possibly downgrading to HTTP/2 + private static CompletableFuture> + createExchangeImpl(Http3Connection c, + Throwable t, + Exchange exchange, + HttpConnection connection) { + if (debug.on()) + debug.log("handling HTTP/3 connection creation result"); + if (t == null && exchange.multi.requestCancelled()) { + return MinimalFuture.failedFuture(new IOException("Request cancelled")); + } + if (c == null && t == null) { + if (debug.on()) + debug.log("downgrading to HTTP/2"); + return attemptHttp2Exchange(exchange, connection); + } else if (t != null) { + t = Utils.getCompletionCause(t); + if (debug.on()) { + if (t instanceof HttpConnectTimeoutException || t instanceof ConnectException) { + debug.log("HTTP/3 connection creation failed: " + t); + } else { + debug.log("HTTP/3 connection creation failed " + + "with unexpected exception:", t); + } + } + return MinimalFuture.failedFuture(t); + } else { + if (debug.on()) + debug.log("creating HTTP/3 exchange"); + try { + if (exchange.hasReachedStreamLimit()) { + // clear the flag before attempting to create a stream again + exchange.streamLimitReached(false); + } + return c.createStream(exchange) + .thenApply(ExchangeImpl::checkCancelled); + } catch (IOException e) { + return MinimalFuture.failedFuture(e); + } + } + } + + private static > T checkCancelled(T exchangeImpl) { + Exchange e = exchangeImpl.getExchange(); + if (debug.on()) { + debug.log("checking cancellation for: " + exchangeImpl); + } + if (e.multi.requestCancelled()) { + if (debug.on()) { + debug.log("request was cancelled"); + } + if (!exchangeImpl.isCanceled()) { + if (debug.on()) { + debug.log("cancelling exchange: " + exchangeImpl); + } + var cause = e.getCancelCause(); + if (cause == null) cause = new IOException("Request cancelled"); + exchangeImpl.cancel(cause); + } + } + return exchangeImpl; + } + + + // Creates an HTTP/2 exchange, possibly downgrading to HTTP/1 private static CompletableFuture> createExchangeImpl(Http2Connection c, Throwable t, @@ -280,12 +676,4 @@ abstract class ExchangeImpl { // an Expect-Continue void expectContinueFailed(int rcode) { } - final boolean isUnprocessedByPeer() { - return this.unprocessedByPeer; - } - - // Marks the exchange as unprocessed by the peer - final void markUnprocessedByPeer() { - this.unprocessedByPeer = true; - } } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/H3FrameOrderVerifier.java b/src/java.net.http/share/classes/jdk/internal/net/http/H3FrameOrderVerifier.java new file mode 100644 index 00000000000..3eb4d631c75 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/H3FrameOrderVerifier.java @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http; + +import jdk.internal.net.http.http3.frames.DataFrame; +import jdk.internal.net.http.http3.frames.HeadersFrame; +import jdk.internal.net.http.http3.frames.Http3Frame; +import jdk.internal.net.http.http3.frames.Http3FrameType; +import jdk.internal.net.http.http3.frames.MalformedFrame; +import jdk.internal.net.http.http3.frames.PushPromiseFrame; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.http.http3.frames.UnknownFrame; + +/** + * Verifies that when a HTTP3 frame arrives on a stream, then that particular frame type + * is in the expected order as compared to the previous frame type that was received. + * In effect, does what the RFC-9114, section 4.1 and section 6.2.1 specifies. + * Note that the H3FrameOrderVerifier is only responsible for checking the order in which a + * frame type is received on a stream. It isn't responsible for checking if that particular frame + * type is expected to be received on a particular stream type. + */ +abstract class H3FrameOrderVerifier { + long currentProcessingFrameType = -1; // -1 implies no frame being processed currently + long lastCompletedFrameType = -1; // -1 implies no frame processing has completed yet + + /** + * {@return a frame order verifier for HTTP3 request/response stream} + */ + static H3FrameOrderVerifier newForRequestResponseStream() { + return new ResponseStreamVerifier(false); + } + + /** + * {@return a frame order verifier for HTTP3 push promise stream} + */ + static H3FrameOrderVerifier newForPushPromiseStream() { + return new ResponseStreamVerifier(true); + } + + /** + * {@return a frame order verifier for HTTP3 control stream} + */ + static H3FrameOrderVerifier newForControlStream() { + return new ControlStreamVerifier(); + } + + /** + * @param frame The frame that has been received + * {@return true if the {@code frameType} processing can start. false otherwise} + */ + abstract boolean allowsProcessing(final Http3Frame frame); + + /** + * Marks the receipt of complete content of a frame that was currently being processed + * + * @param frame The frame whose content was fully received + * @throws IllegalStateException If the passed frame type wasn't being currently processed + */ + void completed(final Http3Frame frame) { + if (frame instanceof UnknownFrame) { + return; + } + final long frameType = frame.type(); + if (currentProcessingFrameType != frameType) { + throw new IllegalStateException("Unexpected completion of processing " + + "of frame type (" + frameType + "): " + + Http3FrameType.asString(frameType) + ", expected " + + Http3FrameType.asString(currentProcessingFrameType)); + } + currentProcessingFrameType = -1; + lastCompletedFrameType = frameType; + } + + private static final class ControlStreamVerifier extends H3FrameOrderVerifier { + + @Override + boolean allowsProcessing(final Http3Frame frame) { + if (frame instanceof MalformedFrame) { + // a malformed frame can come in any time, so we allow it to be processed + // and we don't "track" it either + return true; + } + if (frame instanceof UnknownFrame) { + // unknown frames can come in any time, we allow them to be processed + // and we don't track their processing/completion. However, if an unknown frame + // is the first frame on a control stream then that's an error and we return "false" + // to prevent processing that frame. + // RFC-9114, section 9, which states - "where a known frame type is required to be + // in a specific location, such as the SETTINGS frame as the first frame of the + // control stream, an unknown frame type does not satisfy that requirement and + // SHOULD be treated as an error" + return lastCompletedFrameType != -1; + } + final long frameType = frame.type(); + if (currentProcessingFrameType != -1) { + // we are in the middle of processing a particular frame type and we + // only expect additional frames of only that type + return frameType == currentProcessingFrameType; + } + // we are not currently processing any frame + if (lastCompletedFrameType == -1) { + // there was no previous frame either, so this is the first frame to have been + // received + if (frameType != SettingsFrame.TYPE) { + // unexpected first frame type + return false; + } + currentProcessingFrameType = frameType; + // expected first frame type + return true; + } + // there's no specific ordering specified on control stream other than expecting + // the SETTINGS frame to be the first received (which we have already verified before + // reaching here) + currentProcessingFrameType = frameType; + return true; + } + } + + private static final class ResponseStreamVerifier extends H3FrameOrderVerifier { + private boolean headerSeen; + private boolean dataSeen; + private boolean trailerCompleted; + private final boolean pushStream; + + private ResponseStreamVerifier(boolean pushStream) { + this.pushStream = pushStream; + } + + @Override + boolean allowsProcessing(final Http3Frame frame) { + if (frame instanceof MalformedFrame) { + // a malformed frame can come in any time, so we allow it to be processed + // and we don't track their processing/completion + return true; + } + if (frame instanceof UnknownFrame) { + // unknown frames can come in any time, we allow them to be processed + // and we don't track their processing/completion + return true; + } + final long frameType = frame.type(); + if (currentProcessingFrameType != -1) { + // we are in the middle of processing a particular frame type and we + // only expect additional frames of only that type + return frameType == currentProcessingFrameType; + } + if (frameType == DataFrame.TYPE) { + if (!headerSeen || trailerCompleted) { + // DATA is not permitted before HEADERS or after trailer + return false; + } + dataSeen = true; + } else if (frameType == HeadersFrame.TYPE) { + if (trailerCompleted) { + // HEADERS is not permitted after trailer + return false; + } + headerSeen = true; + if (dataSeen) { + trailerCompleted = true; + } + } else if (frameType == PushPromiseFrame.TYPE) { + // a push promise is only permitted on a response, + // and not on a push stream + if (pushStream) { + return false; + } + } else { + // no other frames permitted + return false; + } + + currentProcessingFrameType = frameType; + return true; + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http1Exchange.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http1Exchange.java index ecc4a63c9d0..02ce63b6314 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/Http1Exchange.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http1Exchange.java @@ -244,7 +244,7 @@ class Http1Exchange extends ExchangeImpl { this.connection = connection; } else { InetSocketAddress addr = request.getAddress(); - this.connection = HttpConnection.getConnection(addr, client, request, HTTP_1_1); + this.connection = HttpConnection.getConnection(addr, client, exchange, request, HTTP_1_1); } this.requestAction = new Http1Request(request, this); this.asyncReceiver = new Http1AsyncReceiver(executor, this); diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http2ClientImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http2ClientImpl.java index 92a48d901ff..cc8a2a7142b 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/Http2ClientImpl.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http2ClientImpl.java @@ -76,7 +76,7 @@ class Http2ClientImpl { /** * When HTTP/2 requested only. The following describes the aggregate behavior including the - * calling code. In all cases, the HTTP2 connection cache + * calling code. In all cases, the HTTP/2 connection cache * is checked first for a suitable connection and that is returned if available. * If not, a new connection is opened, except in https case when a previous negotiate failed. * In that case, we want to continue using http/1.1. When a connection is to be opened and @@ -144,6 +144,7 @@ class Http2ClientImpl { if (conn != null) { try { conn.reserveStream(true, exchange.pushEnabled()); + exchange.connectionAborter.clear(conn.connection); } catch (IOException e) { throw new UncheckedIOException(e); // shouldn't happen } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http2Connection.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http2Connection.java index c33cc93e7dd..63889fa6af2 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/Http2Connection.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http2Connection.java @@ -33,6 +33,7 @@ import java.lang.invoke.VarHandle; import java.net.InetSocketAddress; import java.net.ProtocolException; import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; import java.net.http.HttpHeaders; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -70,6 +71,7 @@ import jdk.internal.net.http.common.SequentialScheduler; import jdk.internal.net.http.common.Utils; import jdk.internal.net.http.common.ValidatingHeadersConsumer; import jdk.internal.net.http.common.ValidatingHeadersConsumer.Context; +import jdk.internal.net.http.frame.AltSvcFrame; import jdk.internal.net.http.frame.ContinuationFrame; import jdk.internal.net.http.frame.DataFrame; import jdk.internal.net.http.frame.ErrorFrame; @@ -90,6 +92,7 @@ import jdk.internal.net.http.hpack.Decoder; import jdk.internal.net.http.hpack.DecodingCallback; import jdk.internal.net.http.hpack.Encoder; import static java.nio.charset.StandardCharsets.UTF_8; +import static jdk.internal.net.http.AltSvcProcessor.processAltSvcFrame; import static jdk.internal.net.http.frame.SettingsFrame.ENABLE_PUSH; import static jdk.internal.net.http.frame.SettingsFrame.HEADER_TABLE_SIZE; import static jdk.internal.net.http.frame.SettingsFrame.INITIAL_CONNECTION_WINDOW_SIZE; @@ -527,6 +530,7 @@ class Http2Connection { AbstractAsyncSSLConnection connection = (AbstractAsyncSSLConnection) HttpConnection.getConnection(request.getAddress(), h2client.client(), + exchange, request, HttpClient.Version.HTTP_2); @@ -635,6 +639,32 @@ class Http2Connection { return true; } + void abandonStream() { + boolean shouldClose = false; + stateLock.lock(); + try { + long reserved = --numReservedClientStreams; + assert reserved >= 0; + if (finalStream && reserved == 0 && streams.isEmpty()) { + shouldClose = true; + } + } catch (Throwable t) { + shutdown(t); // in case the assert fires... + } finally { + stateLock.unlock(); + } + + // We should close the connection here if + // it's not pooled. If it's not pooled it will + // be marked final stream, reserved will be 0 + // after decrementing it by one, and there should + // be no active request-response streams. + if (shouldClose) { + shutdown(new IOException("HTTP/2 connection abandoned")); + } + + } + boolean shouldClose() { stateLock.lock(); try { @@ -1218,6 +1248,8 @@ class Http2Connection { case PingFrame.TYPE -> handlePing((PingFrame) frame); case GoAwayFrame.TYPE -> handleGoAway((GoAwayFrame) frame); case WindowUpdateFrame.TYPE -> handleWindowUpdate((WindowUpdateFrame) frame); + case AltSvcFrame.TYPE -> processAltSvcFrame(0, (AltSvcFrame) frame, + connection, connection.client()); default -> protocolError(ErrorFrame.PROTOCOL_ERROR); } @@ -1323,7 +1355,8 @@ class Http2Connection { try { // idleConnectionTimeoutEvent is always accessed within a lock protected block if (streams.isEmpty() && idleConnectionTimeoutEvent == null) { - idleConnectionTimeoutEvent = client().idleConnectionTimeout() + final HttpClient.Version version = Version.HTTP_2; + idleConnectionTimeoutEvent = client().idleConnectionTimeout(version) .map(IdleConnectionTimeoutEvent::new) .orElse(null); if (idleConnectionTimeoutEvent != null) { @@ -1367,6 +1400,7 @@ class Http2Connection { String protocolError = "protocol error" + (msg == null?"":(": " + msg)); ProtocolException protocolException = new ProtocolException(protocolError); + this.cause.compareAndSet(null, protocolException); if (markHalfClosedLocal()) { framesDecoder.close(protocolError); subscriber.stop(protocolException); @@ -1844,8 +1878,16 @@ class Http2Connection { } finally { Throwable x = errorRef.get(); if (x != null) { - if (debug.on()) debug.log("Stopping scheduler", x); scheduler.stop(); + if (client2.stopping()) { + if (debug.on()) { + debug.log("Stopping scheduler"); + } + } else { + if (debug.on()) { + debug.log("Stopping scheduler", x); + } + } Http2Connection.this.shutdown(x); } } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http3ClientImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http3ClientImpl.java new file mode 100644 index 00000000000..05b27c3d529 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http3ClientImpl.java @@ -0,0 +1,844 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.UnsupportedProtocolVersionException; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; + +import jdk.internal.net.http.AltServicesRegistry.AltService; +import jdk.internal.net.http.common.ConnectionExpiredException; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.QuicClient; +import jdk.internal.net.http.quic.QuicTransportParameters; +import jdk.internal.net.quic.QuicVersion; +import jdk.internal.net.quic.QuicTLSContext; + +import static java.net.http.HttpClient.Version.HTTP_3; +import static jdk.internal.net.http.Http3ClientProperties.WAIT_FOR_PENDING_CONNECT; +import static jdk.internal.net.http.common.Alpns.H3; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_stream_data_bidi_remote; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_streams_bidi; + +/** + * Http3 specific aspects of HttpClientImpl + */ +final class Http3ClientImpl implements AutoCloseable { + // Setting this property disables HTTPS hostname verification. Use with care. + private static final boolean disableHostnameVerification = Utils.isHostnameVerificationDisabled(); + // QUIC versions in their descending order of preference + private static final List availableQuicVersions; + static { + // we default to QUIC v1 followed by QUIC v2, if no specific preference cannot be + // determined + final List defaultPref = List.of(QuicVersion.QUIC_V1, QuicVersion.QUIC_V2); + // check user specified preference + final String sysPropVal = Utils.getProperty("jdk.httpclient.quic.available.versions"); + if (sysPropVal == null || sysPropVal.isBlank()) { + // default to supporting both v1 and v2, with v1 given preference + availableQuicVersions = defaultPref; + } else { + final List descendingPref = new ArrayList<>(); + for (final String val : sysPropVal.split(",")) { + final QuicVersion qv; + try { + // parse QUIC version number represented as a hex string + final var vernum = Integer.parseInt(val.trim(), 16); + qv = QuicVersion.of(vernum).orElse(null); + } catch (NumberFormatException nfe) { + // ignore and continue with next + continue; + } + if (qv == null) { + continue; + } + descendingPref.add(qv); + } + availableQuicVersions = descendingPref.isEmpty() ? defaultPref : descendingPref; + } + } + + private final Logger debug = Utils.getDebugLogger(this::dbgString); + + final HttpClientImpl client; + private final Http3ConnectionPool connections = new Http3ConnectionPool(debug); + private final Http3PendingConnections reconnections = new Http3PendingConnections(); + private final Set pendingClose = ConcurrentHashMap.newKeySet(); + private final Set noH3 = ConcurrentHashMap.newKeySet(); + + private final QuicClient quicClient; + private volatile boolean closed; + private final AtomicReference errorRef = new AtomicReference<>(); + private final ReentrantLock lock = new ReentrantLock(); + + Http3ClientImpl(HttpClientImpl client) { + this.client = client; + var executor = client.theExecutor().safeDelegate(); + var context = client.theSSLContext(); + var parameters = client.sslParameters(); + if (!disableHostnameVerification) { + // setting the endpoint identification algo to HTTPS ensures that + // during the TLS handshake, the cert presented by the server is verified + // for hostname checks against the SNI hostname(s) set by the client + // or in its absence the peer's hostname. + // see sun.security.ssl.X509TrustManagerImpl#checkIdentity(...) + parameters.setEndpointIdentificationAlgorithm("HTTPS"); + } + final QuicTLSContext quicTLSContext = new QuicTLSContext(context); + final QuicClient.Builder builder = new QuicClient.Builder(); + builder.availableVersions(availableQuicVersions) + .tlsContext(quicTLSContext) + .sslParameters(parameters) + .executor(executor) + .applicationErrors(Http3Error::stringForCode) + .clientId(client.dbgString()); + if (client.localAddress() != null) { + builder.bindAddress(new InetSocketAddress(client.localAddress(), 0)); + } + final QuicTransportParameters transportParameters = new QuicTransportParameters(); + // HTTP/3 doesn't allow remote bidirectional stream + transportParameters.setIntParameter(initial_max_streams_bidi, 0); + // HTTP/3 doesn't allow remote bidirectional stream: no need to allow data + transportParameters.setIntParameter(initial_max_stream_data_bidi_remote, 0); + builder.transportParameters(transportParameters); + this.quicClient = builder.build(); + } + + // Records an exchange waiting for a connection recovery to complete. + // A connection recovery happens when a connection has maxed out its number + // of streams, and no MAX_STREAM frame has arrived. In that case, the connection + // is abandoned (marked with setFinalStream() and taken out of the pool) and a + // new connection is initiated. Waiters are waiting for the new connection + // handshake to finish and for the connection to be put in the pool. + record Waiter(MinimalFuture cf, HttpRequestImpl request, Exchange exchange) { + void complete(Http3Connection conn, Throwable error) { + if (error != null) cf.completeExceptionally(error); + else cf.complete(conn); + } + static Waiter of(HttpRequestImpl request, Exchange exchange) { + return new Waiter(new MinimalFuture<>(), request, exchange); + } + } + + // Indicates that recovery is needed, or in progress, for a given + // connection + sealed interface ConnectionRecovery permits PendingConnection, StreamLimitReached { + } + + // Indicates that recovery of a connection has been initiated. + // Waiters will be put in wait until the handshake is completed + // and the connection is inserted in the pool + record PendingConnection(AltService altSvc, Exchange exchange, ConcurrentLinkedQueue waiters) + implements ConnectionRecovery { + PendingConnection(AltService altSvc, Exchange exchange, ConcurrentLinkedQueue waiters) { + this.altSvc = altSvc; + this.waiters = Objects.requireNonNull(waiters); + this.exchange = exchange; + } + PendingConnection(AltService altSvc, Exchange exchange) { + this(altSvc, exchange, new ConcurrentLinkedQueue<>()); + } + } + + // Indicates that a connection that was in the pool has maxed out + // its stream limit and will be taken out of the pool. A new connection + // will be created for the first request/response exchange that needs + // it. + record StreamLimitReached(Http3Connection connection) implements ConnectionRecovery {} + + // Called when recovery is needed for a given connection, with + // the request that got the StreamLimitException + public void streamLimitReached(Http3Connection connection, HttpRequestImpl request) { + lock.lock(); + try { + reconnections.streamLimitReached(connectionKey(request), connection); + } finally { + lock.unlock(); + } + } + + HttpClientImpl client() { + return client; + } + + String dbgString() { + return "Http3ClientImpl(" + client.dbgString() + ")"; + } + + QuicClient quicClient() { + return this.quicClient; + } + + String connectionKey(HttpRequestImpl request) { + return connections.connectionKey(request); + } + + Http3Connection findPooledConnectionFor(HttpRequestImpl request, + Exchange exchange) + throws IOException { + if (request.secure() && request.proxy() == null) { + final var pooled = connections.lookupFor(request); + if (pooled == null) { + return null; + } + if (pooled.tryReserveForPoolCheckout() && !pooled.isFinalStream()) { + final var altService = pooled.connection() + .getSourceAltService().orElse(null); + if (altService != null) { + // if this connection was created because it was advertised by some alt-service + // then verify that the alt-service is still valid/active + if (altService.wasAdvertised() && !client.registry().isActive(altService)) { + if (debug.on()) { + debug.log("Alt-Service %s for pooled connection has expired," + + " marking the connection as unusable for new streams", altService); + } + // alt-service that was the reason for this H3 connection to be created (and pooled) + // is no longer valid. We set a state on the connection to disallow any new streams + // and be auto-closed when all current streams are done + pooled.setFinalStreamAndCloseIfIdle(); + return null; + } + } + if (debug.on()) { + debug.log("Found Http3Connection in connection pool"); + } + // found a valid connection in pool, return it + return pooled; + } else { + if (debug.on()) { + debug.log("Pooled connection expired. Removing it."); + } + removeFromPool(pooled); + } + } + return null; + } + + private static String label(Http3Connection conn) { + return Optional.ofNullable(conn) + .map(Http3Connection::connection) + .map(HttpQuicConnection::label) + .orElse("null"); + } + + private static String describe(HttpRequestImpl request, long id) { + return String.format("%s #%s", request, id); + } + + private static String describe(Exchange exchange) { + if (exchange == null) return "null"; + return describe(exchange.request, exchange.multi.id); + } + + private static String describePendingExchange(String prefix, PendingConnection pending) { + return String.format("%s %s", prefix, describe(pending.exchange)); + } + + private static String describeAltSvc(PendingConnection pendingConnection) { + return Optional.ofNullable(pendingConnection) + .map(PendingConnection::altSvc) + .map(AltService::toString) + .map(s -> "altsvc: " + s) + .orElse("no altSvc"); + } + + // Called after a recovered connection has been put back in the pool + // (or when recovery has failed), or when a new connection handshake + // has completed. + // Waiters, if any, will be notified. + private void connectionCompleted(String connectionKey, Exchange origExchange, Http3Connection conn, Throwable error) { + try { + if (Log.http3()) { + Log.logHttp3("Checking waiters on completed connection {0} to {1} created for {2}", + label(conn), connectionKey, describe(origExchange)); + } + connectionCompleted0(connectionKey, origExchange, conn, error); + } catch (Throwable t) { + if (Log.http3() || Log.errors()) { + Log.logError(t); + } + throw t; + } + } + + private void connectionCompleted0(String connectionKey, Exchange origExchange, Http3Connection conn, Throwable error) { + lock.lock(); + // There should be a connection in the pool at this point, + // so we can remove the PendingConnection from the reconnections list; + PendingConnection pendingConnection = null; + try { + var recovery = reconnections.removeCompleted(connectionKey, origExchange, conn); + if (recovery instanceof PendingConnection pending) { + pendingConnection = pending; + } + } finally { + lock.unlock(); + } + if (pendingConnection == null) { + if (Log.http3()) { + Log.logHttp3("No waiters to complete for " + label(conn)); + } + return; + } + + int waitersCount = pendingConnection.waiters.size(); + if (waitersCount != 0 && Log.http3()) { + Log.logHttp3("Completing " + waitersCount + + " waiters on recreated connection " + label(conn) + + describePendingExchange(" - originally created for", pendingConnection)); + } + + // now for each waiter we're going to try to complete it. + // however, there may be more waiters than available streams! + // so it's rinse and repeat at this point + boolean origExchangeCancelled = origExchange == null ? false : origExchange.multi.requestCancelled(); + int completedWaiters = 0; + int errorWaiters = 0; + int retriedWaiters = 0; + try { + while (!pendingConnection.waiters.isEmpty()) { + var waiter = pendingConnection.waiters.poll(); + if (error != null && (!origExchangeCancelled || waiter.exchange == origExchange)) { + if (Log.http3()) { + Log.logHttp3("Completing pending waiter for: " + waiter.request + " #" + + waiter.exchange.multi.id + " with " + error); + } else if (debug.on()) { + debug.log("Completing waiter for: " + waiter.request + + " #" + waiter.exchange.multi.id + " with " + conn + " error=" + error); + } + errorWaiters++; + waiter.complete(conn, error); + } else { + var request = waiter.request; + var exchange = waiter.exchange; + try { + Http3Connection pooled = findPooledConnectionFor(request, exchange); + if (pooled != null && !pooled.isFinalStream() && !waiter.cf.isDone()) { + if (Log.http3()) { + Log.logHttp3("Completing pending waiter for: " + waiter.request + " #" + + waiter.exchange.multi.id + " with " + label(pooled)); + } else if (debug.on()) { + debug.log("Completing waiter for: " + waiter.request + + " #" + waiter.exchange.multi.id + " with pooled conn " + label(pooled)); + } + completedWaiters++; + waiter.cf.complete(pooled); + } else if (!waiter.cf.isDone()) { + // we call getConnectionFor: it should put waiter in the + // new waiting list, or attempt to open a connection again + if (conn != null) { + if (Log.http3()) { + Log.logHttp3("Not enough streams on recreated connection for: " + waiter.request + " #" + + waiter.exchange.multi.id + " with " + label(conn)); + } else if (debug.on()) { + debug.log("Not enough streams on recreated connection for: " + waiter.request + + " #" + waiter.exchange.multi.id + " with " + label(conn) + + ": retrying on new connection"); + } + retriedWaiters++; + getConnectionFor(request, exchange, waiter); + } else { + if (Log.http3()) { + Log.logHttp3("No HTTP/3 connection for:: " + waiter.request + " #" + + waiter.exchange.multi.id + ": will downgrade or fail"); + } else if (debug.on()) { + debug.log("No HTTP/3 connection for: " + waiter.request + + " #" + waiter.exchange.multi.id + ": will downgrade or fail"); + } + completedWaiters++; + waiter.complete(null, error); + } + } + } catch (Throwable t) { + if (debug.on()) { + debug.log("Completing waiter for: " + waiter.request + + " #" + waiter.exchange.multi.id + " with error: " + + Utils.getCompletionCause(t)); + } + var cause = Utils.getCompletionCause(t); + if (cause instanceof ClosedChannelException) { + cause = new ConnectionExpiredException(cause); + } + if (Log.http3()) { + Log.logHttp3("Completing pending waiter for: " + waiter.request + " #" + + waiter.exchange.multi.id + " with " + cause); + } + errorWaiters++; + waiter.cf.completeExceptionally(cause); + } + } + } + } finally { + if (Log.http3()) { + String pendingInfo = describePendingExchange(" - originally created for", pendingConnection); + + if (conn != null) { + Log.logHttp3(("Connection creation completed for requests to %s: " + + "waiters[%s](completed:%s, retried:%s, errors:%s)%s") + .formatted(connectionKey, waitersCount, completedWaiters, + retriedWaiters, errorWaiters, pendingInfo)); + } else { + Log.logHttp3(("No HTTP/3 connection created for requests to %s, will fail or downgrade: " + + "waiters[%s](completed:%s, retried:%s, errors:%s)%s") + .formatted(connectionKey, waitersCount, completedWaiters, + retriedWaiters, errorWaiters, pendingInfo)); + } + } + } + } + + CompletableFuture getConnectionFor(HttpRequestImpl request, Exchange exchange) { + assert request != null; + return getConnectionFor(request, exchange, null); + } + + private void completeWaiter(Logger debug, Waiter pendingWaiter, Http3Connection r, Throwable t) { + // the recovery was done on behalf of a pending waiter. + // this can happen if the new connection has already maxed out, + // and recovery was initiated on behalf of the next waiter. + if (Log.http3()) { + Log.logHttp3("Completing waiter for: " + pendingWaiter.request + " #" + + pendingWaiter.exchange.multi.id + " with (conn: " + label(r) + " error: " + t +")"); + } else if (debug.on()) { + debug.log("Completing pending waiter for " + pendingWaiter.request + " #" + + pendingWaiter.exchange.multi.id + " with (conn: " + label(r) + " error: " + t +")"); + } + pendingWaiter.complete(r, t); + } + + private CompletableFuture wrapForDebug(CompletableFuture h3Cf, + Exchange exchange, + HttpRequestImpl request) { + if (debug.on() || Log.http3()) { + if (Log.http3()) { + Log.logHttp3("Recreating connection for: " + request + " #" + + exchange.multi.id); + } else if (debug.on()) { + debug.log("Recreating connection for: " + request + " #" + + exchange.multi.id); + } + return h3Cf.whenComplete((r, t) -> { + if (Log.http3()) { + if (r != null && t == null) { + Log.logHttp3("Connection recreated for " + request + " #" + + exchange.multi.id + " on " + label(r)); + } else if (t != null) { + Log.logHttp3("Connection creation failed for " + request + " #" + + exchange.multi.id + ": " + t); + } else if (r == null) { + Log.logHttp3("No connection found for " + request + " #" + + exchange.multi.id); + } + } else if (debug.on()) { + debug.log("Connection recreated for " + request + " #" + + exchange.multi.id); + } + }); + } else { + return h3Cf; + } + } + + Optional lookupAltSvc(HttpRequestImpl request) { + return client.registry() + .lookup(request.uri(), H3::equals) + .findFirst(); + } + + CompletableFuture getConnectionFor(HttpRequestImpl request, + Exchange exchange, + Waiter pendingWaiter) { + assert request != null; + if (Log.http3()) { + if (pendingWaiter != null) { + Log.logHttp3("getConnectionFor pendingWaiter {0}", + describe(pendingWaiter.request, pendingWaiter.exchange.multi.id)); + } else { + Log.logHttp3("getConnectionFor exchange {0}", + describe(request, exchange.multi.id)); + } + } + try { + Http3Connection pooled = findPooledConnectionFor(request, exchange); + if (pooled != null) { + if (pendingWaiter != null) { + if (Log.http3()) { + Log.logHttp3("Completing pending waiter for: " + request + " #" + + exchange.multi.id + " with " + pooled.dbgTag()); + } else if (debug.on()) { + debug.log("Completing pending waiter for: " + request + " #" + + exchange.multi.id + " with " + pooled.dbgTag()); + } + pendingWaiter.cf.complete(pooled); + return pendingWaiter.cf; + } else { + return MinimalFuture.completedFuture(pooled); + } + } + if (request.secure() && request.proxy() == null) { + boolean reconnecting, waitForPendingConnect; + PendingConnection pendingConnection = null; + String key; + Waiter waiter = null; + if (reconnecting = exchange.hasReachedStreamLimit()) { + if (debug.on()) { + debug.log("Exchange has reached limit for: " + request + " #" + + exchange.multi.id); + } + } + if (pendingWaiter != null) reconnecting = true; + lock.lock(); + try { + key = connectionKey(request); + + var recovery = reconnections.lookupFor(key, request, client); + if (debug.on()) debug.log("lookup found %s for %s", recovery, request); + if (recovery instanceof PendingConnection pending) { + // Recovery already initiated. Add waiter to the list! + if (debug.on()) { + debug.log("PendingConnection (%s) found for %s", + describePendingExchange("originally created for", pending), + describe(request, exchange.multi.id)); + } + pendingConnection = pending; + waiter = pendingWaiter == null + ? Waiter.of(request, exchange) + : pendingWaiter; + exchange.streamLimitReached(false); + pendingConnection.waiters.add(waiter); + return waiter.cf; + } else if (recovery instanceof StreamLimitReached) { + // A connection to this server has maxed out its allocated + // streams and will be taken out of the pool, but recovery + // has not been initiated yet. Do that now. + reconnecting = waitForPendingConnect = true; + } else waitForPendingConnect = WAIT_FOR_PENDING_CONNECT; + // By default, we allow concurrent attempts to + // create HTTP/3 connections to the same host, except when + // one connection has reached the maximum number of streams + // it is allowed to use. However, + // if waitForPendingConnect is set to `true` above we will + // only allow one connection to attempt handshake at a given + // time, other requests will be added to a pending list so + // that they can go through that connection. + if (waitForPendingConnect) { + // check again + if ((pooled = findPooledConnectionFor(request, exchange)) == null) { + // initiate recovery + var altSvc = lookupAltSvc(request).orElse(null); + // maybe null if ALT_SVC && altSvc == null + pendingConnection = reconnections.addPending(key, request, altSvc, exchange); + } else if (pendingWaiter != null) { + if (Log.http3()) { + Log.logHttp3("Completing pending waiter for: " + request + " #" + + exchange.multi.id + " with " + pooled.dbgTag()); + } else if (debug.on()) { + debug.log("Completing pending waiter for: " + request + " #" + + exchange.multi.id + " with " + pooled.dbgTag()); + } + pendingWaiter.cf.complete(pooled); + return pendingWaiter.cf; + } else { + return MinimalFuture.completedFuture(pooled); + } + } + } finally { + lock.unlock(); + if (waiter != null && waiter != pendingWaiter && Log.http3()) { + var altSvc = describeAltSvc(pendingConnection); + var orig = Optional.of(pendingConnection) + .map(PendingConnection::exchange) + .map(e -> " created for #" + e.multi.id) + .orElse(""); + Log.logHttp3("Waiting for connection for: " + describe(request, exchange.multi.id) + + " " + altSvc + orig); + } else if (pendingWaiter != null && Log.http3()) { + var altSvc = describeAltSvc(pendingConnection); + Log.logHttp3("Creating connection for: " + describe(request, exchange.multi.id) + + " " + altSvc); + } else if (debug.on() && waiter != null) { + debug.log("Waiting for connection for: " + describe(request, exchange.multi.id) + + (waiter == pendingWaiter ? " (still pending)" : "")); + } + } + + if (Log.http3()) { + Log.logHttp3("Creating connection for Exchange {0}", describe(exchange)); + } else if (debug.on()) { + debug.log("Creating connection for Exchange %s", describe(exchange)); + } + + CompletableFuture h3Cf = Http3Connection + .createAsync(request, this, exchange); + if (reconnecting) { + // System.err.println("Recreating connection for: " + request + " #" + // + exchange.multi.id); + h3Cf = wrapForDebug(h3Cf, exchange, request); + } + if (pendingWaiter != null) { + // the connection was done on behalf of a pending waiter. + // this can happen if the new connection has already maxed out, + // and recovery was initiated on behalf of the next waiter. + h3Cf = h3Cf.whenComplete((r,t) -> completeWaiter(debug, pendingWaiter, r, t)); + } + h3Cf = h3Cf.thenApply(conn -> { + if (conn != null) { + if (debug.on()) { + debug.log("Offering connection %s created for %s", + label(conn), exchange.multi.id); + } + var offered = offerConnection(conn); + if (debug.on()) { + debug.log("Connection offered %s created for %s", + label(conn), exchange.multi.id); + } + // if we return null here, we will downgrade + // but if we return `conn` we will open a new connection. + return offered == null ? conn : offered; + } else { + if (debug.on()) { + debug.log("No connection for exchange #" + exchange.multi.id); + } + return null; + } + }); + if (pendingConnection != null) { + // need to wake up waiters after successful handshake and recovery + h3Cf = h3Cf.whenComplete((r, t) -> connectionCompleted(key, exchange, r, t)); + } + return h3Cf; + } else { + if (debug.on()) + debug.log("Request is unsecure, or proxy isn't null: can't use HTTP/3"); + if (request.isHttp3Only(exchange.version())) { + return MinimalFuture.failedFuture(new UnsupportedProtocolVersionException( + "can't use HTTP/3 with proxied or unsecured connection")); + } + return MinimalFuture.completedFuture(null); + } + } catch (Throwable t) { + if (Log.http3() || Log.errors()) { + Log.logError("Failed to get connection for {0}: {1}", + describe(exchange), t); + } + return MinimalFuture.failedFuture(t); + } + } + + /* + * Cache the given connection, if no connection to the same + * destination exists. If one exists, then we let the initial stream + * complete but allow it to close itself upon completion. + * This situation should not arise with https because the request + * has not been sent as part of the initial alpn negotiation + */ + Http3Connection offerConnection(Http3Connection c) { + if (debug.on()) debug.log("offering to the connection pool: %s", c); + if (!c.isOpen() || c.isFinalStream()) { + if (debug.on()) + debug.log("skipping offered closed or closing connection: %s", c); + return null; + } + + String key = c.key(); + lock.lock(); + try { + if (closed) { + var error = errorRef.get(); + if (error == null) error = new IOException("client closed"); + c.connectionError(error, Http3Error.H3_INTERNAL_ERROR); + return null; + } + Http3Connection c1 = connections.putIfAbsent(key, c); + if (c1 != null) { + // there was a connection in the pool + if (!c1.isFinalStream() || c.isFinalStream()) { + if (!c.isFinalStream()) { + c.allowOnlyOneStream(); + return c; + } else if (c1.isFinalStream()) { + return c; + } + if (debug.on()) + debug.log("existing entry %s in connection pool for %s", c1, key); + // c1 will remain in the pool and we will use c for the given + // request. + if (Log.http3()) { + Log.logHttp3("Existing connection {0} for {1} found in the pool", label(c1), c1.key()); + Log.logHttp3("New connection {0} marked final and not offered to the pool", label(c)); + } + return c1; + } + connections.put(key, c); + } + if (debug.on()) + debug.log("put in the connection pool: %s", c); + return c; + } finally { + lock.unlock(); + } + } + + void removeFromPool(Http3Connection c) { + lock.lock(); + try { + if (connections.remove(c.key(), c)) { + if (debug.on()) + debug.log("removed from the connection pool: %s", c); + } + if (c.isOpen()) { + if (debug.on()) + debug.log("adding to pending close: %s", c); + pendingClose.add(c); + } + } finally { + lock.unlock(); + } + } + + void connectionClosed(Http3Connection c) { + removeFromPool(c); + if (pendingClose.remove(c)) { + if (debug.on()) + debug.log("removed from pending close: %s", c); + } + } + + public Logger debug() { return debug;} + + @Override + public void close() { + try { + lock.lock(); + try { + closed = true; + pendingClose.clear(); + connections.clear(); + } finally { + lock.unlock(); + } + // The client itself is being closed, so we don't individually close the connections + // here and instead just close the QuicClient which then initiates the close of + // the QUIC endpoint. That will silently terminate the underlying QUIC connections + // without exchanging any datagram packets with the peer, since there's no point + // sending/receiving those (including GOAWAY frame) when the endpoint (socket channel) + // itself won't be around after this point. + } finally { + quicClient.close(); + } + } + + // Called in case of RejectedExecutionException, or shutdownNow; + public void abort(Throwable t) { + if (debug.on()) { + debug.log("HTTP/3 client aborting due to " + t); + } + try { + errorRef.compareAndSet(null, t); + List connectionList; + lock.lock(); + try { + closed = true; + connectionList = new ArrayList<>(connections.values().toList()); + connectionList.addAll(pendingClose); + pendingClose.clear(); + connections.clear(); + } finally { + lock.unlock(); + } + for (var conn : connectionList) { + conn.close(t); + } + } finally { + quicClient.abort(t); + } + } + + public void stop() { + close(); + } + + /** + * After an unsuccessful H3 direct connection attempt, + * mark the authority as not supporting h3. + * @param rawAuthority the raw authority (host:port) + */ + public void noH3(String rawAuthority) { + noH3.add(rawAuthority); + } + + /** + * Tells whether the given authority has been marked as + * not supporting h3 + * @param rawAuthority the raw authority (host:port) + * @return true if the given authority is believed to not support h3 + */ + public boolean hasNoH3(String rawAuthority) { + return noH3.contains(rawAuthority); + } + + /** + * A direct HTTP/3 attempt may be attempted if we don't have an + * AltService h3 endpoint recorded for it, and if the given request + * URI's raw authority hasn't been marked as not supporting HTTP/3, + * and if the request discovery config is not ALT_SVC. + * Note that a URI may be marked has not supporting H3 if it doesn't + * acknowledge the first initial quic packet in the time defined + * by {@systemProperty jdk.httpclient.http3.maxDirectConnectionTimeout}. + * @param request the request that may go through h3 + * @return true if there's no h3 endpoint already registered for the given uri. + */ + public boolean mayAttemptDirectConnection(HttpRequestImpl request) { + var config = request.http3Discovery(); + return switch (config) { + // never attempt direct connection with ALT_SVC + case Http3DiscoveryMode.ALT_SVC -> false; + // always attempt direct connection with HTTP_3_ONLY, unless + // it was attempted before and failed + case Http3DiscoveryMode.HTTP_3_URI_ONLY -> + !hasNoH3(request.uri().getRawAuthority()); + // otherwise, attempt direct connection only if we have no + // alt service and it wasn't attempted and failed before + default -> lookupAltSvc(request).isEmpty() + && !hasNoH3(request.uri().getRawAuthority()); + }; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http3ClientProperties.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http3ClientProperties.java new file mode 100644 index 00000000000..81f8c8109d5 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http3ClientProperties.java @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http; + +import jdk.internal.net.http.common.Utils; + +import static jdk.internal.net.http.http3.frames.SettingsFrame.DEFAULT_SETTINGS_MAX_FIELD_SECTION_SIZE; +import static jdk.internal.net.http.http3.frames.SettingsFrame.DEFAULT_SETTINGS_QPACK_BLOCKED_STREAMS; +import static jdk.internal.net.http.http3.frames.SettingsFrame.DEFAULT_SETTINGS_QPACK_MAX_TABLE_CAPACITY; + +/** + * A class that groups initial values for HTTP/3 client properties. + *

+ * Properties starting with {@code jdk.internal.} are not exposed and + * typically reserved for testing. They could be removed, and their name, + * semantics, or values, could be changed at any time. + *

+ * Properties that are exposed are JDK specifics and typically documented + * in the {@link java.net.http} module API documentation. + *

    + *
  1. + *
  2. + *
+ * + * @apiNote + * Not all properties are exposed. Properties that are not included in + * the {@link java.net.http} module API documentation are subject to + * change, and should be considered internal, though we might also consider + * exposing them in the future if needed. + * + */ +public final class Http3ClientProperties { + + private Http3ClientProperties() { + throw new InternalError("should not come here"); + } + + // The maximum timeout to wait for a reply to the first INITIAL + // packet when attempting a direct connection + public static final long MAX_DIRECT_CONNECTION_TIMEOUT; + + // The maximum timeout to wait for a MAX_STREAM frame + // before throwing StreamLimitException + public static final long MAX_STREAM_LIMIT_WAIT_TIMEOUT; + + // The maximum number of concurrent push streams + // by connection + public static final long MAX_HTTP3_PUSH_STREAMS; + + // Limit for dynamic table capacity that the encoder is allowed + // to set. Its capacity is also limited by the QPACK_MAX_TABLE_CAPACITY + // HTTP/3 setting value received from the peer decoder. + public static final long QPACK_ENCODER_TABLE_CAPACITY_LIMIT; + + // The value of SETTINGS_QPACK_MAX_TABLE_CAPACITY HTTP/3 setting that is + // negotiated by HTTP client's decoder + public static final long QPACK_DECODER_MAX_TABLE_CAPACITY; + + // The value of SETTINGS_MAX_FIELD_SECTION_SIZE HTTP/3 setting that is + // negotiated by HTTP client's decoder + public static final long QPACK_DECODER_MAX_FIELD_SECTION_SIZE; + + // Decoder upper bound on the number of streams that can be blocked + public static final long QPACK_DECODER_BLOCKED_STREAMS; + + // of available space in the dynamic table + + // Percentage of occupied space in the dynamic table that controls when + // the draining index starts increasing. This index determines which entries + // are too close to eviction, and can be referenced by the encoder. + public static final int QPACK_ENCODER_DRAINING_THRESHOLD; + + // If set to "true" allows the encoder to insert a header with a dynamic + // name reference and reference it in a field line section without awaiting + // decoder's acknowledgement. + public static final boolean QPACK_ALLOW_BLOCKING_ENCODING = Utils.getBooleanProperty( + "jdk.internal.httpclient.qpack.allowBlockingEncoding", false); + + // whether localhost is acceptable as an alternative service origin + public static final boolean ALTSVC_ALLOW_LOCAL_HOST_ORIGIN = Utils.getBooleanProperty( + "jdk.httpclient.altsvc.allowLocalHostOrigin", true); + + // whether concurrent HTTP/3 requests to the same host should wait for + // first connection to succeed (or fail) instead of attempting concurrent + // connections. Where concurrent connections are attempted, only one of + // them will be offered to the connection pool. The others will serve a + // single request. + public static final boolean WAIT_FOR_PENDING_CONNECT = Utils.getBooleanProperty( + "jdk.httpclient.http3.waitForPendingConnect", true); + + + static { + // 375 is ~ to the initial loss timer + // 1000 is ~ the initial PTO + // We will set a timeout of 2*1375 ms to wait for the reply to our + // first initial packet for a direct connection + long defaultMaxDirectConnectionTimeout = 1375 << 1; // ms + long maxDirectConnectionTimeout = Utils.getLongProperty( + "jdk.httpclient.http3.maxDirectConnectionTimeout", + defaultMaxDirectConnectionTimeout); + long maxStreamLimitTimeout = Utils.getLongProperty( + "jdk.httpclient.http3.maxStreamLimitTimeout", + defaultMaxDirectConnectionTimeout); + int defaultMaxHttp3PushStreams = Utils.getIntegerProperty( + "jdk.httpclient.maxstreams", + 100); + int maxHttp3PushStreams = Utils.getIntegerProperty( + "jdk.httpclient.http3.maxConcurrentPushStreams", + defaultMaxHttp3PushStreams); + long defaultDecoderMaxCapacity = 0; + long decoderMaxTableCapacity = Utils.getLongProperty( + "jdk.httpclient.qpack.decoderMaxTableCapacity", + defaultDecoderMaxCapacity); + long decoderBlockedStreams = Utils.getLongProperty( + "jdk.httpclient.qpack.decoderBlockedStreams", + DEFAULT_SETTINGS_QPACK_BLOCKED_STREAMS); + long defaultEncoderTableCapacityLimit = 4096; + long encoderTableCapacityLimit = Utils.getLongProperty( + "jdk.httpclient.qpack.encoderTableCapacityLimit", + defaultEncoderTableCapacityLimit); + int defaultDecoderMaxFieldSectionSize = 393216; // 384kB + long decoderMaxFieldSectionSize = Utils.getIntegerNetProperty( + "jdk.http.maxHeaderSize", Integer.MIN_VALUE, Integer.MAX_VALUE, + defaultDecoderMaxFieldSectionSize, true); + // Percentage of occupied space in the dynamic table that when + // exceeded the dynamic table draining index starts increasing + int drainingThreshold = Utils.getIntegerProperty( + "jdk.internal.httpclient.qpack.encoderDrainingThreshold", + 75); + + MAX_DIRECT_CONNECTION_TIMEOUT = maxDirectConnectionTimeout <= 0 + ? defaultMaxDirectConnectionTimeout : maxDirectConnectionTimeout; + MAX_STREAM_LIMIT_WAIT_TIMEOUT = maxStreamLimitTimeout < 0 + ? defaultMaxDirectConnectionTimeout + : maxStreamLimitTimeout; + MAX_HTTP3_PUSH_STREAMS = Math.max(maxHttp3PushStreams, 0); + QPACK_ENCODER_TABLE_CAPACITY_LIMIT = encoderTableCapacityLimit < 0 + ? defaultEncoderTableCapacityLimit : encoderTableCapacityLimit; + QPACK_DECODER_MAX_TABLE_CAPACITY = decoderMaxTableCapacity < 0 ? + DEFAULT_SETTINGS_QPACK_MAX_TABLE_CAPACITY : decoderMaxTableCapacity; + QPACK_DECODER_MAX_FIELD_SECTION_SIZE = decoderMaxFieldSectionSize < 0 ? + DEFAULT_SETTINGS_MAX_FIELD_SECTION_SIZE : decoderMaxFieldSectionSize; + QPACK_DECODER_BLOCKED_STREAMS = decoderBlockedStreams < 0 ? + DEFAULT_SETTINGS_QPACK_BLOCKED_STREAMS : decoderBlockedStreams; + QPACK_ENCODER_DRAINING_THRESHOLD = Math.clamp(drainingThreshold, 10, 90); + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http3Connection.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http3Connection.java new file mode 100644 index 00000000000..b97a441881d --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http3Connection.java @@ -0,0 +1,1657 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.net.ProtocolException; +import java.net.http.HttpHeaders; +import java.net.http.HttpResponse.PushPromiseHandler.PushId; +import java.net.http.HttpResponse.PushPromiseHandler.PushId.Http3PushId; +import java.net.http.StreamLimitException; +import java.net.http.UnsupportedProtocolVersionException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Predicate; + +import jdk.internal.net.http.Http3PushManager.CancelPushReason; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.CancelPushFrame; +import jdk.internal.net.http.http3.frames.FramesDecoder; +import jdk.internal.net.http.http3.frames.GoAwayFrame; +import jdk.internal.net.http.http3.frames.Http3Frame; +import jdk.internal.net.http.http3.frames.Http3FrameType; +import jdk.internal.net.http.http3.frames.MalformedFrame; +import jdk.internal.net.http.http3.frames.MaxPushIdFrame; +import jdk.internal.net.http.http3.frames.PartialFrame; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.http.http3.streams.Http3Streams; +import jdk.internal.net.http.http3.streams.Http3Streams.StreamType; +import jdk.internal.net.http.http3.streams.PeerUniStreamDispatcher; +import jdk.internal.net.http.http3.streams.QueuingStreamPair; +import jdk.internal.net.http.http3.streams.UniStreamPair; +import jdk.internal.net.http.qpack.Decoder; +import jdk.internal.net.http.qpack.Encoder; +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.TableEntry; +import jdk.internal.net.http.quic.QuicConnection; +import jdk.internal.net.http.quic.QuicStreamLimitException; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.http.quic.VariableLengthEncoder; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; +import jdk.internal.net.http.quic.streams.QuicStream; +import jdk.internal.net.http.quic.streams.QuicStreamWriter; +import jdk.internal.net.http.quic.streams.QuicStreams; +import static java.net.http.HttpClient.Version.HTTP_3; +import static jdk.internal.net.http.Http3ClientProperties.MAX_STREAM_LIMIT_WAIT_TIMEOUT; +import static jdk.internal.net.http.http3.Http3Error.H3_CLOSED_CRITICAL_STREAM; +import static jdk.internal.net.http.http3.Http3Error.H3_INTERNAL_ERROR; +import static jdk.internal.net.http.http3.Http3Error.H3_NO_ERROR; +import static jdk.internal.net.http.http3.Http3Error.H3_STREAM_CREATION_ERROR; + +/** + * An HTTP/3 connection wraps an HttpQuicConnection and implements + * HTTP/3 on top it. + */ +public final class Http3Connection implements AutoCloseable { + + private final Logger debug = Utils.getDebugLogger(this::dbgTag); + private final Http3ClientImpl client; + private final HttpQuicConnection connection; + private final QuicConnection quicConnection; + // key by which this connection will be referred to within the connection pool + private final String connectionKey; + private final String dbgTag; + private final UniStreamPair controlStreamPair; + private final UniStreamPair qpackEncoderStreams; + private final UniStreamPair qpackDecoderStreams; + private final Encoder qpackEncoder; + private final Decoder qpackDecoder; + private final FramesDecoder controlFramesDecoder; + private final Predicate remoteStreamListener; + private final H3FrameOrderVerifier frameOrderVerifier = H3FrameOrderVerifier.newForControlStream(); + // streams for HTTP3 exchanges + private final ConcurrentMap exchangeStreams = new ConcurrentHashMap<>(); + private final ConcurrentMap> exchanges = new ConcurrentHashMap<>(); + // true when the settings frame has been received on the control stream of this connection + private volatile boolean settingsFrameReceived; + // the settings we received from the peer + private volatile ConnectionSettings peerSettings; + // the settings we send to our peer + private volatile ConnectionSettings ourSettings; + // for tests + private final MinimalFuture peerSettingsCF = new MinimalFuture<>(); + // the (lowest) request stream id received in GOAWAY frames on this connection. + // subsequent request stream id(s) (if any) must always be equal to lesser than this value + // as per spec + // -1 is used to imply no GOAWAY received so far + private final AtomicLong lowestGoAwayReceipt = new AtomicLong(-1); + private volatile IdleConnectionTimeoutEvent idleConnectionTimeoutEvent; + // value of true implies no more streams will be initiated on this connection, + // and the connection will be closed once the in-progress streams complete. + private volatile boolean finalStream; + private volatile boolean allowOnlyOneStream; + // set to true if we decide to open a new connection + // due to stream limit reached + private volatile boolean streamLimitReached; + + private static final int GOAWAY_SENT = 1; // local endpoint sent GOAWAY + private static final int GOAWAY_RECEIVED = 2; // received GOAWAY from remote peer + private static final int CLOSED = 4; // close called on QUIC connection + volatile int closedState; + + private final ReentrantLock lock = new ReentrantLock(); + private final Http3PushManager pushManager; + private final AtomicLong reservedStreamCount = new AtomicLong(); + + // The largest pushId for a remote created stream. + // After GOAWAY has been sent, we will not accept + // any larger pushId. + private final AtomicLong largestPushId = new AtomicLong(); + + // The max pushId for which a frame was scheduled to be sent. + // This should always be less or equal to pushManager.maxPushId + private final AtomicLong maxPushIdSent = new AtomicLong(); + + + /** + * Creates a new HTTP/3 connection over a given {@link HttpQuicConnection}. + * + * @apiNote + * This constructor is invoked upon a successful quic connection establishment, + * typically after a successful Quic handshake. Creating the Http3Connection + * earlier, for instance, after receiving the Server Hello, could also be considered. + * + * @implNote + * Creating an HTTP/3 connection will trigger the creation of the HTTP/3 control + * stream, sending of the HTTP/3 Settings frame, and creation of the QPack + * encoder/decoder streams. + * + * @param request the request which triggered the creation of the connection + * @param client the Http3Client instance this connection belongs to + * @param connection the {@code HttpQuicConnection} that was established + */ + Http3Connection(HttpRequestImpl request, Http3ClientImpl client, HttpQuicConnection connection) { + this.connectionKey = client.connectionKey(request); + this.client = client; + this.connection = connection; + this.quicConnection = connection.quicConnection(); + var qdb = quicConnection.dbgTag(); + this.dbgTag = "H3(" + qdb +")"; + this.pushManager = new Http3PushManager(this); // OK to leak this + controlFramesDecoder = new FramesDecoder("H3-control("+qdb+")", + FramesDecoder::isAllowedOnControlStream); + controlStreamPair = new UniStreamPair( + StreamType.CONTROL, + quicConnection, + this::processPeerControlBytes, + this::lcsWriterLoop, + this::controlStreamFailed, + debug); + + qpackEncoder = new Encoder(Http3Connection::shouldUpdateDynamicTable, + this::createEncoderStreams, this::connectionError); + qpackEncoderStreams = qpackEncoder.encoderStreams(); + qpackDecoder = new Decoder(this::createDecoderStreams, this::connectionError); + qpackDecoderStreams = qpackDecoder.decoderStreams(); + // Register listener to be called when the peer opens a new stream + remoteStreamListener = this::onOpenRemoteStream; + quicConnection.addRemoteStreamListener(remoteStreamListener); + + // Registers dependent actions with the controlStreamPair + // .futureSenderStreamWriter() CF, in order to send + // the SETTINGS and MAX_PUSHID frames. + // These actions will be executed when the stream writer is + // available. + // + // This will schedule the SETTINGS and MAX_PUSHID frames + // for writing, buffering them if necessary until control + // flow credits are available. + // + // If an exception happens the connection will be + // closed abruptly (by closing the underlying quic connection) + // with an error of type Http3Error.H3_INTERNAL_ERROR + controlStreamPair.futureSenderStreamWriter() + // Send SETTINGS first + .thenApply(this::sendSettings) + // Chains to sending MAX_PUSHID after SETTINGS + .thenApply(this::sendMaxPushId) + // arranges for the connection to be closed + // in case of exception. Throws in the dependent + // action after wrapping the exception if needed. + .exceptionally(this::exceptionallyAndClose); + if (Log.http3()) { + Log.logHttp3("HTTP/3 connection created for " + quicConnectionTag() + " - local address: " + + quicConnection.localAddress()); + } + } + + public String quicConnectionTag() { + return quicConnection.logTag(); + } + + private static boolean shouldUpdateDynamicTable(TableEntry tableEntry) { + if (tableEntry.type() == TableEntry.EntryType.NAME_VALUE) { + return false; + } + return switch (tableEntry.name().toString()) { + case ":authority", "user-agent" -> !tableEntry.value().isEmpty(); + default -> false; + }; + } + + private void lock() { + lock.lock(); + } + + private void unlock() { + lock.unlock(); + } + + /** + * Debug tag used to create the debug logger for this + * HTTP/3 connection instance. + * + * @return a debug tag + */ + String dbgTag() { + return dbgTag; + } + + /** + * Asynchronously create an instance of an HTTP/3 connection, if the + * server has a known HTTP/3 endpoint. + * @param request the first request that will go over this connection + * @param h3client the HTTP/3 client + * @param exchange the exchange for which this connection is created + * @return a completable future that will be completed with a new + * HTTP/3 connection, or {@code null} if no usable HTTP/3 endpoint + * was found, or completed exceptionally if an error occurred + */ + static CompletableFuture createAsync(HttpRequestImpl request, + Http3ClientImpl h3client, + Exchange exchange) { + assert request.secure(); + final HttpConnection connection = HttpConnection.getConnection(request.getAddress(), + h3client.client(), + exchange, + request, + HTTP_3); + var debug = h3client.debug(); + var where = "Http3Connection.createAsync"; + if (!(connection instanceof HttpQuicConnection httpQuicConnection)) { + if (Log.http3()) { + Log.logHttp3("{0}: Connection for {1} #{2} is not an HttpQuicConnection: {3}", + where, request, exchange.multi.id, connection); + } + if (debug.on()) + debug.log("%s: Connection is not an HttpQuicConnection: %s", where, connection); + if (request.isHttp3Only(exchange.version())) { + assert connection == null; + // may happen if the client doesn't support HTTP3 + return MinimalFuture.failedFuture(new UnsupportedProtocolVersionException( + "cannot establish exchange to requested origin with HTTP/3")); + } + return MinimalFuture.completedFuture(null); + } + if (debug.on()) { + debug.log("%s: Got HttpQuicConnection: %s", where, connection); + } + if (Log.http3()) { + Log.logHttp3("{0}: Got HttpQuicConnection for {1} #{2} is: {3}", + where, request, exchange.multi.id, connection.label()); + } + + // Expose the underlying connection to the exchange's aborter so it can + // be closed if a timeout occurs. + exchange.connectionAborter.connection(httpQuicConnection); + + return httpQuicConnection.connectAsync(exchange) + .thenCompose(unused -> httpQuicConnection.finishConnect()) + .thenCompose(unused -> checkSSLConfig(httpQuicConnection)) + .thenCompose(notused-> { + CompletableFuture cf = new MinimalFuture<>(); + try { + if (debug.on()) + debug.log("creating Http3Connection for %s", httpQuicConnection); + Http3Connection hc = new Http3Connection(request, h3client, httpQuicConnection); + if (!hc.isFinalStream()) { + exchange.connectionAborter.clear(httpQuicConnection); + cf.complete(hc); + } else { + var io = new IOException("can't reserve first stream"); + if (Log.http3()) { + Log.logHttp3(" Unable to use HTTP/3 connection over {0}: {1}", + hc.quicConnectionTag(), + io); + } + hc.protocolError(io); + cf.complete(null); + } + } catch (Exception e) { + cf.completeExceptionally(e); + } + return cf; } ) + .whenComplete(httpQuicConnection::connectionEstablished); + } + + private static CompletableFuture checkSSLConfig(HttpQuicConnection quic) { + // HTTP/2 checks ALPN here; with HTTP/3, we only offer one ALPN, + // and TLS verifies that it's negotiated. + + // We can examine the negotiated parameters here and possibly fail + // if they are not satisfactory. + return MinimalFuture.completedFuture(null); + } + + HttpQuicConnection connection() { + return connection; + } + + String key() { + return connectionKey; + } + + /** + * Whether the final stream (last stream allowed on a connection), has + * been set. + * + * @return true if the final stream has been set. + */ + boolean isFinalStream() { + return this.finalStream; + } + + /** + * Sets the final stream to be the next stream opened on + * the connection. No other stream will be opened after this. + */ + void setFinalStream() { + this.finalStream = true; + } + + void setFinalStreamAndCloseIfIdle() { + boolean closeNow; + lock(); + try { + setFinalStream(); + closeNow = finalStreamClosed(); + } finally { + unlock(); + } + if (closeNow) close(); + } + + void allowOnlyOneStream() { + lock(); + try { + if (isFinalStream()) return; + this.allowOnlyOneStream = true; + this.finalStream = true; + } finally { + unlock(); + } + } + + boolean isOpen() { + return closedState == 0 && quicConnection.isOpen(); + } + + private IOException checkConnectionError() { + final TerminationCause tc = quicConnection.terminationCause(); + return tc == null ? null : tc.getCloseCause(); + } + + // Used only by tests + CompletableFuture peerSettingsCF() { + return peerSettingsCF; + } + + private boolean reserveStream() { + lock(); + try { + boolean allowStream0 = this.allowOnlyOneStream; + this.allowOnlyOneStream = false; + if (finalStream && !allowStream0) { + return false; + } + reservedStreamCount.incrementAndGet(); + return true; + } finally { + unlock(); + } + } + + CompletableFuture> + createStream(final Exchange exchange) throws IOException { + // check if this connection is closing before initiating this new stream + if (!reserveStream()) { + if (Log.http3()) { + Log.logHttp3("Cannot initiate new stream on connection {0} for exchange {1}", + quicConnectionTag(), exchange); + } + // we didn't create the stream and thus the server hasn't yet processed this request. + // mark the request as unprocessed to allow it to be retried on a different connection. + exchange.markUnprocessedByPeer(); + String message = "cannot initiate additional new streams on chosen connection"; + IOException cause = streamLimitReached + ? new StreamLimitException(HTTP_3, message) + : new IOException(message); + return MinimalFuture.failedFuture(cause); + } + // TODO: this duration is currently "computed" from the request timeout duration. + // this computation needs a bit more thought + final Duration streamLimitIncreaseDuration = exchange.request.timeout() + .map((reqTimeout) -> reqTimeout.dividedBy(2)) + .orElse(Duration.ofMillis(MAX_STREAM_LIMIT_WAIT_TIMEOUT)); + final CompletableFuture bidiStream = + quicConnection.openNewLocalBidiStream(streamLimitIncreaseDuration); + // once the bidi stream creation completes: + // - if completed exceptionally, we transform any QuicStreamLimitException into a + // StreamLimitException + // - if completed successfully, we create a Http3 exchange and return that as the result + final CompletableFuture>> h3ExchangeCf = + bidiStream.handle((stream, t) -> { + if (t == null) { + // no exception occurred and a bidi stream was created on the quic + // connection, but check if the connection has been terminated + // in the meantime + final var terminationCause = checkConnectionError(); + if (terminationCause != null) { + // connection already closed and we haven't yet issued the request. + // mark the exchange as unprocessed to allow it to be retried on + // a different connection. + exchange.markUnprocessedByPeer(); + return MinimalFuture.failedFuture(terminationCause); + } + // creation of bidi stream succeeded, now create the H3 exchange impl + // and return it + final Http3ExchangeImpl h3Exchange = createHttp3ExchangeImpl(exchange, stream); + return MinimalFuture.completedFuture(h3Exchange); + } + // failed to open a bidi stream + reservedStreamCount.decrementAndGet(); + final Throwable cause = Utils.getCompletionCause(t); + if (cause instanceof QuicStreamLimitException) { + if (Log.http3()) { + Log.logHttp3("Maximum stream limit reached on {0} for exchange {1}", + quicConnectionTag(), exchange.multi.streamLimitState()); + } + if (debug.on()) { + debug.log("bidi stream creation failed due to stream limit: " + + cause + ", connection will be marked as unusable for subsequent" + + " requests"); + } + // Since we have reached the stream creation limit (which translates to not + // being able to initiate new requests on this connection), we mark the + // connection as "final stream" (i.e. don't consider this (pooled) + // connection for subsequent requests) + this.streamLimitReachedWith(exchange); + return MinimalFuture.failedFuture(new StreamLimitException(HTTP_3, + "No more streams allowed on connection")); + } else if (cause instanceof ClosedChannelException) { + // stream creation failed due to the connection (that was chosen) + // got closed. Thus the request wasn't processed by the server. + // mark the request as unprocessed to allow it to be + // initiated on a different connection + exchange.markUnprocessedByPeer(); + return MinimalFuture.failedFuture(cause); + } + return MinimalFuture.failedFuture(cause); + }); + return h3ExchangeCf.thenCompose(Function.identity()); + } + + private void streamLimitReachedWith(Exchange exchange) { + streamLimitReached = true; + client.streamLimitReached(this, exchange.request); + setFinalStream(); + } + + private Http3ExchangeImpl createHttp3ExchangeImpl(Exchange exchange, QuicBidiStream stream) { + if (debug.on()) { + debug.log("Temporary reference h3 stream: " + stream.streamId()); + } + if (Log.http3()) { + Log.logHttp3("Creating HTTP/3 exchange for {0}/streamId={1}", + quicConnectionTag(), Long.toString(stream.streamId())); + } + client.client.h3StreamReference(); + try { + lock(); + try { + this.exchangeStreams.put(stream.streamId(), stream); + reservedStreamCount.decrementAndGet(); + var te = idleConnectionTimeoutEvent; + if (te != null) { + client.client().cancelTimer(te); + idleConnectionTimeoutEvent = null; + } + } finally { + unlock(); + } + var http3Exchange = new Http3ExchangeImpl<>(this, exchange, stream); + return registerAndStartExchange(http3Exchange); + } finally { + if (debug.on()) { + debug.log("Temporary unreference h3 stream: " + stream.streamId()); + } + client.client.h3StreamUnreference(); + } + } + + private Http3ExchangeImpl registerAndStartExchange(Http3ExchangeImpl exchange) { + var streamId = exchange.streamId(); + if (debug.on()) debug.log("Reference h3 stream: " + streamId); + client.client.h3StreamReference(); + exchanges.put(streamId, exchange); + exchange.start(); + return exchange; + } + + // marks this connection as no longer available for creating additional streams. current + // streams will run to completion. marking the connection as gracefully shutdown + // can involve sending the necessary protocol message(s) to the peer. + private void sendGoAway() throws IOException { + if (markSentGoAway()) { + // already sent (either successfully or an attempt was made) GOAWAY, nothing more to do + return; + } + // RFC-9114, section 5.2: Endpoints initiate the graceful shutdown of an HTTP/3 connection + // by sending a GOAWAY frame. + final QuicStreamWriter writer = controlStreamPair.localWriter(); + if (writer != null && quicConnection.isOpen()) { + try { + // We send here the largest pushId for which the peer has + // opened a stream. We won't process pushIds larger than that, and + // we will later cancel any pending push promises anyway. + final long lastProcessedPushId = largestPushId.get(); + final GoAwayFrame goAwayFrame = new GoAwayFrame(lastProcessedPushId); + final long size = goAwayFrame.size(); + assert size >= 0 && size < Integer.MAX_VALUE; + final var buf = ByteBuffer.allocate((int) size); + goAwayFrame.writeFrame(buf); + buf.flip(); + if (debug.on()) { + debug.log("Sending GOAWAY frame %s from client connection %s", goAwayFrame, this); + } + writer.scheduleForWriting(buf, false); + } catch (Exception e) { + // ignore - we couldn't send a GOAWAY + if (debug.on()) { + debug.log("Failed to send GOAWAY from client " + this, e); + } + Log.logError("Could not send a GOAWAY from client {0}", this); + Log.logError(e); + } + } + } + + @Override + public void close() { + try { + sendGoAway(); + } catch (IOException ioe) { + // log and ignore the failure + // failure to send a GOAWAY shouldn't prevent closing a connection + if (debug.on()) { + debug.log("failed to send a GOAWAY frame before initiating a close: " + ioe); + } + } + // TODO: ideally we should hava flushForClose() which goes all the way to terminator to flush + // streams and increasing the chances of GOAWAY being sent. + // check RFC-9114, section 5.3 which seems to allow including GOAWAY and CONNECTION_CLOSE + // frames in same packet (optionally) + close(Http3Error.H3_NO_ERROR, "H3 connection closed - no error"); + } + + void close(final Throwable throwable) { + close(H3_INTERNAL_ERROR, null, throwable); + } + + void close(final Http3Error error, final String message) { + if (error != H3_NO_ERROR) { + // construct a ProtocolException representing the connection termination cause + final ProtocolException cause = new ProtocolException(message); + close(error, message, cause); + } else { + close(error, message, null); + } + } + + void close(final Http3Error error, final String logMsg, + final Throwable closeCause) { + if (!markClosed()) { + // already closed, nothing to do + return; + } + if (debug.on()) { + debug.log("Closing HTTP/3 connection: %s %s %s", error, logMsg == null ? "" : logMsg, + closeCause == null ? "" : closeCause.toString()); + debug.log("State is: " + describeClosedState(closedState)); + } + exchanges.values().forEach(e -> e.recordError(closeCause)); + // close the underlying QUIC connection + connection.close(error.code(), logMsg, closeCause); + final TerminationCause tc = connection.quicConnection.terminationCause(); + assert tc != null : "termination cause is null"; + // close all HTTP streams + exchanges.values().forEach(exchange -> exchange.cancelImpl(tc.getCloseCause(), error)); + pushManager.cancelAllPromises(tc.getCloseCause(), error); + discardConnectionState(); + // No longer wait for reading HTTP/3 stream types: + // stop waiting on any stream for which we haven't received the stream + // type yet. + try { + var listener = remoteStreamListener; + if (listener != null) { + quicConnection.removeRemoteStreamListener(listener); + } + } finally { + client.connectionClosed(this); + } + if (!peerSettingsCF.isDone()) { + peerSettingsCF.completeExceptionally(tc.getCloseCause()); + } + } + + private void discardConnectionState() { + controlStreamPair.stopSchedulers(); + controlFramesDecoder.clear(); + qpackDecoderStreams.stopSchedulers(); + qpackEncoderStreams.stopSchedulers(); + } + + private boolean markClosed() { + return markClosedState(CLOSED); + } + + void protocolError(IOException error) { + connectionError(error, Http3Error.H3_GENERAL_PROTOCOL_ERROR); + } + + void connectionError(Throwable throwable, Http3Error error) { + connectionError(null, throwable, error.code(), null); + } + + void connectionError(Http3Stream exchange, Throwable throwable, long errorCode, + String logMsg) { + final Optional error = Http3Error.fromCode(errorCode); + assert error.isPresent() : "not a HTTP3 error code: " + errorCode; + close(error.get(), logMsg, throwable); + } + + public String toString() { + return String.format("Http3Connection(%s)", connection()); + } + + private boolean finalStreamClosed() { + lock(); + try { + return this.finalStream && this.exchangeStreams.isEmpty() && this.reservedStreamCount.get() == 0; + } finally { + unlock(); + } + } + + /** + * Called by the {@link Http3ExchangeImpl} when the exchange is closed. + * + * @param streamId The request stream id + */ + void onExchangeClose(Http3ExchangeImpl exch, final long streamId) { + // we expect it to be a request/response stream + if (!(QuicStreams.isClientInitiated(streamId) && QuicStreams.isBidirectional(streamId))) { + throw new IllegalArgumentException("Not a client initiated bidirectional stream"); + } + if (this.exchangeStreams.remove(streamId) != null) { + if (connection().quicConnection().isOpen()) { + qpackDecoder.cancelStream(streamId); + } + decrementStreamsCount(exch, streamId); + exchanges.remove(streamId); + } + + if (finalStreamClosed()) { + // no more streams open on this connection. close the connection + if (Log.http3()) { + Log.logHttp3("Closing HTTP/3 connection {0} on final stream (streamId={1})", + quicConnectionTag(), Long.toString(streamId)); + } + // close will take care of canceling all pending push promises + // if any push promises are left pending + close(); + } else { + if (Log.http3()) { + Log.logHttp3("HTTP/3 connection {0} left open: exchanged streamId={1} closed; " + + "finalStream={2}, exchangeStreams={3}, reservedStreamCount={4}", + quicConnectionTag(), Long.toString(streamId), finalStream, + exchangeStreams.size(), reservedStreamCount.get()); + } + lock(); + try { + var te = idleConnectionTimeoutEvent; + if (te == null && exchangeStreams.isEmpty()) { + te = idleConnectionTimeoutEvent = client.client().idleConnectionTimeout(HTTP_3) + .map(IdleConnectionTimeoutEvent::new).orElse(null); + if (te != null) { + client.client().registerTimer(te); + } + } + } finally { + unlock(); + } + } + } + + void decrementStreamsCount(Http3ExchangeImpl exch, long streamid) { + if (exch.deRegister()) { + debug.log("Unreference h3 stream: " + streamid); + client.client.h3StreamUnreference(); + } else { + debug.log("Already unreferenced h3 stream: " + streamid); + } + } + + // Called from Http3PushPromiseStream::start (via Http3ExchangeImpl) + void onPushPromiseStreamStarted(Http3PushPromiseStream http3PushPromiseStream, long streamId) { + // HTTP/3 push promises are not refcounted. + // At the moment an ongoing push promise will not prevent the client + // to exit normally, if all request-response streams are finished. + // Here would be the place to increment ref-counting if we wanted to + } + + // Called by Http3PushPromiseStream::close + void onPushPromiseStreamClosed(Http3PushPromiseStream http3PushPromiseStream, long streamId) { + // HTTP/3 push promises are not refcounted. + // At the moment an ongoing push promise will not prevent the client + // to exit normally, if all request-response streams are finished. + // Here would be the place to decrement ref-counting if we wanted to + if (connection().quicConnection().isOpen()) { + qpackDecoder.cancelStream(streamId); + } + } + + /** + * A class used to dispatch peer initiated unidirectional streams + * according to their HTTP/3 stream type. + * The type of an HTTP/3 unidirectional stream is determined by + * reading a variable length integer code off the stream, which + * indicates the type of stream. + * @see Http3Streams + */ + private final class Http3StreamDispatcher extends PeerUniStreamDispatcher { + Http3StreamDispatcher(QuicReceiverStream stream) { + super(stream); + } + + @Override + protected Logger debug() { return debug; } + + @Override + protected void onStreamAbandoned(QuicReceiverStream stream) { + if (debug.on()) debug.log("Stream " + stream.streamId() + " abandoned!"); + qpackDecoder.cancelStream(stream.streamId()); + } + + @Override + protected void onControlStreamCreated(String description, QuicReceiverStream stream) { + complete(description, stream, controlStreamPair.futureReceiverStream()); + } + + @Override + protected void onEncoderStreamCreated(String description, QuicReceiverStream stream) { + complete(description, stream, qpackDecoderStreams.futureReceiverStream()); + } + + @Override + protected void onDecoderStreamCreated(String description, QuicReceiverStream stream) { + complete(description, stream, qpackEncoderStreams.futureReceiverStream()); + } + + @Override + protected void onPushStreamCreated(String description, QuicReceiverStream stream, long pushId) { + Http3Connection.this.onPushStreamCreated(stream, pushId); + } + + // completes the given completable future with the given stream + private void complete(String description, QuicReceiverStream stream, CompletableFuture cf) { + debug.log("completing CF for %s with stream %s", description, stream.streamId()); + boolean completed = cf.complete(stream); + if (!completed) { + if (!cf.isCompletedExceptionally()) { + debug.log("CF for %s already completed with stream %s!", description, cf.resultNow().streamId()); + close(Http3Error.H3_STREAM_CREATION_ERROR, + "%s already created".formatted(description)); + } else { + debug.log("CF for %s already completed exceptionally!", description); + } + } + } + + /** + * Dispatches the given remote initiated unidirectional stream to the + * given Http3Connection after reading the stream type off the stream. + * + * @param conn the Http3Connection with which the stream is associated + * @param stream a newly opened remote unidirectional stream. + */ + static CompletableFuture dispatch(Http3Connection conn, QuicReceiverStream stream) { + assert stream.isRemoteInitiated(); + assert !stream.isBidirectional(); + var dispatcher = conn.new Http3StreamDispatcher(stream); + dispatcher.start(); + return dispatcher.dispatchCF(); + } + } + + /** + * Attempts to notify the idle connection management that this connection should + * be considered "in use". This way the idle connection management doesn't close + * this connection during the time the connection is handed out from the pool and any + * new stream created on that connection. + * + * @return true if the connection has been successfully reserved and is {@link #isOpen()}. false + * otherwise; in which case the connection must not be handed out from the pool. + */ + boolean tryReserveForPoolCheckout() { + // must be done with "stateLock" held to co-ordinate idle connection management + lock(); + try { + cancelIdleShutdownEvent(); + // co-ordinate with the QUIC connection to prevent it from silently terminating + // a potentially idle transport + if (!quicConnection.connectionTerminator().tryReserveForUse()) { + // QUIC says the connection can't be used + return false; + } + // consider the reservation successful only if the connection's state hasn't moved + // to "being closed" + return isOpen() && finalStream == false; + } finally { + unlock(); + } + } + + /** + * Cancels any event that might have been scheduled to shutdown this connection. Must be called + * with the stateLock held. + */ + private void cancelIdleShutdownEvent() { + assert lock.isHeldByCurrentThread() : "Current thread doesn't hold " + lock; + if (idleConnectionTimeoutEvent == null) return; + idleConnectionTimeoutEvent.cancel(); + idleConnectionTimeoutEvent = null; + } + + // An Idle connection is one that has no active streams + // and has not sent the final stream flag + final class IdleConnectionTimeoutEvent extends TimeoutEvent { + + // both cancelled and idleShutDownInitiated are to be accessed + // when holding the connection's lock + private boolean cancelled; + private boolean idleShutDownInitiated; + + IdleConnectionTimeoutEvent(Duration duration) { + super(duration); + } + + @Override + public void handle() { + boolean okToIdleTimeout; + lock(); + try { + if (cancelled || idleShutDownInitiated) { + return; + } + idleShutDownInitiated = true; + if (debug.on()) { + debug.log("H3 idle shutdown initiated"); + } + setFinalStream(); + okToIdleTimeout = finalStreamClosed(); + } finally { + unlock(); + } + if (okToIdleTimeout) { + if (debug.on()) { + debug.log("closing idle H3 connection"); + } + close(); + } + } + + /** + * Cancels this event. Should be called with stateLock held + */ + void cancel() { + assert lock.isHeldByCurrentThread() : "Current thread doesn't hold " + lock; + // mark as cancelled to prevent potentially already triggered event from actually + // doing the shutdown + this.cancelled = true; + // cancel the timer to prevent the event from being triggered (if it hasn't already) + client.client().cancelTimer(this); + } + + @Override + public String toString() { + return "IdleConnectionTimeoutEvent, " + super.toString(); + } + + } + + /** + * This method is called when the peer opens a new stream. + * The stream can be unidirectional or bidirectional. + * + * @param stream the new stream + * @return always returns true (see {@link + * QuicConnection#addRemoteStreamListener(Predicate)} + */ + private boolean onOpenRemoteStream(QuicReceiverStream stream) { + debug.log("on open remote stream: " + stream.streamId()); + if (stream instanceof QuicBidiStream bidi) { + // A server will never open a bidirectional stream + // with the client. A client opens a new bidirectional + // stream for each request/response exchange. + return onRemoteBidirectionalStream(bidi); + } else { + // Four types of unidirectional stream are defined: + // control stream, qpack encoder, qpack decoder, push + // promise stream + return onRemoteUnidirectionalStream(stream); + } + } + + /** + * This method is called when the peer opens a unidirectional stream. + * + * @param uni the unidirectional stream opened by the peer + * @return always returns true ({@link + * QuicConnection#addRemoteStreamListener(Predicate)} + */ + protected boolean onRemoteUnidirectionalStream(QuicReceiverStream uni) { + assert !uni.isBidirectional(); + assert uni.isRemoteInitiated(); + if (!isOpen()) return false; + debug.log("dispatching unidirectional remote stream: " + uni.streamId()); + Http3StreamDispatcher.dispatch(this, uni).whenComplete((r, t)-> { + if (t!=null) this.dispatchingFailed(uni, t); + }); + return true; + } + + /** + * Called when the peer opens a bidirectional stream. + * On the client side, this method should never be called. + * + * @param bidi the new bidirectional stream opened by the + * peer. + * @return always returns false ({@link + * QuicConnection#addRemoteStreamListener(Predicate)} + */ + protected boolean onRemoteBidirectionalStream(QuicBidiStream bidi) { + assert bidi.isRemoteInitiated(); + assert bidi.isBidirectional(); + + // From RFC 9114, Section 6.1: + // Clients MUST treat receipt of a server-initiated bidirectional + // stream as a connection error of type H3_STREAM_CREATION_ERROR + // [ unless such an extension has been negotiated]. + // We don't support any extension, so this is a connection error. + close(Http3Error.H3_STREAM_CREATION_ERROR, + "Bidirectional stream %s opened by server peer" + .formatted(bidi.streamId())); + return false; + } + + /** + * Called if the dispatch failed. + * + * @param reason the reason of the failure + */ + protected void dispatchingFailed(QuicReceiverStream uni, Throwable reason) { + debug.log("dispatching failed for streamId=%s: %s", uni.streamId(), reason); + close(H3_STREAM_CREATION_ERROR, "failed to dispatch remote stream " + uni.streamId(), reason); + } + + + /** + * Schedules sending of client settings. + * + * @return a completable future that will be completed with the + * {@link QuicStreamWriter} allowing to write to the local control + * stream + */ + QuicStreamWriter sendSettings(QuicStreamWriter writer) { + try { + final SettingsFrame settings = QPACK.updateDecoderSettings(SettingsFrame.defaultRFCSettings()); + this.ourSettings = ConnectionSettings.createFrom(settings); + this.qpackDecoder.configure(ourSettings); + if (debug.on()) { + debug.log("Sending client settings %s for connection %s", this.ourSettings, this); + } + long size = settings.size(); + assert size >= 0 && size < Integer.MAX_VALUE; + var buf = ByteBuffer.allocate((int) size); + settings.writeFrame(buf); + buf.flip(); + writer.scheduleForWriting(buf, false); + return writer; + } catch (IOException io) { + throw new CompletionException(io); + } + } + + /** + * Schedules sending of max push id that this (client) connection allows. + * + * @param writer the control stream writer + * @return the {@link QuicStreamWriter} passed as parameter + */ + private QuicStreamWriter sendMaxPushId(QuicStreamWriter writer) { + try { + long maxPushId = pushManager.getMaxPushId(); + if (maxPushId > 0 && maxPushId > maxPushIdSent.get()) { + return sendMaxPushId(writer, maxPushId); + } else { + return writer; + } + } catch (IOException io) { + // will wrap the io exception in CompletionException, + // close the connection, and throw. + throw new CompletionException(io); + } + } + + // local control stream write loop + void lcsWriterLoop() { + // since we do not write much data on the control stream + // we don't check for credit and always directly buffer + // the data to send in the writer. Therefore, there is + // nothing to do in the control stream writer loop. + // + // When more credit is available, check if we need + // to send maxpushid; + if (maxPushIdSent.get() < pushManager.getMaxPushId()) { + var writer = controlStreamPair.localWriter(); + if (writer != null && writer.connected()) { + sendMaxPushId(writer); + } + } + } + + void controlStreamFailed(final QuicStream stream, final UniStreamPair uniStreamPair, + final Throwable throwable) { + Http3Streams.debugErrorCode(debug, stream, "Control stream failed"); + if (stream.state() instanceof QuicReceiverStream.ReceivingStreamState rcvrStrmState) { + if (rcvrStrmState.isReset() && quicConnection.isOpen()) { + // RFC-9114, section 6.2.1: + // If either control stream is closed at any point, + // this MUST be treated as a connection error of type H3_CLOSED_CRITICAL_STREAM. + final String logMsg = "control stream " + stream.streamId() + + " was reset"; + close(H3_CLOSED_CRITICAL_STREAM, logMsg); + return; + } + } + if (isOpen()) { + if (debug.on()) { + debug.log("closing connection since control stream " + stream.mode() + + " failed", throwable); + } + } + close(throwable); + } + + /** + * This method is called to process bytes received on the peer + * control stream. + * + * @param buffer the bytes received + */ + private void processPeerControlBytes(final ByteBuffer buffer) { + debug.log("received server control: %s bytes", buffer.remaining()); + controlFramesDecoder.submit(buffer); + Http3Frame frame; + while ((frame = controlFramesDecoder.poll()) != null) { + final long frameType = frame.type(); + debug.log("server control frame: %s", Http3FrameType.asString(frameType)); + if (frame instanceof MalformedFrame malformed) { + var cause = malformed.getCause(); + if (cause != null && debug.on()) { + debug.log(malformed.toString(), cause); + } + final Http3Error error = Http3Error.fromCode(malformed.getErrorCode()) + .orElse(H3_INTERNAL_ERROR); + close(error, malformed.getMessage()); + controlStreamPair.stopSchedulers(); + controlFramesDecoder.clear(); + return; + } + final boolean settingsRcvd = this.settingsFrameReceived; + if ((frameType == SettingsFrame.TYPE && settingsRcvd) + || !this.frameOrderVerifier.allowsProcessing(frame)) { + final String unexpectedFrameType = Http3FrameType.asString(frameType); + // not expected to be arriving now, we either use H3_FRAME_UNEXPECTED + // or H3_MISSING_SETTINGS for the connection error, depending on the context. + // + // RFC-9114, section 4.1: Receipt of an invalid sequence of frames MUST be + // treated as a connection error of type H3_FRAME_UNEXPECTED. + // + // RFC-9114, section 6.2.1: If the first frame of the control stream + // is any other frame type, this MUST be treated as a connection error of + // type H3_MISSING_SETTINGS. + final String logMsg = "unexpected (order of) frame type: " + unexpectedFrameType + + " on control stream"; + if (!settingsRcvd) { + close(Http3Error.H3_MISSING_SETTINGS, logMsg); + } else { + close(Http3Error.H3_FRAME_UNEXPECTED, logMsg); + } + controlStreamPair.stopSchedulers(); + controlFramesDecoder.clear(); + return; + } + if (frame instanceof SettingsFrame settingsFrame) { + this.settingsFrameReceived = true; + this.peerSettings = ConnectionSettings.createFrom(settingsFrame); + if (debug.on()) { + debug.log("Received peer settings %s for connection %s", this.peerSettings, this); + } + peerSettingsCF.completeAsync(() -> peerSettings, + client.client().theExecutor().safeDelegate()); + // We can only initialize encoder's DT only when we get Settings frame with all parameters + qpackEncoder().configure(peerSettings); + } + if (frame instanceof CancelPushFrame cancelPush) { + pushManager.cancelPushPromise(cancelPush.getPushId(), null, CancelPushReason.CANCEL_RECEIVED); + } + if (frame instanceof GoAwayFrame goaway) { + handleIncomingGoAway(goaway); + } + if (frame instanceof PartialFrame partial) { + var payloadBytes = controlFramesDecoder.readPayloadBytes(); + debug.log("added %s bytes to %s", + payloadBytes == null ? 0 : Utils.remaining(payloadBytes), + frame); + if (partial.remaining() == 0) { + this.frameOrderVerifier.completed(frame); + } else if (payloadBytes == null || payloadBytes.isEmpty()) { + break; + } + // only reserved frames reach here; just drop the payload + } else { + this.frameOrderVerifier.completed(frame); + } + if (controlFramesDecoder.eof()) { + break; + } + } + if (controlFramesDecoder.eof()) { + close(H3_CLOSED_CRITICAL_STREAM, "EOF reached while reading server control stream"); + } + } + + /** + * Called when a new push promise stream is created by the peer. + * + * @apiNote this method gives an opportunity to cancel the stream + * before reading the pushId, if it is known that no push + * will be accepted anyway. + * + * @param pushStream the new push promise stream + * @param pushId or -1 if the pushId is not available yet + */ + private void onPushStreamCreated(QuicReceiverStream pushStream, long pushId) { + assert pushStream.isRemoteInitiated(); + assert !pushStream.isBidirectional(); + + onPushPromiseStream(pushStream, pushId); + } + + /** + * Called when a new push promise stream is created by the peer, and + * the pushId has been read. + * + * @param pushStream the new push promise stream + * @param pushId the pushId + */ + void onPushPromiseStream(QuicReceiverStream pushStream, long pushId) { + assert pushId >= 0; + pushManager.onPushPromiseStream(pushStream, pushId); + } + + /** + * This method is called by the {@link Http3PushManager} to figure out whether + * a push stream or a push promise should be processed, with respect to the + * GOAWAY state. Any pushId larger than what was sent in the GOAWAY frame + * should be cancelled /rejected. + * + * @param pushStream a push stream (may be null if not yet materialized) + * @param pushId a pushId, must be > 0 + * @return true if the pushId can be processed + */ + boolean acceptLargerPushPromise(QuicReceiverStream pushStream, long pushId) { + // if GOAWAY has been sent, just cancel the push promise + // otherwise - track this as the maxPushId that will be + // sent in GOAWAY + if (checkMaxPushId(pushId) != null) return false; // connection will be closed + while (true) { + long largestPushId = this.largestPushId.get(); + if ((closedState & GOAWAY_SENT) == GOAWAY_SENT) { + if (pushId >= largestPushId) { + if (pushStream != null) { + pushStream.requestStopSending(H3_NO_ERROR.code()); + } + pushManager.cancelPushPromise(pushId, null, CancelPushReason.PUSH_CANCELLED); + return false; + } + } + if (pushId <= largestPushId) break; + if (!this.largestPushId.compareAndSet(largestPushId, pushId)) continue; + if ((closedState & GOAWAY_SENT) == 0) break; + } + // If we reach here, then either GOAWAY has been sent with a largestPushId >= pushId, + // or GOAWAY has not been sent yet. + return true; + } + + QueuingStreamPair createEncoderStreams(Consumer encoderReceiver) { + return new QueuingStreamPair(StreamType.QPACK_ENCODER, quicConnection, + encoderReceiver, this::onEncoderStreamsFailed, debug); + } + + private void onEncoderStreamsFailed(final QuicStream stream, final UniStreamPair uniStreamPair, + final Throwable throwable) { + Http3Streams.debugErrorCode(debug, stream, "Encoder stream failed"); + if (stream.state() instanceof QuicReceiverStream.ReceivingStreamState rcvrStrmState) { + if (rcvrStrmState.isReset() && quicConnection.isOpen()) { + // RFC-9204, section 4.2: + // Closure of either unidirectional stream type MUST be treated as a connection + // error of type H3_CLOSED_CRITICAL_STREAM. + final String logMsg = "QPACK encoder stream " + stream.streamId() + + " was reset"; + close(H3_CLOSED_CRITICAL_STREAM, logMsg); + return; + } + } + if (isOpen()) { + if (debug.on()) { + debug.log("closing connection since QPack encoder stream " + stream.streamId() + + " failed", throwable); + } + } + close(throwable); + } + + QueuingStreamPair createDecoderStreams(Consumer encoderReceiver) { + return new QueuingStreamPair(StreamType.QPACK_DECODER, quicConnection, + encoderReceiver, this::onDecoderStreamsFailed, debug); + } + + private void onDecoderStreamsFailed(final QuicStream stream, final UniStreamPair uniStreamPair, + final Throwable throwable) { + Http3Streams.debugErrorCode(debug, stream, "Decoder stream failed"); + if (stream.state() instanceof QuicReceiverStream.ReceivingStreamState rcvrStrmState) { + if (rcvrStrmState.isReset() && quicConnection.isOpen()) { + // RFC-9204, section 4.2: + // Closure of either unidirectional stream type MUST be treated as a connection + // error of type H3_CLOSED_CRITICAL_STREAM. + final String logMsg = "QPACK decoder stream " + stream.streamId() + + " was reset"; + close(H3_CLOSED_CRITICAL_STREAM, logMsg); + return; + } + } + if (isOpen()) { + if (debug.on()) { + debug.log("closing connection since QPack decoder stream " + stream.streamId() + + " failed", throwable); + } + } + close(throwable); + } + + // This method never returns anything: it always throws + private T exceptionallyAndClose(Throwable t) { + try { + return exceptionally(t); + } finally { + close(t); + } + } + + // This method never returns anything: it always throws + private T exceptionally(Throwable t) { + try { + debug.log(t.getMessage(), t); + throw t; + } catch (RuntimeException | Error r) { + throw r; + } catch (ExecutionException x) { + throw new CompletionException(x.getMessage(), x.getCause()); + } catch (Throwable e) { + throw new CompletionException(e.getMessage(), e); + } + } + + Decoder qpackDecoder() { + return qpackDecoder; + } + + Encoder qpackEncoder() { + return qpackEncoder; + } + + /** + * {@return the settings, sent by the peer, for this connection. If none is present, due to the SETTINGS + * frame not yet arriving from the peer, this method returns {@link Optional#empty()}} + */ + Optional getPeerSettings() { + return Optional.ofNullable(this.peerSettings); + } + + private void handleIncomingGoAway(final GoAwayFrame incomingGoAway) { + final long quicStreamId = incomingGoAway.getTargetId(); + if (debug.on()) { + debug.log("Received GOAWAY %s", incomingGoAway); + } + // ensure request stream id is a bidirectional stream originating from the client. + // RFC-9114, section 7.2.6: A client MUST treat receipt of a GOAWAY frame containing + // a stream ID of any other type as a connection error of type H3_ID_ERROR. + if (!(QuicStreams.isClientInitiated(quicStreamId) + && QuicStreams.isBidirectional(quicStreamId))) { + close(Http3Error.H3_ID_ERROR, "Invalid stream id in GOAWAY frame"); + return; + } + boolean validStreamId = false; + long current = lowestGoAwayReceipt.get(); + while (current == -1 || quicStreamId <= current) { + if (lowestGoAwayReceipt.compareAndSet(current, quicStreamId)) { + validStreamId = true; + break; + } + current = lowestGoAwayReceipt.get(); + } + if (!validStreamId) { + // the request stream id received in the GOAWAY frame is greater than the one received + // in some previous GOAWAY frame. This isn't allowed by spec. + // RFC-9114, section 5.2: An endpoint MAY send multiple GOAWAY frames indicating + // different identifiers, but the identifier in each frame MUST NOT be greater than + // the identifier in any previous frame, ... Receiving a GOAWAY containing a larger + // identifier than previously received MUST be treated as a connection error of + // type H3_ID_ERROR. + close(Http3Error.H3_ID_ERROR, "Invalid stream id in newer GOAWAY frame"); + return; + } + markReceivedGoAway(); + // mark a state on this connection to let it know that no new streams are allowed on this + // connection. + // RFC-9114, section 5.2: Endpoints MUST NOT initiate new requests or promise new pushes on + // the connection after receipt of a GOAWAY frame from the peer. + setFinalStream(); + if (debug.on()) { + debug.log("Connection will no longer allow new streams due to receipt of GOAWAY" + + " from peer"); + } + handlePeerUnprocessedStreams(quicStreamId); + if (finalStreamClosed()) { + close(Http3Error.H3_NO_ERROR, "GOAWAY received"); + } + } + + private void handlePeerUnprocessedStreams(final long leastUnprocessedStreamId) { + this.exchanges.forEach((id, exchange) -> { + if (id >= leastUnprocessedStreamId) { + // close the exchange as unprocessed + client.client().theExecutor().execute(exchange::closeAsUnprocessed); + } + }); + } + + private boolean isMarked(int state, int mask) { + return (state & mask) == mask; + } + + private boolean markSentGoAway() { + return markClosedState(GOAWAY_SENT); + } + + private boolean markReceivedGoAway() { + return markClosedState(GOAWAY_RECEIVED); + } + + private boolean markClosedState(int flag) { + int state, desired; + do { + state = closedState; + if ((state & flag) == flag) return false; + desired = state | flag; + } while (!CLOSED_STATE.compareAndSet(this, state, desired)); + return true; + } + + String describeClosedState(int state) { + if (state == 0) return "active"; + String desc = null; + if (isMarked(state, GOAWAY_SENT)) { + if (desc == null) desc = "goaway-sent"; + else desc += "+goaway-sent"; + } + if (isMarked(state, GOAWAY_RECEIVED)) { + if (desc == null) desc = "goaway-rcvd"; + else desc += "+goaway-rcvd"; + } + if (isMarked(state, CLOSED)) { + if (desc == null) desc = "quic-closed"; + else desc += "+quic-closed"; + } + return desc != null ? desc : "0x" + Integer.toHexString(state); + } + + // PushPromise handling + // ==================== + + /** + * {@return a new PushId for the given pushId} + * @param pushId the pushId + */ + PushId newPushId(long pushId) { + return new Http3PushId(pushId, connection.label()); + } + + /** + * Called when a pushId needs to be cancelled. + * @param pushId the pushId to cancel + * @param cause the cause (may be {@code null}). + */ + void pushCancelled(long pushId, Throwable cause) { + pushManager.cancelPushPromise(pushId, cause, CancelPushReason.PUSH_CANCELLED); + } + + /** + * Called if a PushPromiseFrame is received by an exchange that doesn't have any + * {@link java.net.http.HttpResponse.PushPromiseHandler}. The pushId will be + * cancelled, unless it's already been accepted by another exchange. + * + * @param pushId the pushId + */ + void noPushHandlerFor(long pushId) { + pushManager.cancelPushPromise(pushId, null, CancelPushReason.NO_HANDLER); + } + + boolean acceptPromises() { + return exchanges.values().stream().anyMatch(Http3ExchangeImpl::acceptPushPromise); + } + + /** + * {@return a completable future that will be completed when a pushId has been + * accepted by the exchange in charge of creating the response body} + *

+ * The completable future is complete with {@code true} if the pushId is + * accepted, and with {@code false} if the pushId was rejected or cancelled. + * + * @apiNote + * This method is intended to be called when {@link + * #onPushPromiseFrame(Http3ExchangeImpl, long, HttpHeaders)}, returns false, + * indicating that the push promise is being delegated to another request/response + * exchange. + * On completion of the future returned here, if the future is completed + * with {@code true}, the caller is expected to call {@link + * PushGroup#acceptPushPromiseId(PushId)} in order to notify the {@link + * java.net.http.HttpResponse.PushPromiseHandler} of the received {@code pushId}. + *

+ * Callers should not forward the pushId to a {@link + * java.net.http.HttpResponse.PushPromiseHandler} unless the future is completed + * with {@code true} + * + * @param pushId the pushId + */ + CompletableFuture whenPushAccepted(long pushId) { + return pushManager.whenAccepted(pushId); + } + + /** + * Called when a PushPromiseFrame has been decoded. + * + * @param exchange The HTTP/3 exchange that received the frame + * @param pushId The pushId contained in the frame + * @param promiseHeaders The push promise headers contained in the frame + * @return true if the exchange should take care of creating the HttpResponse body, + * false otherwise + */ + boolean onPushPromiseFrame(Http3ExchangeImpl exchange, long pushId, HttpHeaders promiseHeaders) + throws IOException { + return pushManager.onPushPromiseFrame(exchange, pushId, promiseHeaders); + } + + /** + * Checks whether a MAX_PUSH_ID frame should be sent. + */ + void checkSendMaxPushId() { + pushManager.checkSendMaxPushId(); + } + + /** + * Schedules sending of max push id that this (client) connection allows. + * + * @return a completable future that will be completed with the + * {@link QuicStreamWriter} allowing to write to the local control + * stream + */ + private QuicStreamWriter sendMaxPushId(QuicStreamWriter writer, long maxPushId) throws IOException { + debug.log("Sending max push id frame with max push id set to " + maxPushId); + final MaxPushIdFrame maxPushIdFrame = new MaxPushIdFrame(maxPushId); + final long frameSize = maxPushIdFrame.size(); + assert frameSize >= 0 && frameSize < Integer.MAX_VALUE; + final ByteBuffer buf = ByteBuffer.allocate((int) frameSize); + maxPushIdFrame.writeFrame(buf); + buf.flip(); + if (writer.credit() > buf.remaining()) { + long previous; + do { + previous = maxPushIdSent.get(); + if (previous >= maxPushId) return writer; + } while (!maxPushIdSent.compareAndSet(previous, maxPushId)); + writer.scheduleForWriting(buf, false); + } + return writer; + } + + /** + * Send a MAX_PUSH_ID frame on the control stream with the given {@code maxPushId} + * + * @param maxPushId the new maxPushId + * + * @throws IOException if the pushId could not be sent + */ + void sendMaxPushId(long maxPushId) throws IOException { + sendMaxPushId(controlStreamPair.localWriter(), maxPushId); + } + + /** + * Sends a CANCEL_PUSH frame for the given {@code pushId}. + * If not null, the cause may indicate why the push is cancelled. + * + * @apiNote the cause is only used for logging + * + * @param pushId the pushId to cancel + * @param cause the reason for cancelling, may be {@code null} + */ + void sendCancelPush(long pushId, Throwable cause) { + // send CANCEL_PUSH frame here + if (debug.on()) { + if (cause != null) { + debug.log("Push Promise %s cancelled: %s", pushId, cause.getMessage()); + } else { + debug.log("Push Promise %s cancelled", pushId); + } + } + try { + CancelPushFrame cancelPush = new CancelPushFrame(pushId); + long size = cancelPush.size(); + // frame should contain type, length, pushId + assert size <= 3 * VariableLengthEncoder.MAX_INTEGER_LENGTH; + ByteBuffer buffer = ByteBuffer.allocate((int) size); + cancelPush.writeFrame(buffer); + controlStreamPair.localWriter().scheduleForWriting(buffer, false); + } catch (IOException io) { + debug.log("Failed to cancel pushId: " + pushId); + } + } + + /** + * Checks whether the given pushId exceed the maximum pushId allowed + * to the peer, and if so, closes the connection. + * + * @param pushId the pushId + * @return an {@code IOException} that can be used to complete a completable + * future if the maximum pushId is exceeded, {@code null} + * otherwise + */ + IOException checkMaxPushId(long pushId) { + return checkMaxPushId(pushId, maxPushIdSent.get()); + } + + /** + * Checks whether the given pushId exceed the maximum pushId allowed + * to the peer, and if so, closes the connection. + * + * @param pushId the pushId + * @return an {@code IOException} that can be used to complete a completable + * future if the maximum pushId is exceeded, {@code null} + * otherwise + */ + private IOException checkMaxPushId(long pushId, long max) { + if (pushId >= max) { + var io = new ProtocolException("Max pushId exceeded (%s >= %s)".formatted(pushId, max)); + connectionError(io, Http3Error.H3_ID_ERROR); + return io; + } + return null; + } + + /** + * {@return the minimum pushId that can be accepted from the peer} + * Any pushId strictly less than this value must be ignored. + * + * @apiNote The minimum pushId represents the smallest pushId that + * was recorded in our history. For smaller pushId, no history has + * been kept, due to history size constraints. Any pushId strictly + * less than this value must be ignored. + */ + public long getMinPushId() { + return pushManager.getMinPushId(); + } + + private static final VarHandle CLOSED_STATE; + static { + try { + CLOSED_STATE = MethodHandles.lookup().findVarHandle(Http3Connection.class, "closedState", int.class); + } catch (Exception x) { + throw new ExceptionInInitializerError(x); + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http3ConnectionPool.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http3ConnectionPool.java new file mode 100644 index 00000000000..eaacd8213cb --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http3ConnectionPool.java @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http; + +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import jdk.internal.net.http.common.Logger; + +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; + +/** + * This class encapsulate the HTTP/3 connection pool managed + * by an instance of {@link Http3ClientImpl}. + */ +class Http3ConnectionPool { + /* Map key is "scheme:host:port" */ + private final Map advertised = new ConcurrentHashMap<>(); + /* Map key is "scheme:host:port" */ + private final Map unadvertised = new ConcurrentHashMap<>(); + + private final Logger debug; + Http3ConnectionPool(Logger logger) { + this.debug = Objects.requireNonNull(logger); + } + + // https:: + String connectionKey(HttpRequestImpl request) { + var uri = request.uri(); + var scheme = uri.getScheme().toLowerCase(Locale.ROOT); + var host = uri.getHost(); + var port = uri.getPort(); + assert scheme.equals("https"); + if (port < 0) port = 443; // https + return String.format("%s:%s:%d", scheme, host, port); + } + + private Http3Connection lookupUnadvertised(String key, Http3DiscoveryMode discoveryMode) { + var unadvertisedConn = unadvertised.get(key); + if (unadvertisedConn == null) return null; + if (discoveryMode == ANY) return unadvertisedConn; + if (discoveryMode == ALT_SVC) return null; + + assert discoveryMode == HTTP_3_URI_ONLY : String.valueOf(discoveryMode); + + // Double check that if there is an alt service, it has same origin. + final var altService = Optional.ofNullable(unadvertisedConn) + .map(Http3Connection::connection) + .flatMap(HttpQuicConnection::getSourceAltService) + .orElse(null); + + if (altService == null || altService.originHasSameAuthority()) { + return unadvertisedConn; + } + + // We should never come here. + assert false : "unadvertised connection with different origin: %s -> %s" + .formatted(key, altService); + return null; + } + + Http3Connection lookupFor(HttpRequestImpl request) { + var discoveryMode = request.http3Discovery(); + var key = connectionKey(request); + + Http3Connection unadvertisedConn = null; + // If not ALT_SVC, we can use unadvertised connections + if (discoveryMode != ALT_SVC) { + unadvertisedConn = lookupUnadvertised(key, discoveryMode); + if (unadvertisedConn != null && discoveryMode == HTTP_3_URI_ONLY) { + if (debug.on()) { + debug.log("Direct HTTP/3 connection found for %s in connection pool %s", + discoveryMode, unadvertisedConn.connection().label()); + } + return unadvertisedConn; + } + } + + // Then see if we have a connection which was advertised. + var advertisedConn = advertised.get(key); + // We can use it for HTTP3_URI_ONLY too if it has same origin + if (advertisedConn != null) { + final var altService = advertisedConn.connection() + .getSourceAltService().orElse(null); + assert altService != null && altService.wasAdvertised(); + switch (discoveryMode) { + case ANY -> { + return advertisedConn; + } + case ALT_SVC -> { + if (debug.on()) { + debug.log("HTTP/3 connection found for %s in connection pool %s", + discoveryMode, advertisedConn.connection().label()); + } + return advertisedConn; + } + case HTTP_3_URI_ONLY -> { + if (altService != null && altService.originHasSameAuthority()) { + if (debug.on()) { + debug.log("Same authority HTTP/3 connection found for %s in connection pool %s", + discoveryMode, advertisedConn.connection().label()); + } + return advertisedConn; + } + } + } + } + + if (unadvertisedConn != null) { + assert discoveryMode != ALT_SVC; + if (debug.on()) { + debug.log("Direct HTTP/3 connection found for %s in connection pool %s", + discoveryMode, unadvertisedConn.connection().label()); + } + return unadvertisedConn; + } + + // do not log here: this produces confusing logs as this method + // can be called several times when trying to establish a + // connection, when no connection is found in the pool + return null; + } + + Http3Connection putIfAbsent(String key, Http3Connection c) { + Objects.requireNonNull(key); + Objects.requireNonNull(c); + assert key.equals(c.key()); + var altService = c.connection().getSourceAltService().orElse(null); + if (altService != null && altService.wasAdvertised()) { + return advertised.putIfAbsent(key, c); + } + assert altService == null || altService.originHasSameAuthority(); + return unadvertised.putIfAbsent(key, c); + } + + Http3Connection put(String key, Http3Connection c) { + Objects.requireNonNull(key); + Objects.requireNonNull(c); + assert key.equals(c.key()) : "key mismatch %s -> %s" + .formatted(key, c.key()); + var altService = c.connection().getSourceAltService().orElse(null); + if (altService != null && altService.wasAdvertised()) { + return advertised.put(key, c); + } + assert altService == null || altService.originHasSameAuthority(); + return unadvertised.put(key, c); + } + + boolean remove(String key, Http3Connection c) { + Objects.requireNonNull(key); + Objects.requireNonNull(c); + assert key.equals(c.key()) : "key mismatch %s -> %s" + .formatted(key, c.key()); + + var altService = c.connection().getSourceAltService().orElse(null); + if (altService != null && altService.wasAdvertised()) { + boolean remUndavertised = unadvertised.remove(key, c); + assert !remUndavertised + : "advertised connection found in unadvertised pool for " + key; + return advertised.remove(key, c); + } + + assert altService == null || altService.originHasSameAuthority(); + return unadvertised.remove(key, c); + } + + void clear() { + advertised.clear(); + unadvertised.clear(); + } + + java.util.stream.Stream values() { + return java.util.stream.Stream.concat( + advertised.values().stream(), + unadvertised.values().stream()); + } + +} + diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http3ExchangeImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http3ExchangeImpl.java new file mode 100644 index 00000000000..ff1e024673c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http3ExchangeImpl.java @@ -0,0 +1,1795 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http; + +import java.io.EOFException; +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.net.ProtocolException; +import java.net.http.HttpClient.Version; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest.BodyPublisher; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodySubscriber; +import java.net.http.HttpResponse.ResponseInfo; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.Flow; +import java.util.concurrent.Flow.Subscription; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.BiPredicate; + +import jdk.internal.net.http.PushGroup.Acceptor; +import jdk.internal.net.http.common.HttpBodySubscriberWrapper; +import jdk.internal.net.http.common.HttpHeadersBuilder; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.SubscriptionBase; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.common.ValidatingHeadersConsumer; +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.DataFrame; +import jdk.internal.net.http.http3.frames.FramesDecoder; +import jdk.internal.net.http.http3.frames.HeadersFrame; +import jdk.internal.net.http.http3.frames.PushPromiseFrame; +import jdk.internal.net.http.qpack.Decoder; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.Encoder; +import jdk.internal.net.http.qpack.QPackException; +import jdk.internal.net.http.qpack.readers.HeaderFrameReader; +import jdk.internal.net.http.qpack.writers.HeaderFrameWriter; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.http.quic.streams.QuicStreamWriter; +import static jdk.internal.net.http.http3.ConnectionSettings.UNLIMITED_MAX_FIELD_SECTION_SIZE; + +/** + * This class represents an HTTP/3 Request/Response stream. + */ +final class Http3ExchangeImpl extends Http3Stream { + + private static final String COOKIE_HEADER = "Cookie"; + private final Logger debug = Utils.getDebugLogger(this::dbgTag); + private final Http3Connection connection; + private final HttpRequestImpl request; + private final BodyPublisher requestPublisher; + private final HttpHeadersBuilder responseHeadersBuilder; + private final HeadersConsumer rspHeadersConsumer; + private final HttpHeaders requestPseudoHeaders; + private final HeaderFrameReader headerFrameReader; + private final HeaderFrameWriter headerFrameWriter; + private final Decoder qpackDecoder; + private final Encoder qpackEncoder; + private final AtomicReference errorRef; + private final CompletableFuture requestBodyCF; + + private final FramesDecoder framesDecoder = + new FramesDecoder(this::dbgTag, FramesDecoder::isAllowedOnRequestStream); + private final SequentialScheduler readScheduler = + SequentialScheduler.lockingScheduler(this::processQuicData); + private final SequentialScheduler writeScheduler = + SequentialScheduler.lockingScheduler(this::sendQuicData); + private final List> response_cfs = new ArrayList<>(5); + private final ReentrantLock stateLock = new ReentrantLock(); + private final ReentrantLock response_cfs_lock = new ReentrantLock(); + private final H3FrameOrderVerifier frameOrderVerifier = H3FrameOrderVerifier.newForRequestResponseStream(); + + + final SubscriptionBase userSubscription = + new SubscriptionBase(readScheduler, this::cancel, this::onSubscriptionError); + + private final QuicBidiStream stream; + private final QuicStreamReader reader; + private final QuicStreamWriter writer; + volatile boolean closed; + volatile RequestSubscriber requestSubscriber; + volatile HttpResponse.BodySubscriber pendingResponseSubscriber; + volatile HttpResponse.BodySubscriber responseSubscriber; + volatile CompletableFuture responseBodyCF; + volatile boolean requestSent; + volatile boolean responseReceived; + volatile long requestContentLen; + volatile int responseCode; + volatile Response response; + volatile boolean stopRequested; + volatile boolean deRegistered; + private String dbgTag = null; + private final AtomicLong sentQuicBytes = new AtomicLong(); + + Http3ExchangeImpl(final Http3Connection connection, final Exchange exchange, + final QuicBidiStream stream) { + super(exchange); + this.errorRef = new AtomicReference<>(); + this.requestBodyCF = new MinimalFuture<>(); + this.connection = connection; + this.request = exchange.request(); + this.requestPublisher = request.requestPublisher; // may be null + this.responseHeadersBuilder = new HttpHeadersBuilder(); + this.rspHeadersConsumer = new HeadersConsumer(ValidatingHeadersConsumer.Context.RESPONSE); + this.qpackDecoder = connection.qpackDecoder(); + this.qpackEncoder = connection.qpackEncoder(); + this.headerFrameReader = qpackDecoder.newHeaderFrameReader(rspHeadersConsumer); + this.headerFrameWriter = qpackEncoder.newHeaderFrameWriter(); + this.requestPseudoHeaders = Utils.createPseudoHeaders(request); + this.stream = stream; + this.reader = stream.connectReader(readScheduler); + this.writer = stream.connectWriter(writeScheduler); + if (debug.on()) debug.log("Http3ExchangeImpl created"); + } + + public void start() { + if (exchange.pushGroup != null) { + connection.checkSendMaxPushId(); + } + if (Log.http3()) { + Log.logHttp3("Starting HTTP/3 exchange for {0}/streamId={1} ({2} #{3})", + connection.quicConnectionTag(), Long.toString(stream.streamId()), + request, Long.toString(exchange.multi.id)); + } + this.reader.start(); + } + + boolean acceptPushPromise() { + return exchange.pushGroup != null; + } + + String dbgTag() { + if (dbgTag != null) return dbgTag; + long streamId = streamId(); + String sid = streamId == -1 ? "?" : String.valueOf(streamId); + String ctag = connection == null ? null : connection.dbgTag(); + String tag = "Http3ExchangeImpl(" + ctag + ", streamId=" + sid + ")"; + if (streamId == -1) return tag; + return dbgTag = tag; + } + + @Override + long streamId() { + var stream = this.stream; + return stream == null ? -1 : stream.streamId(); + } + + Http3Connection http3Connection() { + return connection; + } + + void recordError(Throwable closeCause) { + errorRef.compareAndSet(null, closeCause); + } + + private sealed class HeadersConsumer extends StreamHeadersConsumer permits PushHeadersConsumer { + + private HeadersConsumer(Context context) { + super(context); + } + + @Override + protected HeaderFrameReader headerFrameReader() { + return headerFrameReader; + } + + @Override + protected HttpHeadersBuilder headersBuilder() { + return responseHeadersBuilder; + } + + @Override + protected final Decoder qpackDecoder() { + return qpackDecoder; + } + + void resetDone() { + if (debug.on()) { + debug.log("Response builder cleared, ready to receive new headers."); + } + } + + + @Override + String headerFieldType() { + return "RESPONSE HEADER FIELD"; + } + + @Override + protected String formatMessage(String message, String header) { + // Malformed requests or responses that are detected MUST be + // treated as a stream error of type H3_MESSAGE_ERROR. + return "malformed response: " + super.formatMessage(message, header); + } + + @Override + protected void headersCompleted() { + handleResponse(); + } + + @Override + public final long streamId() { + return stream.streamId(); + } + + } + + private final class PushHeadersConsumer extends HeadersConsumer { + volatile PushPromiseState state; + + private PushHeadersConsumer() { + super(Context.REQUEST); + } + + @Override + protected HttpHeadersBuilder headersBuilder() { + return state.headersBuilder(); + } + + @Override + protected HeaderFrameReader headerFrameReader() { + return state.reader(); + } + + @Override + String headerFieldType() { + return "PUSH REQUEST HEADER FIELD"; + } + + void resetDone() { + if (debug.on()) { + debug.log("Push request builder cleared."); + } + } + + @Override + protected String formatMessage(String message, String header) { + // Malformed requests or responses that are detected MUST be + // treated as a stream error of type H3_MESSAGE_ERROR. + return "malformed push request: " + super.formatMessage(message, header); + } + + @Override + protected void headersCompleted() { + try { + if (exchange.pushGroup == null) { + long pushId = state.frame().getPushId(); + connection.noPushHandlerFor(pushId); + reset(); + } else { + handlePromise(this); + } + } catch (IOException io) { + cancelPushPromise(state, io); + } + } + + public void setState(PushPromiseState state) { + this.state = state; + } + } + + // TODO: this is also defined on Stream + // + private static boolean hasProxyAuthorization(HttpHeaders headers) { + return headers.firstValue("proxy-authorization") + .isPresent(); + } + + // TODO: this is also defined on Stream + // + // Determines whether we need to build a new HttpHeader object. + // + // Ideally we should pass the filter to OutgoingHeaders refactor the + // code that creates the HeaderFrame to honor the filter. + // We're not there yet - so depending on the filter we need to + // apply and the content of the header we will try to determine + // whether anything might need to be filtered. + // If nothing needs filtering then we can just use the + // original headers. + private static boolean needsFiltering(HttpHeaders headers, + BiPredicate filter) { + if (filter == Utils.PROXY_TUNNEL_FILTER || filter == Utils.PROXY_FILTER) { + // we're either connecting or proxying + // slight optimization: we only need to filter out + // disabled schemes, so if there are none just + // pass through. + return Utils.proxyHasDisabledSchemes(filter == Utils.PROXY_TUNNEL_FILTER) + && hasProxyAuthorization(headers); + } else { + // we're talking to a server, either directly or through + // a tunnel. + // Slight optimization: we only need to filter out + // proxy authorization headers, so if there are none just + // pass through. + return hasProxyAuthorization(headers); + } + } + + // TODO: this is also defined on Stream + // + private HttpHeaders filterHeaders(HttpHeaders headers) { + HttpConnection conn = connection(); + BiPredicate filter = conn.headerFilter(request); + if (needsFiltering(headers, filter)) { + return HttpHeaders.of(headers.map(), filter); + } + return headers; + } + + @Override + HttpQuicConnection connection() { + return connection.connection(); + } + + @Override + CompletableFuture> sendHeadersAsync() { + final MinimalFuture completable = MinimalFuture.completedFuture(null); + return completable.thenApply(_ -> this.sendHeaders()); + } + + private Http3ExchangeImpl sendHeaders() { + assert stream != null; + assert writer != null; + + if (debug.on()) debug.log("H3 sendHeaders"); + if (Log.requests()) { + Log.logRequest(request.toString()); + } + if (requestPublisher != null) { + requestContentLen = requestPublisher.contentLength(); + } else { + requestContentLen = 0; + } + + Throwable t = errorRef.get(); + if (t != null) { + if (debug.on()) debug.log("H3 stream already cancelled, headers not sent: %s", (Object) t); + if (t instanceof CompletionException ce) throw ce; + throw new CompletionException(t); + } + + HttpHeadersBuilder h = request.getSystemHeadersBuilder(); + if (requestContentLen > 0) { + h.setHeader("content-length", Long.toString(requestContentLen)); + } + HttpHeaders sysh = filterHeaders(h.build()); + HttpHeaders userh = filterHeaders(request.getUserHeaders()); + // Filter context restricted from userHeaders + userh = HttpHeaders.of(userh.map(), Utils.ACCEPT_ALL); + Utils.setUserAuthFlags(request, userh); + + // Don't override Cookie values that have been set by the CookieHandler. + final HttpHeaders uh = userh; + BiPredicate overrides = + (k, v) -> COOKIE_HEADER.equalsIgnoreCase(k) + || uh.firstValue(k).isEmpty(); + + // Filter any headers from systemHeaders that are set in userHeaders + // except for "Cookie:" - user cookies will be appended to system + // cookies + sysh = HttpHeaders.of(sysh.map(), overrides); + + if (Log.headers() || debug.on()) { + StringBuilder sb = new StringBuilder("H3 HEADERS FRAME (stream="); + sb.append(streamId()).append(")\n"); + Log.dumpHeaders(sb, " ", requestPseudoHeaders); + Log.dumpHeaders(sb, " ", sysh); + Log.dumpHeaders(sb, " ", userh); + if (Log.headers()) { + Log.logHeaders(sb.toString()); + } else if (debug.on()) { + debug.log(sb); + } + } + + final Optional peerSettings = connection.getPeerSettings(); + // It's possible that the peer settings hasn't yet arrived, in which case we use the + // default of "unlimited" header size limit and proceed with sending the request. As per + // RFC-9114, section 7.2.4.2, this is allowed: All settings begin at an initial value. Each + // endpoint SHOULD use these initial values to send messages before the peer's SETTINGS frame + // has arrived, as packets carrying the settings can be lost or delayed. + // When the SETTINGS frame arrives, any settings are changed to their new values. This + // removes the need to wait for the SETTINGS frame before sending messages. + final long headerSizeLimit = peerSettings.isEmpty() ? UNLIMITED_MAX_FIELD_SECTION_SIZE + : peerSettings.get().maxFieldSectionSize(); + if (headerSizeLimit != UNLIMITED_MAX_FIELD_SECTION_SIZE) { + // specific limit has been set on the header size for this connection. + // we compute the header size and ensure that it doesn't exceed that limit + final long computedHeaderSize = computeHeaderSize(requestPseudoHeaders, sysh, userh); + if (computedHeaderSize > headerSizeLimit) { + // RFC-9114, section 4.2.2: An implementation that has received this parameter + // SHOULD NOT send an HTTP message header that exceeds the indicated size. + // we fail the request. + throw new CompletionException(new ProtocolException("Request headers size" + + " exceeds limit set by peer")); + } + } + List buffers = qpackEncoder.encodeHeaders(headerFrameWriter, streamId(), + 1024, requestPseudoHeaders, sysh, userh); + HeadersFrame headersFrame = new HeadersFrame(Utils.remaining(buffers)); + ByteBuffer buffer = ByteBuffer.allocate(headersFrame.headersSize()); + headersFrame.writeHeaders(buffer); + buffer.flip(); + long sentBytes = 0; + try { + boolean hasNoBody = requestContentLen == 0; + int last = buffers.size() - 1; + int toSend = buffer.remaining(); + if (last < 0) { + writer.scheduleForWriting(buffer, hasNoBody); + } else { + writer.queueForWriting(buffer); + } + sentBytes += toSend; + for (int i = 0; i <= last; i++) { + var nextBuffer = buffers.get(i); + toSend = nextBuffer.remaining(); + if (i == last) { + writer.scheduleForWriting(nextBuffer, hasNoBody); + } else { + writer.queueForWriting(nextBuffer); + } + sentBytes += toSend; + } + } catch (QPackException qe) { + if (qe.isConnectionError()) { + // close the connection + connection.close(qe.http3Error(), "QPack error", qe.getCause()); + } + // fail the request + throw new CompletionException(qe.getCause()); + } catch (IOException io) { + throw new CompletionException(io); + } finally { + if (sentBytes != 0) sentQuicBytes.addAndGet(sentBytes); + } + return this; + } + + private static long computeHeaderSize(final HttpHeaders... headers) { + // RFC-9114, section 4.2.2 states: The size of a field list is calculated based on + // the uncompressed size of fields, including the length of the name and value in bytes + // plus an overhead of 32 bytes for each field. + final int OVERHEAD_BYTES_PER_FIELD = 32; + long computedHeaderSize = 0; + for (final HttpHeaders h : headers) { + for (final Map.Entry> entry : h.map().entrySet()) { + try { + computedHeaderSize = Math.addExact(computedHeaderSize, + entry.getKey().getBytes(StandardCharsets.US_ASCII).length); + for (final String v : entry.getValue()) { + computedHeaderSize = Math.addExact(computedHeaderSize, + v.getBytes(StandardCharsets.US_ASCII).length); + } + computedHeaderSize = Math.addExact(computedHeaderSize, OVERHEAD_BYTES_PER_FIELD); + } catch (ArithmeticException ae) { + // overflow, no point trying to compute further, return MAX_VALUE + return Long.MAX_VALUE; + } + } + } + return computedHeaderSize; + } + + + @Override + CompletableFuture> sendBodyAsync() { + return sendBodyImpl().thenApply((e) -> this); + } + + CompletableFuture sendBodyImpl() { + requestBodyCF.whenComplete((v, t) -> requestSent()); + try { + if (debug.on()) debug.log("H3 sendBodyImpl"); + if (requestPublisher != null && requestContentLen != 0) { + final RequestSubscriber subscriber = new RequestSubscriber(requestContentLen); + requestPublisher.subscribe(requestSubscriber = subscriber); + } else { + // there is no request body, therefore the request is complete, + // END_STREAM has already sent with outgoing headers + requestBodyCF.complete(null); + } + } catch (Throwable t) { + cancelImpl(t, Http3Error.H3_REQUEST_CANCELLED); + requestBodyCF.completeExceptionally(t); + } + return requestBodyCF; + } + + // The Http3StreamResponseSubscriber is registered with the HttpClient + // to ensure that it gets completed if the SelectorManager aborts due + // to unexpected exceptions. + private void registerResponseSubscriber(Http3StreamResponseSubscriber subscriber) { + if (client().registerSubscriber(subscriber)) { + if (debug.on()) { + debug.log("Reference response body for h3 stream: " + streamId()); + } + client().h3StreamReference(); + } + } + + private void unregisterResponseSubscriber(Http3StreamResponseSubscriber subscriber) { + if (client().unregisterSubscriber(subscriber)) { + if (debug.on()) { + debug.log("Unreference response body for h3 stream: " + streamId()); + } + client().h3StreamUnreference(); + } + } + + final class Http3StreamResponseSubscriber extends HttpBodySubscriberWrapper { + Http3StreamResponseSubscriber(BodySubscriber subscriber) { + super(subscriber); + } + + @Override + protected void unregister() { + unregisterResponseSubscriber(this); + } + + @Override + protected void register() { + registerResponseSubscriber(this); + } + + @Override + protected void logComplete(Throwable error) { + if (error == null) { + if (Log.requests()) { + Log.logResponse(() -> "HTTP/3 body successfully completed for: " + request + + " #" + exchange.multi.id); + } + } else { + if (Log.requests()) { + Log.logResponse(() -> "HTTP/3 body exceptionally completed for: " + + request + " (" + error + ")" + + " #" + exchange.multi.id); + } + } + } + } + + + @Override + Http3StreamResponseSubscriber createResponseSubscriber(BodyHandler handler, + ResponseInfo response) { + if (debug.on()) debug.log("Creating body subscriber"); + Http3StreamResponseSubscriber subscriber = + new Http3StreamResponseSubscriber<>(handler.apply(response)); + return subscriber; + } + + @Override + CompletableFuture readBodyAsync(BodyHandler handler, + boolean returnConnectionToPool, + Executor executor) { + try { + if (Log.trace()) { + Log.logTrace("Reading body on stream {0}", streamId()); + } + if (debug.on()) debug.log("Getting BodySubscriber for: " + response); + Http3StreamResponseSubscriber bodySubscriber = + createResponseSubscriber(handler, new ResponseInfoImpl(response)); + CompletableFuture cf = receiveResponseBody(bodySubscriber, executor); + + PushGroup pg = exchange.getPushGroup(); + if (pg != null) { + // if an error occurs make sure it is recorded in the PushGroup + cf = cf.whenComplete((t, e) -> pg.pushError(e)); + } + var bodyCF = cf; + return bodyCF; + } catch (Throwable t) { + // may be thrown by handler.apply + // TODO: Is this the right error code? + cancelImpl(t, Http3Error.H3_REQUEST_CANCELLED); + return MinimalFuture.failedFuture(t); + } + } + + @Override + CompletableFuture ignoreBody() { + try { + if (debug.on()) debug.log("Ignoring body"); + reader.stream().requestStopSending(Http3Error.H3_REQUEST_CANCELLED.code()); + return MinimalFuture.completedFuture(null); + } catch (Throwable e) { + if (Log.trace()) { + Log.logTrace("Error requesting stop sending for stream {0}: {1}", + streamId(), e.toString()); + } + return MinimalFuture.failedFuture(e); + } + } + + @Override + void cancel() { + if (debug.on()) debug.log("cancel"); + var stream = this.stream; + if ((stream == null)) { + cancel(new IOException("Stream cancelled before streamid assigned")); + } else { + cancel(new IOException("Stream " + stream.streamId() + " cancelled")); + } + } + + @Override + void cancel(IOException cause) { + cancelImpl(cause, Http3Error.H3_REQUEST_CANCELLED); + } + + @Override + void onProtocolError(IOException cause) { + final long streamId = stream.streamId(); + if (debug.on()) { + debug.log("cancelling exchange on stream %d due to protocol error: %s", streamId, cause.getMessage()); + } + Log.logError("cancelling exchange on stream {0} due to protocol error: {1}\n", streamId, cause); + cancelImpl(cause, Http3Error.H3_GENERAL_PROTOCOL_ERROR); + } + + @Override + void released() { + long streamid = streamId(); + if (debug.on()) debug.log("Released stream %d", streamid); + // remove this stream from the Http2Connection map. + connection.onExchangeClose(this, streamid); + } + + @Override + void completed() { + } + + @Override + boolean isCanceled() { + return errorRef.get() != null; + } + + @Override + Throwable getCancelCause() { + return errorRef.get(); + } + + @Override + void cancelImpl(Throwable e, Http3Error error) { + try { + var streamid = streamId(); + if (errorRef.compareAndSet(null, e)) { + if (debug.on()) { + if (streamid == -1) debug.log("cancelling stream", e); + else debug.log("cancelling stream " + streamid + ":", e); + } + if (Log.trace()) { + if (streamid == -1) Log.logTrace("cancelling stream: {0}\n", e); + else Log.logTrace("cancelling stream {0}: {1}\n", streamid, e); + } + } else { + if (debug.on()) { + if (streamid == -1) debug.log("cancelling stream: %s", (Object) e); + else debug.log("cancelling stream %s: %s", streamid, e); + } + } + var firstError = errorRef.get(); + completeResponseExceptionally(firstError); + if (!requestBodyCF.isDone()) { + // complete requestBodyCF before cancelling subscription + requestBodyCF.completeExceptionally(firstError); // we may be sending the body... + var requestSubscriber = this.requestSubscriber; + if (requestSubscriber != null) { + cancel(requestSubscriber.subscription.get()); + } + } + var responseBodyCF = this.responseBodyCF; + if (responseBodyCF != null) { + responseBodyCF.completeExceptionally(firstError); + } + // will send a RST_STREAM frame + var stream = this.stream; + if (connection.isOpen()) { + if (stream != null && stream.sendingState().isSending()) { + // no use reset if already closed. + var cause = Utils.getCompletionCause(firstError); + if (!(cause instanceof EOFException)) { + if (debug.on()) + debug.log("sending reset %s", error); + stream.reset(error.code()); + } + } + if (stream != null) { + if (debug.on()) + debug.log("request stop sending"); + stream.requestStopSending(error.code()); + } + } + } catch (Throwable ex) { + errorRef.compareAndSet(null, ex); + if (debug.on()) + debug.log("failed cancelling request: ", ex); + Log.logError(ex); + } finally { + close(); + } + } + + // cancel subscription and ignore errors in order to continue with + // the cancel/close sequence. + private void cancel(Subscription subscription) { + if (subscription == null) return; + try { subscription.cancel(); } + catch (Throwable t) { + debug.log("Unexpected exception thrown by Subscription::cancel", t); + if (Log.errors()) { + Log.logError("Unexpected exception thrown by Subscription::cancel: " + t); + Log.logError(t); + } + } + } + + @Override + CompletableFuture getResponseAsync(Executor executor) { + CompletableFuture cf; + // The code below deals with race condition that can be caused when + // completeResponse() is being called before getResponseAsync() + response_cfs_lock.lock(); + try { + if (!response_cfs.isEmpty()) { + // This CompletableFuture was created by completeResponse(). + // it will be already completed, unless the expect continue + // timeout fired + cf = response_cfs.get(0); + if (cf.isDone()) { + cf = response_cfs.remove(0); + } + + // if we find a cf here it should be already completed. + // finding a non completed cf should not happen. just assert it. + assert cf.isDone() || request.expectContinue && expectTimeoutRaised() + : "Removing uncompleted response: could cause code to hang!"; + } else { + // getResponseAsync() is called first. Create a CompletableFuture + // that will be completed by completeResponse() when + // completeResponse() is called. + cf = new MinimalFuture<>(); + response_cfs.add(cf); + } + } finally { + response_cfs_lock.unlock(); + } + if (executor != null && !cf.isDone()) { + // protect from executing later chain of CompletableFuture operations from SelectorManager thread + cf = cf.thenApplyAsync(r -> r, executor); + } + if (Log.trace()) { + Log.logTrace("Response future (stream={0}) is: {1}", streamId(), cf); + } + PushGroup pg = exchange.getPushGroup(); + if (pg != null) { + // if an error occurs make sure it is recorded in the PushGroup + cf = cf.whenComplete((t, e) -> pg.pushError(Utils.getCompletionCause(e))); + } + if (debug.on()) debug.log("Response future is %s", cf); + return cf; + } + + /** + * Completes the first uncompleted CF on list, and removes it. If there is no + * uncompleted CF then creates one (completes it) and adds to list + */ + void completeResponse(Response resp) { + if (debug.on()) debug.log("completeResponse: %s", resp); + response_cfs_lock.lock(); + try { + CompletableFuture cf; + int cfs_len = response_cfs.size(); + for (int i = 0; i < cfs_len; i++) { + cf = response_cfs.get(i); + if (!cf.isDone() && !expectTimeoutRaised()) { + if (Log.trace()) { + Log.logTrace("Completing response (streamid={0}): {1}", + streamId(), cf); + } + if (debug.on()) + debug.log("Completing responseCF(%d) with response headers", i); + response_cfs.remove(cf); + cf.complete(resp); + return; + } else if (expectTimeoutRaised()) { + Log.logTrace("Completing response (streamid={0}): {1}", + streamId(), cf); + if (debug.on()) + debug.log("Completing responseCF(%d) with response headers", i); + // The Request will be removed in getResponseAsync() + cf.complete(resp); + return; + } // else we found the previous response: just leave it alone. + } + cf = MinimalFuture.completedFuture(resp); + if (Log.trace()) { + Log.logTrace("Created completed future (streamid={0}): {1}", + streamId(), cf); + } + if (debug.on()) + debug.log("Adding completed responseCF(0) with response headers"); + response_cfs.add(cf); + } finally { + response_cfs_lock.unlock(); + } + } + + @Override + void expectContinueFailed(int rcode) { + // Have to mark request as sent, due to no request body being sent in the + // event of a 417 Expectation Failed or some other non 100 response code + requestSent(); + } + + // methods to update state and remove stream when finished + + void requestSent() { + stateLock.lock(); + try { + requestSent0(); + } finally { + stateLock.unlock(); + } + } + + private void requestSent0() { + assert stateLock.isHeldByCurrentThread(); + requestSent = true; + if (responseReceived) { + if (debug.on()) debug.log("requestSent: streamid=%d", streamId()); + close(); + } else { + if (debug.on()) { + debug.log("requestSent: streamid=%d but response not received", streamId()); + } + } + } + + void responseReceived() { + stateLock.lock(); + try { + responseReceived0(); + } finally { + stateLock.unlock(); + } + } + + private void responseReceived0() { + assert stateLock.isHeldByCurrentThread(); + responseReceived = true; + if (requestSent) { + if (debug.on()) debug.log("responseReceived: streamid=%d", streamId()); + close(); + } else { + if (debug.on()) { + debug.log("responseReceived: streamid=%d but request not sent", streamId()); + } + } + } + + /** + * Same as {@link #completeResponse(Response)} above but for errors + */ + void completeResponseExceptionally(Throwable t) { + response_cfs_lock.lock(); + try { + // use index to avoid ConcurrentModificationException + // caused by removing the CF from within the loop. + for (int i = 0; i < response_cfs.size(); i++) { + CompletableFuture cf = response_cfs.get(i); + if (!cf.isDone()) { + response_cfs.remove(i); + cf.completeExceptionally(t); + return; + } + } + response_cfs.add(MinimalFuture.failedFuture(t)); + } finally { + response_cfs_lock.unlock(); + } + } + + @Override + void nullBody(HttpResponse resp, Throwable t) { + if (debug.on()) debug.log("nullBody: streamid=%d", streamId()); + // We should have an END_STREAM data frame waiting in the inputQ. + // We need a subscriber to force the scheduler to process it. + assert pendingResponseSubscriber == null; + pendingResponseSubscriber = HttpResponse.BodySubscribers.replacing(null); + readScheduler.runOrSchedule(); + } + + /** + * An unprocessed exchange is one that hasn't been processed by a peer. The local end of the + * connection would be notified about such exchanges when it receives a GOAWAY frame with + * a stream id that tells which exchanges have been unprocessed. + * This method is called on such unprocessed exchanges and the implementation of this method + * will arrange for the request, corresponding to this exchange, to be retried afresh on a + * new connection. + */ + void closeAsUnprocessed() { + // null exchange implies a PUSH stream and those aren't + // initiated by the client, so we don't expect them to be + // considered unprocessed. + assert this.exchange != null : "PUSH streams aren't expected to be closed as unprocessed"; + // We arrange for the request to be retried on a new connection as allowed + // by RFC-9114, section 5.2 + this.exchange.markUnprocessedByPeer(); + this.errorRef.compareAndSet(null, new IOException("request not processed by peer")); + // close the exchange and complete the response CF exceptionally + close(); + completeResponseExceptionally(this.errorRef.get()); + if (debug.on()) { + debug.log("request unprocessed by peer " + this.request); + } + } + + // This method doesn't send any frame + void close() { + if (closed) return; + Throwable error; + stateLock.lock(); + try { + if (closed) return; + closed = true; + error = errorRef.get(); + } finally { + stateLock.unlock(); + } + if (Log.http3()) { + if (error == null) { + Log.logHttp3("Closed HTTP/3 exchange for {0}/streamId={1}", + connection.quicConnectionTag(), Long.toString(stream.streamId())); + } else { + Log.logHttp3("Closed HTTP/3 exchange for {0}/streamId={1} with error {2}", + connection.quicConnectionTag(), Long.toString(stream.streamId()), + error); + } + } + if (debug.on()) { + debug.log("stream %d is now closed with %s", + streamId(), + error == null ? "no error" : String.valueOf(error)); + } + if (Log.trace()) { + Log.logTrace("Stream {0} is now closed", streamId()); + } + + BodySubscriber subscriber = responseSubscriber; + if (subscriber == null) subscriber = pendingResponseSubscriber; + if (subscriber instanceof Http3StreamResponseSubscriber h3srs) { + // ensure subscriber is unregistered + h3srs.complete(error); + } + connection.onExchangeClose(this, streamId()); + } + + class RequestSubscriber implements Flow.Subscriber { + // can be < 0 if the actual length is not known. + private final long contentLength; + private volatile long remainingContentLength; + private volatile boolean dataHeaderWritten; + private volatile boolean completed; + private final AtomicReference subscription = new AtomicReference<>(); + + RequestSubscriber(long contentLen) { + this.contentLength = contentLen; + this.remainingContentLength = contentLen; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (!this.subscription.compareAndSet(null, subscription)) { + subscription.cancel(); + throw new IllegalStateException("already subscribed"); + } + if (debug.on()) + debug.log("RequestSubscriber: onSubscribe, request 1"); + subscription.request(1); + } + + @Override + public void onNext(ByteBuffer item) { + if (debug.on()) + debug.log("RequestSubscriber: onNext(%d)", item.remaining()); + var subscription = this.subscription.get(); + if (writer.stopSendingReceived()) { + // whether StopSending contains NO_ERROR or not - we should + // not fail the request and simply stop sending the body. + // The sender should either reset the stream or send a full + // response with an error status code if it wants to fail the request. + Http3Error error = Http3Error.fromCode(writer.stream().sndErrorCode()) + .orElse(Http3Error.H3_NO_ERROR); + if (debug.on()) + debug.log("Stop sending requested by peer (%s): canceling subscription", error); + requestBodyCF.complete(null); + subscription.cancel(); + return; + } + + if (isCanceled() || errorRef.get() != null) { + if (writer.sendingState().isSending()) { + try { + if (debug.on()) { + debug.log("onNext called after stream cancelled: " + + "resetting stream %s", streamId()); + } + writer.reset(Http3Error.H3_REQUEST_CANCELLED.code()); + } catch (Throwable t) { + if (debug.on()) debug.log("Failed to reset stream: ", t); + errorRef.compareAndSet(null, t); + requestBodyCF.completeExceptionally(errorRef.get()); + } + } + return; + } + long len = item.remaining(); + try { + writeHeadersIfNeeded(item); + var remaining = remainingContentLength; + if (contentLength >= 0) { + remaining -= len; + remainingContentLength = remaining; + if (remaining < 0) { + lengthMismatch("Too many bytes in request body"); + subscription.cancel(); + } + } + var completed = remaining == 0; + if (completed) this.completed = true; + writer.scheduleForWriting(item, completed); + sentQuicBytes.addAndGet(len); + if (completed) { + requestBodyCF.complete(null); + } + if (writer.credit() > 0) { + if (debug.on()) + debug.log("RequestSubscriber: request 1"); + subscription.request(1); + } else { + if (debug.on()) + debug.log("RequestSubscriber: no more credit"); + } + } catch (Throwable t) { + if (writer.stopSendingReceived()) { + // We can reach here if we continue sending after stop sending + // was received, which may happen since stop sending is + // received asynchronously. In that case, we should + // not fail the request but simply stop sending the body. + // The sender will either reset the stream or send a full + // response with an error status code if it wants to fail + // or complete the request. + if (debug.on()) + debug.log("Stop sending requested by peer: canceling subscription"); + requestBodyCF.complete(null); + subscription.cancel(); + return; + } + // stop sending was not received: cancel the stream + errorRef.compareAndSet(null, t); + if (debug.on()) { + debug.log("Unexpected exception in onNext: " + t); + debug.log("resetting stream %s", streamId()); + } + try { + writer.reset(Http3Error.H3_REQUEST_CANCELLED.code()); + } catch (Throwable rt) { + if (debug.on()) + debug.log("Failed to reset stream: %s", t); + } + cancelImpl(errorRef.get(), Http3Error.H3_REQUEST_CANCELLED); + } + + } + + private void lengthMismatch(String what) { + if (debug.on()) { + debug.log(what + " (%s/%s)", + contentLength - remainingContentLength, contentLength); + } + try { + var failed = new IOException("stream=" + streamId() + " " + + "[" + Thread.currentThread().getName() + "] " + + what + " (" + + (contentLength - remainingContentLength) + "/" + + contentLength + ")"); + errorRef.compareAndSet(null, failed); + writer.reset(Http3Error.H3_REQUEST_CANCELLED.code()); + requestBodyCF.completeExceptionally(errorRef.get()); + } catch (Throwable t) { + if (debug.on()) + debug.log("Failed to reset stream: %s", t); + } + close(); + } + + private void writeHeadersIfNeeded(ByteBuffer item) throws IOException { + long len = item.remaining(); + if (contentLength >= 0) { + if (!dataHeaderWritten) { + dataHeaderWritten = true; + len = contentLength; + } else { + // headers already written: nothing to do. + return; + } + } + DataFrame df = new DataFrame(len); + ByteBuffer headers = ByteBuffer.allocate(df.headersSize()); + df.writeHeaders(headers); + headers.flip(); + int sent = headers.remaining(); + writer.queueForWriting(headers); + if (sent != 0) sentQuicBytes.addAndGet(sent); + } + + @Override + public void onError(Throwable throwable) { + if (debug.on()) + debug.log(() -> "RequestSubscriber: onError: " + throwable); + // ensure that errors are handled within the flow. + if (errorRef.compareAndSet(null, throwable)) { + try { + writer.reset(Http3Error.H3_REQUEST_CANCELLED.code()); + } catch (Throwable t) { + if (debug.on()) debug.log("Failed to reset stream: %s", t); + } + requestBodyCF.completeExceptionally(throwable); + // no need to cancel subscription + close(); + } + } + + @Override + public void onComplete() { + if (debug.on()) debug.log("RequestSubscriber: send request body completed"); + var completed = this.completed; + if (completed || errorRef.get() != null) return; + if (contentLength >= 0 && remainingContentLength != 0) { + if (remainingContentLength < 0) { + lengthMismatch("Too many bytes in request body"); + } else { + lengthMismatch("Too few bytes returned by the publisher"); + } + return; + } + this.completed = true; + try { + writer.scheduleForWriting(QuicStreamReader.EOF, true); + requestBodyCF.complete(null); + } catch (Throwable t) { + if (debug.on()) debug.log("Failed to complete stream: " + t, t); + requestBodyCF.completeExceptionally(t); + } + } + + void unblock() { + if (completed || errorRef.get() != null) { + return; + } + var subscription = this.subscription.get(); + try { + if (writer.credit() > 0) { + if (subscription != null) { + subscription.request(1); + } + } + } catch (Throwable throwable) { + if (debug.on()) + debug.log(() -> "RequestSubscriber: unblock: " + throwable); + // ensure that errors are handled within the flow. + if (errorRef.compareAndSet(null, throwable)) { + try { + writer.reset(Http3Error.H3_REQUEST_CANCELLED.code()); + } catch (Throwable t) { + if (debug.on()) debug.log("Failed to reset stream: %s", t); + } + requestBodyCF.completeExceptionally(throwable); + cancelImpl(throwable, Http3Error.H3_REQUEST_CANCELLED); + subscription.cancel(); + } + } + } + + } + + @Override + Response newResponse(HttpHeaders responseHeaders, int responseCode) { + this.responseCode = responseCode; + return this.response = new Response( + request, exchange, responseHeaders, connection(), + responseCode, Version.HTTP_3); + } + + protected void handleResponse() { + handleResponse(responseHeadersBuilder, rspHeadersConsumer, readScheduler, debug); + } + + protected void handlePromise(PushHeadersConsumer consumer) throws IOException { + PushPromiseState state = consumer.state; + PushPromiseFrame ppf = state.frame(); + promiseMap.remove(ppf); + long pushId = ppf.getPushId(); + + HttpHeaders promiseHeaders = state.headersBuilder().build(); + consumer.reset(); + + if (debug.on()) { + debug.log("received promise headers: %s", + promiseHeaders); + } + + if (Log.headers() || debug.on()) { + StringBuilder sb = new StringBuilder("PUSH_PROMISE HEADERS (pushId: ") + .append(pushId).append("):\n"); + Log.dumpHeaders(sb, " ", promiseHeaders); + if (Log.headers()) { + Log.logHeaders(sb.toString()); + } else if (debug.on()) { + debug.log(sb); + } + } + + String method = promiseHeaders.firstValue(":method") + .orElseThrow(() -> new ProtocolException("no method in promise request")); + String path = promiseHeaders.firstValue(":path") + .orElseThrow(() -> new ProtocolException("no path in promise request")); + String authority = promiseHeaders.firstValue(":authority") + .orElseThrow(() -> new ProtocolException("no authority in promise request")); + if (Set.of("PUT", "DELETE", "OPTIONS", "TRACE").contains(method)) { + throw new ProtocolException("push method not allowed pushId=" + pushId); + } + long clen = promiseHeaders.firstValueAsLong("Content-Length").orElse(-1); + if (clen > 0) { + throw new ProtocolException("push headers contain non-zero Content-Length for pushId=" + pushId); + } + if (promiseHeaders.firstValue("Transfer-Encoding").isPresent()) { + throw new ProtocolException("push headers contain Transfer-Encoding for pushId=" + pushId); + } + + + // this will clear the response headers + // At this point the push promise stream may not be opened yet + if (connection.onPushPromiseFrame(this, pushId, promiseHeaders)) { + // the promise response will be handled from a child of this exchange + // once the push stream is open, we have nothing more to do here. + if (debug.on()) { + debug.log("handling push promise response for %s with request-response stream %s", + pushId, streamId()); + } + } else { + // the promise response is being handled by another exchange, just accept the id + if (debug.on()) { + debug.log("push promise response for %s is already handled by another stream", + pushId); + } + PushGroup pushGroup = exchange.getPushGroup(); + connection.whenPushAccepted(pushId).thenAccept((accepted) -> { + if (accepted) { + pushGroup.acceptPushPromiseId(connection.newPushId(pushId)); + } + }); + } + } + + private void cancelPushPromise(PushPromiseState state, IOException cause) { + // send CANCEL_PUSH frame here + long pushId = state.frame().getPushId(); + connection.pushCancelled(pushId, cause); + } + + @Override + void onPollException(QuicStreamReader reader, IOException io) { + if (Log.http3()) { + Log.logHttp3("{0}/streamId={1} {2} #{3} (requestSent={4}, responseReceived={5}, " + + "reader={6}, writer={7}, statusCode={8}, finalStream={9}, " + + "receivedQuicBytes={10}, sentQuicBytes={11}): {12}", + connection().quicConnection().logTag(), + String.valueOf(reader.stream().streamId()), request, String.valueOf(exchange.multi.id), + requestSent, responseReceived, reader.receivingState(), writer.sendingState(), + String.valueOf(responseCode), connection.isFinalStream(), String.valueOf(receivedQuicBytes()), + String.valueOf(sentQuicBytes.get()), io); + } + } + + void onReaderReset() { + long errorCode = stream.rcvErrorCode(); + String resetReason = Http3Error.stringForCode(errorCode); + Http3Error resetError = Http3Error.fromCode(errorCode) + .orElse(Http3Error.H3_REQUEST_CANCELLED); + if (!requestSent || !responseReceived) { + cancelImpl(new IOException("Stream %s reset by peer: %s" + .formatted(streamId(), resetReason)), + resetError); + } + if (debug.on()) { + debug.log("Stream %s reset by peer [%s]: Stopping scheduler", + streamId(), resetReason); + } + readScheduler.stop(); + } + + + + // Invoked when some data is received from the request-response + // Quic stream + private void processQuicData() { + // Poll bytes from the request-response stream + // and parses the data to read HTTP/3 frames. + // + // If the frame being read is a header frame, send the + // compacted header field data to QPack. + // + // Otherwise, if it's a data frame, send the bytes + // to the response body subscriber. + // + // Finally, if the frame being read is a PushPromiseFrame, + // sends the compressed field data to the QPack decoder to + // decode the push promise request headers. + // + try { + processQuicData(reader, framesDecoder, frameOrderVerifier, readScheduler, debug); + } catch (Throwable t) { + if (debug.on()) + debug.log("processQuicData - Unexpected exception", t); + if (!requestSent) { + cancelImpl(t, Http3Error.H3_REQUEST_CANCELLED); + } else if (!responseReceived) { + cancelImpl(t, Http3Error.H3_REQUEST_CANCELLED); + } + } finally { + if (debug.on()) + debug.log("processQuicData - leaving - eof: %s", framesDecoder.eof()); + } + } + + void connectionError(Throwable throwable, long errorCode, String errMsg) { + if (errorRef.compareAndSet(null, throwable)) { + var streamid = streamId(); + if (debug.on()) { + if (streamid == -1) { + debug.log("cancelling stream due to connection error", throwable); + } else { + debug.log("cancelling stream " + streamid + " due to connection error", throwable); + } + } + if (Log.trace()) { + if (streamid == -1) { + Log.logTrace("connection error: {0}", errMsg); + } else { + var format = "cancelling stream {0} due to connection error: {1}"; + Log.logTrace(format, streamid, errMsg); + } + } + } + connection.connectionError(this, throwable, errorCode, errMsg); + } + + record PushPromiseState(PushPromiseFrame frame, + HeaderFrameReader reader, + HttpHeadersBuilder headersBuilder, + DecodingCallback consumer) {} + final ConcurrentHashMap promiseMap = new ConcurrentHashMap<>(); + + private void ignorePushPromiseData(PushPromiseFrame ppf, List payload) { + boolean completed = ppf.remaining() == 0; + boolean eof = false; + if (payload != null) { + int last = payload.size() - 1; + for (int i = 0; i <= last; i++) { + ByteBuffer buf = payload.get(i); + buf.limit(buf.position()); + if (buf == QuicStreamReader.EOF) { + eof = true; + } + } + } + if (!completed && eof) { + cancelImpl(new EOFException("EOF reached promise: " + ppf), + Http3Error.H3_FRAME_ERROR); + } + } + + private boolean ignorePushPromiseFrame(PushPromiseFrame ppf, List payload) + throws IOException { + long pushId = ppf.getPushId(); + long minPushId = connection.getMinPushId(); + if (exchange.pushGroup == null) { + IOException checkFailed = connection.checkMaxPushId(pushId); + if (checkFailed != null) { + // connection is closed + throw checkFailed; + } + if (!connection.acceptPromises()) { + // if no stream accept promises, we can ignore the data and + // cancel the promise right away. + if (debug.on()) { + debug.log("ignoring PushPromiseFrame (no promise handler): %s%n", ppf); + } + ignorePushPromiseData(ppf, payload); + if (pushId >= minPushId) { + connection.noPushHandlerFor(pushId); + } + return true; + } + } + if (pushId < minPushId) { + if (debug.on()) { + debug.log("ignoring PushPromiseFrame (pushId=%s < %s): %s%n", + pushId, minPushId, ppf); + } + ignorePushPromiseData(ppf, payload); + return true; + } + return false; + } + + void receivePushPromiseFrame(PushPromiseFrame ppf, List payload) + throws IOException { + var state = promiseMap.get(ppf); + if (state == null) { + if (ignorePushPromiseFrame(ppf, payload)) return; + if (debug.on()) + debug.log("received PushPromiseFrame: " + ppf); + var checkFailed = connection.checkMaxPushId(ppf.getPushId()); + if (checkFailed != null) throw checkFailed; + var builder = new HttpHeadersBuilder(); + var consumer = new PushHeadersConsumer(); + var reader = qpackDecoder.newHeaderFrameReader(consumer); + state = new PushPromiseState(ppf, reader, builder, consumer); + consumer.setState(state); + promiseMap.put(ppf, state); + } + if (debug.on()) + debug.log("receive promise headers: buffer list: " + payload); + HeaderFrameReader headerFrameReader = state.reader(); + boolean completed = ppf.remaining() == 0; + boolean eof = false; + if (payload != null) { + int last = payload.size() - 1; + for (int i = 0; i <= last; i++) { + ByteBuffer buf = payload.get(i); + boolean endOfHeaders = completed && i == last; + if (debug.on()) + debug.log("QPack decoding %s bytes from headers (last: %s)", + buf.remaining(), last); + qpackDecoder.decodeHeader(buf, + endOfHeaders, + headerFrameReader); + if (buf == QuicStreamReader.EOF) { + eof = true; + } + } + } + if (!completed && eof) { + cancelImpl(new EOFException("EOF reached promise: " + ppf), + Http3Error.H3_FRAME_ERROR); + } + } + + /** + * This method is called by the {@link Http3PushManager} in order to + * invoke the {@link Acceptor} that will accept the push + * promise. This method gets the acceptor, invokes its {@link + * Acceptor#accepted()} method, and if {@code true}, returns the + * {@code Acceptor}. + *

+ * If the push request is not accepted this method returns {@code null}. + * + * @apiNote + * This method is called upon reception of a {@link PushPromiseFrame}. + * The quic stream that will carry the body may not be available yet. + * + * @param pushId the pushId + * @param pushRequest the promised push request + * @return an {@link Acceptor} to get the body handler for the + * push request, or {@code null}. + */ + Acceptor acceptPushPromise(long pushId, HttpRequestImpl pushRequest) { + if (Log.requests()) { + Log.logRequest("PUSH_PROMISE: " + pushRequest.toString()); + } + PushGroup pushGroup = exchange.getPushGroup(); + if (pushGroup == null || exchange.multi.requestCancelled()) { + if (Log.trace()) { + Log.logTrace("Rejecting push promise pushId: " + pushId); + } + connection.pushCancelled(pushId, null); + return null; + } + + Acceptor acceptor = null; + boolean accepted = false; + try { + acceptor = pushGroup.acceptPushRequest(pushRequest, connection.newPushId(pushId)); + accepted = acceptor.accepted(); + } catch (Throwable t) { + if (debug.on()) + debug.log("PushPromiseHandler::applyPushPromise threw exception %s", + (Object)t); + } + if (!accepted) { + // cancel / reject + if (Log.trace()) { + Log.logTrace("No body subscriber for {0}: {1}", pushRequest, + "Push " + pushId + " cancelled by users handler"); + } + connection.pushCancelled(pushId, null); + return null; + } + + assert accepted && acceptor != null; + return acceptor; + } + + /** + * This method is called by the {@link Http3PushManager} once the {@link Acceptor#cf() + * responseCF} has been obtained from the acceptor. + * @param pushId the pushId + * @param responseCF the response completable future + */ + void onPushRequestAccepted(long pushId, CompletableFuture> responseCF) { + PushGroup pushGroup = getExchange().getPushGroup(); + assert pushGroup != null; + // setup housekeeping for when the push is received + // TODO: deal with ignoring of CF anti-pattern + CompletableFuture> cf = responseCF; + cf.whenComplete((HttpResponse resp, Throwable t) -> { + t = Utils.getCompletionCause(t); + if (Log.trace()) { + Log.logTrace("Push {0} completed for {1}{2}", pushId, resp, + ((t==null) ? "": " with exception " + t)); + } + if (t != null) { + if (debug.on()) { + debug.log("completing pushResponseCF for" + + ", pushId=" + pushId + " with: " + t); + } + pushGroup.pushError(t); + } else { + if (debug.on()) { + debug.log("completing pushResponseCF for" + + ", pushId=" + pushId + " with: " + resp); + } + } + pushGroup.pushCompleted(); + }); + } + + /** + * This method is called by the {@link Http3PushPromiseStream} when + * starting + * @param pushRequest the pushRequest + * @param pushStream the pushStream + */ + void onHttp3PushStreamStarted(HttpRequestImpl pushRequest, + Http3PushPromiseStream pushStream) { + PushGroup pushGroup = getExchange().getPushGroup(); + assert pushGroup != null; + assert pushStream != null; + connection.onPushPromiseStreamStarted(pushStream, pushStream.streamId()); + } + + // invoked when ByteBuffers containing the next payload bytes for the + // given partial header frame are received + void receiveHeaders(HeadersFrame headers, List payload) { + if (debug.on()) + debug.log("receive headers: buffer list: " + payload); + boolean completed = headers.remaining() == 0; + boolean eof = false; + if (payload != null) { + int last = payload.size() - 1; + for (int i = 0; i <= last; i++) { + ByteBuffer buf = payload.get(i); + boolean endOfHeaders = completed && i == last; + if (debug.on()) + debug.log("QPack decoding %s bytes from headers (last: %s)", + buf.remaining(), last); + // if we have finished receiving the header frame, pause reading until + // the status code has been decoded + if (endOfHeaders) switchReadingPaused(true); + qpackDecoder.decodeHeader(buf, + endOfHeaders, + headerFrameReader); + if (buf == QuicStreamReader.EOF) { + eof = true; + // we are at EOF - no need to pause reading + switchReadingPaused(false); + } + } + } + if (!completed && eof) { + cancelImpl(new EOFException("EOF reached: " + headers), + Http3Error.H3_FRAME_ERROR); + } + } + + + // Invoked when data can be pushed to the quic stream; + // Headers may block the stream - but they will be buffered in the stream + // so should not cause this method to be called. + // We should reach here only when sending body bytes. + private void sendQuicData() { + // This method is invoked when the sending part of the + // stream is unblocked. + if (!requestBodyCF.isDone()) { + if (!exchange.multi.requestCancelled()) { + var requestSubscriber = this.requestSubscriber; + // the requestSubscriber will request more data from + // upstream if needed + if (requestSubscriber != null) requestSubscriber.unblock(); + } + } + } + + // pushes entire response body into response subscriber + // blocking when required by local or remote flow control + CompletableFuture receiveResponseBody(BodySubscriber bodySubscriber, Executor executor) { + // ensure that the body subscriber will be subscribed and onError() is + // invoked + pendingResponseSubscriber = bodySubscriber; + + // We want to allow the subscriber's getBody() method to block, so it + // can work with InputStreams. So, we offload execution. + responseBodyCF = ResponseSubscribers.getBodyAsync(executor, bodySubscriber, + new MinimalFuture<>(), (t) -> this.cancelImpl(t, Http3Error.H3_REQUEST_CANCELLED)); + + if (isCanceled()) { + Throwable t = getCancelCause(); + responseBodyCF.completeExceptionally(t); + } + + readScheduler.runOrSchedule(); // in case data waiting already to be processed, or error + + return responseBodyCF; + } + + void onSubscriptionError(Throwable t) { + errorRef.compareAndSet(null, t); + if (debug.on()) debug.log("Got subscription error: %s", (Object) t); + // This is the special case where the subscriber + // has requested an illegal number of items. + // In this case, the error doesn't come from + // upstream, but from downstream, and we need to + // handle the error without waiting for the inputQ + // to be exhausted. + stopRequested = true; + readScheduler.runOrSchedule(); + } + + // This loop is triggered to push response body data into + // the body subscriber. It is called from the processQuicData + // loop. However, we cannot call onNext() if we have no demands. + // So we're using a responseData queue to buffer incoming data. + void pushResponseData(ConcurrentLinkedQueue> responseData) { + if (debug.on()) debug.log("pushResponseData"); + HttpResponse.BodySubscriber subscriber = responseSubscriber; + boolean done = false; + try { + if (subscriber == null) { + subscriber = responseSubscriber = pendingResponseSubscriber; + if (subscriber == null) { + // can't process anything yet + return; + } else { + if (debug.on()) debug.log("subscribing user subscriber"); + subscriber.onSubscribe(userSubscription); + } + } + while (!responseData.isEmpty() && errorRef.get() == null) { + List data = responseData.peek(); + List dsts = Collections.unmodifiableList(data); + long size = Utils.remaining(dsts, Long.MAX_VALUE); + boolean finished = dsts.contains(QuicStreamReader.EOF); + if (size == 0 && finished) { + responseData.remove(); + if (Log.trace()) { + Log.logTrace("responseSubscriber.onComplete"); + } + if (debug.on()) debug.log("pushResponseData: onComplete"); + done = true; + subscriber.onComplete(); + responseReceived(); + return; + } else if (userSubscription.tryDecrement()) { + responseData.remove(); + if (Log.trace()) { + Log.logTrace("responseSubscriber.onNext {0}", size); + } + if (debug.on()) debug.log("pushResponseData: onNext(%d)", size); + subscriber.onNext(dsts); + } else { + if (stopRequested) break; + if (debug.on()) debug.log("no demand"); + return; + } + } + if (framesDecoder.eof() && responseData.isEmpty()) { + if (debug.on()) debug.log("pushResponseData: EOF"); + if (Log.trace()) { + Log.logTrace("responseSubscriber.onComplete"); + } + if (debug.on()) debug.log("pushResponseData: onComplete"); + done = true; + subscriber.onComplete(); + responseReceived(); + return; + } + } catch (Throwable throwable) { + if (debug.on()) debug.log("pushResponseData: unexpected exception", throwable); + errorRef.compareAndSet(null, throwable); + } finally { + if (done) responseData.clear(); + } + + Throwable t = errorRef.get(); + if (t != null) { + try { + if (debug.on()) + debug.log("calling subscriber.onError: %s", (Object) t); + subscriber.onError(t); + } catch (Throwable x) { + Log.logError("Subscriber::onError threw exception: {0}", x); + } finally { + cancelImpl(t, Http3Error.H3_REQUEST_CANCELLED); + responseData.clear(); + } + } + } + + // This method is called by Http2Connection::decrementStreamCount in order + // to make sure that the stream count is decremented only once for + // a given stream. + boolean deRegister() { + return DEREGISTERED.compareAndSet(this, false, true); + } + + private static final VarHandle DEREGISTERED; + static { + try { + DEREGISTERED = MethodHandles.lookup() + .findVarHandle(Http3ExchangeImpl.class, "deRegistered", boolean.class); + } catch (Exception x) { + throw new ExceptionInInitializerError(x); + } + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http3PendingConnections.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http3PendingConnections.java new file mode 100644 index 00000000000..92d51b101f7 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http3PendingConnections.java @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http; + +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import jdk.internal.net.http.AltServicesRegistry.AltService; +import jdk.internal.net.http.Http3ClientImpl.ConnectionRecovery; +import jdk.internal.net.http.Http3ClientImpl.PendingConnection; +import jdk.internal.net.http.Http3ClientImpl.StreamLimitReached; +import jdk.internal.net.http.common.Log; + +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; + +/** + * This class keeps track of pending HTTP/3 connections + * to avoid making two connections to the same server + * in parallel. Methods in this class are not atomic. + * Therefore, it is expected that they will be called + * while holding a lock in order to ensure atomicity. + */ +class Http3PendingConnections { + + private final Map pendingAdvertised = new ConcurrentHashMap<>(); + private final Map pendingUnadvertised = new ConcurrentHashMap<>(); + + Http3PendingConnections() {} + + + // Called when recovery is needed for a given connection, with + // the request that got the StreamLimitException + // Should be called while holding Http3ClientImpl.lock + void streamLimitReached(String key, Http3Connection connection) { + var altSvc = connection.connection().getSourceAltService().orElse(null); + var advertised = altSvc != null && altSvc.wasAdvertised(); + var queue = advertised ? pendingAdvertised : pendingUnadvertised; + queue.computeIfAbsent(key, k -> new StreamLimitReached(connection)); + } + + // Remove a ConnectionRecovery after the connection was established + // Should be called while holding Http3ClientImpl.lock + ConnectionRecovery removeCompleted(String connectionKey, Exchange origExchange, Http3Connection conn) { + var altSvc = Optional.ofNullable(conn) + .map(Http3Connection::connection) + .flatMap(HttpQuicConnection::getSourceAltService) + .orElse(null); + var discovery = Optional.ofNullable(origExchange) + .map(Exchange::request) + .map(HttpRequestImpl::http3Discovery) + .orElse(null); + var advertised = (altSvc != null && altSvc.wasAdvertised()) + || discovery == ALT_SVC; + var sameOrigin = (altSvc != null && altSvc.originHasSameAuthority()); + + ConnectionRecovery recovered = null; + if (advertised) { + recovered = pendingAdvertised.remove(connectionKey); + } + if (discovery == ALT_SVC || recovered != null) return recovered; + if (altSvc == null) { + // for instance, there was an exception, so we don't + // know if there was an altSvc because conn == null + recovered = pendingAdvertised.get(connectionKey); + if (recovered instanceof PendingConnection pending) { + if (pending.exchange() == origExchange) { + pendingAdvertised.remove(connectionKey, recovered); + return recovered; + } + } + } + recovered = pendingUnadvertised.get(connectionKey); + if (recovered instanceof PendingConnection pending) { + if (pending.exchange() == origExchange) { + pendingUnadvertised.remove(connectionKey, recovered); + return pending; + } + } + if (!sameOrigin && advertised) return null; + return pendingUnadvertised.remove(connectionKey); + } + + // Lookup a ConnectionRecovery for the given request with the + // given key. + // Should be called while holding Http3ClientImpl.lock + ConnectionRecovery lookupFor(String key, HttpRequestImpl request, HttpClientImpl client) { + + var discovery = request.http3Discovery(); + + // if ALT_SVC only look in advertised + if (discovery == ALT_SVC) { + return pendingAdvertised.get(key); + } + + // if HTTP_3_ONLY look first in pendingUnadvertised + var unadvertised = pendingUnadvertised.get(key); + if (discovery == HTTP_3_URI_ONLY && unadvertised != null) { + if (unadvertised instanceof PendingConnection) { + return unadvertised; + } + } + + // then look in advertised + var advertised = pendingAdvertised.get(key); + if (advertised instanceof PendingConnection pending) { + var altSvc = pending.altSvc(); + var sameOrigin = altSvc != null && altSvc.originHasSameAuthority(); + assert altSvc != null; // pending advertised should have altSvc + if (discovery == ANY || sameOrigin) return advertised; + } + + // if HTTP_3_ONLY, nothing found, stop here + assert discovery != HTTP_3_URI_ONLY || !(unadvertised instanceof PendingConnection); + if (discovery == HTTP_3_URI_ONLY) { + if (advertised != null && Log.http3()) { + Log.logHttp3("{0} cannot be used for {1}: return null", advertised, request); + } + assert !(unadvertised instanceof PendingConnection); + return unadvertised; + } + + // if ANY return advertised if found, otherwise unadvertised + if (advertised instanceof PendingConnection) return advertised; + if (unadvertised instanceof PendingConnection) { + if (client.client3().isEmpty()) { + return unadvertised; + } + // if ANY and we have an alt service that's eligible for the request + // and is not same origin as the request's URI authority, then don't + // return unadvertised and instead return advertised (which may be null) + final AltService altSvc = client.client3().get().lookupAltSvc(request).orElse(null); + if (altSvc != null && !altSvc.originHasSameAuthority()) { + return advertised; + } else { + return unadvertised; + } + } + if (advertised != null) return advertised; + return unadvertised; + } + + // Adds a pending connection for the given request with the given + // key and altSvc. + // Should be called while holding Http3ClientImpl.lock + PendingConnection addPending(String key, HttpRequestImpl request, AltService altSvc, Exchange exchange) { + var discovery = request.http3Discovery(); + var advertised = altSvc != null && altSvc.wasAdvertised(); + var sameOrigin = altSvc == null || altSvc.originHasSameAuthority(); + // if advertised and same origin, we don't use pendingUnadvertised + // but pendingAdvertised even if discovery is HTTP_3_URI_ONLY + // if we have an advertised altSvc with not same origin, we still + // want to attempt HTTP_3_URI_ONLY at origin, as an unadvertised + // connection. If advertised & same origin, we can use the advertised + // service instead and use pendingAdvertised, even for HTTP_3_URI_ONLY + if (discovery == HTTP_3_URI_ONLY && (!advertised || !sameOrigin)) { + PendingConnection pendingConnection = new PendingConnection(null, exchange); + var previous = pendingUnadvertised.put(key, pendingConnection); + if (previous instanceof PendingConnection prev) { + String msg = "previous unadvertised pending connection found!" + + " (originally created for %s #%s) while adding pending connection for %s" + .formatted(prev.exchange().request, prev.exchange().multi.id, exchange.multi.id); + if (Log.errors()) Log.logError(msg); + assert false : msg; + } + return pendingConnection; + } + assert discovery != HTTP_3_URI_ONLY || advertised && sameOrigin; + if (advertised) { + PendingConnection pendingConnection = new PendingConnection(altSvc, exchange); + var previous = pendingAdvertised.put(key, pendingConnection); + if (previous instanceof PendingConnection prev) { + String msg = "previous pending advertised connection found!" + + " (originally created for %s #%s) while adding pending connection for %s" + .formatted(prev.exchange().request, prev.exchange().multi.id, exchange.multi.id); + if (Log.errors()) Log.logError(msg); + assert false : msg; + } + return pendingConnection; + } + if (discovery == ANY) { + assert !advertised; + PendingConnection pendingConnection = new PendingConnection(null, exchange); + var previous = pendingUnadvertised.put(key, pendingConnection); + if (previous instanceof PendingConnection prev) { + String msg = ("previous unadvertised pending connection found for ANY!" + + " (originally created for %s #%s) while adding pending connection for %s") + .formatted(prev.exchange().request, prev.exchange().multi.id, exchange.multi.id); + if (Log.errors()) Log.logError(msg); + assert false : msg; + } + return pendingConnection; + } + // last case - if we reach here we're ALT_SVC but couldn't + // find an advertised alt service. + assert discovery == ALT_SVC; + return null; + } +} + diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http3PushManager.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http3PushManager.java new file mode 100644 index 00000000000..b9cf4dbc0f1 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http3PushManager.java @@ -0,0 +1,811 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http; + +import java.io.IOException; +import java.net.ProtocolException; +import java.net.http.HttpHeaders; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.PushPromiseHandler.PushId; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; + +import static jdk.internal.net.http.Http3ClientProperties.MAX_HTTP3_PUSH_STREAMS; + +/** + * Manages HTTP/3 push promises for an HTTP/3 connection. + *

+ * This class maintains a bounded collection of recent push promises, + * together with the current state of the promise: pending, processed, or + * cancelled. When a new {@link jdk.internal.net.http.http3.frames.PushPromiseFrame} + * is received, and entry is added in the map, and the state of the promise + * is updated as it goes. + * When the map is full, old entries (lowest pushId) are expunged from + * the map. No promise will be accepted if its pushId is lower than the + * lowest pushId in the map. + * + * @apiNote + * When a PushPromiseFrame is received, {@link + * #onPushPromiseFrame(Http3ExchangeImpl, long, HttpHeaders)} + * is called. This arranges for an entry to be added to the map, unless there's + * already one. Also, the first Http3ExchangeImpl for which this method is called + * for a given pushId gets to handle the PushPromise: its {@link + * java.net.http.HttpResponse.PushPromiseHandler} will be invoked to accept the promise + * and handle the body. + *

+ * When a new PushStream is opened, {@link #onPushPromiseStream(QuicReceiverStream, long)} + * is called. When both {@code onPushPromiseFrame} and {@code onPushPromiseStream} have + * been called for a given {@code pushId}, an {@link Http3PushPromiseStream} is created + * and started to receive the body. + *

+ * {@link Http3ExchangeImpl} that receive a push promise frame, but don't get to handle + * the body (because it's already been delegated to another stream) should call + * {@link #whenAccepted(long)} to figure out when it is safe to invoke {@link + * PushGroup#acceptPushPromiseId(PushId)}. + *

+ * {@link #cancelPushPromise(long, Throwable, CancelPushReason)} can be called to cancel + * a push promise. {@link #pushPromiseProcessed(long)} should be called when the body + * has been fully processed. + */ +final class Http3PushManager { + + private final Logger debug = Utils.getDebugLogger(this::dbgTag); + + private final ReentrantLock promiseLock = new ReentrantLock(); + private final ConcurrentHashMap promises = new ConcurrentHashMap<>(); + private final CompletableFuture DENIED = MinimalFuture.completedFuture(Boolean.FALSE); + private final CompletableFuture ACCEPTED = MinimalFuture.completedFuture(Boolean.TRUE); + + private final AtomicLong maxPushId = new AtomicLong(); + private final AtomicLong maxPushReceived = new AtomicLong(); + private final AtomicLong minPushId = new AtomicLong(); + // the max history we keep in the promiseMap. We start expunging old + // entries from the map when the size of the map exceeds this value + private static final long MAX_PUSH_HISTORY_SIZE = (3*MAX_HTTP3_PUSH_STREAMS)/2; + // the maxPushId increments, we send on MAX_PUSH_ID frame + // with a maxPushId incremented by that amount. + // Ideally should be <= to MAX_PUSH_HISTORY_SIZE, to avoid + // filling up the history right after the first MAX_PUSH_ID + private static final long MAX_PUSH_ID_INCREMENTS = MAX_HTTP3_PUSH_STREAMS; + private final Http3Connection connection; + + // number of pending promises + private final AtomicInteger pendingPromises = new AtomicInteger(); + // push promises are considered blocked if we have failed to send + // the last MAX_PUSH_ID update due to pendingPromises + // count having reached MAX_HTTP3_PUSH_STREAMS + private volatile boolean pushPromisesBlocked; + + + Http3PushManager(Http3Connection connection) { + this.connection = connection; + } + + String dbgTag() { + return connection.dbgTag(); + } + + public void cancelAllPromises(IOException closeCause, Http3Error error) { + for (var promise : promises.entrySet()) { + var pushId = promise.getKey(); + var pp = promise.getValue(); + switch (pp) { + case ProcessedPushPromise ignored -> {} + case CancelledPushPromise ignored -> {} + case PendingPushPromise ppp -> { + cancelPendingPushPromise(ppp, closeCause); + } + } + } + } + + // Different actions needs to be carried out when cancelling a + // push promise, depending on the state of the promise and the + // cancellation reason. + enum CancelPushReason { + NO_HANDLER, // the exchange has no PushGroup + PUSH_CANCELLED, // the PromiseHandler cancelled the push, + // or an error occurred handling the promise + CANCEL_RECEIVED; // received CANCEL_PUSH from server + } + + /** + * A PushPromise can be a PendingPushPromise, until the push + * response is completely received, or a ProcessedPushPromise, + * which replace the PendingPushPromise after the response body + * has been delivered. If the PushPromise is cancelled before + * accepting it or receiving a body, CancelledPushPromise will + * be recorded and replace the PendingPushPromise. + */ + private sealed interface PushPromise + permits PendingPushPromise, ProcessedPushPromise, CancelledPushPromise { + } + + /** + * Represent a PushPromise whose body as already been delivered + */ + private record ProcessedPushPromise(PushId pushId, HttpHeaders promiseHeaders) + implements PushPromise { } + + /** + * Represent a PushPromise that has been cancelled. No body will be delivered. + */ + private record CancelledPushPromise(PushId pushId) implements PushPromise { } + + // difficult to say what will come first - the push promise, + // or the push stream? + // The first push promise frame received will register the + // exchange with this class - and trigger the parsing of + // the request/response when the stream is available. + // The other will trigger a simple call to register the + // push id. + // Probably we also need some timer to clean + // up the map if the stream doesn't manifest after a while. + // We maintain minPushID, where any frame + // containing a push id < to the min will be discarded, + // and any stream with a pushId < will also be discarded. + + /** + * Represents a PushPromise whose body has not been delivered + * yet. + * @param the type of the body + */ + private static final class PendingPushPromise implements PushPromise { + // called when the first push promise frame is received + PendingPushPromise(Http3ExchangeImpl exchange, long pushId, HttpHeaders promiseHeaders) { + this.accepted = new MinimalFuture<>(); + this.exchange = Objects.requireNonNull(exchange); + this.promiseHeaders = Objects.requireNonNull(promiseHeaders); + this.pushId = pushId; + } + + // called when the push promise stream is opened + PendingPushPromise(QuicReceiverStream stream, long pushId) { + this.accepted = new MinimalFuture<>(); + this.stream = Objects.requireNonNull(stream); + this.pushId = pushId; + } + + // volatiles should not be required since we only modify/read + // those within a lock. Final fields should ensure safe publication + final long pushId; // the push id + QuicReceiverStream stream; // the quic promise stream + Http3ExchangeImpl exchange; // the exchange that will create the body subscriber + Http3PushPromiseStream promiseStream; // the HTTP/3 stream to process the quic stream + HttpHeaders promiseHeaders; // the push promise request headers + CompletableFuture> responseCF; + HttpRequestImpl pushReq; + BodyHandler handler; + final CompletableFuture accepted; // whether the push promise was accepted + + public long pushId() { return pushId; } + + public boolean ready() { + if (stream == null) return false; + if (exchange == null) return false; + if (promiseHeaders == null) return false; + if (!accepted.isDone()) return false; + if (responseCF == null) return false; + if (pushReq == null) return false; + if (handler == null) return false; + return true; + } + + @Override + public String toString() { + return "PendingPushPromise{" + + "pushId=" + pushId + + ", stream=" + stream + + ", exchange=" + dbgTag(exchange) + + ", promiseStream=" + dbgTag(promiseStream) + + ", promiseHeaders=" + promiseHeaders + + ", accepted=" + accepted + + '}'; + } + + String dbgTag(Http3ExchangeImpl exchange) { + return exchange == null ? null : exchange.dbgTag(); + } + + String dbgTag(Http3PushPromiseStream promiseStream) { + return promiseStream == null ? null : promiseStream.dbgTag(); + } + } + + /** + * {@return the maximum pushId that can be accepted from the peer} + * This corresponds to the pushId that has been included in the last + * MAX_PUSH_ID frame sent to the peer. A pushId greater than this + * value must be rejected, and cause the connection to close with + * error. + * + * @apiNote due to internal constraints it is possible that the + * MAX_PUSH_ID frame has not been sent yet, but the {@code Http3PushManager} + * will behave as if the peer had received that frame. + * + * @see Http3Connection#checkMaxPushId(long) + * @see #checkMaxPushId(long) + */ + long getMaxPushId() { + return maxPushId.get(); + } + + /** + * {@return the minimum pushId that can be accepted from the peer} + * Any pushId strictly less than this value must be ignored. + * + * @apiNote The minimum pushId represents the smallest pushId that + * was recorded in our history. For smaller pushId, no history has + * been kept, due to history size constraints. Any pushId strictly + * less than this value must be ignored. + */ + long getMinPushId() { + return minPushId.get(); + } + + /** + * Called when a new push promise stream is created by the peer, and + * the pushId has been read. + * @param pushStream the new push promise stream + * @param pushId the pushId + */ + void onPushPromiseStream(QuicReceiverStream pushStream, long pushId) { + assert pushId >= 0; + if (!connection.acceptLargerPushPromise(pushStream, pushId)) return; + PendingPushPromise promise = addPushPromise(pushStream, pushId); + if (promise != null) { + assert promise.stream == pushStream; + // if stream is avoilable start parsing? + tryReceivePromise(promise); + } + } + + /** + * Checks whether a MAX_PUSH_ID frame needs to be sent, + * and send it. + * Called from {@link Http3Connection#checkSendMaxPushId()}. + */ + void checkSendMaxPushId() { + if (MAX_PUSH_ID_INCREMENTS <= 0) return; + long pendingCount = pendingPromises.get(); + long availableSlots = MAX_HTTP3_PUSH_STREAMS - pendingCount; + if (availableSlots <= 0) { + pushPromisesBlocked = true; + if (debug.on()) debug.log("Push promises blocked: availableSlots=%s", pendingCount); + return; + } + long maxPushIdSent = maxPushId.get(); + long maxPushIdReceived = maxPushReceived.get(); + long half = Math.max(1, MAX_PUSH_ID_INCREMENTS /2); + if (maxPushIdSent - maxPushIdReceived < half) { + // do not send a maxPushId that would consume more + // than our available slots + long increment = Math.min(availableSlots, MAX_PUSH_ID_INCREMENTS); + long update = maxPushIdSent + increment; + boolean updated = false; + try { + // let's update the counter before sending the frame, + // otherwise there's a chance we can receive a frame + // before updating the counter. + do { + if (maxPushId.compareAndSet(maxPushIdSent, update)) { + if (debug.on()) { + debug.log("MAX_PUSH_ID updated: %s (%s -> %s), increment %s, pending %s, availableSlots %s", + update, maxPushIdSent, update, increment, + promises.values().stream().filter(PendingPushPromise.class::isInstance) + .map(p -> (PendingPushPromise) p) + .map(PendingPushPromise::pushId).toList(), + availableSlots); + } + updated = true; + break; + } + maxPushIdSent = maxPushId.get(); + } while (maxPushIdSent < update); + if (updated) { + if (pushPromisesBlocked) { + if (debug.on()) debug.log("Push promises unblocked: maxPushIdSent=%s", update); + pushPromisesBlocked = false; + } + connection.sendMaxPushId(update); + } + } catch (IOException io) { + debug.log("Failed to send MAX_PUSH_ID(%s): %s", update, io); + } + } + } + + /** + * Called when a PushPromiseFrame has been decoded. + * + * @apiNote + * This method calls {@link Http3ExchangeImpl#acceptPushPromise(long, HttpRequestImpl)} + * and {@link Http3ExchangeImpl#onPushRequestAccepted(long, CompletableFuture)} + * for the first exchange that receives the {@link + * jdk.internal.net.http.http3.frames.PushPromiseFrame} + * + * @param exchange The HTTP/3 exchange that received the frame + * @param pushId The pushId contained in the frame + * @param promiseHeaders The push promise headers contained in the frame + * + * @return true if the exchange should take care of creating the HttpResponse body, + * false otherwise + * + * @see Http3Connection#onPushPromiseFrame(Http3ExchangeImpl, long, HttpHeaders) + */ + boolean onPushPromiseFrame(Http3ExchangeImpl exchange, long pushId, HttpHeaders promiseHeaders) + throws IOException { + if (!connection.acceptLargerPushPromise(null, pushId)) return false; + PendingPushPromise promise = addPushPromise(exchange, pushId, promiseHeaders); + if (promise == null) { + return false; + } + // A PendingPushPromise is returned only if there was no + // PushPromise present. If a PendingPushPromise is returned + // it should therefore have its exchange already set to the + // current exchange. + assert promise.exchange == exchange; + HttpRequestImpl pushReq = HttpRequestImpl.createPushRequest( + exchange.getExchange().request(), promiseHeaders); + var acceptor = exchange.acceptPushPromise(pushId, pushReq); + if (acceptor == null) { + // nothing to do: the push should already have been cancelled. + return false; + } + @SuppressWarnings("unchecked") + var pppU = (PendingPushPromise) promise; + var responseCF = pppU.responseCF; + assert responseCF == null; + boolean cancelled = false; + promiseLock.lock(); + try { + promise.pushReq = pushReq; + pppU.responseCF = responseCF = acceptor.cf(); + // recheck to verify the push hasn't been cancelled already + var check = promises.get(pushId); + if (check instanceof CancelledPushPromise || check == null) { + cancelled = true; + } else { + assert promise == check; + pppU.handler = acceptor.bodyHandler(); + } + } finally { + promiseLock.unlock(); + } + if (!cancelled) { + exchange.onPushRequestAccepted(pushId, responseCF); + promise.accepted.complete(true); + // if stream is available start parsing? + tryReceivePromise(promise); + return true; + } else { + cancelPendingPushPromise(promise, null); + // should be a no-op - in theory it should already + // have been completed + promise.accepted.complete(false); + return false; + } + } + + /** + * {@return a completable future that will be completed when a pushId has been + * accepted by the exchange in charge of creating the response body} + * + * The completable future is complete with {@code true} if the pushId is + * accepted, and with {@code false} if the pushId was rejected or cancelled. + * + * This method is intended to be called when {@link + * #onPushPromiseFrame(Http3ExchangeImpl, long, HttpHeaders)}, returns false, + * indicating that the push promise is being delegated to another request/response + * exchange. + * On completion of the future returned here, if the future is completed + * with {@code true}, the caller is expected to call {@link + * PushGroup#acceptPushPromiseId(PushId)} in order to notify the {@link + * java.net.http.HttpResponse.PushPromiseHandler} of the received {@code pushId}. + * + * @see Http3Connection#whenPushAccepted(long) + * @param pushId the pushId + */ + CompletableFuture whenAccepted(long pushId) { + var promise = promises.get(pushId); + if (promise instanceof PendingPushPromise pp) { + return pp.accepted; + } else if (promise instanceof ProcessedPushPromise) { + return ACCEPTED; + } else { // CancelledPushPromise or null + return DENIED; + } + } + + + /** + * Cancel a push promise. In case of concurrent requests receiving the + * same pushId, where one has a PushPromiseHandler and the other doesn't, + * we will cancel the push only if reason != CANCEL_RECEIVED, or no request + * stream has already accepted the push. + * + * @param pushId the promise pushId + * @param cause the cause (can be null) + * @param reason reason for cancelling + */ + void cancelPushPromise(long pushId, Throwable cause, CancelPushReason reason) { + boolean sendCancelPush = false; + PendingPushPromise pending = null; + if (cause != null) { + debug.log("PushPromise cancelled: pushId=" + pushId, cause); + } else { + debug.log("PushPromise cancelled: pushId=%s", pushId); + String msg = "cancelPushPromise(pushId="+pushId+")"; + debug.log(msg); + } + if (reason == CancelPushReason.CANCEL_RECEIVED) { + if (checkMaxPushId(pushId) != null) { + // pushId >= max connection will be closed + return; + } + } + promiseLock.lock(); + try { + var promise = promises.get(pushId); + long min = minPushId.get(); + if (promise == null) { + if (pushId > maxPushReceived.get()) maxPushReceived.set(pushId); + checkExpungePromiseMap(); + if (pushId >= min) { + var cancelled = new CancelledPushPromise(connection.newPushId(pushId)); + promises.put(pushId, cancelled); + sendCancelPush = reason != CancelPushReason.CANCEL_RECEIVED; + } + } else if (promise instanceof CancelledPushPromise) { + // nothing to do + } else if (promise instanceof ProcessedPushPromise) { + // nothing we can do? + } else if (promise instanceof PendingPushPromise ppp) { + // only cancel if never accepted, or force cancel requested + if (ppp.promiseStream == null || reason != CancelPushReason.NO_HANDLER) { + var cancelled = new CancelledPushPromise(connection.newPushId(pushId)); + promises.put(pushId, cancelled); + long pendingCount = pendingPromises.decrementAndGet(); + long ppc; + assert (ppc = promises.values().stream().filter(PendingPushPromise.class::isInstance).count()) == pendingCount + : "bad pending promise count: expected %s but found %s".formatted(pendingCount, ppc); + ppp.accepted.complete(false); // NO OP if already completed + pending = ppp; + // send cancel push; do not send if we received + // a CancelPushFrame from the peer + // also do not update MAX_PUSH_ID here - MAX_PUSH_ID will + // be updated when starting the next request/response exchange that accepts + // push promises. + sendCancelPush = reason != CancelPushReason.CANCEL_RECEIVED; + } + } + } finally { + promiseLock.unlock(); + } + if (sendCancelPush) { + connection.sendCancelPush(pushId, cause); + } + if (pending != null) { + cancelPendingPushPromise(pending, cause); + } + } + + private void cancelPendingPushPromise(PendingPushPromise ppp, Throwable cause) { + var ps = ppp.stream; + var http3 = ppp.promiseStream; + var responseCF = ppp.responseCF; + if (ps != null) { + ps.requestStopSending(Http3Error.H3_REQUEST_CANCELLED.code()); + } + if (http3 != null || responseCF != null) { + IOException io; + if (cause == null) { + io = new IOException("Push promise cancelled: " + ppp.pushId); + } else { + io = Utils.toIOException(cause); + } + if (http3 != null) { + http3.cancel(io); + } else if (responseCF != null) { + responseCF.completeExceptionally(io); + } + } + } + + /** + * Called when a push promise response body has been successfully received. + * @param pushId the pushId + */ + void pushPromiseProcessed(long pushId) { + promiseLock.lock(); + try { + var promise = promises.get(pushId); + if (promise instanceof PendingPushPromise ppp) { + var processed = new ProcessedPushPromise(connection.newPushId(pushId), + ppp.promiseHeaders); + promises.put(pushId, processed); + var pendingCount = pendingPromises.decrementAndGet(); + long ppc; + assert (ppc = promises.values().stream().filter(PendingPushPromise.class::isInstance).count()) == pendingCount + : "bad pending promise count: expected %s but found %s".formatted(pendingCount, ppc); + // do not update MAX_PUSH_ID here - MAX_PUSH_ID will + // be updated when starting the next request/response exchange that accepts + // push promises. + } + } finally { + promiseLock.unlock(); + } + } + + /** + * Checks whether the given pushId exceed the maximum pushId allowed + * to the peer, and if so, closes the connection. + * @param pushId the pushId + * @return an {@code IOException} that can be used to complete a completable + * future if the maximum pushId is exceeded, {@code null} + * otherwise + */ + IOException checkMaxPushId(long pushId) { + return connection.checkMaxPushId(pushId); + } + + // Checks whether an Http3PushPromiseStream can be created now + private void tryReceivePromise(PendingPushPromise promise) { + debug.log("tryReceivePromise: " + promise); + promiseLock.lock(); + Http3PushPromiseStream http3PushPromiseStream = null; + IOException failed = null; + try { + if (promise.ready() && promise.promiseStream == null) { + promise.promiseStream = http3PushPromiseStream = + createPushExchange(promise); + } else { + debug.log("tryReceivePromise: Can't create Http3PushPromiseStream for pushId=%s yet", + promise.pushId); + } + } catch (IOException io) { + failed = io; + } finally { + promiseLock.unlock(); + } + if (failed != null) { + cancelPushPromise(promise.pushId, failed, CancelPushReason.PUSH_CANCELLED); + return; + } + if (http3PushPromiseStream != null) { + // HTTP/3 push promises are not ref-counted + // If we were to change that it could be necessary to + // temporarly increment ref-counting here, until the stream + // read loop effectively starts. + http3PushPromiseStream.start(); + } + } + + // try to create and start an Http3PushPromiseStream when all bits have + // been received + private Http3PushPromiseStream createPushExchange(PendingPushPromise promise) + throws IOException { + assert promise.ready() : "promise is not ready: " + promise; + Http3ExchangeImpl parent = promise.exchange; + HttpRequestImpl pushReq = promise.pushReq; + QuicReceiverStream quicStream = promise.stream; + Exchange pushExch = new Exchange<>(pushReq, parent.exchange.multi); + Http3PushPromiseStream pushStream = new Http3PushPromiseStream<>(pushExch, + parent.http3Connection(), this, + quicStream, promise.responseCF, promise.handler, parent, promise.pushId); + pushExch.exchImpl = pushStream; + return pushStream; + } + + // The first exchange that gets the PushPromise gets a PushPromise object, + // others get null + // TODO: ideally we should start a timer to cancel a push promise if + // the stream doesn't materialize after a while. + // Note that the callers can always start their own timeouts using + // the CompletableFutures we returned to them. + private PendingPushPromise addPushPromise(Http3ExchangeImpl exchange, + long pushId, + HttpHeaders promiseHeaders) { + PushPromise promise = promises.get(pushId); + boolean cancelStream = false; + if (promise == null) { + promiseLock.lock(); + try { + promise = promises.get(pushId); + if (promise == null) { + if (checkMaxPushId(pushId) == null) { + if (pushId >= minPushId.get()) { + if (pushId > maxPushReceived.get()) maxPushReceived.set(pushId); + checkExpungePromiseMap(); + var pp = new PendingPushPromise<>(exchange, pushId, promiseHeaders); + promises.put(pushId, pp); + long pendingCount = pendingPromises.incrementAndGet(); + long ppc; + assert (ppc = promises.values().stream().filter(PendingPushPromise.class::isInstance).count()) == pendingCount + : "bad pending promise count: expected %s but found %s".formatted(pendingCount, ppc); + return pp; + } else { + // pushId < minPushId + cancelStream = true; + } + } else return null; + } + } finally { + promiseLock.unlock(); + } + } + if (cancelStream) { + // we don't have the stream; + // the stream will be canceled if it comes later + // do not send push cancel frame (already cancelled, or abandoned) + return null; + } + if (promise instanceof PendingPushPromise ppp) { + var pe = ppp.exchange; + if (pe == null) { + promiseLock.lock(); + try { + if (ppp.exchange == null) { + assert ppp.promiseHeaders == null; + @SuppressWarnings("unchecked") + var pppU = (PendingPushPromise) ppp; + pppU.exchange = exchange; + pppU.promiseHeaders = promiseHeaders; + return pppU; + } + } finally { + promiseLock.unlock(); + } + } + var previousHeaders = ppp.promiseHeaders; + if (previousHeaders != null && !previousHeaders.equals(promiseHeaders)) { + connection.protocolError( + new ProtocolException("push headers do not match with previous promise for " + pushId)); + } + } else if (promise instanceof ProcessedPushPromise ppp) { + if (!ppp.promiseHeaders().equals(promiseHeaders)) { + connection.protocolError( + new ProtocolException("push headers do not match with previous promise for " + pushId)); + } + } else if (promise instanceof CancelledPushPromise) { + // already cancelled - nothing to do + } + return null; + } + + // TODO: the packet opening the push promise stream might reach us before + // the push promise headers are processed. We could start a timer + // here to cancel the push promise if the PushPromiseFrame doesn't materialize + // after a while. + private PendingPushPromise addPushPromise(QuicReceiverStream stream, long pushId) { + PushPromise promise = promises.get(pushId); + boolean cancelStream = false; + if (promise == null) { + promiseLock.lock(); + try { + promise = promises.get(pushId); + if (promise == null) { + if (checkMaxPushId(pushId) == null) { + if (pushId >= minPushId.get()) { + if (pushId > maxPushReceived.get()) maxPushReceived.set(pushId); + checkExpungePromiseMap(); + var pp = new PendingPushPromise(stream, pushId); + promises.put(pushId, pp); + long pendingCount = pendingPromises.incrementAndGet(); + long ppc; + assert (ppc = promises.values().stream().filter(PendingPushPromise.class::isInstance).count()) == pendingCount + : "bad pending promise count: expected %s but found %s".formatted(pendingCount, ppc); + return pp; + } else { + // pushId < minPushId + cancelStream = true; + } + } else return null; // maxPushId exceeded, connection closed + } + } finally { + promiseLock.unlock(); + } + } + if (cancelStream) { + // do not send push cancel frame (already cancelled, or abandoned) + stream.requestStopSending(Http3Error.H3_REQUEST_CANCELLED.code()); + return null; + } + if (promise instanceof PendingPushPromise ppp) { + var ps = ppp.stream; + if (ps == null) { + promiseLock.lock(); + try { + if ((ps = ppp.stream) == null) { + ps = ppp.stream = stream; + } + } finally { + promiseLock.unlock(); + } + } + if (ps == stream) { + @SuppressWarnings("unchecked") + var pp = ((PendingPushPromise) ppp); + return pp; + } else { + // Error! cancel stream... + var io = new ProtocolException("HTTP/3 pushId %s already used on this connection".formatted(pushId)); + connection.connectionError(io, Http3Error.H3_ID_ERROR); + } + } else if (promise instanceof ProcessedPushPromise) { + var io = new ProtocolException("HTTP/3 pushId %s already used on this connection".formatted(pushId)); + connection.connectionError(io, Http3Error.H3_ID_ERROR); + } else { + // already cancelled? + // Error! cancel stream... + // connection.sendCancelPush(pushId, null); + stream.requestStopSending(Http3Error.H3_REQUEST_CANCELLED.code()); + } + return null; + } + + // We only keep MAX_PUSH_HISTORY_SIZE entries in the map. + // If the map has more than MAX_PUSH_HISTORY_SIZE entries, we start expunging + // pushIds starting at minPushId. This method makes room for at least + // on push promise in the map + private void checkExpungePromiseMap() { + assert promiseLock.isHeldByCurrentThread(); + while (promises.size() >= MAX_PUSH_HISTORY_SIZE) { + long min = minPushId.getAndIncrement(); + var pp = promises.remove(min); + if (pp instanceof PendingPushPromise ppp) { + var pendingCount = pendingPromises.decrementAndGet(); + long ppc; + assert (ppc = promises.values().stream().filter(PendingPushPromise.class::isInstance).count()) == pendingCount + : "bad pending promise count: expected %s but found %s".formatted(pendingCount, ppc); + var http3 = ppp.promiseStream; + IOException io = null; + if (http3 != null) { + http3.cancel(io = new IOException("PushPromise cancelled")); + } + if (io == null) { + io = new IOException("PushPromise cancelled"); + } + connection.sendCancelPush(ppp.pushId, io); + var ps = ppp.stream; + if (ps != null) { + ps.requestStopSending(Http3Error.H3_REQUEST_CANCELLED.code()); + } + } + } + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http3PushPromiseStream.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http3PushPromiseStream.java new file mode 100644 index 00000000000..27aaa75891c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http3PushPromiseStream.java @@ -0,0 +1,746 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http; + +import java.io.EOFException; +import java.io.IOException; +import java.net.ProtocolException; +import java.net.http.HttpClient.Version; +import java.net.http.HttpHeaders; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodySubscriber; +import java.net.http.HttpResponse.ResponseInfo; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; + +import jdk.internal.net.http.Http3PushManager.CancelPushReason; +import jdk.internal.net.http.common.HttpBodySubscriberWrapper; +import jdk.internal.net.http.common.HttpHeadersBuilder; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.SubscriptionBase; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.FramesDecoder; +import jdk.internal.net.http.http3.frames.HeadersFrame; +import jdk.internal.net.http.http3.frames.PushPromiseFrame; +import jdk.internal.net.http.qpack.Decoder; +import jdk.internal.net.http.qpack.readers.HeaderFrameReader; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; +import jdk.internal.net.http.quic.streams.QuicStreamReader; + +import static jdk.internal.net.http.http3.Http3Error.H3_FRAME_UNEXPECTED; + +/** + * This class represents an HTTP/3 PushPromise stream. + */ +final class Http3PushPromiseStream extends Http3Stream { + + private final Logger debug = Utils.getDebugLogger(this::dbgTag); + private final Http3Connection connection; + private final HttpHeadersBuilder respHeadersBuilder; + private final PushRespHeadersConsumer respHeadersConsumer; + private final HeaderFrameReader respHeaderFrameReader; + private final Decoder qpackDecoder; + private final AtomicReference errorRef; + private final CompletableFuture pushCF = new MinimalFuture<>(); + private final CompletableFuture> responseCF; + private final QuicReceiverStream stream; + private final QuicStreamReader reader; + private final Http3ExchangeImpl parent; + private final long pushId; + private final Http3PushManager pushManager; + private final BodyHandler pushHandler; + + private final FramesDecoder framesDecoder = + new FramesDecoder(this::dbgTag, FramesDecoder::isAllowedOnPromiseStream); + private final SequentialScheduler readScheduler = + SequentialScheduler.lockingScheduler(this::processQuicData); + private final ReentrantLock stateLock = new ReentrantLock(); + private final H3FrameOrderVerifier frameOrderVerifier = H3FrameOrderVerifier.newForPushPromiseStream(); + + final SubscriptionBase userSubscription = + new SubscriptionBase(readScheduler, this::cancel, this::onSubscriptionError); + + volatile boolean closed; + volatile BodySubscriber pendingResponseSubscriber; + volatile BodySubscriber responseSubscriber; + volatile CompletableFuture responseBodyCF; + volatile boolean responseReceived; + volatile int responseCode; + volatile Response response; + volatile boolean stopRequested; + private String dbgTag = null; + + Http3PushPromiseStream(Exchange exchange, + final Http3Connection connection, + final Http3PushManager pushManager, + final QuicReceiverStream stream, + final CompletableFuture> responseCF, + final BodyHandler pushHandler, + Http3ExchangeImpl parent, + long pushId) { + super(exchange); + this.responseCF = responseCF; + this.pushHandler = pushHandler; + this.errorRef = new AtomicReference<>(); + this.pushId = pushId; + this.connection = connection; + this.pushManager = pushManager; + this.stream = stream; + this.parent = parent; + this.respHeadersBuilder = new HttpHeadersBuilder(); + this.respHeadersConsumer = new PushRespHeadersConsumer(); + this.qpackDecoder = connection.qpackDecoder(); + this.respHeaderFrameReader = qpackDecoder.newHeaderFrameReader(respHeadersConsumer); + this.reader = stream.connectReader(readScheduler); + debug.log("Http3PushPromiseStream created"); + } + + void start() { + exchange.exchImpl = this; + parent.onHttp3PushStreamStarted(exchange.request(), this); + this.reader.start(); + } + + long pushId() { + return pushId; + } + + String dbgTag() { + if (dbgTag != null) return dbgTag; + long streamId = streamId(); + String sid = streamId == -1 ? "?" : String.valueOf(streamId); + String ctag = connection == null ? null : connection.dbgTag(); + String tag = "Http3PushPromiseStream(" + ctag + ", streamId=" + sid + ", pushId="+ pushId + ")"; + if (streamId == -1) return tag; + return dbgTag = tag; + } + + @Override + long streamId() { + var stream = this.stream; + return stream == null ? -1 : stream.streamId(); + } + + private final class PushRespHeadersConsumer extends StreamHeadersConsumer { + + public PushRespHeadersConsumer() { + super(Context.RESPONSE); + } + + void resetDone() { + if (debug.on()) { + debug.log("Response builder cleared, ready to receive new headers."); + } + } + + @Override + String headerFieldType() { + return "PUSH RESPONSE HEADER FIELD"; + } + + @Override + Decoder qpackDecoder() { + return qpackDecoder; + } + + @Override + protected String formatMessage(String message, String header) { + // Malformed requests or responses that are detected MUST be + // treated as a stream error of type H3_MESSAGE_ERROR. + return "malformed push response: " + super.formatMessage(message, header); + } + + + @Override + HeaderFrameReader headerFrameReader() { + return respHeaderFrameReader; + } + + @Override + HttpHeadersBuilder headersBuilder() { + return respHeadersBuilder; + } + + @Override + void headersCompleted() { + handleResponse(); + } + + @Override + public long streamId() { + return stream.streamId(); + } + } + + @Override + HttpQuicConnection connection() { + return connection.connection(); + } + + + // The Http3StreamResponseSubscriber is registered with the HttpClient + // to ensure that it gets completed if the SelectorManager aborts due + // to unexpected exceptions. + private void registerResponseSubscriber(Http3PushStreamResponseSubscriber subscriber) { + if (client().registerSubscriber(subscriber)) { + debug.log("Reference response body for h3 stream: " + streamId()); + client().h3StreamReference(); + } + } + + private void unregisterResponseSubscriber(Http3PushStreamResponseSubscriber subscriber) { + if (client().unregisterSubscriber(subscriber)) { + debug.log("Unreference response body for h3 stream: " + streamId()); + client().h3StreamUnreference(); + } + } + + final class Http3PushStreamResponseSubscriber extends HttpBodySubscriberWrapper { + Http3PushStreamResponseSubscriber(BodySubscriber subscriber) { + super(subscriber); + } + + @Override + protected void unregister() { + unregisterResponseSubscriber(this); + } + + @Override + protected void register() { + registerResponseSubscriber(this); + } + } + + Http3PushStreamResponseSubscriber createResponseSubscriber(BodyHandler handler, + ResponseInfo response) { + debug.log("Creating body subscriber"); + return new Http3PushStreamResponseSubscriber<>(handler.apply(response)); + } + + @Override + CompletableFuture ignoreBody() { + try { + debug.log("Ignoring body"); + reader.stream().requestStopSending(Http3Error.H3_REQUEST_CANCELLED.code()); + return MinimalFuture.completedFuture(null); + } catch (Throwable e) { + Log.logTrace("Error requesting stop sending for stream {0}: {1}", + streamId(), e.toString()); + return MinimalFuture.failedFuture(e); + } + } + + @Override + void cancel() { + debug.log("cancel"); + var stream = this.stream; + if ((stream == null)) { + cancel(new IOException("Stream cancelled before streamid assigned")); + } else { + cancel(new IOException("Stream " + stream.streamId() + " cancelled")); + } + } + + @Override + void cancel(IOException cause) { + cancelImpl(cause, Http3Error.H3_REQUEST_CANCELLED); + } + + @Override + void onProtocolError(IOException cause) { + final long streamId = stream.streamId(); + if (debug.on()) { + debug.log("cancelling exchange on stream %d due to protocol error: %s", streamId, cause.getMessage()); + } + Log.logError("cancelling exchange on stream {0} due to protocol error: {1}\n", streamId, cause); + cancelImpl(cause, Http3Error.H3_GENERAL_PROTOCOL_ERROR); + } + + @Override + void released() { + + } + + @Override + void completed() { + + } + + @Override + boolean isCanceled() { + return errorRef.get() != null; + } + + @Override + Throwable getCancelCause() { + return errorRef.get(); + } + + @Override + void cancelImpl(Throwable e, Http3Error error) { + try { + var streamid = streamId(); + if (errorRef.compareAndSet(null, e)) { + if (debug.on()) { + if (streamid == -1) debug.log("cancelling stream: %s", e); + else debug.log("cancelling stream " + streamid + ":", e); + } + if (Log.trace()) { + if (streamid == -1) Log.logTrace("cancelling stream: {0}\n", e); + else Log.logTrace("cancelling stream {0}: {1}\n", streamid, e); + } + } else { + if (debug.on()) { + if (streamid == -1) debug.log("cancelling stream: %s", (Object) e); + else debug.log("cancelling stream %s: %s", streamid, e); + } + } + + var firstError = errorRef.get(); + completeResponseExceptionally(firstError); + if (responseBodyCF != null) { + responseBodyCF.completeExceptionally(firstError); + } + // will send a RST_STREAM frame + var stream = this.stream; + if (connection.isOpen()) { + if (stream != null) { + if (debug.on()) + debug.log("request stop sending"); + stream.requestStopSending(error.code()); + } + } + } catch (Throwable ex) { + debug.log("failed cancelling request: ", ex); + Log.logError(ex); + } finally { + close(); + } + } + + @Override + CompletableFuture getResponseAsync(Executor executor) { + var cf = pushCF; + if (executor != null && !cf.isDone()) { + // protect from executing later chain of CompletableFuture operations from SelectorManager thread + cf = cf.thenApplyAsync(r -> r, executor); + } + Log.logTrace("Response future (stream={0}) is: {1}", streamId(), cf); + if (debug.on()) debug.log("Response future is %s", cf); + return cf; + } + + void completeResponse(Response r) { + debug.log("Response: " + r); + Log.logResponse(r::toString); + pushCF.complete(r); // not strictly required for push API + // start reading the body using the obtained BodySubscriber + CompletableFuture start = new MinimalFuture<>(); + start.thenCompose( v -> readBodyAsync(getPushHandler(), false, getExchange().executor())) + .whenComplete((T body, Throwable t) -> { + if (t != null) { + responseCF.completeExceptionally(t); + debug.log("Cancelling push promise %s (stream %s) due to: %s", pushId, streamId(), t); + pushManager.cancelPushPromise(pushId, t, CancelPushReason.PUSH_CANCELLED); + cancelImpl(t, Http3Error.H3_REQUEST_CANCELLED); + } else { + HttpResponseImpl resp = + new HttpResponseImpl<>(r.request, r, null, body, getExchange()); + debug.log("Completing responseCF: " + resp); + pushManager.pushPromiseProcessed(pushId); + responseCF.complete(resp); + } + }); + start.completeAsync(() -> null, getExchange().executor()); + } + + // methods to update state and remove stream when finished + + void responseReceived() { + stateLock.lock(); + try { + responseReceived0(); + } finally { + stateLock.unlock(); + } + } + + private void responseReceived0() { + assert stateLock.isHeldByCurrentThread(); + responseReceived = true; + if (debug.on()) debug.log("responseReceived: streamid=%d", streamId()); + close(); + } + + /** + * same as above but for errors + */ + void completeResponseExceptionally(Throwable t) { + pushManager.cancelPushPromise(pushId, t, CancelPushReason.PUSH_CANCELLED); + responseCF.completeExceptionally(t); + } + + void nullBody(HttpResponse resp, Throwable t) { + if (debug.on()) debug.log("nullBody: streamid=%d", streamId()); + // We should have an END_STREAM data frame waiting in the inputQ. + // We need a subscriber to force the scheduler to process it. + assert pendingResponseSubscriber == null; + pendingResponseSubscriber = HttpResponse.BodySubscribers.replacing(null); + readScheduler.runOrSchedule(); + } + + @Override + CompletableFuture> sendHeadersAsync() { + return MinimalFuture.completedFuture(this); + } + + @Override + CompletableFuture> sendBodyAsync() { + return MinimalFuture.completedFuture(this); + } + + CompletableFuture> responseCF() { + return responseCF; + } + + + BodyHandler getPushHandler() { + // ignored parameters to function can be used as BodyHandler + return this.pushHandler; + } + + @Override + CompletableFuture readBodyAsync(BodyHandler handler, + boolean returnConnectionToPool, + Executor executor) { + try { + Log.logTrace("Reading body on stream {0}", streamId()); + debug.log("Getting BodySubscriber for: " + response); + Http3PushStreamResponseSubscriber bodySubscriber = + createResponseSubscriber(handler, new ResponseInfoImpl(response)); + CompletableFuture cf = receiveResponseBody(bodySubscriber, executor); + + PushGroup pg = parent.exchange.getPushGroup(); + if (pg != null) { + // if an error occurs make sure it is recorded in the PushGroup + cf = cf.whenComplete((t, e) -> pg.pushError(e)); + } + var bodyCF = cf; + return bodyCF; + } catch (Throwable t) { + // may be thrown by handler.apply + // TODO: Is this the right error code? + cancelImpl(t, Http3Error.H3_REQUEST_CANCELLED); + PushGroup pg = parent.exchange.getPushGroup(); + if (pg != null) { + // if an error occurs make sure it is recorded in the PushGroup + pg.pushError(t); + } + return MinimalFuture.failedFuture(t); + } + } + + // This method doesn't send any frame + void close() { + if (closed) return; + stateLock.lock(); + try { + if (closed) return; + closed = true; + } finally { + stateLock.unlock(); + } + if (debug.on()) debug.log("stream %d is now closed", streamId()); + Log.logTrace("Stream {0} is now closed", streamId()); + + BodySubscriber subscriber = responseSubscriber; + if (subscriber == null) subscriber = pendingResponseSubscriber; + if (subscriber instanceof Http3PushStreamResponseSubscriber h3srs) { + // ensure subscriber is unregistered + h3srs.complete(errorRef.get()); + } + connection.onPushPromiseStreamClosed(this, streamId()); + } + + @Override + Response newResponse(HttpHeaders responseHeaders, int responseCode) { + return this.response = new Response( + exchange.request, exchange, responseHeaders, connection(), + responseCode, Version.HTTP_3); + } + + protected void handleResponse() { + handleResponse(respHeadersBuilder, respHeadersConsumer, readScheduler, debug); + } + + @Override + void receivePushPromiseFrame(PushPromiseFrame ppf, List payload) throws IOException { + readScheduler.stop(); + connectionError(new ProtocolException("Unexpected PUSH_PROMISE on push response stream"), H3_FRAME_UNEXPECTED); + } + + @Override + void onPollException(QuicStreamReader reader, IOException io) { + if (Log.http3()) { + Log.logHttp3("{0}/streamId={1} pushId={2} #{3} (responseReceived={4}, " + + "reader={5}, statusCode={6}, finalStream={9}): {10}", + connection().quicConnection().logTag(), + String.valueOf(reader.stream().streamId()), pushId, String.valueOf(exchange.multi.id), + responseReceived, reader.receivingState(), + String.valueOf(responseCode), connection.isFinalStream(), io); + } + } + + @Override + void onReaderReset() { + long errorCode = stream.rcvErrorCode(); + String resetReason = Http3Error.stringForCode(errorCode); + Http3Error resetError = Http3Error.fromCode(errorCode) + .orElse(Http3Error.H3_REQUEST_CANCELLED); + if (!responseReceived) { + cancelImpl(new IOException("Stream %s reset by peer: %s" + .formatted(streamId(), resetReason)), + resetError); + } + if (debug.on()) { + debug.log("Stream %s reset by peer [%s]: Stopping scheduler", + streamId(), resetReason); + } + readScheduler.stop(); + } + + // Invoked when some data is received from the request-response + // Quic stream + private void processQuicData() { + // Poll bytes from the request-response stream + // and parses the data to read HTTP/3 frames. + // + // If the frame being read is a header frame, send the + // compacted header field data to QPack. + // + // Otherwise, if it's a data frame, send the bytes + // to the response body subscriber. + // + // Finally, if the frame being read is a PushPromiseFrame, + // sends the compressed field data to the QPack decoder to + // decode the push promise request headers. + try { + processQuicData(reader, framesDecoder, frameOrderVerifier, readScheduler, debug); + } catch (Throwable t) { + debug.log("processQuicData - Unexpected exception", t); + if (!responseReceived) { + cancelImpl(t, Http3Error.H3_REQUEST_CANCELLED); + } + } finally { + debug.log("processQuicData - leaving - eof: %s", framesDecoder.eof()); + } + } + + // invoked when ByteBuffers containing the next payload bytes for the + // given partial header frame are received + void receiveHeaders(HeadersFrame headers, List payload) + throws IOException { + debug.log("receive headers: buffer list: " + payload); + boolean completed = headers.remaining() == 0; + boolean eof = false; + if (payload != null) { + int last = payload.size() - 1; + for (int i = 0; i <= last; i++) { + ByteBuffer buf = payload.get(i); + boolean endOfHeaders = completed && i == last; + if (debug.on()) + debug.log("QPack decoding %s bytes from headers (last: %s)", + buf.remaining(), last); + // if we have finished receiving the header frame, pause reading until + // the status code has been decoded + if (endOfHeaders) switchReadingPaused(true); + qpackDecoder.decodeHeader(buf, + endOfHeaders, + respHeaderFrameReader); + if (buf == QuicStreamReader.EOF) { + // we are at EOF - no need to pause reading + switchReadingPaused(false); + eof = true; + } + } + } + if (!completed && eof) { + cancelImpl(new EOFException("EOF reached: " + headers), + Http3Error.H3_REQUEST_CANCELLED); + } + } + + void connectionError(Throwable throwable, long errorCode, String errMsg) { + if (errorRef.compareAndSet(null, throwable)) { + var streamid = streamId(); + if (debug.on()) { + if (streamid == -1) { + debug.log("cancelling stream due to connection error", throwable); + } else { + debug.log("cancelling stream " + streamid + + " due to connection error", throwable); + } + } + if (Log.trace()) { + if (streamid == -1) { + Log.logTrace( "connection error: {0}", errMsg); + } else { + var format = "cancelling stream {0} due to connection error: {1}"; + Log.logTrace(format, streamid, errMsg); + } + } + } + connection.connectionError(this, throwable, errorCode, errMsg); + } + + + // pushes entire response body into response subscriber + // blocking when required by local or remote flow control + CompletableFuture receiveResponseBody(BodySubscriber bodySubscriber, Executor executor) { + // We want to allow the subscriber's getBody() method to block so it + // can work with InputStreams. So, we offload execution. + responseBodyCF = ResponseSubscribers.getBodyAsync(executor, bodySubscriber, + new MinimalFuture<>(), (t) -> this.cancelImpl(t, Http3Error.H3_REQUEST_CANCELLED)); + + if (isCanceled()) { + Throwable t = getCancelCause(); + responseBodyCF.completeExceptionally(t); + } + + // ensure that the body subscriber will be subsribed and onError() is + // invoked + pendingResponseSubscriber = bodySubscriber; + readScheduler.runOrSchedule(); // in case data waiting already to be processed, or error + + return responseBodyCF; + } + + void onSubscriptionError(Throwable t) { + errorRef.compareAndSet(null, t); + if (debug.on()) debug.log("Got subscription error: %s", (Object) t); + // This is the special case where the subscriber + // has requested an illegal number of items. + // In this case, the error doesn't come from + // upstream, but from downstream, and we need to + // handle the error without waiting for the inputQ + // to be exhausted. + stopRequested = true; + readScheduler.runOrSchedule(); + } + + // This loop is triggered to push response body data into + // the body subscriber. + void pushResponseData(ConcurrentLinkedQueue> responseData) { + debug.log("pushResponseData"); + boolean onCompleteCalled = false; + BodySubscriber subscriber = responseSubscriber; + boolean done = false; + try { + if (subscriber == null) { + subscriber = responseSubscriber = pendingResponseSubscriber; + if (subscriber == null) { + // can't process anything yet + return; + } else { + if (debug.on()) debug.log("subscribing user subscriber"); + subscriber.onSubscribe(userSubscription); + } + } + while (!responseData.isEmpty()) { + List data = responseData.peek(); + List dsts = Collections.unmodifiableList(data); + long size = Utils.remaining(dsts, Long.MAX_VALUE); + boolean finished = dsts.contains(QuicStreamReader.EOF); + if (size == 0 && finished) { + responseData.remove(); + Log.logTrace("responseSubscriber.onComplete"); + if (debug.on()) debug.log("pushResponseData: onComplete"); + subscriber.onComplete(); + done = true; + onCompleteCalled = true; + responseReceived(); + return; + } else if (userSubscription.tryDecrement()) { + responseData.remove(); + Log.logTrace("responseSubscriber.onNext {0}", size); + if (debug.on()) debug.log("pushResponseData: onNext(%d)", size); + subscriber.onNext(dsts); + } else { + if (stopRequested) break; + debug.log("no demand"); + return; + } + } + if (framesDecoder.eof() && responseData.isEmpty()) { + debug.log("pushResponseData: EOF"); + if (!onCompleteCalled) { + Log.logTrace("responseSubscriber.onComplete"); + if (debug.on()) debug.log("pushResponseData: onComplete"); + subscriber.onComplete(); + done = true; + onCompleteCalled = true; + responseReceived(); + return; + } + } + } catch (Throwable throwable) { + debug.log("pushResponseData: unexpected exception", throwable); + errorRef.compareAndSet(null, throwable); + } finally { + if (done) responseData.clear(); + } + + Throwable t = errorRef.get(); + if (t != null) { + try { + if (!onCompleteCalled) { + if (debug.on()) + debug.log("calling subscriber.onError: %s", (Object) t); + subscriber.onError(t); + } else { + if (debug.on()) + debug.log("already completed: dropping error %s", (Object) t); + } + } catch (Throwable x) { + Log.logError("Subscriber::onError threw exception: {0}", t); + } finally { + cancelImpl(t, Http3Error.H3_REQUEST_CANCELLED); + responseData.clear(); + } + } + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http3Stream.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http3Stream.java new file mode 100644 index 00000000000..cdac68b47f1 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http3Stream.java @@ -0,0 +1,693 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http; + +import java.io.EOFException; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.ProtocolException; +import java.net.http.HttpHeaders; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.OptionalLong; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; + +import jdk.internal.net.http.common.HttpHeadersBuilder; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.common.ValidatingHeadersConsumer; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.DataFrame; +import jdk.internal.net.http.http3.frames.FramesDecoder; +import jdk.internal.net.http.http3.frames.HeadersFrame; +import jdk.internal.net.http.http3.frames.Http3Frame; +import jdk.internal.net.http.http3.frames.Http3FrameType; +import jdk.internal.net.http.http3.frames.MalformedFrame; +import jdk.internal.net.http.http3.frames.PartialFrame; +import jdk.internal.net.http.http3.frames.PushPromiseFrame; +import jdk.internal.net.http.http3.frames.UnknownFrame; +import jdk.internal.net.http.qpack.Decoder; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.readers.HeaderFrameReader; +import jdk.internal.net.http.quic.streams.QuicStreamReader; + +import static jdk.internal.net.http.Exchange.MAX_NON_FINAL_RESPONSES; +import static jdk.internal.net.http.RedirectFilter.HTTP_NOT_MODIFIED; + +/** + * A common super class for the HTTP/3 request/response stream ({@link Http3ExchangeImpl} + * and the HTTP/3 push promises stream ({@link Http3PushPromiseStream}. + * @param the expected type of the response body + */ +sealed abstract class Http3Stream extends ExchangeImpl permits Http3ExchangeImpl, Http3PushPromiseStream { + enum ResponseState { PERMIT_HEADER, PERMIT_TRAILER, PERMIT_NONE } + + // count of bytes read from the Quic stream. This is weakly consistent and + // used for debug only. Must not be updated outside of processQuicData + private volatile long receivedQuicBytes; + // keep track of which HTTP/3 frames have been parsed and whether more header + // frames are permitted + private ResponseState responseState = ResponseState.PERMIT_HEADER; + // value of content-length header in the response header, or null + private Long contentLength; + // number of data bytes delivered to user subscriber + private long consumedDataBytes; + // switched to true if reading from the quic stream should be temporarily + // paused. After switching back to false, readScheduler.runOrSchedule() should + // called. + private volatile boolean readingPaused; + + // A temporary buffer for response body bytes + final ConcurrentLinkedQueue> responseData = new ConcurrentLinkedQueue<>(); + + private final AtomicInteger nonFinalResponseCount = new AtomicInteger(); + + + Http3Stream(Exchange exchange) { + super(exchange); + } + + /** + * Cancel the stream exchange on error + * @param throwable an exception to be relayed to the multi exchange + * through the completable future chain + * @param error an HTTP/3 error + */ + abstract void cancelImpl(Throwable throwable, Http3Error error); + + /** + * {@return the Quic stream id for this exchange (request/response or push response)} + */ + abstract long streamId(); + + /** + * A base class implementing {@link DecodingCallback} used for receiving + * and building HttpHeaders. Can be used for request headers, response headers, + * push response headers, or trailers. + */ + abstract class StreamHeadersConsumer extends ValidatingHeadersConsumer + implements DecodingCallback { + + private volatile boolean hasError; + + StreamHeadersConsumer(Context context) { + super(context); + } + + abstract Decoder qpackDecoder(); + + abstract HeaderFrameReader headerFrameReader(); + + abstract HttpHeadersBuilder headersBuilder(); + + abstract void resetDone(); + + @Override + public void reset() { + super.reset(); + headerFrameReader().reset(); + headersBuilder().clear(); + hasError = false; + resetDone(); + } + + String headerFieldType() {return "HEADER FIELD";} + + @Override + public void onDecoded(CharSequence name, CharSequence value) { + try { + String n = name.toString(); + String v = value.toString(); + super.onDecoded(n, v); + headersBuilder().addHeader(n, v); + if (Log.headers() && Log.trace()) { + Log.logTrace("RECEIVED {0} (streamid={1}): {2}: {3}", + headerFieldType(), streamId(), n, v); + } + } catch (Throwable throwable) { + if (throwable instanceof UncheckedIOException uio) { + // UncheckedIOException is thrown by ValidatingHeadersConsumer.onDecoded + // for cases with invalid headers or unknown/unsupported pseudo-headers. + // It should be treated as a malformed request. + // RFC-9114 4.1.2. Malformed Requests and Responses: + // Malformed requests or responses that are + // detected MUST be treated as a stream error of + // type H3_MESSAGE_ERROR. + onStreamError(uio.getCause(), Http3Error.H3_MESSAGE_ERROR); + } else { + onConnectionError(throwable, Http3Error.H3_INTERNAL_ERROR); + } + } + } + + @Override + public void onComplete() { + // RFC-9204 2.2.2.1: After the decoder finishes decoding a field + // section encoded using representations containing dynamic table + // references, it MUST emit a Section Acknowledgment instruction + qpackDecoder().ackSection(streamId(), headerFrameReader()); + qpackDecoder().resetInsertionsCounter(); + headersCompleted(); + } + + abstract void headersCompleted(); + + @Override + public void onStreamError(Throwable throwable, Http3Error http3Error) { + hasError = true; + qpackDecoder().resetInsertionsCounter(); + // Stream error + cancelImpl(throwable, http3Error); + } + + @Override + public void onConnectionError(Throwable throwable, Http3Error http3Error) { + hasError = true; + // Connection error + connectionError(throwable, http3Error); + } + + @Override + public boolean hasError() { + return hasError; + } + + } + + /** + * {@return count of bytes read from the QUIC stream so far} + */ + public long receivedQuicBytes() { + return receivedQuicBytes; + } + + /** + * Notify of a connection error. + * + * The implementation of this method is supposed to close all + * exchanges, cancel all push promises, and close the connection. + * + * @implSpec + * The implementation of this method calls + * {@snippet lang=java : + * connectionError(throwable, error.code(), throwable.getMessage()); + * } + * + * @param throwable an exception to be relayed to the multi exchange + * through the completable future chain + * @param error an HTTP/3 error + */ + void connectionError(Throwable throwable, Http3Error error) { + connectionError(throwable, error.code(), throwable.getMessage()); + } + + + /** + * Notify of a connection error. + * + * The implementation of this method is supposed to close all + * exchanges, cancel all push promises, and close the connection. + * + * @param throwable an exception to be relayed to the multi exchange + * through the completable future chain + * @param errorCode an HTTP/3 error code + * @param errMsg an error message to be logged when closing the connection + */ + abstract void connectionError(Throwable throwable, long errorCode, String errMsg); + + + /** + * Push response data to the {@linkplain java.net.http.HttpResponse.BodySubscriber + * response body subscriber} if allowed by the subscription state. + * @param responseData a queue of available data to be pushed to the subscriber + */ + abstract void pushResponseData(ConcurrentLinkedQueue> responseData); + + /** + * Called when an exception is thrown by {@link QuicStreamReader#poll() reader::poll} + * when called from {@link #processQuicData(QuicStreamReader, FramesDecoder, + * H3FrameOrderVerifier, SequentialScheduler, Logger) processQuicData}. + * This is typically only used for logging purposes. + * @param reader the stream reader + * @param io the exception caught + */ + abstract void onPollException(QuicStreamReader reader, IOException io); + + /** + * Called when new payload data is received by {@link #processQuicData(QuicStreamReader, + * FramesDecoder, H3FrameOrderVerifier, SequentialScheduler, Logger) processQuicData} + * for a given header frame. + *

+ * Any exception thrown here will be rethrown by {@code processQuicData} + * + * @param headers a partially received header frame + * @param payload the payload bytes available for that frame + * @throws IOException if an error is detected + */ + abstract void receiveHeaders(HeadersFrame headers, List payload) throws IOException; + + /** + * Called when new payload data is received by {@link #processQuicData(QuicStreamReader, + * FramesDecoder, H3FrameOrderVerifier, SequentialScheduler, Logger) processQuicData} + * for a given push promise frame. + *

+ * Any exception thrown here will be rethrown by {@code processQuicData} + * + * @param ppf a partially received push promise frame + * @param payload the payload bytes available for that frame + * @throws IOException if an error is detected + */ + abstract void receivePushPromiseFrame(PushPromiseFrame ppf, List payload) throws IOException; + + /** + * {@return whether reading from the quic stream is currently paused} + * Typically reading is paused when waiting for headers to be decoded by QPack. + */ + boolean readingPaused() {return readingPaused;} + + /** + * Switches the value of the {@link #readingPaused() readingPaused} + * flag + *

+ * Subclasses of {@code Http3Stream} can call this method to switch + * the value of this flag if needed, typically in their + * concrete implementation of {@link #receiveHeaders(HeadersFrame, List)}. + * @param value the new value + */ + void switchReadingPaused(boolean value) { + readingPaused = value; + } + + // invoked when ByteBuffers containing the next payload bytes for the + // given partial data frame are received. + private void receiveData(DataFrame data, List payload, Logger debug) { + if (debug.on()) { + debug.log("receiveData: adding %s payload byte", Utils.remaining(payload)); + } + responseData.add(payload); + pushResponseData(responseData); + } + + private ByteBuffer pollIfNotReset(QuicStreamReader reader) throws IOException { + ByteBuffer buffer; + try { + if (reader.isReset()) return null; + buffer = reader.poll(); + } catch (IOException io) { + if (reader.isReset()) return null; + onPollException(reader, io); + throw io; + } + return buffer; + } + + private Throwable toThrowable(MalformedFrame malformedFrame) { + Throwable cause = malformedFrame.getCause(); + if (cause != null) return cause; + return new ProtocolException(malformedFrame.toString()); + } + + /** + * Called when {@code processQuicData} detects that the {@linkplain + * QuicStreamReader reader} has been reset. + * This method should do the appropriate garbage collection, + * possibly closing the exchange or the connection if needed, and + * closing the read scheduler. + */ + abstract void onReaderReset(); + + /** + * Invoked when some data is received from the underlying quic stream. + * This implements the read loop for a request-response stream or a + * push response stream. + */ + void processQuicData(QuicStreamReader reader, + FramesDecoder framesDecoder, + H3FrameOrderVerifier frameOrderVerifier, + SequentialScheduler readScheduler, + Logger debug) throws IOException { + + + // Poll bytes from the request-response stream + // and parses the data to read HTTP/3 frames. + // + // If the frame being read is a header frame, send the + // compacted header field data to QPack. + // + // Otherwise, if it's a data frame, send the bytes + // to the response body subscriber. + // + // Finally, if the frame being read is a PushPromiseFrame, + // sends the compressed field data to the QPack decoder to + // decode the push promise request headers. + // + + // the reader might be null if the loop is triggered before + // the field is assigned + if (reader == null) return; + + // check whether we need to wait until response headers + // have been decoded: in that case readingPaused will be true + if (readingPaused) return; + + if (debug.on()) debug.log("processQuicData"); + ByteBuffer buffer; + Http3Frame frame; + pushResponseData(responseData); + boolean readmore = responseData.isEmpty(); + // do not read more until data has been pulled + while (readmore && (buffer = pollIfNotReset(reader)) != null) { + if (debug.on()) + debug.log("processQuicData - submitting buffer: %s bytes (ByteBuffer@%s)", + buffer.remaining(), System.identityHashCode(buffer)); + // only updated here + var received = receivedQuicBytes; + receivedQuicBytes = received + buffer.remaining(); + framesDecoder.submit(buffer); + while ((frame = framesDecoder.poll()) != null) { + if (debug.on()) debug.log("processQuicData - frame: " + frame); + final long frameType = frame.type(); + // before we start processing, verify that this frame *type* has arrived in the + // allowed order + if (!frameOrderVerifier.allowsProcessing(frame)) { + final String unexpectedFrameType = Http3FrameType.asString(frameType); + // not expected to be arriving now + // RFC-9114, section 4.1 - Receipt of an invalid sequence of frames MUST be + // treated as a connection error of type H3_FRAME_UNEXPECTED. + if (debug.on()) { + debug.log("unexpected (order of) frame type: " + + unexpectedFrameType + " on stream"); + } + Log.logError("Connection error due to unexpected (order of) frame type" + + " {0} on stream", unexpectedFrameType); + readScheduler.stop(); + final String errMsg = "Unexpected frame " + unexpectedFrameType; + connectionError(new ProtocolException(errMsg), Http3Error.H3_FRAME_UNEXPECTED); + return; + } + if (frame instanceof PartialFrame partialFrame) { + final List payload = framesDecoder.readPayloadBytes(); + if (debug.on()) { + debug.log("processQuicData - payload: %s", + payload == null ? null : Utils.remaining(payload)); + } + if (framesDecoder.eof() && !framesDecoder.clean()) { + String msg = "Frame truncated: " + partialFrame; + connectionError(new ProtocolException(msg), + Http3Error.H3_FRAME_ERROR.code(), + msg); + break; + } + if ((payload == null || payload.isEmpty()) && partialFrame.remaining() != 0) { + break; + } + if (partialFrame instanceof HeadersFrame headers) { + receiveHeaders(headers, payload); + // check if we need to wait for the status code to be decoded + // before reading more + readmore = !readingPaused; + } else if (partialFrame instanceof DataFrame data) { + if (responseState != ResponseState.PERMIT_TRAILER) { + cancelImpl(new IOException("DATA frame not expected here"), Http3Error.H3_MESSAGE_ERROR); + return; + } + if (payload != null) { + consumedDataBytes += Utils.remaining(payload); + if (contentLength != null && + consumedDataBytes + data.remaining() > contentLength) { + cancelImpl(new IOException( + String.format("DATA frame (length %d) exceeds content-length (%d) by %d", + data.streamingLength(), contentLength, + consumedDataBytes + data.remaining() - contentLength)), + Http3Error.H3_MESSAGE_ERROR); + return; + } + // don't read more if there is pending data waiting + // to be read from downstream + readmore = responseData.isEmpty(); + receiveData(data, payload, debug); + } + } else if (partialFrame instanceof PushPromiseFrame ppf) { + receivePushPromiseFrame(ppf, payload); + } else if (partialFrame instanceof UnknownFrame) { + if (debug.on()) { + debug.log("ignoring %s bytes for unknown frame type: %s", + Utils.remaining(payload), + Http3FrameType.asString(frameType)); + } + } else { + // should never come here: the only frame that we can + // receive on a request-response stream are + // HEADERS, DATA, PUSH_PROMISE, and RESERVED/UNKNOWN + // All have already been taken care above. + // So this here should be dead-code. + String msg = "unhandled frame type: " + + Http3FrameType.asString(frameType); + if (debug.on()) debug.log("Warning: %s", msg); + throw new AssertionError(msg); + } + // mark as complete, if all expected data has been read for a frame + if (partialFrame.remaining() == 0) { + frameOrderVerifier.completed(frame); + } + } else if (frame instanceof MalformedFrame malformed) { + var cause = malformed.getCause(); + if (cause != null && debug.on()) { + debug.log(malformed.toString(), cause); + } + readScheduler.stop(); + connectionError(toThrowable(malformed), + malformed.getErrorCode(), + malformed.getMessage()); + return; + } else { + // should never come here: the only frame that we can + // receive on a request-response stream are + // HEADERS, DATA, PUSH_PROMISE, and RESERVED/UNKNOWN + // All should have already been taken care above, + // including malformed frames. So this here should be + // dead-code. + String msg = "unhandled frame type: " + + Http3FrameType.asString(frameType); + if (debug.on()) debug.log("Warning: %s", msg); + throw new AssertionError(msg); + } + if (framesDecoder.eof()) break; + } + if (framesDecoder.eof()) break; + } + if (framesDecoder.eof()) { + if (!framesDecoder.clean()) { + String msg = "EOF reading frame type and length"; + connectionError(new ProtocolException(msg), + Http3Error.H3_FRAME_ERROR.code(), + msg); + } + if (debug.on()) debug.log("processQuicData - EOF"); + if (responseState == ResponseState.PERMIT_HEADER) { + cancelImpl(new EOFException("EOF reached: no header bytes received"), Http3Error.H3_MESSAGE_ERROR); + } else { + if (contentLength != null && + consumedDataBytes != contentLength) { + cancelImpl(new IOException( + String.format("fixed content-length: %d, bytes received: %d", contentLength, consumedDataBytes)), + Http3Error.H3_MESSAGE_ERROR); + return; + } + receiveData(new DataFrame(0), + List.of(QuicStreamReader.EOF), debug); + } + } + if (framesDecoder.eof() && responseData.isEmpty()) { + if (debug.on()) debug.log("EOF: Stopping scheduler"); + readScheduler.stop(); + } + if (reader.isReset() && responseData.isEmpty()) { + onReaderReset(); + } + } + + final String checkInterimResponseCountExceeded() { + // this is also checked by Exchange - but tracking it here too provides + // a more informative message. + int count = nonFinalResponseCount.incrementAndGet(); + if (MAX_NON_FINAL_RESPONSES > 0 && (count < 0 || count > MAX_NON_FINAL_RESPONSES)) { + return String.format( + "Stream %s PROTOCOL_ERROR: too many interim responses received: %s > %s", + streamId(), count, MAX_NON_FINAL_RESPONSES); + } + return null; + } + + /** + * Called to create a new Response object for the newly receive response headers and + * response status code. This method is called from {@link #handleResponse(HttpHeadersBuilder, + * StreamHeadersConsumer, SequentialScheduler, Logger) handleResponse}, after the status code + * and headers have been validated. + * + * @param responseHeaders response headers + * @param responseCode response code + * @return a new {@code Response} object + */ + abstract Response newResponse(HttpHeaders responseHeaders, int responseCode); + + /** + * Called at the end of {@link #handleResponse(HttpHeadersBuilder, + * StreamHeadersConsumer, SequentialScheduler, Logger) handleResponse}, to propagate + * the response to the multi exchange. + * @param response the {@code Response} that was received. + */ + abstract void completeResponse(Response response); + + /** + * Validate response headers and status code based on the {@link #responseState}. + * If validated, this method will call {@link #newResponse(HttpHeaders, int)} to + * create a {@code Response} object, which it will then pass to + * {@link #completeResponse(Response)}. + * + * @param responseHeadersBuilder the response headers builder + * @param rspHeadersConsumer the response headers consumer + * @param readScheduler the read scheduler + * @param debug the debug logger + */ + void handleResponse(HttpHeadersBuilder responseHeadersBuilder, + StreamHeadersConsumer rspHeadersConsumer, + SequentialScheduler readScheduler, + Logger debug) { + if (responseState == ResponseState.PERMIT_NONE) { + connectionError(new ProtocolException("HEADERS after trailer"), + Http3Error.H3_FRAME_UNEXPECTED.code(), + "HEADERS after trailer"); + return; + } + HttpHeaders responseHeaders = responseHeadersBuilder.build(); + if (responseState == ResponseState.PERMIT_TRAILER) { + if (responseHeaders.firstValue(":status").isPresent()) { + cancelImpl(new IOException("Unexpected :status header in trailer"), Http3Error.H3_MESSAGE_ERROR); + return; + } + if (Log.headers()) { + Log.logHeaders("Ignoring trailers on stream {0}: {1}", streamId(), responseHeaders); + } else if (debug.on()) { + debug.log("Ignoring trailers: %s", responseHeaders); + } + responseState = ResponseState.PERMIT_NONE; + rspHeadersConsumer.reset(); + if (readingPaused) { + readingPaused = false; + readScheduler.runOrSchedule(exchange.executor()); + } + return; + } + + int responseCode; + boolean finalResponse = false; + try { + responseCode = (int) responseHeaders + .firstValueAsLong(":status") + .orElseThrow(() -> new IOException("no statuscode in response")); + } catch (IOException | NumberFormatException exception) { + // RFC-9114: 4.1.2. Malformed Requests and Responses: + // "Malformed requests or responses that are + // detected MUST be treated as a stream error of type H3_MESSAGE_ERROR" + cancelImpl(exception, Http3Error.H3_MESSAGE_ERROR); + return; + } + if (responseCode < 100 || responseCode > 999) { + cancelImpl(new IOException("Unexpected :status header value"), Http3Error.H3_MESSAGE_ERROR); + return; + } + + if (responseCode >= 200) { + responseState = ResponseState.PERMIT_TRAILER; + finalResponse = true; + } else { + assert responseCode >= 100 && responseCode <= 200 : "unexpected responseCode: " + responseCode; + String protocolErrorMsg = checkInterimResponseCountExceeded(); + if (protocolErrorMsg != null) { + if (debug.on()) { + debug.log(protocolErrorMsg); + } + cancelImpl(new ProtocolException(protocolErrorMsg), Http3Error.H3_GENERAL_PROTOCOL_ERROR); + rspHeadersConsumer.reset(); + return; + } + } + + // update readingPaused after having decoded the statusCode and + // switched the responseState. + if (readingPaused) { + readingPaused = false; + readScheduler.runOrSchedule(exchange.executor()); + } + + var response = newResponse(responseHeaders, responseCode); + + if (debug.on()) { + debug.log("received response headers: %s", + responseHeaders); + } + + try { + OptionalLong cl = responseHeaders.firstValueAsLong("content-length"); + if (finalResponse && cl.isPresent()) { + long cll = cl.getAsLong(); + if (cll < 0) { + cancelImpl(new IOException("Invalid content-length value "+cll), Http3Error.H3_MESSAGE_ERROR); + return; + } + if (!(exchange.request().method().equalsIgnoreCase("HEAD") || responseCode == HTTP_NOT_MODIFIED)) { + // HEAD response and 304 response might have a content-length header, + // but it carries no meaning + contentLength = cll; + } + } + } catch (NumberFormatException nfe) { + cancelImpl(nfe, Http3Error.H3_MESSAGE_ERROR); + return; + } + + if (Log.headers() || debug.on()) { + StringBuilder sb = new StringBuilder("H3 RESPONSE HEADERS (stream="); + sb.append(streamId()).append(")\n"); + Log.dumpHeaders(sb, " ", responseHeaders); + if (Log.headers()) { + Log.logHeaders(sb.toString()); + } else if (debug.on()) { + debug.log(sb); + } + } + + // this will clear the response headers + rspHeadersConsumer.reset(); + + completeResponse(response); + } + + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/HttpClientImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/HttpClientImpl.java index c58f0b0c752..b73b92add63 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/HttpClientImpl.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/HttpClientImpl.java @@ -41,6 +41,7 @@ import java.net.ProtocolException; import java.net.ProxySelector; import java.net.http.HttpConnectTimeoutException; import java.net.http.HttpTimeoutException; +import java.net.http.UnsupportedProtocolVersionException; import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; import java.nio.channels.ClosedChannelException; @@ -93,8 +94,16 @@ import jdk.internal.net.http.common.TimeSource; import jdk.internal.net.http.common.Utils; import jdk.internal.net.http.common.OperationTrackers.Trackable; import jdk.internal.net.http.common.OperationTrackers.Tracker; +import jdk.internal.net.http.common.Utils.SafeExecutor; +import jdk.internal.net.http.common.Utils.SafeExecutorService; import jdk.internal.net.http.websocket.BuilderImpl; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.util.Objects.requireNonNullElse; +import static java.util.Objects.requireNonNullElseGet; +import static jdk.internal.net.quic.QuicTLSContext.isQuicCompatible; + /** * Client implementation. Contains all configuration information and also * the selector manager thread which allows async events to be registered @@ -112,7 +121,8 @@ final class HttpClientImpl extends HttpClient implements Trackable { static final int DEFAULT_KEEP_ALIVE_TIMEOUT = 30; static final long KEEP_ALIVE_TIMEOUT = getTimeoutProp("jdk.httpclient.keepalive.timeout", DEFAULT_KEEP_ALIVE_TIMEOUT); // Defaults to value used for HTTP/1 Keep-Alive Timeout. Can be overridden by jdk.httpclient.keepalive.timeout.h2 property. - static final long IDLE_CONNECTION_TIMEOUT = getTimeoutProp("jdk.httpclient.keepalive.timeout.h2", KEEP_ALIVE_TIMEOUT); + static final long IDLE_CONNECTION_TIMEOUT_H2 = getTimeoutProp("jdk.httpclient.keepalive.timeout.h2", KEEP_ALIVE_TIMEOUT); + static final long IDLE_CONNECTION_TIMEOUT_H3 = getTimeoutProp("jdk.httpclient.keepalive.timeout.h3", IDLE_CONNECTION_TIMEOUT_H2); // Define the default factory as a static inner class // that embeds all the necessary logic to avoid @@ -145,15 +155,23 @@ final class HttpClientImpl extends HttpClient implements Trackable { static final class DelegatingExecutor implements Executor { private final BooleanSupplier isInSelectorThread; private final Executor delegate; + private final SafeExecutor safeDelegate; private final BiConsumer errorHandler; DelegatingExecutor(BooleanSupplier isInSelectorThread, Executor delegate, BiConsumer errorHandler) { this.isInSelectorThread = isInSelectorThread; this.delegate = delegate; + this.safeDelegate = delegate instanceof ExecutorService svc + ? new SafeExecutorService(svc, ASYNC_POOL, errorHandler) + : new SafeExecutor<>(delegate, ASYNC_POOL, errorHandler); this.errorHandler = errorHandler; } + Executor safeDelegate() { + return safeDelegate; + } + Executor delegate() { return delegate; } @@ -325,6 +343,8 @@ final class HttpClientImpl extends HttpClient implements Trackable { private final SelectorManager selmgr; private final FilterFactory filters; private final Http2ClientImpl client2; + private final Http3ClientImpl client3; + private final AltServicesRegistry registry; private final long id; private final String dbgTag; private final InetAddress localAddr; @@ -386,6 +406,7 @@ final class HttpClientImpl extends HttpClient implements Trackable { private final AtomicLong pendingHttpOperationsCount = new AtomicLong(); private final AtomicLong pendingHttpRequestCount = new AtomicLong(); private final AtomicLong pendingHttp2StreamCount = new AtomicLong(); + private final AtomicLong pendingHttp3StreamCount = new AtomicLong(); private final AtomicLong pendingTCPConnectionCount = new AtomicLong(); private final AtomicLong pendingSubscribersCount = new AtomicLong(); private final AtomicBoolean isAlive = new AtomicBoolean(); @@ -429,14 +450,26 @@ final class HttpClientImpl extends HttpClient implements Trackable { id = CLIENT_IDS.incrementAndGet(); dbgTag = "HttpClientImpl(" + id +")"; localAddr = builder.localAddr; - if (builder.sslContext == null) { + version = requireNonNullElse(builder.version, Version.HTTP_2); + sslContext = requireNonNullElseGet(builder.sslContext, () -> { try { - sslContext = SSLContext.getDefault(); + return SSLContext.getDefault(); } catch (NoSuchAlgorithmException ex) { throw new UncheckedIOException(new IOException(ex)); } - } else { - sslContext = builder.sslContext; + }); + final boolean sslCtxSupportedForH3 = isQuicCompatible(sslContext); + if (version == Version.HTTP_3 && !sslCtxSupportedForH3) { + throw new UncheckedIOException(new UnsupportedProtocolVersionException( + "HTTP3 is not supported")); + } + sslParams = requireNonNullElseGet(builder.sslParams, sslContext::getDefaultSSLParameters); + boolean sslParamsSupportedForH3 = sslParams.getProtocols() == null + || sslParams.getProtocols().length == 0 + || isQuicCompatible(sslParams); + if (version == Version.HTTP_3 && !sslParamsSupportedForH3) { + throw new UncheckedIOException(new UnsupportedProtocolVersionException( + "HTTP3 is not supported - TLSv1.3 isn't configured on SSLParameters")); } Executor ex = builder.executor; if (ex == null) { @@ -450,7 +483,6 @@ final class HttpClientImpl extends HttpClient implements Trackable { this::onSubmitFailure); facadeRef = new WeakReference<>(facadeFactory.createFacade(this)); implRef = new WeakReference<>(this); - client2 = new Http2ClientImpl(this); cookieHandler = builder.cookieHandler; connectTimeout = builder.connectTimeout; followRedirects = builder.followRedirects == null ? @@ -462,17 +494,11 @@ final class HttpClientImpl extends HttpClient implements Trackable { debug.log("proxySelector is %s (user-supplied=%s)", this.proxySelector, userProxySelector != null); authenticator = builder.authenticator; - if (builder.version == null) { - version = HttpClient.Version.HTTP_2; - } else { - version = builder.version; - } - if (builder.sslParams == null) { - sslParams = getDefaultParams(sslContext); - } else { - sslParams = builder.sslParams; - } + boolean h3Supported = sslCtxSupportedForH3 && sslParamsSupportedForH3; + registry = new AltServicesRegistry(id); connections = new ConnectionPool(id); + client2 = new Http2ClientImpl(this); + client3 = h3Supported ? new Http3ClientImpl(this) : null; connections.start(); timeouts = new TreeSet<>(); try { @@ -518,6 +544,11 @@ final class HttpClientImpl extends HttpClient implements Trackable { client2.stop(); // make sure all subscribers are completed closeSubscribers(); + // close client3 + if (client3 != null) { + // close client3 + client3.stop(); + } // close TCP connection if any are still opened openedConnections.forEach(this::closeConnection); // shutdown the executor if needed @@ -610,11 +641,6 @@ final class HttpClientImpl extends HttpClient implements Trackable { return isStarted.get() && !isAlive.get(); } - private static SSLParameters getDefaultParams(SSLContext ctx) { - SSLParameters params = ctx.getDefaultSSLParameters(); - return params; - } - // Returns the facade that was returned to the application code. // May be null if that facade is no longer referenced. final HttpClientFacade facade() { @@ -664,12 +690,14 @@ final class HttpClientImpl extends HttpClient implements Trackable { final long count = pendingOperationCount.decrementAndGet(); final long httpCount = pendingHttpOperationsCount.decrementAndGet(); final long http2Count = pendingHttp2StreamCount.get(); + final long http3Count = pendingHttp3StreamCount.get(); final long webSocketCount = pendingWebSocketCount.get(); if (count == 0 && (facadeRef.refersTo(null) || shutdownRequested)) { selmgr.wakeupSelector(); } assert httpCount >= 0 : "count of HTTP/1.1 operations < 0"; assert http2Count >= 0 : "count of HTTP/2 operations < 0"; + assert http3Count >= 0 : "count of HTTP/3 operations < 0"; assert webSocketCount >= 0 : "count of WS operations < 0"; assert count >= 0 : "count of pending operations < 0"; return count; @@ -681,10 +709,35 @@ final class HttpClientImpl extends HttpClient implements Trackable { return pendingOperationCount.incrementAndGet(); } + // Increments the pendingHttp3StreamCount and pendingOperationCount. + final long h3StreamReference() { + pendingHttp3StreamCount.incrementAndGet(); + return pendingOperationCount.incrementAndGet(); + } + // Decrements the pendingHttp2StreamCount and pendingOperationCount. final long streamUnreference() { final long count = pendingOperationCount.decrementAndGet(); final long http2Count = pendingHttp2StreamCount.decrementAndGet(); + final long http3Count = pendingHttp3StreamCount.get(); + final long httpCount = pendingHttpOperationsCount.get(); + final long webSocketCount = pendingWebSocketCount.get(); + if (count == 0 && facadeRef.refersTo(null)) { + selmgr.wakeupSelector(); + } + assert httpCount >= 0 : "count of HTTP/1.1 operations < 0"; + assert http2Count >= 0 : "count of HTTP/2 operations < 0"; + assert http3Count >= 0 : "count of HTTP/3 operations < 0"; + assert webSocketCount >= 0 : "count of WS operations < 0"; + assert count >= 0 : "count of pending operations < 0"; + return count; + } + + // Decrements the pendingHttp3StreamCount and pendingOperationCount. + final long h3StreamUnreference() { + final long count = pendingOperationCount.decrementAndGet(); + final long http2Count = pendingHttp2StreamCount.get(); + final long http3Count = pendingHttp3StreamCount.decrementAndGet(); final long httpCount = pendingHttpOperationsCount.get(); final long webSocketCount = pendingWebSocketCount.get(); if (count == 0 && (facadeRef.refersTo(null) || shutdownRequested)) { @@ -692,6 +745,7 @@ final class HttpClientImpl extends HttpClient implements Trackable { } assert httpCount >= 0 : "count of HTTP/1.1 operations < 0"; assert http2Count >= 0 : "count of HTTP/2 operations < 0"; + assert http3Count >= 0 : "count of HTTP/3 operations < 0"; assert webSocketCount >= 0 : "count of WS operations < 0"; assert count >= 0 : "count of pending operations < 0"; return count; @@ -709,11 +763,13 @@ final class HttpClientImpl extends HttpClient implements Trackable { final long webSocketCount = pendingWebSocketCount.decrementAndGet(); final long httpCount = pendingHttpOperationsCount.get(); final long http2Count = pendingHttp2StreamCount.get(); + final long http3Count = pendingHttp3StreamCount.get(); if (count == 0 && (facadeRef.refersTo(null) || shutdownRequested)) { selmgr.wakeupSelector(); } assert httpCount >= 0 : "count of HTTP/1.1 operations < 0"; assert http2Count >= 0 : "count of HTTP/2 operations < 0"; + assert http3Count >= 0 : "count of HTTP/3 operations < 0"; assert webSocketCount >= 0 : "count of WS operations < 0"; assert count >= 0 : "count of pending operations < 0"; return count; @@ -732,6 +788,7 @@ final class HttpClientImpl extends HttpClient implements Trackable { final AtomicLong requestCount; final AtomicLong httpCount; final AtomicLong http2Count; + final AtomicLong http3Count; final AtomicLong websocketCount; final AtomicLong operationsCount; final AtomicLong connnectionsCount; @@ -744,6 +801,7 @@ final class HttpClientImpl extends HttpClient implements Trackable { HttpClientTracker(AtomicLong request, AtomicLong http, AtomicLong http2, + AtomicLong http3, AtomicLong ws, AtomicLong ops, AtomicLong conns, @@ -756,6 +814,7 @@ final class HttpClientImpl extends HttpClient implements Trackable { this.requestCount = request; this.httpCount = http; this.http2Count = http2; + this.http3Count = http3; this.websocketCount = ws; this.operationsCount = ops; this.connnectionsCount = conns; @@ -787,6 +846,8 @@ final class HttpClientImpl extends HttpClient implements Trackable { @Override public long getOutstandingHttp2Streams() { return http2Count.get(); } @Override + public long getOutstandingHttp3Streams() { return http3Count.get(); } + @Override public long getOutstandingWebSocketOperations() { return websocketCount.get(); } @@ -811,6 +872,7 @@ final class HttpClientImpl extends HttpClient implements Trackable { pendingHttpRequestCount, pendingHttpOperationsCount, pendingHttp2StreamCount, + pendingHttp3StreamCount, pendingWebSocketCount, pendingOperationCount, pendingTCPConnectionCount, @@ -866,6 +928,8 @@ final class HttpClientImpl extends HttpClient implements Trackable { return Thread.currentThread() == selmgr; } + AltServicesRegistry registry() { return registry; } + boolean isSelectorClosed() { return selmgr.isClosed(); } @@ -878,6 +942,10 @@ final class HttpClientImpl extends HttpClient implements Trackable { return client2; } + Optional client3() { + return Optional.ofNullable(client3); + } + private void debugCompleted(String tag, long startNanos, HttpRequest req) { if (debugelapsed.on()) { debugelapsed.log(tag + " elapsed " @@ -917,6 +985,10 @@ final class HttpClientImpl extends HttpClient implements Trackable { HttpConnectTimeoutException hcte = new HttpConnectTimeoutException(msg); hcte.initCause(throwable); throw hcte; + } else if (throwable instanceof UnsupportedProtocolVersionException) { + var upve = new UnsupportedProtocolVersionException(msg); + upve.initCause(throwable); + throw upve; } else if (throwable instanceof HttpTimeoutException) { throw new HttpTimeoutException(msg); } else if (throwable instanceof ConnectException) { @@ -972,6 +1044,13 @@ final class HttpClientImpl extends HttpClient implements Trackable { return MinimalFuture.failedFuture(new IOException("closed")); } + final HttpClient.Version vers = userRequest.version().orElse(this.version()); + if (vers == Version.HTTP_3 && client3 == null + && userRequest.getOption(H3_DISCOVERY).orElse(null) == HTTP_3_URI_ONLY) { + // HTTP3 isn't supported by this client + return MinimalFuture.failedFuture(new UnsupportedProtocolVersionException( + "HTTP3 is not supported")); + } // should not happen, unless the selector manager has // exited abnormally if (selmgr.isClosed()) { @@ -1095,8 +1174,11 @@ final class HttpClientImpl extends HttpClient implements Trackable { } IOException selectorClosedException() { - var io = new IOException("selector manager closed"); - var cause = errorRef.get(); + final var cause = errorRef.get(); + final String msg = cause == null + ? "selector manager closed" + : "selector manager closed due to: " + cause; + final var io = new IOException(msg); if (cause != null) { io.initCause(cause); } @@ -1181,6 +1263,10 @@ final class HttpClientImpl extends HttpClient implements Trackable { } // double check after closing abortPendingRequests(owner, t); + var client3 = owner.client3; + if (client3 != null) { + client3.abort(t); + } IOException io = toAbort.isEmpty() ? null : selectorClosedException(); @@ -1456,8 +1542,8 @@ final class HttpClientImpl extends HttpClient implements Trackable { String keyInterestOps = key.isValid() ? "key.interestOps=" + Utils.interestOps(key) : "invalid key"; return String.format("channel registered with selector, %s, sa.interestOps=%s", - keyInterestOps, - Utils.describeOps(((SelectorAttachment)key.attachment()).interestOps)); + keyInterestOps, + Utils.describeOps(((SelectorAttachment)key.attachment()).interestOps)); } catch (Throwable t) { return String.valueOf(t); } @@ -1627,8 +1713,12 @@ final class HttpClientImpl extends HttpClient implements Trackable { return Optional.ofNullable(connectTimeout); } - Optional idleConnectionTimeout() { - return Optional.ofNullable(getIdleConnectionTimeout()); + Optional idleConnectionTimeout(Version version) { + return switch (version) { + case HTTP_2 -> timeoutDuration(IDLE_CONNECTION_TIMEOUT_H2); + case HTTP_3 -> timeoutDuration(IDLE_CONNECTION_TIMEOUT_H3); + case HTTP_1_1 -> timeoutDuration(KEEP_ALIVE_TIMEOUT); + }; } @Override @@ -1755,7 +1845,7 @@ final class HttpClientImpl extends HttpClient implements Trackable { // error from here - but in this case there's not much we // could do anyway. Just let it flow... if (failed == null) failed = e; - else failed.addSuppressed(e); + else Utils.addSuppressed(failed, e); Log.logTrace("Failed to handle event {0}: {1}", event, e); } } @@ -1799,10 +1889,11 @@ final class HttpClientImpl extends HttpClient implements Trackable { return sslBufferSupplier; } - private Duration getIdleConnectionTimeout() { - if (IDLE_CONNECTION_TIMEOUT >= 0) - return Duration.ofSeconds(IDLE_CONNECTION_TIMEOUT); - return null; + private Optional timeoutDuration(long seconds) { + if (seconds >= 0) { + return Optional.of(Duration.ofSeconds(seconds)); + } + return Optional.empty(); } private static long getTimeoutProp(String prop, long def) { diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/HttpConnection.java b/src/java.net.http/share/classes/jdk/internal/net/http/HttpConnection.java index 07cfc4dbdf6..0219b0960d7 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/HttpConnection.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/HttpConnection.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.http.HttpResponse; import java.nio.ByteBuffer; +import java.nio.channels.NetworkChannel; import java.nio.channels.SocketChannel; import java.util.Arrays; import java.util.Comparator; @@ -57,7 +58,10 @@ import jdk.internal.net.http.common.SequentialScheduler; import jdk.internal.net.http.common.SequentialScheduler.DeferredCompleter; import jdk.internal.net.http.common.Log; import jdk.internal.net.http.common.Utils; + +import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; import static jdk.internal.net.http.common.Utils.ProxyHeaders; /** @@ -69,12 +73,13 @@ import static jdk.internal.net.http.common.Utils.ProxyHeaders; * PlainTunnelingConnection: opens plain text (CONNECT) tunnel to server * AsyncSSLConnection: TLS channel direct to server * AsyncSSLTunnelConnection: TLS channel via (CONNECT) proxy tunnel + * HttpQuicConnection: direct QUIC connection to server */ abstract class HttpConnection implements Closeable { final Logger debug = Utils.getDebugLogger(this::dbgString, Utils.DEBUG); static final Logger DEBUG_LOGGER = Utils.getDebugLogger( - () -> "HttpConnection(SocketTube(?))", Utils.DEBUG); + () -> "HttpConnection", Utils.DEBUG); public static final Comparator COMPARE_BY_ID = Comparator.comparing(HttpConnection::id); @@ -112,8 +117,8 @@ abstract class HttpConnection implements Closeable { this.label = label; } - private static String nextLabel() { - return "" + LABEL_COUNTER.incrementAndGet(); + private static String nextLabel(String prefix) { + return prefix + LABEL_COUNTER.incrementAndGet(); } /** @@ -198,9 +203,17 @@ abstract class HttpConnection implements Closeable { abstract InetSocketAddress proxy(); /** Tells whether, or not, this connection is open. */ - final boolean isOpen() { + boolean isOpen() { return channel().isOpen() && - (connected() ? !getConnectionFlow().isFinished() : true); + (connected() ? !isFlowFinished() : true); + } + + /** + * {@return {@code true} if the {@linkplain #getConnectionFlow() + * connection flow} is {@linkplain FlowTube#isFinished() finished}. + */ + boolean isFlowFinished() { + return getConnectionFlow().isFinished(); } /** @@ -232,13 +245,17 @@ abstract class HttpConnection implements Closeable { * still open, and the method returns true. * @return true if the channel appears to be still open. */ - final boolean checkOpen() { + boolean checkOpen() { if (isOpen()) { try { // channel is non blocking - int read = channel().read(ByteBuffer.allocate(1)); - if (read == 0) return true; - close(); + if (channel() instanceof SocketChannel channel) { + int read = channel.read(ByteBuffer.allocate(1)); + if (read == 0) return true; + close(); + } else { + return channel().isOpen(); + } } catch (IOException x) { debug.log("Pooled connection is no longer operational: %s", x.toString()); @@ -294,6 +311,7 @@ abstract class HttpConnection implements Closeable { * is one of the following: * {@link PlainHttpConnection} * {@link PlainTunnelingConnection} + * {@link HttpQuicConnection} * * The returned connection, if not from the connection pool, must have its, * connect() or connectAsync() method invoked, which ( when it completes @@ -301,6 +319,7 @@ abstract class HttpConnection implements Closeable { */ public static HttpConnection getConnection(InetSocketAddress addr, HttpClientImpl client, + Exchange exchange, HttpRequestImpl request, Version version) { // The default proxy selector may select a proxy whose address is @@ -322,18 +341,27 @@ abstract class HttpConnection implements Closeable { return getPlainConnection(addr, proxy, request, client); } } else { // secure - if (version != HTTP_2) { // only HTTP/1.1 connections are in the pool + if (version == HTTP_1_1) { // only HTTP/1.1 connections are in the pool c = pool.getConnection(true, addr, proxy); } if (c != null && c.isOpen()) { - final HttpConnection conn = c; - if (DEBUG_LOGGER.on()) - DEBUG_LOGGER.log(conn.getConnectionFlow() - + ": SSL connection retrieved from HTTP/1.1 pool"); + if (DEBUG_LOGGER.on()) { + DEBUG_LOGGER.log(c.getConnectionFlow() + + ": SSL connection retrieved from HTTP/1.1 pool"); + } return c; + } else if (version == HTTP_3 && client.client3().isPresent()) { + // We only come here after we have checked the HTTP/3 connection pool, + // and if the client config supports HTTP/3 + if (DEBUG_LOGGER.on()) + DEBUG_LOGGER.log("Attempting to get an HTTP/3 connection"); + return HttpQuicConnection.getHttpQuicConnection(addr, proxy, request, exchange, client); } else { + assert !request.isHttp3Only(version); // should have failed before String[] alpn = null; if (version == HTTP_2 && hasRequiredHTTP2TLSVersion(client)) { + // We only come here after we have checked the HTTP/2 connection pool. + // We will not negotiate HTTP/2 if we don't have the appropriate TLS version alpn = new String[] { Alpns.H2, Alpns.HTTP_1_1 }; } return getSSLConnection(addr, proxy, alpn, request, client); @@ -346,7 +374,7 @@ abstract class HttpConnection implements Closeable { String[] alpn, HttpRequestImpl request, HttpClientImpl client) { - final String label = nextLabel(); + final String label = nextLabel("tls:"); final Origin originServer; try { originServer = Origin.from(request.uri()); @@ -433,7 +461,7 @@ abstract class HttpConnection implements Closeable { InetSocketAddress proxy, HttpRequestImpl request, HttpClientImpl client) { - final String label = nextLabel(); + final String label = nextLabel("tcp:"); final Origin originServer; try { originServer = Origin.from(request.uri()); @@ -483,7 +511,7 @@ abstract class HttpConnection implements Closeable { /* Tells whether or not this connection is a tunnel through a proxy */ boolean isTunnel() { return false; } - abstract SocketChannel channel(); + abstract NetworkChannel channel(); final InetSocketAddress address() { return address; @@ -516,6 +544,19 @@ abstract class HttpConnection implements Closeable { close(); } + /** + * {@return the underlying connection flow, if applicable} + * + * @apiNote + * TCP based protocols like HTTP/1.1 and HTTP/2 are built on + * top of a {@linkplain FlowTube bidirectional connection flow}. + * On the other hand, Quic based protocol like HTTP/3 are + * multiplexed at the Quic level, and therefore do not have + * a connection flow. + * + * @throws IllegalStateException if the underlying transport + * does not expose a single connection flow. + */ abstract FlowTube getConnectionFlow(); /** diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/HttpQuicConnection.java b/src/java.net.http/share/classes/jdk/internal/net/http/HttpQuicConnection.java new file mode 100644 index 00000000000..bbbe1157cdf --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/HttpQuicConnection.java @@ -0,0 +1,690 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http; + +import java.io.IOException; +import java.net.ConnectException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.SocketOption; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpConnectTimeoutException; +import java.nio.channels.NetworkChannel; +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.function.Predicate; + +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLParameters; + +import jdk.internal.net.http.ConnectionPool.CacheKey; +import jdk.internal.net.http.AltServicesRegistry.AltService; +import jdk.internal.net.http.common.FlowTube; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.ConnectionTerminator; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.http.quic.QuicConnection; + +import static jdk.internal.net.http.Http3ClientProperties.MAX_DIRECT_CONNECTION_TIMEOUT; +import static jdk.internal.net.http.common.Alpns.H3; +import static jdk.internal.net.http.http3.Http3Error.H3_INTERNAL_ERROR; +import static jdk.internal.net.http.http3.Http3Error.H3_NO_ERROR; +import static jdk.internal.net.http.quic.TerminationCause.appLayerClose; +import static jdk.internal.net.http.quic.TerminationCause.appLayerException; + +/** + * An {@code HttpQuicConnection} models an HTTP connection over + * QUIC. + * The particulars of the HTTP/3 protocol are handled by the + * Http3Connection class. + */ +abstract class HttpQuicConnection extends HttpConnection { + + final Logger debug = Utils.getDebugLogger(this::quicDbgString); + + final QuicConnection quicConnection; + final ConnectionTerminator quicConnTerminator; + // the alt-service which was advertised, from some origin, for this connection co-ordinates. + // can be null, which indicates this wasn't created because of an alt-service + private final AltService sourceAltService; + // HTTP/2 MUST use TLS version 1.3 or higher for HTTP/3 over TLS + private static final Predicate testRequiredHTTP3TLSVersion = proto -> + proto.equals("TLSv1.3"); + + + HttpQuicConnection(Origin originServer, InetSocketAddress address, HttpClientImpl client, + QuicConnection quicConnection, AltService sourceAltService) { + super(originServer, address, client, "quic:" + quicConnection.uniqueId()); + Objects.requireNonNull(quicConnection); + this.quicConnection = quicConnection; + this.quicConnTerminator = quicConnection.connectionTerminator(); + this.sourceAltService = sourceAltService; + } + + /** + * A HTTP QUIC connection could be created due to an alt-service that was advertised + * from some origin. This method returns that source alt-service if there was one. + * @return The source alt-service if present + */ + Optional getSourceAltService() { + return Optional.ofNullable(this.sourceAltService); + } + + @Override + public List getSNIServerNames() { + final SSLParameters sslParams = this.quicConnection.getTLSEngine().getSSLParameters(); + if (sslParams == null) { + return List.of(); + } + final List sniServerNames = sslParams.getServerNames(); + if (sniServerNames == null) { + return List.of(); + } + return List.copyOf(sniServerNames); + } + + final String quicDbgString() { + String tag = dbgTag; + if (tag == null) tag = dbgTag = "Http" + quicConnection.dbgTag(); + return tag; + } + + /** + * Initiates the connect phase. + * + * Returns a CompletableFuture that completes when the underlying + * TCP connection has been established or an error occurs. + */ + public abstract CompletableFuture connectAsync(Exchange exchange); + + private volatile boolean connected; + /** + * Finishes the connection phase. + * + * Returns a CompletableFuture that completes when any additional, + * type specific, setup has been done. Must be called after connectAsync. + */ + public CompletableFuture finishConnect() { + this.connected = true; + return MinimalFuture.completedFuture(null); + } + + /** Tells whether, or not, this connection is connected to its destination. */ + boolean connected() { + return connected; + } + + /** Tells whether, or not, this connection is secure ( over SSL ) */ + final boolean isSecure() { return true; } // QUIC is secure + + /** + * Tells whether, or not, this connection is proxied. + * Returns true for tunnel connections, or clear connection to + * any host through proxy. + */ + final boolean isProxied() { return false;} // Proxy not supported + + /** + * Returns the address of the proxy used by this connection. + * Returns the proxy address for tunnel connections, or + * clear connection to any host through proxy. + * Returns {@code null} otherwise. + */ + final InetSocketAddress proxy() { return null; } // Proxy not supported + + /** + * This method throws an {@link UnsupportedOperationException} + */ + @Override + final HttpPublisher publisher() { + throw new UnsupportedOperationException("no publisher for a quic connection"); + } + + QuicConnection quicConnection() { + return quicConnection; + } + + /** + * Returns true if the given client's SSL parameter protocols contains at + * least one TLS version that HTTP/3 requires. + */ + private static boolean hasRequiredHTTP3TLSVersion(HttpClient client) { + String[] protos = client.sslParameters().getProtocols(); + if (protos != null) { + return Arrays.stream(protos).anyMatch(testRequiredHTTP3TLSVersion); + } else { + return false; + } + } + + /** + * Called when the HTTP/3 connection is established, either successfully or + * unsuccessfully + * @param connection the HTTP/3 connection, if successful, or null, otherwise + * @param throwable the exception encountered, if unsuccessful + */ + public abstract void connectionEstablished(Http3Connection connection, + Throwable throwable); + + /** + * A functional interface used to update the Alternate Service Registry + * after a direct connection attempt. + */ + @FunctionalInterface + private interface DirectConnectionUpdater { + /** + * This method may update the HttpClient registry, or + * {@linkplain Http3ClientImpl#noH3(String) record the unsuccessful} + * direct connection attempt. + * + * @param conn the connection or null + * @param throwable the exception or null + */ + void onConnectionEstablished( + Http3Connection conn, Throwable throwable); + + /** + * Does nothing + * @param conn the connection + * @param throwable the exception + */ + static void noUpdate( + Http3Connection conn, Throwable throwable) { + } + } + + /** + * This method create and return a new unconnected HttpQuicConnection, + * wrapping a {@link QuicConnection}. May return {@code null} if + * HTTP/3 is not supported with the given parameters. For instance, + * if TLSv1.3 isn't available/enabled in the client's SSLParameters, + * or if ALT_SERVICE is required but no alt service is found. + * + * @param addr the HTTP/3 peer endpoint address, if direct connection + * @param proxy the proxy address, if a proxy is used, in which case this + * method will return {@code null} as proxying is not supported + * with HTTP/3 + * @param request the request for which the connection is being created + * @param exchange the exchange for which the connection is being created + * @param client the HttpClientImpl instance + * @return A new HttpQuicConnection or {@code null} + */ + public static HttpQuicConnection getHttpQuicConnection(final InetSocketAddress addr, + final InetSocketAddress proxy, + final HttpRequestImpl request, + final Exchange exchange, + final HttpClientImpl client) { + if (!client.client3().isPresent()) { + if (Log.http3()) { + Log.logHttp3("HTTP3 isn't supported by the client"); + } + return null; + } + + final Http3ClientImpl h3client = client.client3().get(); + // HTTP_3 with proxy not supported; In this case we will downgrade + // to using HTTP/2 + var debug = h3client.debug(); + var where = "HttpQuicConnection.getHttpQuicConnection"; + if (proxy != null || !hasRequiredHTTP3TLSVersion(client)) { + if (debug.on()) + debug.log("%s: proxy required or SSL version mismatch", where); + return null; + } + + assert request.secure(); + // Question: Do we need this scaffolding? + // I mean - could Http3Connection and HttpQuicConnection be the same + // object? + // Answer: Http3Connection models an established connection which is + // ready to be used. + // HttpQuicConnection serves at establishing a new Http3Connection + // => Http3Connection is pooled, HttpQuicConnection is not. + // => Do we need HttpQuicConnection vs QuicConnection? + // => yes: HttpQuicConnection can access all package protected + // APIs in HttpConnection & al + // QuicConnection is in the quic subpackage. + // HttpQuicConnection makes the necessary adaptation between + // HttpConnection and QuicConnection. + + // find whether we have an alternate service access point for HTTP/3 + // if we do, create a new QuicConnection and a new Http3Connection over it. + var uri = request.uri(); + var config = request.http3Discovery(); + if (debug.on()) { + debug.log("Checking ALT-SVC regardless of H3_DISCOVERY settings"); + } + // we only support H3 right now + var altSvc = client.registry() + .lookup(uri, H3::equals) + .findFirst().orElse(null); + Optional directTimeout = Optional.empty(); + final boolean advertisedAltSvc = altSvc != null && altSvc.wasAdvertised(); + logAltSvcFor(debug, uri, altSvc, where); + switch (config) { + case ALT_SVC: { + if (!advertisedAltSvc) { + // fallback to HTTP/2 + if (altSvc != null) { + if (Log.altsvc()) { + Log.logAltSvc("{0}: Cannot use unadvertised AltService: {1}", + config, altSvc); + } + } + return null; + } + assert altSvc != null && altSvc.wasAdvertised(); + break; + } + // attempt direct connection if HTTP/3 only + case HTTP_3_URI_ONLY: { + if (advertisedAltSvc && !altSvc.originHasSameAuthority()) { + if (Log.altsvc()) { + Log.logAltSvc("{0}: Cannot use advertised AltService: {1}", + config, altSvc); + } + altSvc = null; + } + assert altSvc == null || altSvc.originHasSameAuthority(); + break; + } + default: { + // if direct connection already attempted and failed, + // fallback to HTTP/2 + if (altSvc == null && h3client.hasNoH3(uri.getRawAuthority())) { + return null; + } + if (!advertisedAltSvc) { + // directTimeout is only used for happy eyeball + Duration def = Duration.ofMillis(MAX_DIRECT_CONNECTION_TIMEOUT); + Duration timeout = client.connectTimeout() + .filter(d -> d.compareTo(def) <= 0) + .orElse(def); + directTimeout = Optional.of(timeout); + } + break; + } + } + + if (altSvc != null) { + assert H3.equals(altSvc.alpn()); + Log.logAltSvc("{0}: Using AltService for {1}: {2}", + config, uri.getRawAuthority(), altSvc); + } + if (debug.on()) { + debug.log("%s: creating QuicConnection for: %s", where, uri); + } + final QuicConnection quicConnection = (altSvc != null) ? + h3client.quicClient().createConnectionFor(altSvc) : + h3client.quicClient().createConnectionFor(addr, new String[] {H3}); + if (debug.on()) debug.log("%s: QuicConnection: %s", where, quicConnection); + final DirectConnectionUpdater onConnectFinished = advertisedAltSvc + ? DirectConnectionUpdater::noUpdate + : (c,t) -> registerUnadvertised(client, uri, addr, c, t); + // Note: we could get rid of the updater by introducing + // H3DirectQuicConnectionImpl extends H3QuicConnectionImpl + HttpQuicConnection httpQuicConn = new H3QuicConnectionImpl(Origin.from(request.uri()), addr, client, + quicConnection, onConnectFinished, directTimeout, altSvc); + // if we created a connection and if that connection is to an (advertised) alt service then + // we setup the Exchange's request to include the "alt-used" header to refer to the + // alt service that was used (section 5, RFC-7838) + if (httpQuicConn != null && altSvc != null && advertisedAltSvc) { + exchange.request().setSystemHeader("alt-used", altSvc.authority()); + } + return httpQuicConn; + } + + private static void logAltSvcFor(Logger debug, URI uri, AltService altSvc, String where) { + if (altSvc == null) { + if (Log.altsvc()) { + Log.logAltSvc("No AltService found for {0}", uri.getRawAuthority()); + } else if (debug.on()) { + debug.log("%s: No ALT-SVC for %s", where, uri.getRawAuthority()); + } + } else { + if (debug.on()) debug.log("%s: ALT-SVC: %s", where, altSvc); + } + } + + static void registerUnadvertised(final HttpClientImpl client, + final URI requestURI, + final InetSocketAddress destAddr, + final Http3Connection connection, + final Throwable t) { + if (t == null && connection != null) { + // There is an h3 endpoint at the given origin: update the registry + final Origin origin = connection.connection().getOriginServer(); + assert origin != null : "origin server is null on connection: " + + connection.connection(); + assert origin.port() == destAddr.getPort(); + var id = new AltService.Identity(H3, origin.host(), origin.port()); + client.registry().registerUnadvertised(id, origin, connection.connection()); + return; + } + if (t != null) { + assert client.client3().isPresent() : "HTTP3 isn't supported by the client"; + final URI originURI = requestURI.resolve("/"); + // record that there is no h3 at the given origin + client.client3().get().noH3(originURI.getRawAuthority()); + } + } + + // TODO: we could probably merge H3QuicConnectionImpl with HttpQuicConnection now + static class H3QuicConnectionImpl extends HttpQuicConnection { + private final Optional directTimeout; + private final DirectConnectionUpdater connFinishedAction; + H3QuicConnectionImpl(Origin originServer, + InetSocketAddress address, + HttpClientImpl client, + QuicConnection quic, + DirectConnectionUpdater connFinishedAction, + Optional directTimeout, + AltService sourceAltService) { + super(originServer, address, client, quic, sourceAltService); + this.directTimeout = directTimeout; + this.connFinishedAction = connFinishedAction; + } + + @Override + public CompletableFuture connectAsync(Exchange exchange) { + var request = exchange.request(); + var uri = request.uri(); + // Adapt HandshakeCF to CompletableFuture + CompletableFuture> handshakeCfCf = + quicConnection.startHandshake() + .handle((r, t) -> { + if (t == null) { + // successful handshake + return MinimalFuture.completedFuture(r); + } + final TerminationCause terminationCause = quicConnection.terminationCause(); + final boolean appLayerTermination = terminationCause != null + && terminationCause.isAppLayer(); + // QUIC connection handshake failed. we now decide whether we should + // unregister the alt-service (if any) that was the source of this + // connection attempt. + // + // handshake could have failed for one of several reasons, some of them: + // - something at QUIC layer caused the failure (either some internal + // exception or protocol error or QUIC TLS error) + // - or the app layer, through the HttpClient/HttpConnection + // could have triggered a connection close. + // + // we unregister the alt-service (if any) only if the termination cause + // originated in the QUIC layer. An app layer termination cause doesn't + // necessarily mean that the alt-service isn't valid for subsequent use. + if (!appLayerTermination && this.getSourceAltService().isPresent()) { + final AltService altSvc = this.getSourceAltService().get(); + if (debug.on()) { + debug.log("connection attempt to an alternate service at " + + altSvc.authority() + " failed during handshake: " + t); + } + client().registry().markInvalid(this.getSourceAltService().get()); + // fail with ConnectException to allow the request to potentially + // be retried on a different connection + final ConnectException connectException = new ConnectException( + "QUIC connection handshake to an alternate service failed"); + connectException.initCause(t); + return MinimalFuture.failedFuture(connectException); + } else { + // alt service wasn't the cause of this failed connection attempt. + // return a failed future with the original cause + return MinimalFuture.failedFuture(t); + } + }) + .thenApply((handshakeCompletion) -> { + if (handshakeCompletion.isCompletedExceptionally()) { + return MinimalFuture.failedFuture(handshakeCompletion.exceptionNow()); + } + return MinimalFuture.completedFuture(null); + }); + + // In case of direct connection, set up a timeout on the handshakeReachedPeerCf, + // and arrange for it to complete the handshakeCfCf above with a timeout in + // case that timeout expires... + if (directTimeout.isPresent()) { + debug.log("setting up quic direct connect timeout: " + directTimeout.get().toMillis()); + var handshakeReachedPeerCf = quicConnection.handshakeReachedPeer(); + CompletableFuture> fxi2 = handshakeReachedPeerCf + .thenApply((unused) -> MinimalFuture.completedFuture(null)); + fxi2 = fxi2.completeOnTimeout( + MinimalFuture.failedFuture(new HttpConnectTimeoutException("quic handshake timeout")), + directTimeout.get().toMillis(), TimeUnit.MILLISECONDS); + fxi2.handleAsync((r, t) -> { + if (t != null) { + var cause = Utils.getCompletionCause(t); + // arrange for handshakeCfCf to timeout + handshakeCfCf.completeExceptionally(cause); + } + if (r.isCompletedExceptionally()) { + var cause = Utils.getCompletionCause(r.exceptionNow()); + // arrange for handshakeCfCf to timeout + handshakeCfCf.completeExceptionally(cause); + } + return r; + }, exchange.parentExecutor.safeDelegate()); + } + + Optional timeout = client().connectTimeout(); + CompletableFuture> fxi = handshakeCfCf; + + // In case of connection timeout, set up a timeout on the handshakeCfCf. + // Note: this is a different timeout than the direct connection timeout. + if (timeout.isPresent()) { + // In case of timeout we need to close the quic connection + debug.log("setting up quic connect timeout: " + timeout.get().toMillis()); + fxi = handshakeCfCf.completeOnTimeout( + MinimalFuture.failedFuture(new HttpConnectTimeoutException("quic connect timeout")), + timeout.get().toMillis(), TimeUnit.MILLISECONDS); + } + + // If we have set up any timeout, arrange to close the quicConnection + // if one of the timeout expires + if (timeout.isPresent() || directTimeout.isPresent()) { + fxi = fxi.handleAsync(this::handleTimeout, exchange.parentExecutor.safeDelegate()); + } + return fxi.thenCompose(Function.identity()); + } + + @Override + public void connectionEstablished(Http3Connection connection, + Throwable throwable) { + connFinishedAction.onConnectionEstablished(connection, throwable); + } + + private CompletableFuture handleTimeout(CompletableFuture r, Throwable t) { + if (t != null) { + if (Utils.getCompletionCause(t) instanceof HttpConnectTimeoutException te) { + debug.log("Timeout expired: " + te); + close(H3_NO_ERROR.code(), "timeout expired", te); + return MinimalFuture.failedFuture(te); + } + return MinimalFuture.failedFuture(t); + } else if (r.isCompletedExceptionally()) { + t = r.exceptionNow(); + if (Utils.getCompletionCause(t) instanceof HttpConnectTimeoutException te) { + debug.log("Completed in timeout: " + te); + close(H3_NO_ERROR.code(), "timeout expired", te); + } + } + return r; + } + + + @Override + NetworkChannel /* DatagramChannel */ channel() { + // Note: revisit this + // - don't return a new instance each time + // - see if we could avoid exposing + // the channel in the first place + H3QuicConnectionImpl self = this; + return new NetworkChannel() { + @Override + public NetworkChannel bind(SocketAddress local) throws IOException { + throw new UnsupportedOperationException("no bind for a quic connection"); + } + + @Override + public SocketAddress getLocalAddress() throws IOException { + return quicConnection.localAddress(); + } + + @Override + public NetworkChannel setOption(SocketOption name, T value) throws IOException { + return this; + } + + @Override + public T getOption(SocketOption name) throws IOException { + return null; + } + + @Override + public Set> supportedOptions() { + return Set.of(); + } + + @Override + public boolean isOpen() { + return quicConnection.isOpen(); + } + + @Override + public void close() throws IOException { + self.close(); + } + }; + } + + @Override + CacheKey cacheKey() { + return null; + } + + // close with H3_NO_ERROR + @Override + public final void close() { + close(H3_NO_ERROR.code(), "connection closed", null); + } + + @Override + void close(final Throwable cause) { + close(H3_INTERNAL_ERROR.code(), null, cause); + } + } + + /* Tells whether this connection is a tunnel through a proxy */ + boolean isTunnel() { return false; } + + abstract NetworkChannel /* DatagramChannel */ channel(); + + abstract ConnectionPool.CacheKey cacheKey(); + + /** + * Closes the underlying transport connection with + * the given {@code connCloseCode} code. This will be considered a application + * layer close and will generate a {@code ConnectionCloseFrame} + * of type {@code 0x1d} as the cause of the termination. + * + * @param connCloseCode the connection close code + * @param logMsg the message to be included in the logs as + * the cause of the connection termination. can be null. + * @param closeCause the underlying cause of the connection termination. can be null, + * in which case just the {@code error} will be recorded as the + * cause of the connection termination. + */ + final void close(final long connCloseCode, final String logMsg, + final Throwable closeCause) { + final TerminationCause terminationCause; + if (closeCause == null) { + terminationCause = appLayerClose(connCloseCode); + } else { + terminationCause = appLayerException(connCloseCode, closeCause); + } + // set the log message only if non-null, else let it default to internal + // implementation sensible default + if (logMsg != null) { + terminationCause.loggedAs(logMsg); + } + quicConnTerminator.terminate(terminationCause); + } + + abstract void close(final Throwable t); + + /** + * {@inheritDoc} + * + * @implSpec + * Unlike HTTP/1.1 and HTTP/2, an HTTP/3 connection is not + * built on a single connection flow, since multiplexing is + * provided by the lower layer. Therefore, the higher HTTP + * layer should never call {@code getConnectionFlow()} on an + * {@link HttpQuicConnection}. As a consequence, this method + * always throws {@link IllegalStateException} unconditionally. + * + * @return nothing: this method always throw {@link IllegalStateException} + * + * @throws IllegalStateException always + */ + @Override + final FlowTube getConnectionFlow() { + throw new IllegalStateException( + "An HTTP/3 connection does not expose " + + "a single connection flow"); + } + + /** + * Unlike HTTP/1.1 and HTTP/2, an HTTP/3 connection is not + * built on a single connection flow, since multiplexing is + * provided by the lower layer. This method instead will + * return {@code true} if the underlying quic connection + * has been terminated, either exceptionally or normally. + * + * @return {@code true} if the underlying Quic connection + * has been terminated. + */ + @Override + boolean isFlowFinished() { + return !quicConnection().isOpen(); + } + + @Override + public String toString() { + return quicDbgString(); + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/HttpRequestBuilderImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/HttpRequestBuilderImpl.java index a495fcce1ee..c68965626b7 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/HttpRequestBuilderImpl.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/HttpRequestBuilderImpl.java @@ -26,12 +26,18 @@ package jdk.internal.net.http; import java.net.URI; +import java.net.http.HttpRequest.Builder; +import java.net.http.HttpOption; import java.time.Duration; +import java.util.HashMap; import java.util.Locale; +import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpRequest.BodyPublisher; +import java.util.Set; import jdk.internal.net.http.common.HttpHeadersBuilder; import jdk.internal.net.http.common.Utils; @@ -51,6 +57,10 @@ public class HttpRequestBuilderImpl implements HttpRequest.Builder { private BodyPublisher bodyPublisher; private volatile Optional version; private Duration duration; + private final Map, Object> options = new HashMap<>(); + + private static final Set> supportedOptions = + Set.of(HttpOption.H3_DISCOVERY); public HttpRequestBuilderImpl(URI uri) { requireNonNull(uri, "uri must be non-null"); @@ -100,6 +110,7 @@ public class HttpRequestBuilderImpl implements HttpRequest.Builder { b.uri = uri; b.duration = duration; b.version = version; + b.options.putAll(Map.copyOf(options)); return b; } @@ -158,6 +169,19 @@ public class HttpRequestBuilderImpl implements HttpRequest.Builder { return this; } + @Override + public Builder setOption(HttpOption option, T value) { + Objects.requireNonNull(option, "option"); + if (value == null) options.remove(option); + else if (supportedOptions.contains(option)) { + if (!option.type().isInstance(value)) { + throw newIAE("Illegal value type %s for %s", value, option); + } + options.put(option, value); + } // otherwise just ignore the option + return this; + } + HttpHeadersBuilder headersBuilder() { return headersBuilder; } URI uri() { return uri; } @@ -170,6 +194,8 @@ public class HttpRequestBuilderImpl implements HttpRequest.Builder { Optional version() { return version; } + Map, Object> options() { return options; } + @Override public HttpRequest.Builder GET() { return method0("GET", null); @@ -245,4 +271,30 @@ public class HttpRequestBuilderImpl implements HttpRequest.Builder { Duration timeout() { return duration; } + public static Map, Object> copySupportedOptions(HttpRequest request) { + Objects.requireNonNull(request, "request"); + if (request instanceof ImmutableHttpRequest ihr) { + // already checked and immutable + return ihr.options(); + } + Map, Object> options = new HashMap<>(); + for (HttpOption option : supportedOptions) { + var val = request.getOption(option); + if (!val.isPresent()) continue; + options.put(option, option.type().cast(val.get())); + } + return Map.copyOf(options); + } + + public static Map, Object> copySupportedOptions(Map, Object> options) { + Objects.requireNonNull(options, "option"); + Map, Object> result = new HashMap<>(); + for (HttpOption option : supportedOptions) { + var val = options.get(option); + if (val == null) continue; + result.put(option, option.type().cast(val)); + } + return Map.copyOf(result); + } + } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/HttpRequestImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/HttpRequestImpl.java index 81c693ea192..00730baa0ce 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/HttpRequestImpl.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/HttpRequestImpl.java @@ -31,9 +31,13 @@ import java.net.InetSocketAddress; import java.net.Proxy; import java.net.ProxySelector; import java.net.URI; +import java.net.http.HttpClient.Version; +import java.net.http.HttpOption; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.time.Duration; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.net.http.HttpClient; @@ -46,6 +50,7 @@ import jdk.internal.net.http.common.HttpHeadersBuilder; import jdk.internal.net.http.common.Utils; import jdk.internal.net.http.websocket.WebSocketRequest; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.net.Authenticator.RequestorType.SERVER; import static jdk.internal.net.http.common.Utils.ALLOWED_HEADERS; import static jdk.internal.net.http.common.Utils.ProxyHeaders; @@ -65,6 +70,8 @@ public class HttpRequestImpl extends HttpRequest implements WebSocketRequest { private volatile boolean isWebSocket; private final Duration timeout; // may be null private final Optional version; + // An alternative would be to have one field per supported option + private final Map, Object> options; private volatile boolean userSetAuthorization; private volatile boolean userSetProxyAuthorization; @@ -92,6 +99,7 @@ public class HttpRequestImpl extends HttpRequest implements WebSocketRequest { this.requestPublisher = builder.bodyPublisher(); // may be null this.timeout = builder.timeout(); this.version = builder.version(); + this.options = Map.copyOf(builder.options()); this.authority = null; } @@ -110,12 +118,13 @@ public class HttpRequestImpl extends HttpRequest implements WebSocketRequest { "uri must be non null"); Duration timeout = request.timeout().orElse(null); this.method = method == null ? "GET" : method; + this.options = HttpRequestBuilderImpl.copySupportedOptions(request); this.userHeaders = HttpHeaders.of(request.headers().map(), Utils.VALIDATE_USER_HEADER); - if (request instanceof HttpRequestImpl) { + if (request instanceof HttpRequestImpl impl) { // all cases exception WebSocket should have a new system headers - this.isWebSocket = ((HttpRequestImpl) request).isWebSocket; + this.isWebSocket = impl.isWebSocket; if (isWebSocket) { - this.systemHeadersBuilder = ((HttpRequestImpl)request).systemHeadersBuilder; + this.systemHeadersBuilder = impl.systemHeadersBuilder; } else { this.systemHeadersBuilder = new HttpHeadersBuilder(); } @@ -199,6 +208,19 @@ public class HttpRequestImpl extends HttpRequest implements WebSocketRequest { this.timeout = other.timeout; this.version = other.version(); this.authority = null; + this.options = other.optionsFor(this.uri); + } + + private Map, Object> optionsFor(URI uri) { + if (this.uri == uri || Objects.equals(this.uri.getRawAuthority(), uri.getRawAuthority())) { + return options; + } + // preserve config if version is HTTP/3 + if (version.orElse(null) == Version.HTTP_3) { + Http3DiscoveryMode h3DiscoveryMode = (Http3DiscoveryMode)options.get(H3_DISCOVERY); + if (h3DiscoveryMode != null) return Map.of(H3_DISCOVERY, h3DiscoveryMode); + } + return Map.of(); } private BodyPublisher publisher(HttpRequestImpl other) { @@ -234,12 +256,26 @@ public class HttpRequestImpl extends HttpRequest implements WebSocketRequest { // What we want to possibly upgrade is the tunneled connection to the // target server (so not the CONNECT request itself) this.version = Optional.of(HttpClient.Version.HTTP_1_1); + this.options = Map.of(); } final boolean isConnect() { return "CONNECT".equalsIgnoreCase(method); } + final boolean isHttp3Only(Version version) { + return version == Version.HTTP_3 && http3Discovery() == HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; + } + + final Http3DiscoveryMode http3Discovery() { + // see if discovery mode is set on the request + final var h3Discovery = getOption(H3_DISCOVERY); + // if no explicit discovery mode is set, then default to "ANY" + // irrespective of whether the HTTP/3 version may have been + // set on the HttpClient or the HttpRequest + return h3Discovery.orElse(Http3DiscoveryMode.ANY); + } + /** * Creates a HttpRequestImpl from the given set of Headers and the associated * "parent" request. Fields not taken from the headers are taken from the @@ -276,6 +312,7 @@ public class HttpRequestImpl extends HttpRequest implements WebSocketRequest { this.timeout = parent.timeout; this.version = parent.version; this.authority = null; + this.options = parent.options; } @Override @@ -399,6 +436,11 @@ public class HttpRequestImpl extends HttpRequest implements WebSocketRequest { @Override public Optional version() { return version; } + @Override + public Optional getOption(HttpOption option) { + return Optional.ofNullable(option.type().cast(options.get(option))); + } + @Override public void setSystemHeader(String name, String value) { systemHeadersBuilder.setHeader(name, value); diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/HttpResponseImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/HttpResponseImpl.java index 1552cd40ede..ba5bea43078 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/HttpResponseImpl.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/HttpResponseImpl.java @@ -41,14 +41,14 @@ import jdk.internal.net.http.websocket.RawChannel; /** * The implementation class for HttpResponse */ -class HttpResponseImpl implements HttpResponse, RawChannel.Provider { +final class HttpResponseImpl implements HttpResponse, RawChannel.Provider { final int responseCode; private final String connectionLabel; final HttpRequest initialRequest; - final Optional> previousResponse; + final HttpResponse previousResponse; // may be null; final HttpHeaders headers; - final Optional sslSession; + final SSLSession sslSession; // may be null final URI uri; final HttpClient.Version version; final RawChannelProvider rawChannelProvider; @@ -62,10 +62,10 @@ class HttpResponseImpl implements HttpResponse, RawChannel.Provider { this.responseCode = response.statusCode(); this.connectionLabel = connectionLabel(exch).orElse(null); this.initialRequest = initialRequest; - this.previousResponse = Optional.ofNullable(previousResponse); + this.previousResponse = previousResponse; this.headers = response.headers(); //this.trailers = trailers; - this.sslSession = Optional.ofNullable(response.getSSLSession()); + this.sslSession = response.getSSLSession(); this.uri = response.request().uri(); this.version = response.version(); this.rawChannelProvider = RawChannelProvider.create(response, exch); @@ -96,7 +96,7 @@ class HttpResponseImpl implements HttpResponse, RawChannel.Provider { @Override public Optional> previousResponse() { - return previousResponse; + return Optional.ofNullable(previousResponse); } @Override @@ -111,7 +111,7 @@ class HttpResponseImpl implements HttpResponse, RawChannel.Provider { @Override public Optional sslSession() { - return sslSession; + return Optional.ofNullable(sslSession); } @Override diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/ImmutableHttpRequest.java b/src/java.net.http/share/classes/jdk/internal/net/http/ImmutableHttpRequest.java index 1fb96944afc..b2c15ac3bef 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/ImmutableHttpRequest.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/ImmutableHttpRequest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -28,7 +28,9 @@ package jdk.internal.net.http; import java.net.URI; import java.net.http.HttpHeaders; import java.net.http.HttpRequest; +import java.net.http.HttpOption; import java.time.Duration; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.net.http.HttpClient.Version; @@ -43,6 +45,8 @@ final class ImmutableHttpRequest extends HttpRequest { private final boolean expectContinue; private final Optional timeout; private final Optional version; + // An alternative would be to have one field per supported option + private final Map, Object> options; /** Creates an ImmutableHttpRequest from the given builder. */ ImmutableHttpRequest(HttpRequestBuilderImpl builder) { @@ -53,6 +57,7 @@ final class ImmutableHttpRequest extends HttpRequest { this.expectContinue = builder.expectContinue(); this.timeout = Optional.ofNullable(builder.timeout()); this.version = Objects.requireNonNull(builder.version()); + this.options = Map.copyOf(builder.options()); } @Override @@ -78,8 +83,17 @@ final class ImmutableHttpRequest extends HttpRequest { @Override public Optional version() { return version; } + @Override + public Optional getOption(HttpOption option) { + return Optional.ofNullable(option.type().cast(options.get(option))); + } + @Override public String toString() { return uri.toString() + " " + method; } + + public Map, Object> options() { + return options; + } } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/MultiExchange.java b/src/java.net.http/share/classes/jdk/internal/net/http/MultiExchange.java index 20120aad7d5..ec621f7f955 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/MultiExchange.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/MultiExchange.java @@ -29,7 +29,9 @@ import java.io.IOError; import java.io.IOException; import java.lang.ref.WeakReference; import java.net.ConnectException; +import java.net.http.HttpClient.Version; import java.net.http.HttpConnectTimeoutException; +import java.net.http.StreamLimitException; import java.time.Duration; import java.util.List; import java.util.ListIterator; @@ -38,8 +40,6 @@ import java.util.Optional; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; -import java.util.concurrent.CompletionException; -import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicInteger; @@ -62,6 +62,8 @@ import jdk.internal.net.http.common.ConnectionExpiredException; import jdk.internal.net.http.common.Utils; import static jdk.internal.net.http.common.MinimalFuture.completedFuture; import static jdk.internal.net.http.common.MinimalFuture.failedFuture; +import static jdk.internal.net.http.AltSvcProcessor.processAltSvcHeader; + /** * Encapsulates multiple Exchanges belonging to one HttpRequestImpl. @@ -76,6 +78,16 @@ class MultiExchange implements Cancelable { static final Logger debug = Utils.getDebugLogger("MultiExchange"::toString, Utils.DEBUG); + private record RetryContext(Throwable requestFailureCause, + boolean shouldRetry, + AtomicInteger reqAttemptCounter, + boolean shouldResetConnectTimer) { + private static RetryContext doNotRetry(Throwable requestFailureCause) { + return new RetryContext(requestFailureCause, false, null, false); + } + } + + private static final AtomicLong IDS = new AtomicLong(); private final HttpRequest userRequest; // the user request private final HttpRequestImpl request; // a copy of the user request private final ConnectTimeoutTracker connectTimeout; // null if no timeout @@ -83,12 +95,11 @@ class MultiExchange implements Cancelable { final HttpResponse.BodyHandler responseHandler; final HttpClientImpl.DelegatingExecutor executor; final AtomicInteger attempts = new AtomicInteger(); + final long id = IDS.incrementAndGet(); HttpRequestImpl currentreq; // used for retries & redirect HttpRequestImpl previousreq; // used for retries & redirect Exchange exchange; // the current exchange Exchange previous; - volatile Throwable retryCause; - volatile boolean retriedOnce; volatile HttpResponse response; // Maximum number of times a request will be retried/redirected @@ -98,6 +109,12 @@ class MultiExchange implements Cancelable { "jdk.httpclient.redirects.retrylimit", DEFAULT_MAX_ATTEMPTS ); + // Maximum number of times a request should be retried when + // max streams limit is reached + static final int max_stream_limit_attempts = Utils.getIntegerNetProperty( + "jdk.httpclient.retryOnStreamlimit", max_attempts + ); + private final List filters; volatile ResponseTimerEvent responseTimerEvent; volatile boolean cancelled; @@ -113,20 +130,22 @@ class MultiExchange implements Cancelable { volatile AuthenticationFilter.AuthInfo serverauth, proxyauth; // RedirectHandler volatile int numberOfRedirects = 0; + // StreamLimit + private final AtomicInteger streamLimitRetries = new AtomicInteger(); // This class is used to keep track of the connection timeout // across retries, when a ConnectException causes a retry. // In that case - we will retry the connect, but we don't // want to double the timeout by starting a new timer with // the full connectTimeout again. - // Instead we use the ConnectTimeoutTracker to return a new + // Instead, we use the ConnectTimeoutTracker to return a new // duration that takes into account the time spent in the // first connect attempt. // If however, the connection gets connected, but we later // retry the whole operation, then we reset the timer before // retrying (since the connection used for the second request // will not necessarily be the same: it could be a new - // unconnected connection) - see getExceptionalCF(). + // unconnected connection) - see checkRetryEligible(). private static final class ConnectTimeoutTracker { final Duration max; final AtomicLong startTime = new AtomicLong(); @@ -199,8 +218,22 @@ class MultiExchange implements Cancelable { HttpClient.Version version() { HttpClient.Version vers = request.version().orElse(client.version()); - if (vers == HttpClient.Version.HTTP_2 && !request.secure() && request.proxy() != null) + if (vers != Version.HTTP_1_1 + && !request.secure() && request.proxy() != null + && !request.isHttp3Only(vers)) { + // downgrade to HTTP_1_1 unless HTTP_3_URI_ONLY. + // if HTTP_3_URI_ONLY and not secure it will fail down the road, so + // we don't downgrade here. vers = HttpClient.Version.HTTP_1_1; + } + if (vers == Version.HTTP_3 && request.secure() && !client.client3().isPresent()) { + if (!request.isHttp3Only(vers)) { + // HTTP/3 not supported with the client config. + // Downgrade to HTTP/2, unless HTTP_3_URI_ONLY is specified + vers = Version.HTTP_2; + if (debug.on()) debug.log("HTTP_3 downgraded to " + vers); + } + } return vers; } @@ -229,28 +262,28 @@ class MultiExchange implements Cancelable { } private void requestFilters(HttpRequestImpl r) throws IOException { - Log.logTrace("Applying request filters"); + if (Log.trace()) Log.logTrace("Applying request filters"); for (HeaderFilter filter : filters) { - Log.logTrace("Applying {0}", filter); + if (Log.trace()) Log.logTrace("Applying {0}", filter); filter.request(r, this); } - Log.logTrace("All filters applied"); + if (Log.trace()) Log.logTrace("All filters applied"); } private HttpRequestImpl responseFilters(Response response) throws IOException { - Log.logTrace("Applying response filters"); + if (Log.trace()) Log.logTrace("Applying response filters"); ListIterator reverseItr = filters.listIterator(filters.size()); while (reverseItr.hasPrevious()) { HeaderFilter filter = reverseItr.previous(); - Log.logTrace("Applying {0}", filter); + if (Log.trace()) Log.logTrace("Applying {0}", filter); HttpRequestImpl newreq = filter.response(response); if (newreq != null) { - Log.logTrace("New request: stopping filters"); + if (Log.trace()) Log.logTrace("New request: stopping filters"); return newreq; } } - Log.logTrace("All filters applied"); + if (Log.trace()) Log.logTrace("All filters applied"); return null; } @@ -293,9 +326,13 @@ class MultiExchange implements Cancelable { return true; } else { if (cancelled) { - debug.log("multi exchange already cancelled: " + interrupted.get()); + if (debug.on()) { + debug.log("multi exchange already cancelled: " + interrupted.get()); + } } else { - debug.log("multi exchange mayInterruptIfRunning=" + mayInterruptIfRunning); + if (debug.on()) { + debug.log("multi exchange mayInterruptIfRunning=" + mayInterruptIfRunning); + } } } return false; @@ -316,7 +353,7 @@ class MultiExchange implements Cancelable { // and therefore doesn't have to include header information which indicates no // body is present. This is distinct from responses that also do not contain // response bodies (possibly ever) but which are required to have content length - // info in the header (eg 205). Those cases do not have to be handled specially + // info in the header (e.g. 205). Those cases do not have to be handled specially private static boolean bodyNotPermitted(Response r) { return r.statusCode == 204; @@ -344,19 +381,27 @@ class MultiExchange implements Cancelable { if (exception != null) result.completeExceptionally(exception); else { - this.response = - new HttpResponseImpl<>(r.request(), r, this.response, nullBody, exch); - result.complete(this.response); + result.complete(setNewResponse(r.request(), r, nullBody, exch)); } }); // ensure that the connection is closed or returned to the pool. return result.whenComplete(exch::nullBody); } + // creates a new HttpResponseImpl object and assign it to this.response + private HttpResponse setNewResponse(HttpRequest request, Response r, T body, Exchange exch) { + HttpResponse previousResponse = this.response; + return this.response = new HttpResponseImpl<>(request, r, previousResponse, body, exch); + } + private CompletableFuture> responseAsync0(CompletableFuture start) { - return start.thenCompose( v -> responseAsyncImpl()) - .thenCompose((Response r) -> { + return start.thenCompose( _ -> { + // this is the first attempt to have the request processed by the server + attempts.set(1); + return responseAsyncImpl(true); + }).thenCompose((Response r) -> { + processAltSvcHeader(r, client(), currentreq); Exchange exch = getExchange(); if (bodyNotPermitted(r)) { if (bodyIsPresent(r)) { @@ -368,15 +413,11 @@ class MultiExchange implements Cancelable { return handleNoBody(r, exch); } return exch.readBodyAsync(responseHandler) - .thenApply((T body) -> { - this.response = - new HttpResponseImpl<>(r.request(), r, this.response, body, exch); - return this.response; - }); + .thenApply((T body) -> setNewResponse(r.request, r, body, exch)); }).exceptionallyCompose(this::whenCancelled); } - // returns a CancellationExcpetion that wraps the given cause + // returns a CancellationException that wraps the given cause // if cancel(boolean) was called, the given cause otherwise private Throwable wrapIfCancelled(Throwable cause) { CancellationException interrupt = interrupted.get(); @@ -412,79 +453,100 @@ class MultiExchange implements Cancelable { } } - private CompletableFuture responseAsyncImpl() { - CompletableFuture cf; - if (attempts.incrementAndGet() > max_attempts) { - cf = failedFuture(new IOException("Too many retries", retryCause)); - } else { - if (currentreq.timeout().isPresent()) { - responseTimerEvent = ResponseTimerEvent.of(this); - client.registerTimer(responseTimerEvent); - } - try { - // 1. apply request filters - // if currentreq == previousreq the filters have already - // been applied once. Applying them a second time might - // cause some headers values to be added twice: for - // instance, the same cookie might be added again. - if (currentreq != previousreq) { - requestFilters(currentreq); - } - } catch (IOException e) { - return failedFuture(e); - } - Exchange exch = getExchange(); - // 2. get response - cf = exch.responseAsync() - .thenCompose((Response response) -> { - HttpRequestImpl newrequest; - try { - // 3. apply response filters - newrequest = responseFilters(response); - } catch (Throwable t) { - IOException e = t instanceof IOException io ? io : new IOException(t); - exch.exchImpl.cancel(e); - return failedFuture(e); - } - // 4. check filter result and repeat or continue - if (newrequest == null) { - if (attempts.get() > 1) { - Log.logError("Succeeded on attempt: " + attempts); - } - return completedFuture(response); - } else { - cancelTimer(); - this.response = - new HttpResponseImpl<>(currentreq, response, this.response, null, exch); - Exchange oldExch = exch; - if (currentreq.isWebSocket()) { - // need to close the connection and open a new one. - exch.exchImpl.connection().close(); - } - return exch.ignoreBody().handle((r,t) -> { - previousreq = currentreq; - currentreq = newrequest; - retriedOnce = false; - setExchange(new Exchange<>(currentreq, this)); - return responseAsyncImpl(); - }).thenCompose(Function.identity()); - } }) - .handle((response, ex) -> { - // 5. handle errors and cancel any timer set - cancelTimer(); - if (ex == null) { - assert response != null; - return completedFuture(response); - } - // all exceptions thrown are handled here - CompletableFuture errorCF = getExceptionalCF(ex, exch.exchImpl); - if (errorCF == null) { - return responseAsyncImpl(); - } else { - return errorCF; - } }) - .thenCompose(Function.identity()); + // we call this only when a request is being retried + private CompletableFuture retryRequest() { + // maintain state indicating a request being retried + previousreq = currentreq; + // request is being retried, so the filters have already + // been applied once. Applying them a second time might + // cause some headers values to be added twice: for + // instance, the same cookie might be added again. + final boolean applyReqFilters = false; + return responseAsyncImpl(applyReqFilters); + } + + private CompletableFuture responseAsyncImpl(final boolean applyReqFilters) { + if (currentreq.timeout().isPresent()) { + responseTimerEvent = ResponseTimerEvent.of(this); + client.registerTimer(responseTimerEvent); } + try { + // 1. apply request filters + if (applyReqFilters) { + requestFilters(currentreq); + } + } catch (IOException e) { + return failedFuture(e); + } + final Exchange exch = getExchange(); + // 2. get response + final CompletableFuture cf = exch.responseAsync() + .thenCompose((Response response) -> { + HttpRequestImpl newrequest; + try { + // 3. apply response filters + newrequest = responseFilters(response); + } catch (Throwable t) { + IOException e = t instanceof IOException io ? io : new IOException(t); + exch.exchImpl.cancel(e); + return failedFuture(e); + } + // 4. check filter result and repeat or continue + if (newrequest == null) { + if (attempts.get() > 1) { + if (Log.requests()) { + Log.logResponse(() -> String.format( + "%s #%s Succeeded on attempt %s: statusCode=%s", + request, id, attempts, response.statusCode)); + } + } + return completedFuture(response); + } else { + cancelTimer(); + setNewResponse(currentreq, response, null, exch); + if (currentreq.isWebSocket()) { + // need to close the connection and open a new one. + exch.exchImpl.connection().close(); + } + return exch.ignoreBody().handle((r,t) -> { + previousreq = currentreq; + currentreq = newrequest; + // this is the first attempt to have the new request + // processed by the server + attempts.set(1); + setExchange(new Exchange<>(currentreq, this)); + return responseAsyncImpl(true); + }).thenCompose(Function.identity()); + } }) + .handle((response, ex) -> { + // 5. handle errors and cancel any timer set + cancelTimer(); + if (ex == null) { + assert response != null; + return completedFuture(response); + } + // all exceptions thrown are handled here + final RetryContext retryCtx = checkRetryEligible(ex, exch); + assert retryCtx != null : "retry context is null"; + if (retryCtx.shouldRetry()) { + // increment the request attempt counter and retry the request + assert retryCtx.reqAttemptCounter != null : "request attempt counter is null"; + final int numAttempt = retryCtx.reqAttemptCounter.incrementAndGet(); + if (debug.on()) { + debug.log("Retrying request: " + currentreq + " id: " + id + + " attempt: " + numAttempt + " due to: " + + retryCtx.requestFailureCause); + } + // reset the connect timer if necessary + if (retryCtx.shouldResetConnectTimer && this.connectTimeout != null) { + this.connectTimeout.reset(); + } + return retryRequest(); + } else { + assert retryCtx.requestFailureCause != null : "missing request failure cause"; + return MinimalFuture.failedFuture(retryCtx.requestFailureCause); + } }) + .thenCompose(Function.identity()); return cf; } @@ -492,14 +554,14 @@ class MultiExchange implements Cancelable { String s = Utils.getNetProperty("jdk.httpclient.enableAllMethodRetry"); if (s == null) return false; - return s.isEmpty() ? true : Boolean.parseBoolean(s); + return s.isEmpty() || Boolean.parseBoolean(s); } private static boolean disableRetryConnect() { String s = Utils.getNetProperty("jdk.httpclient.disableRetryConnect"); if (s == null) return false; - return s.isEmpty() ? true : Boolean.parseBoolean(s); + return s.isEmpty() || Boolean.parseBoolean(s); } /** True if ALL ( even non-idempotent ) requests can be automatic retried. */ @@ -517,7 +579,7 @@ class MultiExchange implements Cancelable { } /** Returns true if the given request can be automatically retried. */ - private static boolean canRetryRequest(HttpRequest request) { + private static boolean isHttpMethodRetriable(HttpRequest request) { if (RETRY_ALWAYS) return true; if (isIdempotentRequest(request)) @@ -534,70 +596,125 @@ class MultiExchange implements Cancelable { return interrupted.get() != null; } - private boolean retryOnFailure(Throwable t) { - if (requestCancelled()) return false; - return t instanceof ConnectionExpiredException - || (RETRY_CONNECT && (t instanceof ConnectException)); - } - - private Throwable retryCause(Throwable t) { - Throwable cause = t instanceof ConnectionExpiredException ? t.getCause() : t; - return cause == null ? t : cause; + String streamLimitState() { + return id + " attempt:" + streamLimitRetries.get(); } /** - * Takes a Throwable and returns a suitable CompletableFuture that is - * completed exceptionally, or null. + * This method determines if a failed request can be retried. The returned RetryContext + * will contain the {@linkplain RetryContext#shouldRetry() retry decision} and the + * {@linkplain RetryContext#requestFailureCause() underlying + * cause} (computed out of the given {@code requestFailureCause}) of the request failure. + * + * @param requestFailureCause the exception that caused the request to fail + * @param exchg the Exchange + * @return a non-null RetryContext which contains the result of retry eligibility */ - private CompletableFuture getExceptionalCF(Throwable t, ExchangeImpl exchImpl) { - if ((t instanceof CompletionException) || (t instanceof ExecutionException)) { - if (t.getCause() != null) { - t = t.getCause(); + private RetryContext checkRetryEligible(final Throwable requestFailureCause, + final Exchange exchg) { + assert requestFailureCause != null : "request failure cause is missing"; + assert exchg != null : "exchange cannot be null"; + // determine the underlying cause for the request failure + final Throwable t = Utils.getCompletionCause(requestFailureCause); + final Throwable underlyingCause = switch (t) { + case IOException ioe -> { + if (cancelled && !requestCancelled() && !(ioe instanceof HttpTimeoutException)) { + yield toTimeoutException(ioe); + } + yield ioe; } + default -> { + yield t; + } + }; + if (requestCancelled()) { + // request has been cancelled, do not retry + return RetryContext.doNotRetry(underlyingCause); } - final boolean retryAsUnprocessed = exchImpl != null && exchImpl.isUnprocessedByPeer(); - if (cancelled && !requestCancelled() && t instanceof IOException) { - if (!(t instanceof HttpTimeoutException)) { - t = toTimeoutException((IOException)t); + // check if retry limited is reached. if yes then don't retry. + record Limit(int numAttempts, int maxLimit) { + boolean retryLimitReached() { + return Limit.this.numAttempts >= Limit.this.maxLimit; } - } else if (retryAsUnprocessed || retryOnFailure(t)) { - Throwable cause = retryCause(t); - - if (!(t instanceof ConnectException)) { - // we may need to start a new connection, and if so - // we want to start with a fresh connect timeout again. - if (connectTimeout != null) connectTimeout.reset(); - if (!retryAsUnprocessed && !canRetryRequest(currentreq)) { - // a (peer) processed request which cannot be retried, fail with - // the original cause - return failedFuture(cause); - } - } // ConnectException: retry, but don't reset the connectTimeout. - - // allow the retry mechanism to do its work - retryCause = cause; - if (!retriedOnce) { - if (debug.on()) { - debug.log(t.getClass().getSimpleName() - + " (async): retrying " + currentreq + " due to: ", t); - } - retriedOnce = true; - // The connection was abruptly closed. - // We return null to retry the same request a second time. - // The request filters have already been applied to the - // currentreq, so we set previousreq = currentreq to - // prevent them from being applied again. - previousreq = currentreq; - return null; - } else { - if (debug.on()) { - debug.log(t.getClass().getSimpleName() - + " (async): already retried once " + currentreq, t); - } - t = cause; + }; + final Limit limit = switch (underlyingCause) { + case StreamLimitException _ -> { + yield new Limit(streamLimitRetries.get(), max_stream_limit_attempts); } + case ConnectException _ -> { + // for ConnectException (i.e. inability to establish a connection to the server) + // we currently retry the request only once and don't honour the + // "jdk.httpclient.redirects.retrylimit" configuration value. + yield new Limit(attempts.get(), 2); + } + default -> { + yield new Limit(attempts.get(), max_attempts); + } + }; + if (limit.retryLimitReached()) { + if (debug.on()) { + debug.log("request already attempted " + + limit.numAttempts + " times, won't be retried again " + + currentreq + " " + id, underlyingCause); + } + final var x = underlyingCause instanceof ConnectionExpiredException cee + ? cee.getCause() == null ? cee : cee.getCause() + : underlyingCause; + // do not retry anymore + return RetryContext.doNotRetry(x); } - return failedFuture(t); + return switch (underlyingCause) { + case ConnectException _ -> { + // connection attempt itself failed, so the request hasn't reached the server. + // check if retry on connection failure is enabled, if not then we don't retry + // the request. + if (!RETRY_CONNECT) { + // do not retry + yield RetryContext.doNotRetry(underlyingCause); + } + // OK to retry. Since the failure is due to a connection/stream being unavailable + // we mark the retry context to not allow the connect timer to be reset + // when the retry is actually attempted. + yield new RetryContext(underlyingCause, true, attempts, false); + } + case StreamLimitException sle -> { + // make a note that the stream limit was reached for a particular HTTP version + exchg.streamLimitReached(true); + // OK to retry. Since the failure is due to a connection/stream being unavailable + // we mark the retry context to not allow the connect timer to be reset + // when the retry is actually attempted. + yield new RetryContext(underlyingCause, true, streamLimitRetries, false); + } + case ConnectionExpiredException cee -> { + final Throwable cause = cee.getCause() == null ? cee : cee.getCause(); + // check if the request was explicitly marked as unprocessed, in which case + // we retry + if (exchg.isUnprocessedByPeer()) { + // OK to retry and allow for the connect timer to be reset + yield new RetryContext(cause, true, attempts, true); + } + // the request which failed hasn't been marked as unprocessed which implies that + // it could be processed by the server. check if the request's METHOD allows + // for retry. + if (!isHttpMethodRetriable(currentreq)) { + // request METHOD doesn't allow for retry + yield RetryContext.doNotRetry(cause); + } + // OK to retry and allow for the connect timer to be reset + yield new RetryContext(cause, true, attempts, true); + } + default -> { + // some other exception that caused the request to fail. + // we check if the request has been explicitly marked as "unprocessed", + // which implies the server hasn't processed the request and is thus OK to retry. + if (exchg.isUnprocessedByPeer()) { + // OK to retry and allow for resetting the connect timer + yield new RetryContext(underlyingCause, true, attempts, false); + } + // some other cause of failure, do not retry. + yield RetryContext.doNotRetry(underlyingCause); + } + }; } private HttpTimeoutException toTimeoutException(IOException ioe) { diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Origin.java b/src/java.net.http/share/classes/jdk/internal/net/http/Origin.java index adbee565297..8aee8ef2230 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/Origin.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Origin.java @@ -26,6 +26,7 @@ package jdk.internal.net.http; import java.net.URI; +import java.net.URISyntaxException; import java.util.Locale; import java.util.Objects; @@ -132,6 +133,46 @@ public record Origin(String scheme, String host, int port) { return host + ":" + port; } + /** + * {@return true if the Origin's scheme is considered secure, else returns false} + */ + boolean isSecure() { + // we consider https to be the only secure scheme + return scheme.equals("https"); + } + + /** + * {@return Creates and returns an Origin parsed from the ASCII serialized form as defined + * in section 6.2 of RFC-6454} + * + * @param value The value to be parsed + */ + static Origin fromASCIISerializedForm(final String value) throws IllegalArgumentException { + Objects.requireNonNull(value); + try { + final URI uri = new URI(value); + // the ASCII-serialized form contains scheme://host, optionally followed by :port + if (uri.getScheme() == null || uri.getHost() == null) { + throw new IllegalArgumentException("Invalid ASCII serialized form of origin"); + } + // normalize the origin string, check if we get the same result + String normalized = uri.getScheme() + "://" + uri.getHost(); + if (uri.getPort() != -1) { + normalized += ":" + uri.getPort(); + } + if (!value.equals(normalized)) { + throw new IllegalArgumentException("Invalid ASCII serialized form of origin"); + } + try { + return Origin.from(uri); + } catch (IllegalArgumentException iae) { + throw new IllegalArgumentException("Invalid ASCII serialized form of origin", iae); + } + } catch (URISyntaxException use) { + throw new IllegalArgumentException("Invalid ASCII serialized form of origin", use); + } + } + private static boolean isValidScheme(final String scheme) { // only "http" and "https" literals allowed return "http".equals(scheme) || "https".equals(scheme); diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/PlainHttpConnection.java b/src/java.net.http/share/classes/jdk/internal/net/http/PlainHttpConnection.java index d0d64312f1f..e705aae72a1 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/PlainHttpConnection.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/PlainHttpConnection.java @@ -266,24 +266,8 @@ class PlainHttpConnection extends HttpConnection { try { this.chan = SocketChannel.open(); chan.configureBlocking(false); - if (debug.on()) { - int bufsize = getSoReceiveBufferSize(); - debug.log("Initial receive buffer size is: %d", bufsize); - bufsize = getSoSendBufferSize(); - debug.log("Initial send buffer size is: %d", bufsize); - } - if (trySetReceiveBufferSize(client.getReceiveBufferSize())) { - if (debug.on()) { - int bufsize = getSoReceiveBufferSize(); - debug.log("Receive buffer size configured: %d", bufsize); - } - } - if (trySetSendBufferSize(client.getSendBufferSize())) { - if (debug.on()) { - int bufsize = getSoSendBufferSize(); - debug.log("Send buffer size configured: %d", bufsize); - } - } + Utils.configureChannelBuffers(debug::log, chan, + client.getReceiveBufferSize(), client.getSendBufferSize()); chan.setOption(StandardSocketOptions.TCP_NODELAY, true); // wrap the channel in a Tube for async reading and writing tube = new SocketTube(client(), chan, Utils::getBuffer, label); @@ -292,54 +276,6 @@ class PlainHttpConnection extends HttpConnection { } } - private int getSoReceiveBufferSize() { - try { - return chan.getOption(StandardSocketOptions.SO_RCVBUF); - } catch (IOException x) { - if (debug.on()) - debug.log("Failed to get initial receive buffer size on %s", chan); - } - return 0; - } - - private int getSoSendBufferSize() { - try { - return chan.getOption(StandardSocketOptions.SO_SNDBUF); - } catch (IOException x) { - if (debug.on()) - debug.log("Failed to get initial receive buffer size on %s", chan); - } - return 0; - } - - private boolean trySetReceiveBufferSize(int bufsize) { - try { - if (bufsize > 0) { - chan.setOption(StandardSocketOptions.SO_RCVBUF, bufsize); - return true; - } - } catch (IOException x) { - if (debug.on()) - debug.log("Failed to set receive buffer size to %d on %s", - bufsize, chan); - } - return false; - } - - private boolean trySetSendBufferSize(int bufsize) { - try { - if (bufsize > 0) { - chan.setOption(StandardSocketOptions.SO_SNDBUF, bufsize); - return true; - } - } catch (IOException x) { - if (debug.on()) - debug.log("Failed to set send buffer size to %d on %s", - bufsize, chan); - } - return false; - } - @Override HttpPublisher publisher() { return writePublisher; } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/PushGroup.java b/src/java.net.http/share/classes/jdk/internal/net/http/PushGroup.java index f2c7a61c9b6..6bf03f195a7 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/PushGroup.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/PushGroup.java @@ -25,6 +25,7 @@ package jdk.internal.net.http; +import java.net.http.HttpResponse.PushPromiseHandler.PushId; import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.net.http.HttpRequest; @@ -105,9 +106,21 @@ class PushGroup { } Acceptor acceptPushRequest(HttpRequest pushRequest) { + return doAcceptPushRequest(pushRequest, null); + } + + Acceptor acceptPushRequest(HttpRequest pushRequest, PushId pushId) { + return doAcceptPushRequest(pushRequest, Objects.requireNonNull(pushId)); + } + + private Acceptor doAcceptPushRequest(HttpRequest pushRequest, PushId pushId) { AcceptorImpl acceptor = new AcceptorImpl<>(executor); try { - pushPromiseHandler.applyPushPromise(initiatingRequest, pushRequest, acceptor::accept); + if (pushId == null) { + pushPromiseHandler.applyPushPromise(initiatingRequest, pushRequest, acceptor::accept); + } else { + pushPromiseHandler.applyPushPromise(initiatingRequest, pushRequest, pushId, acceptor::accept); + } } catch (Throwable t) { if (acceptor.accepted()) { CompletableFuture cf = acceptor.cf(); @@ -128,6 +141,10 @@ class PushGroup { } } + void acceptPushPromiseId(PushId pushId) { + pushPromiseHandler.notifyAdditionalPromise(initiatingRequest, pushId); + } + void pushCompleted() { stateLock.lock(); try { diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Response.java b/src/java.net.http/share/classes/jdk/internal/net/http/Response.java index 949024453ca..e416773ee61 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/Response.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Response.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -25,13 +25,14 @@ package jdk.internal.net.http; -import java.net.URI; import java.io.IOException; +import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpHeaders; import java.net.InetSocketAddress; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLSession; + import jdk.internal.net.http.common.Utils; /** @@ -71,17 +72,14 @@ class Response { this.statusCode = statusCode; this.isConnectResponse = isConnectResponse; if (connection != null) { - InetSocketAddress a; - try { - a = (InetSocketAddress)connection.channel().getLocalAddress(); - } catch (IOException e) { - a = null; - } - this.localAddress = a; - if (connection instanceof AbstractAsyncSSLConnection) { - AbstractAsyncSSLConnection cc = (AbstractAsyncSSLConnection)connection; + this.localAddress = revealedLocalSocketAddress(connection); + if (connection instanceof AbstractAsyncSSLConnection cc) { SSLEngine engine = cc.getEngine(); sslSession = Utils.immutableSession(engine.getSession()); + } else if (connection instanceof HttpQuicConnection qc) { + // TODO: consider adding Optional getSession() to HttpConnection? + var session = qc.quicConnection().getTLSEngine().getSession(); + sslSession = Utils.immutableSession(session); } else { sslSession = null; } @@ -128,4 +126,12 @@ class Response { sb.append(" Local port: ").append(localAddress.getPort()); return sb.toString(); } + + private static InetSocketAddress revealedLocalSocketAddress(HttpConnection connection) { + try { + return (InetSocketAddress) connection.channel().getLocalAddress(); + } catch (IOException io) { + return null; + } + } } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/ResponseSubscribers.java b/src/java.net.http/share/classes/jdk/internal/net/http/ResponseSubscribers.java index 04d019e4c81..071c68720ac 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/ResponseSubscribers.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/ResponseSubscribers.java @@ -526,7 +526,7 @@ public class ResponseSubscribers { @Override public void onError(Throwable thrwbl) { if (debug.on()) - debug.log("onError called: " + thrwbl); + debug.log("onError called", thrwbl); subscription = null; failed = Objects.requireNonNull(thrwbl); // The client process that reads the input stream might @@ -1086,6 +1086,16 @@ public class ResponseSubscribers { bs.getBody().whenComplete((r, t) -> { if (t != null) { cf.completeExceptionally(t); + // if a user-provided BodySubscriber returns + // a getBody() CF completed exceptionally, it's + // the responsibility of that BodySubscriber to cancel + // its subscription in order to cancel the request, + // if operations are still in progress. + // Calling the errorHandler here would ensure that the + // request gets cancelled, but there me cases where that is + // not what the caller wants. Therefore, it's better to + // not call `errorHandler.accept(t);` here, but leave it + // to the provided BodySubscriber implementation. } else { cf.complete(r); } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java b/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java index b5dada882b2..4b59a777de2 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java @@ -59,6 +59,8 @@ import jdk.internal.net.http.common.*; import jdk.internal.net.http.frame.*; import jdk.internal.net.http.hpack.DecodingCallback; +import static jdk.internal.net.http.AltSvcProcessor.processAltSvcFrame; + import static jdk.internal.net.http.Exchange.MAX_NON_FINAL_RESPONSES; /** @@ -96,8 +98,8 @@ import static jdk.internal.net.http.Exchange.MAX_NON_FINAL_RESPONSES; * placed on the stream's inputQ which is consumed by the stream's * reader thread. * - * PushedStream sub class - * ====================== + * PushedStream subclass + * ===================== * Sending side methods are not used because the request comes from a PUSH_PROMISE * frame sent by the server. When a PUSH_PROMISE is received the PushedStream * is created. PushedStream does not use responseCF list as there can be only @@ -151,7 +153,7 @@ class Stream extends ExchangeImpl { // Indicates the first reason that was invoked when sending a ResetFrame // to the server. A streamState of 0 indicates that no reset was sent. // (see markStream(int code) - private volatile int streamState; // assigned using STREAM_STATE varhandle. + private volatile int streamState; // assigned while holding the sendLock. private volatile boolean deRegistered; // assigned using DEREGISTERED varhandle. // state flags @@ -219,7 +221,7 @@ class Stream extends ExchangeImpl { List buffers = df.getData(); List dsts = Collections.unmodifiableList(buffers); - int size = Utils.remaining(dsts, Integer.MAX_VALUE); + long size = Utils.remaining(dsts, Long.MAX_VALUE); if (size == 0 && finished) { inputQ.remove(); // consumed will not be called @@ -478,7 +480,9 @@ class Stream extends ExchangeImpl { if (code == 0) return streamState; sendLock.lock(); try { - return (int) STREAM_STATE.compareAndExchange(this, 0, code); + var state = streamState; + if (state == 0) streamState = code; + return state; } finally { sendLock.unlock(); } @@ -534,7 +538,7 @@ class Stream extends ExchangeImpl { this.requestPublisher = request.requestPublisher; // may be null this.responseHeadersBuilder = new HttpHeadersBuilder(); this.rspHeadersConsumer = new HeadersConsumer(); - this.requestPseudoHeaders = createPseudoHeaders(request); + this.requestPseudoHeaders = Utils.createPseudoHeaders(request); this.streamWindowUpdater = new StreamWindowUpdateSender(connection); } @@ -587,6 +591,7 @@ class Stream extends ExchangeImpl { case WindowUpdateFrame.TYPE -> incoming_windowUpdate((WindowUpdateFrame) frame); case ResetFrame.TYPE -> incoming_reset((ResetFrame) frame); case PriorityFrame.TYPE -> incoming_priority((PriorityFrame) frame); + case AltSvcFrame.TYPE -> handleAltSvcFrame(streamid, (AltSvcFrame) frame); default -> throw new IOException("Unexpected frame: " + frame); } @@ -745,6 +750,10 @@ class Stream extends ExchangeImpl { } } + void handleAltSvcFrame(int streamid, AltSvcFrame asf) { + processAltSvcFrame(streamid, asf, connection.connection, connection.client()); + } + void handleReset(ResetFrame frame, Flow.Subscriber subscriber) { Log.logTrace("Handling RST_STREAM on stream {0}", streamid); if (!closed) { @@ -763,12 +772,16 @@ class Stream extends ExchangeImpl { // A REFUSED_STREAM error code implies that the stream wasn't processed by the // peer and the client is free to retry the request afresh. if (error == ErrorFrame.REFUSED_STREAM) { + // null exchange implies a PUSH stream and those aren't + // initiated by the client, so we don't expect them to be + // considered unprocessed. + assert this.exchange != null : "PUSH streams aren't expected to be marked as unprocessed"; // Here we arrange for the request to be retried. Note that we don't call // closeAsUnprocessed() method here because the "closed" state is already set // to true a few lines above and calling close() from within // closeAsUnprocessed() will end up being a no-op. We instead do the additional // bookkeeping here. - markUnprocessedByPeer(); + this.exchange.markUnprocessedByPeer(); errorRef.compareAndSet(null, new IOException("request not processed by peer")); if (debug.on()) { debug.log("request unprocessed by peer (REFUSED_STREAM) " + this.request); @@ -1216,6 +1229,7 @@ class Stream extends ExchangeImpl { assert !endStreamSent : "internal error, send data after END_STREAM flag"; } if ((state = streamState) != 0) { + t = errorRef.get(); if (debug.on()) debug.log("trySend: cancelled: %s", String.valueOf(t)); break; } @@ -1521,7 +1535,7 @@ class Stream extends ExchangeImpl { } else cancelImpl(cause); } - // This method sends a RST_STREAM frame + // This method sends an RST_STREAM frame void cancelImpl(Throwable e) { cancelImpl(e, ResetFrame.CANCEL); } @@ -1856,8 +1870,12 @@ class Stream extends ExchangeImpl { */ void closeAsUnprocessed() { try { + // null exchange implies a PUSH stream and those aren't + // initiated by the client, so we don't expect them to be + // considered unprocessed. + assert this.exchange != null : "PUSH streams aren't expected to be closed as unprocessed"; // We arrange for the request to be retried on a new connection as allowed by the RFC-9113 - markUnprocessedByPeer(); + this.exchange.markUnprocessedByPeer(); this.errorRef.compareAndSet(null, new IOException("request not processed by peer")); if (debug.on()) { debug.log("closing " + this.request + " as unprocessed by peer"); @@ -1905,7 +1923,7 @@ class Stream extends ExchangeImpl { streamid, n, v); } } catch (UncheckedIOException uio) { - // reset stream: From RFC 9113, section 8.1 + // reset stream: From RFC 7540, section-8.1.2.6 // Malformed requests or responses that are detected MUST be // treated as a stream error (Section 5.4.2) of type // PROTOCOL_ERROR. @@ -1953,13 +1971,10 @@ class Stream extends ExchangeImpl { } - private static final VarHandle STREAM_STATE; private static final VarHandle DEREGISTERED; static { try { MethodHandles.Lookup lookup = MethodHandles.lookup(); - STREAM_STATE = lookup - .findVarHandle(Stream.class, "streamState", int.class); DEREGISTERED = lookup .findVarHandle(Stream.class, "deRegistered", boolean.class); } catch (Exception x) { diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/Alpns.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/Alpns.java index 66397d93410..5888e6c2de3 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/Alpns.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/Alpns.java @@ -34,4 +34,9 @@ public final class Alpns { public static final String HTTP_1_1 = "http/1.1"; public static final String H2 = "h2"; public static final String H2C = "h2c"; + public static final String H3 = "h3"; + + public static boolean isSecureALPNName(final String alpnName) { + return H3.equals(alpnName) || H2.equals(alpnName); + } } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/ConnectionExpiredException.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/ConnectionExpiredException.java index 07091886533..164e1e5685d 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/ConnectionExpiredException.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/ConnectionExpiredException.java @@ -29,7 +29,9 @@ import java.io.IOException; /** * Signals that an end of file or end of stream has been reached - * unexpectedly before any protocol specific data has been received. + * unexpectedly before any protocol specific data has been received, + * or that a new stream creation was rejected because the underlying + * connection was closed. */ public final class ConnectionExpiredException extends IOException { private static final long serialVersionUID = 0; diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/Deadline.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/Deadline.java index bc9a992b3bd..3ee334885a3 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/Deadline.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/Deadline.java @@ -88,6 +88,24 @@ public final class Deadline implements Comparable { return of(deadline.truncatedTo(unit)); } + /** + * Returns a copy of this deadline with the specified amount subtracted. + *

+ * This returns a {@code Deadline}, based on this one, with the specified amount subtracted. + * The amount is typically {@link Duration} but may be any other type implementing + * the {@link TemporalAmount} interface. + *

+ * This instance is immutable and unaffected by this method call. + * + * @param amountToSubtract the amount to subtract, not null + * @return a {@code Deadline} based on this deadline with the subtraction made, not null + * @throws DateTimeException if the subtraction cannot be made + * @throws ArithmeticException if numeric overflow occurs + */ + public Deadline minus(TemporalAmount amountToSubtract) { + return Deadline.of(deadline.minus(amountToSubtract)); + } + /** * Returns a copy of this deadline with the specified amount added. *

@@ -126,6 +144,21 @@ public final class Deadline implements Comparable { return Deadline.of(deadline.plusSeconds(secondsToAdd)); } + /** + * Returns a copy of this deadline with the specified duration in milliseconds added. + *

+ * This instance is immutable and unaffected by this method call. + * + * @param millisToAdd the milliseconds to add, positive or negative + * @return a {@code Deadline} based on this deadline with the specified milliseconds added, not null + * @throws DateTimeException if the result exceeds the maximum or minimum deadline + * @throws ArithmeticException if numeric overflow occurs + */ + public Deadline plusMillis(long millisToAdd) { + if (millisToAdd == 0) return this; + return Deadline.of(deadline.plusMillis(millisToAdd)); + } + /** * Returns a copy of this deadline with the specified amount added. *

@@ -183,7 +216,7 @@ public final class Deadline implements Comparable { /** * Checks if this deadline is before the specified deadline. *

- * The comparison is based on the time-line position of the deadines. + * The comparison is based on the time-line position of the deadlines. * * @param otherDeadline the other deadine to compare to, not null * @return true if this deadline is before the specified deadine @@ -217,7 +250,26 @@ public final class Deadline implements Comparable { return deadline.hashCode(); } + Instant asInstant() { + return deadline; + } + static Deadline of(Instant instant) { return new Deadline(instant); } + + /** + * Obtains a {@code Duration} representing the duration between two deadlines. + *

+ * The result of this method can be a negative period if the end is before the start. + * + * @param startInclusive the start deadline, inclusive, not null + * @param endExclusive the end deadline, exclusive, not null + * @return a {@code Duration}, not null + * @throws DateTimeException if the seconds between the deadline cannot be obtained + * @throws ArithmeticException if the calculation exceeds the capacity of {@code Duration} + */ + public static Duration between(Deadline startInclusive, Deadline endExclusive) { + return Duration.between(startInclusive.deadline, endExclusive.deadline); + } } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/HttpBodySubscriberWrapper.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/HttpBodySubscriberWrapper.java index 6dc79760b0a..1c483ce99f4 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/HttpBodySubscriberWrapper.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/HttpBodySubscriberWrapper.java @@ -284,6 +284,7 @@ public class HttpBodySubscriberWrapper implements TrustedSubscriber { */ public final void complete(Throwable t) { if (markCompleted()) { + logComplete(t); tryUnregister(); t = withError = Utils.getCompletionCause(t); if (t == null) { @@ -312,6 +313,10 @@ public class HttpBodySubscriberWrapper implements TrustedSubscriber { } } + protected void logComplete(Throwable error) { + + } + /** * {@return true if this subscriber has already completed, either normally * or abnormally} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/HttpHeadersBuilder.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/HttpHeadersBuilder.java index 409a8540b68..7c1d2311ba9 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/HttpHeadersBuilder.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/HttpHeadersBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -41,6 +41,15 @@ public class HttpHeadersBuilder { headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); } + // used in test library (Http3ServerExchange) + public HttpHeadersBuilder(HttpHeaders headers) { + headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + for (Map.Entry> entry : headers.map().entrySet()) { + List valuesCopy = new ArrayList<>(entry.getValue()); + headersMap.put(entry.getKey(), valuesCopy); + } + } + public HttpHeadersBuilder structuralCopy() { HttpHeadersBuilder builder = new HttpHeadersBuilder(); for (Map.Entry> entry : headersMap.entrySet()) { diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/Log.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/Log.java index 48f5a2b06d8..bc89a6e9d8e 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/Log.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/Log.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -28,14 +28,28 @@ package jdk.internal.net.http.common; import java.net.http.HttpHeaders; import java.util.ArrayList; +import java.util.Collection; +import java.util.EnumSet; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.function.Supplier; +import java.util.stream.Stream; + import jdk.internal.net.http.frame.DataFrame; import jdk.internal.net.http.frame.Http2Frame; import jdk.internal.net.http.frame.WindowUpdateFrame; +import jdk.internal.net.http.quic.frames.AckFrame; +import jdk.internal.net.http.quic.frames.CryptoFrame; +import jdk.internal.net.http.quic.frames.HandshakeDoneFrame; +import jdk.internal.net.http.quic.frames.PaddingFrame; +import jdk.internal.net.http.quic.frames.PingFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.frames.StreamFrame; +import jdk.internal.net.http.quic.packets.PacketSpace; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketType; import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLParameters; @@ -43,7 +57,8 @@ import javax.net.ssl.SSLParameters; /** * -Djdk.httpclient.HttpClient.log= * errors,requests,headers, - * frames[:control:data:window:all..],content,ssl,trace,channel + * frames[:control:data:window:all..],content,ssl,trace,channel, + * quic[:control:processed:retransmit:ack:crypto:data:cc:hs:dbb:ping:all] * * Any of errors, requests, headers or content are optional. * @@ -57,15 +72,17 @@ public abstract class Log implements System.Logger { static final String logProp = "jdk.httpclient.HttpClient.log"; - public static final int OFF = 0; - public static final int ERRORS = 0x1; - public static final int REQUESTS = 0x2; - public static final int HEADERS = 0x4; - public static final int CONTENT = 0x8; - public static final int FRAMES = 0x10; - public static final int SSL = 0x20; - public static final int TRACE = 0x40; - public static final int CHANNEL = 0x80; + public static final int OFF = 0x00; + public static final int ERRORS = 0x01; + public static final int REQUESTS = 0x02; + public static final int HEADERS = 0x04; + public static final int CONTENT = 0x08; + public static final int FRAMES = 0x10; + public static final int SSL = 0x20; + public static final int TRACE = 0x40; + public static final int CHANNEL = 0x80; + public static final int QUIC = 0x0100; + public static final int HTTP3 = 0x0200; static int logging; // Frame types: "control", "data", "window", "all" @@ -75,6 +92,27 @@ public abstract class Log implements System.Logger { public static final int ALL = CONTROL| DATA | WINDOW_UPDATES; static int frametypes; + // Quic message types + public static final int QUIC_CONTROL = 1; + public static final int QUIC_PROCESSED = 2; + public static final int QUIC_RETRANSMIT = 4; + public static final int QUIC_DATA = 8; + public static final int QUIC_CRYPTO = 16; + public static final int QUIC_ACK = 32; + public static final int QUIC_PING = 64; + public static final int QUIC_CC = 128; + public static final int QUIC_TIMER = 256; + public static final int QUIC_DIRECT_BUFFER_POOL = 512; + public static final int QUIC_HANDSHAKE = 1024; + public static final int QUIC_ALL = QUIC_CONTROL + | QUIC_PROCESSED | QUIC_RETRANSMIT + | QUIC_DATA | QUIC_CRYPTO + | QUIC_ACK | QUIC_PING | QUIC_CC + | QUIC_TIMER | QUIC_DIRECT_BUFFER_POOL + | QUIC_HANDSHAKE; + static int quictypes; + + static final System.Logger logger; static { @@ -94,6 +132,12 @@ public abstract class Log implements System.Logger { case "headers": logging |= HEADERS; break; + case "quic": + logging |= QUIC; + break; + case "http3": + logging |= HTTP3; + break; case "content": logging |= CONTENT; break; @@ -107,13 +151,14 @@ public abstract class Log implements System.Logger { logging |= TRACE; break; case "all": - logging |= CONTENT|HEADERS|REQUESTS|FRAMES|ERRORS|TRACE|SSL| CHANNEL; + logging |= CONTENT | HEADERS | REQUESTS | FRAMES | ERRORS | TRACE | SSL | CHANNEL | QUIC | HTTP3; frametypes |= ALL; + quictypes |= QUIC_ALL; break; default: // ignore bad values } - if (val.startsWith("frames")) { + if (val.startsWith("frames:") || val.equals("frames")) { logging |= FRAMES; String[] types = val.split(":"); if (types.length == 1) { @@ -139,6 +184,56 @@ public abstract class Log implements System.Logger { } } } + if (val.startsWith("quic:") || val.equals("quic")) { + logging |= QUIC; + String[] types = val.split(":"); + if (types.length == 1) { + quictypes = QUIC_ALL & ~QUIC_TIMER & ~QUIC_DIRECT_BUFFER_POOL; + } else { + for (String type : types) { + switch (type.toLowerCase(Locale.US)) { + case "control": + quictypes |= QUIC_CONTROL; + break; + case "data": + quictypes |= QUIC_DATA; + break; + case "processed": + quictypes |= QUIC_PROCESSED; + break; + case "retransmit": + quictypes |= QUIC_RETRANSMIT; + break; + case "crypto": + quictypes |= QUIC_CRYPTO; + break; + case "cc": + quictypes |= QUIC_CC; + break; + case "hs": + quictypes |= QUIC_HANDSHAKE; + break; + case "ack": + quictypes |= QUIC_ACK; + break; + case "ping": + quictypes |= QUIC_PING; + break; + case "timer": + quictypes |= QUIC_TIMER; + break; + case "dbb": + quictypes |= QUIC_DIRECT_BUFFER_POOL; + break; + case "all": + quictypes = QUIC_ALL; + break; + default: + // ignore bad values + } + } + } + } } } if (logging != OFF) { @@ -175,6 +270,119 @@ public abstract class Log implements System.Logger { return (logging & CHANNEL) != 0; } + public static boolean altsvc() { return headers(); } + + public static boolean quicRetransmit() { + return (logging & QUIC) != 0 && (quictypes & QUIC_RETRANSMIT) != 0; + } + + // not called directly - but impacts isLogging(QuicFrame) + public static boolean quicHandshake() { + return (logging & QUIC) != 0 && (quictypes & QUIC_HANDSHAKE) != 0; + } + + public static boolean quicProcessed() { + return (logging & QUIC) != 0 && (quictypes & QUIC_PROCESSED) != 0; + } + + // not called directly - but impacts isLogging(QuicFrame) + public static boolean quicData() { + return (logging & QUIC) != 0 && (quictypes & QUIC_DATA) != 0; + } + + public static boolean quicCrypto() { + return (logging & QUIC) != 0 && (quictypes & QUIC_CRYPTO) != 0; + } + + public static boolean quicCC() { + return (logging & QUIC) != 0 && (quictypes & QUIC_CC) != 0; + } + + public static boolean quicControl() { + return (logging & QUIC) != 0 && (quictypes & QUIC_CONTROL) != 0; + } + + public static boolean quicTimer() { + return (logging & QUIC) != 0 && (quictypes & QUIC_TIMER) != 0; + } + public static boolean quicDBB() { + return (logging & QUIC) != 0 && (quictypes & QUIC_DIRECT_BUFFER_POOL) != 0; + } + + public static boolean quic() { + return (logging & QUIC) != 0; + } + + public static boolean http3() { + return (logging & HTTP3) != 0; + } + + public static void logHttp3(String s, Object... s1) { + if (http3()) { + logger.log(Level.INFO, "HTTP3: " + s, s1); + } + } + + private static boolean isLogging(QuicFrame frame) { + if (frame instanceof StreamFrame sf) + return (quictypes & QUIC_DATA) != 0 + || (quictypes & QUIC_CONTROL) != 0 && sf.isLast() + || (quictypes & QUIC_CONTROL) != 0 && sf.offset() == 0; + if (frame instanceof AckFrame) + return (quictypes & QUIC_ACK) != 0; + if (frame instanceof CryptoFrame) + return (quictypes & QUIC_CRYPTO) != 0 + || (quictypes & QUIC_HANDSHAKE) != 0; + if (frame instanceof PingFrame) + return (quictypes & QUIC_PING) != 0; + if (frame instanceof PaddingFrame) return false; + if (frame instanceof HandshakeDoneFrame && quicHandshake()) + return true; + return (quictypes & QUIC_CONTROL) != 0; + } + + private static final EnumSet HS_TYPES = EnumSet.complementOf( + EnumSet.of(PacketType.ONERTT)); + + private static boolean quicPacketLoggable(QuicPacket packet) { + return (logging & QUIC) != 0 + && (quictypes == QUIC_ALL + || quicHandshake() && HS_TYPES.contains(packet.packetType()) + || stream(packet.frames()).anyMatch(Log::isLogging)); + } + + public static boolean quicPacketOutLoggable(QuicPacket packet) { + return quicPacketLoggable(packet); + } + + private static Stream stream(Collection list) { + return list == null ? Stream.empty() : list.stream(); + } + + public static boolean quicPacketInLoggable(QuicPacket packet) { + return quicPacketLoggable(packet); + } + + public static void logQuic(String s, Object... s1) { + if (quic()) { + logger.log(Level.INFO, "QUIC: " + s, s1); + } + } + + public static void logQuicPacketOut(String connectionTag, QuicPacket packet) { + if (quicPacketOutLoggable(packet)) { + logger.log(Level.INFO, "QUIC: {0} OUT: {1}", + connectionTag, packet.prettyPrint()); + } + } + + public static void logQuicPacketIn(String connectionTag, QuicPacket packet) { + if (quicPacketInLoggable(packet)) { + logger.log(Level.INFO, "QUIC: {0} IN: {1}", + connectionTag, packet.prettyPrint()); + } + } + public static void logError(String s, Object... s1) { if (errors()) { logger.log(Level.INFO, "ERROR: " + s, s1); @@ -237,6 +445,12 @@ public abstract class Log implements System.Logger { } } + public static void logAltSvc(String s, Object... s1) { + if (altsvc()) { + logger.log(Level.INFO, "ALTSVC: " + s, s1); + } + } + public static boolean loggingFrame(Class clazz) { if (frametypes == ALL) { return true; diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/OperationTrackers.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/OperationTrackers.java index 3aec13b59ec..ef031eef999 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/OperationTrackers.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/OperationTrackers.java @@ -55,6 +55,8 @@ public final class OperationTrackers { long getOutstandingHttpOperations(); // The number of active HTTP/2 streams long getOutstandingHttp2Streams(); + // The number of active HTTP/3 streams + long getOutstandingHttp3Streams(); // The number of active WebSockets long getOutstandingWebSocketOperations(); // number of TCP connections still opened diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/TimeSource.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/TimeSource.java index 489fbe7ffd8..c74c67f7d58 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/TimeSource.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/TimeSource.java @@ -25,7 +25,6 @@ package jdk.internal.net.http.common; import java.time.Instant; -import java.time.InstantSource; /** * A {@link TimeLine} based on {@link System#nanoTime()} for the @@ -52,7 +51,7 @@ public final class TimeSource implements TimeLine { // The use of Integer.MAX_VALUE is arbitrary. // Any value not too close to Long.MAX_VALUE // would do. - static final int TIME_WINDOW = Integer.MAX_VALUE; + static final long TIME_WINDOW = Integer.MAX_VALUE; final Instant first; final long firstNanos; diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java index 2916a41e62a..b14d76d8dba 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java @@ -37,33 +37,50 @@ import java.io.PrintStream; import java.io.UncheckedIOException; import java.lang.System.Logger.Level; import java.net.ConnectException; +import java.net.Inet6Address; import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.StandardSocketOptions; import java.net.Proxy; import java.net.URI; import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; import java.net.http.HttpTimeoutException; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.channels.CancelledKeyException; +import java.nio.channels.NetworkChannel; import java.nio.channels.SelectionKey; import java.nio.charset.CharacterCodingException; import java.nio.charset.Charset; import java.nio.charset.CodingErrorAction; import java.nio.charset.StandardCharsets; import java.text.Normalizer; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HexFormat; +import java.util.Iterator; import java.util.List; import java.util.Set; import java.util.TreeSet; import java.util.concurrent.CancellationException; +import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; import java.util.function.BiPredicate; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; @@ -152,6 +169,10 @@ public final class Utils { return prop.isEmpty() ? true : Boolean.parseBoolean(prop); } + // A threshold to decide whether to slice or copy. + // see sliceOrCopy + public static final int SLICE_THRESHOLD = 32; + /** * Allocated buffer size. Must never be higher than 16K. But can be lower * if smaller allocation units preferred. HTTP/2 mandates that all @@ -169,7 +190,8 @@ public final class Utils { private static Set getDisallowedHeaders() { Set headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER); - headers.addAll(Set.of("connection", "content-length", "expect", "host", "upgrade")); + headers.addAll(Set.of("connection", "content-length", "expect", "host", "upgrade", + "alt-used")); String v = getNetProperty("jdk.httpclient.allowRestrictedHeaders"); if (v != null) { @@ -215,6 +237,56 @@ public final class Utils { return true; }; + public static T addSuppressed(T x, Throwable suppressed) { + if (x != suppressed && suppressed != null) { + var sup = x.getSuppressed(); + if (sup != null && sup.length > 0) { + if (Arrays.asList(sup).contains(suppressed)) { + return x; + } + } + sup = suppressed.getSuppressed(); + if (sup != null && sup.length > 0) { + if (Arrays.asList(sup).contains(x)) { + return x; + } + } + x.addSuppressed(suppressed); + } + return x; + } + + /** + * {@return a string comparing the given deadline with now, typically + * something like "due since Nms" or "due in Nms"} + * + * @apiNote + * This method recognize deadlines set to Instant.MIN + * and Instant.MAX as special cases meaning "due" and + * "not scheduled". + * + * @param now now + * @param deadline the deadline + */ + public static String debugDeadline(Deadline now, Deadline deadline) { + boolean isDue = deadline.compareTo(now) <= 0; + try { + if (isDue) { + if (deadline.equals(Deadline.MIN)) { + return "due (Deadline.MIN)"; + } else { + return "due since " + deadline.until(now, ChronoUnit.MILLIS) + "ms"; + } + } else if (deadline.equals(Deadline.MAX)) { + return "not scheduled (Deadline.MAX)"; + } else { + return "due in " + now.until(deadline, ChronoUnit.MILLIS) + "ms"; + } + } catch (ArithmeticException x) { + return isDue ? "due since too long" : "due in the far future"; + } + } + public record ProxyHeaders(HttpHeaders userHeaders, HttpHeaders systemHeaders) {} public static final BiPredicate PROXY_TUNNEL_RESTRICTED() { @@ -346,6 +418,7 @@ public final class Utils { } public static String interestOps(SelectionKey key) { + if (key == null) return "null-key"; try { return describeOps(key.interestOps()); } catch (CancelledKeyException x) { @@ -354,6 +427,7 @@ public final class Utils { } public static String readyOps(SelectionKey key) { + if (key == null) return "null-key"; try { return describeOps(key.readyOps()); } catch (CancelledKeyException x) { @@ -438,6 +512,21 @@ public final class Utils { return cause; } + public static IOException toIOException(Throwable cause) { + if (cause == null) return null; + if (cause instanceof CompletionException ce) { + cause = ce.getCause(); + } else if (cause instanceof ExecutionException ee) { + cause = ee.getCause(); + } + if (cause instanceof IOException io) { + return io; + } else if (cause instanceof UncheckedIOException uio) { + return uio.getCause(); + } + return new IOException(cause.getMessage(), cause); + } + public static IOException getIOException(Throwable t) { if (t instanceof IOException) { return (IOException) t; @@ -575,6 +664,10 @@ public final class Utils { return Integer.parseInt(System.getProperty(name, String.valueOf(defaultValue))); } + public static long getLongProperty(String name, long defaultValue) { + return Long.parseLong(System.getProperty(name, String.valueOf(defaultValue))); + } + public static int getIntegerNetProperty(String property, int min, int max, int defaultValue, boolean log) { int value = Utils.getIntegerNetProperty(property, defaultValue); // use default value if misconfigured @@ -755,6 +848,91 @@ public final class Utils { return remain; } + // + + /** + * Reads as much bytes as possible from the buffer list, and + * write them in the provided {@code data} byte array. + * Returns the number of bytes read and written to the byte array. + * This method advances the position in the byte buffers it reads + * @param bufs A list of byte buffer + * @param data A byte array to write into + * @param offset Where to start writing in the byte array + * @return the amount of bytes read and written to the byte array + */ + public static int read(List bufs, byte[] data, int offset) { + int pos = offset; + for (ByteBuffer buf : bufs) { + if (pos >= data.length) break; + int read = Math.min(buf.remaining(), data.length - pos); + if (read <= 0) continue; + buf.get(data, pos, read); + pos += read; + } + return pos - offset; + } + + /** + * Returns the next buffer that has remaining bytes, or null. + * @param iterator an iterator + * @return the next buffer that has remaining bytes, or null + */ + public static ByteBuffer next(Iterator iterator) { + ByteBuffer next = null; + while (iterator.hasNext() && !(next = iterator.next()).hasRemaining()); + return next == null || !next.hasRemaining() ? null : next; + } + + /** + * Compute the relative consolidated position in bytes at which the two + * input mismatch, or -1 if there is no mismatch. + * @apiNote This method behaves as {@link ByteBuffer#mismatch(ByteBuffer)}. + * @param these a first list of byte buffers + * @param those a second list of byte buffers + * @return the relative consolidated position in bytes at which the two + * input mismatch, or -1L if there is no mismatch. + */ + public static long mismatch(List these, List those) { + if (these.isEmpty()) return those.isEmpty() ? -1 : 0; + if (those.isEmpty()) return 0; + Iterator lefti = these.iterator(), righti = those.iterator(); + ByteBuffer left = next(lefti), right = next(righti); + long parsed = 0; + while (left != null || right != null) { + int m = left == null || right == null ? 0 : left.mismatch(right); + if (m == -1) { + parsed = parsed + left.remaining(); + assert right.remaining() == left.remaining(); + if ((left = next(lefti)) != null) { + if ((right = next(righti)) != null) { + continue; + } + return parsed; + } + return (right = next(righti)) != null ? parsed : -1; + } + if (m == 0) return parsed; + parsed = parsed + m; + if (m < left.remaining()) { + if (m < right.remaining()) { + return parsed; + } + if ((right = next(righti)) != null) { + left = left.slice(m, left.remaining() - m); + continue; + } + return parsed; + } + assert m < right.remaining(); + if ((left = next(lefti)) != null) { + right = right.slice(m, right.remaining() - m); + continue; + } + return parsed; + } + return -1L; + } + public static long synchronizedRemaining(List bufs) { if (bufs == null) return 0L; synchronized (bufs) { @@ -766,12 +944,13 @@ public final class Utils { if (bufs == null) return 0; long remain = 0; for (ByteBuffer buf : bufs) { - remain += buf.remaining(); - if (remain > max) { + int size = buf.remaining(); + if (max - remain < size) { throw new IllegalArgumentException("too many bytes"); } + remain += size; } - return (int) remain; + return remain; } public static int remaining(List bufs, int max) { @@ -783,12 +962,13 @@ public final class Utils { if (refs == null) return 0; long remain = 0; for (ByteBuffer b : refs) { - remain += b.remaining(); - if (remain > max) { + int size = b.remaining(); + if (max - remain < size) { throw new IllegalArgumentException("too many bytes"); } + remain += size; } - return (int) remain; + return remain; } public static int remaining(ByteBuffer[] refs, int max) { @@ -834,6 +1014,50 @@ public final class Utils { return newb; } + /** + * Creates a slice of a buffer, possibly copying the data instead + * of slicing. + * If the buffer capacity is less than the {@linkplain #SLICE_THRESHOLD + * default slice threshold}, or if the capacity minus the length to slice + * is less than the {@linkplain #SLICE_THRESHOLD threshold}, returns a slice. + * Otherwise, copy so as not to retain a reference to a big buffer + * for a small slice. + * @param src the original buffer + * @param start where to start copying/slicing from src + * @param len how many byte to slice/copy + * @return a new ByteBuffer for the given slice + */ + public static ByteBuffer sliceOrCopy(ByteBuffer src, int start, int len) { + return sliceOrCopy(src, start, len, SLICE_THRESHOLD); + } + + /** + * Creates a slice of a buffer, possibly copying the data instead + * of slicing. + * If the buffer capacity minus the length to slice is less than the threshold, + * returns a slice. + * Otherwise, copy so as not to retain a reference to a buffer + * that contains more bytes than needed. + * @param src the original buffer + * @param start where to start copying/slicing from src + * @param len how many byte to slice/copy + * @param threshold a threshold to decide whether to slice or copy + * @return a new ByteBuffer for the given slice + */ + public static ByteBuffer sliceOrCopy(ByteBuffer src, int start, int len, int threshold) { + assert src.hasArray(); + int cap = src.array().length; + if (cap - len < threshold) { + return src.slice(start, len); + } else { + byte[] b = new byte[len]; + if (len > 0) { + src.get(start, b, 0, len); + } + return ByteBuffer.wrap(b); + } + } + /** * Get the Charset from the Content-encoding header. Defaults to * UTF_8 @@ -849,7 +1073,9 @@ public final class Utils { if (value == null) return StandardCharsets.UTF_8; return Charset.forName(value); } catch (Throwable x) { - Log.logTrace("Can't find charset in \"{0}\" ({1})", type, x); + if (Log.trace()) { + Log.logTrace("Can't find charset in \"{0}\" ({1})", type, x); + } return StandardCharsets.UTF_8; } } @@ -1078,6 +1304,40 @@ public final class Utils { } } + /** + * Creates HTTP/2 HTTP/3 pseudo headers for the given request. + * @param request the request + * @return pseudo headers for that request + */ + public static HttpHeaders createPseudoHeaders(HttpRequest request) { + HttpHeadersBuilder hdrs = new HttpHeadersBuilder(); + String method = request.method(); + hdrs.setHeader(":method", method); + URI uri = request.uri(); + hdrs.setHeader(":scheme", uri.getScheme()); + String host = uri.getHost(); + int port = uri.getPort(); + assert host != null; + if (port != -1) { + hdrs.setHeader(":authority", host + ":" + port); + } else { + hdrs.setHeader(":authority", host); + } + String query = uri.getRawQuery(); + String path = uri.getRawPath(); + if (path == null || path.isEmpty()) { + if (method.equalsIgnoreCase("OPTIONS")) { + path = "*"; + } else { + path = "/"; + } + } + if (query != null) { + path += "?" + query; + } + hdrs.setHeader(":path", Utils.encode(path)); + return hdrs.build(); + } // -- toAsciiString-like support to encode path and query URI segments // Encodes all characters >= \u0080 into escaped, normalized UTF-8 octets, @@ -1121,6 +1381,302 @@ public final class Utils { return sb.toString(); } + /** + * {@return the content of the buffer as an hexadecimal string} + * This method doesn't move the buffer position or limit. + * @param buffer a byte buffer + */ + public static String asHexString(ByteBuffer buffer) { + if (!buffer.hasRemaining()) return ""; + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(buffer.position(), bytes); + return HexFormat.of().formatHex(bytes); + } + + /** + * Converts a ByteBuffer containing bytes encoded using + * the given {@linkplain Charset charset} into a + * string. This method does not throw but will replace + * unrecognized sequences with the replacement character. + * The bytes in the buffer are consumed. + * + * @apiNote + * This method is intended for debugging purposes only, + * since buffers are not guaranteed to be split at character + * boundaries. + * + * @param buffer a buffer containing bytes encoded using + * a charset + * @param charset the charset to use to decode the bytes + * into a string + * + * @return a string built from the bytes contained + * in the buffer decoded using the given charset + */ + public static String asString(ByteBuffer buffer, Charset charset) { + var decoded = charset.decode(buffer); + char[] chars = new char[decoded.length()]; + decoded.get(chars); + return new String(chars); + } + + /** + * Converts a ByteBuffer containing UTF-8 bytes into a + * string. This method does not throw but will replace + * unrecognized sequences with the replacement character. + * The bytes in the buffer are consumed. + * + * @apiNote + * This method is intended for debugging purposes only, + * since buffers are not guaranteed to be split at character + * boundaries. + * + * @param buffer a buffer containing UTF-8 bytes + * + * @return a string built from the decoded UTF-8 bytes contained + * in the buffer + */ + public static String asString(ByteBuffer buffer) { + return asString(buffer, StandardCharsets.UTF_8); + } + + public static String millis(Instant now, Instant deadline) { + if (Instant.MAX.equals(deadline)) return "not scheduled"; + try { + long delay = now.until(deadline, ChronoUnit.MILLIS); + return delay + " ms"; + } catch (ArithmeticException a) { + return "too far away"; + } + } + + public static String millis(Deadline now, Deadline deadline) { + return millis(now.asInstant(), deadline.asInstant()); + } + + public static ExecutorService safeExecutor(ExecutorService delegate, + BiConsumer errorHandler) { + Executor overflow = new CompletableFuture().defaultExecutor(); + return new SafeExecutorService(delegate, overflow, errorHandler); + } + + public static sealed class SafeExecutor implements Executor + permits SafeExecutorService { + final E delegate; + final BiConsumer errorHandler; + final Executor overflow; + + public SafeExecutor(E delegate, Executor overflow, BiConsumer errorHandler) { + this.delegate = delegate; + this.overflow = overflow; + this.errorHandler = errorHandler; + } + + @Override + public void execute(Runnable command) { + ensureExecutedAsync(command); + } + + private void ensureExecutedAsync(Runnable command) { + try { + delegate.execute(command); + } catch (RejectedExecutionException t) { + errorHandler.accept(command, t); + overflow.execute(command); + } + } + + } + + public static final class SafeExecutorService extends SafeExecutor + implements ExecutorService { + + public SafeExecutorService(ExecutorService delegate, + Executor overflow, + BiConsumer errorHandler) { + super(delegate, overflow, errorHandler); + } + + @Override + public void shutdown() { + delegate.shutdown(); + } + + @Override + public List shutdownNow() { + return delegate.shutdownNow(); + } + + @Override + public boolean isShutdown() { + return delegate.isShutdown(); + } + + @Override + public boolean isTerminated() { + return delegate.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return delegate.awaitTermination(timeout, unit); + } + + @Override + public Future submit(Callable task) { + return delegate.submit(task); + } + + @Override + public Future submit(Runnable task, T result) { + return delegate.submit(task, result); + } + + @Override + public Future submit(Runnable task) { + return delegate.submit(task); + } + + @Override + public List> invokeAll(Collection> tasks) + throws InterruptedException { + return delegate.invokeAll(tasks); + } + + @Override + public List> invokeAll(Collection> tasks, + long timeout, TimeUnit unit) + throws InterruptedException { + return delegate.invokeAll(tasks, timeout, unit); + } + + @Override + public T invokeAny(Collection> tasks) + throws InterruptedException, ExecutionException { + return delegate.invokeAny(tasks); + } + + @Override + public T invokeAny(Collection> tasks, + long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return delegate.invokeAny(tasks); + } + } + + public static T configureChannelBuffers(Consumer logSink, T chan, + int receiveBufSize, int sendBufSize) { + + if (logSink != null) { + int bufsize = getSoReceiveBufferSize(logSink, chan); + logSink.accept("Initial receive buffer size is: %d".formatted(bufsize)); + bufsize = getSoSendBufferSize(logSink, chan); + logSink.accept("Initial send buffer size is: %d".formatted(bufsize)); + } + if (trySetReceiveBufferSize(logSink, chan, receiveBufSize)) { + if (logSink != null) { + int bufsize = getSoReceiveBufferSize(logSink, chan); + logSink.accept("Receive buffer size configured: %d".formatted(bufsize)); + } + } + if (trySetSendBufferSize(logSink, chan, sendBufSize)) { + if (logSink != null) { + int bufsize = getSoSendBufferSize(logSink, chan); + logSink.accept("Send buffer size configured: %d".formatted(bufsize)); + } + } + return chan; + } + + public static boolean trySetReceiveBufferSize(Consumer logSink, NetworkChannel chan, int bufsize) { + try { + if (bufsize > 0) { + chan.setOption(StandardSocketOptions.SO_RCVBUF, bufsize); + return true; + } + } catch (IOException x) { + if (logSink != null) + logSink.accept("Failed to set receive buffer size to %d on %s" + .formatted(bufsize, chan)); + } + return false; + } + + public static boolean trySetSendBufferSize(Consumer logSink, NetworkChannel chan, int bufsize) { + try { + if (bufsize > 0) { + chan.setOption(StandardSocketOptions.SO_SNDBUF, bufsize); + return true; + } + } catch (IOException x) { + if (logSink != null) + logSink.accept("Failed to set send buffer size to %d on %s" + .formatted(bufsize, chan)); + } + return false; + } + + public static int getSoReceiveBufferSize(Consumer logSink, NetworkChannel chan) { + try { + return chan.getOption(StandardSocketOptions.SO_RCVBUF); + } catch (IOException x) { + if (logSink != null) + logSink.accept("Failed to get initial receive buffer size on %s".formatted(chan)); + } + return 0; + } + + public static int getSoSendBufferSize(Consumer logSink, NetworkChannel chan) { + try { + return chan.getOption(StandardSocketOptions.SO_SNDBUF); + } catch (IOException x) { + if (logSink!= null) + logSink.accept("Failed to get initial receive buffer size on %s".formatted(chan)); + } + return 0; + } + + + /** + * Try to figure out whether local and remote addresses are compatible. + * Used to diagnose potential communication issues early. + * This is a best effort, and there is no guarantee that all potential + * conflicts will be detected. + * @param local local address + * @param peer peer address + * @return a message describing the conflict, if any, or {@code null} if no + * conflict was detected. + */ + public static String addressConflict(SocketAddress local, SocketAddress peer) { + if (local == null || peer == null) return null; + if (local.equals(peer)) { + return "local endpoint and remote endpoint are bound to the same IP address and port"; + } + if (!(local instanceof InetSocketAddress li) || !(peer instanceof InetSocketAddress pi)) { + return null; + } + var laddr = li.getAddress(); + var paddr = pi.getAddress(); + if (!laddr.isAnyLocalAddress() && !paddr.isAnyLocalAddress()) { + if (laddr.getClass() != paddr.getClass()) { // IPv4 vs IPv6 + if ((laddr instanceof Inet6Address laddr6 && !laddr6.isIPv4CompatibleAddress()) + || (paddr instanceof Inet6Address paddr6 && !paddr6.isIPv4CompatibleAddress())) { + return "local endpoint IP (%s) and remote endpoint IP (%s) don't match" + .formatted(laddr.getClass().getSimpleName(), + paddr.getClass().getSimpleName()); + } + } + } + if (li.getPort() != pi.getPort()) return null; + if (li.getAddress().isAnyLocalAddress() && pi.getAddress().isLoopbackAddress()) { + return "local endpoint (wildcard) and remote endpoint (loopback) ports conflict"; + } + if (pi.getAddress().isAnyLocalAddress() && li.getAddress().isLoopbackAddress()) { + return "local endpoint (loopback) and remote endpoint (wildcard) ports conflict"; + } + return null; + } + /** * {@return the exception the given {@code cf} was completed with, * or a {@link CancellationException} if the given {@code cf} was diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/frame/AltSvcFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/frame/AltSvcFrame.java new file mode 100644 index 00000000000..46d4ebeb772 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/frame/AltSvcFrame.java @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.frame; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Objects; +import java.util.Optional; + +public final class AltSvcFrame extends Http2Frame { + + public static final int TYPE = 0xa; + + private final int length; + private final String origin; + private final String altSvcValue; + + private static final Charset encoding = StandardCharsets.US_ASCII; + + // Strings should be US-ASCII. This is checked by the FrameDecoder. + public AltSvcFrame(int streamid, int flags, Optional originVal, String altValue) { + super(streamid, flags); + this.origin = originVal.orElse(""); + this.altSvcValue = Objects.requireNonNull(altValue); + this.length = 2 + origin.length() + altValue.length(); + assert origin.length() == origin.getBytes(encoding).length; + assert altSvcValue.length() == altSvcValue.getBytes(encoding).length; + } + + @Override + public int type() { + return TYPE; + } + + @Override + int length() { + return length; + } + + public String getOrigin() { + return origin; + } + + public String getAltSvcValue() { + return altSvcValue; + } + + @Override + public String toString() { + return super.toString() + + ", origin=" + this.origin + + ", alt-svc: " + altSvcValue; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/frame/FramesDecoder.java b/src/java.net.http/share/classes/jdk/internal/net/http/frame/FramesDecoder.java index 7ebfa090830..da05f6392c1 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/frame/FramesDecoder.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/frame/FramesDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -26,11 +26,13 @@ package jdk.internal.net.http.frame; import java.io.IOException; -import java.lang.System.Logger.Level; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.List; +import java.util.Optional; + import jdk.internal.net.http.common.Log; import jdk.internal.net.http.common.Logger; import jdk.internal.net.http.common.Utils; @@ -344,6 +346,8 @@ public class FramesDecoder { return parseWindowUpdateFrame(frameLength, frameStreamid, frameFlags); case ContinuationFrame.TYPE: return parseContinuationFrame(frameLength, frameStreamid, frameFlags); + case AltSvcFrame.TYPE: + return parseAltSvcFrame(frameLength, frameStreamid, frameFlags); default: // RFC 7540 4.1 // Implementations MUST ignore and discard any frame that has a type that is unknown. @@ -557,4 +561,32 @@ public class FramesDecoder { return new ContinuationFrame(streamid, flags, getBuffers(false, frameLength)); } + private Http2Frame parseAltSvcFrame(int frameLength, int frameStreamid, int frameFlags) { + var len = getShort(); + byte[] origin; + Optional originUri = Optional.empty(); + if (len > 0) { + origin = getBytes(len); + if (!isUSAscii(origin)) { + return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR, frameStreamid, + "illegal character in AltSvcFrame"); + } + originUri = Optional.of(new String(origin, StandardCharsets.US_ASCII)); + } + byte[] altbytes = getBytes(frameLength - 2 - len); + if (!isUSAscii(altbytes)) { + return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR, frameStreamid, + "illegal character in AltSvcFrame"); + } + String altSvc = new String(altbytes, StandardCharsets.US_ASCII); + return new AltSvcFrame(frameStreamid, 0, originUri, altSvc); + } + + static boolean isUSAscii(byte[] bytes) { + for (int i=0; i < bytes.length; i++) { + if (bytes[i] < 0) return false; + } + return true; + } + } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/frame/FramesEncoder.java b/src/java.net.http/share/classes/jdk/internal/net/http/frame/FramesEncoder.java index 4fdd4acd661..2ee2083c22c 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/frame/FramesEncoder.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/frame/FramesEncoder.java @@ -26,6 +26,7 @@ package jdk.internal.net.http.frame; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; @@ -70,6 +71,7 @@ public class FramesEncoder { case GoAwayFrame.TYPE -> encodeGoAwayFrame((GoAwayFrame) frame); case WindowUpdateFrame.TYPE -> encodeWindowUpdateFrame((WindowUpdateFrame) frame); case ContinuationFrame.TYPE -> encodeContinuationFrame((ContinuationFrame) frame); + case AltSvcFrame.TYPE -> encodeAltSvcFrame((AltSvcFrame) frame); default -> throw new UnsupportedOperationException("Not supported frame " + frame.type() + " (" + frame.getClass().getName() + ")"); }; @@ -227,6 +229,20 @@ public class FramesEncoder { return join(buf, frame.getHeaderBlock()); } + private List encodeAltSvcFrame(AltSvcFrame frame) { + final int length = frame.length(); + ByteBuffer buf = getBuffer(Http2Frame.FRAME_HEADER_SIZE + length); + putHeader(buf, length, AltSvcFrame.TYPE, NO_FLAGS, frame.streamid); + final String origin = frame.getOrigin(); + assert (origin.length() & 0xffff0000) == 0; + buf.putShort((short)origin.length()); + if (!origin.isEmpty()) + buf.put(frame.getOrigin().getBytes(StandardCharsets.US_ASCII)); + buf.put(frame.getAltSvcValue().getBytes(StandardCharsets.US_ASCII)); + buf.flip(); + return List.of(buf); + } + private List joinWithPadding(ByteBuffer buf, List data, int padLength) { int len = data.size(); if (len == 0) return List.of(buf, getPadding(padLength)); diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/frame/Http2Frame.java b/src/java.net.http/share/classes/jdk/internal/net/http/frame/Http2Frame.java index f837645696f..469d06cef0c 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/frame/Http2Frame.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/frame/Http2Frame.java @@ -91,8 +91,9 @@ public abstract class Http2Frame { case PingFrame.TYPE -> "PING"; case PushPromiseFrame.TYPE -> "PUSH_PROMISE"; case WindowUpdateFrame.TYPE -> "WINDOW_UPDATE"; + case AltSvcFrame.TYPE -> "ALTSVC"; - default -> "UNKNOWN"; + default -> "UNKNOWN"; }; } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/hpack/Decoder.java b/src/java.net.http/share/classes/jdk/internal/net/http/hpack/Decoder.java index 881be12c67c..9cdd604efd6 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/hpack/Decoder.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/hpack/Decoder.java @@ -282,7 +282,7 @@ public final class Decoder { if (endOfHeaderBlock && state != State.READY) { logger.log(NORMAL, () -> format("unexpected end of %s representation", state)); - throw new IOException("Unexpected end of header block"); + throw new ProtocolException("Unexpected end of header block"); } if (endOfHeaderBlock) { size = indexed = 0; diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/hpack/ISO_8859_1.java b/src/java.net.http/share/classes/jdk/internal/net/http/hpack/ISO_8859_1.java index a233e0f3a38..979c3ded2bc 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/hpack/ISO_8859_1.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/hpack/ISO_8859_1.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -40,9 +40,10 @@ import java.nio.ByteBuffer; // // The encoding is simple and well known: 1 byte <-> 1 char // -final class ISO_8859_1 { +public final class ISO_8859_1 { - private ISO_8859_1() { } + private ISO_8859_1() { + } public static final class Reader { diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/hpack/QuickHuffman.java b/src/java.net.http/share/classes/jdk/internal/net/http/hpack/QuickHuffman.java index 427c2504de5..c6b4c51761b 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/hpack/QuickHuffman.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/hpack/QuickHuffman.java @@ -619,7 +619,7 @@ public final class QuickHuffman { } } - static final class Reader implements Huffman.Reader { + public static final class Reader implements Huffman.Reader { private final BufferUpdateConsumer UPDATER = (buf, bufLen) -> { @@ -703,7 +703,7 @@ public final class QuickHuffman { } } - static final class Writer implements Huffman.Writer { + public static final class Writer implements Huffman.Writer { private final BufferUpdateConsumer UPDATER = (buf, bufLen) -> { @@ -782,12 +782,26 @@ public final class QuickHuffman { @Override public int lengthOf(CharSequence value, int start, int end) { - int len = 0; - for (int i = start; i < end; i++) { - char c = value.charAt(i); - len += codeLengthOf(c); - } - return bytesForBits(len); + return QuickHuffman.lengthOf(value, start, end); } } + + public static int lengthOf(CharSequence value, int start, int end) { + int len = 0; + for (int i = start; i < end; i++) { + char c = value.charAt(i); + len += codeLengthOf(c); + } + return bytesForBits(len); + } + + public static int lengthOf(CharSequence value) { + return lengthOf(value, 0, value.length()); + } + + /* Used to calculate the number of bytes required for Huffman encoding */ + + public static boolean isHuffmanBetterFor(CharSequence input) { + return lengthOf(input) < input.length(); + } } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/ConnectionSettings.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/ConnectionSettings.java new file mode 100644 index 00000000000..ca734d74ae7 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/ConnectionSettings.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3; + +import java.util.Objects; + +import jdk.internal.net.http.http3.frames.SettingsFrame; + +/** + * Represents the settings that are conveyed in a HTTP3 SETTINGS frame for a HTTP3 connection + */ +public record ConnectionSettings( + long maxFieldSectionSize, + long qpackMaxTableCapacity, + long qpackBlockedStreams) { + + // we use -1 (an internal value) to represent unlimited + public static final long UNLIMITED_MAX_FIELD_SECTION_SIZE = -1; + + public static ConnectionSettings createFrom(final SettingsFrame frame) { + Objects.requireNonNull(frame); + // default is unlimited as per RFC-9114 section 7.2.4.1 + final long maxFieldSectionSize = getOrDefault(frame, SettingsFrame.SETTINGS_MAX_FIELD_SECTION_SIZE, + UNLIMITED_MAX_FIELD_SECTION_SIZE); + // default is zero as per RFC-9204 section 5 + final long qpackMaxTableCapacity = getOrDefault(frame, SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, 0); + // default is zero as per RFC-9204, section 5 + final long qpackBlockedStreams = getOrDefault(frame, SettingsFrame.SETTINGS_QPACK_BLOCKED_STREAMS, 0); + return new ConnectionSettings(maxFieldSectionSize, qpackMaxTableCapacity, qpackBlockedStreams); + } + + private static long getOrDefault(final SettingsFrame frame, final int paramId, final long defaultValue) { + final long val = frame.getParameter(paramId); + if (val == -1) { + return defaultValue; + } + return val; + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/Http3Error.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/Http3Error.java new file mode 100644 index 00000000000..423fb27f844 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/Http3Error.java @@ -0,0 +1,308 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.http3; + +import java.util.HexFormat; +import java.util.Optional; +import java.util.stream.Stream; + +import jdk.internal.net.quic.QuicTransportErrors; + +/** + * This enum models HTTP/3 error codes as specified in + * RFC 9114, Section 8, + * augmented with QPack error codes as specified in + * RFC 9204, Section 6. + */ +public enum Http3Error { + + /** + * No error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * This is used when the connection or stream
+     * needs to be closed, but there is no error to signal.
+     * }
+ */ + H3_NO_ERROR (0x0100), // 256 + + /** + * General protocol error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * Peer violated protocol requirements in a way that does
+     * not match a more specific error code, or endpoint declines
+     * to use the more specific error code.
+     * }
+ */ + H3_GENERAL_PROTOCOL_ERROR (0x0101), // 257 + + /** + * Internal error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * An internal error has occurred in the HTTP stack.
+     * }
+ */ + H3_INTERNAL_ERROR (0x0102), // 258 + + /** + * Stream creation error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * The endpoint detected that its peer created a stream that
+     * it will not accept.
+     * }
+ */ + H3_STREAM_CREATION_ERROR (0x0103), // 259 + + /** + * Critical stream closed error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * A stream required by the HTTP/3 connection was closed or reset.
+     * }
+ */ + H3_CLOSED_CRITICAL_STREAM (0x0104), // 260 + + /** + * Frame unexpected error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * A frame was received that was not permitted in the
+     * current state or on the current stream.
+     * }
+ */ + H3_FRAME_UNEXPECTED (0x0105), // 261 + + /** + * Frame error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * A frame that fails to satisfy layout requirements or with
+     * an invalid size was received.
+     * }
+ */ + H3_FRAME_ERROR (0x0106), // 262 + + /** + * Excessive load error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * The endpoint detected that its peer is exhibiting a behavior
+     * that might be generating excessive load.
+     * }
+ */ + H3_EXCESSIVE_LOAD (0x0107), // 263 + + /** + * Stream ID or Push ID error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * A Stream ID or Push ID was used incorrectly, such as exceeding
+     * a limit, reducing a limit, or being reused.
+     * }
+ */ + H3_ID_ERROR (0x0108), // 264 + + /** + * Settings error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * An endpoint detected an error in the payload of a SETTINGS frame.
+     * }
+ */ + H3_SETTINGS_ERROR (0x0109), // 265 + + /** + * Missing settings error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * No SETTINGS frame was received at the beginning of the control
+     * stream.
+     * }
+ */ + H3_MISSING_SETTINGS (0x010a), // 266 + + /** + * Request rejected error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * A server rejected a request without performing any application
+     * processing.
+     * }
+ */ + H3_REQUEST_REJECTED (0x010b), // 267 + + /** + * Request cancelled error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * The request or its response (including pushed response) is
+     * cancelled.
+     * }
+ */ + H3_REQUEST_CANCELLED (0x010c), // 268 + + /** + * Request incomplete error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * The client's stream terminated without containing a
+     * fully-formed request.
+     * }
+ */ + H3_REQUEST_INCOMPLETE (0x010d), //269 + + /** + * Message error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * An HTTP message was malformed and cannot be processed.
+     * }
+ */ + H3_MESSAGE_ERROR (0x010e), // 270 + + /** + * Connect error. + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * The TCP connection established in response to a CONNECT
+     * request was reset or abnormally closed.
+     * }
+ */ + H3_CONNECT_ERROR (0x010f), // 271 + + /** + * Version fallback error + *

+ * From + * RFC 9114, Section 8.1: + *

{@code
+     * The requested operation cannot be served over HTTP/3.
+     * The peer should retry over HTTP/1.1.
+     * }
+ */ + H3_VERSION_FALLBACK (0x0110), // 272 + + /** + * QPack decompression error + *

+ * From + * RFC 9204, Section 6: + *

{@code
+     * The decoder failed to interpret an encoded field section
+     * and is not able to continue decoding that field section.
+     * }
+ */ + QPACK_DECOMPRESSION_FAILED (0x0200), // 512 + + /** + * Qpack encoder stream error. + *

+ * From + * RFC 9204, Section 6: + *

{@code
+     * The decoder failed to interpret an encoder instruction
+     * received on the encoder stream.
+     * }
+ */ + QPACK_ENCODER_STREAM_ERROR (0x0201), // 513 + + /** + * Qpack decoder stream error + *

+ * From + * RFC 9204, Section 6: + *

{@code
+     * The encoder failed to interpret a decoder instruction
+     * received on the decoder stream.
+     * }
+ */ + QPACK_DECODER_STREAM_ERROR (0x0202); // 514 + + final long errorCode; + Http3Error(long errorCode) { + this.errorCode = errorCode; + } + + public long code() { + return errorCode; + } + + public static Optional fromCode(long code) { + return Stream.of(values()).filter((v) -> v.code() == code) + .findFirst(); + } + + public static String stringForCode(long code) { + return fromCode(code).map(Http3Error::name).orElse(unknown(code)); + } + + private static String unknown(long code) { + return "UnknownError(code=0x" + HexFormat.of().withUpperCase().toHexDigits(code) + ")"; + } + + /** + * {@return true if the given code is {@link Http3Error#H3_NO_ERROR} or equivalent} + * Unknown error codes are treated as equivalent to {@code H3_NO_ERROR} + * @param code an HTTP/3 code error code + */ + public static boolean isNoError(long code) { + return fromCode(code).orElse(H3_NO_ERROR) == Http3Error.H3_NO_ERROR; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/AbstractHttp3Frame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/AbstractHttp3Frame.java new file mode 100644 index 00000000000..4cb8aa051fd --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/AbstractHttp3Frame.java @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3.frames; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.util.Random; + +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.BuffersReader; +import jdk.internal.net.http.quic.VariableLengthEncoder; +import static jdk.internal.net.http.http3.frames.Http3FrameType.asString; + +/** + * Super class for all HTTP/3 frames. + */ +public abstract non-sealed class AbstractHttp3Frame implements Http3Frame { + public static final Random RANDOM = new Random(); + final long type; + public AbstractHttp3Frame(long type) { + this.type = type; + } + + public final String typeAsString() { + return asString(type()); + } + + @Override + public long type() { + return type; + } + + + /** + * Computes the size of this frame. This corresponds to + * the {@linkplain #length()} of the frame's payload, plus the + * size needed to encode this length, plus the size needed to + * encode the frame type. + * + * @return the size of this frame. + */ + public long size() { + var len = length(); + return len + VariableLengthEncoder.getEncodedSize(len) + + VariableLengthEncoder.getEncodedSize(type()); + } + + public int headersSize() { + var len = length(); + return VariableLengthEncoder.getEncodedSize(len) + + VariableLengthEncoder.getEncodedSize(type()); + } + + @Override + public long streamingLength() { + return 0; + } + + protected static long decodeRequiredType(final BuffersReader reader, final long expectedType) { + final long type = VariableLengthEncoder.decode(reader); + if (type < 0) throw new BufferUnderflowException(); + // TODO: throw an exception instead? + assert type == expectedType : "bad frame type: " + type + " expected: " + expectedType; + return type; + } + + protected static MalformedFrame checkPayloadSize(long frameType, + BuffersReader reader, + long start, + long length) { + // check position after reading payload + long read = reader.position() - start; + if (length != read) { + reader.position(start + length); + reader.release(); + return new MalformedFrame(frameType, + Http3Error.H3_FRAME_ERROR.code(), + "payload length mismatch (length=%s, read=%s)" + .formatted(length, start)); + + } + + assert length == reader.position() - start; + return null; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(typeAsString()) + .append(": length=") + .append(length()); + return sb.toString(); + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/CancelPushFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/CancelPushFrame.java new file mode 100644 index 00000000000..f470efb9b27 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/CancelPushFrame.java @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3.frames; + +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.BuffersReader; +import jdk.internal.net.http.quic.VariableLengthEncoder; + +/** + * Represents the CANCEL_PUSH HTTP3 frame + */ +public final class CancelPushFrame extends AbstractHttp3Frame { + + public static final int TYPE = Http3FrameType.TYPE.CANCEL_PUSH_FRAME; + private final long length; + private final long pushId; + + public CancelPushFrame(final long pushId) { + super(Http3FrameType.CANCEL_PUSH.type()); + this.pushId = pushId; + // the payload length of this frame + this.length = VariableLengthEncoder.getEncodedSize(this.pushId); + } + + // only used when constructing the frame during decoding content over a stream + private CancelPushFrame(final long pushId, final long length) { + super(Http3FrameType.CANCEL_PUSH.type()); + this.pushId = pushId; + this.length = length; + } + + @Override + public long length() { + return this.length; + } + + public long getPushId() { + return pushId; + } + + public void writeFrame(final ByteBuffer buf) { + // write the type of the frame + VariableLengthEncoder.encode(buf, this.type); + // write the length of the payload + VariableLengthEncoder.encode(buf, this.length); + // write the push id that needs to be cancelled + VariableLengthEncoder.encode(buf, this.pushId); + } + + /** + * This method is expected to be called when the reader + * contains enough bytes to decode the frame. + * @param reader the reader + * @param debug a logger for debugging purposes + * @return the new frame + * @throws BufferUnderflowException if the reader doesn't contain + * enough bytes to decode the frame + */ + static AbstractHttp3Frame decodeFrame(final BuffersReader reader, final Logger debug) { + long position = reader.position(); + decodeRequiredType(reader, TYPE); + long length = VariableLengthEncoder.decode(reader); + if (length > reader.remaining() || length < 0) { + reader.position(position); + throw new BufferUnderflowException(); + } + // position before reading payload + long start = reader.position(); + if (length == 0 || length != VariableLengthEncoder.peekEncodedValueSize(reader, start)) { + // frame length does not match the enclosed pushId + return new MalformedFrame(TYPE, Http3Error.H3_FRAME_ERROR.code(), + "Invalid length in CANCEL_PUSH frame: " + length); + } + + long pushId = VariableLengthEncoder.decode(reader); + if (pushId == -1) { + reader.position(position); + throw new BufferUnderflowException(); + } + + // check position after reading payload + var malformed = checkPayloadSize(TYPE, reader, start, length); + if (malformed != null) return malformed; + + reader.release(); + return new CancelPushFrame(pushId); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/DataFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/DataFrame.java new file mode 100644 index 00000000000..97b475774b2 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/DataFrame.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3.frames; + + +/** + * This class models an HTTP/3 DATA frame. + * @apiNote + * An instance of {@code DataFrame} is used to read or writes + * the frame's type and length. The payload is supposed to be + * read or written directly to the stream on its own, after having + * read or written the frame type and length. + * @see PartialFrame + */ +public final class DataFrame extends PartialFrame { + + /** + * The DATA frame type, as defined by HTTP/3 + */ + public static final int TYPE = Http3FrameType.TYPE.DATA_FRAME; + + private final long length; + + /** + * Creates a new HTTP/3 HEADERS frame + */ + public DataFrame(long length) { + super(TYPE, length); + this.length = length; + } + + @Override + public long length() { + return length; + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/FramesDecoder.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/FramesDecoder.java new file mode 100644 index 00000000000..a51c71e0a05 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/FramesDecoder.java @@ -0,0 +1,331 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3.frames; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.LongPredicate; +import java.util.function.Supplier; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.http.quic.BuffersReader; +import jdk.internal.net.http.quic.BuffersReader.ListBuffersReader; + +/** + * A FramesDecoder accumulates buffers until a frame can be + * decoded. It also supports decoding {@linkplain PartialFrame + * partial frames} and {@linkplain #readPayloadBytes() reading + * their payload} incrementally. + * @apiNote + * When the frame decoder {@linkplain #poll() returns} a partial + * frame, the same frame will be returned until its payload has been + * {@linkplain PartialFrame#remaining() fully} {@linkplain #readPayloadBytes() + * read}. + * The caller is supposed to call {@link #readPayloadBytes()} until + * {@link #poll()} returns a different frame. At this point there will be no + * {@linkplain PartialFrame#remaining() remaining} payload bytes to read for + * the previous frame. + *
+ * The sequence of calls: {@snippet : + * framesDecoder.submit(buffer); + * while ((frame = framesDecoder.poll()) != null) { + * if (frame instanceof PartialFrame partial) { + * var nextPayloadBytes = framesDecoder.readPayloadBytes(); + * if (nextPayloadBytes == null || nextPayloadBytes.isEmpty()) { + * // no more data is available at this moment + * break; + * } + * // nextPayloadBytes are the next bytes for the payload + * // of the partial frame + * deliverBytes(partial, nextPayloadBytes); + * } else ... + * // got a full frame... + * } + * } + * makes it possible to incrementally deliver payload bytes for + * a frame - since {@code poll()} will always return the same partial + * frame until all its payload has been read. + */ +public class FramesDecoder { + + private final Logger debug = Utils.getDebugLogger(this::dbgTag); + private final ListBuffersReader framesReader = BuffersReader.list(); + private final ReentrantLock lock = new ReentrantLock(); + + private final Supplier dbgTag; + private final LongPredicate isAllowed; + + // the current partial frame or null + PartialFrame partialFrame; + boolean eof; + + /** + * A new {@code FramesDecoder} that accepts all frames. + * @param dbgTag a debug tag for logging + */ + public FramesDecoder(String dbgTag) { + this(dbgTag, FramesDecoder::allAllowed); + } + + /** + * A new {@code FramesDecoder} that accepts only frames + * authorized by the given {@code isAllowed} predicate. + * If a frame is not allowed, a {@link MalformedFrame} is + * returned. + * @param dbgTag a debug tag for logging + */ + public FramesDecoder(String dbgTag, LongPredicate isAllowed) { + this(() -> dbgTag, Objects.requireNonNull(isAllowed)); + } + + /** + * A new {@code FramesDecoder} that accepts only frames + * authorized by the given {@code isAllowed} predicate. + * If a frame is not allowed, a {@link MalformedFrame} is + * returned. + * @param dbgTag a debug tag for logging + */ + public FramesDecoder(Supplier dbgTag, LongPredicate isAllowed) { + this.dbgTag = dbgTag; + this.isAllowed = Objects.requireNonNull(isAllowed); + } + + String dbgTag() { return dbgTag.get(); } + + /** + * Submit a new buffer to this frames decoder + * @param buffer a new buffer from the stream + */ + public void submit(ByteBuffer buffer) { + lock.lock(); + try { + if (buffer == QuicStreamReader.EOF) { + eof = true; + } else { + framesReader.add(buffer); + } + } finally { + lock.unlock(); + } + } + + /** + * {@return an {@code Http3Frame}, possibly {@linkplain PartialFrame partial}, + * or {@code null} if not enough bytes have been receive to decode (at least + * partially) a frame} + * If a frame is illegal or not allowed, a {@link MalformedFrame} is + * returned. The caller is supposed to {@linkplain #clear() clear} all data + * and proceed to close the connection in that case. + */ + public Http3Frame poll() { + lock.lock(); + try { + if (partialFrame != null) { + if (partialFrame.remaining() != 0) { + return partialFrame; + } else partialFrame = null; + } + var frame = Http3Frame.decode(framesReader, this::isAllowed, debug); + if (frame instanceof PartialFrame partial) { + partialFrame = partial; + } + return frame; + } finally { + lock.unlock(); + } + } + + /** + * {@return the next payload bytes for the current partial frame, + * or {@code null} if no partial frame} + * If EOF has been reached ({@link QuicStreamReader#EOF EOF} was + * {@linkplain #submit(ByteBuffer) submitted}, and all buffers have + * been read, the returned list will contain {@link QuicStreamReader#EOF + * EOF} + */ + public List readPayloadBytes() { + lock.lock(); + try { + if (partialFrame == null || partialFrame.remaining() == 0) { + partialFrame = null; + return null; + } + if (eof && !framesReader.hasRemaining()) { + return List.of(QuicStreamReader.EOF); + } + return partialFrame.nextPayloadBytes(framesReader); + } finally { + lock.unlock(); + } + } + + /** + * {@return true if EOF has been reached and all buffers have been read} + */ + public boolean eof() { + lock.lock(); + try { + if (!eof) return false; + if (!framesReader.hasRemaining()) return true; + if (partialFrame != null) { + // still some payload data to read... + if (partialFrame.remaining() > 0) return false; + } + var pos = framesReader.position(); + try { + // if there's not enough data to decode a new frame or a new + // partial frame then since no more data will ever come, we do have + // reached EOF. If however, we can read a frame from the remaining + // data in the buffer, then EOF is not reached yet. + // The next call to poll() will return that frame. + var frame = Http3Frame.decode(framesReader, this::isAllowed, debug); + return frame == null; + } finally { + // restore position for the next call to poll. + framesReader.position(pos); + } + } finally { + lock.unlock(); + } + } + + /** + * {@return true if all buffers have been read} + */ + public boolean clean() { + lock.lock(); + try { + if (partialFrame != null) { + // still some payload data to read... + if (partialFrame.remaining() > 0) return false; + } + return !framesReader.hasRemaining(); + } finally { + lock.unlock(); + } + } + + /** + * Clears any unconsumed buffers. + */ + public void clear() { + lock.lock(); + try { + partialFrame = null; + framesReader.clear(); + } finally { + lock.unlock(); + } + } + + /** + * Can be overridden by subclasses to avoid parsing a frame + * fully if the frame is not allowed on this stream, or + * according to the stream state. + * + * @implSpec + * This method delegates to the {@linkplain #FramesDecoder(String, LongPredicate) + * predicate} given at construction time. If {@linkplain #FramesDecoder(String) + * no predicate} was given this method returns true. + * + * @param frameType the frame type + * @return true if the frame is allowed + */ + protected boolean isAllowed(long frameType) { + return isAllowed.test(frameType); + } + + /** + * A predicate that returns true for all frames types allowed + * on the server->client control stream. + * @param frameType a frame type + * @return whether a frame of this type is allowed on a control stream. + */ + public static boolean isAllowedOnControlStream(long frameType) { + if (frameType == Http3FrameType.DATA.type()) return false; + if (frameType == Http3FrameType.HEADERS.type()) return false; + if (frameType == Http3FrameType.PUSH_PROMISE.type()) return false; + if (frameType == Http3FrameType.MAX_PUSH_ID.type()) return false; + if (Http3FrameType.isIllegalType(frameType)) return false; + return true; + } + + /** + * A predicate that returns true for all frames types allowed + * on the client->server control stream. + * @param frameType a frame type + * @return whether a frame of this type is allowed on a control stream. + */ + public static boolean isAllowedOnClientControlStream(long frameType) { + if (frameType == Http3FrameType.DATA.type()) return false; + if (frameType == Http3FrameType.HEADERS.type()) return false; + if (frameType == Http3FrameType.PUSH_PROMISE.type()) return false; + if (Http3FrameType.isIllegalType(frameType)) return false; + return true; + } + + /** + * A predicate that returns true for all frames types allowed + * on a request/response stream. + * @param frameType a frame type + * @return whether a frame of this type is allowed on a request/response + * stream. + */ + public static boolean isAllowedOnRequestStream(long frameType) { + if (frameType == Http3FrameType.SETTINGS.type()) return false; + if (frameType == Http3FrameType.CANCEL_PUSH.type()) return false; + if (frameType == Http3FrameType.GOAWAY.type()) return false; + if (frameType == Http3FrameType.MAX_PUSH_ID.type()) return false; + if (Http3FrameType.isIllegalType(frameType)) return false; + return true; + } + + + /** + * A predicate that returns true for all frames types allowed + * on a push promise stream. + * @param frameType a frame type + * @return whether a frame of this type is allowed on a request/response + * stream. + */ + public static boolean isAllowedOnPromiseStream(long frameType) { + if (frameType == Http3FrameType.SETTINGS.type()) return false; + if (frameType == Http3FrameType.CANCEL_PUSH.type()) return false; + if (frameType == Http3FrameType.GOAWAY.type()) return false; + if (frameType == Http3FrameType.MAX_PUSH_ID.type()) return false; + if (frameType == Http3FrameType.PUSH_PROMISE.type()) return false; + if (Http3FrameType.isIllegalType(frameType)) return false; + return true; + } + + private static boolean allAllowed(long frameType) { + return true; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/GoAwayFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/GoAwayFrame.java new file mode 100644 index 00000000000..274a1af3a56 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/GoAwayFrame.java @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.http3.frames; + +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.BuffersReader; +import jdk.internal.net.http.quic.VariableLengthEncoder; + +/** + * Represents a GOAWAY HTTP3 frame + */ +public final class GoAwayFrame extends AbstractHttp3Frame { + + public static final int TYPE = Http3FrameType.TYPE.GOAWAY_FRAME; + private final long length; + // represents either a stream id or a push id depending on the context + // of the frame + private final long id; + + public GoAwayFrame(final long id) { + super(TYPE); + this.id = id; + // the payload length of this frame + this.length = VariableLengthEncoder.getEncodedSize(this.id); + } + + // only used when constructing the frame during decoding content over a stream + private GoAwayFrame(final long length, final long id) { + super(Http3FrameType.GOAWAY.type()); + this.length = length; + this.id = id; + } + + @Override + public long length() { + return this.length; + } + + /** + * {@return the id of either the stream or a push promise, depending on the context + * of this frame} + */ + public long getTargetId() { + return this.id; + } + + public void writeFrame(final ByteBuffer buf) { + // write the type of the frame + VariableLengthEncoder.encode(buf, this.type); + // write the length of the payload + VariableLengthEncoder.encode(buf, this.length); + // write the stream id/push id + VariableLengthEncoder.encode(buf, this.id); + } + + static AbstractHttp3Frame decodeFrame(final BuffersReader reader, final Logger debug) { + final long position = reader.position(); + // read the frame type + decodeRequiredType(reader, Http3FrameType.GOAWAY.type()); + // read length of the payload + final long length = VariableLengthEncoder.decode(reader); + if (length < 0 || length > reader.remaining()) { + reader.position(position); + throw new BufferUnderflowException(); + } + // position before reading payload + long start = reader.position(); + + if (length == 0 || length != VariableLengthEncoder.peekEncodedValueSize(reader, start)) { + // frame length does not match the enclosed targetId + return new MalformedFrame(TYPE, + Http3Error.H3_FRAME_ERROR.code(), + "Invalid length in GOAWAY frame: " + length); + } + + // read stream id / push id + final long targetId = VariableLengthEncoder.decode(reader); + if (targetId == -1) { + reader.position(position); + throw new BufferUnderflowException(); + } + + // check position after reading payload + var malformed = checkPayloadSize(TYPE, reader, start, length); + if (malformed != null) return malformed; + + reader.release(); + return new GoAwayFrame(length, targetId); + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(); + sb.append(super.toString()).append(" stream/push id: ").append(this.id); + return sb.toString(); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/HeadersFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/HeadersFrame.java new file mode 100644 index 00000000000..5d7672458a3 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/HeadersFrame.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3.frames; + +/** + * This class models an HTTP/3 HEADERS frame. + * @apiNote + * An instance of {@code HeadersFrame} is used to read or writes + * the frame's type and length. The payload is supposed to be + * read or written directly to the stream on its own, after having + * read or written the frame type and length. + * @see jdk.internal.net.http.http3.frames.PartialFrame + */ +public final class HeadersFrame extends PartialFrame { + + /** + * The HEADERS frame type, as defined by HTTP/3 + */ + public static final int TYPE = Http3FrameType.TYPE.HEADERS_FRAME; + + + private final long length; + + /** + * Creates a new HTTP/3 HEADERS frame + */ + public HeadersFrame(long length) { + super(TYPE, length); + this.length = length; + } + + @Override + public long length() { + return length; + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/Http3Frame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/Http3Frame.java new file mode 100644 index 00000000000..c4b234553ad --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/Http3Frame.java @@ -0,0 +1,214 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3.frames; + +import java.util.function.LongPredicate; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.BuffersReader; +import jdk.internal.net.http.quic.VariableLengthEncoder; + +import static jdk.internal.net.http.http3.frames.Http3FrameType.DATA; +import static jdk.internal.net.http.http3.frames.Http3FrameType.HEADERS; +import static jdk.internal.net.http.http3.frames.Http3FrameType.PUSH_PROMISE; +import static jdk.internal.net.http.http3.frames.Http3FrameType.UNKNOWN; +import static jdk.internal.net.http.http3.frames.Http3FrameType.asString; +import static jdk.internal.net.http.http3.frames.Http3FrameType.isIllegalType; + +/** + * An HTTP/3 frame + */ +public sealed interface Http3Frame permits AbstractHttp3Frame { + + /** + * {@return the type of this frame} + */ + long type(); + + /** + * {@return the length of this frame} + */ + long length(); + + + /** + * {@return the portion of the frame payload that can be read + * after the frame was created, when the current frame + * can be read as a partial frame, otherwise 0, if + * the payload can't be streamed} + */ + default long streamingLength() { return 0;} + + /** + * Attempts to decode an HTTP/3 frame from the bytes accumulated + * in the reader. + * + * @apiNote + * + * If an error is detected while parsing the frame, a {@link MalformedFrame} + * error will be returned + * + * @param reader the reader containing the bytes + * @param isFrameTypeAllowed a predicate to test whether a given + * frame type is allowed in this context + * @param debug a logger to log debug traces + * @return the decoded frame, or {@code null} if some bytes are + * missing to decode the frame + */ + static Http3Frame decode(BuffersReader reader, LongPredicate isFrameTypeAllowed, Logger debug) { + long pos = reader.position(); + long limit = reader.limit(); + long remaining = reader.remaining(); + long type = -1; + long before = reader.read(); + Http3Frame frame; + try { + int tsize = VariableLengthEncoder.peekEncodedValueSize(reader, pos); + if (tsize == -1 || remaining - tsize < 0) return null; + type = VariableLengthEncoder.peekEncodedValue(reader, pos); + if (type == -1) return null; + if (isIllegalType(type) || !isFrameTypeAllowed.test(type)) { + var msg = "H3_FRAME_UNEXPECTED: Frame " + + asString(type) + + " is not allowed on this stream"; + if (debug.on()) debug.log(msg); + frame = new MalformedFrame(type, Http3Error.H3_FRAME_UNEXPECTED.code(), msg); + reader.clear(); + return frame; + } + + int lsize = VariableLengthEncoder.peekEncodedValueSize(reader, pos + tsize); + if (lsize == -1 || remaining - tsize - lsize < 0) return null; + final long length = VariableLengthEncoder.peekEncodedValue(reader, pos + tsize); + var frameType = Http3FrameType.forType(type); + if (debug.on()) { + debug.log("Decoding %s(length=%s)", frameType, length); + } + if (frameType == UNKNOWN) { + if (debug.on()) { + debug.log("decode partial unknown frame: " + + "pos:%s, limit:%s, remaining:%s," + + " tsize:%s, lsize:%s, length:%s", + pos, limit, remaining, tsize, lsize, length); + } + reader.position(pos + tsize + lsize); + reader.release(); + return new UnknownFrame(type, length); + } else if (frameType.maxLength() < length) { + var msg = "H3_FRAME_ERROR: Frame " + asString(type) + " length too long"; + if (debug.on()) debug.log(msg); + frame = new MalformedFrame(type, Http3Error.H3_FRAME_ERROR.code(), msg); + reader.clear(); + return frame; + } + + if (frameType == HEADERS) { + if (length == 0) { + var msg = "H3_FRAME_ERROR: Frame " + asString(type) + " does not contain headers"; + if (debug.on()) debug.log(msg); + frame = new MalformedFrame(type, Http3Error.H3_FRAME_ERROR.code(), msg); + reader.clear(); + return frame; + } + reader.position(pos + tsize + lsize); + reader.release(); + return new HeadersFrame(length); + } + + if (frameType == DATA) { + reader.position(pos + tsize + lsize); + reader.release(); + return new DataFrame(length); + } + + if (frameType == PUSH_PROMISE) { + int pidsize = VariableLengthEncoder.peekEncodedValueSize(reader, pos + tsize + lsize); + if (length == 0 || length < pidsize) { + var msg = "H3_FRAME_ERROR: Frame " + asString(type) + " length too short to fit pushID"; + if (debug.on()) debug.log(msg); + frame = new MalformedFrame(type, Http3Error.H3_FRAME_ERROR.code(), msg); + reader.clear(); + return frame; + } + if (length == pidsize) { + var msg = "H3_FRAME_ERROR: Frame " + asString(type) + " does not contain headers"; + if (debug.on()) debug.log(msg); + frame = new MalformedFrame(type, Http3Error.H3_FRAME_ERROR.code(), msg); + reader.clear(); + return frame; + } + if (pidsize == -1 || remaining - tsize - lsize - pidsize < 0) return null; + long pushId = VariableLengthEncoder.peekEncodedValue(reader, pos + tsize + lsize); + reader.position(pos + tsize + lsize + pidsize); + reader.release(); + return new PushPromiseFrame(pushId, length - pidsize); + } + + if (length + tsize + lsize > reader.remaining()) { + // we haven't moved the reader's position. + // we'll be called back when new bytes are available and + // we'll resume reading type + length from the same position + // again, until we have enough to read the frame. + return null; + } + + assert isFrameTypeAllowed.test(type); + + frame = switch(frameType) { + case SETTINGS -> SettingsFrame.decodeFrame(reader, debug); + case GOAWAY -> GoAwayFrame.decodeFrame(reader, debug); + case CANCEL_PUSH -> CancelPushFrame.decodeFrame(reader, debug); + case MAX_PUSH_ID -> MaxPushIdFrame.decodeFrame(reader, debug); + default -> { + reader.position(pos + tsize + lsize); + reader.release(); + yield new UnknownFrame(type, length); + } + }; + + long read; + if (frame instanceof MalformedFrame || frame == null) { + return frame; + } else if ((read = (reader.read() - before - tsize - lsize)) != length) { + String msg = ("H3_FRAME_ERROR: Frame %s payload length does not match" + + " frame length (length=%s, payload=%s)") + .formatted(asString(type), length, read); + if (debug.on()) debug.log(msg); + reader.release(); // mark reader read + reader.position(reader.position() + tsize + lsize + length); + reader.release(); + return new MalformedFrame(type, Http3Error.H3_FRAME_ERROR.code(), msg); + } else { + return frame; + } + } catch (Throwable t) { + if (debug.on()) debug.log("Failed to decode frame", t); + reader.clear(); // mark reader read + return new MalformedFrame(type, Http3Error.H3_INTERNAL_ERROR.code(), t.getMessage(), t); + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/Http3FrameType.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/Http3FrameType.java new file mode 100644 index 00000000000..858e10fa6a0 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/Http3FrameType.java @@ -0,0 +1,201 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3.frames; +import java.util.stream.Stream; + +import static jdk.internal.net.http.quic.VariableLengthEncoder.MAX_ENCODED_INTEGER; +import static jdk.internal.net.http.quic.VariableLengthEncoder.MAX_INTEGER_LENGTH; + +/** + * An enum to model HTTP/3 frame types. + */ +public enum Http3FrameType { + + /** + * Used to identify an HTTP/3 frame whose type is unknown + */ + UNKNOWN(-1, MAX_ENCODED_INTEGER), + /** + * Used to identify an HTTP/3 DATA frame + */ + DATA(TYPE.DATA_FRAME, MAX_ENCODED_INTEGER), + /** + * Used to identify an HTTP/3 HEADERS frame + */ + HEADERS(TYPE.HEADERS_FRAME, MAX_ENCODED_INTEGER), + /** + * Used to identify an HTTP/3 CANCEL_PUSH frame + */ + CANCEL_PUSH(TYPE.CANCEL_PUSH_FRAME, MAX_INTEGER_LENGTH), + /** + * Used to identify an HTTP/3 SETTINGS frame + */ + SETTINGS(TYPE.SETTINGS_FRAME, TYPE.MAX_SETTINGS_LENGTH), + /** + * Used to identify an HTTP/3 PUSH_PROMISE frame + */ + PUSH_PROMISE(TYPE.PUSH_PROMISE_FRAME, MAX_ENCODED_INTEGER), + /** + * Used to identify an HTTP/3 GOAWAY frame + */ + GOAWAY(TYPE.GOAWAY_FRAME, MAX_INTEGER_LENGTH), + /** + * Used to identify an HTTP/3 MAX_PUSH_ID_FRAME frame + */ + MAX_PUSH_ID(TYPE.MAX_PUSH_ID_FRAME, MAX_INTEGER_LENGTH); + + /** + * A class to hold type constants + */ + static final class TYPE { + private TYPE() { throw new InternalError(); } + + // Frames types + public static final int DATA_FRAME = 0x00; + public static final int HEADERS_FRAME = 0x01; + public static final int CANCEL_PUSH_FRAME = 0x03; + public static final int SETTINGS_FRAME = 0x04; + public static final int PUSH_PROMISE_FRAME = 0x05; + public static final int GOAWAY_FRAME = 0x07; + public static final int MAX_PUSH_ID_FRAME = 0x0d; + + // The maximum size a settings frame can have. + // This is a limit imposed by our implementation. + // There are only 7 settings defined in the current + // specification, but we will allow for a frame to + // contain up to 80. Past that limit, we will consider + // the frame to be malformed: + // 8 x 10 x (max sizeof(id) + max sizeof(value)) = 80 x 16 bytes + public static final long MAX_SETTINGS_LENGTH = + 10L * 8L * MAX_INTEGER_LENGTH * 2L; + } + + + // This is one of the values defined in TYPE above, or + // -1 for the UNKNOWN frame types. + private final int type; + private final long maxLength; + private Http3FrameType(int type, long maxLength) { + this.type = type; + this.maxLength = maxLength; + } + + /** + * {@return the frame type, as defined by HTTP/3} + */ + public long type() { return type;} + + /** + * {@return the maximum length a frame of this type + * can take} + */ + public long maxLength() { + return maxLength; + } + + /** + * {@return the HTTP/3 frame type, as an int} + * + * @apiNote + * HTTP/3 defines frames type as variable length integers + * in the range [0, 2^62-1]. However, the few standard frame + * types registered for HTTP/3 and modeled by this enum + * class can be coded as an int. + * This method provides a convenient way to access the frame + * type as an int, which avoids having to cast when using + * the value in switch statements. + */ + public int intType() { return type;} + + /** + * {@return the {@link Http3FrameType} corresponding to the given + * {@code type}, or {@link #UNKNOWN} if no corresponding + * {@link Http3FrameType} instance is found} + * @param type an HTTP/3 frame type identifier read from an HTTP/3 frame + */ + public static Http3FrameType forType(long type) { + return Stream.of(values()) + .filter(x -> x.type == type) + .findFirst() + .orElse(UNKNOWN); + } + + /** + * {@return a string representation of the given type, suited for inclusion + * in log messages, exceptions, etc...} + * @param type an HTTP/3 frame type identifier read from an HTTP/3 frame + */ + public static String asString(long type) { + String str = null; + if (type >= Integer.MIN_VALUE && type <= Integer.MAX_VALUE) { + str = switch ((int)type) { + case TYPE.DATA_FRAME -> DATA.name(); // 0x00 + case TYPE.HEADERS_FRAME -> HEADERS.name(); // 0x01 + case 0x02 -> "RESERVED(0x02)"; + case TYPE.CANCEL_PUSH_FRAME -> CANCEL_PUSH.name(); // 0x03 + case TYPE.SETTINGS_FRAME -> SETTINGS.name(); // 0x04 + case TYPE.PUSH_PROMISE_FRAME -> PUSH_PROMISE.name(); // 0x05 + case 0x06 -> "RESERVED(0x06)"; + case TYPE.GOAWAY_FRAME -> GOAWAY.name(); // 0x07 + case 0x08 -> "RESERVED(0x08)"; + case 0x09 -> "RESERVED(0x09)"; + case TYPE.MAX_PUSH_ID_FRAME -> MAX_PUSH_ID.name(); // 0x0d + default -> null; + }; + } + if (str != null) return str; + if (isReservedType(type)) { + return "RESERVED(type=" + type + ")"; + } + return "UNKNOWN(type=" + type + ")"; + } + + /** + * {@return whether this frame type is illegal} + * This corresponds to HTTP/2 frame types that have no equivalent in + * HTTP/3. + * @param type the frame type + */ + public static boolean isIllegalType(long type) { + return type == 0x02 || type == 0x06 || type == 0x08 || type == 0x09; + } + + /** + * Whether the given type is one of the reserved frame + * types defined by HTTP/3. For any non-negative integer N: + * {@code 0x21 + 0x1f * N } + * is a reserved frame type that has no meaning. + * + * @param type an HTTP/3 frame type identifier read from an HTTP/3 frame + * + * @return true if the given type matches the {@code 0x21 + 0x1f * N} + * pattern + */ + public static boolean isReservedType(long type) { + return type >= 0x21L && type <= MAX_ENCODED_INTEGER + && (type - 0x21L) % 0x1f == 0; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/MalformedFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/MalformedFrame.java new file mode 100644 index 00000000000..0d89666ec26 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/MalformedFrame.java @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3.frames; + +import java.util.function.LongPredicate; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.quic.BuffersReader; + +import jdk.internal.net.http.http3.Http3Error; + +/** + * An instance of MalformedFrame can be returned by + * {@link AbstractHttp3Frame#decode(BuffersReader, LongPredicate, Logger)} + * when a malformed frame is detected. This should cause the caller + * to send an error to its peer, and possibly throw an + * exception to the higher layer. + */ +public class MalformedFrame extends AbstractHttp3Frame { + + private final long errorCode; + private final String msg; + private final Throwable cause; + + /** + * Creates Connection Error malformed frame + * + * @param errorCode - error code + * @param msg - internal debug message + */ + public MalformedFrame(long type, long errorCode, String msg) { + this(type, errorCode, msg, null); + } + + /** + * Creates Connection Error malformed frame + * + * @param errorCode - error code + * @param msg - internal debug message + * @param cause - internal cause for the error, if available + * (can be null) + */ + public MalformedFrame(long type, long errorCode, String msg, Throwable cause) { + super(type); + this.errorCode = errorCode; + this.msg = msg; + this.cause = cause; + } + + @Override + public String toString() { + return super.toString() + " MalformedFrame, Error: " + + Http3Error.stringForCode(errorCode) + + " reason: " + msg; + } + + /** + * {@inheritDoc} + * @implSpec this method always returns 0 + */ + @Override + public long length() { + return 0; // Not Applicable + } + + /** + * {@inheritDoc} + * @implSpec this method always returns 0 + */ + @Override + public long size() { + return 0; // Not applicable + } + + /** + * {@return the {@linkplain Http3Error#code() HTTP/3 error code} that + * should be reported to the peer} + */ + public long getErrorCode() { + return errorCode; + } + + /** + * {@return a message that describe the error} + */ + public String getMessage() { + return msg; + } + + /** + * {@return the cause of the error, if available, {@code null} otherwise} + * + * @apiNote + * This is useful for logging and diagnosis purpose, typically when the + * error is an {@linkplain Http3Error#H3_INTERNAL_ERROR internal error}. + */ + public Throwable getCause() { + return cause; + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/MaxPushIdFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/MaxPushIdFrame.java new file mode 100644 index 00000000000..b4f35064da1 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/MaxPushIdFrame.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.http3.frames; + +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.BuffersReader; +import jdk.internal.net.http.quic.VariableLengthEncoder; + +/** + * Represents a MAX_PUSH_ID HTTP3 frame + */ +public final class MaxPushIdFrame extends AbstractHttp3Frame { + + public static final int TYPE = Http3FrameType.TYPE.MAX_PUSH_ID_FRAME; + + private final long length; + private final long maxPushId; + + public MaxPushIdFrame(final long maxPushId) { + super(Http3FrameType.MAX_PUSH_ID.type()); + this.maxPushId = maxPushId; + // the payload length of this frame + this.length = VariableLengthEncoder.getEncodedSize(this.maxPushId); + } + + // only used when constructing the frame during decoding content over a stream + private MaxPushIdFrame(final long maxPushId, final long length) { + super(Http3FrameType.MAX_PUSH_ID.type()); + this.maxPushId = maxPushId; + this.length = length; + } + + @Override + public long length() { + return this.length; + } + + public long getMaxPushId() { + return this.maxPushId; + } + + public void writeFrame(final ByteBuffer buf) { + // write the type of the frame + VariableLengthEncoder.encode(buf, this.type); + // write the length of the payload + VariableLengthEncoder.encode(buf, this.length); + // write the max push id value + VariableLengthEncoder.encode(buf, this.maxPushId); + } + + /** + * This method is expected to be called when the reader + * contains enough bytes to decode the frame. + * @param reader the reader + * @param debug a logger for debugging purposes + * @return the new frame + * @throws BufferUnderflowException if the reader doesn't contain + * enough bytes to decode the frame + */ + static AbstractHttp3Frame decodeFrame(final BuffersReader reader, final Logger debug) { + long position = reader.position(); + decodeRequiredType(reader, TYPE); + long length = VariableLengthEncoder.decode(reader); + if (length > reader.remaining() || length < 0) { + reader.position(position); + throw new BufferUnderflowException(); + } + // position before reading payload + long start = reader.position(); + + if (length == 0 || length != VariableLengthEncoder.peekEncodedValueSize(reader, start)) { + // frame length does not match the enclosed maxPushId + return new MalformedFrame(TYPE, Http3Error.H3_FRAME_ERROR.code(), + "Invalid length in MAX_PUSH_ID frame: " + length); + } + + long maxPushId = VariableLengthEncoder.decode(reader); + if (maxPushId == -1) { + reader.position(position); + throw new BufferUnderflowException(); + } + + // check position after reading payload + var malformed = checkPayloadSize(TYPE, reader, start, length); + if (malformed != null) return malformed; + + reader.release(); + return new MaxPushIdFrame(maxPushId); + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/PartialFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/PartialFrame.java new file mode 100644 index 00000000000..3cbbb814b82 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/PartialFrame.java @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3.frames; + +import java.nio.ByteBuffer; +import java.util.List; + +import jdk.internal.net.http.quic.VariableLengthEncoder; +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.http.quic.BuffersReader; + +/** + * A PartialFrame helps to read the payload of a frame. + * This class is not multi-thread safe. + */ +public abstract sealed class PartialFrame + extends AbstractHttp3Frame + permits HeadersFrame, + DataFrame, + PushPromiseFrame, + UnknownFrame { + + private static final List NONE = List.of(); + private final long streamingLength; + private long remaining; + PartialFrame(long frameType, long streamingLength) { + super(frameType); + this.remaining = this.streamingLength = streamingLength; + } + + @Override + public final long streamingLength() { + return streamingLength; + } + + /** + * {@return the number of payload bytes that remains to read} + */ + public final long remaining() { + return remaining; + } + + /** + * Reads remaining payload bytes from the given {@link BuffersReader}. + * This method must not run concurrently with any code that submit + * new buffers to the {@link BuffersReader}. + * @param buffers a {@link BuffersReader} that contains payload bytes. + * @return the payload bytes available so far, an empty list if no + * bytes are available or the whole payload has already been + * read + */ + public final List nextPayloadBytes(BuffersReader buffers) { + var remaining = this.remaining; + if (remaining > 0) { + long available = buffers.remaining(); + if (available > 0) { + long read = Math.min(remaining, available); + this.remaining = remaining - read; + return buffers.getAndRelease(read); + } + } + return NONE; + } + + /** + * Reads remaining payload bytes from the given {@link ByteBuffer}. + * @param buffer a {@link ByteBuffer} that contains payload bytes. + * @return the payload bytes available in the given buffer, or + * {@code null} if all payload has been read. + */ + public final ByteBuffer nextPayloadBytes(ByteBuffer buffer) { + var remaining = this.remaining; + if (remaining > 0) { + int available = buffer.remaining(); + if (available > 0) { + long read = Math.min(remaining, available); + remaining -= read; + this.remaining = remaining; + assert read <= available; + int pos = buffer.position(); + int len = (int) read; + // always create a slice, so that we can move the position + // of the original buffer, as if the data had been read. + ByteBuffer next = buffer.slice(pos, len); + buffer.position(pos + len); + return next; + } else return buffer == QuicStreamReader.EOF ? buffer : buffer.slice(); + } + return null; + } + + /** + * Write the frame headers to the given buffer. + * + * @apiNote + * The caller will be responsible for writing the + * remaining {@linkplain #length() length} bytes of + * the frame content after writing the frame headers. + * + * @implSpec + * Usually the header of a frame is assumed to simply + * contain the frame type and frame length. + * Some subclasses of {@code AbstractHttp3Frame} may + * however include some additional information. + * For instance, {@link PushPromiseFrame} may consider + * the {@link PushPromiseFrame#getPushId() pushId} as + * being in part of the headers, and write it along + * in this method after the frame type and length. + * In such a case, a subclass would also need to + * override {@link #headersSize()} in order to add + * the size of the additional information written + * by {@link #writeHeaders(ByteBuffer)}. + * + * @param buf a buffer to write the headers into + */ + public void writeHeaders(ByteBuffer buf) { + long len = length(); + int pos0 = buf.position(); + VariableLengthEncoder.encode(buf, type()); + VariableLengthEncoder.encode(buf, len); + int pos1 = buf.position(); + assert pos1 - pos0 == super.headersSize(); + } + + @Override + public String toString() { + var len = length(); + return "%s (partial: %s/%s)".formatted(this.getClass().getSimpleName(), len - remaining, len); + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/PushPromiseFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/PushPromiseFrame.java new file mode 100644 index 00000000000..f95fa7f964d --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/PushPromiseFrame.java @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.http3.frames; + +import java.nio.ByteBuffer; + +import jdk.internal.net.http.quic.VariableLengthEncoder; + +/** + * Represents a PUSH_PROMISE HTTP3 frame + */ +public final class PushPromiseFrame extends PartialFrame { + + /** + * The PUSH_PROMISE frame type, as defined by HTTP/3 + */ + public static final int TYPE = Http3FrameType.TYPE.PUSH_PROMISE_FRAME; + + private final long length; + private final long pushId; + + public PushPromiseFrame(final long pushId, final long fieldLength) { + super(TYPE, fieldLength); + if (pushId < 0 || pushId > VariableLengthEncoder.MAX_ENCODED_INTEGER) { + throw new IllegalArgumentException("invalid pushId: " + pushId); + } + this.pushId = pushId; + // the payload length of this frame + this.length = VariableLengthEncoder.getEncodedSize(this.pushId) + fieldLength; + } + + @Override + public long length() { + return this.length; + } + + public long getPushId() { + return this.pushId; + } + + /** + * Write the frame header and the promise {@link #getPushId() + * pushId} to the given buffer. The caller will be responsible + * for writing the remaining {@link #streamingLength()} bytes + * that constitutes the field section length. + * @param buf a buffer to write the headers into + */ + @Override + public void writeHeaders(ByteBuffer buf) { + super.writeHeaders(buf); + VariableLengthEncoder.encode(buf, this.pushId); + } + + /** + * {@return the number of bytes needed to write the headers and + * the promised {@link #getPushId() pushId}}. + */ + @Override + public int headersSize() { + return super.headersSize() + VariableLengthEncoder.getEncodedSize(pushId); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/SettingsFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/SettingsFrame.java new file mode 100644 index 00000000000..90fabd47e68 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/SettingsFrame.java @@ -0,0 +1,364 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.http3.frames; + +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.util.Arrays; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.BuffersReader; +import jdk.internal.net.http.quic.VariableLengthEncoder; +import static jdk.internal.net.http.quic.VariableLengthEncoder.MAX_ENCODED_INTEGER; + +/** + * This class models an HTTP/3 SETTINGS frame + */ +public class SettingsFrame extends AbstractHttp3Frame { + + // An array of setting parameters. + // The index is the parameter id, minus 1, the value is the parameter value + private final long[] parameters; + // HTTP/3 specifies some reserved identifier for which the parameter + // has no semantics and the value is undefined and should be ignored. + // It's excepted that at least one such parameter should be included + // in the settings frame to exercise the fact that undefined parameters + // should be ignored + private long undefinedId; + private long undefinedValue; + + /** + * The SETTINGS frame type, as defined by HTTP/3 + */ + public static final int TYPE = Http3FrameType.TYPE.SETTINGS_FRAME; + + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(super.toString()) + .append(" Settings: "); + + for (int i = 0; i < MAX_PARAM; i++) { + if (parameters[i] != -1) { + sb.append(name(i+1)) + .append("=") + .append(parameters[i]) + .append(' '); + } + } + if (undefinedId != -1) { + sb.append(name(undefinedId)).append("=") + .append(undefinedValue).append(' '); + } + return sb.toString(); + } + + // TODO: should we use an enum instead? + // HTTP/2 only Parameters - receiving one of those should be + // considered as a protocol error of type SETTINGS_ERROR + public static final int ENABLE_PUSH = 0x2; + public static final int MAX_CONCURRENT_STREAMS = 0x3; + public static final int INITIAL_WINDOW_SIZE = 0x4; + public static final int MAX_FRAME_SIZE = 0x5; + // HTTP/3 Parameters + // This parameter was defined as HEADER_TABLE_SIZE in HTTP/2 + public static final int SETTINGS_QPACK_MAX_TABLE_CAPACITY = 0x1; + public static final int DEFAULT_SETTINGS_QPACK_MAX_TABLE_CAPACITY = 0; + // This parameter was defined as MAX_HEADER_LIST_SIZE in HTTP/2 + public static final int SETTINGS_MAX_FIELD_SECTION_SIZE = 0x6; + public static final long DEFAULT_SETTINGS_MAX_FIELD_SECTION_SIZE = -1; + // Allow compression efficiency by allowing referencing dynamic table entries + // that are still in transit. This parameter specifies the number of streams + // that could become blocked. + public static final int SETTINGS_QPACK_BLOCKED_STREAMS = 0x7; + public static final int DEFAULT_SETTINGS_QPACK_BLOCKED_STREAMS = 0; + + public static final int MAX_PARAM = 0x7; + + // maps a parameter id to a parameter name + private String name(long i) { + if (i <= MAX_PARAM) { + return switch ((int)i) { + case SETTINGS_QPACK_MAX_TABLE_CAPACITY -> "SETTINGS_QPACK_MAX_TABLE_CAPACITY"; // 0x01 + case ENABLE_PUSH -> "ENABLE_PUSH"; // 0x02 + case MAX_CONCURRENT_STREAMS -> "MAX_CONCURRENT_STREAMS"; // 0x03 + case INITIAL_WINDOW_SIZE -> "INITIAL_WINDOW_SIZE"; // 0x04 + case MAX_FRAME_SIZE -> "MAX_FRAME_SIZE"; // 0x05 + case SETTINGS_MAX_FIELD_SECTION_SIZE -> "SETTINGS_MAX_FIELD_SECTION_SIZE"; // 0x06 + case SETTINGS_QPACK_BLOCKED_STREAMS -> "SETTINGS_QPACK_BLOCKED_STREAMS"; // 0x07 + default -> "UNKNOWN(0x00)"; // 0x00 ? + }; + } else if (isReservedId(i)) { + return "RESERVED(" + i + ")"; + } else { + return "UNKNOWN(" + i +")"; + } + } + + /** + * Creates a new HTTP/3 SETTINGS frame, including the given + * reserved identifier id and value pair. + * + * @implNote + * We only keep one reserved id/value pair - there's no + * reason to keep more... + * + * @param undefinedId the id of an undefined (reserved) parameter + * @param undefinedValue a random value for the undefined parameter + */ + public SettingsFrame(long undefinedId, long undefinedValue) { + super(TYPE); + parameters = new long [MAX_PARAM]; + Arrays.fill(parameters, -1); + assert undefinedId == -1 || isReservedId(undefinedId); + assert undefinedId != -1 || undefinedValue == -1; + this.undefinedId = undefinedId; + this.undefinedValue = undefinedValue; + } + + /** + * Creates a new empty SETTINGS frame, and allocate a random + * reserved id and value pair. + */ + public SettingsFrame() { + this(nextRandomReservedParameterId(), nextRandomParameterValue()); + } + + /** + * Get the parameter value for the given parameter id + * + * @param paramID the parameter id + * + * @return the value of the given parameter, if present, + * {@code -1}, if absent + * + * @throws IllegalArgumentException if the parameter id is negative or + * {@linkplain #isIllegal(long) illegal} + * + */ + public synchronized long getParameter(int paramID) { + if (isIllegal(paramID)) { + throw new IllegalArgumentException("illegal parameter: " + paramID); + } + if (undefinedId != -1 && paramID == undefinedId) + return undefinedValue; + if (paramID > MAX_PARAM) return -1; + return parameters[paramID - 1]; + } + + /** + * Sets the given parameter to the given value. + * + * @param paramID the parameter id + * @param value the parameter value + * + * @return this + * + * @throws IllegalArgumentException if the parameter id is negative or + * {@linkplain #isIllegal(long) illegal} + */ + public synchronized SettingsFrame setParameter(long paramID, long value) { + // subclasses can override this to actually send + // an illegal parameter + if (isIllegal(paramID) || paramID < 1 || paramID > MAX_ENCODED_INTEGER) { + throw new IllegalArgumentException("illegal parameter: " + paramID); + } + if (paramID <= MAX_PARAM) { + parameters[(int)paramID - 1] = value; + } else if (isReservedId(paramID)) { + this.undefinedId = paramID; + this.undefinedValue = value; + } + return this; + } + + @Override + public long length() { + int len = 0; + int i = 0; + for (long p : parameters) { + if (p != -1) { + len += VariableLengthEncoder.getEncodedSize(i+1); + len += VariableLengthEncoder.getEncodedSize(p); + } + } + if (undefinedId != -1) { + assert isReservedId(undefinedId); + len += VariableLengthEncoder.getEncodedSize(undefinedId); + len += VariableLengthEncoder.getEncodedSize(undefinedValue); + } + return len; + } + + /** + * Writes this frame to the given buffer. + * + * @param buf a byte buffer to write this frame into + * + * @throws java.nio.BufferUnderflowException if the buffer + * doesn't have enough space + */ + public void writeFrame(ByteBuffer buf) { + long size = size(); + long len = length(); + int pos0 = buf.position(); + VariableLengthEncoder.encode(buf, TYPE); + VariableLengthEncoder.encode(buf, len); + int pos1 = buf.position(); + for (int i = 0; i < MAX_PARAM; i++) { + if (parameters[i] != -1) { + VariableLengthEncoder.encode(buf, i+1); + VariableLengthEncoder.encode(buf, parameters[i]); + } + } + if (undefinedId != -1) { + // Setting identifiers of the format 0x1f * N + 0x21 for + // non-negative integer values of N are reserved to exercise + // the requirement that unknown identifiers be ignored. + // Such settings have no defined meaning. Endpoints SHOULD + // include at least one such setting in their SETTINGS frame + assert isReservedId(undefinedId); + VariableLengthEncoder.encode(buf, undefinedId); + VariableLengthEncoder.encode(buf, undefinedValue); + } + assert buf.position() - pos1 == len; + assert buf.position() == pos0 + size; + } + + /** + * Decodes a SETTINGS frame from the given reader. + * This method is expected to be called when the reader + * contains enough bytes to decode the frame. + * + * @param reader a reader containing bytes + * + * @return a new SettingsFrame frame, or a MalformedFrame. + * + * @throws BufferUnderflowException if the reader doesn't contain + * enough bytes to decode the frame + */ + public static AbstractHttp3Frame decodeFrame(BuffersReader reader, Logger debug) { + final long pos = reader.position(); + decodeRequiredType(reader, TYPE); + final SettingsFrame frame = new SettingsFrame(-1, -1); + long length = VariableLengthEncoder.decode(reader); + + // is that OK? Find what's the actual limit for + // a frame length... + if (length > reader.remaining()) { + reader.position(pos); + throw new BufferUnderflowException(); + } + + // position before reading payload + long start = reader.position(); + + while (length > reader.position() - start) { + long id = VariableLengthEncoder.decode(reader); + long value = VariableLengthEncoder.decode(reader); + if (id == -1 || value == -1) { + return new MalformedFrame(TYPE, + Http3Error.H3_FRAME_ERROR.code(), + "Invalid SETTINGS frame contents."); + } + try { + frame.setParameter(id, value); + } catch (IllegalArgumentException iae) { + String msg = "H3_SETTINGS_ERROR: " + iae.getMessage(); + if (debug.on()) debug.log(msg, iae); + reader.position(start + length); + reader.release(); + return new MalformedFrame(TYPE, + Http3Error.H3_SETTINGS_ERROR.code(), + iae.getMessage(), + iae); + } + } + + // check position after reading payload + var malformed = checkPayloadSize(TYPE, reader, start, length); + if (malformed != null) return malformed; + + reader.release(); + return frame; + } + + public static SettingsFrame defaultRFCSettings() { + SettingsFrame f = new SettingsFrame() + .setParameter(SETTINGS_MAX_FIELD_SECTION_SIZE, + DEFAULT_SETTINGS_MAX_FIELD_SECTION_SIZE) + .setParameter(SETTINGS_QPACK_MAX_TABLE_CAPACITY, + DEFAULT_SETTINGS_QPACK_MAX_TABLE_CAPACITY) + .setParameter(SETTINGS_QPACK_BLOCKED_STREAMS, + DEFAULT_SETTINGS_QPACK_BLOCKED_STREAMS); + return f; + } + + public boolean isIllegal(long parameterId) { + // Parameters with 0x0, 0x2, 0x3, 0x4 and 0x5 ids are reserved, + // 0x6 is the legal one: + // https://www.rfc-editor.org/rfc/rfc9114.html#name-settings-parameters + // 0x1 and 0x7 defined by QPACK as a legal one: + // https://www.rfc-editor.org/rfc/rfc9204.html#name-configuration + return parameterId < SETTINGS_MAX_FIELD_SECTION_SIZE && + parameterId != SETTINGS_QPACK_MAX_TABLE_CAPACITY; + } + + public static long nextRandomParameterValue() { + long value = RANDOM.nextLong(0, MAX_ENCODED_INTEGER + 1); + assert value >= 0 && value <= MAX_ENCODED_INTEGER; + return value; + } + + private static final long MAX_N = (MAX_ENCODED_INTEGER - 0x21L) / 0x1fL; + public static long nextRandomReservedParameterId() { + long N = RANDOM.nextLong(0, MAX_N + 1); + long id = 0x1fL * N + 0x21L; + assert id <= MAX_ENCODED_INTEGER; + assert id >= 0x21L; + assert isReservedId(id) : "generated id is not undefined: " + id; + return id; + } + + /** + * Tells whether the given id is one of the undefined parameter ids that + * are reserved and have no meaning. + * + * @apiNote + * Setting identifiers of the format 0x1f * N + 0x21 + * for non-negative integer values of N are reserved to + * exercise the requirement that unknown identifiers be + * ignored + * + * @param id the parameter id + * + * @return true if this is one of the reserved identifiers + */ + public static boolean isReservedId(long id) { + return id >= 0x21 && id < MAX_ENCODED_INTEGER && (id - 0x21) % 0x1f == 0; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/UnknownFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/UnknownFrame.java new file mode 100644 index 00000000000..4426493ed7c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/frames/UnknownFrame.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.http3.frames; + +/** + * A class to model an unknown or reserved frame. + * @apiNote + * From RFC 9114: + *
+ * Frame types of the format 0x1f * N + 0x21 for non-negative integer + * values of N are reserved to exercise the requirement that + * unknown types be ignored (Section 9). These frames have no semantics, + * and MAY be sent on any stream where frames are allowed to be sent. + * This enables their use for application-layer padding. Endpoints MUST NOT + * consider these frames to have any meaning upon receipt. + *
+ * + * @apiNote + * An instance of {@code UnknownFrame} is used to read or writes + * the frame's type and length. The payload is supposed to be + * read or written directly to the stream on its own, after having + * read or written the frame type and length. + * @see jdk.internal.net.http.http3.frames.PartialFrame + * */ +public final class UnknownFrame extends PartialFrame { + final long length; + UnknownFrame(long type, long length) { + super(type, length); + this.length = length; + } + + @Override + public long length() { + return length; + } + + /** + * {@return true if this frame type is one of the reserved + * types} + */ + public boolean isReserved() { + return Http3FrameType.isReservedType(type); + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/Http3Streams.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/Http3Streams.java new file mode 100644 index 00000000000..8399a4550ff --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/Http3Streams.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.http3.streams; + +import java.util.EnumSet; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; +import jdk.internal.net.http.quic.streams.QuicSenderStream; +import jdk.internal.net.http.quic.streams.QuicStream; + +public final class Http3Streams { + public static final int CONTROL_STREAM_CODE = 0x00; + public static final int PUSH_STREAM_CODE = 0x01; + public static final int QPACK_ENCODER_STREAM_CODE = 0x02; + public static final int QPACK_DECODER_STREAM_CODE = 0x03; + + private Http3Streams() { throw new InternalError(); } + + public enum StreamType { + CONTROL(CONTROL_STREAM_CODE), + PUSH(PUSH_STREAM_CODE), + QPACK_ENCODER(QPACK_ENCODER_STREAM_CODE), + QPACK_DECODER(QPACK_DECODER_STREAM_CODE); + final int code; + StreamType(int code) { + this.code = code; + } + public final int code() { + return code; + } + public static Optional ofCode(long code) { + return EnumSet.allOf(StreamType.class).stream() + .filter(s -> s.code() == code) + .findFirst(); + } + } + + /** + * {@return an optional string that represents the error state of the + * stream, or {@code Optional.empty()} if no error code + * has been received or sent} + * @param stream a quic stream that may have errors + */ + public static Optional errorCodeAsString(QuicStream stream) { + long sndErrorCode = -1; + long rcvErrorCode = -1; + if (stream instanceof QuicReceiverStream rcv) { + rcvErrorCode = rcv.rcvErrorCode(); + } + if (stream instanceof QuicSenderStream snd) { + sndErrorCode = snd.sndErrorCode(); + } + if (rcvErrorCode >= 0 || sndErrorCode >= 0) { + Stream rcv = rcvErrorCode >= 0 + ? Stream.of("RCV: " + Http3Error.stringForCode(rcvErrorCode)) + : Stream.empty(); + Stream snd = sndErrorCode >= 0 + ? Stream.of("SND: " + Http3Error.stringForCode(sndErrorCode)) + : Stream.empty(); + return Optional.of(Stream.concat(rcv, snd) + .collect(Collectors.joining(",", "errorCode(", ")" ))); + } + return Optional.empty(); + } + + /** + * If the stream has errors, prints a message recording the + * {@linkplain #errorCodeAsString(QuicStream) error state} of the + * stream through the given logger. The message is of the form: + * {@code : }. + * If the given {@code name} is null or empty, {@code "Stream"} is substituted + * to {@code }. + * @param logger the logger to log through + * @param stream a quic stream that may have errors + * @param name a name for the stream, e.g {@code "Control stream"}, or {@code null}. + */ + public static void debugErrorCode(Logger logger, QuicStream stream, String name) { + if (logger.on()) { + var errorCodeStr = errorCodeAsString(stream); + if (errorCodeStr.isPresent()) { + var what = (name == null || name.isEmpty()) ? "Stream" : name; + logger.log("%s %s: %s", what, stream.streamId(), errorCodeStr.get()); + } + } + } + + public static boolean isReserved(long streamType) { + return streamType % 31 == 2 && streamType > 31; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/PeerUniStreamDispatcher.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/PeerUniStreamDispatcher.java new file mode 100644 index 00000000000..5db767902f9 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/PeerUniStreamDispatcher.java @@ -0,0 +1,328 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.http3.streams; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.streams.Http3Streams.StreamType; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; + +/** + * A class that analyzes the first byte of the stream to figure + * out where to dispatch it. + */ +public abstract class PeerUniStreamDispatcher { + private final QuicStreamIntReader reader; + private final QuicReceiverStream stream; + private final CompletableFuture cf = new MinimalFuture<>(); + + /** + * Creates a {@code PeerUniStreamDispatcher} for the given stream. + * @param stream a new unidirectional stream opened by the peer + */ + protected PeerUniStreamDispatcher(QuicReceiverStream stream) { + this.reader = new QuicStreamIntReader(checkStream(stream), debug()); + this.stream = stream; + } + + private static QuicReceiverStream checkStream(QuicReceiverStream stream) { + if (!stream.isRemoteInitiated()) { + throw new IllegalArgumentException("stream " + stream.streamId() + " is not peer initiated"); + } + if (stream.isBidirectional()) { + throw new IllegalArgumentException("stream " + stream.streamId() + " is not unidirectional"); + } + return stream; + } + + /** + * {@return a completable future that will contain the dispatched stream, + * once dispatched, or a throwable if dispatching the stream failed} + */ + public CompletableFuture dispatchCF() { + return cf; + } + + // The dispatch function. + private void dispatch(Long result, Throwable error) { + if (result != null && result == Http3Streams.PUSH_STREAM_CODE) { + reader.readInt().whenComplete(this::dispatchPushStream); + return; + } + reader.stop(); + if (result != null) { + cf.complete(stream); + if (Http3Streams.isReserved(result)) { + // reserved stream type, 0x1f * N + 0x21 + reservedStreamType(result, stream); + return; + } + if (result < 0) { + debug().log("stream %s EOF, cannot dispatch!", + stream.streamId()); + abandon(); + } + if (result > Integer.MAX_VALUE) { + unknownStreamType(result, stream); + return; + } + int code = (int)(long)result; + switch (code) { + case Http3Streams.CONTROL_STREAM_CODE -> { + controlStream("peer control stream", StreamType.CONTROL); + } + case Http3Streams.QPACK_ENCODER_STREAM_CODE -> { + qpackEncoderStream("peer qpack encoder stream", StreamType.QPACK_ENCODER); + } + case Http3Streams.QPACK_DECODER_STREAM_CODE -> { + qpackDecoderStream("peer qpack decoder stream", StreamType.QPACK_DECODER); + } + default -> { + unknownStreamType(code, stream); + } + } + } else if (error instanceof IOException io) { + if (stream.receivingState().isReset()) { + debug().log("stream %s %s before stream type received, cannot dispatch!", + stream.streamId(), stream.receivingState()); + // RFC 9114: https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2-10 + // > A receiver MUST tolerate unidirectional streams being closed or reset + // > prior to the reception of the unidirectional stream header + cf.complete(stream); + abandon(); + return; + } + abort(io); + } else { + // We shouldn't come here, so if we do, it's closer to an + // internal error than a stream creation error. + abort(error); + } + } + + private void dispatchPushStream(Long result, Throwable error) { + reader.stop(); + if (result != null) { + cf.complete(stream); + if (result < 0) { + debug().log("stream %s EOF, cannot dispatch!", + stream.streamId()); + abandon(); + } else { + pushStream("push stream", StreamType.PUSH, result); + } + } else if (error instanceof IOException io) { + if (stream.receivingState().isReset()) { + debug().log("stream %s %s before push stream ID received, cannot dispatch!", + stream.streamId(), stream.receivingState()); + // RFC 9114: https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2-10 + // > A receiver MUST tolerate unidirectional streams being closed or reset + // > prior to the reception of the unidirectional stream header + cf.complete(stream); + abandon(); + return; + } + abort(io); + } else { + // We shouldn't come here, so if we do, it's closer to an + // internal error than a stream creation error. + abort(error); + } + + } + + // dispatches the peer control stream + private void controlStream(String description, StreamType type) { + assert type.code() == Http3Streams.CONTROL_STREAM_CODE; + debug().log("dispatching %s %s(%s)", description, type, type.code()); + onControlStreamCreated(description, stream); + } + + // dispatches the peer encoder stream + private void qpackEncoderStream(String description, StreamType type) { + assert type.code() == Http3Streams.QPACK_ENCODER_STREAM_CODE; + debug().log("dispatching %s %s(%s)", description, type, type.code()); + onEncoderStreamCreated(description, stream); + } + + // dispatches the peer decoder stream + private void qpackDecoderStream(String description, StreamType type) { + assert type.code() == Http3Streams.QPACK_DECODER_STREAM_CODE; + debug().log("dispatching %s %s(%s)", description, type, type.code()); + onDecoderStreamCreated(description, stream); + } + + // dispatches a push stream initiated by the peer + private void pushStream(String description, StreamType type, long pushId) { + assert type.code() == Http3Streams.PUSH_STREAM_CODE; + debug().log("dispatching %s %s(%s, %s)", description, type, type.code(), pushId); + onPushStreamCreated(description, stream, pushId); + } + + // dispatches a stream whose stream type was recognized as a reserved stream type + private void reservedStreamType(long code, QuicReceiverStream stream) { + onReservedStreamType(code, stream); + } + + // dispatches a stream whose stream type was not recognized + private void unknownStreamType(long code, QuicReceiverStream stream) { + onUnknownStreamType(code, stream); + // if an exception is thrown above, abort will be called. + } + + /** + * {@return the debug logger that should be used} + */ + protected abstract Logger debug(); + + /** + * Starts the dispatcher. + * @apiNote + * The dispatcher should be explicitly started after + * creating the dispatcher. + */ + protected void start() { + reader.readInt().whenComplete(this::dispatch); + } + + /** + * This method disconnects the reader, stops the dispatch, and unless + * the stream type could be decoded and was a {@linkplain Http3Streams#isReserved(long) + * reserved type}, calls {@link #onStreamAbandoned(QuicReceiverStream)} + */ + protected void abandon() { + onStreamAbandoned(stream); + } + + /** + * Aborts the dispatch - for instance, if the stream type + * can't be read, or isn't recognized. + *

+ * This method requests the peer to stop sending this stream, + * and completes the {@link #dispatchCF() dispatchCF} exceptionally + * with the provided throwable. + * + * @param throwable the reason for aborting the dispatch + */ + private void abort(Throwable throwable) { + try { + var debug = debug(); + if (debug.on()) debug.log("aborting dispatch: " + throwable, throwable); + if (!stream.receivingState().isReset() && !stream.isStopSendingRequested()) { + stream.requestStopSending(Http3Error.H3_INTERNAL_ERROR.code()); + } + } finally { + abandon(); + cf.completeExceptionally(throwable); + } + } + + /** + * Called when a reserved stream type is read off the + * stream. + * + * @implSpec + * The default implementation of this method calls + * {@snippet : + * stream.requestStopSending(Http3Error.H3_STREAM_CREATION_ERROR.code()); + * } + * + * @param code the unrecognized stream type + * @param stream the peer initiated stream + */ + protected void onReservedStreamType(long code, QuicReceiverStream stream) { + debug().log("Ignoring reserved stream type %s", code); + stream.requestStopSending(Http3Error.H3_STREAM_CREATION_ERROR.code()); + } + + /** + * Called when an unrecognized stream type is read off the + * stream. + * + * @implSpec + * The default implementation of this method calls + * {@snippet : + * stream.requestStopSending(Http3Error.H3_STREAM_CREATION_ERROR.code()); + * abandon(); + * } + * + * @param code the unrecognized stream type + * @param stream the peer initiated stream + */ + protected void onUnknownStreamType(long code, QuicReceiverStream stream) { + debug().log("Ignoring unknown stream type %s", code); + stream.requestStopSending(Http3Error.H3_STREAM_CREATION_ERROR.code()); + abandon(); + } + + /** + * Called after disconnecting to abandon a peer initiated stream. + * @param stream a peer initiated stream which was abandoned due to having an + * unknown type, or which was abandoned due to being reset + * before being dispatched. + * @apiNote + * A subclass may want to override this method in order to, e.g, emit a + * QPack Stream Cancellation instruction; + * See https://www.rfc-editor.org/rfc/rfc9204.html#name-abandonment-of-a-stream + */ + protected void onStreamAbandoned(QuicReceiverStream stream) {} + + /** + * Called after disconnecting to handle the peer control stream. + * The stream type has already been read off the stream. + * @param description a brief description of the stream for logging purposes + * @param stream the peer control stream + */ + protected abstract void onControlStreamCreated(String description, QuicReceiverStream stream); + + /** + * Called after disconnecting to handle the peer encoder stream. + * The stream type has already been read off the stream. + * @param description a brief description of the stream for logging purposes + * @param stream the peer encoder stream + */ + protected abstract void onEncoderStreamCreated(String description, QuicReceiverStream stream); + + /** + * Called after disconnecting to handle the peer decoder stream. + * The stream type has already been read off the stream. + * @param description a brief description of the stream for logging purposes + * @param stream the peer decoder stream + */ + protected abstract void onDecoderStreamCreated(String description, QuicReceiverStream stream); + + /** + * Called after disconnecting to handle a peer initiated push stream. + * The stream type has already been read off the stream. + * @param description a brief description of the stream for logging purposes + * @param stream a peer initiated push stream + */ + protected abstract void onPushStreamCreated(String description, QuicReceiverStream stream, long pushId); + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/QueuingStreamPair.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/QueuingStreamPair.java new file mode 100644 index 00000000000..bff353a648a --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/QueuingStreamPair.java @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3.streams; + +import java.io.IOException; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Consumer; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.http3.streams.Http3Streams.StreamType; +import jdk.internal.net.http.quic.QuicConnection; +import jdk.internal.net.http.quic.streams.QuicSenderStream; +import jdk.internal.net.http.quic.streams.QuicSenderStream.SendingStreamState; +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.http.quic.streams.QuicStreamWriter; + +/** + * A class that models a pair of unidirectional streams, where + * data to be written is simply submitted to a queue. + */ +public class QueuingStreamPair extends UniStreamPair { + + // a queue of ByteBuffers submitted for writing. + protected final ConcurrentLinkedQueue writerQueue; + + /** + * Creates a new {@code QueuingStreamPair} for the given HTTP/3 {@code streamType}. + * Valid values for {@code streamType} are {@link StreamType#CONTROL}, + * {@link StreamType#QPACK_ENCODER}, and {@link StreamType#QPACK_DECODER}. + *

+ * This class implements a read loop and a write loop. + *

+ * The read loop will call the given {@code receiver} + * whenever a {@code ByteBuffer} is received. + *

+ * Data can be written to the stream simply by {@linkplain + * #submitData(ByteBuffer) submitting} it to the + * internal unbounded queue managed by this stream. + * When the stream becomes writable, the write loop is invoked and all + * pending data in the queue is written to the stream, until the stream + * is blocked or the queue is empty. + * + * @param streamType the HTTP/3 stream type + * @param quicConnection the underlying Quic connection + * @param receiver the receiver callback + * @param errorHandler the error handler invoked in case of read errors + * @param logger the debug logger + */ + public QueuingStreamPair(StreamType streamType, + QuicConnection quicConnection, + Consumer receiver, + StreamErrorHandler errorHandler, + Logger logger) { + // initialize writer queue before the parent constructor starts the writer loop + writerQueue = new ConcurrentLinkedQueue<>(); + super(streamType, quicConnection, receiver, errorHandler, logger); + } + + /** + * {@return the available credit, taking into account data that has + * not been submitted yet} + * This is only weakly consistent. + */ + public long credit() { + var writer = localWriter(); + long credit = (writer == null) ? 0 : writer.credit(); + if (writerQueue.isEmpty()) return credit; + return credit - writerQueue.stream().mapToLong(Buffer::remaining).sum(); + } + + /** + * Submit data to be written to the sending stream via this + * object's internal queue. + * @param buffer the data to submit + */ + public final void submitData(ByteBuffer buffer) { + writerQueue.offer(buffer); + localWriteScheduler().runOrSchedule(); + } + + // The local control stream write loop + @Override + void localWriterLoop() { + var writer = localWriter(); + if (writer == null) return; + assert !(writer instanceof QueuingWriter); + ByteBuffer buffer; + if (debug.on()) + debug.log("start control writing loop: credit=" + writer.credit()); + while (writer.credit() > 0 && (buffer = writerQueue.poll()) != null) { + try { + if (debug.on()) + debug.log("schedule %s bytes for writing on control stream", buffer.remaining()); + writer.scheduleForWriting(buffer, buffer == QuicStreamReader.EOF); + } catch (Throwable t) { + if (debug.on()) { + debug.log("Failed to write to control stream", t); + } + errorHandler.onError(writer.stream(), this, t); + } + } + } + + @Override + QuicStreamWriter wrap(QuicStreamWriter writer) { + return new QueuingWriter(writer); + } + + /** + * A class that wraps the actual {@code QuicStreamWriter} + * and redirect everything to the QueuingStreamPair's + * writerQueue - so that data is not sent out of order. + */ + class QueuingWriter extends QuicStreamWriter { + final QuicStreamWriter writer; + QueuingWriter(QuicStreamWriter writer) { + super(QueuingStreamPair.this.localWriteScheduler()); + this.writer = writer; + } + + @Override + public SendingStreamState sendingState() { + return writer.sendingState(); + } + + @Override + public void scheduleForWriting(ByteBuffer buffer, boolean last) throws IOException { + if (!last || buffer.hasRemaining()) submitData(buffer); + if (last) submitData(QuicStreamReader.EOF); + } + + @Override + public void queueForWriting(ByteBuffer buffer) throws IOException { + QueuingStreamPair.this.writerQueue.offer(buffer); + } + + @Override + public long credit() { + return QueuingStreamPair.this.credit(); + } + + @Override + public void reset(long errorCode) throws IOException { + writer.reset(errorCode); + } + + @Override + public QuicSenderStream stream() { + return writer.stream(); + } + + @Override + public boolean connected() { + return writer.connected(); + } + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/QuicStreamIntReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/QuicStreamIntReader.java new file mode 100644 index 00000000000..a9101b48ad9 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/QuicStreamIntReader.java @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.http3.streams; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.quic.VariableLengthEncoder; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; +import jdk.internal.net.http.quic.streams.QuicStreamReader; + +import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; + +/** + * A class that reads VL integers from a QUIC stream. + *

+ * After constructing an instance of this class, the application + * is can call {@link #readInt()} to read one VL integer off the stream. + * When the read operation completes, the application can call {@code readInt} + * again, or call {@link #stop()} to disconnect the reader. + */ +final class QuicStreamIntReader { + private final SequentialScheduler scheduler = SequentialScheduler.lockingScheduler(this::dispatch); + private final QuicReceiverStream stream; + private final QuicStreamReader reader; + private final Logger debug; + private CompletableFuture cf; + private ByteBuffer vlongBuf; // accumulate bytes until stream type can be decoded + + /** + * Creates a {@code QuicStreamIntReader} for the given stream. + * @param stream a receiver stream with no connected reader + * @param debug a logger + */ + public QuicStreamIntReader(QuicReceiverStream stream, Logger debug) { + this.stream = stream; + this.reader = stream.connectReader(scheduler); + this.debug = debug; + debug.log("int reader created for stream " + stream.streamId()); + } + + // The read loop. Attempts to read a VL int, and completes the CF when done. + private void dispatch() { + if (cf == null) return; // not reading anything at the moment + try { + ByteBuffer buffer; + while ((buffer = reader.peek()) != null) { + if (buffer == QuicStreamReader.EOF) { + debug.log("stream %s EOF, cannot complete!", + stream.streamId()); + CompletableFuture cf0; + synchronized (this) { + cf0 = cf; + cf = null; + } + cf0.complete(-1L); + return; + } + if (buffer.remaining() == 0) { + var polled = reader.poll(); + assert buffer == polled; + continue; + } + if (vlongBuf == null) { + long vlong = VariableLengthEncoder.decode(buffer); + if (vlong >= 0) { + // happy case: we have enough bytes in the buffer + if (buffer.remaining() == 0) { + var polled = reader.poll(); + assert buffer == polled; + } + CompletableFuture cf0; + synchronized (this) { + cf0 = cf; + cf = null; + } + cf0.complete(vlong); + return; + } + // we don't have enough bytes: start accumulating them + int vlongSize = VariableLengthEncoder.peekEncodedValueSize(buffer, buffer.position()); + assert vlongSize > 0 && vlongSize <= VariableLengthEncoder.MAX_INTEGER_LENGTH + : vlongSize + " is out of bound for a variable integer size (should be in [1..8]"; + assert buffer.remaining() < vlongSize; + vlongBuf = ByteBuffer.allocate(vlongSize); + vlongBuf.put(buffer); + assert buffer.remaining() == 0; + var polled = reader.poll(); + assert polled == buffer; + // continue and wait for more + } else { + // there wasn't enough bytes the first time around, accumulate + // missing bytes + int missing = vlongBuf.remaining(); + int available = Math.min(missing, buffer.remaining()); + for (int i = 0; i < available; i++) { + vlongBuf.put(buffer.get()); + } + // if we have exhausted the buffer, poll it. + if (!buffer.hasRemaining()) { + var polled = reader.poll(); + assert polled == buffer; + } + // if we have all bytes, we can proceed and decode the stream type + if (!vlongBuf.hasRemaining()) { + vlongBuf.flip(); + long vlong = VariableLengthEncoder.decode(vlongBuf); + assert !vlongBuf.hasRemaining(); + vlongBuf = null; + assert vlong >= 0; + CompletableFuture cf0; + synchronized (this) { + cf0 = cf; + cf = null; + } + cf0.complete(vlong); + return; + } // otherwise, wait for more + } + } + } catch (Throwable throwable) { + CompletableFuture cf0; + synchronized (this) { + cf0 = cf; + cf = null; + } + cf0.completeExceptionally(throwable); + } + } + + /** + * Stops and disconnects this reader. This operation must not be done when a read operation + * is in progress. If cancelling a read operation is intended, use + * {@link QuicReceiverStream#requestStopSending(long)}. + * @throws IllegalStateException if a read operation is currently in progress. + */ + public synchronized void stop() { + if (cf != null) { + // if a read is in progress, some bytes might have been read + // off the stream already, and stopping the reader could corrupt the data. + throw new IllegalStateException("Reading in progress"); + } + if (!reader.connected()) return; + stream.disconnectReader(reader); + scheduler.stop(); + } + + /** + * Starts a read operation to decode a single number. + * @return a {@link CompletableFuture} that will be completed + * with the decoded number, or -1 if the stream is terminated before + * the complete number could be read, or an exception + * if the stream is reset or decoding fails. + * @throws IllegalStateException if the reader is stopped, or if a read + * operation is already in progress + */ + public synchronized CompletableFuture readInt() { + if (cf != null) { + throw new IllegalStateException("Read in progress"); + } + if (!reader.connected()) { + throw new IllegalStateException("Reader stopped"); + } + var cf0 = cf = new MinimalFuture<>(); + reader.start(); + scheduler.runOrSchedule(); + return cf0; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/UniStreamPair.java b/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/UniStreamPair.java new file mode 100644 index 00000000000..403b26f244d --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/http3/streams/UniStreamPair.java @@ -0,0 +1,505 @@ +/* + * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.http3.streams; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; + +import jdk.internal.net.http.Http3Connection; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.streams.Http3Streams.StreamType; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.http.quic.QuicConnection; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; +import jdk.internal.net.http.quic.streams.QuicSenderStream; +import jdk.internal.net.http.quic.streams.QuicStream; +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.http.quic.streams.QuicStreamWriter; +import jdk.internal.net.http.quic.VariableLengthEncoder; +import static jdk.internal.net.http.http3.Http3Error.H3_STREAM_CREATION_ERROR; +import static jdk.internal.net.http.quic.TerminationCause.appLayerClose; + +/** + * A class that models a pair of HTTP/3 unidirectional streams. + * This class implements a read loop that calls a {@link + * #UniStreamPair(StreamType, QuicConnection, Consumer, Runnable, StreamErrorHandler, Logger) + * receiver} every time a {@code ByteBuffer} is read from + * the receiver part. + * The {@linkplain #futureSenderStreamWriter() sender stream writer}, + * when available, can be used to write to the sender part. + * The {@link #UniStreamPair(StreamType, QuicConnection, Consumer, Runnable,StreamErrorHandler, Logger) + * writerLoop} is invoked whenever the writer part becomes unblocked, and + * writing can be resumed. + *

+ * @apiNote + * The creator of the stream pair (typically {@link Http3Connection}) is expected + * to complete the {@link #futureReceiverStream()} completable future when the remote + * part of the stream pair is created by the remote peer. This class will not + * listen directly for creation of new remote streams. + *

+ * The {@link QueuingStreamPair} class is a subclass of this class which + * implements a writer loop over an unbounded queue of {@code ByteBuffer}, and + * can be used when unlimited buffering of data for writing is not an issue. + */ +public class UniStreamPair { + + // The sequential scheduler for the local control stream (LCS) writer loop + private final SequentialScheduler localWriteScheduler; + // The QuicStreamWriter for the local control stream + private volatile QuicStreamWriter localWriter; + private final CompletableFuture streamWriterCF; + // A completable future that will be completed when the local sender + // stream is opened and the stream type has been queued to the + // writer queue. + private volatile CompletableFuture localSenderStreamCF; + + // The sequential scheduler for the peer receiver stream (PRS) reader loop + final SequentialScheduler peerReadScheduler = + SequentialScheduler.lockingScheduler(this::peerReaderLoop); + // The QuicStreamReader for the peer control stream + volatile QuicStreamReader peerReader; + private final CompletableFuture streamReaderCF; + // A completable future that will be completed when the peer opens + // the receiver part of the stream pair + private final CompletableFuture peerReceiverStreamCF = new MinimalFuture<>(); + private final ReentrantLock lock = new ReentrantLock(); + + + private final StreamType localStreamType; // The HTTP/3 stream type of the sender part + private final StreamType remoteStreamType; // The HTTP/3 stream type of the receiver part + private final QuicConnection quicConnection; // the underlying quic connection + private final Consumer receiver; // called when a ByteBuffer is received + final StreamErrorHandler errorHandler; // used by QueuingStreamPair + final Logger debug; // the debug logger + + /** + * Creates a new {@code UniStreamPair} for the given HTTP/3 {@code streamType}. + * Valid values for {@code streamType} are {@link StreamType#CONTROL}, + * {@link StreamType#QPACK_ENCODER}, and {@link StreamType#QPACK_DECODER}. + *

+ * This class implements a read loop that will call the given {@code receiver} + * whenever a {@code ByteBuffer} is received. + *

+ * Writing to the sender part can be done by interacting directly with + * the writer. If the writer is blocked due to flow control, and becomes + * unblocked again, the {@code writeLoop} is invoked. + * The {@link QueuingStreamPair} subclass provides a convenient implementation + * of a {@code writeLoop} based on an unbounded queue of {@code ByteBuffer}. + * + * @param streamType the HTTP/3 stream type + * @param quicConnection the underlying Quic connection + * @param receiver the receiver callback + * @param writerLoop the writer loop + * @param errorHandler the error handler invoked in case of read errors + * @param logger the debug logger + */ + public UniStreamPair(StreamType streamType, + QuicConnection quicConnection, + Consumer receiver, + Runnable writerLoop, + StreamErrorHandler errorHandler, + Logger logger) { + this(local(streamType), remote(streamType), + Objects.requireNonNull(quicConnection), + Objects.requireNonNull(receiver), + Optional.of(writerLoop), + Objects.requireNonNull(errorHandler), + Objects.requireNonNull(logger)); + } + + /** + * A constructor used by the {@link QueuingStreamPair} subclass + * @param streamType the HTTP/3 stream type + * @param quicConnection the underlying Quic connection + * @param receiver the receiver callback + * @param errorHandler the error handler invoked in case + * of read or write errors + * @param logger + */ + UniStreamPair(StreamType streamType, + QuicConnection quicConnection, + Consumer receiver, + StreamErrorHandler errorHandler, + Logger logger) { + this(local(streamType), remote(streamType), + Objects.requireNonNull(quicConnection), + Objects.requireNonNull(receiver), + Optional.empty(), + errorHandler, + Objects.requireNonNull(logger)); + } + + // all constructors delegate here + private UniStreamPair(StreamType localStreamType, + StreamType remoteStreamType, + QuicConnection quicConnection, + Consumer receiver, + Optional writerLoop, + StreamErrorHandler errorHandler, + Logger logger) { + assert this.getClass() != UniStreamPair.class + || writerLoop.isPresent(); + this.debug = logger; + this.localStreamType = localStreamType; + this.remoteStreamType = remoteStreamType; + this.quicConnection = quicConnection; + this.receiver = receiver; + this.errorHandler = errorHandler; + var localWriterLoop = writerLoop.orElse(this::localWriterLoop); + this.localWriteScheduler = + SequentialScheduler.lockingScheduler(localWriterLoop); + this.streamWriterCF = startSending(); + this.streamReaderCF = startReceiving(); + } + + private static StreamType local(StreamType localStreamType) { + return switch (localStreamType) { + case CONTROL -> localStreamType; + case QPACK_ENCODER -> localStreamType; + case QPACK_DECODER -> localStreamType; + default -> throw new IllegalArgumentException(localStreamType + + " cannot be part of a stream pair"); + }; + } + + private static StreamType remote(StreamType localStreamType) { + return switch (localStreamType) { + case CONTROL -> localStreamType; + case QPACK_ENCODER -> StreamType.QPACK_DECODER; + case QPACK_DECODER -> StreamType.QPACK_ENCODER; + default -> throw new IllegalArgumentException(localStreamType + + " cannot be part of a stream pair"); + }; + } + + /** + * {@return the HTTP/3 stream type of the sender part of the stream pair} + */ + public final StreamType localStreamType() { + return localStreamType; + } + + /** + * {@return the HTTP/3 stream type of the receiver part of the stream pair} + */ + public final StreamType remoteStreamType() { + return remoteStreamType; + } + + /** + * {@return a completable future that will be completed with a writer connected + * to the sender part of this stream pair after the local HTTP/3 stream type + * has been queued for writing on the writing queue} + */ + public final CompletableFuture futureSenderStreamWriter() { + return streamWriterCF; + } + + /** + * {@return a completable future that will be completed with a reader connected + * to the receiver part of this stream pair after the remote HTTP/3 stream + * type has been read off the remote initiated stream} + */ + public final CompletableFuture futureReceiverStreamReader() { + return streamReaderCF; + } + + /** + * {@return a completable future that will be completed with the sender part + * of this stream pair after the local HTTP/3 stream type + * has been queued for writing on the writing queue} + */ + public CompletableFuture futureSenderStream() { + return localSenderStream(); + } + + /** + * {@return a completable future that will be completed with the receiver part + * of this stream pair after the remote HTTP/3 stream type has been read off + * the remote initiated stream} + */ + public CompletableFuture futureReceiverStream() { + return peerReceiverStreamCF; + } + + /** + * {@return the scheduler for the local writer loop} + */ + public SequentialScheduler localWriteScheduler() { + return localWriteScheduler; + } + + /** + * {@return the writer connected to the sender part of this stream or + * {@code null} if no writer is connected yet} + */ + public QuicStreamWriter localWriter() {return localWriter; } + + /** + * Stops schedulers. Can be called when the connection is + * closed to stop the reading and writing loops. + */ + public void stopSchedulers() { + peerReadScheduler.stop(); + localWriteScheduler.stop(); + } + + // Hooks for QueuingStreamPair + // ============================ + + /** + * This method is overridden by {@link QueuingStreamPair} to implement + * a writer loop for this stream. It is only called when the concrete + * subclass is {@link QueuingStreamPair}. + */ + void localWriterLoop() { + if (debug.on()) debug.log("writing loop not implemented"); + } + + + /** + * Used by subclasses to redirect queuing of data to the + * subclass queue. + * @param writer the downstream writer + * @return a writer that can be safely used. + */ + QuicStreamWriter wrap(QuicStreamWriter writer) { + return writer; + } + + // Undidirectional Stream Pair Implementation + // ========================================== + + + /** + * This method is called to process bytes received on the peer + * control stream. + * @param buffer the bytes received + */ + private void processPeerControlBytes(ByteBuffer buffer) { + receiver.accept(buffer); + } + + /** + * Creates the local sender stream and queues the stream + * type code in its writer queue. + * @return a completable future that will be completed with the + * local sender stream + */ + private CompletableFuture localSenderStream() { + CompletableFuture lcs = localSenderStreamCF; + if (lcs != null) return lcs; + StreamType type = localStreamType(); + lock.lock(); + try { + if ((lcs = localSenderStreamCF) != null) return lcs; + if (debug.on()) { + debug.log("Opening local stream: %s(%s)", + type, type.code()); + } + // TODO: review this duration + final Duration streamLimitIncreaseDuration = Duration.ZERO; + localSenderStreamCF = lcs = quicConnection + .openNewLocalUniStream(streamLimitIncreaseDuration) + .thenApply( s -> openLocalStream(s, type.code())); + // TODO: use thenApplyAsync with the executor instead + } finally { + lock.unlock(); + } + return lcs; + } + + + /** + * Schedules sending of client settings. + * @return a completable future that will be completed with the + * {@link QuicStreamWriter} allowing to write to the local control + * stream + */ + private CompletableFuture startSending() { + return localSenderStream().thenApply((stream) -> { + if (debug.on()) { + debug.log("stream %s is ready for sending", stream.streamId()); + } + var controlWriter = stream.connectWriter(localWriteScheduler); + localWriter = controlWriter; + localWriteScheduler.runOrSchedule(); + return wrap(controlWriter); + }); + } + + /** + * Schedules the receiving of server settings + * @return a completable future that will be completed with the + * {@link QuicStreamReader} allowing to read from the remote control + * stream. + */ + private CompletableFuture startReceiving() { + if (debug.on()) { + debug.log("prepare to receive"); + } + return peerReceiverStreamCF.thenApply(this::connectReceiverStream); + } + + /** + * Connects the peer control stream reader and + * schedules the receiving of the peer settings from the given + * {@code peerControlStream}. + * @param peerControlStream the peer control stream + * @return the peer control stream reader + */ + private QuicStreamReader connectReceiverStream(QuicReceiverStream peerControlStream) { + var reader = peerControlStream.connectReader(peerReadScheduler); + var streamType = remoteStreamType(); + if (debug.on()) { + debug.log("peer %s stream reader connected (stream %s)", + streamType, peerControlStream.streamId()); + } + peerReader = reader; + reader.start(); + return reader; + } + + // The peer receiver stream reader loop + private void peerReaderLoop() { + var reader = peerReader; + if (reader == null) return; + ByteBuffer buffer; + long bytes = 0; + var streamType = remoteStreamType(); + try { + // TODO: Revisit: if the underlying quic connection is closed + // by the peer, we might get a ClosedChannelException from poll() + // here before the upper layer connection (HTTP/3 connection) is + // marked closed. + if (debug.on()) { + debug.log("start reading from peer %s stream", streamType); + } + while ((buffer = reader.poll()) != null) { + final int remaining = buffer.remaining(); + if (remaining == 0 && buffer != QuicStreamReader.EOF) { + continue; // not yet EOF, so poll more + } + bytes += remaining; + processPeerControlBytes(buffer); + if (buffer == QuicStreamReader.EOF) { + // a EOF was processed, don't poll anymore + break; + } + } + if (debug.on()) { + debug.log("stop reading peer %s stream after %s bytes", + streamType, bytes); + } + } catch (IOException | RuntimeException | Error throwable) { + if (debug.on()) { + debug.log("Reading peer %s stream failed: %s", streamType, throwable); + } + // call the error handler and pass it the stream on which the error happened + errorHandler.onError(reader.stream(), this, throwable); + } + } + + /** + * Queues the given HTTP/3 stream type code on the given local unidirectional + * stream writer queue. + * @param stream a new local unidirectional stream + * @param code the code to queue up on the stream writer queue + * @return the given {@code stream} + */ + private QuicSenderStream openLocalStream(QuicSenderStream stream, int code) { + var streamType = localStreamType(); + if (debug.on()) { + debug.log("Opening local stream: %s %s(code=%s)", + stream.streamId(), streamType, code); + } + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = stream.connectWriter(scheduler); + try { + if (debug.on()) { + debug.log("Writing local stream type: stream %s %s(code=%s)", + stream.streamId(), streamType, code); + } + var buffer = ByteBuffer.allocate(VariableLengthEncoder.getEncodedSize(code)); + VariableLengthEncoder.encode(buffer, code); + buffer.flip(); + writer.queueForWriting(buffer); + scheduler.stop(); + stream.disconnectWriter(writer); + } catch (Throwable t) { + if (debug.on()) { + debug.log("failed to create stream %s %s(code=%s): %s", + stream.streamId(), streamType, code, t); + } + try { + switch (streamType) { + case CONTROL, QPACK_ENCODER, QPACK_DECODER -> { + final String logMsg = "stream %s %s(code=%s)" + .formatted(stream.streamId(), streamType, code); + // TODO: revisit - we should probably invoke a method + // on the HttpQuicConnection or H3Connection instead of + // dealing directly with QuicConnection here. + final TerminationCause terminationCause = + appLayerClose(H3_STREAM_CREATION_ERROR.code()).loggedAs(logMsg); + quicConnection.connectionTerminator().terminate(terminationCause); + } + default -> writer.reset(H3_STREAM_CREATION_ERROR.code()); + } + } catch (Throwable suppressed) { + if (debug.on()) { + debug.log("couldn't close connection or reset stream: " + suppressed); + } + Utils.addSuppressed(t, suppressed); + throw new CompletionException(t); + } + } + return stream; + } + + public static interface StreamErrorHandler { + + /** + * Will be invoked when there is an error on a {@code QuicStream} handled by + * the {@code UniStreamPair} + * + * @param stream the stream on which the error occurred + * @param uniStreamPair the UniStreamPair to which the stream belongs + * @param error the error that occurred + */ + void onError(QuicStream stream, UniStreamPair uniStreamPair, Throwable error); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/Decoder.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/Decoder.java new file mode 100644 index 00000000000..8487100c8ca --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/Decoder.java @@ -0,0 +1,400 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack; + +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.streams.QueuingStreamPair; +import jdk.internal.net.http.http3.streams.UniStreamPair; +import jdk.internal.net.http.qpack.QPACK.QPACKErrorHandler; +import jdk.internal.net.http.qpack.QPACK.StreamPairSupplier; +import jdk.internal.net.http.qpack.readers.EncoderInstructionsReader; +import jdk.internal.net.http.qpack.readers.HeaderFrameReader; +import jdk.internal.net.http.qpack.writers.DecoderInstructionsWriter; +import jdk.internal.net.http.quic.streams.QuicStreamReader; + +import java.io.IOException; +import java.net.ProtocolException; +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static jdk.internal.net.http.http3.Http3Error.H3_CLOSED_CRITICAL_STREAM; +import static jdk.internal.net.http.http3.frames.SettingsFrame.DEFAULT_SETTINGS_MAX_FIELD_SECTION_SIZE; +import static jdk.internal.net.http.qpack.DynamicTable.ENTRY_SIZE; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.NORMAL; + +/** + * Decodes headers from their binary representation. + * + *

Typical lifecycle looks like this: + * + *

{@link #Decoder(StreamPairSupplier, QPACKErrorHandler) new Decoder} + * ({@link #configure(ConnectionSettings)} called once from our HTTP/3 settings + * {@link #decodeHeader(ByteBuffer, boolean, HeaderFrameReader) decodeHeader} + * + *

{@code Decoder} does not require a complete header block in a single + * {@code ByteBuffer}. The header block can be spread across many buffers of any + * size and decoded one-by-one the way it makes most sense for the user. This + * way also allows not to limit the size of the header block. + * + *

Headers are delivered to the {@linkplain DecodingCallback callback} as + * soon as they become decoded. Using the callback also gives the user freedom + * to decide how headers are processed. The callback does not limit the number + * of headers decoded during single decoding operation. + */ + +public final class Decoder { + + private final QPACK.Logger logger; + private final DynamicTable dynamicTable; + private final EncoderInstructionsReader encoderInstructionsReader; + private final QueuingStreamPair decoderStreamPair; + private static final AtomicLong DECODERS_IDS = new AtomicLong(); + // ID of last acknowledged entry acked by Insert Count Increment + // or section acknowledgement instruction + private long acknowledgedInsertsCount; + private final ReentrantLock ackInsertCountLock = new ReentrantLock(); + private final AtomicLong blockedStreamsCounter = new AtomicLong(); + private volatile long maxBlockedStreams; + private final QPACKErrorHandler qpackErrorHandler; + private volatile long maxFieldSectionSize = DEFAULT_SETTINGS_MAX_FIELD_SECTION_SIZE; + private final AtomicLong concurrentDynamicTableInsertions = + new AtomicLong(); + private static final long MAX_LITERAL_WITH_INDEXING = + Utils.getIntegerNetProperty("jdk.httpclient.maxLiteralWithIndexing", 512); + + /** + * Constructs a {@code Decoder} with zero initial capacity of the dynamic table. + * + *

Dynamic table capacity values has to be agreed between decoder and encoder out-of-band, + * e.g. by a protocol that uses QPACK. + *

Maximum dynamic table capacity is determined by the value of SETTINGS_QPACK_MAX_TABLE_CAPACITY + * HTTP/3 setting sent by the decoder side (see + * + * 3.2.3. Maximum Dynamic Table Capacity). + *

An encoder informs the decoder of a change to the dynamic table capacity using the + * "Set Dynamic Table Capacity" instruction + * (see + * 4.3.1. Set Dynamic Table Capacity) + * + * @see Decoder#configure(ConnectionSettings) + */ + public Decoder(StreamPairSupplier streams, QPACKErrorHandler errorHandler) { + long id = DECODERS_IDS.incrementAndGet(); + logger = QPACK.getLogger().subLogger("Decoder#" + id); + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> "New decoder"); + } + dynamicTable = new DynamicTable(logger.subLogger("DynamicTable"), false); + decoderStreamPair = streams.create(this::processEncoderInstruction); + qpackErrorHandler = errorHandler; + encoderInstructionsReader = new EncoderInstructionsReader(new DecoderTableCallback(), logger); + } + + public QueuingStreamPair decoderStreams() { + return decoderStreamPair; + } + + /** + * {@return a new {@link HeaderFrameReader} that will hold the decoding + * state for a new request/response stream} + */ + public HeaderFrameReader newHeaderFrameReader(DecodingCallback decodingCallback) { + return new HeaderFrameReader(dynamicTable, decodingCallback, + blockedStreamsCounter, maxBlockedStreams, + maxFieldSectionSize, logger); + } + + public void ackTableInsertions() { + ackInsertCountLock.lock(); + try { + long insertCount = dynamicTable.insertCount(); + assert acknowledgedInsertsCount <= insertCount; + long incrementValue = insertCount - acknowledgedInsertsCount; + if (incrementValue > 0) { + // Write "Insert Count Increment" to the decoder stream + var decoderInstructionsWriter = new DecoderInstructionsWriter(); + int instructionSize = decoderInstructionsWriter.configureForInsertCountInc(incrementValue); + submitDecoderInstruction(decoderInstructionsWriter, instructionSize); + } + // Update lastAck value + acknowledgedInsertsCount = insertCount; + } finally { + ackInsertCountLock.unlock(); + } + } + + /** + * Submit "Section Acknowledgment" instruction to the decoder stream. + * A field line section needs to be acknowledged after completion of + * section decoding. + * @param streamId stream ID associated with the field section's + * @param headerFrameReader header frame reader used to read + * the field line section + */ + public void ackSection(long streamId, HeaderFrameReader headerFrameReader) { + + FieldSectionPrefix prefix = headerFrameReader.decodedSectionPrefix(); + + // 4.4.1. Section Acknowledgment: If an encoder receives a Section Acknowledgment instruction + // referring to a stream on which every encoded field section with a non-zero Required Insert + // Count has already been acknowledged, this MUST be treated as a connection error of type + // QPACK_DECODER_STREAM_ERROR. + long prefixInsertCount = prefix.requiredInsertCount(); + if (prefixInsertCount == 0) return; + ackInsertCountLock.lock(); + try { + var decoderInstructionsWriter = new DecoderInstructionsWriter(); + int instrSize = decoderInstructionsWriter.configureForSectionAck(streamId); + submitDecoderInstruction(decoderInstructionsWriter, instrSize); + if (prefixInsertCount > acknowledgedInsertsCount) { + acknowledgedInsertsCount = prefixInsertCount; + } + } finally { + ackInsertCountLock.unlock(); + } + } + + public void cancelStream(long streamId) { + var decoderInstructionsWriter = new DecoderInstructionsWriter(); + int instrSize = decoderInstructionsWriter.configureForStreamCancel(streamId); + submitDecoderInstruction(decoderInstructionsWriter, instrSize); + dynamicTable.cleanupStreamInsertCountNotifications(streamId); + } + + /** + * Configures maximum capacity of the decoder's dynamic table based on connection settings of + * the HTTP client, also configures the number of allowed blocked streams. + * The decoder's dynamic table capacity can only be changed via + * {@linkplain EncoderInstructionsReader.Callback encoder instructions callback}. + * + * @param ourSettings connection settings + */ + public void configure(ConnectionSettings ourSettings) { + long maxCapacity = ourSettings.qpackMaxTableCapacity(); + dynamicTable.setMaxTableCapacity(maxCapacity); + maxBlockedStreams = ourSettings.qpackBlockedStreams(); + long maxFieldSS = ourSettings.maxFieldSectionSize(); + if (maxFieldSS > 0) { + maxFieldSectionSize = maxFieldSS; + } else { + // Unlimited field section size + maxFieldSectionSize = -1L; + } + } + + /** + * Decodes a header block from the given buffer to the given callback. + * + *

Suppose a header block is represented by a sequence of + * {@code ByteBuffer}s in the form of {@code Iterator}. And the + * consumer of decoded headers is represented by {@linkplain DecodingCallback the callback} + * registered within the provided {@code headerFrameReader}. + * Then to decode the header block, the following approach might be used: + * {@snippet : + * HeaderFrameReader headerFrameReader = + * newHeaderFrameReader(decodingCallback); + * while (buffers.hasNext()) { + * ByteBuffer input = buffers.next(); + * decoder.decodeHeader(input, !buffers.hasNext(), headerFrameReader); + * } + * } + * + *

The decoder reads as much as possible of the header block from the + * given buffer, starting at the buffer's position, and increments its + * position to reflect the bytes read. The buffer's mark and limit will not + * be modified. + * + *

Once the method is invoked with {@code endOfHeaderBlock == true}, the + * current header block is deemed ended, and inconsistencies, if any, are + * reported immediately via a callback registered within the {@code + * headerFrameReader} instance. + * + *

Each callback method is called only after the implementation has + * processed the corresponding bytes. If the bytes revealed a decoding + * error it is reported via a callback registered within the {@code + * headerFrameReader} instance. + * + * @apiNote The method asks for {@code endOfHeaderBlock} flag instead of + * returning it for two reasons. The first one is that the user of the + * decoder always knows which chunk is the last. The second one is to throw + * the most detailed exception possible, which might be useful for + * diagnosing issues. + * + * @implNote This implementation is not atomic in respect to decoding + * errors. In other words, if the decoding operation has thrown a decoding + * error, the decoder is no longer usable. + * + * @param headerBlock + * the chunk of the header block, may be empty + * @param endOfHeaderBlock + * true if the chunk is the final (or the only one) in the sequence + * @param headerFrameReader the stateful header frame reader + * @throws NullPointerException + * if either {@code headerBlock} or {@code headerFrameReader} are null + */ + public void decodeHeader(ByteBuffer headerBlock, boolean endOfHeaderBlock, + HeaderFrameReader headerFrameReader) { + requireNonNull(headerFrameReader, "headerFrameReader"); + headerFrameReader.read(headerBlock, endOfHeaderBlock); + } + + /** + * This method is invoked when the {@linkplain + * UniStreamPair#futureReceiverStreamReader() decoder's stream reader} + * has data available for reading. + */ + private void processEncoderInstruction(ByteBuffer buffer) { + if (buffer == QuicStreamReader.EOF) { + // RFC-9204, section 4.2: + // Closure of either unidirectional stream type MUST be treated as a connection + // error of type H3_CLOSED_CRITICAL_STREAM. + qpackErrorHandler.closeOnError( + new ProtocolException("QPACK " + decoderStreamPair.remoteStreamType() + + " remote stream was unexpectedly closed"), H3_CLOSED_CRITICAL_STREAM); + return; + } + try { + int stringLengthLimit = Math.clamp(dynamicTable.capacity() - ENTRY_SIZE, + 0, Integer.MAX_VALUE - (int) ENTRY_SIZE); + encoderInstructionsReader.read(buffer, stringLengthLimit); + } catch (QPackException qPackException) { + qpackErrorHandler.closeOnError(qPackException.getCause(), qPackException.http3Error()); + } + } + + private void submitDecoderInstruction(DecoderInstructionsWriter decoderInstructionsWriter, + int size) { + if (size > decoderStreamPair.credit()) { + qpackErrorHandler.closeOnError( + new IOException("QPACK not enough credit on a decoder stream " + + decoderStreamPair.remoteStreamType()), H3_CLOSED_CRITICAL_STREAM); + return; + } + // All decoder instructions contain only one variable length integer. + // Which could take up to 9 bytes max. + ByteBuffer buffer = ByteBuffer.allocate(size); + boolean done = decoderInstructionsWriter.write(buffer); + // Assert that instruction is fully written, ie the correct + // instruction size estimation was supplied. + assert done; + buffer.flip(); + decoderStreamPair.submitData(buffer); + } + + void incrementAndCheckDynamicTableInsertsCount() { + if (MAX_LITERAL_WITH_INDEXING > 0) { + long concurrentNumberOfInserts = concurrentDynamicTableInsertions.incrementAndGet(); + if (concurrentNumberOfInserts > MAX_LITERAL_WITH_INDEXING) { + String exceptionMessage = "Too many literal with indexing: %s > %s" + .formatted(concurrentNumberOfInserts, MAX_LITERAL_WITH_INDEXING); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> exceptionMessage); + } + throw QPackException.encoderStreamError(new ProtocolException(exceptionMessage)); + } + } + } + + public void resetInsertionsCounter() { + if (MAX_LITERAL_WITH_INDEXING > 0) { + concurrentDynamicTableInsertions.set(0); + } + } + + private class DecoderTableCallback implements EncoderInstructionsReader.Callback { + + private void ensureInstructionsAllowed() { + // RFC9204 3.2.3. Maximum Dynamic Table Capacity: + // "When the maximum table capacity is zero, the encoder MUST NOT + // insert entries into the dynamic table and MUST NOT send any encoder + // instructions on the encoder stream." + if (dynamicTable.maxCapacity() == 0) { + throw new IllegalStateException("Unexpected encoder instruction"); + } + } + + @Override + public void onCapacityUpdate(long capacity) { + if (capacity == 0 && dynamicTable.maxCapacity() == 0) { + return; + } + ensureInstructionsAllowed(); + dynamicTable.setCapacity(capacity); + } + + @Override + public void onInsert(String name, String value) { + ensureInstructionsAllowed(); + incrementAndCheckDynamicTableInsertsCount(); + if (dynamicTable.insert(name, value) != DynamicTable.ENTRY_NOT_INSERTED) { + ackTableInsertions(); + } else { + // Not enough evictable space in dynamic table to insert entry + throw new IllegalStateException("Not enough space in dynamic table"); + } + } + + @Override + public void onInsertIndexedName(boolean indexInStaticTable, long nameIndex, String valueString) { + // RFC9204 7.4. Implementation Limits: + // "If an implementation encounters a value larger than it is able to decode, this MUST be + // treated as a stream error of type QPACK_DECOMPRESSION_FAILED if on a request stream or + // a connection error of the appropriate type if on the encoder or decoder stream." + ensureInstructionsAllowed(); + incrementAndCheckDynamicTableInsertsCount(); + if (dynamicTable.insert(nameIndex, indexInStaticTable, valueString) != + DynamicTable.ENTRY_NOT_INSERTED) { + ackTableInsertions(); + } else { + // Not enough space in dynamic table to insert entry + throw new IllegalStateException("Not enough space in dynamic table"); + } + } + + @Override + public void onDuplicate(long l) { + // RFC9204 7.4. Implementation Limits: + // "If an implementation encounters a value larger than it is able to decode, this + // MUST be treated as a stream error of type QPACK_DECOMPRESSION_FAILED" + ensureInstructionsAllowed(); + incrementAndCheckDynamicTableInsertsCount(); + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, + () -> format("Processing duplicate instruction (%d)", l)); + } + if (dynamicTable.duplicate(l) != DynamicTable.ENTRY_NOT_INSERTED) { + ackTableInsertions(); + } else { + // Not enough space in dynamic table to duplicate entry + throw new IllegalStateException("Not enough space in dynamic table"); + } + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/DecodingCallback.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/DecodingCallback.java new file mode 100644 index 00000000000..bb9dbdadf59 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/DecodingCallback.java @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack; + +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.qpack.readers.HeaderFrameReader; + +import java.nio.ByteBuffer; + +/** + * Delivers results of the {@link Decoder#decodeHeader(ByteBuffer, boolean, HeaderFrameReader)} + * decoding operation. + * + *

Methods of the callback are never called by a decoder with any of the + * arguments being {@code null}. + * + * @apiNote + * + *

The callback provides methods for all possible + * + * field line representations. + * + *

Names and values are {@link CharSequence}s rather than {@link String}s in + * order to allow users to decide whether they need to create objects. A + * {@code CharSequence} might be used in-place, for example, to be appended to + * an {@link Appendable} (e.g. {@link StringBuilder}) and then discarded. + * + *

That said, if a passed {@code CharSequence} needs to outlast the method + * call, it needs to be copied. + * + */ +public interface DecodingCallback { + + /** + * A method the more specific methods of the callback forward their calls + * to. + * + * @param name + * header name + * @param value + * header value + */ + void onDecoded(CharSequence name, CharSequence value); + + /** + * A header fields decoding is completed. + */ + void onComplete(); + + /** + * A connection-level error observed during the decoding process. + * + * @param throwable a {@code Throwable} instance + * @param http3Error a HTTP3 error code + */ + void onConnectionError(Throwable throwable, Http3Error http3Error); + + /** + * A stream-level error observed during the decoding process. + * + * @param throwable a {@code Throwable} instance + * @param http3Error a HTTP3 error code + */ + default void onStreamError(Throwable throwable, Http3Error http3Error) { + onConnectionError(throwable, http3Error); + } + + /** + * Reports if {@linkplain #onConnectionError(Throwable, Http3Error) a connection} + * or {@linkplain #onStreamError(Throwable, Http3Error) a stream} error has been + * observed during the decoding process + * @return true - if error was observed; false - otherwise + */ + default boolean hasError() { + return false; + } + + /** + * Returns request/response stream id or push stream id associated with a decoding callback. + */ + long streamId(); + + /** + * A more finer-grained version of {@link #onDecoded(CharSequence, + * CharSequence)} that also reports on value sensitivity. + * + *

Value sensitivity must be considered, for example, when implementing + * an intermediary. A {@code value} is sensitive if it was represented as Literal Header + * Field Never Indexed. + * + * @implSpec + * + *

The default implementation invokes {@code onDecoded(name, value)}. + * + * @param name + * header name + * @param value + * header value + * @param sensitive + * whether the value is sensitive + */ + default void onDecoded(CharSequence name, + CharSequence value, + boolean sensitive) { + onDecoded(name, value); + } + + /** + * An Indexed + * Field Line decoded. + * + * @implSpec + * + *

The default implementation invokes + * {@code onDecoded(name, value, false)}. + * + * @param index + * index of a name/value pair in static or dynamic table + * @param name + * header name + * @param value + * header value + */ + default void onIndexed(long index, CharSequence name, CharSequence value) { + onDecoded(name, value, false); + } + + /** + * A Literal + * Field Line with Name Reference decoded, where a {@code name} was + * referred by an {@code index}. + * + * @implSpec + * + *

The default implementation invokes + * {@code onDecoded(name, value, false)}. + * + * @param index + * index of an entry in the table + * @param value + * header value + * @param valueHuffman + * if the {@code value} was Huffman encoded + * @param hideIntermediary + * if the header field should be written to intermediary nodes + */ + default void onLiteralWithNameReference(long index, + CharSequence name, + CharSequence value, + boolean valueHuffman, + boolean hideIntermediary) { + onDecoded(name, value, hideIntermediary); + } + + /** + * A Literal Field + * Line with Literal Name decoded, where both a {@code name} and a {@code value} + * were literal. + * + * @implSpec + * + *

The default implementation invokes + * {@code onDecoded(name, value, false)}. + * + * @param name + * header name + * @param nameHuffman + * if the {@code name} was Huffman encoded + * @param value + * header value + * @param valueHuffman + * if the {@code value} was Huffman encoded + */ + default void onLiteralWithLiteralName(CharSequence name, boolean nameHuffman, + CharSequence value, boolean valueHuffman, + boolean hideIntermediary) { + onDecoded(name, value, hideIntermediary); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/DynamicTable.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/DynamicTable.java new file mode 100644 index 00000000000..6f5e567e182 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/DynamicTable.java @@ -0,0 +1,1069 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack; + +import jdk.internal.net.http.http3.streams.QueuingStreamPair; +import jdk.internal.net.http.qpack.Encoder.EncodingContext; +import jdk.internal.net.http.qpack.Encoder.SectionReference; +import jdk.internal.net.http.qpack.QPACK.Logger; +import jdk.internal.net.http.qpack.writers.EncoderInstructionsWriter; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Comparator; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Predicate; + +import static java.lang.String.format; +import static jdk.internal.net.http.http3.Http3Error.H3_CLOSED_CRITICAL_STREAM; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.NORMAL; +import static jdk.internal.net.http.qpack.TableEntry.EntryType.NAME; + +/* + * The dynamic table to store header fields. Implements dynamic table described + * in "QPACK: Header Compression for HTTP/3" RFC. + * The size of the table is the sum of the sizes of its entries. + */ +public final class DynamicTable implements HeadersTable { + + // QPACK Section 3.2.1: + // The size of an entry is the sum of its name's length in bytes, + // its value's length in bytes, and 32 additional bytes. + public static final long ENTRY_SIZE = 32L; + + // Initial length of the elements array + // It is required for this value to be a power of 2 integer + private static final int INITIAL_HOLDER_ARRAY_LENGTH = 64; + + final Logger logger; + + // Capacity (Maximum size) in bytes (or capacity in RFC 9204) of the dynamic table + private long capacity; + + // RFC-9204: 3.2.3. Maximum Dynamic Table Capacity + private long maxCapacity; + + // Max entries is required to implement encoding of Required Insert Count + // in Field Lines Prefix: + // RFC-9204: 4.5.1.1. Required Insert Count: + // "This encoding limits the length of the prefix on long-lived connections." + private long maxEntries; + + // Size of the dynamic table in bytes - calculated as the sum of the sizes of its entries. + private long size; + + // Table elements holder and its state variables + // Absolute ID of tail and head elements. + // tail id - is an id of the oldest element in the table + // head id - is an id of the next element that will be added to the table. + // head element id is head - 1. + // drain id - is the lowest element id that encoder can reference + private long tail, head, drain = -1; + + // Used space percentage threshold when to start increasing the drain index + private final int drainUsedSpaceThreshold = QPACK.ENCODER_DRAINING_THRESHOLD; + + // true - table is used by the QPack encoder, otherwise used by the + // QPack decoder + private final boolean encoderTable; + + // Array that holds dynamic table entries + private HeaderField[] elements; + + // name -> (value -> [index]) + private final Map>> indicesMap; + + private record TableInsertCountNotification(long streamId, long minimumRIC, + CompletableFuture completion) { + public boolean isStreamId(long streamId) { + return this.streamId == streamId; + } + public boolean isFulfilled(long insertionCount) { + return insertionCount >= minimumRIC; + } + } + + private final Queue insertCountNotifications = + new PriorityQueue<>( + Comparator.comparingLong(TableInsertCountNotification::minimumRIC) + ); + + public CompletableFuture awaitFutureInsertCount(long streamId, + long valueToAwait) { + if (encoderTable) { + throw new IllegalStateException("Misconfigured table"); + } + var writeLock = lock.writeLock(); + writeLock.lock(); + try { + var completion = new CompletableFuture(); + long insertCount = insertCount(); + if (insertCount >= valueToAwait) { + completion.complete(null); + } else { + insertCountNotifications + .add(new TableInsertCountNotification( + streamId, valueToAwait, completion)); + } + return completion; + } finally { + writeLock.unlock(); + } + } + + private void notifyInsertCountChange() { + assert lock.isWriteLockedByCurrentThread(); + if (insertCountNotifications.isEmpty()) { + return; + } + long insertCount = insertCount(); + Predicate isFulfilled = + icn -> icn.isFulfilled(insertCount); + insertCountNotifications.removeIf(icn -> completeIf(isFulfilled, icn)); + } + + public boolean cleanupStreamInsertCountNotifications(long streamId) { + var writeLock = lock.writeLock(); + writeLock.lock(); + try { + Predicate isSameStreamId = + icn -> icn.isStreamId(streamId); + return insertCountNotifications.removeIf(icn -> completeIf(isSameStreamId, icn)); + } finally { + writeLock.unlock(); + } + } + + private static boolean completeIf(Predicate predicate, + TableInsertCountNotification insertCountNotification) { + if (predicate.test(insertCountNotification)) { + insertCountNotification.completion.complete(null); + return true; + } + return false; + } + + // Read-Write lock to manage access to table entries + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + + public DynamicTable(Logger logger) { + this(logger, true); + } + + public DynamicTable(Logger logger, boolean encoderTable) { + this.logger = logger; + this.encoderTable = encoderTable; + elements = new HeaderField[INITIAL_HOLDER_ARRAY_LENGTH]; + indicesMap = new HashMap<>(); + // -1 signifies that max table capacity was not yet initialized + maxCapacity = -1L; + maxEntries = 0L; + } + + /** + * Returns size of the dynamic table in bytes + * @return size of the dynamic table + */ + public long size() { + var readLock = lock.readLock(); + readLock.lock(); + try { + return size; + } finally { + readLock.unlock(); + } + } + + /** + * Returns current capacity of the dynamic table + * @return current capacity + */ + public long capacity() { + var readLock = lock.readLock(); + readLock.lock(); + try { + return capacity; + } finally { + readLock.unlock(); + } + } + + /** + * Returns a maximum capacity in bytes of the dynamic table. + * @return maximum capacity + */ + public long maxCapacity() { + var readLock = lock.readLock(); + readLock.lock(); + try { + return maxCapacity; + } finally { + readLock.unlock(); + } + } + + /** + * Sets a maximum capacity in bytes of the dynamic table. + * + *

The value has to be agreed between decoder and encoder out-of-band, + * e.g. by a protocol that uses QPACK + * (see + * 3.2.3 Maximum Dynamic Table Capacity). + * + *

May be called only once to set maximum dynamic table capacity. + *

This method doesn't change the actual capacity of the dynamic table. + * + * @see #setCapacity(long) + * @param maxCapacity a non-negative long + * @throws IllegalArgumentException if max capacity is negative + * @throws IllegalStateException if max capacity was already set + */ + public void setMaxTableCapacity(long maxCapacity) { + var writeLock = lock.writeLock(); + writeLock.lock(); + try { + if (maxCapacity < 0) { + throw new IllegalArgumentException("maxCapacity >= 0: " + maxCapacity); + } + if (this.maxCapacity != -1L) { + // Max table capacity is initialized from SETTINGS frame which can be only received once: + // "If an endpoint receives a second SETTINGS frame on the control stream, + // the endpoint MUST respond with a connection error of type H3_FRAME_UNEXPECTED" + // [RFC 9114 https://www.rfc-editor.org/rfc/rfc9114.html#name-settings] + throw new IllegalStateException("Max Table Capacity can only be set once"); + } + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format("setting maximum allowed dynamic table capacity to %s", + maxCapacity)); + } + this.maxCapacity = maxCapacity; + this.maxEntries = maxCapacity / ENTRY_SIZE; + } finally { + writeLock.unlock(); + } + } + + /** + * Returns maximum possible number of entries that could be stored in the dynamic table + * with respect to MAX_CAPACITY setting. + * @return max entries + */ + public long maxEntries() { + var readLock = lock.readLock(); + readLock.lock(); + try { + return maxEntries; + } finally { + readLock.unlock(); + } + } + + /** + * Retrieves a header field by its absolute index. Entry referenced by an absolute + * index does not depend on the state of the dynamic table. + * @param uniqueID an entry unique index + * @return retrieved header field + * @throws IllegalArgumentException if entry is not received yet, + * already evicted or invalid entry index is specified. + */ + @Override + public HeaderField get(long uniqueID) { + var readLock = lock.readLock(); + readLock.lock(); + try { + if (uniqueID < 0) { + throw new IllegalArgumentException("Entry index invalid"); + } + // Not yet received entry + if (uniqueID >= head) { + throw new IllegalArgumentException("Entry not received yet"); + } + // Already evicted entry + if (uniqueID < tail) { + throw new IllegalArgumentException("Entry already evicted"); + } + return elements[(int) (uniqueID & (elements.length - 1))]; + } finally { + readLock.unlock(); + } + } + + /** + * Retrieves a header field by its relative index. Entry referenced by a relative index depends + * on the state of the dynamic table. + * @param relativeId index relative to the most recently inserted entry + * @return retrieved header field + */ + public HeaderField getRelative(long relativeId) { + // RFC 9204: 3.2.5. Relative Indexing + // "Relative indices begin at zero and increase in the opposite direction from the absolute index. + // Determining which entry has a relative index of 0 depends on the context of the reference. + // In encoder instructions (Section 4.3), a relative index of 0 refers to the most recently inserted + // value in the dynamic table." + var readLock = lock.readLock(); + readLock.lock(); + try { + return get(insertCount() - 1 - relativeId); + } finally { + readLock.unlock(); + } + } + + /** + * Converts absolute entry index to relative index that can be used + * in the encoder instructions. + * Relative index of 0 refers to the most recently inserted entry. + * + * @param absoluteId absolute index of an entry + * @return relative entry index + */ + public long toRelative(long absoluteId) { + var readLock = lock.readLock(); + readLock.lock(); + try { + assert absoluteId < head; + return head - 1 - absoluteId; + } finally { + readLock.unlock(); + } + } + + /** + * Search an absolute id of a name:value pair in the dynamic table. + * @param name a name to search for + * @param value a value to search for + * @return positive index if name:value match found, + * negative index if only name match found, + * 0 if no match found + */ + @Override + public long search(String name, String value) { + // This method is only designated for encoder use + if (!encoderTable) { + return 0; + } + var readLock = lock.readLock(); + readLock.lock(); + try { + Map> values = indicesMap.get(name); + if (values == null) { + return 0; + } + Deque indexes = values.get(value); + if (indexes != null) { + // "+1" since the index range [0..id] is mapped to [1..id+1] + return indexes.peekLast() + 1; + } else { + assert !values.isEmpty(); + Long any = values.values().iterator().next().peekLast(); // Iterator allocation + // Use last entry in found values with matching name, and use its index for + // encoding with name reference. + // Negation and "-1" since name-only matches are mapped from [0..id] to + // [-1..-id-1] region + return -any - 1; + } + } finally { + readLock.unlock(); + } + } + + /** + * Add an entry to the dynamic table. + * Entries could be evicted from the dynamic table. + * Unacknowledged section references are not checked by this method, therefore + * this method is intended to be used by the decoder only. The encoder should use + * overloaded method that takes global unacknowledged section reference. + * + * @param name header name + * @param value header value + * @return unique index of an entry added to the table. + * If element cannot be added {@code -1} is returned. + */ + @Override + public long insert(String name, String value) { + // Invoking toString() will possibly allocate Strings. But that's + // unavoidable at this stage. If a CharSequence is going to be stored in + // the table, it must not be mutable (e.g. for the sake of hashing). + return insert(new HeaderField(name, value), SectionReference.noReferences()); + } + + + /** + * Add entry to the dynamic table with name specified as index in static + * or dynamic table. + * Entries could be evicted from the dynamic table. + * Unacknowledged section references are not checked by this method, therefore + * this method is intended to be used by the decoder only. The encoder should use + * overloaded method that takes global unacknowledged section reference. + * + * @param nameIndex index of the header name to add + * @param isStaticIndex if name index references static table header name + * @param value header value + * @return unique index of an entry added to the table. + * If element cannot be added {@code -1} is returned. + * @throws IllegalStateException if table memory reclamation error observed + */ + public long insert(long nameIndex, boolean isStaticIndex, String value) { + var writeLock = lock.writeLock(); + writeLock.lock(); + try { + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Inserting with name index (nameIndex='%s' isStaticIndex='%s' value=%s)", + nameIndex, isStaticIndex, value)); + } + String name = isStaticIndex ? + StaticTable.HTTP3.get(nameIndex).name() : + getRelative(nameIndex).name(); + return insert(name, value); + } finally { + writeLock.unlock(); + } + } + + /** + * Add an entry to the dynamic table. + * Entries could be evicted from the dynamic table. + * The supplied unacknowledged section references are checked by this method to check + * if entries are evictable. + * Such checks are performed when there is not enough space in the dynamic table to insert + * the requested header. + * This method is intended to be used by the encoder only. + * + * @param name header name + * @param value header value + * @param sectionReference unacknowledged section references + * @return unique index of an entry added to the table. + * If element cannot be added {@code -1} is returned. + * @throws IllegalStateException if table memory reclamation error observed + */ + public long insert(String name, String value, SectionReference sectionReference) { + return insert(new HeaderField(name, value), sectionReference); + } + + /** + * Inserts an entry to the dynamic table and sends encoder insert instruction bytes + * to the peer decoder. + * This method is designated to be used by the {@link Encoder} class only. + * If an entry with matching name:value is available, its index is returned + * and no insert instruction is generated on encoder stream. If duplicate entry is required + * due to entry being non-referencable then {@link DynamicTable#duplicateWithEncoderStreamUpdate( + * EncoderInstructionsWriter, long, QueuingStreamPair, EncodingContext)} is used. + * + * @param entry table entry to add + * @param writer non-configured encoder instruction writer for generating encoder + * instruction + * @param encoderStreams encoder stream pair + * @param encodingContext encoder encoding context + * @return absolute id of inserted entry OR already available entry, -1L if entry cannot + * be added + */ + public long insertWithEncoderStreamUpdate(TableEntry entry, + EncoderInstructionsWriter writer, + QueuingStreamPair encoderStreams, + EncodingContext encodingContext) { + if (!encoderTable) { + throw new IllegalStateException("Misconfigured table"); + } + String name = entry.name().toString(); + String value = entry.value().toString(); + // Entry with name only match in dynamic table + boolean nameOnlyDynamicEntry = !entry.isStaticTable() && entry.type() == NAME; + var writeLock = lock.writeLock(); + writeLock.lock(); + try { + // First, check if entry is in the table already - + // no need to add a new one. + long index = search(name, value); + if (index > 0) { + long absIndex = index - 1; + // Check if found entry can be referenced, + // if not issue duplicate instruction + if (!canReferenceEntry(absIndex)) { + return duplicateWithEncoderStreamUpdate(writer, + absIndex, encoderStreams, encodingContext); + } + return absIndex; + } + SectionReference evictionLimitSR = encodingContext.evictionLimit(); + if (nameOnlyDynamicEntry) { + long nameIndex = entry.index(); + if (!canReferenceEntry(nameIndex)) { + return ENTRY_NOT_INSERTED; + } + evictionLimitSR = evictionLimitSR.reduce(nameIndex); + encodingContext.registerSessionReference(nameIndex); + } + // Relative index calculation should precede the insertion + // due to dependency on insert count value + long relativeNameIndex = + nameOnlyDynamicEntry ? toRelative(entry.index()) : -1; + + // Insert new entry to the table with respect to entry + // references range provided by the encoding context + long idx = insert(name, value, evictionLimitSR); + if (idx == ENTRY_NOT_INSERTED) { + // Insertion requires eviction of entries from unacknowledged + // sections therefore entry is not added + return ENTRY_NOT_INSERTED; + } + // Entry was successfully inserted + if (nameOnlyDynamicEntry) { + // Absolute index only needs to be replaced with the relative one + // when it references a name in the dynamic table. + entry = entry.relativizeDynamicTableEntry(relativeNameIndex); + } + int instructionSize = writer.configureForEntryInsertion(entry); + writeEncoderInstruction(writer, instructionSize, encoderStreams); + return idx; + } finally { + writeLock.unlock(); + } + } + + private long insert(HeaderField h, SectionReference sectionReference) { + var writeLock = lock.writeLock(); + writeLock.lock(); + try { + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("adding ('%s', '%s')", h.name(), h.value())); + } + long entrySize = headerSize(h); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("size of ('%s', '%s') is %s", h.name(), h.value(), entrySize)); + } + + long availableEvictableSpace = availableEvictableSpace(sectionReference); + if (availableEvictableSpace < entrySize) { + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Header size exceeds available evictable space=%s." + + " Combined section reference=%s", + availableEvictableSpace, sectionReference)); + } + // Evicting entries won't help to gather enough space to insert the requested one + return ENTRY_NOT_INSERTED; + } + while (entrySize > capacity - size && size != 0) { + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("insufficient space %s, must evict entry", (capacity - size))); + } + // Only Encoder will supply section with referenced + // entries + if (sectionReference.referencesEntries()) { + // Check if tail element is evictable + if (tail < sectionReference.min()) { + if (!evictEntry()) { + return ENTRY_NOT_INSERTED; + } + } else { + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Cannot evict entry: sectionRef=%s tail=%s", + sectionReference, tail)); + } + // For now -1 is returned to notify the Encoder that entry + // cannot be inserted to the dynamic table + return ENTRY_NOT_INSERTED; + } + } else { + // This call can be called by both Encoder and Decoder: + // - Encoder when add new entry with no unacked section references + // - Decoder when processing insert entry instructions. + // Entries are evicted until there is enough space OR until table + // is empty. + if (!evictEntry()) { + return ENTRY_NOT_INSERTED; + } + } + } + size += entrySize; + // At this stage it is clear that there are enough bytes (max capacity is not exceeded) in the dynamic + // table to add new header field + addWithInverseMapping(h); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("('%s, '%s') added", h.name(), h.value())); + logger.log(EXTRA, this::toString); + } + notifyInsertCountChange(); + return head - 1; + } finally { + writeLock.unlock(); + } + } + + public long duplicate(long relativeId) { + var writeLock = lock.writeLock(); + writeLock.lock(); + try { + var entry = getRelative(relativeId); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Duplicate entry with absId=%s" + + " insertCount=%s ('%s', '%s')", + insertCount() - 1 - relativeId, insertCount(), + entry.name(), entry.value())); + } + return insert(entry.name(), entry.value()); + } finally { + writeLock.unlock(); + } + } + + public long duplicateWithEncoderStreamUpdate(EncoderInstructionsWriter writer, + long absoluteEntryId, + QueuingStreamPair encoderStreams, + EncodingContext encodingContext) { + if (!encoderTable) { + throw new IllegalStateException("Misconfigured table"); + } + var writeLock = lock.writeLock(); + writeLock.lock(); + try { + var entry = get(absoluteEntryId); + // Relative index calculation should precede the insertion + // due to dependency on insert count value + long relativeEntryId = toRelative(absoluteEntryId); + + // Make entry id that needs to be duplicated non-evictable + SectionReference evictionLimit = encodingContext.evictionLimit() + .reduce(absoluteEntryId); + + // Put duplicated entry to our dynamic table first + long idx = insert(entry.name(), entry.value(), + evictionLimit); + if (idx == ENTRY_NOT_INSERTED) { + return ENTRY_NOT_INSERTED; + } + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Issuing entry duplication instruction" + + " for absId=%s relId=%s ('%s', '%s')", + absoluteEntryId, relativeEntryId, entry.name(), entry.value())); + } + + // Configure writer for entry duplication + int instructionSize = + writer.configureForEntryDuplication(relativeEntryId); + + // Write instruction to the encoder stream + writeEncoderInstruction(writer, instructionSize, encoderStreams); + return idx; + } finally { + writeLock.unlock(); + } + } + + private HeaderField remove() { + assert lock.isWriteLockedByCurrentThread(); + // Remove element from the holder array first + if (getElementsCount() == 0) { + throw new IllegalStateException("Empty table"); + } + + int tailIdx = (int) (tail++ & (elements.length - 1)); + HeaderField f = elements[tailIdx]; + elements[tailIdx] = null; + + // Log the removal event + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("removing ('%s', '%s')", f.name(), f.value())); + } + + // Update indices map on the encoder table only + if (encoderTable) { + Map> values = indicesMap.get(f.name()); + Deque indexes = values.get(f.value()); + // Remove the oldest index of the name:value pair + Long index = indexes.pollFirst(); + // Clean-up indexes associated with a value from values map + if (indexes.isEmpty()) { + values.remove(f.value()); + } + assert index != null; + // If indexes map associated with name is empty remove name + // entry from indices map + if (values.isEmpty()) { + indicesMap.remove(f.name()); + } + } + return f; + } + + /** + * Sets the dynamic table capacity in bytes. + * The new capacity must be lower than or equal to the limit defined by + * SETTINGS_QPACK_MAX_TABLE_CAPACITY HTTP/3 settings parameter. This limit is + * enforced by {@linkplain DynamicTable#setMaxTableCapacity(long)}. + * + * @param capacity dynamic table capacity to set + */ + public void setCapacity(long capacity) { + var writeLock = lock.writeLock(); + writeLock.lock(); + try { + if (capacity > maxCapacity) { + // Calling code catches IllegalArgumentException and generates the connection error: + // 4.3.1. Set Dynamic Table Capacity: + // "The decoder MUST treat a new dynamic table capacity value that exceeds + // this limit as a connection error of type QPACK_ENCODER_STREAM_ERROR." + throw new IllegalArgumentException("Illegal dynamic table capacity"); + } + if (capacity < 0) { + throw new IllegalArgumentException("capacity >= 0: capacity=" + capacity); + } + while (capacity < size && size != 0) { + // Evict entries until existing elements fit into + // new table capacity + boolean entryEvicted = evictEntry(); + assert entryEvicted; + } + this.capacity = capacity; + if (usedSpace() < drainUsedSpaceThreshold) { + if (drain != -1) { + drain = -1; + } + } else if (drain == -1 || tail > drain) { + drain = tail; + } + } finally { + writeLock.unlock(); + } + } + + /** + * Updates the capacity of the dynamic table and sends encoder capacity update instruction + * bytes to the peer decoder. + * This method is designated to be used by the {@link Encoder} class only. + * @param writer non-configured encoder instruction writer for generating encoder instruction + * @param capacity new capacity value + * @param encoderStreams encoder stream pair + */ + public void setCapacityWithEncoderStreamUpdate(EncoderInstructionsWriter writer, long capacity, + QueuingStreamPair encoderStreams) { + var writeLock = lock.writeLock(); + writeLock.lock(); + try { + // Configure writer for capacity update + int instructionSize = writer.configureForTableCapacityUpdate(capacity); + // Check and set our capacity + setCapacity(capacity); + // Write instruction + writeEncoderInstruction(writer, instructionSize, encoderStreams); + } finally { + writeLock.unlock(); + } + } + + + /** + * Evicts one entry from the table tail. + * @return {@code true} if entry was evicted, + * {@code false} if nothing to remove + */ + private boolean evictEntry() { + assert lock.isWriteLockedByCurrentThread(); + try { + HeaderField f = remove(); + long s = headerSize(f); + this.size -= s; + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, + () -> format("evicted entry ('%s', '%s') of size %s with absId=%s", + f.name(), f.value(), s, tail - 1)); + } + } catch (IllegalStateException ise) { + // Entry cannot be evicted from empty table + return false; + } + return true; + } + + public long availableEvictableSpace(SectionReference sectionReference) { + var readLock = lock.readLock(); + readLock.lock(); + try { + if (!sectionReference.referencesEntries()) { + return capacity; + } + // Size that can be reclaimed in the dynamic table by evicting + // non-referenced entries + long availableEvictableCapacity = 0; + for (long absId = tail; absId < sectionReference.min(); absId++) { + HeaderField field = get(absId); + availableEvictableCapacity += headerSize(field); + } + // (capacity - size) - free space in the dynamic table + return availableEvictableCapacity + (capacity - size); + } finally { + readLock.unlock(); + } + } + + public boolean tryReferenceEntry(TableEntry tableEntry, EncodingContext context) { + var readLock = lock.readLock(); + readLock.lock(); + try { + long absId = tableEntry.index(); + if (canReferenceEntry(absId)) { + context.registerSessionReference(absId); + context.referenceEntry(tableEntry); + return true; + } else { + return false; + } + } finally { + readLock.unlock(); + } + } + + @Override + public String toString() { + var readLock = lock.readLock(); + readLock.lock(); + try { + double used = usedSpace(); + return format("full length: %s, used space: %s/%s (%.1f%%)", + getElementsCount(), size, capacity, used); + } finally { + readLock.unlock(); + } + } + + private boolean canReferenceEntry(long absId) { + // The dynamic table lock is acquired by the calling methods + return absId > drain; + } + + private double usedSpace() { + return capacity == 0 ? 0 : 100 * (((double) size) / capacity); + } + + public static long headerSize(HeaderField f) { + return headerSize(f.name(), f.value()); + } + + public static long headerSize(String name, String value) { + return name.length() + value.length() + ENTRY_SIZE; + } + + // To quickly find an index of an entry in the dynamic table with the + // given contents an effective inverse mapping is needed. + private void addWithInverseMapping(HeaderField field) { + assert lock.isWriteLockedByCurrentThread(); + // Check if holder array has at least one free slot to add header field + // The method below can increase elements.length if no free slot found + ensureElementsArrayLength(); + long counterSnapshot = head++; + elements[(int) (counterSnapshot & (elements.length - 1))] = field; + if (encoderTable) { + // Allocate unique index and use it to store in indicesMap + Map> values = indicesMap.computeIfAbsent( + field.name(), _ -> new HashMap<>()); + Deque indexes = values.computeIfAbsent( + field.value(), _ -> new LinkedList<>()); + indexes.add(counterSnapshot); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, + () -> format("added '%s' header field with '%s' unique id", + field, counterSnapshot)); + } + assert indexesUniqueAndOrdered(indexes); + // Draining index is only used by the Encoder + updateDrainIndex(); + } + } + + private void updateDrainIndex() { + if (!encoderTable) { + return; + } + assert lock.isWriteLockedByCurrentThread(); + if (usedSpace() > drainUsedSpaceThreshold) { + if (drain == -1L) { + drain = tail; + } else { + drain++; + } + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Draining index changed: %d", drain)); + } + } + } + + private void ensureElementsArrayLength() { + assert lock.isWriteLockedByCurrentThread(); + + int currentArrayLength = elements.length; + if (getElementsCount() == currentArrayLength) { + if (currentArrayLength == (1 << 30)) { + throw new IllegalStateException("No room for more elements"); + } + // Increase elements array by factor of 2 + resize(currentArrayLength << 1); + } + } + + private boolean indexesUniqueAndOrdered(Deque indexes) { + long maxIndexSoFar = -1L; + for (long l : indexes) { + if (l <= maxIndexSoFar) { + return false; + } else { + maxIndexSoFar = l; + } + } + return true; + } + + private int getElementsCount() { + // head and tail are unique and monotonic indexes - therefore we can just use their + // difference to determine number of header:value pairs stored in the dynamic table. + // Since head points to the next unused element head == 0 means that there is + // no elements in the dynamic table + return head > 0 ? (int) (head - tail) : 0; + } + + private void resize(int newSize) { + // newSize is always a power of 2: + // - its initial size is a power of 2 + // - it is shifted 1 bit left every + // time when there is not enough space in the 'elements' array + assert lock.isWriteLockedByCurrentThread(); + int elementsCnt = getElementsCount(); + final int oldSize = elements.length; + + if (newSize < elementsCnt) { + throw new IllegalArgumentException("New size is too low to hold existing elements"); + } + + HeaderField[] newElements = new HeaderField[newSize]; + if (elementsCnt == 0) { + elements = newElements; + return; + } + long headID = head - 1; + final int oldTailIdx = (int) (tail & (oldSize - 1)); + final int oldHeadIdx = (int) (headID & (oldSize - 1)); + final int newTailIdx = (int) (tail & (newSize - 1)); + final int newHeadIdx = (int) (headID & (newSize - 1)); + + if (oldTailIdx <= oldHeadIdx) { + // Elements in an old array are stored in a continuous segment + if (newTailIdx <= newHeadIdx) { + // Elements in a new array will be stored in a continuous segment + System.arraycopy(elements, oldTailIdx, newElements, newTailIdx, elementsCnt); + } else { + // Elements in a new array will split in two segments due to wrapping around + // the end of a new array. + int sizeFromNewTailToEnd = newSize - newTailIdx; + System.arraycopy(elements, oldTailIdx, newElements, newTailIdx, sizeFromNewTailToEnd); + System.arraycopy(elements, oldTailIdx + sizeFromNewTailToEnd, + newElements, 0, newHeadIdx + 1); + } + } else { + // Elements in an old array are split in two segments + if (newTailIdx <= newHeadIdx) { + // Elements in a new array will be stored in a continuous segment + int firstSegmentSize = oldSize - oldTailIdx; + System.arraycopy(elements, oldTailIdx, newElements, newTailIdx, firstSegmentSize); + System.arraycopy(elements, 0, + newElements, newTailIdx + firstSegmentSize, oldHeadIdx + 1); + } else { + // Elements in a new array will be stored in two segments + // Size from the tail to the end in an old array + int oldPart1Size = oldSize - oldTailIdx; + // Size from the tail to the end in a new array + int newPart1Size = newSize - newTailIdx; + if (oldPart1Size <= newPart1Size) { + // Segment from tail to the end of an old array + // fits into the corresponding segment in a new array + System.arraycopy(elements, oldTailIdx, newElements, newTailIdx, oldPart1Size); + int leftToCopyToNewPart1 = newPart1Size - oldPart1Size; + System.arraycopy(elements, 0, newElements, + newTailIdx + oldPart1Size, leftToCopyToNewPart1); + System.arraycopy(elements, leftToCopyToNewPart1, + newElements, 0, newHeadIdx + 1); + } else { // oldPart1Size > newPart1Size + // Not possible given two restrictions: + // - we do not allow rewriting of entries if size is not enough, + // IAE is thrown above. + // - the size of elements holder array can only be a power of 2 + throw new AssertionError("Not possible dynamic table indexes configuration"); + } + } + } + elements = newElements; + } + + /** + * Method returns number of elements inserted to the dynamic table. + * Since element ids start from 0 the returned value is equal + * to the id of the head element plus one. + * @return number of elements in the dynamic table + */ + public long insertCount() { + var rl = lock.readLock(); + rl.lock(); + try { + // head points to the next unallocated element + return head; + } finally { + rl.unlock(); + } + } + + // Writes an encoder instruction to the encoder stream associated with dynamic table. + // This method is kept in DynamicTable class since most instructions depend on + // and/or update the dynamic table state. + // Also, we want to send encoder instruction and update the dynamic table state while + // holding the write-lock. + private void writeEncoderInstruction(EncoderInstructionsWriter writer, int instructionSize, + QueuingStreamPair encoderStreams) { + if (instructionSize > encoderStreams.credit()) { + throw new QPackException(H3_CLOSED_CRITICAL_STREAM, + new IOException("QPACK not enough credit on an encoder stream " + + encoderStreams.remoteStreamType()), true); + } + boolean done; + ByteBuffer buffer; + do { + if (instructionSize > MAX_BUFFER_SIZE) { + buffer = ByteBuffer.allocate(MAX_BUFFER_SIZE); + instructionSize -= MAX_BUFFER_SIZE; + } else { + buffer = ByteBuffer.allocate(instructionSize); + } + done = writer.write(buffer); + buffer.flip(); + encoderStreams.submitData(buffer); + } while (!done); + } + private static final int MAX_BUFFER_SIZE = 1024 * 16; + static final long ENTRY_NOT_INSERTED = -1L; +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/Encoder.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/Encoder.java new file mode 100644 index 00000000000..26da27a702a --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/Encoder.java @@ -0,0 +1,672 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack; + +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.streams.QueuingStreamPair; +import jdk.internal.net.http.qpack.QPACK.Logger; +import jdk.internal.net.http.qpack.QPACK.QPACKErrorHandler; +import jdk.internal.net.http.qpack.QPACK.StreamPairSupplier; +import jdk.internal.net.http.qpack.TableEntry.EntryType; +import jdk.internal.net.http.qpack.readers.DecoderInstructionsReader; +import jdk.internal.net.http.qpack.writers.EncoderInstructionsWriter; +import jdk.internal.net.http.qpack.writers.FieldLineSectionPrefixWriter; +import jdk.internal.net.http.qpack.writers.HeaderFrameWriter; +import jdk.internal.net.http.quic.streams.QuicStreamReader; + +import java.net.ProtocolException; +import java.net.http.HttpHeaders; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Predicate; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static jdk.internal.net.http.http3.Http3Error.H3_CLOSED_CRITICAL_STREAM; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.NORMAL; + +/** + * Encodes headers to their binary representation. + */ +public class Encoder { + private static final AtomicLong ENCODERS_IDS = new AtomicLong(); + + // RFC 9204 7.1.3. Never-Indexed Literals: + // "Implementations can also choose to protect sensitive fields by not + // compressing them and instead encoding their value as literals" + private static final Set SENSITIVE_HEADER_NAMES = + Set.of("cookie", "authorization", "proxy-authorization"); + + private final Logger logger; + private final InsertionPolicy policy; + private final TablesIndexer tablesIndexer; + private final DynamicTable dynamicTable; + private final QueuingStreamPair encoderStreams; + private final DecoderInstructionsReader decoderInstructionsReader; + // RFC-9204: 2.1.4. Known Received Count + private long knownReceivedCount; + + // Lock for Known Received Count variable + private final ReentrantReadWriteLock krcLock = new ReentrantReadWriteLock(); + + // Max blocked streams setting value received from the peer decoder + // can be set only once + private long maxBlockedStreams = -1L; + + // Number of streams in process of headers encoding that expected to be blocked + // but their unacknowledged section is not registered yet + private long blockedStreamsInFlight; + + private final ReentrantLock blockedStreamsCounterLock = new ReentrantLock(); + + // stream id -> fifo list of max and min ids referenced from field sections for each stream id + private final ConcurrentMap> unacknowledgedSections = + new ConcurrentHashMap<>(); + + // stream id -> set of referenced entry absolute indexes from a field line section that currently + // are in process of encoding and not added to the unacknowledged field sections map yet. + private final ConcurrentMap> liveContextReferences = + new ConcurrentHashMap<>(); + + private final QPACKErrorHandler qpackErrorHandler; + + public HeaderFrameWriter newHeaderFrameWriter() { + return new HeaderFrameWriter(logger); + } + + /** + * Constructs an {@code Encoder} with zero initial capacity of the dynamic table. + * Maximum dynamic table capacity is not initialized until peer (decoder) HTTP/3 settings frame is + * received (see {@link Encoder#configure(ConnectionSettings)}). + * + *

Dynamic table capacity values has to be agreed between decoder and encoder out-of-band, + * e.g. by a protocol that uses QPACK. + *

Maximum dynamic table capacity is determined by the value of SETTINGS_QPACK_MAX_TABLE_CAPACITY + * HTTP/3 setting sent by the decoder side (see + * + * 3.2.3. Maximum Dynamic Table Capacity). + *

An encoder informs the decoder of a change to the dynamic table capacity using the + * "Set Dynamic Table Capacity" instruction + * (see + * 4.3.1. Set Dynamic Table Capacity) + * + * @param streamPairs supplier of the encoder unidirectional stream pair + * @throws IllegalArgumentException if maxCapacity is negative + * @see Encoder#configure(ConnectionSettings) + */ + public Encoder(InsertionPolicy policy, StreamPairSupplier streamPairs, QPACKErrorHandler codingError) { + this.policy = policy; + long id = ENCODERS_IDS.incrementAndGet(); + this.logger = QPACK.getLogger().subLogger("Encoder#" + id); + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> "New encoder"); + } + if (logger.isLoggable(EXTRA)) { + /* To correlate with logging outside QPACK, knowing + hashCode/toString is important */ + logger.log(EXTRA, () -> { + String hashCode = Integer.toHexString( + System.identityHashCode(this)); + /* Since Encoder can be subclassed hashCode AND identity + hashCode might be different. So let's print both. */ + return format("toString='%s', hashCode=%s, identityHashCode=%s", + this, hashCode(), hashCode); + }); + } + // Set maximum dynamic table to 0, postpone setting of max capacity until peer + // settings frame is received + dynamicTable = new DynamicTable(logger.subLogger("DynamicTable"), true); + tablesIndexer = new TablesIndexer(StaticTable.HTTP3, dynamicTable); + encoderStreams = streamPairs.create(this::processDecoderAcks); + decoderInstructionsReader = new DecoderInstructionsReader(new TableUpdatesCallback(), + logger); + qpackErrorHandler = codingError; + } + + /** + * Configures encoder according to the settings received from the peer. + * + * @param peerSettings the peer settings + */ + public void configure(ConnectionSettings peerSettings) { + blockedStreamsCounterLock.lock(); + try { + if (maxBlockedStreams == -1) { + maxBlockedStreams = peerSettings.qpackBlockedStreams(); + } else { + throw new IllegalStateException("Encoder already configured"); + } + } finally { + blockedStreamsCounterLock.unlock(); + } + // Set max dynamic table capacity + long maxCapacity = peerSettings.qpackMaxTableCapacity(); + dynamicTable.setMaxTableCapacity(maxCapacity); + // Send DT capacity update instruction if the peer negotiated non-zero + // max table capacity, and limit the value with encoder's table capacity + // limit system property value + if (QPACK.ENCODER_TABLE_CAPACITY_LIMIT > 0 && maxCapacity > 0) { + long encoderCapacity = Math.min(maxCapacity, QPACK.ENCODER_TABLE_CAPACITY_LIMIT); + setTableCapacity(encoderCapacity); + } + } + + public QueuingStreamPair encoderStreams() { + return encoderStreams; + } + + public void header(EncodingContext context, CharSequence name, CharSequence value, + boolean sensitive) throws IllegalStateException { + header(context, name, value, sensitive, knownReceivedCount()); + } + + /** + * Sets up the given header {@code (name, value)} with possibly sensitive + * value. + * + *

If the {@code value} is sensitive (think security, secrecy, etc.) + * this encoder will compress it using a special representation + * (see + * 7.1.3. Never-Indexed Literals). + * + *

Fixates {@code name} and {@code value} for the duration of encoding. + * + * @param context the encoding context + * @param name the name + * @param value the value + * @param sensitive whether the value is sensitive + * @param knownReceivedCount the count of received entries known to a peer decoder or + * {@code -1} to skip the dynamic table entry index check during header encoding. + * @throws NullPointerException if any of the arguments are {@code null} + * @throws IllegalStateException if the encoder hasn't fully encoded the previous header, or + * hasn't yet started to encode it + * @see DecodingCallback#onDecoded(CharSequence, CharSequence, boolean) + */ + public void header(EncodingContext context, CharSequence name, CharSequence value, + boolean sensitive, long knownReceivedCount) throws IllegalStateException { + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format("encoding ('%s', '%s'), sensitive: %s", + name, value, sensitive)); + } + requireNonNull(name, "name"); + requireNonNull(value, "value"); + + // TablesIndexer.entryOf checks if the found entry is a dynamic table entry, + // and if its insertion was already ACKed. If not - use literal or name index encoding. + var tableEntry = tablesIndexer.entryOf(name, value, knownReceivedCount); + + // NAME_VALUE table entry type means that one of dynamic or static tables contain + // exact name:value pair. + if (dynamicTable.capacity() > 0L + && tableEntry.type() != EntryType.NAME_VALUE + && !sensitive && policy.shouldUpdateDynamicTable(tableEntry)) { + // We should check if we have an entry in dynamic table: + // - If we have it - do nothing + // - if we do not have it - insert it and use the index straight-away + // when blocking encoding is allowed + tableEntry = context.tryInsertEntry(tableEntry); + } + + // First, check that found/newly inserted entry is in the dynamic table + // and can be referenced + if (!tableEntry.isStaticTable() && tableEntry.index() >= 0 && + tableEntry.type() != EntryType.NEITHER) { + if (!dynamicTable.tryReferenceEntry(tableEntry, context)) { + // If entry cannot be referenced - use literal encoding instead + tableEntry = tableEntry.toLiteralsEntry(); + } + } + + // Configure header frame writer to write header field to the headers frame. One of the following + // writers is selected based on entry type, the base value and the referenced table (static or dynamic): + // - static table and name:value match - "Indexed Field Line" + // - static table and name match - "Literal Field Line with Name Reference" + // - dynamic table, name:value match and index < base - "Indexed Field Line" + // - dynamic table, name match and index < base - "Literal Field Line with Name Reference" + // - dynamic table, name:value match and index >= base - "Indexed Field Line with Post-Base Index" + // - dynamic table, name match and index >= base - "Literal Field Line with Post-Base Name Reference" + // - not in dynamic or static tables - "Literal Field Line with Literal Name" + context.writer.configure(tableEntry, sensitive, context.base); + } + + /** + * Sets the capacity of the encoder's dynamic table and notifies the decoder by + * issuing "Set Dynamic Table Capacity" instruction. + * + *

The value has to be agreed between decoder and encoder out-of-band, + * e.g. by a protocol that uses QPACK + * (see + * 4.3.1. Set Dynamic Table Capacity). + * + * @param capacity a non-negative long + * @throws IllegalArgumentException if capacity is negative or exceeds the negotiated max capacity HTTP/3 setting + */ + public void setTableCapacity(long capacity) { + dynamicTable.setCapacityWithEncoderStreamUpdate(new EncoderInstructionsWriter(logger), + capacity, encoderStreams); + } + + /** + * This method is called when the peer decoder sends + * data on the peer's decoder stream + * + * @param buffer data sent by the peer's decoder + */ + private void processDecoderAcks(ByteBuffer buffer) { + if (buffer == QuicStreamReader.EOF) { + // RFC-9204, section 4.2: + // Closure of either unidirectional stream type MUST be treated as a connection + // error of type H3_CLOSED_CRITICAL_STREAM. + qpackErrorHandler.closeOnError( + new ProtocolException("QPACK " + encoderStreams.remoteStreamType() + + " remote stream was unexpectedly closed"), H3_CLOSED_CRITICAL_STREAM); + return; + } + try { + decoderInstructionsReader.read(buffer); + } catch (QPackException e) { + qpackErrorHandler.closeOnError(e.getCause(), e.http3Error()); + } + } + + public List encodeHeaders(HeaderFrameWriter writer, long streamId, + int bufferSize, HttpHeaders... headers) { + List buffers = new ArrayList<>(); + ByteBuffer buffer = getByteBuffer(bufferSize); + + try (EncodingContext encodingContext = newEncodingContext(streamId, + dynamicTable.insertCount(), writer)) { + for (HttpHeaders header : headers) { + for (Map.Entry> e : header.map().entrySet()) { + // RFC-9114, section 4.2: Field names are strings containing a subset of + // ASCII characters. .... Characters in field names MUST be converted to + // lowercase prior to their encoding. + final String lKey = e.getKey().toLowerCase(Locale.ROOT); + final List values = e.getValue(); + // An encoder might also choose not to index values for fields that are + // considered to be highly valuable or sensitive to recovery, such as the + // Cookie or Authorization header fields + final boolean sensitive = SENSITIVE_HEADER_NAMES.contains(lKey); + for (String value : values) { + header(encodingContext, lKey, value, sensitive); + while (!writer.write(buffer)) { + buffer.flip(); + buffers.add(buffer); + buffer = getByteBuffer(bufferSize); + } + } + } + } + buffer.flip(); + buffers.add(buffer); + + // Put field line section prefix as the first byte buffer + generateFieldLineSectionPrefix(encodingContext, buffers); + + // Register field line section as unacked if it uses references to the + // dynamic table entries + registerUnackedFieldLineSection(streamId, SectionReference.of(encodingContext)); + } + return buffers; + } + + public void generateFieldLineSectionPrefix(EncodingContext encodingContext, List buffers) { + // Write field section prefix according to RFC 9204: "4.5.1. Encoded Field Section Prefix" + FieldLineSectionPrefixWriter prefixWriter = new FieldLineSectionPrefixWriter(); + FieldSectionPrefix fsp = encodingContext.sectionPrefix(); + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format("Encoding Field Section Prefix - required insert" + + " count: %d base: %d", + fsp.requiredInsertCount(), fsp.base())); + } + int requiredSize = prefixWriter.configure(fsp, dynamicTable.maxEntries()); + var fspBuffer = getByteBuffer(requiredSize); + if (!prefixWriter.write(fspBuffer)) { + throw new IllegalStateException("Field Line Section Prefix"); + } + fspBuffer.flip(); + buffers.addFirst(fspBuffer); + } + + public void registerUnackedFieldLineSection(long streamId, SectionReference sectionReference) { + if (sectionReference.referencesEntries()) { + unacknowledgedSections + .computeIfAbsent(streamId, k -> new ConcurrentLinkedQueue<>()) + .add(sectionReference); + } + } + + // This one is for tracking evict-ability of dynamic table entries + public SectionReference unackedFieldLineSectionsRange(EncodingContext context) { + SectionReference referenceNotRegisteredYet = SectionReference.of(context); + return unackedFieldLineSectionsRange(referenceNotRegisteredYet); + } + + private SectionReference unackedFieldLineSectionsRange(SectionReference initial) { + return unacknowledgedSections.values().stream() + .flatMap(Queue::stream) + .reduce(initial, SectionReference::reduce); + } + + long blockedStreamsCount() { + long blockedStreams = 0; + long krc = knownReceivedCount(); + for (var streamSections : unacknowledgedSections.values()) { + boolean hasBlockedSection = streamSections.stream() + .anyMatch(sectionReference -> !sectionReference.fullyAcked(krc)); + blockedStreams = hasBlockedSection ? blockedStreams + 1 : blockedStreams; + } + return blockedStreams; + } + + + public long knownReceivedCount() { + krcLock.readLock().lock(); + try { + return knownReceivedCount; + } finally { + krcLock.readLock().unlock(); + } + } + + private void updateKrcSectionAck(long streamId) { + krcLock.writeLock().lock(); + try { + var queue = unacknowledgedSections.get(streamId); + // max() + 1 - since it is "Required Insert Count" not entry ID + SectionReference oldestSectionRef = queue != null ? queue.poll() : null; + long oldestNonAckedRic = oldestSectionRef != null ? oldestSectionRef.max() + 1 : -1L; + if (oldestNonAckedRic == -1L) { + // RFC 9204 4.4.1. Section Acknowledgment: + // If an encoder receives a Section Acknowledgment instruction referring + // to a stream on which every encoded field section with a non-zero + // Required Insert Count has already been acknowledged, this MUST be treated + // as a connection error of type QPACK_DECODER_STREAM_ERROR. + var qPackException = QPackException.decoderStreamError( + new IllegalStateException("No unacknowledged sections found" + + " for stream id = " + streamId)); + throw qPackException; + } + // "2.1.4. Known Received Count": + // If the Required Insert Count of the acknowledged field section is greater + // than the current Known Received Count, the Known Received Count is updated + // to that Required Insert Count value. + if (oldestNonAckedRic != -1 && knownReceivedCount < oldestNonAckedRic) { + knownReceivedCount = oldestNonAckedRic; + } + } finally { + krcLock.writeLock().unlock(); + } + } + + private void updateKrcInsertCountIncrement(long increment) { + long insertCount = dynamicTable.insertCount(); + krcLock.writeLock().lock(); + try { + // An encoder that receives an Increment field equal to zero, or one that increases + // the Known Received Count beyond what the encoder has sent, MUST treat this as + // a connection error of type QPACK_DECODER_STREAM_ERROR. + if (increment == 0 || knownReceivedCount > insertCount - increment) { + var qpackException = QPackException.decoderStreamError( + new IllegalStateException("Invalid increment field value: " + increment)); + throw qpackException; + } + knownReceivedCount += increment; + } finally { + krcLock.writeLock().unlock(); + } + } + + private void cleanupStreamData(long streamId) { + liveContextReferences.remove(streamId); + unacknowledgedSections.remove(streamId); + } + + private class TableUpdatesCallback implements DecoderInstructionsReader.Callback { + @Override + public void onSectionAck(long streamId) { + updateKrcSectionAck(streamId); + } + + @Override + public void onInsertCountIncrement(long increment) { + updateKrcInsertCountIncrement(increment); + } + + @Override + public void onStreamCancel(long streamId) { + cleanupStreamData(streamId); + } + } + + public class EncodingContext implements AutoCloseable { + final long base; + final long streamId; + final ConcurrentSkipListSet referencedIndexes; + long maxIndex; + long minIndex; + boolean blockedDecoderExpected; + final HeaderFrameWriter writer; + final EncoderInstructionsWriter encoderInstructionsWriter; + + public EncodingContext(long streamId, long base, HeaderFrameWriter writer) { + this.base = base; + this.encoderInstructionsWriter = new EncoderInstructionsWriter(logger); + this.writer = writer; + this.maxIndex = -1L; + this.minIndex = Long.MAX_VALUE; + this.streamId = streamId; + this.referencedIndexes = liveContextReferences.computeIfAbsent(streamId, + _ -> new ConcurrentSkipListSet<>()); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Begin encoding session with base = %s stream-id = %s", base, streamId)); + } + } + + public void registerSessionReference(long absoluteEntryId) { + referencedIndexes.add(absoluteEntryId); + } + + @Override + public void close() { + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Closing encoding context for stream-id=%s" + + " session references:%s", + streamId, referencedIndexes)); + } + liveContextReferences.remove(streamId); + // Deregister if this stream was marked as in-flight blocked + blockedStreamsCounterLock.lock(); + try { + if (blockedDecoderExpected) { + blockedStreamsInFlight--; + } + } finally { + blockedStreamsCounterLock.unlock(); + } + } + + public FieldSectionPrefix sectionPrefix() { + // RFC 9204: 2.1.2. Blocked Streams + // "the Required Insert Count is one larger than the largest absolute index + // of all referenced dynamic table entries" + // largestAbsoluteIndex is initialized to -1, and if there is no dynamic + // table entry references - RIC will be set to 0. + return new FieldSectionPrefix(maxIndex + 1, base); + } + + public SectionReference evictionLimit() { + // In-flight references - a set with entry ids referenced from all + // active header encoding sessions not fully encoded yet + SectionReference inFlightReferences = SectionReference.singleReference( + liveContextReferences.values().stream() + .filter(Predicate.not(ConcurrentSkipListSet::isEmpty)) + .map(ConcurrentSkipListSet::first) + .min(Long::compare) + .orElse(-1L)); + + // Calculate the eviction limit with respect to: + // - in-flight references + // - acknowledged dynamic table insertions + // - range of unacknowledged sections which already fully encoded + // and sent as part of other request/response streams + return inFlightReferences + .reduce(knownReceivedCount()) + .reduce(unackedFieldLineSectionsRange(this)); + } + + public TableEntry tryInsertEntry(TableEntry entry) { + long idx = dynamicTable.insertWithEncoderStreamUpdate(entry, + encoderInstructionsWriter, encoderStreams, + this); + if (idx == DynamicTable.ENTRY_NOT_INSERTED) { + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Not adding entry '%s' to the dynamic " + + "table - not enough space, or unacknowledged entry needs to be evicted", + entry)); + } + // Return what we previously found in the dynamic or static table + return entry; + } + + if (QPACK.ALLOW_BLOCKING_ENCODING && canReferenceNewEntry()) { + // Create a new TableEntry that describes newly added header field + return entry.toNewDynamicTableEntry(idx); + } else { + return entry; + } + } + + private boolean canReferenceNewEntry() { + blockedStreamsCounterLock.lock(); + try { + // If current encoding context is already marked as blocked we can + // reference new entries without analyzing number of blocked streams + if (blockedDecoderExpected) { + return true; + } + // Number of streams with unacknowledged field line section + long alreadyBlocked = blockedStreamsCount(); + // Other streams might be in progress of headers encoding + boolean canReferenceNewEntry = maxBlockedStreams - alreadyBlocked - blockedStreamsInFlight > 0; + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("%s reference to newly added header. " + + "Number of blocked streams based on unAcked sections: %d " + + "Number of blocked streams in progress of encoding: %d " + + "Max allowed by HTTP/3 settings: %d", + canReferenceNewEntry ? "Allowing" : "Restricting", + alreadyBlocked, blockedStreamsInFlight, maxBlockedStreams)); + } + if (canReferenceNewEntry && !blockedDecoderExpected) { + blockedStreamsInFlight++; + blockedDecoderExpected = true; + } + return canReferenceNewEntry; + } finally { + blockedStreamsCounterLock.unlock(); + } + } + + public void referenceEntry(TableEntry tableEntry) { + assert tableEntry.index() >= 0; + if (!tableEntry.isStaticTable()) { + long index = tableEntry.index(); + maxIndex = Long.max(maxIndex, index); + minIndex = Long.min(minIndex, index); + } + } + } + + /** + * Descriptor of entries range referenced from a field lines section. + * + * @param min minimum entry id referenced from a field lines section + * @param max maximum entry id referenced from a field lines section + */ + public record SectionReference(long min, long max) { + public static SectionReference of(EncodingContext context) { + if (context.maxIndex == -1L) { + return SectionReference.noReferences(); + } + return new SectionReference(context.minIndex, context.maxIndex); + } + + public SectionReference reduce(SectionReference other) { + if (!referencesEntries()) { + return other; + } else if (!other.referencesEntries()) { + return this; + } + long newMin = Long.min(this.min, other.min); + long newMax = Long.max(this.max, other.max); + return new SectionReference(newMin, newMax); + } + + public SectionReference reduce(long entryId) { + return reduce(singleReference(entryId)); + } + + public static SectionReference singleReference(long entryId) { + return new SectionReference(entryId, entryId); + } + + public boolean fullyAcked(long knownReceiveCount) { + return max < knownReceiveCount; + } + + public static SectionReference noReferences() { + return new SectionReference(-1L, -1L); + } + + public boolean referencesEntries() { + return max != -1L; + } + } + + public EncodingContext newEncodingContext(long streamId, long base, HeaderFrameWriter writer) { + assert streamId >= 0; + assert base >= 0; + return new EncodingContext(streamId, base, writer); + } + + private ByteBuffer getByteBuffer(int size) { + ByteBuffer buf = ByteBuffer.allocate(size); + buf.limit(size); + return buf; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/FieldSectionPrefix.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/FieldSectionPrefix.java new file mode 100644 index 00000000000..cbf2037af6a --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/FieldSectionPrefix.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.qpack; + +public record FieldSectionPrefix(long requiredInsertCount, long base) { + + public static FieldSectionPrefix decode(long encodedRIC, long deltaBase, + int baseSign, DynamicTable dynamicTable) { + long decodedRIC = decodeRIC(encodedRIC, dynamicTable); + long decodedBase = decodeBase(decodedRIC, deltaBase, baseSign); + return new FieldSectionPrefix(decodedRIC, decodedBase); + } + + private static long decodeRIC(long encodedRIC, DynamicTable dynamicTable) { + if (encodedRIC == 0) { + return 0; + } + long maxEntries = dynamicTable.maxEntries(); + long insertCount = dynamicTable.insertCount(); + long fullRange = 2 * maxEntries; + if (encodedRIC > fullRange) { + throw decompressionFailed(); + } + long maxValue = insertCount + maxEntries; + long maxWrapped = (maxValue/fullRange) * fullRange; + long ric = maxWrapped + encodedRIC - 1; + if (ric > maxValue) { + if (ric <= fullRange) { + throw decompressionFailed(); + } + ric -= fullRange; + } + + if (ric == 0) { + throw decompressionFailed(); + } + return ric; + } + + private static long decodeBase(long decodedRic, long deltaBase, int signBit) { + if (signBit == 0) { + return decodedRic + deltaBase; + } else { + return decodedRic - deltaBase - 1; + } + } + + private static QPackException decompressionFailed() { + var decompressionFailed = new IllegalStateException("QPACK decompression failed"); + return QPackException.decompressionFailed(decompressionFailed, true); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/HeaderField.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/HeaderField.java new file mode 100644 index 00000000000..83ee21eda30 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/HeaderField.java @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack; + +public record HeaderField(String name, String value) { + + public HeaderField(String name) { + this(name, ""); + } + + @Override + public String toString() { + return value.isEmpty() ? name : name + ":" + value; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/HeadersTable.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/HeadersTable.java new file mode 100644 index 00000000000..4b70777401d --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/HeadersTable.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.qpack; + +public sealed interface HeadersTable permits StaticTable, DynamicTable { + + /** + * Add an entry to the table. + * + * @param name header name + * @param value header value + * @return unique index of entry added to the table. + * If element cannot be added {@code -1} is returned. + */ + long insert(String name, String value); + + /** + * Get a table entry with specified unique index. + * + * @param index an entry unique index + * @return table entry + */ + HeaderField get(long index); + + /** + * Returns an index for name:value pair, or just name in a headers table. + * The contract for return values is the following: + * - a positive integer {@code i} where {@code i - 1} is an index of an + * entry with a header (n, v), where {@code n.equals(name) && v.equals(value)}. + *

+ * - a negative integer {@code j} where {@code -j - 1} is an index of an entry with + * a header (n, v), where {@code n.equals(name)}. + *

+ * - {@code 0} if there's no entry 'e' found such that {@code e.getName().equals(name)} + * + * @param name a name to search for + * @param value a value to search for + * @return a non-zero value if a matching entry is found, 0 otherwise + */ + long search(String name, String value); +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/InsertionPolicy.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/InsertionPolicy.java new file mode 100644 index 00000000000..1bdf84304ca --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/InsertionPolicy.java @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack; + +public interface InsertionPolicy { + boolean shouldUpdateDynamicTable(TableEntry tableEntry); +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/QPACK.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/QPACK.java new file mode 100644 index 00000000000..44282e5ad53 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/QPACK.java @@ -0,0 +1,229 @@ +/* + * Copyright (c) 2017, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack; + +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.http.http3.streams.QueuingStreamPair; +import jdk.internal.net.http.qpack.QPACK.Logger.Level; + +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.ResourceBundle; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static java.lang.String.format; +import static jdk.internal.net.http.Http3ClientProperties.QPACK_ALLOW_BLOCKING_ENCODING; +import static jdk.internal.net.http.Http3ClientProperties.QPACK_DECODER_BLOCKED_STREAMS; +import static jdk.internal.net.http.Http3ClientProperties.QPACK_DECODER_MAX_FIELD_SECTION_SIZE; +import static jdk.internal.net.http.Http3ClientProperties.QPACK_DECODER_MAX_TABLE_CAPACITY; +import static jdk.internal.net.http.Http3ClientProperties.QPACK_ENCODER_DRAINING_THRESHOLD; +import static jdk.internal.net.http.Http3ClientProperties.QPACK_ENCODER_TABLE_CAPACITY_LIMIT; +import static jdk.internal.net.http.http3.frames.SettingsFrame.SETTINGS_MAX_FIELD_SECTION_SIZE; +import static jdk.internal.net.http.http3.frames.SettingsFrame.SETTINGS_QPACK_BLOCKED_STREAMS; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.NONE; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.NORMAL; + +/** + * Internal utilities and stuff. + */ +public final class QPACK { + + + // A dynamic table capacity that the encoder is allowed to set given that it doesn't + // exceed the max capacity value negotiated by the decoder. If the max capacity + // less than this limit the encoder's dynamic table capacity is set to the max capacity + // value. + public static final long ENCODER_TABLE_CAPACITY_LIMIT = QPACK_ENCODER_TABLE_CAPACITY_LIMIT; + + // The value of SETTINGS_QPACK_MAX_TABLE_CAPACITY HTTP/3 setting that is + // negotiated by HTTP client's decoder + public static final long DECODER_MAX_TABLE_CAPACITY = QPACK_DECODER_MAX_TABLE_CAPACITY; + + // The value of SETTINGS_MAX_FIELD_SECTION_SIZE HTTP/3 setting that is + // negotiated by HTTP client's decoder + public static final long DECODER_MAX_FIELD_SECTION_SIZE = QPACK_DECODER_MAX_FIELD_SECTION_SIZE; + + // Decoder upper bound on the number of streams that can be blocked + public static final long DECODER_BLOCKED_STREAMS = QPACK_DECODER_BLOCKED_STREAMS; + + // If set to "true" allows the encoder to insert a header with a dynamic + // name reference and reference it in a field line section without awaiting + // decoder's acknowledgement. + public static final boolean ALLOW_BLOCKING_ENCODING = QPACK_ALLOW_BLOCKING_ENCODING; + + // Threshold of available dynamic table space after which the draining + // index starts increasing. This index determines which entries are + // too close to eviction, and can be referenced by the encoder + public static final int ENCODER_DRAINING_THRESHOLD = QPACK_ENCODER_DRAINING_THRESHOLD; + + private static final RootLogger LOGGER; + private static final Map logLevels = + Map.of("NORMAL", NORMAL, "EXTRA", EXTRA); + + static { + String PROPERTY = "jdk.internal.httpclient.qpack.log.level"; + String value = Utils.getProperty(PROPERTY); + + if (value == null) { + LOGGER = new RootLogger(NONE); + } else { + String upperCasedValue = value.toUpperCase(); + Level l = logLevels.get(upperCasedValue); + if (l == null) { + LOGGER = new RootLogger(NONE); + LOGGER.log(System.Logger.Level.INFO, + () -> format("%s value '%s' not recognized (use %s); logging disabled", + PROPERTY, value, String.join(", ", logLevels.keySet()))); + } else { + LOGGER = new RootLogger(l); + LOGGER.log(System.Logger.Level.DEBUG, + () -> format("logging level %s", l)); + } + } + } + + public static Logger getLogger() { + return LOGGER; + } + + public static SettingsFrame updateDecoderSettings(SettingsFrame defaultSettingsFrame) { + SettingsFrame settingsFrame = defaultSettingsFrame; + settingsFrame.setParameter(SETTINGS_QPACK_BLOCKED_STREAMS, DECODER_BLOCKED_STREAMS); + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, DECODER_MAX_TABLE_CAPACITY); + settingsFrame.setParameter(SETTINGS_MAX_FIELD_SECTION_SIZE, DECODER_MAX_FIELD_SECTION_SIZE); + return settingsFrame; + } + + private QPACK() { } + + /** + * The purpose of this logger is to provide means of diagnosing issues _in + * the QPACK implementation_. It's not a general purpose logger. + */ + // implements System.Logger to make it possible to skip this class + // when looking for the Caller. + public static class Logger implements System.Logger { + + /** + * Log detail level. + */ + public enum Level { + + NONE(0, System.Logger.Level.OFF), + NORMAL(1, System.Logger.Level.DEBUG), + EXTRA(2, System.Logger.Level.TRACE); + + private final int level; + final System.Logger.Level systemLevel; + + Level(int i, System.Logger.Level system) { + level = i; + systemLevel = system; + } + + public final boolean implies(Level other) { + return this.level >= other.level; + } + } + + private final String name; + private final Level level; + private final String path; + private final System.Logger logger; + + private Logger(String path, String name, Level level) { + this.path = path; + this.name = name; + this.level = level; + this.logger = Utils.getHpackLogger(path::toString, level.systemLevel); + } + + public final String getName() { + return name; + } + + @Override + public boolean isLoggable(System.Logger.Level level) { + return logger.isLoggable(level); + } + + @Override + public void log(System.Logger.Level level, ResourceBundle bundle, String msg, Throwable thrown) { + logger.log(level, bundle, msg,thrown); + } + + @Override + public void log(System.Logger.Level level, ResourceBundle bundle, String format, Object... params) { + logger.log(level, bundle, format, params); + } + + /* + * Usual performance trick for logging, reducing performance overhead in + * the case where logging with the specified level is a NOP. + */ + + public boolean isLoggable(Level level) { + return this.level.implies(level); + } + + public void log(Level level, Supplier s) { + if (this.level.implies(level)) { + logger.log(level.systemLevel, s); + } + } + + public Logger subLogger(String name) { + return new Logger(path + "/" + name, name, level); + } + + } + + private static final class RootLogger extends Logger { + + protected RootLogger(Level level) { + super("qpack", "qpack", level); + } + + } + + // -- low-level utilities -- + + /** + * An interface used to obtain the encoder or decoder stream pair + * from the enclosing HTTP/3 connection. + */ + @FunctionalInterface + public interface StreamPairSupplier { + QueuingStreamPair create(Consumer receiver); + } + + public interface QPACKErrorHandler { + void closeOnError(Throwable throwable, Http3Error error); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/QPackException.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/QPackException.java new file mode 100644 index 00000000000..6252a1b9549 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/QPackException.java @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.qpack; + +import jdk.internal.net.http.http3.Http3Error; + +/** + * Represents a QPack related failure as a failure cause and + * an HTTP/3 error code. + */ +public final class QPackException extends RuntimeException { + + @java.io.Serial + private static final long serialVersionUID = 8443631555257118370L; + + private final boolean isConnectionError; + private final Http3Error http3Error; + + public QPackException(Http3Error http3Error, Throwable cause, boolean isConnectionError) { + super(cause); + this.isConnectionError = isConnectionError; + this.http3Error = http3Error; + } + + public static QPackException encoderStreamError(Throwable cause) { + throw new QPackException(Http3Error.QPACK_ENCODER_STREAM_ERROR, cause, true); + } + + public static QPackException decoderStreamError(Throwable cause) { + throw new QPackException(Http3Error.QPACK_DECODER_STREAM_ERROR, cause, true); + } + + public static QPackException decompressionFailed(Throwable cause, boolean isConnectionError) { + throw new QPackException(Http3Error.QPACK_DECOMPRESSION_FAILED, cause, isConnectionError); + } + + + public Http3Error http3Error() { + return http3Error; + } + + public boolean isConnectionError() { + return isConnectionError; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/StaticTable.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/StaticTable.java new file mode 100644 index 00000000000..52900a2d1c3 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/StaticTable.java @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/* + * A header table with most common header fields. + * This table was generated by analyzing actual Internet traffic in 2018 + * and is a part of "QPACK: Header Compression for HTTP/3" RFC. + */ +public final class StaticTable implements HeadersTable { + + /* An immutable list of static header fields */ + public static final List HTTP3_HEADER_FIELDS = List.of( + new HeaderField(":authority"), + new HeaderField(":path", "/"), + new HeaderField("age", "0"), + new HeaderField("content-disposition"), + new HeaderField("content-length", "0"), + new HeaderField("cookie"), + new HeaderField("date"), + new HeaderField("etag"), + new HeaderField("if-modified-since"), + new HeaderField("if-none-match"), + new HeaderField("last-modified"), + new HeaderField("link"), + new HeaderField("location"), + new HeaderField("referer"), + new HeaderField("set-cookie"), + new HeaderField(":method", "CONNECT"), + new HeaderField(":method", "DELETE"), + new HeaderField(":method", "GET"), + new HeaderField(":method", "HEAD"), + new HeaderField(":method", "OPTIONS"), + new HeaderField(":method", "POST"), + new HeaderField(":method", "PUT"), + new HeaderField(":scheme", "http"), + new HeaderField(":scheme", "https"), + new HeaderField(":status", "103"), + new HeaderField(":status", "200"), + new HeaderField(":status", "304"), + new HeaderField(":status", "404"), + new HeaderField(":status", "503"), + new HeaderField("accept", "*/*"), + new HeaderField("accept", "application/dns-message"), + new HeaderField("accept-encoding", "gzip, deflate, br"), + new HeaderField("accept-ranges", "bytes"), + new HeaderField("access-control-allow-headers", "cache-control"), + new HeaderField("access-control-allow-headers", "content-type"), + new HeaderField("access-control-allow-origin", "*"), + new HeaderField("cache-control", "max-age=0"), + new HeaderField("cache-control", "max-age=2592000"), + new HeaderField("cache-control", "max-age=604800"), + new HeaderField("cache-control", "no-cache"), + new HeaderField("cache-control", "no-store"), + new HeaderField("cache-control", "public, max-age=31536000"), + new HeaderField("content-encoding", "br"), + new HeaderField("content-encoding", "gzip"), + new HeaderField("content-type", "application/dns-message"), + new HeaderField("content-type", "application/javascript"), + new HeaderField("content-type", "application/json"), + new HeaderField("content-type", "application/x-www-form-urlencoded"), + new HeaderField("content-type", "image/gif"), + new HeaderField("content-type", "image/jpeg"), + new HeaderField("content-type", "image/png"), + new HeaderField("content-type", "text/css"), + new HeaderField("content-type", "text/html; charset=utf-8"), + new HeaderField("content-type", "text/plain"), + new HeaderField("content-type", "text/plain;charset=utf-8"), + new HeaderField("range", "bytes=0-"), + new HeaderField("strict-transport-security", "max-age=31536000"), + new HeaderField("strict-transport-security", "max-age=31536000; includesubdomains"), + new HeaderField("strict-transport-security", "max-age=31536000; includesubdomains; preload"), + new HeaderField("vary", "accept-encoding"), + new HeaderField("vary", "origin"), + new HeaderField("x-content-type-options", "nosniff"), + new HeaderField("x-xss-protection", "1; mode=block"), + new HeaderField(":status", "100"), + new HeaderField(":status", "204"), + new HeaderField(":status", "206"), + new HeaderField(":status", "302"), + new HeaderField(":status", "400"), + new HeaderField(":status", "403"), + new HeaderField(":status", "421"), + new HeaderField(":status", "425"), + new HeaderField(":status", "500"), + new HeaderField("accept-language"), + new HeaderField("access-control-allow-credentials", "FALSE"), + new HeaderField("access-control-allow-credentials", "TRUE"), + new HeaderField("access-control-allow-headers", "*"), + new HeaderField("access-control-allow-methods", "get"), + new HeaderField("access-control-allow-methods", "get, post, options"), + new HeaderField("access-control-allow-methods", "options"), + new HeaderField("access-control-expose-headers", "content-length"), + new HeaderField("access-control-request-headers", "content-type"), + new HeaderField("access-control-request-method", "get"), + new HeaderField("access-control-request-method", "post"), + new HeaderField("alt-svc", "clear"), + new HeaderField("authorization"), + new HeaderField("content-security-policy", "script-src 'none'; object-src 'none'; base-uri 'none'"), + new HeaderField("early-data", "1"), + new HeaderField("expect-ct"), + new HeaderField("forwarded"), + new HeaderField("if-range"), + new HeaderField("origin"), + new HeaderField("purpose", "prefetch"), + new HeaderField("server"), + new HeaderField("timing-allow-origin", "*"), + new HeaderField("upgrade-insecure-requests", "1"), + new HeaderField("user-agent"), + new HeaderField("x-forwarded-for"), + new HeaderField("x-frame-options", "deny"), + new HeaderField("x-frame-options", "sameorigin") + ); + + public static final StaticTable HTTP3 = new StaticTable(HTTP3_HEADER_FIELDS); + + private final List headerFields; + private final Map> indicesMap; + + private StaticTable(List headerFields) { + this.headerFields = headerFields; + this.indicesMap = buildIndicesMap(headerFields); + } + + @Override + public HeaderField get(long index) { + if (index >= headerFields.size()) { + throw new IllegalArgumentException("Invalid static table entry index"); + } + return headerFields.get((int)index); + } + + @Override + public long insert(String name, String value) { + throw new UnsupportedOperationException("Operation not supported by static tables"); + } + + @Override + public long search(String name, String value) { + Map values = indicesMap.get(name); + // 0 return value if no match is found in the static table + int searchResult = 0; + if (values != null) { + Integer idx = values.get(value); + if (idx != null) { + searchResult = idx + 1; + } else { + // Only name is found - return first id from indices for the name provided + searchResult = -values.values().iterator().next() - 1; + } + } + return searchResult; + } + + private static Map> buildIndicesMap(List fields) { + int numEntries = fields.size(); + Map> map = new HashMap<>(numEntries); + for (int i = 0; i < numEntries; i++) { + HeaderField f = fields.get(i); + Map values = map.computeIfAbsent(f.name(), _ -> new HashMap<>()); + values.put(f.value(), i); + } + return map; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/TableEntry.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/TableEntry.java new file mode 100644 index 00000000000..6603c1d4c44 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/TableEntry.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.qpack; + +import jdk.internal.net.http.hpack.QuickHuffman; + +// Record containing table information for entry +public record TableEntry(boolean isStaticTable, long index, CharSequence name, CharSequence value, + EntryType type, boolean huffmanName, boolean huffmanValue) { + + public TableEntry(boolean isStaticTable, long index, CharSequence name, CharSequence value, EntryType type) { + this(isStaticTable, index, name, value, type, + isHuffmanBetterFor(name, true, type), + isHuffmanBetterFor(value, false, type)); + } + + public TableEntry toNewDynamicTableEntry(long index) { + return new TableEntry(false, index, name, value, EntryType.NAME_VALUE); + } + + public TableEntry relativizeDynamicTableEntry(long relativeIndex) { + assert !isStaticTable; + assert relativeIndex >= 0; + return new TableEntry(false, relativeIndex, name, value, type); + } + + public TableEntry(CharSequence name, CharSequence value) { + this(false, -1L, name, value, EntryType.NEITHER, + isHuffmanBetterFor(name, true, EntryType.NEITHER), + isHuffmanBetterFor(value, false, EntryType.NEITHER)); + } + + public TableEntry toLiteralsEntry() { + return new TableEntry(name, value); + } + + /** + * EntryType describes the type of TableEntry as either: + *

+ * - NAME_VALUE: a table entry where both name and value exist in table + * - NAME: a table entry where only name is present in table + * - NEITHER: a table entry where neither name nor value have been found + */ + public enum EntryType {NAME_VALUE, NAME, NEITHER} + + static boolean isHuffmanBetterFor(CharSequence str, boolean isName, EntryType type) { + return switch (type) { + case NEITHER -> QuickHuffman.isHuffmanBetterFor(str); + case NAME_VALUE -> false; + case NAME -> !isName && QuickHuffman.isHuffmanBetterFor(str); + }; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/TablesIndexer.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/TablesIndexer.java new file mode 100644 index 00000000000..5afe25d9781 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/TablesIndexer.java @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack; + +import static jdk.internal.net.http.qpack.TableEntry.EntryType.NAME; +import static jdk.internal.net.http.qpack.TableEntry.EntryType.NAME_VALUE; + +/* + * Adds reverse lookup to dynamic and static tables. + * Decoder does not need this functionality. On the other hand, + * Encoder does. + */ +public final class TablesIndexer { + + private final DynamicTable dynamicTable; + private final StaticTable staticTable; + + public TablesIndexer(StaticTable staticTable, DynamicTable dynamicTable) { + this.dynamicTable = dynamicTable; + this.staticTable = staticTable; + } + + /** + * Searches in dynamic and static tables for an entry that has matching name + * or name:value. + * Found dynamic table entry ids are matched against provided + * known receive count value if it is non-negative. + * If known receive count value is negative the entry id check is + * not performed. + * + * @param name entry name to search + * @param value entry value to search + * @param knownReceivedCount known received count to match dynamic table + * entries, if negative - id check is not performed. + * @return a table entry that matches provided parameters + */ + public TableEntry entryOf(CharSequence name, CharSequence value, + long knownReceivedCount) { + // Invoking toString() will possibly allocate Strings for the sake of + // the searchDynamic, which doesn't feel right. + String n = name.toString(); + String v = value.toString(); + + // Tests can use -1 known receive count value to filter dynamic table + // entry ids. + boolean limitDynamicTableEntryIds = knownReceivedCount >= 0; + + // 1. Try exact match in the static table + var staticSearchResult = staticTable.search(n, v); + if (staticSearchResult > 0) { + // name:value pair is found in static table + return new TableEntry(true, staticSearchResult - 1, + name, value, NAME_VALUE); + } + // 2. Try exact match in the dynamic table + var dynamicSearchResult = dynamicTable.search(n, v); + if (dynamicSearchResult == 0 && staticSearchResult == 0) { + // dynamic and static tables do not contain name or name:value entries + // - use literal table entry + return new TableEntry(name, value); + } + long dtEntryId; + // name:value hit in dynamic table + if (dynamicSearchResult > 0) { + dtEntryId = dynamicSearchResult - 1; + if (!limitDynamicTableEntryIds || dtEntryId < knownReceivedCount) { + return new TableEntry(false, dtEntryId, name, value, + NAME_VALUE); + } + } + // Name only hit in the static table + if (staticSearchResult < 0) { + return new TableEntry(true, -staticSearchResult - 1, name, + value, NAME); + } + + // Name only hit in the dynamic table + if (dynamicSearchResult < 0) { + dtEntryId = -dynamicSearchResult - 1; + if (!limitDynamicTableEntryIds || dtEntryId < knownReceivedCount) { + return new TableEntry(false, dtEntryId, name, value, NAME); + } + } + + // No match found in the tables, or there is a dynamic table entry that has + // name or 'name:value' match but its index is greater than max allowed dynamic + // table index, ie the entry is not acknowledged by the decoder. + return new TableEntry(name, value); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/package-info.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/package-info.java new file mode 100644 index 00000000000..d171e81dfbf --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/package-info.java @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +/** + * QPACK (Header Compression for HTTP/3) implementation conforming to + * RFC 9204. + * + *

Headers can be decoded and encoded by {@link jdk.internal.net.http.qpack.Decoder} + * and {@link jdk.internal.net.http.qpack.Encoder} respectively. + * + *

Instances of these classes are not safe for use by multiple threads. + */ +package jdk.internal.net.http.qpack; diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/DecoderInstructionsReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/DecoderInstructionsReader.java new file mode 100644 index 00000000000..70ddace72ce --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/DecoderInstructionsReader.java @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.qpack.readers; + +import jdk.internal.net.http.qpack.QPACK.Logger; +import jdk.internal.net.http.qpack.QPackException; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static jdk.internal.net.http.http3.Http3Error.QPACK_DECODER_STREAM_ERROR; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; + +/* + * Reader for decoder instructions described in RFC9204 + * "4.4 Encoder Instructions" section. + * Read instructions are passed to the consumer via the DecoderInstructionsReader.Callback + * instance supplied to the reader constructor. + */ +public class DecoderInstructionsReader { + enum State { + INIT, + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 1 | Stream ID (7+) | + +---+---------------------------+ + */ + SECTION_ACKNOWLEDGMENT, + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 1 | Stream ID (6+) | + +---+---+-----------------------+ + */ + STREAM_CANCELLATION, + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 0 | Increment (6+) | + +---+---+-----------------------+ + */ + INSERT_COUNT_INCREMENT + } + + private State state; + private final IntegerReader integerReader; + private final Callback callback; + private final Logger logger; + + public DecoderInstructionsReader(Callback callback, Logger logger) { + this.integerReader = new IntegerReader( + new ReaderError(QPACK_DECODER_STREAM_ERROR, true)); + this.callback = callback; + this.state = State.INIT; + this.logger = logger.subLogger("DecoderInstructionsReader"); + } + + public void read(ByteBuffer buffer) { + requireNonNull(buffer, "buffer"); + while (buffer.hasRemaining()) { + switch (state) { + case INIT: + integerReader.reset(); + state = identifyDecoderInstruction(buffer); + break; + case INSERT_COUNT_INCREMENT, SECTION_ACKNOWLEDGMENT, STREAM_CANCELLATION: + // All decoder instructions consists of only one variable + // length integer field, therefore we fully read integer and + // then call the callback method depending on the state value + if (integerReader.read(buffer)) { + long value = integerReader.get(); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Instruction: %s value: %s", + state.name(), value)); + } + // dispatch instruction to the consumer via the callback + dispatchParsedInstruction(value); + state = State.INIT; + } + break; + } + } + } + + private State identifyDecoderInstruction(ByteBuffer buffer) { + int b = buffer.get(buffer.position()) & 0xFF; // absolute read + int pos = Integer.numberOfLeadingZeros(b) - 24; + return switch (pos) { + case 0 -> { + integerReader.configure(7); + yield State.SECTION_ACKNOWLEDGMENT; + } + case 1 -> { + integerReader.configure(6); + yield State.STREAM_CANCELLATION; + } + default -> { + if ((b & 0b1100_0000) == 0) { + integerReader.configure(6); + yield State.INSERT_COUNT_INCREMENT; + } else { + throw QPackException.decoderStreamError( + new IOException("Unexpected decoder instruction: " + b)); + } + } + }; + } + + private void dispatchParsedInstruction(long value) { + switch (state) { + case INSERT_COUNT_INCREMENT: + callback.onInsertCountIncrement(value); + break; + case SECTION_ACKNOWLEDGMENT: + callback.onSectionAck(value); + break; + case STREAM_CANCELLATION: + callback.onStreamCancel(value); + break; + default: + throw QPackException.decoderStreamError( + new IOException("Unknown decoder instruction")); + } + } + + public interface Callback { + void onSectionAck(long streamId); + + void onStreamCancel(long streamId); + + void onInsertCountIncrement(long increment); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/EncoderInstructionsReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/EncoderInstructionsReader.java new file mode 100644 index 00000000000..ff8b2868639 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/EncoderInstructionsReader.java @@ -0,0 +1,245 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.qpack.readers; + +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.QPackException; + +import java.nio.ByteBuffer; + +import static java.lang.String.format; +import static java.lang.System.Logger.Level.TRACE; +import static java.util.Objects.requireNonNull; +import static jdk.internal.net.http.http3.Http3Error.QPACK_ENCODER_STREAM_ERROR; + +/* + * Reader for encoder instructions defined in RFC9204 + * "4.3 Encoder Instructions" section. + * Read instruction is passed to the consumer via Callback + * interface supplied to the EncoderInstructionsReader constructor. + */ +public class EncoderInstructionsReader { + + enum State { + INIT, + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 0 | 1 | Capacity (5+) | + +---+---+---+-------------------+ + */ + DT_CAPACITY, + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 1 | T | Name Index (6+) | + +---+---+-----------------------+ + | H | Value Length (7+) | + +---+---------------------------+ + | Value String (Length bytes) | + +-------------------------------+ + */ + INSERT_NAME_REF_NAME, + INSERT_NAME_REF_VALUE, + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 1 | H | Name Length (5+) | + +---+---+---+-------------------+ + | Name String (Length bytes) | + +---+---------------------------+ + | H | Value Length (7+) | + +---+---------------------------+ + | Value String (Length bytes) | + +-------------------------------+ + */ + INSERT_NAME_LIT_NAME, + INSERT_NAME_LIT_VALUE, + + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 0 | 0 | Index (5+) | + +---+---+---+-------------------+ + */ + DUPLICATE + } + + private final QPACK.Logger logger; + private final Callback updateCallback; + private State state; + private final IntegerReader integerReader; + private final StringReader stringReader; + private int bitT = -1; + private long nameIndex = -1L; + private boolean huffmanValue; + private final StringBuilder valueString = new StringBuilder(); + + private boolean huffmanName; + private final StringBuilder nameString = new StringBuilder(); + + public EncoderInstructionsReader(Callback dtUpdateCallback, QPACK.Logger logger) { + this.logger = logger; + this.updateCallback = dtUpdateCallback; + this.state = State.INIT; + var errorToReport = new ReaderError(QPACK_ENCODER_STREAM_ERROR, true); + this.integerReader = new IntegerReader(errorToReport); + this.stringReader = new StringReader(errorToReport); + } + + public void read(ByteBuffer buffer, int maxStringLength) { + try { + read0(buffer, maxStringLength); + } catch (IllegalArgumentException | IllegalStateException exception) { + // "Duplicate" and "Insert With Name Reference" instructions can reference + // non-existing entries in the dynamic table. + // Such errors are treated as encoder stream errors. + throw QPackException.encoderStreamError(exception); + } + } + + private void read0(ByteBuffer buffer, int maxStringLength) { + requireNonNull(buffer, "buffer"); + while (buffer.hasRemaining()) { + switch (state) { + case INIT: + state = identifyEncoderInstruction(buffer); + break; + case DT_CAPACITY: + if (integerReader.read(buffer)) { + long capacity = integerReader.get(); + if (logger.isLoggable(TRACE)) { + logger.log(TRACE, () -> format("Dynamic Table Capacity update: %d", + capacity)); + } + updateCallback.onCapacityUpdate(integerReader.get()); + reset(); + } + break; + case INSERT_NAME_LIT_NAME: + if (stringReader.read(5, buffer, nameString, maxStringLength)) { + huffmanName = stringReader.isHuffmanEncoded(); + stringReader.reset(); + state = State.INSERT_NAME_LIT_VALUE; + } + break; + case INSERT_NAME_LIT_VALUE: + int stringReaderLimit = maxStringLength > 0 ? + Math.max(maxStringLength - nameString.length(), 0) : -1; + if (stringReader.read(buffer, valueString, stringReaderLimit)) { + huffmanValue = stringReader.isHuffmanEncoded(); + // Insert with literal name instruction completely parsed + if (logger.isLoggable(TRACE)) { + logger.log(TRACE, () -> format("Insert with Literal Name ('%s','%s'," + + " huffmanName='%s', huffmanValue='%s')", nameString, + valueString, huffmanName, huffmanValue)); + } + updateCallback.onInsert(nameString.toString(), valueString.toString()); + reset(); + } + break; + case INSERT_NAME_REF_NAME: + if (integerReader.read(buffer)) { + nameIndex = integerReader.get(); + state = State.INSERT_NAME_REF_VALUE; + } + break; + case INSERT_NAME_REF_VALUE: + if (stringReader.read(buffer, valueString, maxStringLength)) { + // Insert with name reference instruction completely parsed + if (logger.isLoggable(TRACE)) { + logger.log(TRACE, () -> format("Insert With Name Reference (T=%d, nameIdx=%d," + + " value='%s', valueHuffman='%s')", + bitT, nameIndex, valueString, stringReader.isHuffmanEncoded())); + } + updateCallback.onInsertIndexedName(bitT == 1, nameIndex, valueString.toString()); + reset(); + } + break; + case DUPLICATE: + if (integerReader.read(buffer)) { + updateCallback.onDuplicate(integerReader.get()); + reset(); + } + break; + } + } + } + + private State identifyEncoderInstruction(ByteBuffer buffer) { + int b = buffer.get(buffer.position()) & 0xFF; // absolute read + int pos = Integer.numberOfLeadingZeros(b) - 24; + return switch (pos) { + case 0 -> { + // Configure integer reader to read out name index and read the T bit + integerReader.configure(6); + bitT = (b & 0b0100_0000) == 0 ? 0 : 1; + yield State.INSERT_NAME_REF_NAME; + } + case 1 -> State.INSERT_NAME_LIT_NAME; + case 2 -> { + integerReader.configure(5); + yield State.DT_CAPACITY; + } + default -> { + boolean isDuplicateInstruction = (b & 0b1110_0000) == 0; + if (isDuplicateInstruction) { + integerReader.configure(5); + yield State.DUPLICATE; + } else { + throw QPackException.encoderStreamError( + new InternalError("Unexpected encoder instruction: " + b)); + } + } + }; + } + + public void reset() { + state = State.INIT; + bitT = -1; + nameIndex = -1L; + huffmanName = false; + huffmanValue = false; + resetBuffersAndReaders(); + } + + private void resetBuffersAndReaders() { + integerReader.reset(); + stringReader.reset(); + nameString.setLength(0); + valueString.setLength(0); + } + + public interface Callback { + void onCapacityUpdate(long capacity); + + void onInsert(String name, String value); + + void onInsertIndexedName(boolean indexInStaticTable, long nameIndex, String valueString); + + void onDuplicate(long l); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineIndexedPostBaseReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineIndexedPostBaseReader.java new file mode 100644 index 00000000000..ce2da3552fa --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineIndexedPostBaseReader.java @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.readers; + +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.FieldSectionPrefix; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.qpack.QPACK; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; + +import static java.lang.String.format; +import static jdk.internal.net.http.http3.Http3Error.QPACK_DECOMPRESSION_FAILED; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.NORMAL; + +final class FieldLineIndexedPostBaseReader extends FieldLineReader { + private final IntegerReader integerReader; + private final QPACK.Logger logger; + + public FieldLineIndexedPostBaseReader(DynamicTable dynamicTable, long maxSectionSize, + AtomicLong sectionSizeTracker, QPACK.Logger logger) { + super(dynamicTable, maxSectionSize, sectionSizeTracker); + this.integerReader = new IntegerReader( + new ReaderError(QPACK_DECOMPRESSION_FAILED, false)); + this.logger = logger; + } + + public void configure(int b) { + integerReader.configure(4); + } + + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 0 | 1 | Index (4+) | + // +---+---+---+---+---------------+ + // + public boolean read(ByteBuffer input, FieldSectionPrefix prefix, + DecodingCallback action) { + if (!integerReader.read(input)) { + return false; + } + long relativeIndex = integerReader.get(); + long absoluteIndex = prefix.base() + relativeIndex; + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format("Post-Base Indexed Field Line: base=%s index=%s[%s]", + prefix.base(), relativeIndex, absoluteIndex)); + } + checkEntryIndex(absoluteIndex, prefix); + HeaderField f = entryAtIndex(absoluteIndex); + checkSectionSize(DynamicTable.headerSize(f)); + action.onIndexed(absoluteIndex, f.name(), f.value()); + reset(); + return true; + } + + public void reset() { + integerReader.reset(); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineIndexedReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineIndexedReader.java new file mode 100644 index 00000000000..321a640bb5f --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineIndexedReader.java @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.readers; + +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.FieldSectionPrefix; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.qpack.QPACK; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; + +import static java.lang.String.format; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.NORMAL; + +final class FieldLineIndexedReader extends FieldLineReader { + private final IntegerReader integerReader; + private final QPACK.Logger logger; + + public FieldLineIndexedReader(DynamicTable dynamicTable, long maxSectionSize, + AtomicLong sectionSizeTracker, QPACK.Logger logger) { + super(dynamicTable, maxSectionSize, sectionSizeTracker); + this.logger = logger; + integerReader = new IntegerReader( + new ReaderError(Http3Error.QPACK_DECOMPRESSION_FAILED, false)); + } + + public void configure(int b) { + integerReader.configure(6); + fromStaticTable = (b & 0b0100_0000) != 0; + } + + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 1 | T | Index (6+) | + // +---+---------------------------+ + // + public boolean read(ByteBuffer input, FieldSectionPrefix prefix, + DecodingCallback action) { + if (!integerReader.read(input)) { + return false; + } + long intValue = integerReader.get(); + // "In a field line representation, a relative index of 0 refers to the + // entry with absolute index equal to Base - 1." + long absoluteIndex = fromStaticTable ? intValue : prefix.base() - 1 - intValue; + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format("%s index %s", fromStaticTable ? "Static" : "Dynamic", + absoluteIndex)); + } + checkEntryIndex(absoluteIndex, prefix); + HeaderField f = entryAtIndex(absoluteIndex); + checkSectionSize(DynamicTable.headerSize(f)); + action.onIndexed(absoluteIndex, f.name(), f.value()); + reset(); + return true; + } + + public void reset() { + integerReader.reset(); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineLiteralsReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineLiteralsReader.java new file mode 100644 index 00000000000..1dfa81b5631 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineLiteralsReader.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.readers; + +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.FieldSectionPrefix; +import jdk.internal.net.http.qpack.QPACK; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; + +import static java.lang.String.format; +import static jdk.internal.net.http.http3.Http3Error.QPACK_DECOMPRESSION_FAILED; +import static jdk.internal.net.http.qpack.DynamicTable.ENTRY_SIZE; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.NORMAL; + +final class FieldLineLiteralsReader extends FieldLineReader { + private boolean hideIntermediary; + private boolean huffmanName, huffmanValue; + private final StringBuilder name, value; + private final StringReader stringReader; + private final QPACK.Logger logger; + private boolean firstValueRead = false; + + public FieldLineLiteralsReader(long maxSectionSize, AtomicLong sectionSizeTracker, + QPACK.Logger logger) { + // Dynamic table is not needed for literals reader + super(null, maxSectionSize, sectionSizeTracker); + this.logger = logger; + stringReader = new StringReader(new ReaderError(QPACK_DECOMPRESSION_FAILED, false)); + name = new StringBuilder(512); + value = new StringBuilder(1024); + } + + public void configure(int b) { + hideIntermediary = (b & 0b0001_0000) != 0; + } + + // + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 1 | N | H |NameLen(3+)| + // +---+---+-----------------------+ + // | Name String (Length bytes) | + // +---+---------------------------+ + // | H | Value Length (7+) | + // +---+---------------------------+ + // | Value String (Length bytes) | + // +-------------------------------+ + // + public boolean read(ByteBuffer input, FieldSectionPrefix prefix, + DecodingCallback action) { + if (!completeReading(input)) { + long readPart = ENTRY_SIZE + name.length() + value.length(); + checkPartialSize(readPart); + return false; + } + String n = name.toString(); + String v = value.toString(); + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format( + "literal with literal name ('%s', huffman=%b, '%s', huffman=%b)", + n, huffmanName, v, huffmanValue)); + } + checkSectionSize(DynamicTable.headerSize(n, v)); + action.onLiteralWithLiteralName(n, huffmanName, v, huffmanValue, hideIntermediary); + reset(); + return true; + } + + private boolean completeReading(ByteBuffer input) { + if (!firstValueRead) { + if (!stringReader.read(3, input, name, getMaxFieldLineLimit(name.length()))) { + return false; + } + huffmanName = stringReader.isHuffmanEncoded(); + stringReader.reset(); + firstValueRead = true; + return false; + } else { + int maxLength = getMaxFieldLineLimit(name.length() + value.length()); + if (!stringReader.read(input, value, maxLength)) { + return false; + } + } + huffmanValue = stringReader.isHuffmanEncoded(); + stringReader.reset(); + return true; + } + + public void reset() { + name.setLength(0); + value.setLength(0); + firstValueRead = false; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineNameRefPostBaseReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineNameRefPostBaseReader.java new file mode 100644 index 00000000000..420844ad72d --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineNameRefPostBaseReader.java @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.readers; + +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.FieldSectionPrefix; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.qpack.QPACK; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; + +import static java.lang.String.format; +import static jdk.internal.net.http.http3.Http3Error.QPACK_DECOMPRESSION_FAILED; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.NORMAL; + +final class FieldLineNameRefPostBaseReader extends FieldLineReader { + private long intValue; + private boolean hideIntermediary; + private boolean huffmanValue; + private final StringBuilder value; + private final IntegerReader integerReader; + private final StringReader stringReader; + private final QPACK.Logger logger; + + private boolean firstValueRead = false; + + FieldLineNameRefPostBaseReader(DynamicTable dynamicTable, long maxSectionSize, + AtomicLong sectionSizeTracker, QPACK.Logger logger) { + super(dynamicTable, maxSectionSize, sectionSizeTracker); + this.logger = logger; + var errorToReport = new ReaderError(QPACK_DECOMPRESSION_FAILED, false); + integerReader = new IntegerReader(errorToReport); + stringReader = new StringReader(errorToReport); + value = new StringBuilder(1024); + } + + public void configure(int b) { + hideIntermediary = (b & 0b0000_1000) != 0; + integerReader.configure(3); + } + + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 0 | 0 | N |NameIdx(3+)| + // +---+---+---+---+---+-----------+ + // | H | Value Length (7+) | + // +---+---------------------------+ + // | Value String (Length bytes) | + // +-------------------------------+ + public boolean read(ByteBuffer input, FieldSectionPrefix prefix, + DecodingCallback action) { + if (!completeReading(input)) { + if (firstValueRead) { + long readPart = DynamicTable.ENTRY_SIZE + value.length(); + checkPartialSize(readPart); + } + return false; + } + + long absoluteIndex = prefix.base() + intValue; + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format( + "literal with post-base name reference (%s, %s, '%s', huffman=%b)", + absoluteIndex, prefix.base(), value, huffmanValue)); + } + checkEntryIndex(absoluteIndex, prefix); + HeaderField f = entryAtIndex(absoluteIndex); + String valueStr = value.toString(); + checkSectionSize(DynamicTable.headerSize(f.name(), valueStr)); + action.onLiteralWithNameReference(absoluteIndex, + f.name(), valueStr, huffmanValue, hideIntermediary); + reset(); + return true; + } + + private boolean completeReading(ByteBuffer input) { + if (!firstValueRead) { + if (!integerReader.read(input)) { + return false; + } + intValue = integerReader.get(); + integerReader.reset(); + + firstValueRead = true; + return false; + } else { + if (!stringReader.read(input, value, getMaxFieldLineLimit())) { + return false; + } + } + huffmanValue = stringReader.isHuffmanEncoded(); + stringReader.reset(); + + return true; + } + + public void reset() { + value.setLength(0); + firstValueRead = false; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineNameReferenceReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineNameReferenceReader.java new file mode 100644 index 00000000000..f7e3d300628 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineNameReferenceReader.java @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.readers; + +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.FieldSectionPrefix; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.qpack.QPACK; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; + +import static java.lang.String.format; +import static jdk.internal.net.http.http3.Http3Error.QPACK_DECOMPRESSION_FAILED; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.NORMAL; + +final class FieldLineNameReferenceReader extends FieldLineReader { + private long intValue; + private boolean hideIntermediary; + private boolean huffmanValue; + private final StringBuilder value; + private final IntegerReader integerReader; + private final StringReader stringReader; + private final QPACK.Logger logger; + + private boolean firstValueRead = false; + + FieldLineNameReferenceReader(DynamicTable dynamicTable, long maxSectionSize, + AtomicLong sectionSizeTracker, QPACK.Logger logger) { + super(dynamicTable, maxSectionSize, sectionSizeTracker); + this.logger = logger; + var errorToReport = new ReaderError(QPACK_DECOMPRESSION_FAILED, false); + integerReader = new IntegerReader(errorToReport); + stringReader = new StringReader(errorToReport); + value = new StringBuilder(1024); + } + + public void configure(int b) { + fromStaticTable = (b & 0b0001_0000) != 0; + hideIntermediary = (b & 0b0010_0000) != 0; + integerReader.configure(4); + } + + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 1 | N | T | NameIndex (4+)| + // +---+---+-----------------------+ + // | H | Value Length (7+) | + // +---+---------------------------+ + // | Value String (Length octets) | + // +-------------------------------+ + // + public boolean read(ByteBuffer input, FieldSectionPrefix prefix, + DecodingCallback action) { + if (!completeReading(input)) { + if (firstValueRead) { + long readPart = DynamicTable.ENTRY_SIZE + value.length(); + checkPartialSize(readPart); + } + return false; + } + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format( + "literal with name reference (%s, %s, '%s', huffman=%b)", + fromStaticTable ? "static" : "dynamic", intValue, value, huffmanValue)); + } + long absoluteIndex = fromStaticTable ? intValue : prefix.base() - 1 - intValue; + checkEntryIndex(absoluteIndex, prefix); + HeaderField f = entryAtIndex(absoluteIndex); + String valueStr = value.toString(); + checkSectionSize(DynamicTable.headerSize(f.name(), valueStr)); + action.onLiteralWithNameReference(absoluteIndex, f.name(), valueStr, + huffmanValue, hideIntermediary); + reset(); + return true; + } + + private boolean completeReading(ByteBuffer input) { + if (!firstValueRead) { + if (!integerReader.read(input)) { + return false; + } + intValue = integerReader.get(); + integerReader.reset(); + + firstValueRead = true; + return false; + } else { + if (!stringReader.read(input, value, getMaxFieldLineLimit())) { + return false; + } + } + huffmanValue = stringReader.isHuffmanEncoded(); + stringReader.reset(); + + return true; + } + + public void reset() { + value.setLength(0); + firstValueRead = false; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineReader.java new file mode 100644 index 00000000000..da67f92bf19 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/FieldLineReader.java @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.readers; + +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.FieldSectionPrefix; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.qpack.QPackException; +import jdk.internal.net.http.qpack.StaticTable; + +import java.io.IOException; +import java.net.ProtocolException; +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; + +sealed abstract class FieldLineReader permits FieldLineIndexedPostBaseReader, + FieldLineIndexedReader, FieldLineLiteralsReader, FieldLineNameRefPostBaseReader, + FieldLineNameReferenceReader { + + final long maxSectionSize; + boolean fromStaticTable; + private final AtomicLong sectionSizeTracker; + private final DynamicTable dynamicTable; + + FieldLineReader(DynamicTable dynamicTable, long maxSectionSize, AtomicLong sectionSizeTracker) { + this.maxSectionSize = maxSectionSize; + this.sectionSizeTracker = sectionSizeTracker; + this.dynamicTable = dynamicTable; + } + + abstract void reset(); + abstract void configure(int b); + abstract boolean read(ByteBuffer input, FieldSectionPrefix prefix, + DecodingCallback action); + + final void checkSectionSize(long fieldSize) { + long sectionSize = sectionSizeTracker.addAndGet(fieldSize); + if (maxSectionSize > 0 && sectionSize > maxSectionSize) { + throw maxFieldSectionExceeded(sectionSize, maxSectionSize); + } + } + + final void checkPartialSize(long partialFieldSize) { + long sectionSize = sectionSizeTracker.get() + partialFieldSize; + if (maxSectionSize > 0 && sectionSize > maxSectionSize) { + throw maxFieldSectionExceeded(sectionSize, maxSectionSize); + } + } + + final int getMaxFieldLineLimit(int partiallyRead) { + int maxLimit = -1; + if (maxSectionSize > 0) { + maxLimit = Math.clamp(maxSectionSize - partiallyRead - 32 - + sectionSizeTracker.get(), 0, Integer.MAX_VALUE); + } + return maxLimit; + } + + final int getMaxFieldLineLimit() { + return getMaxFieldLineLimit(0); + } + + private static QPackException maxFieldSectionExceeded(long sectionSize, long maxSize) { + throw QPackException.decompressionFailed( + new ProtocolException("Size exceeds MAX_FIELD_SECTION_SIZE: %s > %s" + .formatted(sectionSize, maxSize)), false); + } + + /** + * Checks if the decoder encounters a reference in a field line representation to + * a dynamic table entry that has already been evicted or that has an absolute index + * greater than or equal to the declared Required Insert Count (Section 4.5.1), + * it MUST treat this as a connection error of type QPACK_DECOMPRESSION_FAILED. + * @param absoluteIndex dynamic table absolute index + * @param prefix field line section prefix + */ + void checkEntryIndex(long absoluteIndex, FieldSectionPrefix prefix) { + if (!fromStaticTable && absoluteIndex >= prefix.requiredInsertCount()) { + throw QPackException.decompressionFailed( + new IOException("header index is greater than RIC"), true); + } + } + + /** + * Return a header field entry for the specified entry index. The table type + * is selected according to the {@code fromStaticTable} value. + * @param index absolute index of the table entry. + * @return a header field corresponding to the specified entry + */ + final HeaderField entryAtIndex(long index) { + HeaderField f; + try { + if (fromStaticTable) { + f = StaticTable.HTTP3.get(index); + } else { + assert dynamicTable != null; + f = dynamicTable.get(index); + } + } catch (IndexOutOfBoundsException | IllegalStateException | IllegalArgumentException e) { + throw QPackException.decompressionFailed( + new IOException("header fields table index", e), true); + } + return f; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/HeaderFrameReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/HeaderFrameReader.java new file mode 100644 index 00000000000..a4ea55661fe --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/HeaderFrameReader.java @@ -0,0 +1,414 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.readers; + +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.FieldSectionPrefix; +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.QPackException; +import jdk.internal.net.http.quic.streams.QuicStreamReader; + +import java.io.IOException; +import java.net.ProtocolException; +import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicLong; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static jdk.internal.net.http.http3.Http3Error.H3_INTERNAL_ERROR; +import static jdk.internal.net.http.http3.Http3Error.QPACK_DECOMPRESSION_FAILED; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.NORMAL; + +public class HeaderFrameReader { + + private enum State { + // Nothing has been read so-far, "Required Insert Count" (RIC) will be read next + INITIAL, + // "Required Insert Count" read is done, "S" and "Delta Base" are next + DELTA_BASE, + // Encoded Field Section Prefix read is done, ready to start reading header fields. + // In this state we only select a proper reader based on the field line encoding type, + // ie the first byte is analysed to select a proper reader. + SELECT_FIELD_READER, + INDEXED, + INDEX_WITH_POST_BASE, + LITERAL_WITH_LITERAL_NAME, + LITERAL_WITH_NAME_REF, + LITERAL_WITH_POST_BASE, + AWAITING_DT_INSERT_COUNT + } + + /* + 4.5.1. Encoded Field Section Prefix + Each encoded field section is prefixed with two integers. The Required Insert Count + is encoded as an integer with an 8-bit prefix using the encoding described in Section 4.5.1.1. + The Base is encoded as a Sign bit ('S') and a Delta Base value with a 7-bit prefix; + see Section 4.5.1.2. + + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | Required Insert Count (8+) | + +---+---------------------------+ + | S | Delta Base (7+) | + +---+---------------------------+ + */ + long requiredInsertCount; + long deltaBase; + int signBit; + volatile FieldSectionPrefix fieldSectionPrefix; + private final IntegerReader integerReader; + private FieldLineReader reader; + private final QPACK.Logger logger; + private final FieldLineIndexedReader indexedReader; + private final FieldLineIndexedPostBaseReader indexedPostBaseReader; + private final FieldLineNameReferenceReader literalWithNameReferenceReader; + private final FieldLineNameRefPostBaseReader literalWithNameRefPostBaseReader; + + private final FieldLineLiteralsReader literalWithLiteralNameReader; + // Need dynamic table reference for decoding field line section prefix + private final DynamicTable dynamicTable; + private final DecodingCallback decodingCallback; + + private volatile State state = State.INITIAL; + + private final SequentialScheduler headersScheduler = SequentialScheduler.lockingScheduler(this::readLoop); + private final ConcurrentLinkedQueue headersData = new ConcurrentLinkedQueue<>(); + + private final AtomicLong blockedStreamsCounter; + private final long maxBlockedStreams; + + // A tracker of header data received by the decoder, to check that the peer encoder + // honours the SETTINGS_MAX_FIELD_SECTION_SIZE value: + // RFC-9114: 4.2.2. Header Size Constraints + // "If an implementation wishes to advise its peer of this limit, it can + // be conveyed as a number of bytes in the SETTINGS_MAX_FIELD_SECTION_SIZE parameter. + // An implementation that has received this parameter SHOULD NOT send an HTTP message + // header that exceeds the indicated size" + // "A client can discard responses that it cannot process." + // + // Maximum allowed value is passed to FieldLineReader's implementations and not stored in + // HeaderFrameReader instance. + private final AtomicLong fieldSectionSizeTracker; + + private static final AtomicLong HEADER_FRAME_READER_IDS = new AtomicLong(); + + private void readLoop() { + try { + readLoop0(); + } catch (QPackException qPackException) { + Throwable cause = qPackException.getCause(); + if (qPackException.isConnectionError()) { + decodingCallback.onConnectionError(cause, qPackException.http3Error()); + } else { + decodingCallback.onStreamError(cause, qPackException.http3Error()); + } + } catch (Throwable throwable) { + decodingCallback.onConnectionError(throwable, H3_INTERNAL_ERROR); + } finally { + // Stop the scheduler, clear the reader's queue and + // remove all insert count notification events associated + // with current stream. + if (decodingCallback.hasError()) { + headersScheduler.stop(); + headersData.clear(); + dynamicTable.cleanupStreamInsertCountNotifications(decodingCallback.streamId()); + } + } + } + + private void readLoop0() { + ByteBuffer headerBlock; + OUTER: + while (!decodingCallback.hasError() && (headerBlock = headersData.peek()) != null) { + boolean endOfHeaderBlock = headerBlock == QuicStreamReader.EOF; + State state = this.state; + FieldSectionPrefix sectionPrefix = this.fieldSectionPrefix; + while (!decodingCallback.hasError() && headerBlock.hasRemaining()) { + if (state == State.SELECT_FIELD_READER) { + int b = headerBlock.get(headerBlock.position()) & 0xff; // absolute read + state = this.state = selectHeaderReaderState(b); + if (logger.isLoggable(EXTRA)) { + String message = format("next binary representation %s (first byte 0x%02x)", state, b); + logger.log(EXTRA, () -> message); + } + reader = switch (state) { + case INDEXED -> indexedReader; + case LITERAL_WITH_NAME_REF -> literalWithNameReferenceReader; + case LITERAL_WITH_LITERAL_NAME -> literalWithLiteralNameReader; + case INDEX_WITH_POST_BASE -> indexedPostBaseReader; + case LITERAL_WITH_POST_BASE -> literalWithNameRefPostBaseReader; + default -> throw QPackException.decompressionFailed( + new InternalError("Unexpected decoder state: " + state), false); + }; + reader.configure(b); + } else if (state == State.INITIAL) { + if (!integerReader.read(headerBlock)) { + continue; + } + // Required Insert Count was fully read + requiredInsertCount = integerReader.get(); + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format("Encoded Required Insert Count = %d", requiredInsertCount)); + } + // Continue reading S and Delta Base values + state = this.state = State.DELTA_BASE; + // Reset integer reader + integerReader.reset(); + // Prepare it for reading S and Delta Base (7+) + integerReader.configure(7); + continue; + } else if (state == State.DELTA_BASE) { + if (signBit == -1) { + int b = headerBlock.get(headerBlock.position()) & 0xff; // absolute read + signBit = (b & 0b1000_0000) == 0b1000_0000 ? 1 : 0; + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format("Base Sign = %d", signBit)); + } + } + if (!integerReader.read(headerBlock)) { + continue; + } + deltaBase = integerReader.get(); + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format("Delta Base = %d", deltaBase)); + } + // Construct field section prefix from the parsed fields + sectionPrefix = this.fieldSectionPrefix = + FieldSectionPrefix.decode(requiredInsertCount, deltaBase, + signBit, dynamicTable); + + // Check if decoding of field section is blocked due to not yet received + // dynamic table entries + long insertCount = dynamicTable.insertCount(); + if (sectionPrefix.requiredInsertCount() > insertCount) { + long blocked = blockedStreamsCounter.incrementAndGet(); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, + () -> "Blocked stream observed. Total blocked: " + blocked + + " Max allowed: " + maxBlockedStreams); + } + // System property value is checked here instead of the HTTP3 settings because decoder uses its + // value to update connection settings. HTTP client's encoder implementation won't block the streams - + // only acknowledged entry references is used, therefore this connection setting is not consulted + // on encoder side. + if (blocked > maxBlockedStreams) { + var ioException = new IOException(("too many blocked streams: current=%d; max=%d; " + + "prefixCount=%d; tableCount=%d").formatted(blocked, maxBlockedStreams, + sectionPrefix.requiredInsertCount(), insertCount)); + // If a decoder encounters more blocked streams than it promised to support, + // it MUST treat this as a connection error of type QPACK_DECOMPRESSION_FAILED. + throw QPackException.decompressionFailed(ioException, true); + } else { + CompletableFuture future = + dynamicTable.awaitFutureInsertCount(decodingCallback.streamId(), + sectionPrefix.requiredInsertCount()); + state = this.state = State.AWAITING_DT_INSERT_COUNT; + future.thenRun(this::onInsertCountUpdate); + } + break OUTER; + } + // The stream is unblocked - field lines can be decoded now + state = this.state = State.SELECT_FIELD_READER; + continue; + } else if (state == State.AWAITING_DT_INSERT_COUNT) { + // If we're waiting for a specific dynamic table update + return; + } + if (reader.read(headerBlock, sectionPrefix, decodingCallback)) { + // Finished reading of one header field line + state = this.state = State.SELECT_FIELD_READER; + } + } + if (!headerBlock.hasRemaining()) { + var head = headersData.poll(); + assert head == headerBlock; + } + if (endOfHeaderBlock) { + if (state == State.SELECT_FIELD_READER) { + decodingCallback.onComplete(); + } else { + logger.log(NORMAL, () -> "unexpected end of representation"); + throw QPackException.decompressionFailed( + new ProtocolException("Unexpected end of header block"), true); + } + } + } + } + + private void onInsertCountUpdate() { + long blocked = blockedStreamsCounter.decrementAndGet(); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> "Stream Unblocked - number of blocked streams: " + blocked); + } + state = State.SELECT_FIELD_READER; + headersScheduler.runOrSchedule(); + } + + public HeaderFrameReader(DynamicTable dynamicTable, DecodingCallback callback, + AtomicLong blockedStreamsCounter, long maxBlockedStreams, + long maxFieldSectionSize, QPACK.Logger logger) { + this.blockedStreamsCounter = blockedStreamsCounter; + this.logger = logger.subLogger("HeaderFrameReader#" + + HEADER_FRAME_READER_IDS.incrementAndGet()); + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format("New HeaderFrameReader, dynamic table capacity = %s", + dynamicTable.capacity())); + /* To correlate with logging outside QPACK, knowing + hashCode/toString is important */ + logger.log(NORMAL, () -> { + String hashCode = Integer.toHexString(System.identityHashCode(this)); + return format("toString='%s', identityHashCode=%s", this, hashCode); + }); + } + this.fieldSectionSizeTracker = new AtomicLong(); + indexedReader = new FieldLineIndexedReader(dynamicTable, + maxFieldSectionSize, fieldSectionSizeTracker, + this.logger.subLogger("FieldLineIndexedReader")); + indexedPostBaseReader = new FieldLineIndexedPostBaseReader(dynamicTable, + maxFieldSectionSize, fieldSectionSizeTracker, + this.logger.subLogger("FieldLineIndexedPostBaseReader")); + literalWithNameReferenceReader = new FieldLineNameReferenceReader(dynamicTable, + maxFieldSectionSize, fieldSectionSizeTracker, + this.logger.subLogger("FieldLineNameReferenceReader")); + literalWithNameRefPostBaseReader = new FieldLineNameRefPostBaseReader(dynamicTable, + maxFieldSectionSize, fieldSectionSizeTracker, + this.logger.subLogger("FieldLineNameRefPostBaseReader")); + literalWithLiteralNameReader = new FieldLineLiteralsReader( + maxFieldSectionSize, fieldSectionSizeTracker, + this.logger.subLogger("FieldLineLiteralsReader")); + integerReader = new IntegerReader(new ReaderError(QPACK_DECOMPRESSION_FAILED, false)); + resetPrefixVars(); + // Since reader is constructed in Initial state - it means that the + // "Required Insert Count" will be read first. + integerReader.configure(8); + decodingCallback = callback; + this.dynamicTable = dynamicTable; + this.maxBlockedStreams = maxBlockedStreams; + } + + private void resetPrefixVars() { + requiredInsertCount = -1L; + deltaBase = -1L; + signBit = -1; + fieldSectionPrefix = null; + fieldSectionSizeTracker.set(0); + } + + public FieldSectionPrefix decodedSectionPrefix() { + if (deltaBase == -1L) { + throw new IllegalStateException("Field Section Prefix not parsed yet"); + } + return fieldSectionPrefix; + } + + public void read(ByteBuffer headerBlock, boolean endOfHeaderBlock) { + requireNonNull(headerBlock, "headerBlock"); + if (logger.isLoggable(NORMAL)) { + logger.log(NORMAL, () -> format("reading %s, end of header block? %s", + headerBlock, endOfHeaderBlock)); + } + headersData.add(headerBlock); + if (endOfHeaderBlock) { + headersData.add(QuicStreamReader.EOF); + } + headersScheduler.runOrSchedule(); + } + + private State selectHeaderReaderState(int b) { + // First non-zero bit in lower 8 bits (see the caller) + int pos = Integer.numberOfLeadingZeros(b) - 24; + return switch (pos) { + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 1 | T | Index (6+) | + +---+---+-----------------------+ + */ + case 0 -> State.INDEXED; + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 1 | N | T |Name Index (4+)| + +---+---+---+---+---------------+ + | H | Value Length (7+) | + +---+---------------------------+ + | Value String (Length bytes) | + +-------------------------------+ + */ + case 1 -> State.LITERAL_WITH_NAME_REF; + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 0 | 1 | N | H |NameLen(3+)| + +---+---+---+---+---+-----------+ + | Name String (Length bytes) | + +---+---------------------------+ + | H | Value Length (7+) | + +---+---------------------------+ + | Value String (Length bytes) | + +-------------------------------+ + */ + case 2 -> State.LITERAL_WITH_LITERAL_NAME; + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 0 | 0 | 1 | Index (4+) | + +---+---+---+---+---------------+ + */ + case 3 -> State.INDEX_WITH_POST_BASE; + // "Literal Field Line with Post-Base Name Reference": + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 0 | 0 | N |NameIdx(3+)| + // +---+---+---+---+---------------+ + default -> { + if ((b & 0xF0) == 0) { + yield State.LITERAL_WITH_POST_BASE; + } + throw QPackException.decompressionFailed( + new IOException("Unknown frame reader line prefix: " + b), + false); + } + }; + } + + /** + * Reset the state of the HeaderFrameReader so that it's ready + * to parse a new HeaderFrame. + */ + public void reset() { + state = State.INITIAL; + reader = null; + resetPrefixVars(); + integerReader.reset(); + integerReader.configure(8); + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/IntegerReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/IntegerReader.java new file mode 100644 index 00000000000..35025c3f06a --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/IntegerReader.java @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.readers; + +import jdk.internal.net.http.http3.Http3Error; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; + +import static java.lang.String.format; + +/** + * This class able to decode integers up to and including 62 bits long values. + * https://www.rfc-editor.org/rfc/rfc9204.html#name-prefixed-integers + */ +public final class IntegerReader { + + private static final int NEW = 0; + private static final int CONFIGURED = 1; + private static final int FIRST_BYTE_READ = 2; + private static final int DONE = 4; + + private int state = NEW; + + private int N; + private long maxValue; + private long value; + private long r; + private long b = 1; + private final ReaderError readError; + + public IntegerReader(ReaderError readError) { + this.readError = readError; + } + + public IntegerReader() { + this(new ReaderError(Http3Error.H3_INTERNAL_ERROR, true)); + } + + // "QPACK implementations MUST be able to decode integers up to and including 62 bits long." + // https://www.rfc-editor.org/rfc/rfc9204.html#name-prefixed-integers + public static final long QPACK_MAX_INTEGER_VALUE = (1L << 62) - 1; + + public IntegerReader configure(int N) { + return configure(N, QPACK_MAX_INTEGER_VALUE); + } + + // + // Why is it important to configure 'maxValue' here. After all we can wait + // for the integer to be fully read and then check it. Can't we? + // + // Two reasons. + // + // 1. Value wraps around long won't be unnoticed. + // 2. It can spit out an exception as soon as it becomes clear there's + // an overflow. Therefore, no need to wait for the value to be fully read. + // + public IntegerReader configure(int N, long maxValue) { + if (state != NEW) { + throw new IllegalStateException("Already configured"); + } + checkPrefix(N); + if (maxValue < 0) { + throw new IllegalArgumentException( + "maxValue >= 0: maxValue=" + maxValue); + } + this.maxValue = maxValue; + this.N = N; + state = CONFIGURED; + return this; + } + + public boolean read(ByteBuffer input) { + if (state == NEW) { + throw new IllegalStateException("Configure first"); + } + if (state == DONE) { + return true; + } + if (!input.hasRemaining()) { + return false; + } + if (state == CONFIGURED) { + int max = (2 << (N - 1)) - 1; + int n = input.get() & max; + if (n != max) { + value = n; + state = DONE; + return true; + } else { + r = max; + } + state = FIRST_BYTE_READ; + } + if (state == FIRST_BYTE_READ) { + try { + // variable-length quantity (VLQ) + byte i; + boolean continuationFlag; + do { + if (!input.hasRemaining()) { + return false; + } + i = input.get(); + // RFC 7541: 5.1. Integer Representation + // "The most significant bit of each octet is used + // as a continuation flag: its value is set to 1 except + // for the last octet in the list" + continuationFlag = (i & 0b10000000) != 0; + long increment = Math.multiplyExact(b, i & 127); + if (continuationFlag) { + b = Math.multiplyExact(b, 128); + } + if (r > maxValue - increment) { + throw readError.toQPackException( + new IOException(format( + "Integer overflow: maxValue=%,d, value=%,d", + maxValue, r + increment))); + } + r += increment; + } while (continuationFlag); + value = r; + state = DONE; + return true; + } catch (ArithmeticException arithmeticException) { + // Sequence of bytes encodes value greater + // than QPACK_MAX_INTEGER_VALUE + throw readError.toQPackException(new IOException("Integer overflow", + arithmeticException)); + } + } + throw new InternalError(Arrays.toString( + new Object[]{state, N, maxValue, value, r, b})); + } + + public long get() throws IllegalStateException { + if (state != DONE) { + throw new IllegalStateException("Has not been fully read yet"); + } + return value; + } + + private static void checkPrefix(int N) { + if (N < 1 || N > 8) { + throw new IllegalArgumentException("1 <= N <= 8: N= " + N); + } + } + + public IntegerReader reset() { + b = 1; + state = NEW; + return this; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/ReaderError.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/ReaderError.java new file mode 100644 index 00000000000..3be6033f827 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/ReaderError.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.qpack.readers; + +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.qpack.QPackException; + +/** + * QPack readers configuration record to be used by the readers to + * report errors. + * @param http3Error corresponding HTTP/3 error code. + * @param isConnectionError if the reader error should be treated + * as connection error. + */ +record ReaderError(Http3Error http3Error, boolean isConnectionError) { + + /** + * Construct a {@link QPackException} from on {@code http3Error}, + * {@code isConnectionError} and provided {@code "cause"} values. + * @param cause cause of the constructed {@link QPackException} + * @return a {@code QPackException} instance. + */ + QPackException toQPackException(Throwable cause) { + return new QPackException(http3Error, cause, isConnectionError); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/StringReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/StringReader.java new file mode 100644 index 00000000000..0cd0c92e91a --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/readers/StringReader.java @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.readers; + +import java.io.IOException; +import java.net.ProtocolException; +import java.nio.ByteBuffer; +import java.util.Arrays; + +import jdk.internal.net.http.hpack.ISO_8859_1; +import jdk.internal.net.http.hpack.Huffman; +import jdk.internal.net.http.hpack.QuickHuffman; +import jdk.internal.net.http.http3.Http3Error; + +// +// 0 1 2 3 4 5 6 7 +// +---+---+---+---+---+---+---+---+ +// | H | String Length (7+) | +// +---+---------------------------+ +// | String Data (Length octets) | +// +-------------------------------+ +// +public final class StringReader { + + private static final int NEW = 0; + private static final int FIRST_BYTE_READ = 1; + private static final int LENGTH_READ = 2; + private static final int DONE = 4; + + private final ReaderError readError; + private final IntegerReader intReader; + private final Huffman.Reader huffmanReader = new QuickHuffman.Reader(); + private final ISO_8859_1.Reader plainReader = new ISO_8859_1.Reader(); + + private int state = NEW; + private boolean huffman; + private int remainingLength; + + public StringReader() { + this(new ReaderError(Http3Error.H3_INTERNAL_ERROR, true)); + } + + public StringReader(ReaderError readError) { + this.readError = readError; + this.intReader = new IntegerReader(readError); + } + + public boolean read(ByteBuffer input, Appendable output, int maxLength) { + return read(7, input, output, maxLength); + } + + boolean read(int N, ByteBuffer input, Appendable output, int maxLength) { + if (state == DONE) { + return true; + } + if (!input.hasRemaining()) { + return false; + } + if (state == NEW) { + int huffmanBit = switch (N) { + case 7 -> 0b1000_0000; // for all value strings + case 5 -> 0b0010_0000; // in name string for insert literal + case 3 -> 0b0000_1000; // in name string for literal + default -> throw new IllegalStateException("Unexpected value: " + N); + }; + int p = input.position(); + huffman = (input.get(p) & huffmanBit) != 0; + state = FIRST_BYTE_READ; + intReader.configure(N); + } + if (state == FIRST_BYTE_READ) { + boolean lengthRead = intReader.read(input); + if (!lengthRead) { + return false; + } + long remainingLengthLong = intReader.get(); + if (maxLength >= 0) { + long huffmanEstimate = huffman ? + remainingLengthLong / 4 : remainingLengthLong; + if (huffmanEstimate > maxLength) { + throw readError.toQPackException(new ProtocolException( + "Size exceeds MAX_FIELD_SECTION_SIZE or dynamic table capacity.")); + } + } + remainingLength = (int) remainingLengthLong; + state = LENGTH_READ; + } + if (state == LENGTH_READ) { + boolean isLast = input.remaining() >= remainingLength; + int oldLimit = input.limit(); + if (isLast) { + input.limit(input.position() + remainingLength); + } + remainingLength -= Math.min(input.remaining(), remainingLength); + try { + if (huffman) { + huffmanReader.read(input, output, isLast); + } else { + plainReader.read(input, output); + } + } catch (IOException ioe) { + throw readError.toQPackException(ioe); + } + if (isLast) { + input.limit(oldLimit); + state = DONE; + } + return isLast; + } + throw new InternalError(Arrays.toString( + new Object[]{state, huffman, remainingLength})); + } + + public boolean isHuffmanEncoded() { + if (state < FIRST_BYTE_READ) { + throw new IllegalStateException("Has not been fully read yet"); + } + return huffman; + } + + public void reset() { + if (huffman) { + huffmanReader.reset(); + } else { + plainReader.reset(); + } + intReader.reset(); + state = NEW; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/BinaryRepresentationWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/BinaryRepresentationWriter.java new file mode 100644 index 00000000000..b028b994df2 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/BinaryRepresentationWriter.java @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import java.nio.ByteBuffer; + +interface BinaryRepresentationWriter { + boolean write(ByteBuffer destination); + + BinaryRepresentationWriter reset(); +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/DecoderInstructionsWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/DecoderInstructionsWriter.java new file mode 100644 index 00000000000..c27f46e2761 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/DecoderInstructionsWriter.java @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.qpack.writers; + +import jdk.internal.net.http.qpack.QPACK; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; + +import static java.lang.String.format; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; + +public class DecoderInstructionsWriter { + private final QPACK.Logger logger; + private boolean encoding; + private static final AtomicLong IDS = new AtomicLong(); + + private final IntegerWriter integerWriter = new IntegerWriter(); + + public DecoderInstructionsWriter() { + long id = IDS.incrementAndGet(); + this.logger = QPACK.getLogger().subLogger("DecoderInstructionsWriter#" + id); + } + + /* + * Configure the writer for encoding "Section Acknowledgment" decoder instruction: + * 0 1 2 3 4 5 6 7 + * +---+---+---+---+---+---+---+---+ + * | 1 | Stream ID (7+) | + * +---+---------------------------+ + */ + public int configureForSectionAck(long streamId) { + checkIfEncodingInProgress(); + encoding = true; + integerWriter.configure(streamId, 7, 0b1000_0000); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Section Acknowledgment for stream id=%s", + streamId)); + } + return IntegerWriter.requiredBufferSize(7, streamId); + } + + /* + * Configure the writer for encoding "Stream Cancellation" decoder instruction: + * 0 1 2 3 4 5 6 7 + * +---+---+---+---+---+---+---+---+ + * | 0 | 1 | Stream ID (6+) | + * +---+---+-----------------------+ + */ + public int configureForStreamCancel(long streamId) { + checkIfEncodingInProgress(); + encoding = true; + integerWriter.configure(streamId, 6, 0b0100_0000); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Stream Cancellation for stream id=%s", + streamId)); + } + return IntegerWriter.requiredBufferSize(6, streamId); + } + + /* + * Configure the writer for encoding "Insert Count Increment" decoder instruction: + * 0 1 2 3 4 5 6 7 + * +---+---+---+---+---+---+---+---+ + * | 0 | 0 | Increment (6+) | + * +---+---+-----------------------+ + */ + public int configureForInsertCountInc(long increment) { + checkIfEncodingInProgress(); + encoding = true; + integerWriter.configure(increment, 6, 0b0000_0000); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Insert Count Increment value=%s", + increment)); + } + return IntegerWriter.requiredBufferSize(6, increment); + } + + public boolean write(ByteBuffer byteBuffer) { + if (!encoding) { + throw new IllegalStateException("Writer hasn't been configured"); + } + boolean done = integerWriter.write(byteBuffer); + if (done) { + integerWriter.reset(); + encoding = false; + } + return done; + } + + private void checkIfEncodingInProgress() { + if (encoding) { + throw new IllegalStateException( + "Previous encoding operation hasn't finished yet"); + } + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderDuplicateEntryWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderDuplicateEntryWriter.java new file mode 100644 index 00000000000..29f6cac9ed8 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderDuplicateEntryWriter.java @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import java.nio.ByteBuffer; + +final class EncoderDuplicateEntryWriter implements BinaryRepresentationWriter { + + private final IntegerWriter intWriter; + + public EncoderDuplicateEntryWriter() { + this.intWriter = new IntegerWriter(); + } + + public EncoderDuplicateEntryWriter configure(long relativeId) { + // IntegerWriter.configure checks if the relative id value is not negative + intWriter.configure(relativeId, 5, 0b0000_0000); + // Need to store entry id for adding a duplicate to the dynamic table + // once write operation is completed + return this; + } + + @Override + public boolean write(ByteBuffer destination) { + // IntegerWriter.write checks if it was properly configured + return intWriter.write(destination); + } + + @Override + public BinaryRepresentationWriter reset() { + intWriter.reset(); + return this; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderDynamicTableCapacityWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderDynamicTableCapacityWriter.java new file mode 100644 index 00000000000..1a920ff5b7b --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderDynamicTableCapacityWriter.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import java.nio.ByteBuffer; + +final class EncoderDynamicTableCapacityWriter implements BinaryRepresentationWriter { + + private final IntegerWriter intWriter; + + public EncoderDynamicTableCapacityWriter() { + this.intWriter = new IntegerWriter(); + } + + public EncoderDynamicTableCapacityWriter configure(long capacity) { + // IntegerWriter.configure checks if the capacity value is not negative + intWriter.configure(capacity, 5, 0b0010_0000); + return this; + } + + @Override + public boolean write(ByteBuffer destination) { + // IntegerWriter.write checks if it was properly configured + return intWriter.write(destination); + } + + @Override + public BinaryRepresentationWriter reset() { + intWriter.reset(); + return this; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInsertIndexedNameWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInsertIndexedNameWriter.java new file mode 100644 index 00000000000..d1851a0c204 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInsertIndexedNameWriter.java @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.TableEntry; + +import java.nio.ByteBuffer; + +import static java.lang.String.format; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; + +class EncoderInsertIndexedNameWriter implements BinaryRepresentationWriter { + private int state = NEW; + private final QPACK.Logger logger; + private final IntegerWriter intWriter = new IntegerWriter(); + private final StringWriter valueWriter = new StringWriter(); + private static final int NEW = 0; + private static final int NAME_PART_WRITTEN = 1; + private static final int VALUE_WRITTEN = 2; + + public EncoderInsertIndexedNameWriter(QPACK.Logger logger) { + this.logger = logger; + } + + public BinaryRepresentationWriter configure(TableEntry e) throws IndexOutOfBoundsException { + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format( + "Encoder Insert With %s Table Name Reference (%s, '%s', huffman=%b)", + e.isStaticTable() ? "Static" : "Dynamic", e.index(), e.value(), e.huffmanValue())); + } + return this.index(e).value(e); + } + + @Override + public boolean write(ByteBuffer destination) { + if (state < NAME_PART_WRITTEN) { + if (!intWriter.write(destination)) { + return false; + } + state = NAME_PART_WRITTEN; + } + if (state < VALUE_WRITTEN) { + if (!valueWriter.write(destination)) { + return false; + } + state = VALUE_WRITTEN; + } + return state == VALUE_WRITTEN; + } + + @Override + public BinaryRepresentationWriter reset() { + intWriter.reset(); + valueWriter.reset(); + state = NEW; + return this; + } + + private EncoderInsertIndexedNameWriter index(TableEntry e) { + int N = 6; + int payload = 0b1000_0000; + long index = e.index(); + if (e.isStaticTable()) { + payload |= 0b0100_0000; + } + intWriter.configure(index, N, payload); + return this; + } + + private EncoderInsertIndexedNameWriter value(TableEntry e) { + int N = 7; + int payload = 0b0000_0000; + if (e.huffmanValue()) { + payload |= 0b1000_0000; + } + valueWriter.configure(e.value(), N, payload, e.huffmanValue()); + return this; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInsertLiteralNameWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInsertLiteralNameWriter.java new file mode 100644 index 00000000000..1783f60b062 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInsertLiteralNameWriter.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.TableEntry; + +import java.nio.ByteBuffer; + +import static java.lang.String.format; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; + +final class EncoderInsertLiteralNameWriter implements BinaryRepresentationWriter { + private int state = NEW; + private final QPACK.Logger logger; + private final StringWriter nameWriter = new StringWriter(); + private final StringWriter valueWriter = new StringWriter(); + private static final int NEW = 0; + private static final int NAME_PART_WRITTEN = 1; + private static final int VALUE_WRITTEN = 2; + + EncoderInsertLiteralNameWriter(QPACK.Logger logger) { + this.logger = logger; + } + + public BinaryRepresentationWriter configure(TableEntry e) throws IndexOutOfBoundsException { + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format( + "Insert With Literal Name (%s, '%s', huffmanName=%b, huffmanValue=%b)", + e.name(), e.value(), e.huffmanName(), e.huffmanValue())); + } + return this.name(e).value(e); + } + + @Override + public boolean write(ByteBuffer destination) { + if (state < NAME_PART_WRITTEN) { + if (!nameWriter.write(destination)) { + return false; + } + state = NAME_PART_WRITTEN; + } + if (state < VALUE_WRITTEN) { + if (!valueWriter.write(destination)) { + return false; + } + state = VALUE_WRITTEN; + } + return state == VALUE_WRITTEN; + } + + @Override + public BinaryRepresentationWriter reset() { + nameWriter.reset(); + valueWriter.reset(); + state = NEW; + return this; + } + + private EncoderInsertLiteralNameWriter name(TableEntry e) { + int N = 5; + int payload = 0b0100_0000; + if (e.huffmanName()) { + payload |= 0b0010_0000; + } + nameWriter.configure(e.name(), N, payload, e.huffmanName()); + return this; + } + + private EncoderInsertLiteralNameWriter value(TableEntry e) { + int N = 7; + int payload = 0b0000_0000; + if (e.huffmanValue()) { + payload |= 0b1000_0000; + } + valueWriter.configure(e.value(), N, payload, e.huffmanValue()); + return this; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInstructionsWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInstructionsWriter.java new file mode 100644 index 00000000000..c4a22436083 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/EncoderInstructionsWriter.java @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; + +import jdk.internal.net.http.hpack.QuickHuffman; +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.TableEntry; + +import static java.lang.String.format; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; + +public class EncoderInstructionsWriter { + private BinaryRepresentationWriter writer; + private final QPACK.Logger logger; + private final EncoderInsertIndexedNameWriter insertIndexedNameWriter; + private final EncoderInsertLiteralNameWriter insertLiteralNameWriter; + private final EncoderDuplicateEntryWriter duplicateWriter; + private final EncoderDynamicTableCapacityWriter capacityWriter; + private boolean encoding; + private static final AtomicLong ENCODERS_IDS = new AtomicLong(); + + public EncoderInstructionsWriter() { + this(QPACK.getLogger()); + } + + public EncoderInstructionsWriter(QPACK.Logger parentLogger) { + long id = ENCODERS_IDS.incrementAndGet(); + this.logger = parentLogger.subLogger("EncoderInstructionsWriter#" + id); + // Writer for "Insert with Name Reference" encoder instruction + insertIndexedNameWriter = new EncoderInsertIndexedNameWriter( + logger.subLogger("EncoderInsertIndexedNameWriter")); + // Writer for "Insert with Literal Name" encoder instruction + insertLiteralNameWriter = new EncoderInsertLiteralNameWriter( + logger.subLogger("EncoderInsertLiteralNameWriter")); + // Writer for "Set Dynamic Table Capacity" encoder instruction + capacityWriter = new EncoderDynamicTableCapacityWriter(); + // Writer for "Duplicate" encoder instruction + duplicateWriter = new EncoderDuplicateEntryWriter(); + } + + /* + * Configure EncoderInstructionsWriter for encoding "Insert with Name Reference" or "Insert with Literal Name" + * encoder instruction. The instruction is selected based on TableEntry.type() value: + * "Insert with Name Reference" is selected for TableEntry.EntryType.NAME: + * 0 1 2 3 4 5 6 7 + * +---+---+---+---+---+---+---+---+ + * | 1 | T | Name Index (6+) | + * +---+---+-----------------------+ + * | H | Value Length (7+) | + * +---+---------------------------+ + * | Value String (Length bytes) | + * +-------------------------------+ + * + * "Insert with Literal Name" is selected for TableEntry.EntryType.NEITHER: + * 0 1 2 3 4 5 6 7 + * +---+---+---+---+---+---+---+---+ + * | 0 | 1 | H | Name Length (5+) | + * +---+---+---+-------------------+ + * | Name String (Length bytes) | + * +---+---------------------------+ + * | H | Value Length (7+) | + * +---+---------------------------+ + * | Value String (Length bytes) | + * +-------------------------------+ + */ + public int configureForEntryInsertion(TableEntry e) { + checkIfEncodingInProgress(); + encoding = true; + writer = switch (e.type()) { + case NAME -> insertIndexedNameWriter.configure(e); + case NEITHER -> insertLiteralNameWriter.configure(e); + default -> throw new IllegalArgumentException("Unsupported table entry insertion type: " + e.type()); + }; + return calculateEntryInsertionSize(e); + } + + /* + * Configure EncoderInstructionsWriter for encoding "Duplicate" encoder instruction: + * 0 1 2 3 4 5 6 7 + * +---+---+---+---+---+---+---+---+ + * | 0 | 0 | 0 | Index (5+) | + * +---+---+---+-------------------+ + */ + public int configureForEntryDuplication(long entryIndexToDuplicate) { + checkIfEncodingInProgress(); + encoding = true; + duplicateWriter.configure(entryIndexToDuplicate); + writer = duplicateWriter; + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("duplicate entry with id=%s", entryIndexToDuplicate)); + } + return IntegerWriter.requiredBufferSize(5, entryIndexToDuplicate); + } + + /* + * Configure EncoderInstructionsWriter for encoding "Set Dynamic Table Capacity" encoder instruction: + * 0 1 2 3 4 5 6 7 + * +---+---+---+---+---+---+---+---+ + * | 0 | 0 | 1 | Capacity (5+) | + * +---+---+---+-------------------+ + */ + public int configureForTableCapacityUpdate(long tableCapacity) { + checkIfEncodingInProgress(); + encoding = true; + capacityWriter.configure(tableCapacity); + writer = capacityWriter; + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("set dynamic table capacity to %s", tableCapacity)); + } + return IntegerWriter.requiredBufferSize(5, tableCapacity); + } + + + public boolean write(ByteBuffer byteBuffer) { + if (!encoding) { + throw new IllegalStateException("Writer hasn't been configured"); + } + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("writing to %s", byteBuffer)); + } + boolean done = writer.write(byteBuffer); + if (done) { + writer.reset(); + encoding = false; + } + return done; + } + + private int calculateEntryInsertionSize(TableEntry e) { + int vlen = Math.min(QuickHuffman.lengthOf(e.value()), e.value().length()); + int integerValuesSize; + return switch (e.type()) { + case NAME -> { + // Calculate how many bytes are needed to encode the index part: + // | 1 | T | Name Index (6+) | + integerValuesSize = IntegerWriter.requiredBufferSize(6, e.index()); + // Calculate how many bytes are needed to encode the value length part: + // | H | Value Length (7+) | + integerValuesSize += IntegerWriter.requiredBufferSize(7, vlen); + // We also need vlen bytes for the value string content + yield integerValuesSize + vlen; + } + case NEITHER -> { + int nlen = Math.min(QuickHuffman.lengthOf(e.name()), e.name().length()); + // Calculate how many bytes are needed to encode the name length part: + // | 0 | 1 | H | Name Length (5+) | + integerValuesSize = IntegerWriter.requiredBufferSize(5, nlen); + // Calculate how many bytes are needed to encode the value length part: + // | H | Value Length (7+) | + integerValuesSize += IntegerWriter.requiredBufferSize(7, vlen); + // We also need nlen + vlen bytes for the name and the value strings + // content + yield integerValuesSize + nlen + vlen; + } + default -> throw new IllegalArgumentException("Unsupported table entry type: " + e.type()); + }; + } + + private void checkIfEncodingInProgress() { + if (encoding) { + throw new IllegalStateException( + "Previous encoding operation hasn't finished yet"); + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineIndexedNameWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineIndexedNameWriter.java new file mode 100644 index 00000000000..6e268f32acc --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineIndexedNameWriter.java @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.TableEntry; + +import java.nio.ByteBuffer; + +import static java.lang.String.format; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; + +final class FieldLineIndexedNameWriter implements BinaryRepresentationWriter { + private int state = NEW; + private final QPACK.Logger logger; + private final IntegerWriter intWriter = new IntegerWriter(); + private final StringWriter valueWriter = new StringWriter(); + private static final int NEW = 0; + private static final int NAME_PART_WRITTEN = 1; + private static final int VALUE_WRITTEN = 2; + + FieldLineIndexedNameWriter(QPACK.Logger logger) { + this.logger = logger; + } + + public BinaryRepresentationWriter configure(TableEntry e, boolean hideIntermediary, long base) + throws IndexOutOfBoundsException { + return e.isStaticTable() ? configureStatic(e, hideIntermediary) : + configureDynamic(e, hideIntermediary, base); + } + + private BinaryRepresentationWriter configureStatic(TableEntry e, boolean hideIntermediary) + throws IndexOutOfBoundsException { + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format( + "Field Line With Static Table Name Reference" + + " (%s, '%s', huffman=%b, hideIntermediary=%b)", + e.index(), e.value(), e.huffmanValue(), hideIntermediary)); + } + return this.staticIndex(e.index(), hideIntermediary).value(e); + } + + private BinaryRepresentationWriter configureDynamic(TableEntry e, boolean hideIntermediary, long base) + throws IndexOutOfBoundsException { + boolean usePostBase = e.index() >= base; + long index = usePostBase ? e.index() - base : base - 1 - e.index(); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format( + "Field Line With %s Dynamic Table Name Reference" + + " (%s, '%s', huffman=%b, hideIntermediary=%b)", + usePostBase ? "Post-Base" : "", index, e.value(), e.huffmanValue(), + hideIntermediary)); + } + if (usePostBase) { + return this.dynamicPostBaseIndex(index, hideIntermediary).value(e); + } else { + return this.dynamicIndex(index, hideIntermediary).value(e); + } + } + + @Override + public boolean write(ByteBuffer destination) { + if (state < NAME_PART_WRITTEN) { + if (!intWriter.write(destination)) { + return false; + } + state = NAME_PART_WRITTEN; + } + if (state < VALUE_WRITTEN) { + if (!valueWriter.write(destination)) { + return false; + } + state = VALUE_WRITTEN; + } + return state == VALUE_WRITTEN; + } + + @Override + public FieldLineIndexedNameWriter reset() { + intWriter.reset(); + valueWriter.reset(); + state = NEW; + return this; + } + + private FieldLineIndexedNameWriter staticIndex(long absoluteIndex, boolean hideIntermediary) { + int payload = 0b0101_0000; + if (hideIntermediary) { + payload |= 0b0010_0000; + } + intWriter.configure(absoluteIndex, 4, payload); + return this; + } + + private FieldLineIndexedNameWriter dynamicIndex(long relativeIndex, boolean hideIntermediary) { + int payload = 0b0100_0000; + if (hideIntermediary) { + payload |= 0b0010_0000; + } + intWriter.configure(relativeIndex, 4, payload); + return this; + } + + private FieldLineIndexedNameWriter dynamicPostBaseIndex(long relativeIndex, boolean hideIntermediary) { + int payload = 0b0000_0000; + if (hideIntermediary) { + payload |= 0b0000_1000; + } + intWriter.configure(relativeIndex, 3, payload); + return this; + } + + private FieldLineIndexedNameWriter value(TableEntry e) { + int N = 7; + int payload = 0b0000_0000; + if (e.huffmanValue()) { + payload |= 0b1000_0000; + } + valueWriter.configure(e.value(), N, payload, e.huffmanValue()); + return this; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineIndexedWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineIndexedWriter.java new file mode 100644 index 00000000000..c829966980a --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineIndexedWriter.java @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.TableEntry; + +import java.nio.ByteBuffer; + +import static java.lang.String.format; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; + +final class FieldLineIndexedWriter implements BinaryRepresentationWriter { + private final QPACK.Logger logger; + private final IntegerWriter intWriter = new IntegerWriter(); + + public FieldLineIndexedWriter(QPACK.Logger logger) { + this.logger = logger; + } + + public BinaryRepresentationWriter configure(TableEntry e, long base) { + return e.isStaticTable() ? configureStatic(e) : configureDynamic(e, base); + } + + private BinaryRepresentationWriter configureStatic(TableEntry e) { + assert e.isStaticTable(); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Indexed Field Line Static Table reference" + + " (%s, '%s', '%s')", e.index(), e.name(), e.value())); + } + return this.staticIndex(e.index()); + } + + private BinaryRepresentationWriter configureDynamic(TableEntry e, long base) { + assert !e.isStaticTable(); + // RFC-9204: 3.2.6. Post-Base Indexing + // Post-Base indices are used in field line representations for entries with absolute + // indices greater than or equal to Base, starting at 0 for the entry with absolute index + // equal to Base and increasing in the same direction as the absolute index. + boolean usePostBase = e.index() >= base; + long index = usePostBase ? e.index() - base : base - 1 - e.index(); + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("Indexed Field Line Dynamic Table reference %s (%s[%s], '%s', '%s')", + usePostBase ? "with Post-Base Index" : "", index, e.index(), e.name(), e.value())); + } + if (usePostBase) { + return dynamicPostBaseIndex(index); + } else { + return dynamicIndex(index); + } + } + + @Override + public boolean write(ByteBuffer destination) { + return intWriter.write(destination); + } + + @Override + public BinaryRepresentationWriter reset() { + intWriter.reset(); + return this; + } + + private FieldLineIndexedWriter staticIndex(long absoluteIndex) { + int N = 6; + intWriter.configure(absoluteIndex, N, 0b1100_0000); + return this; + } + + private FieldLineIndexedWriter dynamicIndex(long relativeIndex) { + assert relativeIndex >= 0; + int N = 6; + intWriter.configure(relativeIndex, N, 0b1000_0000); + return this; + } + + private FieldLineIndexedWriter dynamicPostBaseIndex(long relativeIndex) { + assert relativeIndex >= 0; + int N = 4; + intWriter.configure(relativeIndex, N, 0b0001_0000); + return this; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineLiteralsWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineLiteralsWriter.java new file mode 100644 index 00000000000..42187c256d7 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineLiteralsWriter.java @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.TableEntry; + +import java.nio.ByteBuffer; + +import static java.lang.String.format; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; + +class FieldLineLiteralsWriter implements BinaryRepresentationWriter { + private int state = NEW; + private final QPACK.Logger logger; + private final StringWriter nameWriter = new StringWriter(); + private final StringWriter valueWriter = new StringWriter(); + private static final int NEW = 0; + private static final int NAME_PART_WRITTEN = 1; + private static final int VALUE_WRITTEN = 2; + + public FieldLineLiteralsWriter(QPACK.Logger logger) { + this.logger = logger; + } + + public BinaryRepresentationWriter configure(TableEntry e, boolean hideIntermediary) throws IndexOutOfBoundsException { + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format( + "Field Line With Name and Value Literals ('%s', '%s', huffmanName=%b, huffmanValue=%b, hideIntermediary=%b)", + e.name(), e.value(), e.huffmanName(), e.huffmanValue(), hideIntermediary)); + } + return this.name(e, hideIntermediary).value(e); + } + + @Override + public boolean write(ByteBuffer destination) { + if (state < NAME_PART_WRITTEN) { + if (!nameWriter.write(destination)) { + return false; + } + state = NAME_PART_WRITTEN; + } + if (state < VALUE_WRITTEN) { + if (!valueWriter.write(destination)) { + return false; + } + state = VALUE_WRITTEN; + } + return state == VALUE_WRITTEN; + } + + @Override + public BinaryRepresentationWriter reset() { + nameWriter.reset(); + valueWriter.reset(); + state = NEW; + return this; + } + + private FieldLineLiteralsWriter name(TableEntry e, boolean hideIntermediary) { + int N = 3; + int payload = 0b0010_0000; + if (hideIntermediary) { + payload |= 0b0001_0000; + } + if (e.huffmanName()) { + payload |= 0b0000_1000; + } + nameWriter.configure(e.name(), N, payload, e.huffmanName()); + return this; + } + + private FieldLineLiteralsWriter value(TableEntry e) { + int N = 7; + int payload = 0b0000_0000; + if (e.huffmanValue()) { + payload |= 0b1000_0000; + } + valueWriter.configure(e.value(), N, payload, e.huffmanValue()); + return this; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineSectionPrefixWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineSectionPrefixWriter.java new file mode 100644 index 00000000000..18e1d1e2676 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/FieldLineSectionPrefixWriter.java @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import jdk.internal.net.http.qpack.FieldSectionPrefix; + +import java.nio.ByteBuffer; + +public class FieldLineSectionPrefixWriter { + enum State {NEW, CONFIGURED, RIC_WRITTEN, DONE} + + private final IntegerWriter intWriter; + private State state = State.NEW; + private long encodedRic; + private int signBit; + private long deltaBase; + + public FieldLineSectionPrefixWriter() { + this.intWriter = new IntegerWriter(); + } + + private void encodeFieldSectionPrefixFields(FieldSectionPrefix fsp, long maxEntries) { + // Required Insert Count encoded according to RFC-9204 "4.5.1.1: Required Insert Count" + // Base and Sign encoded according to RFC-9204: "4.5.1.2. Base" + long ric = fsp.requiredInsertCount(); + long base = fsp.base(); + + if (ric == 0) { + encodedRic = 0; + deltaBase = 0; + signBit = 0; + } else { + encodedRic = (ric % (2 * maxEntries)) + 1; + signBit = base >= ric ? 0 : 1; + deltaBase = base >= ric ? base - ric : ric - base - 1; + } + } + + public int configure(FieldSectionPrefix sectionPrefix, long maxEntries) { + intWriter.reset(); + encodeFieldSectionPrefixFields(sectionPrefix, maxEntries); + intWriter.configure(encodedRic, 8, 0); + state = State.CONFIGURED; + return IntegerWriter.requiredBufferSize(8, encodedRic) + + IntegerWriter.requiredBufferSize(7, deltaBase); + } + + public boolean write(ByteBuffer destination) { + if (state == State.NEW) { + throw new IllegalStateException("Configure first"); + } + + if (state == State.CONFIGURED) { + if (!intWriter.write(destination)) { + return false; + } + // Required Insert Count part is written, + // prepare integer writer for delta base and + // base sign write + intWriter.reset(); + int signPayload = signBit == 1 ? 0b1000_0000 : 0b0000_0000; + intWriter.configure(deltaBase, 7, signPayload); + state = State.RIC_WRITTEN; + } + + if (state == State.RIC_WRITTEN) { + if (!intWriter.write(destination)) { + return false; + } + state = State.DONE; + } + return state == State.DONE; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/HeaderFrameWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/HeaderFrameWriter.java new file mode 100644 index 00000000000..3841ca7c5ff --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/HeaderFrameWriter.java @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.TableEntry; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; + +import static java.lang.String.format; +import static jdk.internal.net.http.qpack.QPACK.Logger.Level.EXTRA; + +public class HeaderFrameWriter { + private BinaryRepresentationWriter writer; + private final QPACK.Logger logger; + private final FieldLineIndexedWriter indexedWriter; + private final FieldLineIndexedNameWriter literalWithNameReferenceWriter; + private final FieldLineLiteralsWriter literalWithLiteralNameWriter; + private boolean encoding; + private static final AtomicLong HEADER_FRAME_WRITER_IDS = new AtomicLong(); + + public HeaderFrameWriter() { + this(QPACK.getLogger()); + } + + public HeaderFrameWriter(QPACK.Logger parentLogger) { + long id = HEADER_FRAME_WRITER_IDS.incrementAndGet(); + this.logger = parentLogger.subLogger("HeaderFrameWriter#" + id); + + indexedWriter = new FieldLineIndexedWriter(logger.subLogger("FieldLineIndexedWriter")); + literalWithNameReferenceWriter = new FieldLineIndexedNameWriter( + logger.subLogger("FieldLineIndexedNameWriter")); + literalWithLiteralNameWriter = new FieldLineLiteralsWriter( + logger.subLogger("FieldLineLiteralsWriter")); + } + + public void configure(TableEntry e, boolean sensitive, long base) { + checkIfEncodingInProgress(); + encoding = true; + writer = switch (e.type()) { + case NAME_VALUE -> indexedWriter.configure(e, base); + case NAME -> literalWithNameReferenceWriter.configure(e, sensitive, base); + case NEITHER -> literalWithLiteralNameWriter.configure(e, sensitive); + }; + } + + /** + * Writes the {@linkplain #configure(TableEntry, boolean, long) + * set up} header into the given buffer. + * + *

The method writes as much as possible of the header's binary + * representation into the given buffer, starting at the buffer's position, + * and increments its position to reflect the bytes written. The buffer's + * mark and limit will not be modified. + * + *

Once the method has returned {@code true}, the configured header is + * deemed encoded. A new header may be set up. + * + * @param headerFrame the buffer to encode the header into, may be empty + * @return {@code true} if the current header has been fully encoded, + * {@code false} otherwise + * @throws NullPointerException if the buffer is {@code null} + * @throws IllegalStateException if there is no set up header + */ + public boolean write(ByteBuffer headerFrame) { + if (!encoding) { + throw new IllegalStateException("A header hasn't been set up"); + } + if (logger.isLoggable(EXTRA)) { + logger.log(EXTRA, () -> format("writing to %s", headerFrame)); + } + boolean done = writer.write(headerFrame); + if (done) { + writer.reset(); + encoding = false; + } + return done; + } + + private void checkIfEncodingInProgress() { + if (encoding) { + throw new IllegalStateException("Previous encoding operation hasn't finished yet"); + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/IntegerWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/IntegerWriter.java new file mode 100644 index 00000000000..f042135b4a1 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/IntegerWriter.java @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +public final class IntegerWriter { + + private static final int NEW = 0; + private static final int CONFIGURED = 1; + private static final int FIRST_BYTE_WRITTEN = 2; + private static final int DONE = 4; + + private int state = NEW; + + private int payload; + private int N; + private long value; + + // + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | | | | | | | | | + // +---+---+---+-------------------+ + // |<--------->|<----------------->| + // payload N=5 + // + // payload is the contents of the left-hand side part of the first octet; + // it is truncated to fit into 8-N bits, where 1 <= N <= 8; + // + public IntegerWriter configure(long value, int N, int payload) { + if (state != NEW) { + throw new IllegalStateException("Already configured"); + } + if (value < 0) { + throw new IllegalArgumentException("value >= 0: value=" + value); + } + checkPrefix(N); + this.value = value; + this.N = N; + this.payload = payload & 0xFF & (0xFFFFFFFF << N); + state = CONFIGURED; + return this; + } + + public boolean write(ByteBuffer output) { + if (state == NEW) { + throw new IllegalStateException("Configure first"); + } + if (state == DONE) { + return true; + } + + if (!output.hasRemaining()) { + return false; + } + if (state == CONFIGURED) { + int max = (2 << (N - 1)) - 1; + if (value < max) { + output.put((byte) (payload | value)); + state = DONE; + return true; + } + output.put((byte) (payload | max)); + value -= max; + state = FIRST_BYTE_WRITTEN; + } + if (state == FIRST_BYTE_WRITTEN) { + while (value >= 128 && output.hasRemaining()) { + output.put((byte) ((value & 127) + 128)); + value /= 128; + } + if (!output.hasRemaining()) { + return false; + } + output.put((byte) value); + state = DONE; + return true; + } + throw new InternalError(Arrays.toString( + new Object[]{state, payload, N, value})); + } + + private static void checkPrefix(int N) { + if (N < 1 || N > 8) { + throw new IllegalArgumentException("1 <= N <= 8: N= " + N); + } + } + + public static int requiredBufferSize(int N, long value) { + checkPrefix(N); + int size = 1; + int max = (2 << (N - 1)) - 1; + if (value < max) { + return size; + } + size++; + value -= max; + while (value >= 128) { + value /= 128; + size++; + } + return size; + } + + public IntegerWriter reset() { + state = NEW; + return this; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/StringWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/StringWriter.java new file mode 100644 index 00000000000..298c3d8f9c1 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/qpack/writers/StringWriter.java @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.qpack.writers; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +import jdk.internal.net.http.hpack.ISO_8859_1; +import jdk.internal.net.http.hpack.Huffman; +import jdk.internal.net.http.hpack.QuickHuffman; + +// +// 0 1 2 3 4 5 6 7 +// +---+---+---+---+---+---+---+---+ +// | H | String Length (7+) | +// +---+---------------------------+ +// | String Data (Length octets) | +// +-------------------------------+ +// +// StringWriter does not require a notion of endOfInput (isLast) in 'write' +// methods due to the nature of string representation in HPACK. Namely, the +// length of the string is put before string's contents. Therefore the length is +// always known beforehand. +// +// Expected use: +// +// configure write* (reset configure write*)* +// +public final class StringWriter { + private static final int DEFAULT_PREFIX = 7; + private static final int DEFAULT_PAYLOAD = 0b0000_0000; + private static final int HUFFMAN_PAYLOAD = 0b1000_0000; + private static final int NEW = 0; + private static final int CONFIGURED = 1; + private static final int LENGTH_WRITTEN = 2; + private static final int DONE = 4; + + private final IntegerWriter intWriter = new IntegerWriter(); + private final Huffman.Writer huffmanWriter = new QuickHuffman.Writer(); + private final ISO_8859_1.Writer plainWriter = new ISO_8859_1.Writer(); + + private int state = NEW; + private boolean huffman; + + public StringWriter configure(CharSequence input, boolean huffman) { + return configure(input, 0, input.length(), DEFAULT_PREFIX, huffman ? HUFFMAN_PAYLOAD : DEFAULT_PAYLOAD, huffman); + } + + public StringWriter configure(CharSequence input, int N, int payload, boolean huffman) { + return configure(input, 0, input.length(), N, payload, huffman); + } + + StringWriter configure(CharSequence input, + int start, + int end, + int N, + int payload, + boolean huffman) { + if (start < 0 || end < 0 || end > input.length() || start > end) { + throw new IndexOutOfBoundsException( + String.format("input.length()=%s, start=%s, end=%s", + input.length(), start, end)); + } + if (!huffman) { + plainWriter.configure(input, start, end); + intWriter.configure(end - start, N, payload); + } else { + huffmanWriter.from(input, start, end); + intWriter.configure(huffmanWriter.lengthOf(input, start, end), N, payload); + } + + this.huffman = huffman; + state = CONFIGURED; + return this; + } + + public boolean write(ByteBuffer output) { + if (state == DONE) { + return true; + } + if (state == NEW) { + throw new IllegalStateException("Configure first"); + } + if (!output.hasRemaining()) { + return false; + } + if (state == CONFIGURED) { + if (intWriter.write(output)) { + state = LENGTH_WRITTEN; + } else { + return false; + } + } + if (state == LENGTH_WRITTEN) { + boolean written = huffman + ? huffmanWriter.write(output) + : plainWriter.write(output); + if (written) { + state = DONE; + return true; + } else { + return false; + } + } + throw new InternalError(Arrays.toString(new Object[]{state, huffman})); + } + + public void reset() { + intWriter.reset(); + if (huffman) { + huffmanWriter.reset(); + } else { + plainWriter.reset(); + } + state = NEW; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/BuffersReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/BuffersReader.java new file mode 100644 index 00000000000..9f2ebf17264 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/BuffersReader.java @@ -0,0 +1,707 @@ +/* + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +/** + * A class that allows to read data from an aggregation of {@code ByteBuffer}. + * This is mostly geared to reading Quic or HTTP/3 frames that are composed + * of an aggregation of {@linkplain VariableLengthEncoder Variable Length Integers}. + * This class is not multi-thread safe. + *

+ * The {@code BuffersReader} class is an abstract class with two concrete + * implementations: {@link SingleBufferReader} and {@link ListBuffersReader}. + *

+ * The {@link SingleBufferReader} presents a simple lightweight view of a single + * {@link ByteBuffer}. Instances of {@code SingleBufferReader} can be created by + * calling {@link BuffersReader#single(ByteBuffer) BuffersReader.single(buffer)}; + *

+ * The {@link ListBuffersReader} view can be created from a (possibly empty) + * list of byte buffers. New byte buffers can be later {@linkplain + * ListBuffersReader#add(ByteBuffer) added} to the {@link ListBuffersReader} instance + * as they become available. Once a frame has been fully received, + * {@link BuffersReader#release()} or {@link BuffersReader#getAndRelease(long)} should + * be called to forget and relinquish all bytes buffers up to the current + * {@linkplain #position() position} of the {@code BuffersReader}. + * Released buffers are removed from {@code BuffersReader} list, and the position + * of the reader is reset to 0, allowing to read the next frame from the remaining + * data. + */ +public abstract sealed class BuffersReader { + + /** + * Release all buffers held by this {@code BuffersReader}, whether + * consumed or unconsumed. Released buffer are all set to their + * limit. + */ + public abstract void clear(); + + // Used to store the original position and limit of a + // buffer at the time it's added to the reader's list + // It is not possible to beyond that position or limit when + // using the reader + private record Buffer(ByteBuffer buffer, int offset, int limit) { + Buffer { + assert offset <= limit; + assert offset >= 0; + assert limit == buffer.limit(); + } + Buffer(ByteBuffer buffer) { + this(buffer, buffer.position(), buffer.limit()); + } + } + + /** + * {@return the current position of the reader} + * The semantic is similar to {@link ByteBuffer#position()}. + */ + public abstract long position(); + + /** + * {@return the limit of the reader} + * The semantic is similar to {@link ByteBuffer#limit()}. + */ + public abstract long limit(); + + /** + * Reads one byte from the reader. This method increase + * the position by one. The semantic is similar to + * {@link ByteBuffer#get()}. + * @return the byte at the current position + * @throws BufferUnderflowException if trying to read past + * the limit. + */ + public abstract byte get(); + + /** + * Reads the byte located at the given position in the + * reader. The semantic is similar to {@link ByteBuffer#get(int)}. + * This method doesn't change the position of the reader. + * + * @param position the position of the byte + * @return the byte at the given position in the reader + * + * @throws IndexOutOfBoundsException if trying to read before + * the reader's position or after the reader's limit + */ + public abstract byte get(long position); + + /** + * Sets the position of the reader. + * The semantic is similar to {@link ByteBuffer#position(int)}. + * + * @param newPosition the new position + * + * @throws IllegalArgumentException if trying to set + * the position to a negative value, or to a value + * past the limit + */ + public abstract void position(long newPosition); + + /** + * Releases all the data that has been read, sets the + * reader's position to 0 and its limit to the amount + * of data remaining. + */ + public abstract void release(); + + /** + * Returns a list of {@code ByteBuffer} containing the + * requested amount of bytes, starting at the current + * position, then release all the data up to the new + * position, and reset the reader's position to 0 and + * the reader's limit to the amount of remaining data. + * . + * @param bytes the amount of bytes to read and move + * to the returned list. + * + * @return a list of {@code ByteBuffer} containing the next + * {@code bytes} of data, starting at the current position. + * + * @throws BufferUnderflowException if attempting to read past + * the limit + */ + public abstract List getAndRelease(long bytes); + + /** + * {@return true if the reader has remaining bytes to the read} + * The semantic is similar to {@link ByteBuffer#hasRemaining()}. + */ + public boolean hasRemaining() { + return position() < limit(); + } + + /** + * {@return the number of bytes that remain to read} + * The semantic is similar to {@link ByteBuffer#remaining()}. + */ + public long remaining() { + long rem = limit() - position(); + return rem > 0 ? rem : 0; + } + + /** + * {@return the cumulated amount of data that has been read in this + * {@code BuffersReader} since its creation} + * This number is not reset when calling {@link #release()}. + */ + public abstract long read(); + + /** + * {@return The offset of this {@code BuffersReader}} + * This is the position in the first {@code ByteBuffer} that + * was set on the reader. The {@code BuffersReader} will not + * allow to get or set a position lower than the offset. + */ + public abstract long offset(); + + /** + * {@return true if this {@code BuffersReader} is empty} + * A {@code BuffersReader} is empty if it has been {@linkplain + * #list() created empty, or if it has been {@linkplain #release() + * released} after all data has been read. + */ + public abstract boolean isEmpty(); + + /** + * A lightweight view allowing to see a {@link ByteBuffer} as a + * {@link BuffersReader}. This class wrap a single {@link ByteBuffer} + * and cannot be reused after {@link #release()}. + */ + public static final class SingleBufferReader extends BuffersReader { + ByteBuffer single; + long read = 0; + long start; + SingleBufferReader(ByteBuffer single) { + this.single = single; + start = single.position(); + } + + @Override + public void release() { + single = null; + } + + @Override + public List getAndRelease(long bytes) { + return List.of(getAndReleaseBuffer(bytes)); + } + + @Override + public byte get() { + if (single == null) throw new BufferUnderflowException(); + return single.get(); + } + + @Override + public byte get(long position) { + if (single == null || position < start || position >= single.limit()) + throw new IndexOutOfBoundsException(); + return single.get((int) position); + } + + @Override + public long limit() { + return single == null ? 0 : single.limit(); + } + + @Override + public long position() { + return single == null ? 0 : single.position(); + } + + @Override + public boolean hasRemaining() { + return single != null && single.hasRemaining(); + } + + @Override + public void position(long pos) { + if (single == null || pos < start || pos > single.limit()) + throw new BufferUnderflowException(); + single.position((int) pos); + } + + /** + * This method has the same semantics than {@link #getAndRelease(long)} + * except that it avoids creating a list. + * @return a buffer containing the next {@code bytes}. + */ + public ByteBuffer getAndReleaseBuffer(long bytes) { + var released = single; + int remaining = released.remaining(); + if (bytes > remaining) + throw new BufferUnderflowException(); + if (bytes == remaining) { + read = single.limit() - start; + single = null; + } else { + read = single.position() - start; + single = released.slice(released.position() + (int)bytes, released.limit()); + start = 0; + released = released.slice(released.position(), (int) bytes); + } + return released; + } + + @Override + public long read() { + return single == null ? read : (read + single.position() - start); + } + + @Override + public long offset() { + return start; + } + + @Override + public boolean isEmpty() { + return single == null; + } + + @Override + public void clear() { + if (single == null) return; + single.position(single.limit()); + single = null; + } + } + + /** + * A {@code BuffersReader} that iterates over a list of {@code ByteBuffers}. + * New {@code ByteBuffers} can be added at the end of list by calling + * {@link #add(ByteBuffer)} or {@link #addAll(List)}, which increases + * the {@linkplain #limit() limit} accordingly. + *

+ * When {@link #release() released}, the data prior to the current + * {@linkplain #position()} is discarded, the {@linkplain #position() position} + * and {@linkplain #offset() offset} are reset to {@code 0}, and the + * {@linkplain #limit() limit} is set to the amount of remaining data. + *

+ * A {@code ListBuffersReader} can be reused after being released. + * If it still contains data, the {@linkplain #offset() offset} will + * be {@code 0}. Otherwise, the offset will be set to the position + * of the first buffer {@linkplain #add(ByteBuffer) added} to the + * {@code ListBuffersReader}. + */ + public static final class ListBuffersReader extends BuffersReader { + private final List buffers = new ArrayList<>(); + private Buffer current; + private int nextIndex; + private long currentOffset; + private long position; + private long limit; + private long start; + private long readAndReleased = 0; + + ListBuffersReader() { + } + + /** + * Adds a new {@code ByteBuffer} to this {@code BuffersReader}. + * If the reader is {@linkplain #isEmpty() empty}, the reader's + * {@linkplain #offset() offset} and {@linkplain #position() position} + * is set to the buffer's position, and the reader {@linkplain #limit() + * limit} is set to the buffer's limit. + * Otherwise, the reader's limit is simply increased by the buffer's + * remaining bytes. The reader will only allow to read those bytes + * between the current position and limit of the buffer. + * + * @apiNote + * This class doesn't make defensive copies of the provided buffers, + * so the caller must not modify the buffer's position or limit + * after it's been added to the reader. + * + * @param buffer a byte buffer + * @return this reader + */ + public ListBuffersReader add(ByteBuffer buffer) { + if (buffers.isEmpty()) { + int lim = buffer.limit(); + buffers.add(new Buffer(buffer, 0, lim)); + start = buffer.position(); + position = limit = start; + currentOffset = 0; + } else { + buffers.add(new Buffer(buffer)); + } + limit += buffer.remaining(); + return this; + } + + /** + * Adds a list of byte buffers to this reader. + * This is equivalent to calling: + * {@snippet : + * ListBuffersReader reader = ...; + * for (var buffer : buffers) { + * reader.add(buffer); // @link substring="add" target="#add(ByteBuffer)" + * } + * } + * @param buffers a list of {@link ByteBuffer ByteBuffers} + * @return this reader + */ + public ListBuffersReader addAll(List buffers) { + for (var buffer : buffers) { + if (isEmpty()) { + add(buffer); + continue; + } + this.buffers.add(new Buffer(buffer)); + limit += buffer.remaining(); + } + return this; + } + + @Override + public boolean isEmpty() { + return buffers.isEmpty(); + } + + @Override + public byte get() { + ByteBuffer buffer = current(true); + byte res = buffer.get(); + position++; + return res; + } + + @Override + public byte get(long pos) { + if (pos >= limit || pos < start) + throw new IndexOutOfBoundsException(); + ByteBuffer buffer = current(false); + if (position == limit && current != null) { + // let the current buffer throw + buffer = current.buffer; + } + assert buffer != null : "limit check failed"; + if (pos == position) { + return buffer.get(buffer.position()); + } + long offset = currentOffset; + int index = nextIndex; + Buffer cur = current; + while (pos >= offset) { + int bpos = buffer.position(); + int boffset = cur.offset; + int blimit = buffer.limit(); + assert index == nextIndex || bpos == boffset; + if (pos - offset < blimit - boffset) { + return buffer.get((int) (pos - offset + boffset)); + } + if (index >= buffers.size()) { + assert false : "buffers exhausted"; + throw new IndexOutOfBoundsException(); + } + int skipped = cur.limit - cur.offset; + offset += skipped; + cur = buffers.get(index++); + buffer = cur.buffer; + } + assert pos <= offset; + int blimit = cur.offset; + int boffset = cur.offset; + while (pos < offset) { + assert blimit == cur.limit || index == nextIndex && blimit == boffset; + if (index <= 1) { + assert false : "buffers exhausted"; + throw new IndexOutOfBoundsException(); + } + cur = buffers.get(--index - 1); + buffer = cur.buffer; + int bpos = buffer.position(); + blimit = buffer.limit(); + boffset = cur.offset; + int skipped = blimit - boffset; + offset -= skipped; + assert index == nextIndex || bpos == blimit; + if (pos - offset >= 0 && pos - offset < blimit - boffset) { + return buffer.get((int) (pos - offset + boffset)); + } + } + assert false : "buffer not found"; + throw new IndexOutOfBoundsException(); // should not reach here + } + + /** + * {@return the current {@code ByteBuffer} in which to find + * the byte at the current {@link #position()}} + * + * @param throwIfUnderflow if true, calling this method + * will throw {@link BufferUnderflowException} if + * the position is past the limit. + * + * @throws BufferUnderflowException if attempting to read past + * the limit and {@code throwIfUnderflow == true} + */ + private ByteBuffer current(boolean throwIfUnderflow) { + while (current == null || !current.buffer.hasRemaining()) { + if (buffers.size() > nextIndex) { + if (nextIndex != 0) { + currentOffset = position; + } else { + currentOffset = 0; + } + current = buffers.get(nextIndex++); + } else if (throwIfUnderflow) { + throw new BufferUnderflowException(); + } else { + return null; + } + } + return current.buffer; + } + + @Override + public List getAndRelease(long bytes) { + release(); + if (bytes > limit - position) { + throw new BufferUnderflowException(); + } + ByteBuffer buf = current(false); + if (buf == null || bytes == 0) return List.of(); + List list = null; + assert position == 0; + assert currentOffset == 0; + while (bytes > 0) { + buf = current(false); + assert nextIndex == 1; + assert buf != null; + assert buf.position() == current.offset; + int remaining = buf.remaining(); + if (remaining <= bytes) { + var b = buffers.remove(--nextIndex); + assert b == current; + long relased = buf.remaining(); + assert b.buffer.limit() == b.limit; + bytes -= relased; + limit -= relased; + readAndReleased += relased; + current = null; + + // if a buffer has no remaining bytes it + // may be EOF. Let's not skip it here + // if (!buf.hasRemaining()) continue; + + if (bytes == 0 && list == null) { + list = List.of(buf); + } else { + if (list == null) { + list = new ArrayList<>(); + } + list.add(buf); + } + } else { + var b = current; + long relased = bytes; + bytes = 0; + limit -= relased; + var pos = buf.position(); + assert b.limit == buf.limit(); + assert pos == b.offset; + var slice = buf.slice(pos, (int)relased); + buf.position(pos + (int) relased); + buffers.set(nextIndex - 1, current = new Buffer(buf)); + readAndReleased += relased; + if (list != null) { + list.add(slice); + } else { + list = List.of(slice); + } + assert bytes == 0; + } + } + return list; + } + + @Override + public long position() { + return position; + } + + @Override + public long limit() { + return limit; + } + + @Override + public void release() { + long released = - start; + for (var it = buffers.listIterator(); it.hasNext(); ) { + var b = it.next(); + var buf = b.buffer; + released += (buf.position() - b.offset); + if (buf.hasRemaining()) { + it.set(new Buffer(buf)); + break; + } + it.remove(); + } + assert released == position - start + : "start=%s, position=%s, released=%s" + .formatted(start, position, released); + readAndReleased += released; + limit -= position; + current = null; + position = 0; + currentOffset = 0; + nextIndex = 0; + start = 0; + } + + @Override + public void position(long pos) { + if (pos > limit) throw new IllegalArgumentException(pos + " > " + limit); + if (pos < start) throw new IllegalArgumentException(pos + " < " + start); + if (pos == position) return; // happy case! + // look forward, starting from the current position: + // - identify the ByteBuffer that contains the requested position + // - set the local position in that ByteBuffer to + // match the requested position + if (pos > position) { + long skip = pos - position; + assert skip > 0; + while (skip > 0) { + var buffer = current(true); + int remaining = buffer.remaining(); + if (remaining == 0) continue; + if (skip > remaining) { + // somewhere after the current buffer + buffer.position(buffer.limit()); + position += remaining; + skip -= remaining; + } else { + // somewhere in the current buffer + buffer.position(buffer.position() + (int) skip); + position += skip; + skip = 0; + } + } + } else { + // look backward, starting from the current position: + // - identify the ByteBuffer that contains the requested position + // - set the local position in that ByteBuffer to + // match the requested position + long skip = pos - position; + assert skip < 0; + if (current == null) { + current(false); + if (current == null) + throw new IllegalArgumentException(); + } + while (skip < 0) { + var buffer = current.buffer; + assert buffer.limit() == current.limit; + var remaining = buffer.position() - current.offset; + var rest = skip + remaining; + if (rest >= 0) { + // somewhere in this byte buffer, between the + // buffer offset and the buffer position + buffer.position(buffer.position() + (int)skip); + position += skip; + assert position >= start; + skip = 0; + } else { + // in some buffer prior to the current byte buffer + buffer.position(current.offset); + skip += remaining; + position -= remaining; + assert skip < 0; + assert position >= start; + assert nextIndex > 1; + current = buffers.get(--nextIndex - 1); + currentOffset -= current.limit - current.offset; + assert currentOffset >= 0; + assert current.buffer.position() == current.limit; + } + } + } + } + + @Override + public long read() { + return readAndReleased + (position - start); + } + + @Override + public long offset() { + return start; + } + + @Override + public void clear() { + release(); + position(limit()); + release(); + } + } + + /** + * Creates a lightweight {@link SingleBufferReader} view over + * a single {@link ByteBuffer}. + * @param buffer a byte buffer + * @return a lightweight {@link SingleBufferReader} view over + * a single {@link ByteBuffer} + */ + public static SingleBufferReader single(ByteBuffer buffer) { + return new SingleBufferReader(Objects.requireNonNull(buffer)); + } + + /** + * Creates an {@linkplain #isEmpty() empty} {@link ListBuffersReader}. + * @return an empty {@code ListBuffersReader} + */ + public static ListBuffersReader list() { + return new ListBuffersReader(); + } + + /** + * Creates a {@link ListBuffersReader} with the given + * {@code buffer}. More buffers can be later {@linkplain + * ListBuffersReader#add(ByteBuffer) added} as they become + * available. + * @return a {@code ListBuffersReader} + */ + public static ListBuffersReader list(ByteBuffer buffer) { + return new ListBuffersReader().add(buffer); + } + + /** + * Creates a {@link ListBuffersReader} with the given + * {@code buffers} list. More buffers can be later {@linkplain + * ListBuffersReader#add(ByteBuffer) added} as they become + * available. + * @return a {@code ListBuffersReader} + */ + public static ListBuffersReader list(List buffers) { + return new ListBuffersReader().addAll(buffers); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/CodingContext.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/CodingContext.java new file mode 100644 index 00000000000..d2daa6fadaf --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/CodingContext.java @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.quic; + +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportException; + +import java.io.IOException; +import java.nio.ByteBuffer; + +public interface CodingContext { + + /** + * {@return the largest incoming packet number successfully processed + * in the given packet number space} + * + * @apiNote + * This method is used when decoding the packet number of an incoming packet. + * + * @param packetSpace the packet number space + */ + long largestProcessedPN(QuicPacket.PacketNumberSpace packetSpace); + + /** + * {@return the largest outgoing packet number acknowledged by the peer + * in the given packet number space} + * + * @apiNote + * This method is used when encoding the packet number of an outgoing packet. + * + * @param packetSpace the packet number space + */ + long largestAckedPN(QuicPacket.PacketNumberSpace packetSpace); + + /** + * {@return the length of the local connection ids expected + * to be found in incoming short header packets} + */ + int connectionIdLength(); + + /** + * {@return the largest incoming packet number successfully processed + * in the packet number space corresponding to the given packet type} + *

+ * This is equivalent to calling:

+     *     {@code largestProcessedPN(QuicPacket.PacketNumberSpace.of(packetType));}
+     * 
+ * + * @apiNote + * This method is used when decoding the packet number of an incoming packet. + * + * @param packetType the packet type + */ + default long largestProcessedPN(QuicPacket.PacketType packetType) { + return largestProcessedPN(QuicPacket.PacketNumberSpace.of(packetType)); + } + + /** + * {@return the largest outgoing packet number acknowledged by the peer + * in the packet number space corresponding to the given packet type} + *

+ * This is equivalent to calling:

+     *     {@code largestAckedPN(QuicPacket.PacketNumberSpace.of(packetType));}
+     * 
+ * + * @apiNote + * This method is used when encoding the packet number of an outgoing packet. + * + * @param packetType the packet type + */ + default long largestAckedPN(QuicPacket.PacketType packetType) { + return largestAckedPN(QuicPacket.PacketNumberSpace.of(packetType)); + } + + /** + * Writes the given outgoing packet in the given byte buffer. + * This method moves the position of the byte buffer. + * @param packet the outgoing packet to write + * @param buffer the byte buffer to write the packet into + * @return the number of bytes written + * @throws java.nio.BufferOverflowException if the buffer doesn't have + * enough space to write the packet + */ + int writePacket(QuicPacket packet, ByteBuffer buffer) + throws QuicKeyUnavailableException, QuicTransportException; + + /** + * Reads an encrypted packet from the given byte buffer. + * This method moves the position of the byte buffer. + * @param src a byte buffer containing a non encrypted packet + * @return the packet read + * @throws IOException if the packet couldn't be read + * @throws QuicTransportException if packet is correctly signed but malformed + */ + QuicPacket parsePacket(ByteBuffer src) throws IOException, QuicKeyUnavailableException, QuicTransportException; + + /** + * Returns the original destination connection id, required for + * calculating the retry integrity tag. + *

+ * This is only of interest when protecting/unprotecting a {@linkplain + * QuicPacket.PacketType#RETRY Retry Packet}. + * + * @return the original destination connection id, required for calculating + * the retry integrity tag + */ + QuicConnectionId originalServerConnId(); + + /** + * Returns the TLS engine associated with this context + * @return the TLS engine associated with this context + */ + QuicTLSEngine getTLSEngine(); + + /** + * Checks if the provided token is valid for the given context and connection ID. + * @param destinationID destination connection ID found in the packet + * @param token token to verify + * @return true if token is valid, false otherwise + */ + boolean verifyToken(QuicConnectionId destinationID, byte[] token); + + /** + * {@return The minimum payload size for short packet payloads}. + * Padding will be added to match that size if needed. + * @param destConnectionIdLength the length of the destination + * connectionId included in the packet + */ + default int minShortPacketPayloadSize(int destConnectionIdLength) { + // See RFC 9000, Section 10.3 + // https://www.rfc-editor.org/rfc/rfc9000#section-10.3 + // [..] the endpoint SHOULD ensure that all packets it sends + // are at least 22 bytes longer than the minimum connection + // ID length that it requests the peer to include in its + // packets [...] + // + // A 1-RTT packet contains the peer connection id + // (whose length is destConnectionIdLength), therefore the + // payload should be at least 5 - (destConnectionIdLength + // - connectionIdLength()) - where connectionIdLength is the + // length of the local connection ID. + return 5 - (destConnectionIdLength - connectionIdLength()); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/ConnectionTerminator.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/ConnectionTerminator.java new file mode 100644 index 00000000000..24230a0883d --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/ConnectionTerminator.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +// responsible for managing the connection termination of a QUIC connection +public sealed interface ConnectionTerminator permits ConnectionTerminatorImpl { + + // lets the terminator know that the connection is still alive and should not be + // idle timed out + void keepAlive(); + + void terminate(TerminationCause cause); + + boolean tryReserveForUse(); + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/ConnectionTerminatorImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/ConnectionTerminatorImpl.java new file mode 100644 index 00000000000..150d6233953 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/ConnectionTerminatorImpl.java @@ -0,0 +1,475 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.quic.QuicConnectionImpl.HandshakeFlow; +import jdk.internal.net.http.quic.QuicConnectionImpl.ProtectionRecord; +import jdk.internal.net.http.quic.TerminationCause.AppLayerClose; +import jdk.internal.net.http.quic.TerminationCause.SilentTermination; +import jdk.internal.net.http.quic.TerminationCause.TransportError; +import jdk.internal.net.http.quic.frames.ConnectionCloseFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicTLSEngine.KeySpace; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import static jdk.internal.net.http.quic.QuicConnectionImpl.QuicConnectionState.CLOSED; +import static jdk.internal.net.http.quic.QuicConnectionImpl.QuicConnectionState.CLOSING; +import static jdk.internal.net.http.quic.QuicConnectionImpl.QuicConnectionState.DRAINING; +import static jdk.internal.net.http.quic.TerminationCause.appLayerClose; +import static jdk.internal.net.http.quic.TerminationCause.forSilentTermination; +import static jdk.internal.net.http.quic.TerminationCause.forTransportError; +import static jdk.internal.net.quic.QuicTransportErrors.INTERNAL_ERROR; +import static jdk.internal.net.quic.QuicTransportErrors.NO_ERROR; + +final class ConnectionTerminatorImpl implements ConnectionTerminator { + + private final QuicConnectionImpl connection; + private final Logger debug; + private final String logTag; + private final AtomicReference terminationCause = new AtomicReference<>(); + private final CompletableFuture futureTC = new MinimalFuture<>(); + + ConnectionTerminatorImpl(final QuicConnectionImpl connection) { + this.connection = Objects.requireNonNull(connection, "connection"); + this.debug = connection.debug; + this.logTag = connection.logTag(); + } + + @Override + public void keepAlive() { + this.connection.idleTimeoutManager.keepAlive(); + } + + @Override + public boolean tryReserveForUse() { + return this.connection.idleTimeoutManager.tryReserveForUse(); + } + + @Override + public void terminate(final TerminationCause cause) { + Objects.requireNonNull(cause); + try { + doTerminate(cause); + } catch (Throwable t) { + // make sure we do fail the handshake CompletableFuture(s) + // even when the connection termination itself failed. that way + // the dependent CompletableFuture(s) tasks don't keep waiting forever + failHandshakeCFs(t); + } + } + + TerminationCause getTerminationCause() { + return this.terminationCause.get(); + } + + private void doTerminate(final TerminationCause cause) { + final ConnectionCloseFrame frame; + KeySpace keySpace; + switch (cause) { + case SilentTermination st -> { + silentTerminate(st); + return; + } + case TransportError te -> { + frame = new ConnectionCloseFrame(te.getCloseCode(), te.frameType, + te.getPeerVisibleReason()); // 0x1c + keySpace = te.keySpace; + } + case TerminationCause.InternalError ie -> { + frame = new ConnectionCloseFrame(ie.getCloseCode(), 0, + ie.getPeerVisibleReason()); // 0x1c + keySpace = null; + } + case AppLayerClose alc -> { + // application layer triggered connection close + frame = new ConnectionCloseFrame(alc.getCloseCode(), + alc.getPeerVisibleReason()); // 0x1d + keySpace = null; + } + } + if (keySpace == null) { + // TODO: review this + keySpace = connection.getTLSEngine().getCurrentSendKeySpace(); + } + immediateClose(frame, keySpace, cause); + } + + void incomingConnectionCloseFrame(final ConnectionCloseFrame frame) { + Objects.requireNonNull(frame); + if (debug.on()) { + debug.log("Received close frame: %s", frame); + } + drain(frame); + } + + void incomingStatelessReset() { + // if local endpoint is a client, then our peer is a server + final boolean peerIsServer = connection.isClientConnection(); + if (Log.errors()) { + Log.logError("{0}: stateless reset from peer ({1})", connection.logTag(), + (peerIsServer ? "server" : "client")); + } + final SilentTermination st = forSilentTermination("stateless reset from peer (" + + (peerIsServer ? "server" : "client") + ")"); + terminate(st); + } + + /** + * Called only when the connection is expected to be discarded without being required + * to inform the peer. + * Discards all state, no CONNECTION_CLOSE is sent, nor does the connection enter closing + * or discarding state. + */ + private void silentTerminate(final SilentTermination terminationCause) { + // shutdown the idle timeout manager since we no longer bother with idle timeout + // management for this connection + connection.idleTimeoutManager.shutdown(); + // mark the connection state as closed (we don't enter closing or draining state + // during silent termination) + if (!markClosed(terminationCause)) { + // previously already closed + return; + } + if (Log.quic()) { + Log.logQuic("{0} silently terminating connection due to: {1}", + logTag, terminationCause.getLogMsg()); + } else if (debug.on()) { + debug.log("silently terminating connection due to: " + terminationCause.getLogMsg()); + } + if (debug.on() || Log.quic()) { + String message = connection.loggableState(); + if (message != null) { + Log.logQuic("{0} connection state: {1}", logTag, message); + debug.log("connection state: %s", message); + } + } + failHandshakeCFs(); + // remove from the endpoint + unregisterConnFromEndpoint(); + discardConnectionState(); + // terminate the streams + connection.streams.terminate(terminationCause); + } + + CompletableFuture futureTerminationCause() { + return this.futureTC; + } + + private void unregisterConnFromEndpoint() { + final QuicEndpoint endpoint = this.connection.endpoint(); + if (endpoint == null) { + // this can happen if the connection is being terminated before + // an endpoint has been established (which is OK) + return; + } + endpoint.removeConnection(this.connection); + } + + private void immediateClose(final ConnectionCloseFrame closeFrame, + final KeySpace keySpace, + final TerminationCause terminationCause) { + assert closeFrame != null : "connection close frame is null"; + assert keySpace != null : "keyspace is null"; + final String logMsg = terminationCause.getLogMsg(); + // if the connection has already been closed (for example: through silent termination) + // then the local state of the connection is already discarded and thus + // there's nothing more we can do with the connection. + if (connection.stateHandle().isMarked(CLOSED)) { + return; + } + // switch to closing state + if (!markClosing(terminationCause)) { + // has previously already gone into closing state + return; + } + // shutdown the idle timeout manager since we no longer bother with idle timeout + // management for a closing connection + connection.idleTimeoutManager.shutdown(); + + if (connection.stateHandle().draining()) { + if (Log.quic()) { + Log.logQuic("{0} skipping immediate close, since connection is already" + + " in draining state", logTag, logMsg); + } else if (debug.on()) { + debug.log("skipping immediate close, since connection is already" + + " in draining state"); + } + // we are already (in the subsequent) draining state, no need to anything more + return; + } + try { + final String closeCodeHex = (terminationCause.isAppLayer() ? "(app layer) " : "") + + "0x" + Long.toHexString(closeFrame.errorCode()); + if (Log.quic()) { + Log.logQuic("{0} entering closing state, code {1} - {2}", logTag, closeCodeHex, logMsg); + } else if (debug.on()) { + debug.log("entering closing state, code " + closeCodeHex + " - " + logMsg); + } + pushConnectionCloseFrame(keySpace, closeFrame); + } catch (Exception e) { + if (Log.errors()) { + Log.logError("{0} removing connection from endpoint after failure to send" + + " CLOSE_CONNECTION: {1}", logTag, e); + } else if (debug.on()) { + debug.log("removing connection from endpoint after failure to send" + + " CLOSE_CONNECTION"); + } + // we failed to send a CONNECTION_CLOSE frame. this implies that the QuicEndpoint + // won't detect that the QuicConnectionImpl has transitioned to closing connection + // and thus won't remap it to closing. we thus discard such connection from the + // endpoint. + unregisterConnFromEndpoint(); + } + failHandshakeCFs(); + discardConnectionState(); + connection.streams.terminate(terminationCause); + if (Log.quic()) { + Log.logQuic("{0} connection has now transitioned to closing state", logTag); + } else if (debug.on()) { + debug.log("connection has now transitioned to closing state"); + } + } + + private void drain(final ConnectionCloseFrame incomingFrame) { + // if the connection has already been closed (for example: through silent termination) + // then the local state of the connection is already discarded and thus + // there's nothing more we can do with the connection. + if (connection.stateHandle().isMarked(CLOSED)) { + return; + } + final boolean isAppLayerClose = incomingFrame.variant(); + final String closeCodeString = isAppLayerClose ? + "[app]" + connection.quicInstance().appErrorToString(incomingFrame.errorCode()) : + QuicTransportErrors.toString(incomingFrame.errorCode()); + final String reason = incomingFrame.reasonString(); + final String peer = connection.isClientConnection() ? "server" : "client"; + final String msg = "Connection closed by " + peer + " peer: " + + closeCodeString + + (reason == null || reason.isEmpty() ? "" : (" " + reason)); + final TerminationCause terminationCause; + if (isAppLayerClose) { + terminationCause = appLayerClose(incomingFrame.errorCode(), msg) + .peerVisibleReason(reason); + } else { + terminationCause = forTransportError(incomingFrame.errorCode(), msg, + incomingFrame.errorFrameType()) + .peerVisibleReason(reason); + } + // switch to draining state + if (!markDraining(terminationCause)) { + // has previously already gone into draining state + return; + } + // shutdown the idle timeout manager since we no longer bother with idle timeout + // management for a closing connection + connection.idleTimeoutManager.shutdown(); + + if (Log.quic()) { + Log.logQuic("{0} entering draining state, {1}", logTag, + terminationCause.getLogMsg()); + } else if (debug.on()) { + debug.log("entering draining state, " + + terminationCause.getLogMsg()); + } + // RFC-9000, section 10.2.2: + // An endpoint that receives a CONNECTION_CLOSE frame MAY send a single packet containing + // a CONNECTION_CLOSE frame before entering the draining state, using a NO_ERROR code if + // appropriate. An endpoint MUST NOT send further packets. + // if we had previously marked our state as closing, then that implies + // we would have already sent a connection close frame. we won't send + // another when draining in such a case. + if (markClosing(terminationCause)) { + try { + if (Log.quic()) { + Log.logQuic("{0} sending CONNECTION_CLOSE frame before entering draining state", + logTag); + } else if (debug.on()) { + debug.log("sending CONNECTION_CLOSE frame before entering draining state"); + } + final ConnectionCloseFrame outgoingFrame = + new ConnectionCloseFrame(NO_ERROR.code(), incomingFrame.getTypeField(), null); + final KeySpace currentKeySpace = connection.getTLSEngine().getCurrentSendKeySpace(); + pushConnectionCloseFrame(currentKeySpace, outgoingFrame); + } catch (Exception e) { + // just log and ignore, since sending the CONNECTION_CLOSE when entering + // draining state is optional + if (Log.errors()) { + Log.logError(logTag + " Failed to send CONNECTION_CLOSE frame," + + " when entering draining state: {0}", e); + } else if (debug.on()) { + debug.log("failed to send CONNECTION_CLOSE frame, when entering" + + " draining state: " + e); + } + } + } + failHandshakeCFs(); + // remap the connection to a draining connection + final QuicEndpoint endpoint = this.connection.endpoint(); + assert endpoint != null : "QUIC endpoint is null"; + endpoint.draining(connection); + discardConnectionState(); + connection.streams.terminate(terminationCause); + if (Log.quic()) { + Log.logQuic("{0} connection has now transitioned to draining state", logTag); + } else if (debug.on()) { + debug.log("connection has now transitioned to draining state"); + } + } + + private void discardConnectionState() { + // close packet spaces + connection.packetNumberSpaces().close(); + // close the incoming packets buffered queue + connection.closeIncoming(); + } + + private void failHandshakeCFs() { + final TerminationCause tc = this.terminationCause.get(); + assert tc != null : "termination cause is null"; + failHandshakeCFs(tc.getCloseCause()); + } + + private void failHandshakeCFs(final Throwable cause) { + final HandshakeFlow handshakeFlow = connection.handshakeFlow(); + handshakeFlow.failHandshakeCFs(cause); + } + + private boolean markClosing(final TerminationCause terminationCause) { + return mark(CLOSING, terminationCause); + } + + private boolean markDraining(final TerminationCause terminationCause) { + return mark(DRAINING, terminationCause); + } + + private boolean markClosed(final TerminationCause terminationCause) { + return mark(CLOSED, terminationCause); + } + + private boolean mark(final int mask, final TerminationCause cause) { + assert cause != null : "termination cause is null"; + final boolean causeSet = this.terminationCause.compareAndSet(null, cause); + // first mark the state appropriately, before completing the futureTerminationCause + // completable future, so that any dependent actions on the completable future + // will see the right state + final boolean marked = this.connection.stateHandle().mark(mask); + if (causeSet) { + this.futureTC.completeAsync(() -> cause, connection.quicInstance().executor()); + } + return marked; + } + + /** + * CONNECTION_CLOSE frame is not congestion controlled (RFC-9002 section 3 + * and RFC-9000 section 12.4, table 3), nor is it queued or scheduled for sending. + * This method constructs a {@link QuicPacket} containing the {@code frame} and immediately + * {@link QuicConnectionImpl#pushDatagram(ProtectionRecord) pushes the datagram} through + * the connection. + * + * @param keySpace the KeySpace to use for sending the packet + * @param frame the CONNECTION_CLOSE frame + * @throws QuicKeyUnavailableException if the keys for the KeySpace aren't available + * @throws QuicTransportException for any QUIC transport exception when sending the packet + */ + private void pushConnectionCloseFrame(final KeySpace keySpace, + final ConnectionCloseFrame frame) + throws QuicKeyUnavailableException, QuicTransportException { + // ConnectionClose frame is allowed in Initial, Handshake, 0-RTT, 1-RTT spaces. + // for Initial and Handshake space, the frame is expected to be of type 0x1c. + // see RFC-9000, section 12.4, Table 3 for additional details + final ConnectionCloseFrame toSend = switch (keySpace) { + case ONE_RTT, ZERO_RTT -> frame; + case INITIAL, HANDSHAKE -> { + // RFC 9000 - section 10.2.3: + // A CONNECTION_CLOSE of type 0x1d MUST be replaced by a CONNECTION_CLOSE + // of type 0x1c when sending the frame in Initial or Handshake packets. + // Otherwise, information about the application state might be revealed. + // Endpoints MUST clear the value of the Reason Phrase field and SHOULD + // use the APPLICATION_ERROR code when converting to a CONNECTION_CLOSE + // of type 0x1c. + yield frame.clearApplicationState(); + } + default -> { + throw new IllegalStateException("cannot send a connection close frame" + + " in keyspace: " + keySpace); + } + }; + final QuicPacket packet = connection.newQuicPacket(keySpace, List.of(toSend)); + final ProtectionRecord protectionRecord = ProtectionRecord.single(packet, + connection::allocateDatagramForEncryption); + // while sending the packet containing the CONNECTION_CLOSE frame, the pushDatagram will + // remap (or remove) the QuicConnectionImpl in QuicEndpoint. + connection.pushDatagram(protectionRecord); + } + + /** + * Returns a {@link ByteBuffer} which contains an encrypted QUIC packet containing + * a {@linkplain ConnectionCloseFrame CONNECTION_CLOSE frame}. The CONNECTION_CLOSE + * frame will have a frame type of {@code 0x1c} and error code of {@code NO_ERROR}. + *

+ * This method should only be invoked when the {@link QuicEndpoint} is being closed + * and the endpoint wants to send out a {@code CONNECTION_CLOSE} frame on a best-effort + * basis (in a fire and forget manner). + * + * @return the datagram containing the QUIC packet with a CONNECTION_CLOSE frame + * @throws QuicKeyUnavailableException + * @throws QuicTransportException + */ + ByteBuffer makeConnectionCloseDatagram() + throws QuicKeyUnavailableException, QuicTransportException { + // in theory we don't need this assert, but given the knowledge that this method + // should only be invoked by a closing QuicEndpoint, we have this assert here to + // prevent misuse of this makeConnectionCloseDatagram() method + assert connection.endpoint().isClosed() : "QUIC endpoint isn't closed"; + final ConnectionCloseFrame connCloseFrame = new ConnectionCloseFrame(NO_ERROR.code(), + QuicFrame.CONNECTION_CLOSE, null); + final KeySpace keySpace = connection.getTLSEngine().getCurrentSendKeySpace(); + // we don't want the connection's ByteBuffer pooling infrastructure + // (through the QuicConnectionImpl::allocateDatagramForEncryption) for + // this packet, so we use a simple custom allocator. + final Function allocator = (pkt) -> ByteBuffer.allocate(pkt.size()); + final QuicPacket packet = connection.newQuicPacket(keySpace, List.of(connCloseFrame)); + final ProtectionRecord encrypted = ProtectionRecord.single(packet, allocator) + .encrypt(connection.codingContext()); + final ByteBuffer datagram = encrypted.datagram(); + final int firstPacketOffset = encrypted.firstPacketOffset(); + // flip the datagram + datagram.limit(datagram.position()); + datagram.position(firstPacketOffset); + return datagram; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/IdleTimeoutManager.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/IdleTimeoutManager.java new file mode 100644 index 00000000000..a7469f18ed8 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/IdleTimeoutManager.java @@ -0,0 +1,528 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; + +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.TimeLine; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace; +import jdk.internal.net.quic.QuicTLSEngine; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static jdk.internal.net.http.quic.TerminationCause.forSilentTermination; + +/** + * Keeps track of activity on a {@code QuicConnectionImpl} and manages + * the idle timeout of the QUIC connection + */ +final class IdleTimeoutManager { + + private static final long NO_IDLE_TIMEOUT = 0; + + private final QuicConnectionImpl connection; + private final Logger debug; + private final AtomicBoolean shutdown = new AtomicBoolean(); + private final AtomicLong idleTimeoutDurationMs = new AtomicLong(); + private final ReentrantLock stateLock = new ReentrantLock(); + // must be accessed only when holding stateLock + private IdleTimeoutEvent idleTimeoutEvent; + // must be accessed only when holding stateLock + private StreamDataBlockedEvent streamDataBlockedEvent; + // the time at which the last outgoing packet was sent or an + // incoming packet processed on the connection + private volatile long lastPacketActivityAt; + + private final ReentrantLock idleTerminationLock = new ReentrantLock(); + // true if it has been decided to terminate the connection due to being idle, + // false otherwise. should be accessed only when holding the idleTerminationLock + private boolean chosenForIdleTermination; + // the time at which the connection was last reserved for use. + // should be accessed only when holding the idleTerminationLock + private long lastUsageReservationAt; + + IdleTimeoutManager(final QuicConnectionImpl connection) { + this.connection = Objects.requireNonNull(connection, "connection"); + this.debug = connection.debug; + } + + /** + * Starts the idle timeout management for the connection. This should be called + * after the handshake is complete for the connection. + * + * @throw IllegalStateException if handshake hasn't yet completed or if the handshake + * has failed for the connection + */ + void start() { + final CompletableFuture handshakeCF = + this.connection.handshakeFlow().handshakeCF(); + // start idle management only for successfully completed handshake + if (!handshakeCF.isDone()) { + throw new IllegalStateException("handshake isn't yet complete," + + " cannot start idle connection management"); + } + if (handshakeCF.isCompletedExceptionally()) { + throw new IllegalStateException("cannot start idle connection management for a failed" + + " connection"); + } + startTimers(); + } + + /** + * Starts the idle timeout timer of the QUIC connection, if not already started. + */ + private void startTimers() { + if (shutdown.get()) { + return; + } + this.stateLock.lock(); + try { + if (shutdown.get()) { + return; + } + startIdleTerminationTimer(); + startStreamDataBlockedTimer(); + } finally { + this.stateLock.unlock(); + } + } + + private void startIdleTerminationTimer() { + assert stateLock.isHeldByCurrentThread() : "not holding state lock"; + final Optional idleTimeoutMillis = getIdleTimeout(); + if (idleTimeoutMillis.isEmpty()) { + if (debug.on()) { + debug.log("idle connection management disabled for connection"); + } else { + Log.logQuic("{0} idle connection management disabled for connection", + connection.logTag()); + } + return; + } + final QuicTimerQueue timerQueue = connection.endpoint().timer(); + final Deadline deadline = timeLine().instant().plusMillis(idleTimeoutMillis.get()); + // we don't expect idle timeout management to be started more than once + assert this.idleTimeoutEvent == null : "idle timeout management" + + " already started for connection"; + // create the idle timeout event and register with the QuicTimerQueue. + this.idleTimeoutEvent = new IdleTimeoutEvent(deadline); + timerQueue.offer(this.idleTimeoutEvent); + if (debug.on()) { + debug.log("started QUIC idle timeout management for connection," + + " idle timeout event: " + this.idleTimeoutEvent + + " deadline: " + deadline); + } else { + Log.logQuic("{0} started QUIC idle timeout management for connection," + + " idle timeout event: {1} deadline: {2}", + connection.logTag(), this.idleTimeoutEvent, deadline); + } + } + + private void stopIdleTerminationTimer() { + assert stateLock.isHeldByCurrentThread() : "not holding state lock"; + if (this.idleTimeoutEvent == null) { + return; + } + final QuicEndpoint endpoint = this.connection.endpoint(); + assert endpoint != null : "QUIC endpoint is null"; + // disable the event (refreshDeadline() of IdleTimeoutEvent will return Deadline.MAX) + final Deadline nextDeadline = this.idleTimeoutEvent.nextDeadline; + if (!nextDeadline.equals(Deadline.MAX)) { + this.idleTimeoutEvent.nextDeadline = Deadline.MAX; + endpoint.timer().reschedule(this.idleTimeoutEvent, Deadline.MIN); + } + this.idleTimeoutEvent = null; + } + + private void startStreamDataBlockedTimer() { + assert stateLock.isHeldByCurrentThread() : "not holding state lock"; + // 75% of idle timeout or if idle timeout is not configured, then 30 seconds + final long timeoutMillis = getIdleTimeout() + .map((v) -> (long) (0.75 * v)) + .orElse(30000L); + final QuicTimerQueue timerQueue = connection.endpoint().timer(); + final Deadline deadline = timeLine().instant().plusMillis(timeoutMillis); + // we don't expect the timer to be started more than once + assert this.streamDataBlockedEvent == null : "STREAM_DATA_BLOCKED timer already started"; + // create the timeout event and register with the QuicTimerQueue. + this.streamDataBlockedEvent = new StreamDataBlockedEvent(deadline, timeoutMillis); + timerQueue.offer(this.streamDataBlockedEvent); + if (debug.on()) { + debug.log("started STREAM_DATA_BLOCKED timer for connection," + + " event: " + this.streamDataBlockedEvent + + " deadline: " + deadline); + } else { + Log.logQuic("{0} started STREAM_DATA_BLOCKED timer for connection," + + " event: {1} deadline: {2}", + connection.logTag(), this.streamDataBlockedEvent, deadline); + } + } + + private void stopStreamDataBlockedTimer() { + assert stateLock.isHeldByCurrentThread() : "not holding state lock"; + if (this.streamDataBlockedEvent == null) { + return; + } + final QuicEndpoint endpoint = this.connection.endpoint(); + assert endpoint != null : "QUIC endpoint is null"; + // disable the event (refreshDeadline() of StreamDataBlockedEvent will return Deadline.MAX) + final Deadline nextDeadline = this.streamDataBlockedEvent.nextDeadline; + if (!nextDeadline.equals(Deadline.MAX)) { + this.streamDataBlockedEvent.nextDeadline = Deadline.MAX; + endpoint.timer().reschedule(this.streamDataBlockedEvent, Deadline.MIN); + } + this.streamDataBlockedEvent = null; + } + + /** + * Attempts to notify the idle connection management that this connection should + * be considered "in use". This way the idle connection management doesn't close + * this connection during the time the connection is handed out from the pool and any + * new stream created on that connection. + * + * @return true if the connection has been successfully reserved and is {@link #isOpen()}. false + * otherwise; in which case the connection must not be handed out from the pool. + */ + boolean tryReserveForUse() { + this.idleTerminationLock.lock(); + try { + if (chosenForIdleTermination) { + // idle termination has been decided for this connection, don't use it + return false; + } + // if the connection is nearing idle timeout due to lack of traffic then + // don't use it + final long lastPktActivity = lastPacketActivityAt; + final long currentNanos = System.nanoTime(); + final long inactivityMs = MILLISECONDS.convert((currentNanos - lastPktActivity), + NANOSECONDS); + final boolean nearingIdleTimeout = getIdleTimeout() + .map((timeoutMillis) -> inactivityMs >= (0.8 * timeoutMillis)) // 80% of idle timeout + .orElse(false); + if (nearingIdleTimeout) { + return false; + } + // express interest in using the connection + this.lastUsageReservationAt = System.nanoTime(); + return true; + } finally { + this.idleTerminationLock.unlock(); + } + } + + + /** + * Returns the idle timeout duration, in milliseconds, negotiated for the connection represented + * by this {@code IdleTimeoutManager}. The negotiated idle timeout of a connection + * is the minimum of the idle connection timeout that is advertised by the + * endpoint represented by this {@code IdleTimeoutManager} and the idle + * connection timeout advertised by the peer. If neither endpoints have advertised + * any idle connection timeout then this method returns an + * {@linkplain Optional#empty() empty} value. + * + * @return the idle timeout in milliseconds or {@linkplain Optional#empty() empty} + */ + Optional getIdleTimeout() { + final long val = this.idleTimeoutDurationMs.get(); + return val == NO_IDLE_TIMEOUT ? Optional.empty() : Optional.of(val); + } + + void keepAlive() { + lastPacketActivityAt = System.nanoTime(); // TODO: timeline().instant()? + } + + void shutdown() { + if (!shutdown.compareAndSet(false, true)) { + // already shutdown + return; + } + this.stateLock.lock(); + try { + // unregister the timeout events from the QuicTimerQueue + stopIdleTerminationTimer(); + stopStreamDataBlockedTimer(); + } finally { + this.stateLock.unlock(); + } + if (debug.on()) { + debug.log("idle timeout manager shutdown"); + } + } + + void localIdleTimeout(final long timeoutMillis) { + checkUpdateIdleTimeout(timeoutMillis); + } + + void peerIdleTimeout(final long timeoutMillis) { + checkUpdateIdleTimeout(timeoutMillis); + } + + private void checkUpdateIdleTimeout(final long newIdleTimeoutMillis) { + if (newIdleTimeoutMillis <= 0) { + // idle timeout should be non-zero value, we disregard other values + return; + } + long current; + boolean updated = false; + // update the idle timeout if the new timeout is lesser + // than the previously set value + while ((current = this.idleTimeoutDurationMs.get()) == NO_IDLE_TIMEOUT + || current > newIdleTimeoutMillis) { + updated = this.idleTimeoutDurationMs.compareAndSet(current, newIdleTimeoutMillis); + if (updated) { + break; + } + } + if (!updated) { + return; + } + if (debug.on()) { + debug.log("idle connection timeout updated to " + + newIdleTimeoutMillis + " milli seconds"); + } else { + Log.logQuic("{0} idle connection timeout updated to {1} milli seconds", + connection.logTag(), newIdleTimeoutMillis); + } + } + + private TimeLine timeLine() { + return this.connection.endpoint().timeSource(); + } + + // called when the connection has been idle past its idle timeout duration + private void idleTimedOut() { + if (shutdown.get()) { + return; // nothing to do - the idle timeout manager has been shutdown + } + final Optional timeoutVal = getIdleTimeout(); + assert timeoutVal.isPresent() : "unexpectedly idle timing" + + " out connection, when no idle timeout is configured"; + final long timeoutMillis = timeoutVal.get(); + if (Log.quic() || debug.on()) { + // log idle timeout, with packet space statistics + final String msg = "silently terminating connection due to idle timeout (" + + timeoutMillis + " milli seconds)"; + StringBuilder sb = new StringBuilder(); + for (PacketNumberSpace sp : PacketNumberSpace.values()) { + if (sp == PacketNumberSpace.NONE) continue; + if (connection.packetNumberSpaces().get(sp) instanceof PacketSpaceManager m) { + sb.append("\n PacketSpace: ").append(sp).append('\n'); + m.debugState(" ", sb); + } + } + if (Log.quic()) { + Log.logQuic("{0} {1}: {2}", connection.logTag(), msg, sb.toString()); + } else if (debug.on()) { + debug.log("%s: %s", msg, sb); + } + } + // silently close the connection and discard all its state + final TerminationCause cause = forSilentTermination("connection idle timed out (" + + timeoutMillis + " milli seconds)"); + connection.terminator.terminate(cause); + } + + private long computeInactivityMillis() { + final long currentNanos = System.nanoTime(); + final long lastActiveNanos = Math.max(lastPacketActivityAt, lastUsageReservationAt); + return MILLISECONDS.convert((currentNanos - lastActiveNanos), NANOSECONDS); + } + + final class IdleTimeoutEvent implements QuicTimedEvent { + private final long eventId; + private volatile Deadline deadline; + private volatile Deadline nextDeadline; + + private IdleTimeoutEvent(final Deadline deadline) { + assert deadline != null : "timeout deadline is null"; + this.deadline = this.nextDeadline = deadline; + this.eventId = QuicTimerQueue.newEventId(); + } + + @Override + public Deadline deadline() { + return this.deadline; + } + + @Override + public Deadline refreshDeadline() { + if (shutdown.get()) { + return this.deadline = this.nextDeadline = Deadline.MAX; + } + return this.deadline = this.nextDeadline; + } + + @Override + public Deadline handle() { + if (shutdown.get()) { + // timeout manager is shutdown, nothing more to do + return this.nextDeadline = Deadline.MAX; + } + final Optional idleTimeout = getIdleTimeout(); + if (idleTimeout.isEmpty()) { + // nothing to do, don't reschedule + return Deadline.MAX; + } + final long idleTimeoutMillis = idleTimeout.get(); + // check whether the connection has indeed been idle for the idle timeout duration + idleTerminationLock.lock(); + try { + Deadline postponed = maybePostponeDeadline(idleTimeoutMillis); + if (postponed != null) { + // not idle long enough, reschedule + this.nextDeadline = postponed; + return postponed; + } + chosenForIdleTermination = true; + } finally { + idleTerminationLock.unlock(); + } + // the connection has been idle for the idle timeout duration, go + // ahead and terminate it. + terminateNow(); + assert shutdown.get() : "idle timeout manager was expected to be shutdown"; + this.nextDeadline = Deadline.MAX; + return Deadline.MAX; + } + + private Deadline maybePostponeDeadline(final long expectedIdleDurationMs) { + assert idleTerminationLock.isHeldByCurrentThread() : "not holding idle termination lock"; + final long inactivityMs = computeInactivityMillis(); + if (inactivityMs >= expectedIdleDurationMs) { + // the connection has been idle long enough, don't postpone the timeout. + return null; + } + // not idle long enough, compute the deadline when it's expected to reach + // idle timeout + final long remainingMs = expectedIdleDurationMs - inactivityMs; + final Deadline next = timeLine().instant().plusMillis(remainingMs); + if (debug.on()) { + debug.log("postponing timeout event: " + this + " to fire" + + " in " + remainingMs + " milli seconds, deadline: " + next); + } + return next; + } + + private void terminateNow() { + try { + idleTimedOut(); + } finally { + shutdown(); + } + } + + @Override + public long eventId() { + return this.eventId; + } + + @Override + public String toString() { + return "QuicIdleTimeoutEvent-" + this.eventId; + } + } + + final class StreamDataBlockedEvent implements QuicTimedEvent { + private final long eventId; + private final long timeoutMillis; + private volatile Deadline deadline; + private volatile Deadline nextDeadline; + + private StreamDataBlockedEvent(final Deadline deadline, final long timeoutMillis) { + assert deadline != null : "timeout deadline is null"; + this.deadline = this.nextDeadline = deadline; + this.timeoutMillis = timeoutMillis; + this.eventId = QuicTimerQueue.newEventId(); + } + + @Override + public Deadline deadline() { + return this.deadline; + } + + @Override + public Deadline refreshDeadline() { + if (shutdown.get()) { + return this.deadline = this.nextDeadline = Deadline.MAX; + } + return this.deadline = this.nextDeadline; + } + + @Override + public Deadline handle() { + if (shutdown.get()) { + // timeout manager is shutdown, nothing more to do + return this.nextDeadline = Deadline.MAX; + } + // check whether the connection has indeed been idle for the idle timeout duration + idleTerminationLock.lock(); + try { + if (chosenForIdleTermination) { + // connection is already chosen for termination, no need to send + // a STREAM_DATA_BLOCKED + this.nextDeadline = Deadline.MAX; + return this.nextDeadline; + } + final long inactivityMs = computeInactivityMillis(); + if (inactivityMs >= timeoutMillis && connection.streams.hasBlockedStreams()) { + // has been idle long enough, but there are streams that are blocked due to + // flow control limits and that could have lead to the idleness. + // trigger sending a STREAM_DATA_BLOCKED frame for the streams + // to try and have their limits increased by the peer. + connection.streams.enqueueStreamDataBlocked(); + if (debug.on()) { + debug.log("enqueued a STREAM_DATA_BLOCKED frame since connection" + + " has been idle due to blocked stream(s)"); + } else { + Log.logQuic("{0} enqueued a STREAM_DATA_BLOCKED frame" + + " since connection has been idle due to" + + " blocked stream(s)", connection.logTag()); + } + } + this.nextDeadline = timeLine().instant().plusMillis(timeoutMillis); + return this.nextDeadline; + } finally { + idleTerminationLock.unlock(); + } + } + + @Override + public long eventId() { + return this.eventId; + } + + @Override + public String toString() { + return "StreamDataBlockedEvent-" + this.eventId; + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/LocalConnIdManager.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/LocalConnIdManager.java new file mode 100644 index 00000000000..76a1f251e75 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/LocalConnIdManager.java @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.frames.NewConnectionIDFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.frames.RetireConnectionIDFrame; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import java.util.NavigableMap; +import java.util.TreeMap; +import java.util.concurrent.locks.ReentrantLock; + +import static jdk.internal.net.quic.QuicTransportErrors.PROTOCOL_VIOLATION; + +/** + * Manages the connection ids advertised by the local endpoint of a connection. + * - Produces outgoing NEW_CONNECTION_ID frames, + * - handles incoming RETIRE_CONNECTION_ID frames, + * - registers produced connection IDs with the QuicEndpoint + * Handshake connection ID is created and registered by QuicConnection. + */ +final class LocalConnIdManager { + private final Logger debug; + private final QuicConnectionImpl connection; + private long nextConnectionIdSequence; + private final ReentrantLock lock = new ReentrantLock(); + private boolean closed; // when true, no more connection IDs are registered + + // the connection ids (there can be more than one) with which the endpoint identifies this connection. + // the key of this Map is a (RFC defined) sequence number for the connection id + private final NavigableMap localConnectionIds = + Collections.synchronizedNavigableMap(new TreeMap<>()); + + LocalConnIdManager(final QuicConnectionImpl connection, final String dbTag, + QuicConnectionId handshakeConnectionId) { + this.debug = Utils.getDebugLogger(() -> dbTag); + this.connection = connection; + this.localConnectionIds.put(nextConnectionIdSequence++, handshakeConnectionId); + } + + private QuicConnectionId newConnectionId() { + return connection.endpoint().idFactory().newConnectionId(); + + } + + private byte[] statelessTokenFor(QuicConnectionId cid) { + return connection.endpoint().idFactory().statelessTokenFor(cid); + } + + void handleRetireConnectionIdFrame(final QuicConnectionId incomingPacketDestConnId, + final QuicPacket.PacketType packetType, + final RetireConnectionIDFrame retireFrame) + throws QuicTransportException { + if (debug.on()) { + debug.log("Received RETIRE_CONNECTION_ID frame: %s", retireFrame); + } + final QuicConnectionId toRetire; + lock.lock(); + try { + final long seqNumber = retireFrame.sequenceNumber(); + if (seqNumber >= nextConnectionIdSequence) { + // RFC-9000, section 19.16: Receipt of a RETIRE_CONNECTION_ID frame containing a + // sequence number greater than any previously sent to the peer MUST be treated + // as a connection error of type PROTOCOL_VIOLATION + throw new QuicTransportException("Invalid sequence number " + seqNumber + + " in RETIRE_CONNECTION_ID frame", + packetType.keySpace().orElse(null), + retireFrame.getTypeField(), PROTOCOL_VIOLATION); + } + toRetire = this.localConnectionIds.get(seqNumber); + if (toRetire == null) { + return; + } + if (toRetire.equals(incomingPacketDestConnId)) { + // RFC-9000, section 19.16: The sequence number specified in a RETIRE_CONNECTION_ID + // frame MUST NOT refer to the Destination Connection ID field of the packet in which + // the frame is contained. The peer MAY treat this as a connection error of type + // PROTOCOL_VIOLATION. + throw new QuicTransportException("Invalid connection id in RETIRE_CONNECTION_ID frame", + packetType.keySpace().orElse(null), + retireFrame.getTypeField(), PROTOCOL_VIOLATION); + } + // forget this id from our local store + this.localConnectionIds.remove(seqNumber); + this.connection.endpoint().removeConnectionId(toRetire, connection); + } finally { + lock.unlock(); + } + if (debug.on()) { + debug.log("retired connection id " + toRetire); + } + } + + public QuicFrame nextFrame(int remaining) { + if (localConnectionIds.size() >= 2) { + return null; + } + int cidlen = connection.endpoint().idFactory().connectionIdLength(); + if (cidlen == 0) { + return null; + } + // frame: + // type - 1 byte + // sequence number - var int + // retire prior to - 1 byte (always zero) + // connection id: + 1 byte + // stateless reset token - 16 bytes + int len = 19 + cidlen + VariableLengthEncoder.getEncodedSize(nextConnectionIdSequence); + if (len > remaining) { + return null; + } + NewConnectionIDFrame newCidFrame; + QuicConnectionId cid = newConnectionId(); + byte[] token = statelessTokenFor(cid); + lock.lock(); + try { + if (closed) return null; + newCidFrame = new NewConnectionIDFrame(nextConnectionIdSequence++, 0, + cid.asReadOnlyBuffer(), ByteBuffer.wrap(token)); + this.localConnectionIds.put(newCidFrame.sequenceNumber(), cid); + this.connection.endpoint().addConnectionId(cid, connection); + if (debug.on()) { + debug.log("Sending NEW_CONNECTION_ID frame"); + } + return newCidFrame; + } finally { + lock.unlock(); + } + } + + public List connectionIds() { + lock.lock(); + try { + // copy to avoid ConcurrentModificationException + return List.copyOf(localConnectionIds.values()); + } finally { + lock.unlock(); + } + } + + public void close() { + lock.lock(); + closed = true; + lock.unlock(); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/OrderedFlow.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/OrderedFlow.java new file mode 100644 index 00000000000..52821a0fac2 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/OrderedFlow.java @@ -0,0 +1,389 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.util.Comparator; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.function.ToIntFunction; +import java.util.function.ToLongFunction; + +import jdk.internal.net.http.quic.frames.CryptoFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.frames.StreamFrame; + +/** + * A class to take care of frames reordering in an ordered flow. + * + * Frames that are {@linkplain #receive(QuicFrame) received} out of order + * will be either buffered or dropped, depending on their {@linkplain + * #OrderedFlow(Comparator, ToLongFunction, ToIntFunction) position} + * with respect to the current ordered flow {@linkplain #offset() offset}. + * The buffered frames are returned by later calls to {@linkplain #poll()} + * when the flow offset matches the frame offset. + * + * Frames that are {@linkplain #receive(QuicFrame) received} in order + * are immediately returned. + * + * This class is not thread-safe and concurrent access needs to be synchronized + * externally. + * @param A frame type that defines an offset and a {@linkplain + * #OrderedFlow(Comparator, ToLongFunction, ToIntFunction) + * length}. The offset of the frame + * indicates its {@linkplain + * #OrderedFlow(Comparator, ToLongFunction, ToIntFunction) + * position} in the ordered flow. + */ +public sealed abstract class OrderedFlow { + + /** + * A subclass of {@link OrderedFlow} used to reorder instances of + * {@link CryptoFrame}. + */ + public static final class CryptoDataFlow extends OrderedFlow { + /** + * Constructs a new instance of {@code CryptoDataFlow} to reorder + * a flow of {@code CryptoFrame} instances. + */ + public CryptoDataFlow() { + super(CryptoFrame::compareOffsets, + CryptoFrame::offset, + CryptoFrame::length); + } + + @Override + protected CryptoFrame slice(CryptoFrame frame, long offset, int length) { + if (length == 0) return null; + return frame.slice(offset, length); + } + } + + /** + * A subclass of {@link OrderedFlow} used to reorder instances of + * {@link StreamFrame}. + */ + public static final class StreamDataFlow extends OrderedFlow { + /** + * Constructs a new instance of {@code StreamDataFlow} to reorder + * a flow of {@code StreamFrame} instances. + */ + public StreamDataFlow() { + super(StreamFrame::compareOffsets, + StreamFrame::offset, + StreamFrame::dataLength); + } + + @Override + protected StreamFrame slice(StreamFrame frame, long offset, int length) { + if (length == 0) return null; + return frame.slice(offset, length); + } + } + + private final ConcurrentSkipListSet queue; + private final ToLongFunction position; + private final ToIntFunction length; + long offset; + long buffered; + + /** + * Constructs a new instance of ordered flow to reorder frames in a given + * flow. + * @param comparator A comparator to order the frames according to their position in + * the ordered flow. Typically, this will compare the + * frame's offset: the frame with the smaller offset will be sorted + * before the frame with the greater offset + * @param position A method reference that returns the position of the frame in the + * flow. For instance, this would be {@link CryptoFrame#offset() + * CryptoFrame::offset} if {@code } is {@code CryptoFrame}, or + * {@link StreamFrame#offset() StreamFrame::offset} if {@code } + * is {@code StreamFrame} + * @param length A method reference that returns the number of bytes in the frame data. + * This is used to compute the expected position of the next + * frame in the flow. For instance, this would be {@link CryptoFrame#length() + * CryptoFrame::length} if {@code } is {@code CryptoFrame}, or + * {@link StreamFrame#dataLength() StreamFrame::dataLength} if {@code } + * is {@code StreamFrame} + */ + public OrderedFlow(Comparator comparator, ToLongFunction position, + ToIntFunction length) { + queue = new ConcurrentSkipListSet<>(comparator); + this.position = position; + this.length = length; + } + + /** + * {@return a slice of the given frame} + * @param frame the frame to slice + * @param offset the new frame offset + * @param length the new frame length + * @throws IndexOutOfBoundsException if the new offset or length + * fall outside of the frame's bounds + */ + protected abstract T slice(T frame, long offset, int length); + + /** + * Receives a new frame. If the frame is below the current + * offset the frame is dropped. If it is above the current offset, + * it is queued. + * If the frame is exactly at the current offset, it is + * returned. + * + * @param frame a frame that was received + * @return the next frame in the flow, or {@code null} if it is not + * available yet. + */ + public T receive(T frame) { + if (frame == null) return null; + + long start = this.position.applyAsLong(frame); + int length = this.length.applyAsInt(frame); + long end = start + length; + assert length >= 0; + assert start >= 0; + long offset = this.offset; + if (end <= offset || length == 0) { + // late arrival or empty frame. Just drop it; No overlap + // if we reach here! + return null; + } else if (start > offset) { + // the frame is after the offset. + // insert or slice it, depending on what we + // have already received. + enqueue(frame, start, length, offset); + return null; + } else { + // case where the frame is either at offset, or is below + // offset but has a length that provides bytes that + // overlap with the current offset. In the later case + // we will return a slice. + int todeliver = (int)(end - offset); + + assert end == offset + todeliver; + // update the offset with the new position + this.offset = end; + // cleanup the queue + dropuntil(end); + if (start == offset) return frame; + return slice(frame, offset, todeliver); + } + } + + private T peekFirst() { + if (queue.isEmpty()) return null; + // why is there no peekFirst? + try { + return queue.first(); + } catch (NoSuchElementException nse) { + return null; + } + } + + private void enqueue(T frame, long pos, int length, long after) { + assert pos == position.applyAsLong(frame); + assert length == this.length.applyAsInt(frame); + assert pos > after; + long offset = this.offset; + assert offset >= after; + long newpos = pos; + int newlen = length; + long limit = Math.addExact(pos, length); + + // look at the closest frame, if any, whose offset is <= to + // the new frame offset. Try to see if the new frame overlaps + // with that frame, and if so, drops the part that overlaps + // in the new frame. + T floor = queue.floor(frame); + if (floor != null) { + long foffset = position.applyAsLong(floor); + long flen = this.length.applyAsInt(floor); + if (limit <= foffset + flen) { + // bytes already all buffered! + // just drop the frame + return; + } + assert foffset <= pos; + // foffset == pos case handled as ceiling below + if (foffset < pos && pos - foffset < flen) { + // reduce the frame if it overlaps with the + // one that sits just before in the queue + newpos = foffset + flen; + newlen = length - (int) (newpos - pos); + } + } + assert limit == newpos + newlen; + + // Look at the frames that have an offset higher or equal to + // the new frame offset, and see if any overlap with the new + // frame. Remove frames that are entirely contained in the new one, + // slice the current frame if the frames overlap. + while (true) { + T ceil = queue.ceiling(frame); + if (ceil != null) { + long coffset = position.applyAsLong(ceil); + assert coffset >= newpos : "overlapping frames in queue"; + if (coffset < limit) { + long clen = this.length.applyAsInt(ceil); + if (clen <= limit - coffset) { + // ceiling frame completely contained in the new frame: + // remove the ceiling frame + queue.remove(ceil); + buffered -= clen; + continue; + } + // safe cast, since newlen <= len + newlen = (int) (coffset - newpos); + } + } + break; + } + assert newlen >= 0; + if (newlen == length) { + assert newpos == pos; + queue.add(frame); + } else if (newlen > 0) { + queue.add(slice(frame, newpos, newlen)); + } + buffered += newlen; + } + + /** + * Removes and return the head of the queue if it is at the + * current offset. Otherwise, returns null. + * @return the head of the queue if it is at the current offset, + * or {@code null} + */ + public T poll() { + return poll(offset); + } + + /** + * {@return the number of buffered frames} + */ + public int size() { + return queue.size(); + } + + /** + * {@return the number of bytes buffered} + */ + public long buffered() { + return buffered; + } + + /** + * {@return true if there are no buffered frames} + */ + public boolean isEmpty() { + return queue.isEmpty(); + } + + /** + * {@return the current offset of this buffer} + */ + public long offset() { + return offset; + } + + /** + * Drops all buffered frames + */ + public void clear() { + queue.clear(); + } + + /** + * Drop all frames in the buffer whose position is strictly + * below offset. + * + * @param offset the offset below which frames should be dropped + * @return the amount of dropped data + */ + private long dropuntil(long offset) { + T head; + long pos; + long dropped = 0; + do { + head = peekFirst(); + if (head == null) break; + pos = position.applyAsLong(head); + if (pos < offset) { + var length = this.length.applyAsInt(head); + var consumed = offset - pos; + if (length <= consumed) { + // drop it + if (head == queue.pollFirst()) { + buffered -= length; + dropped += length; + } else { + throw new AssertionError("Concurrent modification"); + } + } else { + // safe cast: consumed < length if we reach here + int newlen = length - (int)consumed; + var newhead = slice(head, offset, newlen); + if (head == queue.pollFirst()) { + queue.add(newhead); + buffered -= consumed; + dropped += consumed; + } else { + throw new AssertionError("Concurrent modification"); + } + } + } + } while (pos < offset); + return dropped; + } + + /** + * Pretends to {@linkplain #receive(QuicFrame) receive} the head of the queue, + * if it is at the provided offset + * + * @param offset the minimal offset + * + * @return a received frame at the current flow offset, or {@code null} + */ + private T poll(long offset) { + long current = this.offset; + assert offset <= current; + dropuntil(offset); + T head = peekFirst(); + if (head != null) { + long pos = position.applyAsLong(head); + if (pos == offset) { + // the frame we wanted was in the queue! + // well, let's handle it... + if (head == queue.pollFirst()) { + long length = this.length.applyAsInt(head); + buffered -= length; + } else { + throw new AssertionError("Concurrent modification"); + } + return receive(head); + } + } + return null; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/PacketEmitter.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/PacketEmitter.java new file mode 100644 index 00000000000..35b42333a1e --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/PacketEmitter.java @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.util.concurrent.Executor; + +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.quic.frames.AckFrame; +import jdk.internal.net.http.quic.packets.PacketSpace; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace; +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicTransportException; + +/** + * This interface is a useful abstraction used to tie + * {@link PacketSpaceManager} and {@link QuicConnectionImpl}. + * The {@link PacketSpaceManager} uses functionalities provided + * by a {@link PacketEmitter} when it deems that a packet needs + * to be retransmitted, or that an acknowledgement is due. + * It also uses the emitter's {@linkplain #timer() timer facility} + * when it needs to register a {@link QuicTimedEvent}. + * + * @apiNote + * All these methods are actually implemented by {@link QuicConnectionImpl} + * but the {@code PacketEmitter} interface makes it possible to write + * unit tests against a {@link PacketSpaceManager} without involving + * any {@code QuicConnection} instance. + * + */ +public interface PacketEmitter { + /** + * {@return the timer queue used by this packet emitter} + */ + QuicTimerQueue timer(); + + /** + * Retransmit the given packet on behalf of the given packet space + * manager. + * @param packetSpaceManager the packet space manager on behalf of + * which the packet is being retransmitted + * @param packet the unacknowledged packet which should be retransmitted + * @param attempts the number of previous retransmission of this packet. + * A value of 0 indicates the first retransmission. + */ + void retransmit(PacketSpace packetSpaceManager, QuicPacket packet, int attempts) + throws QuicKeyUnavailableException, QuicTransportException; + + /** + * Emit a possibly non ACK-eliciting packet containing the given ACK frame. + * @param packetSpaceManager the packet space manager on behalf + * of which the acknowledgement should + * be sent. + * @param ackFrame the ACK frame to be sent. + * @param sendPing whether a PING frame should be sent. + * @return the emitted packet number, or -1L if not applicable or not emitted + */ + long emitAckPacket(PacketSpace packetSpaceManager, AckFrame ackFrame, boolean sendPing) + throws QuicKeyUnavailableException, QuicTransportException; + + /** + * Called when a packet has been acknowledged. + * @param packet the acknowledged packet + */ + void acknowledged(QuicPacket packet); + + /** + * Called when congestion controller allows sending one packet + * @param packetNumberSpace current packet number space + * @return true if a packet was sent, false otherwise + */ + boolean sendData(PacketNumberSpace packetNumberSpace) + throws QuicKeyUnavailableException, QuicTransportException; + + /** + * {@return an executor to use when {@linkplain + * jdk.internal.net.http.common.SequentialScheduler#runOrSchedule(Executor) + * offloading loops to another thread} is required} + */ + Executor executor(); + + /** + * Reschedule the given event on the {@link #timer() timer} + * @param event the event to reschedule + */ + default void reschedule(QuicTimedEvent event) { + timer().reschedule(event); + } + + /** + * Reschedule the given event on the {@link #timer() timer} + * @param event the event to reschedule + */ + default void reschedule(QuicTimedEvent event, Deadline deadline) { + timer().reschedule(event, deadline); + } + + /** + * Abort the connection if needed, for example if the peer is not responding + * or max idle time was reached + */ + void checkAbort(PacketNumberSpace packetNumberSpace); + + /** + * {@return true if this emitter is open for transmitting packets, else returns false} + */ + boolean isOpen(); + + default void ptoBackoffIncreased(PacketSpaceManager space, long backoff) { }; + + default String logTag() { return toString(); } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/PacketSpaceManager.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/PacketSpaceManager.java new file mode 100644 index 00000000000..494af85447e --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/PacketSpaceManager.java @@ -0,0 +1,2370 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodHandles.Lookup; +import java.lang.invoke.VarHandle; +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.LongStream; + +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.TimeLine; +import jdk.internal.net.http.common.TimeSource; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.frames.AckFrame; +import jdk.internal.net.http.quic.frames.AckFrame.AckFrameBuilder; +import jdk.internal.net.http.quic.frames.ConnectionCloseFrame; +import jdk.internal.net.http.quic.packets.PacketSpace; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketType; +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicOneRttContext; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; + +/** + * A {@code PacketSpaceManager} takes care of acknowledgement and + * retransmission of packets for a given {@link PacketNumberSpace}. + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9002 + * RFC 9002: QUIC Loss Detection and Congestion Control + */ + +// See also: RFC 9000, https://www.rfc-editor.org/rfc/rfc9000#name-sending-ack-frames +// Every packet SHOULD be acknowledged at least once, and +// ack-eliciting packets MUST be acknowledged at least once within +// the maximum delay an endpoint communicated using the max_ack_delay +// transport parameter [...]; +// [...] +// In order to assist loss detection at the sender, an endpoint +// SHOULD generate and send an ACK frame without delay when it +// receives an ack-eliciting packet either: +// - when the received packet has a packet number less +// than another ack-eliciting packet that has been received, or +// - when the packet has a packet number larger than the +// highest-numbered ack-eliciting packet that has been received +// and there are missing packets between that packet and this +// packet. [...] +public sealed class PacketSpaceManager implements PacketSpace + permits PacketSpaceManager.OneRttPacketSpaceManager, + PacketSpaceManager.HandshakePacketSpaceManager { + + private final QuicCongestionController congestionController; + private volatile boolean blockedByCC; + // packet threshold for loss detection; RFC 9002 suggests 3 + private static final long kPacketThreshold = 3; + // Multiplier for persistent congestion; RFC 9002 suggests 3 + private static final int kPersistentCongestionThreshold = 3; + + /** + * A record that stores the next AckFrame that should be sent + * within this packet number space. + * + * @param ackFrame the ACK frame to send. + * @param deadline the deadline by which to send this ACK frame. + * @param lastUpdated the time at which the {@link AckFrame}'s + * {@link AckFrame#largestAcknowledged()} was + * last updated. Used for calculating ack delay. + * @param sent the time at which the {@link AckFrame} was sent, + * or {@code null} if it has not been sent yet. + */ + record NextAckFrame(AckFrame ackFrame, + Deadline deadline, + Deadline lastUpdated, + Deadline sent) { + /** + * {@return an identical {@code NextAckFrame} record, with an updated + * {@code deadline}} + * @param deadline the new deadline + * @param sent the point in time at which the ack frame was sent, or null. + */ + public NextAckFrame withDeadline(Deadline deadline, Deadline sent) { + return new NextAckFrame(ackFrame, deadline, lastUpdated, sent); + } + } + + // true if transmit timer should fire now + private volatile boolean transmitNow; + + // These two numbers control whether an PING frame will be + // sent with the next ACK frame, to turn the packet that + // contains the ACK frame into an ACK-eliciting packet. + // These numbers are *not* defined in RFC 9000, but are used + // to implement a strategy for sending occasional PING frames + // in order to prevent ACK frames from growing too big. + // See RFC 9000 section 13.2.4 + // https://www.rfc-editor.org/rfc/rfc9000#name-limiting-ranges-by-tracking + public static final int MAX_ACKRANGE_COUNT_BEFORE_PING = 10; + + protected final Logger debug; + private final Supplier debugStrSupplier; + private final PacketNumberSpace packetNumberSpace; + private final PacketEmitter packetEmitter; + private final ReentrantLock transferLock = new ReentrantLock(); + // The next packet number to use in this space + private final AtomicLong nextPN = new AtomicLong(); + private final TimeLine instantSource; + private final QuicRttEstimator rttEstimator; + // first packet number sent after handshake confirmed + private long handshakeConfirmedPN; + + // A priority queue containing a record for each unacknowledged PingRequest. + // PingRequest are removed from this queue when they are acknowledged, that + // is when any packet whose number is greater than the request packet + // is acknowledged. + // Note: this is used to implement {@link #requestSendPing()} which is + // used to implement out of band ping requests triggered by the + // application. + private final ConcurrentLinkedQueue pendingPingRequests = + new ConcurrentLinkedQueue<>(); + + // A priority queue containing a record for each unacknowledged packet. + // Packets are removed from this queue when they are acknowledged, or when they + // are being retransmitted. In which case, they will be in the pendingRetransmission + // queue + private final ConcurrentLinkedQueue pendingAcknowledgements = + new ConcurrentLinkedQueue<>(); + + // A map containing send times of ack-eliciting packets. + // Packets are removed from this map when they can't contribute to RTT sample, + // i.e. when they are acknowledged, or when a higher-numbered packet is acknowledged. + private final ConcurrentSkipListMap sendTimes = + new ConcurrentSkipListMap<>(); + + // A priority queue containing a record for each unacknowledged packet whose deadline + // is due, and which is currently being retransmitted. + // Packets are removed from this queue when they have been scheduled for retransmission + // with the quic endpoint + private final ConcurrentLinkedQueue pendingRetransmission = + new ConcurrentLinkedQueue<>(); + + // A priority queue containing a record for each unacknowledged packet whose deadline + // is due, and which should be retransmitted. + // Packets are removed from this queue when they have been scheduled for encryption. + private final ConcurrentLinkedQueue triggeredForRetransmission = + new ConcurrentLinkedQueue<>(); + + // lost packets + private final ConcurrentLinkedQueue lostPackets = + new ConcurrentLinkedQueue<>(); + + // A task invoked by the QuicTimerQueue when some packet retransmission are + // due. This task will move packets from the pendingAcknowledgement queue + // into the triggeredForRetransmission queue (and pendingRetransmission queue) + private final PacketTransmissionTask packetTransmissionTask; + // Used to synchronize transmission with handshake restarts + private final ReentrantLock transmitLock = new ReentrantLock(); + private volatile boolean fastRetransmitDone; + private volatile boolean fastRetransmit; + + /** + * A record to store previous numbers with which a packet has been + * retransmitted. If such a packet is acknowledged, we can stop + * retransmission. + * + * @param number A packet number with which the content of this + * packet was previously sent. + * @param largestAcknowledged the largest packet number acknowledged by this + * previous packet, or {@code -1L} if no packet was + * acknowledged by this packet. + * @param previous Further previous packet numbers, or {@code null}. + */ + private record PreviousNumbers(long number, + long largestAcknowledged, PreviousNumbers previous) {} + + /** + * A record used to implement {@link #requestSendPing()}. + * @param sent when the ping frame was sent + * @param packetNumber the packet number of the packet containing the pingframe + * @param response the response, which will be complete as soon as a packet whose number is + * >= to {@code packetNumber} is received. + */ + private record PingRequest(Deadline sent, long packetNumber, CompletableFuture response) {} + + /** + * A record to store a packet that hasn't been acknowledged, and should + * be scheduled for retransmission if not acknowledged when the deadline + * is reached. + * + * @param packet the unacknowledged quic packet + * @param sent the instant when the packet was sent. + * @param packetNumber the packet number of the {@code packet} + * @param largestAcknowledged the largest packet number acknowledged by this + * packet, or {@code -1L} if no packet is acknowledged + * by this packet. + * @param previousNumbers previous packet numbers with which the packet was + * transmitted, if any, {@code null} otherwise. + */ + private record PendingAcknowledgement(QuicPacket packet, Deadline sent, + long packetNumber, long largestAcknowledged, + PreviousNumbers previousNumbers) { + + PendingAcknowledgement(QuicPacket packet, Deadline sent, + long packetNumber, PreviousNumbers previousNumbers) { + this(packet, sent, packetNumber, + AckFrame.largestAcknowledgedInPacket(packet), previousNumbers); + } + + boolean hasPreviousNumber(long packetNumber) { + if (this.packetNumber <= packetNumber) return false; + var pn = previousNumbers; + while (pn != null) { + if (pn.number == packetNumber) { + return true; + } + pn = pn.previous; + } + return false; + } + boolean hasExactNumber(long packetNumber) { + return this.packetNumber == packetNumber; + } + boolean hasNumber(long packetNumber) { + return this.packetNumber == packetNumber || hasPreviousNumber(packetNumber); + } + PreviousNumbers findPreviousAcknowledged(AckFrame frame) { + var pn = previousNumbers; + while (pn != null) { + if (frame.isAcknowledging(pn.number)) return pn; + pn = pn.previous; + } + return null; + } + boolean isAcknowledgedBy(AckFrame frame) { + if (frame.isAcknowledging(packetNumber)) return true; + else return findPreviousAcknowledged(frame) != null; + } + + public int attempts() { + var pn = previousNumbers; + int count = 0; + while (pn != null) { + count++; + pn = pn.previous; + } + return count; + } + + String prettyPrint() { + StringBuilder b = new StringBuilder(); + b.append("pn:").append(packetNumber); + var ppn = previousNumbers; + if (ppn != null) { + var sep = " ["; + while (ppn != null) { + b.append(sep).append(ppn.number); + ppn = ppn.previous; + sep = ", "; + } + b.append("]"); + } + return b.toString(); + } + } + + /** + * A task that sends packets to the peer. + * + * Packets are sent after a delay when: + * - ack delay timer expires + * - PTO timer expires + * They can also be sent without delay when: + * - we are unblocked by the peer + * - new data is available for sending, and we are not blocked + * - need to send ack without delay + */ + final class PacketTransmissionTask implements QuicTimedEvent { + private final SequentialScheduler handleScheduler = + SequentialScheduler.lockingScheduler(this::handleLoop); + private final long id = QuicTimerQueue.newEventId(); + private volatile Deadline nextDeadline; // updated through VarHandle + private PacketTransmissionTask() { + nextDeadline = Deadline.MAX; + } + + @Override + public long eventId() { return id; } + + @Override + public Deadline deadline() { + return nextDeadline; + } + + @Override + public Deadline handle() { + if (closed) { + if (debug.on()) { + debug.log("packet space already closed, PacketTransmissionTask will" + + " no longer be scheduled"); + } + return Deadline.MAX; + } + handleScheduler.runOrSchedule(packetEmitter.executor()); + return Deadline.MAX; + } + + /** + * The handle loop takes care of sending ACKs, packaging stream data + * (if applicable), and retransmitting on PTO. It is never invoked + * directly - but can be triggered by {@link #handle()} or {@link + * #runTransmitter()} + */ + private void handleLoop() { + transmitLock.lock(); + try { + handleLoop0(); + } catch (Throwable t) { + if (Log.errors()) { + Log.logError("{0}: {1} handleLoop failed: {2}", + packetEmitter.logTag(), packetNumberSpace, t); + Log.logError(t); + } else if (debug.on()) { + debug.log("handleLoop failed", t); + } + } finally { + transmitLock.unlock(); + } + } + + private void handleLoop0() throws IOException, QuicTransportException { + // while congestion control allows, or if PTO expired: + // - send lost packet or new packet + // if PTO still expired (== nothing was sent) + // - resend oldest packet, if available + // - otherwise send ping (+ack, if available) + // if ACK still not sent, send ack + if (debug.on()) { + debug.log("PacketTransmissionTask::handle"); + } + packetEmitter.checkAbort(PacketSpaceManager.this.packetNumberSpace); + // Handle is called from within the executor + var nextDeadline = this.nextDeadline; + Deadline now = now(); + do { + transmitNow = false; + var closed = !isOpenForTransmission(); + if (closed) { + if (debug.on()) { + debug.log("PacketTransmissionTask::handle: %s closed", + PacketSpaceManager.this.packetNumberSpace); + } + return; + } + if (debug.on()) debug.log("PacketTransmissionTask::handle"); + // this may update congestion controller + int lost = detectAndRemoveLostPackets(now); + if (lost > 0 && debug.on()) debug.log("handle: found %s lost packets", lost); + // if we're sending on PTO, we need to double backoff afterwards + boolean needBackoff = isPTO(now); + int packetsSent = 0; + boolean cwndAvailable; + while ((cwndAvailable = congestionController.canSendPacket()) || + (needBackoff && packetsSent < 2)) { // if PTO, try to send 2 packets + if (!isOpenForTransmission()) { + break; + } + final boolean retransmitted; + try { + retransmitted = retransmit(); + } catch (QuicKeyUnavailableException qkue) { + if (!isOpenForTransmission()) { + if (debug.on()) { + debug.log("already closed; not re-transmitting any more data"); + } + clearAll(); + return; + } + throw new IOException("failed to retransmit data, reason: " + + qkue.getMessage()); + } + if (retransmitted) { + packetsSent++; + continue; + } + final boolean sentNew; + // nothing was retransmitted - check for new data + try { + sentNew = sendNewData(); + } catch (QuicKeyUnavailableException qkue) { + if (!isOpenForTransmission()) { + if (debug.on()) { + debug.log("already closed; not transmitting any more data"); + } + return; + } + throw new IOException("failed to send new data, reason: " + + qkue.getMessage()); + } + if (!sentNew) { + break; + } else { + if (needBackoff && packetsSent == 0 && Log.quicRetransmit()) { + Log.logQuic("%s OUT: transmitted new packet on PTO".formatted( + packetEmitter.logTag())); + } + } + packetsSent++; + } + blockedByCC = !cwndAvailable; + if (!cwndAvailable && isOpenForTransmission()) { + if (debug.on()) debug.log("handle: blocked by CC"); + // CC might be available already + if (congestionController.canSendPacket()) { + if (debug.on()) debug.log("handle: unblocked immediately"); + transmitNow = true; + } + } + try { + if (isPTO(now) && isOpenForTransmission()) { + if (debug.on()) debug.log("handle: retransmit on PTO"); + // nothing was sent by the above loop - try to resend the oldest packet + retransmitPTO(); + } else if (fastRetransmit) { + assert packetNumberSpace == PacketNumberSpace.INITIAL; + fastRetransmitDone = true; + fastRetransmit = false; + if (debug.on()) debug.log("handle: fast retransmit"); + // try to resend the oldest packet + retransmitPTO(); + } + } catch (QuicKeyUnavailableException qkue) { + if (!isOpenForTransmission()) { + if (debug.on()) { + debug.log("already closed; not re-transmitting any more data"); + } + return; + } + throw new IOException("failed to retransmitPTO data, key space, reason: " + + qkue.getMessage()); + } + boolean stillPTO = isPTO(now); + // if the ack frame is not sent yet, send it now + var ackFrame = getNextAckFrame(!stillPTO); + var pingRequested = PacketSpaceManager.this.pingRequested; + boolean sendPing = pingRequested != null || stillPTO + || shouldSendPing(now, ackFrame); + if (sendPing || ackFrame != null) { + if (debug.on()) debug.log("handle: generate ACK packet or PING ack:%s ping:%s", + ackFrame != null, sendPing); + final long emitted; + try { + emitted = emitAckPacket(ackFrame, sendPing); + } catch (QuicKeyUnavailableException qkue) { + if (!isOpenForTransmission()) { + if (debug.on()) { + debug.log("already closed; not sending ack/ping packet"); + } + return; + } + throw new IOException("failed to send ack/ping data, reason: " + + qkue.getMessage()); + } + if (sendPing && pingRequested != null) { + if (emitted < 0) pingRequested.complete(-1L); + else registerPingRequest(new PingRequest(now, emitted, pingRequested)); + synchronized (PacketSpaceManager.this) { + PacketSpaceManager.this.pingRequested = null; + } + } + } + if (needBackoff) { + long backoff = rttEstimator.increasePtoBackoff(); + if (debug.on()) { + debug.log("handle: %s increase backoff to %s", + PacketSpaceManager.this.packetNumberSpace, + backoff); + } + packetEmitter.ptoBackoffIncreased(PacketSpaceManager.this, backoff); + } + + // if nextDeadline is not Deadline.MAX the task will be + // automatically rescheduled. + if (debug.on()) debug.log("handle: refreshing deadline"); + nextDeadline = computeNextDeadline(); + } while(!nextDeadline.isAfter(now)); + + logNoDeadline(nextDeadline, true); + if (Deadline.MAX.equals(nextDeadline)) return; + // we have a new deadline + packetEmitter.reschedule(this, nextDeadline); + } + + /** + * Create and send a new packet + * @return true if packet was sent, false if there is no more data to send + */ + private boolean sendNewData() throws QuicKeyUnavailableException, QuicTransportException { + if (debug.on()) debug.log("handle: sending data..."); + boolean sent = packetEmitter.sendData(packetNumberSpace); + if (!sent) { + if (debug.on()) debug.log("handle: no more data to send"); + } + return sent; + } + + @Override + public Deadline refreshDeadline() { + Deadline previousDeadline, newDeadline; + do { + previousDeadline = this.nextDeadline; + newDeadline = computeNextDeadline(); + } while (!Handles.DEADLINE.compareAndSet(this, previousDeadline, newDeadline)); + + if (!newDeadline.equals(previousDeadline)) { + if (debug.on()) { + var now = now(); + if (newDeadline.equals(Deadline.MAX)) { + debug.log("Deadline refreshed: no new deadline"); + } else if (newDeadline.equals(Deadline.MIN)) { + debug.log("Deadline refreshed: run immediately"); + } else if (previousDeadline.equals(Deadline.MAX) || previousDeadline.equals(Deadline.MIN)) { + var delay = now.until(newDeadline, ChronoUnit.MILLIS); + if (delay < 0) { + debug.log("Deadline refreshed: new deadline passed by %dms", delay); + } else { + debug.log("Deadline refreshed: new deadline in %dms", delay); + } + } else { + var delay = now.until(newDeadline, ChronoUnit.MILLIS); + if (delay < 0) { + debug.log("Deadline refreshed: new deadline passed by %dms (diff: %dms)", + delay, previousDeadline.until(newDeadline, ChronoUnit.MILLIS)); + } else { + debug.log("Deadline refreshed: new deadline in %dms (diff: %dms)", + instantSource.instant().until(newDeadline, ChronoUnit.MILLIS), + previousDeadline.until(newDeadline, ChronoUnit.MILLIS)); + } + } + } + } else { + debug.log("Deadline not refreshed: no change"); + } + logNoDeadline(newDeadline, false); + return newDeadline; + } + + void logNoDeadline(Deadline newDeadline, boolean onlyNoDeadline) { + if (Log.quicRetransmit()) { + if (Deadline.MAX.equals(newDeadline)) { + if (shouldLogWhenNoDeadline()) { + Log.logQuic("{0}: {1} no deadline, task unscheduled", + packetEmitter.logTag(), packetNumberSpace); + } // else: no changes... + } else if (!onlyNoDeadline && shouldLogWhenNewDeadline()) { + if (Deadline.MIN.equals(newDeadline)) { + Log.logQuic("{0}: {1} Deadline.MIN, task will be rescheduled immediately", + packetEmitter.logTag(), packetNumberSpace); + } else { + try { + Log.logQuic("{0}: {1} new deadline computed, deadline in {2}ms", + packetEmitter.logTag(), packetNumberSpace, + Long.toString(now().until(newDeadline, ChronoUnit.MILLIS))); + } catch (ArithmeticException ae) { + Log.logError("Unexpected exception while logging deadline " + + newDeadline + ": " + ae); + Log.logError(ae); + assert false : "Unexpected ArithmeticException: " + ae; + } + } + } + } + } + + private boolean hadNoDeadline; + private synchronized boolean shouldLogWhenNoDeadline() { + if (!hadNoDeadline) { + hadNoDeadline = true; + return true; + } + return false; + } + + private synchronized boolean shouldLogWhenNewDeadline() { + if (hadNoDeadline) { + hadNoDeadline = false; + return true; + } + return false; + } + + boolean hasNoDeadline() { + return Deadline.MAX.equals(nextDeadline); + } + + // reschedule this task + void reschedule() { + Deadline deadline = computeNextDeadline(); + Deadline nextDeadline = this.nextDeadline; + if (Deadline.MAX.equals(deadline)) { + debug.log("no deadline, don't reschedule"); + } else if (deadline.equals(nextDeadline)) { + debug.log("deadline unchanged, don't reschedule"); + } else { + packetEmitter.reschedule(this, deadline); + debug.log("retransmission task: rescheduled"); + } + } + + @Override + public String toString() { + return "PacketTransmissionTask(" + debugStrSupplier.get() + ")"; + } + } + + Deadline deadline() { + return packetTransmissionTask.deadline(); + } + + Deadline prospectiveDeadline() { + return computeNextDeadline(false); + } + + // remove all pending acknowledgements and retransmissions. + private void clearAll() { + transferLock.lock(); + try { + pendingAcknowledgements.forEach(ack -> congestionController.packetDiscarded(List.of(ack.packet))); + if (debug.on()) { + final StringBuilder sb = new StringBuilder(); + pendingAcknowledgements.forEach((p) -> sb.append(" ").append(p)); + if (!sb.isEmpty()) { + debug.log("forgetting pending acks: " + sb); + } + } + pendingAcknowledgements.clear(); + + if (debug.on()) { + final StringBuilder sb = new StringBuilder(); + pendingRetransmission.forEach((p) -> sb.append(" ").append(p)); + if (!sb.isEmpty()) { + debug.log("forgetting pending retransmissions: " + sb); + } + } + pendingRetransmission.clear(); + + if (debug.on()) { + final StringBuilder sb = new StringBuilder(); + triggeredForRetransmission.forEach((p) -> sb.append(" ").append(p)); + if (!sb.isEmpty()) { + debug.log("forgetting triggered-for-retransmissions: " + sb.toString()); + } + } + triggeredForRetransmission.clear(); + + if (debug.on()) { + final StringBuilder sb = new StringBuilder(); + lostPackets.forEach((p) -> sb.append(" ").append(p)); + if (!sb.isEmpty()) { + debug.log("forgetting lost-packets: " + sb.toString()); + } + } + lostPackets.clear(); + } finally { + transferLock.unlock(); + } + } + + private void retransmitPTO() throws QuicKeyUnavailableException, QuicTransportException { + if (!isOpenForTransmission()) { + if (debug.on()) { + debug.log("already closed; retransmission on PTO dropped", packetNumberSpace); + } + clearAll(); + return; + } + + PendingAcknowledgement pending; + transferLock.lock(); + try { + if ((pending = pendingAcknowledgements.poll()) != null) { + if (debug.on()) debug.log("Retransmit on PTO: looking for candidate"); + // TODO should keep this packet on the list until it's either acked or lost + congestionController.packetDiscarded(List.of(pending.packet)); + pendingRetransmission.add(pending); + } + } finally { + transferLock.unlock(); + } + if (pending != null) { + packetEmitter.retransmit(this, pending.packet(), pending.attempts()); + } + } + + /** + * {@return true if this packet space isn't closed and if the underlying packet emitter + * is open, else returns false} + */ + private boolean isOpenForTransmission() { + return !this.closed && this.packetEmitter.isOpen(); + } + + /** + * A class to keep track of the largest packet that was acknowledged by + * a packet that is being acknowledged. + * This information is used to implement the algorithm described in + * RFC 9000 13.2.4. Limiting Ranges by Tracking ACK Frames + */ + private final class EmittedAckTracker { + volatile long ignoreAllPacketsBefore = -1; + /** + * Record the {@link AckFrame#largestAcknowledged() + * largest acknowledged} packet that was sent in an + * {@link AckFrame} that the peer has acknowledged. + * @param largestAcknowledged the packet number to record + * @return the largest {@code largestAcknowledged} + * packet number that was recorded. + * This is necessarily smaller than (or equal to) the + * {@link #getLargestSentAckedPN()}. + */ + private long record(long largestAcknowledged) { + long witness; + long largestSentAckedPN = PacketSpaceManager.this.largestSentAckedPN; + do { + witness = largestAckedPNReceivedByPeer; + if (witness >= largestAcknowledged) { + largestSentAckedPN = PacketSpaceManager.this.largestSentAckedPN; + assert witness <= largestSentAckedPN || ignoreAllPacketsBefore > largestSentAckedPN + : "largestAckedPNReceivedByPeer: %s, ignoreAllPacketsBefore: %s, largestSentAckedPN: %s" + .formatted(witness, ignoreAllPacketsBefore, largestSentAckedPN); + return witness; + } + } while (!Handles.LARGEST_ACK_ACKED_PN.compareAndSet( + PacketSpaceManager.this, witness, largestAcknowledged)); + assert witness <= largestAcknowledged; + assert largestAcknowledged <= largestSentAckedPN || ignoreAllPacketsBefore > largestSentAckedPN + : "largestAcknowledged: %s, ignoreAllPacketsBefore: %s, largestSentAckedPN: %s" + .formatted(largestSentAckedPN, ignoreAllPacketsBefore, largestSentAckedPN); + return largestAcknowledged; + } + + private boolean ignoreAllPacketsBefore(long packetNumber) { + long ignoreAllPacketsBefore; + do { + ignoreAllPacketsBefore = this.ignoreAllPacketsBefore; + if (packetNumber <= ignoreAllPacketsBefore) return false; + } while (!Handles.IGNORE_ALL_PN_BEFORE.compareAndSet( + this, ignoreAllPacketsBefore, packetNumber)); + return true; + } + + /** + * Tracks the largest packet acknowledged by the packets acknowledged in the + * given AckFrame. This helps to implement the algorithm described in + * RFC 9000, 13.2.4. Limiting Ranges by Tracking ACK Frames. + * @param pending a yet unacknowledged packet that may be acknowledged + * by the given{@link AckFrame}. + * @param frame a received {@code AckFrame} + * @return whether the given pending unacknowledged packet is being + * acknowledged by this ack frame. + */ + public boolean trackAcknowlegment(PendingAcknowledgement pending, AckFrame frame) { + if (frame.isAcknowledging(pending.packetNumber)) { + record(pending.largestAcknowledged); + packetEmitter.acknowledged(pending.packet()); + return true; + } + // There is a potential for a never ending retransmission + // loop here if we don't treat the ack of a previous packet just + // as the ack of the tip of the chain. + // So we call packetEmitter.acknowledged(pending.packet()) here too, + // and return `true` in this case as well. + var previous = pending.findPreviousAcknowledged(frame); + if (previous != null) { + record(previous.largestAcknowledged); + packetEmitter.acknowledged(pending.packet()); + return true; + } + return false; + } + + public long largestAckAcked() { + return largestAckedPNReceivedByPeer; + } + + public void dropPacketNumbersSmallerThan(long newLargestIgnored) { + // this method is called after arbitrarily reducing the ack range + // to this value; This mean we will drop packets whose packet + // number is smaller than the given packet number. + if (ignoreAllPacketsBefore(newLargestIgnored)) { + record(newLargestIgnored); + } + } + } + + private final QuicTLSEngine quicTLSEngine; + private final EmittedAckTracker emittedAckTracker; + private volatile NextAckFrame nextAckFrame; // assigned through VarHandle + // exponent for outgoing packets; defaults to 3 + public static final int ACK_DELAY_EXPONENT = 3; + // max ack delay sent in quic transport parameters, in millis + public static final int ADVERTISED_MAX_ACK_DELAY = 25; + // max timer delay, i.e. how late selector.select returns; 15.6 millis on Windows + public static final int TIMER_DELAY = 16; + // effective max ack delay for outgoing application packets + public static final int MAX_ACK_DELAY = ADVERTISED_MAX_ACK_DELAY - TIMER_DELAY; + + // exponent for incoming packets + private volatile long peerAckDelayExponent; + // max peer ack delay; zero on initial and handshake, + // initialized from transport parameters on application + private volatile long peerMaxAckDelayMillis; // ms + // max ack delay; zero on initial and handshake, MAX_ACK_DELAY on application + private final long maxAckDelay; // ms + volatile boolean closed; + + // The last time an ACK eliciting packet was sent. + // May be null before any such packet is sent... + private volatile Deadline lastAckElicitingTime; + + // not null if sending ping has been requested. + private volatile CompletableFuture pingRequested; + + // The largest packet number successfully processed in this space. + // Needed to decode received packet numbers, see RFC 9000 appendix A.3 + private volatile long largestProcessedPN; // assigned through VarHandle + + // The largest ACK-eliciting packet number received in this space. + // Needed to determine if we should send ACK without delay, see RFC 9000 section 13.2.1 + private volatile long largestAckElicitingReceivedPN; // assigned through VarHandle + + // The largest ACK-eliciting packet number sent in this space. + // Needed to determine if we should arm PTO timer + private volatile long largestAckElicitingSentPN; + + // The largest packet number acknowledged by peer. + // Needed to determine packet number length, see RFC 9000 appendix A.2 + private volatile long largestReceivedAckedPN; // assigned through VarHandle + + // The largest packet number acknowledged in this space + // This is the largest packet number we have acknowledged. + // This should be less or equal to the largestProcessedPN always. + // Not used. + private volatile long largestSentAckedPN; // assigned through VarHandle + + // The largest packet number that this instance has included + // in an AckFrame sent to the peer, and of which the peer has + // acknowledged reception. + // Used to limit ack ranges, see RFC 9000 section 13.2.4 + private volatile long largestAckedPNReceivedByPeer; // assigned through VarHandle + + /** + * Creates a new {@code PacketSpaceManager} for the given + * packet number space. + * @param connection The connection for which this manager + * is created. + * @param packetNumberSpace The packet number space. + */ + public PacketSpaceManager(final QuicConnectionImpl connection, + final PacketNumberSpace packetNumberSpace) { + this(packetNumberSpace, connection.emitter(), TimeSource.source(), + connection.rttEstimator, connection.congestionController, connection.getTLSEngine(), + () -> connection.dbgTag() + "[" + packetNumberSpace.name() + "]"); + } + + /** + * Creates a new {@code PacketSpaceManager} for the given + * packet number space. + * + * @param packetNumberSpace the packet number space. + * @param packetEmitter the packet emitter + * @param congestionController the congestion controller + * @param debugStrSupplier a supplier for a debug tag to use for logging purposes + */ + public PacketSpaceManager(PacketNumberSpace packetNumberSpace, + PacketEmitter packetEmitter, + TimeLine instantSource, + QuicRttEstimator rttEstimator, + QuicCongestionController congestionController, + QuicTLSEngine quicTLSEngine, + Supplier debugStrSupplier) { + largestProcessedPN = largestReceivedAckedPN = largestAckElicitingReceivedPN + = largestAckElicitingSentPN = largestSentAckedPN = largestAckedPNReceivedByPeer = -1L; + this.debugStrSupplier = debugStrSupplier; + this.debug = Utils.getDebugLogger(debugStrSupplier); + this.instantSource = instantSource; + this.rttEstimator = rttEstimator; + this.congestionController = congestionController; + this.packetNumberSpace = packetNumberSpace; + this.packetEmitter = packetEmitter; + this.emittedAckTracker = new EmittedAckTracker(); + this.packetTransmissionTask = new PacketTransmissionTask(); + this.quicTLSEngine = quicTLSEngine; + maxAckDelay = (packetNumberSpace == PacketNumberSpace.APPLICATION) + ? MAX_ACK_DELAY : 0; + } + + /** + * {@return the max delay before emitting a non ACK-eliciting packet to + * acknowledge a received ACK-eliciting packet, in milliseconds} + */ + public long getMaxAckDelay() { + return maxAckDelay; + } + + /** + * {@return the max ACK delay of the peer, in milliseconds} + */ + public long getPeerMaxAckDelayMillis() { + return peerMaxAckDelayMillis; + } + + /** + * Changes the value of the {@linkplain #getPeerMaxAckDelayMillis() + * peer max ACK delay} and ack delay exponent + * + * @param peerDelay the new delay, in milliseconds + * @param ackDelayExponent the new ack delay exponent + */ + @Override + public void updatePeerTransportParameters(long peerDelay, long ackDelayExponent) { + this.peerAckDelayExponent = ackDelayExponent; + this.peerMaxAckDelayMillis = peerDelay; + } + + @Override + public PacketNumberSpace packetNumberSpace() { + return packetNumberSpace; + } + + @Override + public long allocateNextPN() { + return nextPN.getAndIncrement(); + } + + @Override + public long getLargestPeerAckedPN() { + return largestReceivedAckedPN; + } + + @Override + public long getLargestProcessedPN() { + return largestProcessedPN; + } + + @Override + public long getMinPNThreshold() { + return largestAckedPNReceivedByPeer; + } + + @Override + public long getLargestSentAckedPN() { + return largestSentAckedPN; + } + + /** + * This method is called by {@link QuicConnectionImpl} upon reception of + * and successful negotiation of a new version. + * In that case we should stop retransmitting packet that have the + * "wrong" version: they will never be acknowledged. + */ + public void versionChanged() { + // don't retransmit packet with "bad" version + assert packetNumberSpace == PacketNumberSpace.INITIAL; + if (debug.on()) { + debug.log("version changed - clearing pending acks"); + } + clearAll(); + } + + public void retry() { + assert packetNumberSpace == PacketNumberSpace.INITIAL; + if (debug.on()) { + debug.log("retry received - clearing pending acks"); + } + clearAll(); + } + + @Override + public ReentrantLock getTransmitLock() { + return transmitLock; + } + + // adds the PingRequest to the pendingPingRequests queue so + // that it can be completed when the packet is ACK'ed. + private void registerPingRequest(PingRequest pingRequest) { + if (closed) { + pingRequest.response().completeExceptionally(new IOException("closed")); + return; + } + pendingPingRequests.add(pingRequest); + // could be acknowledged already! + processPingResponses(largestReceivedAckedPN); + } + + @Override + public void close() { + if (closed) { + return; + } + synchronized (this) { + if (closed) return; + closed = true; + } + if (Log.quicControl() || Log.quicRetransmit()) { + Log.logQuic("{0} closing packet space {1}", + packetEmitter.logTag(), packetNumberSpace); + } + if (debug.on()) { + debug.log("closing packet space"); + } + // stop the internal scheduler + packetTransmissionTask.handleScheduler.stop(); + // make sure the task gets eventually removed from the timer + packetEmitter.reschedule(packetTransmissionTask); + // clear pending acks, retransmissions + transferLock.lock(); + try { + clearAll(); + // discard the (TLS) keys + if (debug.on()) { + debug.log("discarding TLS keys"); + } + this.quicTLSEngine.discardKeys(tlsEncryptionLevel()); + } finally { + transferLock.unlock(); + } + rttEstimator.resetPtoBackoff(); + // complete any ping request that hasn't been completed + IOException io = null; + try { + for (var pr : pendingPingRequests) { + if (io == null) { + io = new IOException("Not sending ping because " + + this.packetNumberSpace + " packet space is being closed"); + } + // TODO: is it necessary for this to be an exceptional completion? + pr.response().completeExceptionally(io); + } + } finally { + pendingPingRequests.clear(); + } + } + + @Override + public boolean isClosed() { + return closed; + } + + @Override + public void runTransmitter() { + transmitNow = true; + // run the handle loop + packetTransmissionTask.handle(); + } + + @Override + public void packetReceived(PacketType packet, long packetNumber, boolean isAckEliciting) { + assert PacketNumberSpace.of(packet) == packetNumberSpace; + assert packetNumber > largestAckedPNReceivedByPeer; + + if (closed) { + if (debug.on()) { + debug.log("%s closed, ignoring %s(pn: %s)", packetNumberSpace, packet, packetNumber); + } + return; + } + + if (debug.on()) { + debug.log("packetReceived %s(pn:%d, needsAck:%s)", + packet, packetNumber, isAckEliciting); + } + + // whether the packet is ack eliciting or not, we need to add its packet + // number to the ack frame. + packetProcessed(packetNumber); + addToAckFrame(packetNumber, isAckEliciting); + } + + // used in tests + public T triggeredForRetransmission(Function walker) { + return walker.apply(triggeredForRetransmission.stream() + .mapToLong(PendingAcknowledgement::packetNumber)); + } + + public T pendingRetransmission(Function walker) { + return walker.apply(pendingRetransmission.stream() + .mapToLong(PendingAcknowledgement::packetNumber)); + } + + // used in tests + public T pendingAcknowledgements(Function walker) { + return walker.apply(pendingAcknowledgements.stream() + .mapToLong(PendingAcknowledgement::packetNumber)); + } + + // used in tests + public AtomicLong getNextPN() { return nextPN; } + + // Called by the retransmitLoop scheduler. + // Retransmit one packet for which retransmission has been triggered by + // the PacketTransmissionTask. + // return true if something was retransmitted, or false if there was nothing to retransmit + private boolean retransmit() throws QuicKeyUnavailableException, QuicTransportException { + PendingAcknowledgement pending; + final var closed = !this.isOpenForTransmission(); + if (closed) { + if (debug.on()) { + debug.log("already closed; retransmission dropped"); + } + clearAll(); + return false; + } + transferLock.lock(); + try { + pending = triggeredForRetransmission.poll(); + } finally { + transferLock.unlock(); + } + + if (pending != null) { + // allocate new packet number + // create new packet + // encrypt packet + // send packet + if (debug.on()) debug.log("handle: retransmitting..."); + packetEmitter.retransmit(this, pending.packet(), pending.attempts()); + return true; + } + return false; + } + + /** + * Called by the {@link PacketTransmissionTask} to + * generate a non ACK eliciting packet containing only the given + * ACK frame. + * + *

If a received packet is ACK-eliciting, then it will be either + * directly acknowledged by {@link QuicConnectionImpl} - which will + * call {@link #getNextAckFrame(boolean)} to embed the {@link AckFrame} + * in a packet, or by a non-eliciting ACK packet which will be + * triggered {@link #getMaxAckDelay() maxAckDelay} after the reception + * of the ACK-eliciting packet (this method, triggered by the {@link + * PacketTransmissionTask}). + * + *

This method doesn't reset the {@linkplain #getNextAckFrame(boolean) + * next ack frame} to be sent, but reset its delay so that only + * one non ACK-eliciting packet is emitted to acknowledge a given + * packet. + * + * @param ackFrame The ACK frame to send. + * @return the packet number of the emitted packet + */ + private long emitAckPacket(AckFrame ackFrame, boolean sendPing) + throws QuicKeyUnavailableException, QuicTransportException { + final boolean closed = !this.isOpenForTransmission(); + if (closed) { + if (debug.on()) { + debug.log("Packet space closed, ack/ping won't be sent" + + (ackFrame != null ? ": " + ackFrame : "")); + } + return -1L; + } + try { + return packetEmitter.emitAckPacket(this, ackFrame, sendPing); + } catch (QuicKeyUnavailableException | QuicTransportException e) { + if (!this.isOpenForTransmission()) { + // possible race condition where the packet space was closed (and keys discarded) + // while there was an attempt to send an ACK/PING frame. + // Ignore such cases, since it's OK to not send those frames when the packet space + // is already closed + if (debug.on()) { + debug.log("ack/ping wasn't sent since packet space was closed" + + (ackFrame != null ? ": " + ackFrame : "")); + } + return -1L; + } + throw e; + } + } + + boolean isClosing(QuicPacket packet) { + var frames = packet.frames(); + if (frames == null || frames.isEmpty()) return false; + return frames.stream() + .anyMatch(ConnectionCloseFrame.class::isInstance); + } + + private synchronized void lastAckElicitingSent(long packetNumber) { + if (largestAckElicitingSentPN < packetNumber) { + largestAckElicitingSentPN = packetNumber; + } + } + + @Override + public void packetSent(QuicPacket packet, long previousPacketNumber, long packetNumber) { + if (packetNumber < 0) { + throw new IllegalArgumentException("Invalid packet number: " + packetNumber); + } + largestAckSent(AckFrame.largestAcknowledgedInPacket(packet)); + if (previousPacketNumber >= 0) { + if (debug.on()) { + debug.log("retransmitted packet %s(%d) as %d", + packet.packetType(), previousPacketNumber, packetNumber); + } + + boolean found = false; + transferLock.lock(); + try { + // check for close and addAcknowledgement in the same lock + // to avoid races with close / clearAll + final var closed = !this.isOpenForTransmission(); + if (closed) { + if (debug.on()) { + debug.log("%s already closed: ignoring packet pn:%s", + packetNumberSpace, packet.packetNumber()); + } + return; + } + // TODO: should use a tail set here to skip all pending acks + // whose packet number is < previousPacketNumber? + var iterator = pendingRetransmission.iterator(); + PendingAcknowledgement replacement; + while (iterator.hasNext()) { + PendingAcknowledgement pending = iterator.next(); + if (pending.hasPreviousNumber(previousPacketNumber)) { + // no need to retransmit twice, but can this happen? + iterator.remove(); + } else if (!found && pending.hasExactNumber(previousPacketNumber)) { + PreviousNumbers previous = new PreviousNumbers( + previousPacketNumber, + pending.largestAcknowledged, pending.previousNumbers); + replacement = + new PendingAcknowledgement(packet, now(), packetNumber, previous); + if (debug.on()) { + debug.log("Packet %s(pn:%s) previous %s(pn:%s) is pending acknowledgement", + packet.packetType(), packetNumber, packet.packetType(), previousPacketNumber); + } + var rep = replacement; + if (lostPackets.removeIf(p -> rep.hasPreviousNumber(p.packetNumber))) { + lostPackets.add(rep); + } + addAcknowledgement(replacement); + iterator.remove(); + found = true; + } + } + } finally { + transferLock.unlock(); + } + if (found && packetTransmissionTask.hasNoDeadline()) { + packetTransmissionTask.reschedule(); + } + if (!found) { + if (debug.on()) { + debug.log("packetRetransmitted: packet not found - previous: %s for %s(%s)", + previousPacketNumber, packet.packetType(), packetNumber); + } + } + } else { + if (packet.isAckEliciting()) { + // This method works with the following assumption: + // - Non ACK eliciting packet do not need to be retransmitted because: + // - they only contain ack frames - which may/will we be retransmitted + // anyway with the next ack eliciting packet + // - they will not be acknowledged directly - we don't want to + // resend them constantly + if (debug.on()) { + debug.log("Packet %s(pn:%s) is pending acknowledgement", + packet.packetType(), packetNumber); + } + PendingAcknowledgement pending = new PendingAcknowledgement(packet, + now(), packetNumber, null); + transferLock.lock(); + try { + // check for close and addAcknowledgement in the same lock + // to avoid races with close / clearAll + final var closed = !this.isOpenForTransmission(); + if (closed) { + if (debug.on()) { + debug.log("%s already closed: ignoring packet pn:%s", + packetNumberSpace, packet.packetNumber()); + } + return; + } + addAcknowledgement(pending); + if (packetTransmissionTask.hasNoDeadline()) { + packetTransmissionTask.reschedule(); + } + } finally { + transferLock.unlock(); + } + } + } + + + } + + private void addAcknowledgement(PendingAcknowledgement ack) { + lastAckElicitingSent(ack.sent); + lastAckElicitingSent(ack.packetNumber); + pendingAcknowledgements.add(ack); + sendTimes.put(ack.packetNumber, ack.sent); + congestionController.packetSent(ack.packet().size()); + } + + /** + * Computes the next deadline for generating a non ACK eliciting + * packet containing the next ACK frame, or for retransmitting + * unacknowledged packets for which retransmission is due. + * This may be different to the {@link #nextScheduledDeadline()} + * if newer changes have not been taken into account yet. + * @return the deadline at which the scheduler's task for this packet + * space should be scheduled to wake up + */ + public Deadline computeNextDeadline() { + return computeNextDeadline(true); + } + + public Deadline computeNextDeadline(boolean verbose) { + + if (closed) { + if (verbose && Log.quicTimer()) { + Log.logQuic(String.format("%s: [%s] closed - no deadline", + packetEmitter.logTag(), packetNumberSpace)); + } + return Deadline.MAX; + } + if (transmitNow) { + if (verbose && Log.quicTimer()) { + Log.logQuic(String.format("%s: [%s] transmit now", + packetEmitter.logTag(), packetNumberSpace)); + } + return Deadline.MIN; + } + if (pingRequested != null) { + if (verbose && Log.quicTimer()) { + Log.logQuic(String.format("%s: [%s] ping requested", + packetEmitter.logTag(), packetNumberSpace)); + } + return Deadline.MIN; + } + var ack = nextAckFrame; + + Deadline ackDeadline = (ack == null || ack.sent() != null) + ? Deadline.MAX // if the ack frame has already been sent, getNextAck() returns null + : ack.deadline(); + Deadline lossDeadline = getLossTimer(); + // TODO: consider removing the debug traces in this method when integrating + // if both loss deadline and PTO timer are set, loss deadline is always earlier + if (verbose && debug.on() && lossDeadline != Deadline.MIN) debug.log("lossDeadline is: " + lossDeadline); + if (lossDeadline != null) { + if (verbose && debug.on()) { + if (lossDeadline == Deadline.MIN) { + debug.log("lossDeadline is immediate"); + } else if (!ackDeadline.isBefore(lossDeadline)) { + debug.log("lossDeadline in %s ms", + Deadline.between(now(), lossDeadline).toMillis()); + } else { + debug.log("ackDeadline before lossDeadline in %s ms", + Deadline.between(now(), ackDeadline).toMillis()); + } + } + if (verbose && Log.quicTimer()) { + Log.logQuic(String.format("%s: [%s] loss deadline: %s, ackDeadline: %s, deadline in %s", + packetEmitter.logTag(), packetNumberSpace, lossDeadline, ackDeadline, + Utils.debugDeadline(now(), min(ackDeadline, lossDeadline)))); + } + return min(ackDeadline, lossDeadline); + } + Deadline ptoDeadline = getPtoDeadline(); + if (verbose && debug.on()) debug.log("ptoDeadline is: " + ptoDeadline); + if (ptoDeadline != null) { + if (verbose && debug.on()) { + if (!ackDeadline.isBefore(ptoDeadline)) { + debug.log("ptoDeadline in %s ms", + Deadline.between(now(), ptoDeadline).toMillis()); + } else { + debug.log("ackDeadline before ptoDeadline in %s ms", + Deadline.between(now(), ackDeadline).toMillis()); + } + } + if (verbose && Log.quicTimer()) { + Log.logQuic(String.format("%s: [%s] PTO deadline: %s, ackDeadline: %s, deadline in %s", + packetEmitter.logTag(), packetNumberSpace, ptoDeadline, ackDeadline, + Utils.debugDeadline(now(), min(ackDeadline, ptoDeadline)))); + } + return min(ackDeadline, ptoDeadline); + } + if (verbose && debug.on()) { + if (ackDeadline == Deadline.MAX) { + debug.log("ackDeadline is: Deadline.MAX"); + } else { + debug.log("ackDeadline in %s ms", + Deadline.between(now(), ackDeadline).toMillis()); + } + } + if (ackDeadline.equals(Deadline.MAX)) { + if (verbose && Log.quicTimer()) { + Log.logQuic(String.format("%s: [%s] no deadline: " + + "pendingAcks: %s, triggered: %s, pendingRetransmit: %s", + packetEmitter.logTag(), packetNumberSpace, pendingAcknowledgements.size(), + triggeredForRetransmission.size(), pendingRetransmission.size())); + } + } else { + if (verbose && Log.quicTimer()) { + Log.logQuic(String.format("%s: [%s] deadline is %s", + packetEmitter.logTag(), packetNumberSpace(), + Utils.debugDeadline(now(), ackDeadline))); + } + } + return ackDeadline; + } + + /** + * {@return the next deadline at which the scheduler's task for this packet + * space is currently scheduled to wake up} + */ + public Deadline nextScheduledDeadline() { + return packetTransmissionTask.nextDeadline; + } + + private Deadline now() { + return instantSource.instant(); + } + + /** + * Tracks the largest packet acknowledged by the packets acknowledged in the + * given AckFrame. This helps to implement the algorithm described in + * RFC 9000, 13.2.4. Limiting Ranges by Tracking ACK Frames. + * @param pending a yet unacknowledged packet that may be acknowledged + * by the given{@link AckFrame}. + * @param frame a received {@code AckFrame} + * @return whether the given pending unacknowledged packet is being + * acknowledged by this ack frame. + */ + private boolean trackAcknowlegment(PendingAcknowledgement pending, AckFrame frame) { + return emittedAckTracker.trackAcknowlegment(pending, frame); + } + + private boolean isAcknowledgingLostPacket(PendingAcknowledgement pending, AckFrame frame, + List[] recovered) { + if (frame.isAcknowledging(pending.packetNumber)) { + if (recovered != null) { + if (recovered[0] == null) { + recovered[0] = new ArrayList<>(); + } + recovered[0].add(pending); + } + return true; + } + // There is a potential for a never ending retransmission + // loop here if we don't treat the ack of a previous packet just + // as the ack of the tip of the chain. + // So we call packetEmitter.acknowledged(pending.packet()) here too, + // and return `true` in this case as well. + var previous = pending.findPreviousAcknowledged(frame); + if (previous != null) { + if (recovered != null) { + if (recovered[0] == null) { + recovered[0] = new ArrayList<>(); + } + recovered[0].add(pending); + } + return true; + } + return false; + + } + + @Override + public void processAckFrame(AckFrame frame) throws QuicTransportException { + // for each acknowledged packet, remove it from the + // list of packets pending acknowledgement, or from the + // list of packets pending retransmission + long largestAckAckedBefore = emittedAckTracker.largestAckAcked(); + long largestAcknowledged = frame.largestAcknowledged(); + Deadline now = now(); + if (largestAcknowledged >= nextPN.get()) { + throw new QuicTransportException("Acknowledgement for a nonexistent packet", + null, frame.getTypeField(), QuicTransportErrors.PROTOCOL_VIOLATION); + } + + int lostCount; + transferLock.lock(); + try { + if (largestAckReceived(largestAcknowledged)) { + // if the largest acknowledged PN is newly acknowledged + // and at least one of the newly acked packets is ack-eliciting + // -> use the new RTT sample + // the below code only checks if largest acknowledged is ack-eliciting + Deadline sentTime = sendTimes.get(largestAcknowledged); + if (sentTime != null) { + long ackDelayMicros; + if (isApplicationSpace()) { + confirmHandshake(); + long baseAckDelay = peerAckDelayToMicros(frame.ackDelay()); + // if packet was sent after handshake confirmed, use max ack delay + if (largestAcknowledged >= handshakeConfirmedPN) { + ackDelayMicros = Math.min( + baseAckDelay, + TimeUnit.MILLISECONDS.toMicros(peerMaxAckDelayMillis)); + } else { + ackDelayMicros = baseAckDelay; + } + } else { + // acks are not delayed during handshake + ackDelayMicros = 0; + } + long rttSample = sentTime.until(now, ChronoUnit.MICROS); + if (debug.on()) { + debug.log("New RTT sample on packet %s: %s us (delay %s us)", + largestAcknowledged, rttSample, + ackDelayMicros); + } + rttEstimator.consumeRttSample( + rttSample, + ackDelayMicros, + now + ); + } else { + if (debug.on()) { + debug.log("RTT sample on packet %s ignored: not ack eliciting", + largestAcknowledged); + } + } + if (packetNumberSpace != PacketNumberSpace.INITIAL) { + rttEstimator.resetPtoBackoff(); + } + purgeSendTimes(largestAcknowledged); + // complete PingRequests if needed + processPingResponses(largestAcknowledged); + } else { + if (debug.on()) { + debug.log("RTT sample on packet %s ignored: not largest", + largestAcknowledged); + } + } + + pendingRetransmission.removeIf((p) -> trackAcknowlegment(p, frame)); + triggeredForRetransmission.removeIf((p) -> trackAcknowlegment(p, frame)); + for (Iterator iterator = pendingAcknowledgements.iterator(); iterator.hasNext(); ) { + PendingAcknowledgement p = iterator.next(); + if (trackAcknowlegment(p, frame)) { + iterator.remove(); + congestionController.packetAcked(p.packet.size(), p.sent); + } + } + lostCount = detectAndRemoveLostPackets(now); + @SuppressWarnings({"unchecked","rawtypes"}) + List[] recovered= Log.quicRetransmit() ? new List[1] : null; + lostPackets.removeIf((p) -> isAcknowledgingLostPacket(p, frame, recovered)); + if (recovered != null && recovered[0] != null) { + Log.logQuic("{0} lost packets recovered: {1}({2}) total unrecovered {3}, unacknowledged {4}", + packetEmitter.logTag(), packetType(), + recovered[0].stream().map(PendingAcknowledgement::packetNumber).toList(), + lostPackets.size(), pendingAcknowledgements.size() + pendingRetransmission.size()); + } + } finally { + transferLock.unlock(); + } + + long largestAckAcked = emittedAckTracker.largestAckAcked(); + if (largestAckAcked > largestAckAckedBefore) { + if (debug.on()) { + debug.log("%s: largestAckAcked=%d - cleaning up AckFrame", + packetNumberSpace, largestAckAcked); + } + // remove ack ranges that we no longer need to acknowledge. + // this implements the algorithm described in RFC 9000, + // 13.2.4. Limiting Ranges by Tracking ACK Frames + cleanupAcks(); + } + + if (lostCount > 0) { + if (debug.on()) + debug.log("Found %s lost packets", lostCount); + // retransmit if possible + runTransmitter(); + } else if (blockedByCC && congestionController.canSendPacket()) { + // CC just got unblocked... send more data + blockedByCC = false; + runTransmitter(); + } else { + // RTT was updated, some packets might be lost, recompute timers + packetTransmissionTask.reschedule(); + } + } + + @Override + public void confirmHandshake() { + assert isApplicationSpace(); + if (handshakeConfirmedPN == 0) { + handshakeConfirmedPN = nextPN.get(); + } + } + + private void purgeSendTimes(long largestAcknowledged) { + sendTimes.headMap(largestAcknowledged, true).clear(); + } + + private long peerAckDelayToMicros(long ackDelay) { + return ackDelay << peerAckDelayExponent; + } + + private NextAckFrame getNextAck(boolean onlyOverdue, int maxSize) { + Deadline now = now(); + // This method is called to retrieve the AckFrame that will + // be embedded in the next packet sent to the peer. + // We therefore need to disarm the timer that will send a + // non-ACK eliciting packet with that AckFrame (if any) before + // returning the AckFrame. This is the purpose of the loop + // below... + while (true) { + NextAckFrame ack = nextAckFrame; + if (ack == null + || ack.deadline() == Deadline.MAX + || (onlyOverdue && ack.deadline().isAfter(now)) + || ack.sent() != null) { + return null; + } + // also reserve 3 bytes for the ack delay + if (ack.ackFrame().size() > maxSize - 3) return null; + NextAckFrame newAck = ack.withDeadline(Deadline.MAX, now); + boolean respin = !Handles.NEXTACK.compareAndSet(this, ack, newAck); + if (!respin) { + return ack; + } + } + } + + @Override + public AckFrame getNextAckFrame(boolean onlyOverdue) { + return getNextAckFrame(onlyOverdue, Integer.MAX_VALUE); + } + + @Override + public AckFrame getNextAckFrame(boolean onlyOverdue, int maxSize) { + if (closed) { + return null; + } + NextAckFrame ack = getNextAck(onlyOverdue, maxSize); + if (ack == null) { + return null; + } + long delay = ack.lastUpdated() + .until(now(), ChronoUnit.MICROS) >> ACK_DELAY_EXPONENT; + return ack.ackFrame().withAckDelay(delay); + } + + /** + * Returns the count of unacknowledged packets that were declared lost. + * The lost packets are moved from the pendingAcknowledgements + * into the pendingRetransmission. + * + * @param now current time, used for time-based loss detection. + */ + private int detectAndRemoveLostPackets(Deadline now) { + Deadline lossSendTime = now.minus(rttEstimator.getLossThreshold()); + int count = 0; + // debug.log("preparing for retransmission"); + transferLock.lock(); + try { + List lost = Log.quicRetransmit() ? new ArrayList<>() : null; + List packets = new ArrayList<>(); + Deadline firstSendTime = null, lastSendTime = null; + for (PendingAcknowledgement head = pendingAcknowledgements.peek(); + head != null && head.packetNumber < largestReceivedAckedPN; + head = pendingAcknowledgements.peek()) { + if (head.packetNumber < largestReceivedAckedPN - kPacketThreshold || + !lossSendTime.isBefore(head.sent)) { + if (debug.on()) { + debug.log("retransmit:head pn:" + head.packetNumber + + ",largest acked PN:" + largestReceivedAckedPN + + ",sent:" + head.sent + + ",lossSendTime:" + lossSendTime + ); + } + if (pendingAcknowledgements.remove(head)) { + pendingRetransmission.add(head); + triggeredForRetransmission.add(head); + packets.add(head.packet); + if (firstSendTime == null) { + firstSendTime = head.sent; + } + lastSendTime = head.sent; + var lp = head; + lostPackets.removeIf(p -> lp.hasPreviousNumber(p.packetNumber)); + lostPackets.add(head); + count++; + if (lost != null) lost.add(head); + } + } else { + if (debug.on()) { + debug.log("no retransmit:head pn:" + head.packetNumber + + ",largest acked PN:" + largestReceivedAckedPN + + ",sent:" + head.sent + + ",lossSendTime:" + lossSendTime + ); + } + break; + } + } + if (!packets.isEmpty()) { + // Persistent congestion is detected more aggressively than mandated by RFC 9002: + // - may be reported even if there's no prior RTT sample + // - may be reported even if there are acknowledged packets between the lost ones + boolean persistent = Deadline.between(firstSendTime, lastSendTime) + .compareTo(getPersistentCongestionDuration()) > 0; + congestionController.packetLost(packets, lastSendTime, persistent); + } + if (lost != null && !lost.isEmpty()) { + Log.logQuic("{0} lost packet {1}({2}) total unrecovered {3}, unacknowledged {4}", + packetEmitter.logTag(), + packetType(), lost.stream().map(PendingAcknowledgement::packetNumber).toList(), + lostPackets.size(), pendingAcknowledgements.size() + pendingRetransmission.size()); + } + } finally { + transferLock.unlock(); + } + return count; + } + + PacketType packetType() { + return switch (packetNumberSpace) { + case INITIAL -> PacketType.INITIAL; + case HANDSHAKE -> PacketType.HANDSHAKE; + case APPLICATION -> PacketType.ONERTT; + case NONE -> PacketType.NONE; + }; + } + + /** + * {@return true if PTO timer expired, false otherwise} + */ + private boolean isPTO(Deadline now) { + Deadline ptoDeadline = getPtoDeadline(); + return ptoDeadline != null && !ptoDeadline.isAfter(now); + } + + // returns true if this space is the APPLICATION space + private boolean isApplicationSpace() { + return packetNumberSpace == PacketNumberSpace.APPLICATION; + } + + // returns the PTO duration + Duration getPtoDuration() { + var pto = rttEstimator.getBasePtoDuration() + .plusMillis(peerMaxAckDelayMillis) + .multipliedBy(rttEstimator.getPtoBackoff()); + var max = QuicRttEstimator.MAX_PTO_BACKOFF_TIMEOUT; + // don't allow PTO > 240s + return pto.compareTo(max) > 0 ? max : pto; + } + + // returns the persistent congestion duration + Duration getPersistentCongestionDuration() { + return rttEstimator.getBasePtoDuration() + .plusMillis(peerMaxAckDelayMillis) + .multipliedBy(kPersistentCongestionThreshold); + } + + private Deadline getPtoDeadline() { + if (packetNumberSpace == PacketNumberSpace.INITIAL && lastAckElicitingTime != null) { + if (!quicTLSEngine.keysAvailable(QuicTLSEngine.KeySpace.HANDSHAKE)) { + // if handshake keys are not available, initial PTO must be set + return lastAckElicitingTime.plus(getPtoDuration()); + } + } + if (packetNumberSpace == PacketNumberSpace.HANDSHAKE) { + // set anti-deadlock timer + if (lastAckElicitingTime == null) { + lastAckElicitingTime = now(); + } + if (largestAckElicitingSentPN == -1) { + return lastAckElicitingTime.plus(getPtoDuration()); + } + } + if (largestAckElicitingSentPN <= largestReceivedAckedPN) { + return null; + } + // Application space deadline can only be set when handshake is confirmed + if (isApplicationSpace() && quicTLSEngine.getHandshakeState() != QuicTLSEngine.HandshakeState.HANDSHAKE_CONFIRMED) { + return null; + } + return lastAckElicitingTime.plus(getPtoDuration()); + } + + private Deadline getLossTimer() { + PendingAcknowledgement head = pendingAcknowledgements.peek(); + if (head == null || head.packetNumber >= largestReceivedAckedPN) { + return null; + } + if (head.packetNumber < largestReceivedAckedPN - kPacketThreshold) { + return Deadline.MIN; + } + return head.sent.plus(rttEstimator.getLossThreshold()); + } + + // Compute the new deadline when adding an ack-eliciting packet number + // to an ack frame which is not empty. + private Deadline computeNewDeadlineFor(AckFrame frame, Deadline now, Deadline deadline, + long packetNumber, long previousLargest, + long ackDelay) { + + boolean previousEliciting = !deadline.equals(Deadline.MAX); + + if (closed) return Deadline.MAX; + + if (previousEliciting) { + // RFC 9000 #13.2.2: + // We should send an ACK immediately after receiving two + // ACK-eliciting packets + if (debug.on()) { + debug.log("two ACK-Eliciting packets received: " + + "next ack deadline now"); + } + return now; + } else if (packetNumber < previousLargest) { + // RFC 9000 #13.2.1: + // if the packet has PN less than another ack-eliciting packet, + // send ACK frame as soon as possible + if (debug.on()) { + debug.log("ACK-Eliciting packet received out of order: " + + "next ack deadline now"); + } + return now; + } else if (packetNumber - 1 > previousLargest && previousLargest > -1) { + // RFC 9000 #13.2.1: + // Check whether there are gaps between this packet and the + // previous ACK-eliciting packet that was received: + // if we find any gap we should send an ACK frame as soon + // as possible + if (!frame.isRangeAcknowledged(previousLargest + 1, packetNumber)) { + if (debug.on()) { + debug.log("gaps detected between this packet" + + " and the previous ACK eliciting packet: " + + "next ack deadline now"); + } + return now; + } + } + // send ACK within max delay + return now.plusMillis(ackDelay); + } + + /** + * Used to request sending of a ping frame, for instance, to verify that + * the connection is alive. + * @return a completable future that will be completed with the time it + * took, in milliseconds, for the peer to acknowledge the packet that + * contained the PingFrame (or any packet that was sent after) + */ + @Override + public CompletableFuture requestSendPing() { + CompletableFuture pingRequested; + synchronized (this) { + if ((pingRequested = this.pingRequested) == null) { + pingRequested = this.pingRequested = new MinimalFuture<>(); + } + } + runTransmitter(); + return pingRequested; + } + + // Look at whether a ping frame should be sent with the + // next ACK frame... + // If a PING frame should be sent, return the new deadline (now) + // Otherwise, return Deadline.MAX; + // A PING frame will be sent if: + // - the AckFrame contains more than (10) ACK Ranges + // - and no ACK eliciting packet was sent, or the last ACK-eliciting was + // sent long enough ago - typically 1 PTO delay + // These numbers are implementation dependent and not defined in the RFC, but + // help implement a strategy that sends occasional PING frames to limit the size + // of the ACK frames - as described in RFC 9000. + // + // See RFC 9000 Section 13.2.4 + private boolean shouldSendPing(Deadline now, AckFrame frame) { + Deadline last = lastAckElicitingTime; + if (frame != null && + (last == null || + last.isBefore(now.minus(rttEstimator.getBasePtoDuration()))) + && frame.ackRanges().size() > MAX_ACKRANGE_COUNT_BEFORE_PING) { + return true; + } + return false; + } + + // TODO: store the builder instead of storing the AckFrame? + // storing a builder would mean locking - so it might not be a good + // idea. But creating a new builder and AckFrame each time means + // producing more garbage for the GC to collect. + // This method is called when a new packet is received, and it adds the + // received packet number to the next ACK frame to send out. + // If the packet is ACK eliciting it also arms a timeout (if needed) + // to make sure the packet will be acknowledged within the committed + // time frame. + private void addToAckFrame(long packetNumber, boolean isAckEliciting) { + + long largestAckEliciting = largestAckElicitingReceivedPN; + if (isAckEliciting) ackElicitingPacketProcessed(packetNumber); + + if (debug.on()) { + if (packetNumber < largestAckEliciting) { + debug.log("already received a larger ACK eliciting packet"); + } + } + + // compute a new AckFrame that includes the + // provided packet number + NextAckFrame nextAckFrame, ack = null; + boolean reschedule; + long largestAckAcked; + long newLargestAckAcked = -1; + do { + Deadline now = now(); + nextAckFrame = this.nextAckFrame; + var frame = nextAckFrame == null ? null : nextAckFrame.ackFrame(); + largestAckAcked = emittedAckTracker.largestAckAcked(); + boolean needNewFrame = (frame == null || !frame.isAcknowledging(packetNumber)) + && packetNumber > largestAckAcked; + if (needNewFrame) { + if (debug.on()) { + debug.log("Adding %s(%d) to ackFrame %s (ackEliciting %s)", + packetNumberSpace, packetNumber, nextAckFrame, isAckEliciting); + } + var builder = AckFrameBuilder + .ofNullable(frame) + .dropAcksBefore(largestAckAcked) + .addAck(packetNumber); + assert !builder.isEmpty(); + frame = builder.build(); + + // Note: we could optimize this if needed by simply using a max number of + // ranges: we could pre-compute the approximate size of a frame that has N ranges + // and use that. + final int maxFrameSize = QuicConnectionImpl.SMALLEST_MAXIMUM_DATAGRAM_SIZE - 100; + if (frame.size() > maxFrameSize) { + // frame is too big. We will drop some ranges + int ranges = frame.ackRanges().size(); + int index = ranges/3; + builder.dropAckRangesAfter(index); + newLargestAckAcked = builder.getLargestAckAcked(); + var newFrame = builder.build(); + if (Log.quicCC() || Log.quicRetransmit()) { + Log.logQuic("{0}: frame too big ({1} bytes) dropping ack ranges after {2}, " + + "will ignore packets smaller than {3} (new frame: {4} bytes)", + debugStrSupplier.get(), Integer.toString(frame.size()), + Integer.toString(index), Long.toString(newLargestAckAcked), + Integer.toString(newFrame.size())); + } + frame = newFrame; + assert frame.size() <= maxFrameSize; + } + assert frame.isAcknowledging(packetNumber); + if (nextAckFrame == null) { + if (debug.on()) debug.log("no previous ackframe"); + Deadline deadline = isAckEliciting + ? now.plusMillis(maxAckDelay) + : Deadline.MAX; + ack = new NextAckFrame(frame, deadline, now, null); + reschedule = isAckEliciting; + if (debug.on()) debug.log("next deadline: " + maxAckDelay); + } else { + Deadline deadline = nextAckFrame.deadline(); + Deadline nextDeadline = deadline; + boolean deadlineNotExpired = now.isBefore(deadline); + if (isAckEliciting && deadlineNotExpired) { + if (debug.on()) debug.log("computing new deadline for ackframe"); + nextDeadline = computeNewDeadlineFor(frame, now, deadline, + packetNumber, largestAckEliciting, maxAckDelay); + } + long millisToNext = nextDeadline.equals(Deadline.MAX) + ? Long.MAX_VALUE + : now.until(nextDeadline, ChronoUnit.MILLIS); + if (debug.on()) { + if (nextDeadline == Deadline.MAX) { + debug.log("next deadline is: Deadline.MAX"); + } else { + debug.log("next deadline is: " + millisToNext); + } + } + ack = new NextAckFrame(frame, nextDeadline, now, null); + reschedule = !nextDeadline.equals(deadline) + || millisToNext <= 0; + } + if (debug.on()) { + String delay = reschedule ? Utils.millis(now(), ack.deadline()) + : "not rescheduled"; + debug.log("%s: new ackFrame composed: %s - reschedule=%s", + packetNumberSpace, ack.ackFrame(), delay); + } + } else { + reschedule = false; + if (debug.on()) { + debug.log("packet %s(%d) is already in ackFrame %s", + packetNumberSpace, packetNumber, nextAckFrame); + } + break; + } + } while (!Handles.NEXTACK.compareAndSet(this, nextAckFrame, ack)); + + if (newLargestAckAcked >= 0) { + // we reduced the frame because it was too big: we need to ignore + // packets that are larger than the new largest ignored packet. + // this is now our new de-facto 'largestAckAcked' even if it wasn't + // really acked by the peer + emittedAckTracker.dropPacketNumbersSmallerThan(newLargestAckAcked); + } + + var ackFrame = ack == null ? null : ack.ackFrame(); + assert packetNumber <= largestAckAcked + || ackFrame != null && ackFrame.isAcknowledging(packetNumber) + || nextAckFrame != null && nextAckFrame.ackFrame() != null + && nextAckFrame.ackFrame.isAcknowledging(packetNumber) + : "packet %s(%s) should be in ackFrame" + .formatted(packetNumberSpace, packetNumber); + + if (reschedule) { + runTransmitter(); + } + } + + void debugState() { + if (debug.on()) { + debug.log("state: %s", isClosed() ? "closed" : "opened" ); + debug.log("AckFrame: " + nextAckFrame); + String pendingAcks = pendingAcknowledgements.stream() + .map(PendingAcknowledgement::prettyPrint) + .collect(Collectors.joining(", ", "(", ")")); + String pendingRetransmit = pendingRetransmission.stream() + .map(PendingAcknowledgement::prettyPrint) + .collect(Collectors.joining(", ", "(", ")")); + debug.log("Pending acks: %s", pendingAcks); + debug.log("Pending retransmit: %s", pendingRetransmit); + } + } + + void debugState(String prefix, StringBuilder sb) { + String state = isClosed() ? "closed" : "opened"; + sb.append(prefix).append("State: ").append(state).append('\n'); + sb.append(prefix).append("AckFrame: ").append(nextAckFrame).append('\n'); + String pendingAcks = pendingAcknowledgements.stream() + .map(PendingAcknowledgement::prettyPrint) + .collect(Collectors.joining(", ", "(", ")")); + String pendingRetransmit = pendingRetransmission.stream() + .map(PendingAcknowledgement::prettyPrint) + .collect(Collectors.joining(", ", "(", ")")); + sb.append(prefix).append("Pending acks: ").append(pendingAcks).append('\n'); + sb.append(prefix).append("Pending retransmit: ").append(pendingRetransmit); + } + + @Override + public boolean isAcknowledged(long packetNumber) { + var ack = nextAckFrame; + var ackFrame = ack == null ? null : ack.ackFrame(); + var largestProcessed = largestProcessedPN; + // if ackFrame is null it means all packets <= largestProcessedPN + // have been acked. + if (ackFrame == null) return packetNumber <= largestProcessed; + if (packetNumber > largestProcessed) return false; + var largestAckedPNReceivedByPeer = this.largestAckedPNReceivedByPeer; + if (packetNumber <= largestAckedPNReceivedByPeer) return true; + return ackFrame.isAcknowledging(packetNumber); + } + + @Override + public void fastRetransmit() { + assert packetNumberSpace == PacketNumberSpace.INITIAL; + if (closed || fastRetransmitDone) { + return; + } + fastRetransmit = true; + if (Log.quicControl() || Log.quicRetransmit()) { + Log.logQuic("Scheduling fast retransmit"); + } else if (debug.on()) { + debug.log("Scheduling fast retransmit"); + } + runTransmitter(); + + } + + private static Deadline min(Deadline one, Deadline two) { + return two.isAfter(one) ? one : two; + } + + // This implements the algorithm described in RFC 9000: + // 13.2.4. Limiting Ranges by Tracking ACK Frames + private void cleanupAcks() { + // clean up the next ACK frame, removing all packets <= largestAckAcked + NextAckFrame nextAckFrame, ack = null; + long largestAckAcked; + do { + nextAckFrame = this.nextAckFrame; + if (nextAckFrame == null) return; // nothing to do! + var frame = nextAckFrame.ackFrame(); + largestAckAcked = emittedAckTracker.largestAckAcked(); + boolean needNewFrame = frame != null + && frame.smallestAcknowledged() <= largestAckAcked; + if (needNewFrame) { + if (debug.on()) { + debug.log("Dropping all acks below %s(%d) in ackFrame %s", + packetNumberSpace, largestAckAcked, nextAckFrame); + } + var builder = AckFrameBuilder + .ofNullable(frame) + .dropAcksBefore(largestAckAcked); + frame = builder.isEmpty() ? null : builder.build(); + if (frame == null) { + ack = null; + if (debug.on()) { + debug.log("%s: ackFrame cleared - nothing to acknowledge", + packetNumberSpace); + } + } else { + Deadline deadline = nextAckFrame.deadline(); + ack = new NextAckFrame(frame, deadline, + nextAckFrame.lastUpdated(), nextAckFrame.sent()); + if (debug.on()) { + debug.log("%s: ackFrame cleaned up: %s", + packetNumberSpace, ack.ackFrame()); + } + } + } else { + if (debug.on()) { + debug.log("%s: no packet smaller than %d in ackFrame %s", + packetNumberSpace, largestAckAcked, nextAckFrame); + } + break; + } + } while (!Handles.NEXTACK.compareAndSet(this, nextAckFrame, ack)); + + var ackFrame = ack == null ? null : ack.ackFrame(); + assert ackFrame == null || ackFrame.smallestAcknowledged() > largestAckAcked + : "%s(pn > %s) should not acknowledge packet <= %s" + .formatted(packetNumberSpace, ackFrame.smallestAcknowledged(), largestAckAcked); + } + + private long ackElicitingPacketProcessed(long packetNumber) { + long largestPN; + do { + largestPN = largestAckElicitingReceivedPN; + if (largestPN >= packetNumber) return largestPN; + } while (!Handles.LARGEST_ACK_ELICITING_RECEIVED_PN + .compareAndSet(this, largestPN, packetNumber)); + return packetNumber; + } + + private long packetProcessed(long packetNumber) { + long largestPN; + do { + largestPN = largestProcessedPN; + if (largestPN >= packetNumber) return largestPN; + } while (!Handles.LARGEST_PROCESSED_PN + .compareAndSet(this, largestPN, packetNumber)); + return packetNumber; + } + + /** + * Theoretically we should wait for the packet that contains the + * ping frame to be acknowledged, but if we receive the ack of a + * packet with a larger number, we can assume that the connection + * is still alive, and therefore complete the ping response. + * @param packetNumber the acknowledged packet number + */ + private void processPingResponses(long packetNumber) { + if (pendingPingRequests.isEmpty()) return; + var iterator = pendingPingRequests.iterator(); + while (iterator.hasNext()) { + var pr = iterator.next(); + if (pr.packetNumber() <= packetNumber) { + iterator.remove(); + pr.response().complete(pr.sent().until(now(), ChronoUnit.MILLIS)); + } else { + // this is a queue, so the PingRequest with the smaller + // packet number will be at the head. We can stop iterating + // as soon as we find a PingRequest that has a packet + // number larger than the one acknowledged. + break; + } + } + } + + private long largestAckSent(long packetNumber) { + long largestPN; + do { + largestPN = largestSentAckedPN; + if (largestPN >= packetNumber) return largestPN; + } while (!Handles.LARGEST_SENT_ACKED_PN + .compareAndSet(this, largestPN, packetNumber)); + return packetNumber; + } + + private boolean largestAckReceived(long packetNumber) { + long largestPN; + do { + largestPN = largestReceivedAckedPN; + if (largestPN >= packetNumber) return false; // already up to date + } while (!Handles.LARGEST_RECEIVED_ACKED_PN + .compareAndSet(this, largestPN, packetNumber)); + return true; // updated + } + + // records the time at which the last ACK-eliciting packet was sent. + // This has the side effect of resetting the nextPingTime to Deadline.MAX + // The logic is that a PING frame only need to be sent if no ACK-eliciting + // packet has been sent for some time (and the AckFrame has grown big enough). + // See RFC 9000 - Section 13.2.4 + private Deadline lastAckElicitingSent(Deadline now) { + Deadline max; + if (debug.on()) + debug.log("Updating last send time to %s", now); + do { + max = lastAckElicitingTime; + if (max != null && !now.isAfter(max)) return max; + } while (!Handles.LAST_ACK_ELICITING_TIME + .compareAndSet(this, max, now)); + return now; + } + + /** + * returns the TLS encryption level of this packet space as specified + * in RFC-9001, section 4, table 1. + */ + private QuicTLSEngine.KeySpace tlsEncryptionLevel() { + return switch (this.packetNumberSpace) { + case INITIAL -> QuicTLSEngine.KeySpace.INITIAL; + // APPLICATION packet space could even mean 0-RTT, but currently we don't support 0-RTT + case APPLICATION -> QuicTLSEngine.KeySpace.ONE_RTT; + case HANDSHAKE -> QuicTLSEngine.KeySpace.HANDSHAKE; + default -> throw new IllegalStateException("No known TLS encryption level" + + " for packet space: " + this.packetNumberSpace); + }; + } + + // VarHandle provide the same atomic compareAndSet functionality + // than AtomicXXXXX classes, but without the additional cost in + // footprint. + private static final class Handles { + private Handles() {throw new InternalError();} + static final VarHandle DEADLINE; + static final VarHandle NEXTACK; + static final VarHandle LARGEST_PROCESSED_PN; + static final VarHandle LARGEST_ACK_ELICITING_RECEIVED_PN; + static final VarHandle LARGEST_RECEIVED_ACKED_PN; + static final VarHandle LARGEST_SENT_ACKED_PN; + static final VarHandle LARGEST_ACK_ACKED_PN; + static final VarHandle LAST_ACK_ELICITING_TIME; + static final VarHandle IGNORE_ALL_PN_BEFORE; + static { + Lookup lookup = MethodHandles.lookup(); + try { + Class srt = PacketTransmissionTask.class; + DEADLINE = lookup.findVarHandle(srt, "nextDeadline", Deadline.class); + + Class pmc = PacketSpaceManager.class; + LAST_ACK_ELICITING_TIME = lookup.findVarHandle(pmc, + "lastAckElicitingTime", Deadline.class); + NEXTACK = lookup.findVarHandle(pmc, "nextAckFrame", NextAckFrame.class); + LARGEST_RECEIVED_ACKED_PN = lookup + .findVarHandle(pmc, "largestReceivedAckedPN", long.class); + LARGEST_SENT_ACKED_PN = lookup + .findVarHandle(pmc, "largestSentAckedPN", long.class); + LARGEST_PROCESSED_PN = lookup + .findVarHandle(pmc, "largestProcessedPN", long.class); + LARGEST_ACK_ELICITING_RECEIVED_PN = lookup + .findVarHandle(pmc, "largestAckElicitingReceivedPN", long.class); + LARGEST_ACK_ACKED_PN = lookup + .findVarHandle(pmc, "largestAckedPNReceivedByPeer", long.class); + + Class eat = EmittedAckTracker.class; + IGNORE_ALL_PN_BEFORE = lookup + .findVarHandle(eat, "ignoreAllPacketsBefore", long.class); + } catch (Exception e) { + throw new ExceptionInInitializerError(e); + + } + } + } + + static final class OneRttPacketSpaceManager extends PacketSpaceManager + implements QuicOneRttContext { + + OneRttPacketSpaceManager(final QuicConnectionImpl connection) { + super(connection, PacketNumberSpace.APPLICATION); + } + } + + static final class HandshakePacketSpaceManager extends PacketSpaceManager { + private final PacketSpaceManager initialPktSpaceMgr; + private final boolean isClientConnection; + private final AtomicBoolean firstPktSent = new AtomicBoolean(); + + HandshakePacketSpaceManager(final QuicConnectionImpl connection, + final PacketSpaceManager initialPktSpaceManager) { + super(connection, PacketNumberSpace.HANDSHAKE); + this.isClientConnection = connection.isClientConnection(); + this.initialPktSpaceMgr = initialPktSpaceManager; + } + + @Override + public void packetSent(QuicPacket packet, long previousPacketNumber, long packetNumber) { + super.packetSent(packet, previousPacketNumber, packetNumber); + if (!isClientConnection) { + // nothing additional to be done for server connections + return; + } + if (firstPktSent.compareAndSet(false, true)) { + // if this is the first packet we sent in the HANDSHAKE keyspace + // then we close the INITIAL space discard the INITIAL keys. + // RFC-9000, section 17.2.2.1: + // A client stops both sending and processing Initial packets when it sends + // its first Handshake packet. ... Though packets might still be in flight or + // awaiting acknowledgment, no further Initial packets need to be exchanged + // beyond this point. Initial packet protection keys are discarded along with + // any loss recovery and congestion control state + if (debug.on()) { + debug.log("first handshake packet sent by client, initiating close of" + + " INITIAL packet space"); + } + this.initialPktSpaceMgr.close(); + } + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/PeerConnIdManager.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/PeerConnIdManager.java new file mode 100644 index 00000000000..065d045b57c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/PeerConnIdManager.java @@ -0,0 +1,520 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.NavigableMap; +import java.util.NavigableSet; +import java.util.Queue; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.concurrent.locks.ReentrantLock; + +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.frames.NewConnectionIDFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.frames.RetireConnectionIDFrame; +import jdk.internal.net.http.quic.packets.InitialPacket; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import static jdk.internal.net.http.quic.QuicConnectionId.MAX_CONNECTION_ID_LENGTH; +import static jdk.internal.net.quic.QuicTransportErrors.PROTOCOL_VIOLATION; + +/** + * Manages the connection ids advertised by a peer of a connection. + * - Handles incoming NEW_CONNECTION_ID frames, + * - produces outgoing RETIRE_CONNECTION_ID frames, + * - registers received stateless reset tokens with the QuicEndpoint + * Additionally on the client side: + * - handles incoming transport parameters (preferred_address, stateless_reset_token) + * - stores original and retry peer IDs + */ +// TODO implement voluntary switching of connection IDs +final class PeerConnIdManager { + private final Logger debug; + private final QuicConnectionImpl connection; + private final String logTag; + private final boolean isClient; + + private enum State { + INITIAL_PKT_NOT_RECEIVED_FROM_PEER, + RETRY_PKT_RECEIVED_FROM_PEER, + PEER_CONN_ID_FINALIZED + } + + private volatile State state = State.INITIAL_PKT_NOT_RECEIVED_FROM_PEER; + + private QuicConnectionId clientSelectedDestConnId; + private QuicConnectionId peerDecidedRetryConnId; + // sequence number of active connection ID + private long activeConnIdSeq = -1; + private QuicConnectionId activeConnId; + + // the connection ids (there can be more than one) with which the peer identifies this connection. + // the key of this Map is a (RFC defined) sequence number for the connection id + private final NavigableMap peerConnectionIds = + Collections.synchronizedNavigableMap(new TreeMap<>()); + // the connection id sequence numbers that we haven't received yet. + // We need to know which sequence numbers are retired, and which are not assigned yet + private final NavigableSet gaps = + Collections.synchronizedNavigableSet(new TreeSet<>()); + // the connection id sequence numbers that are awaiting retirement. + private final Queue toRetire = new ArrayDeque<>(); + // the largest retirePriorTo value received across NEW_CONNECTION_ID frames + private volatile long largestReceivedRetirePriorTo = -1; // -1 implies none received so far + // the largest sequenceNumber value received across NEW_CONNECTION_ID frames + private volatile long largestReceivedSequenceNumber; + private final ReentrantLock lock = new ReentrantLock(); + + PeerConnIdManager(final QuicConnectionImpl connection, final String dbTag) { + this.isClient = connection.isClientConnection(); + this.debug = Utils.getDebugLogger(() -> dbTag); + this.logTag = connection.logTag(); + this.connection = connection; + } + + /** + * Save the client-selected original server connection ID + * + * @param peerConnId the client-selected original server connection ID + */ + void originalServerConnId(final QuicConnectionId peerConnId) { + lock.lock(); + try { + final var st = this.state; + if (st != State.INITIAL_PKT_NOT_RECEIVED_FROM_PEER) { + throw new IllegalStateException("Cannot associate a client selected peer id" + + " in current state " + st); + } + this.clientSelectedDestConnId = peerConnId; + this.activeConnId = peerConnId; + } finally { + lock.unlock(); + } + } + + /** + * {@return the client-selected original server connection ID} + */ + QuicConnectionId originalServerConnId() { + lock.lock(); + try { + final var id = this.clientSelectedDestConnId; + if (id == null) { + throw new IllegalArgumentException("Original (peer) connection id not yet set"); + } + return id; + } finally { + lock.unlock(); + } + } + + /** + * Save the server-selected retry connection ID + * + * @param peerConnId the server-selected retry connection ID + */ + void retryConnId(final QuicConnectionId peerConnId) { + if (!isClient) { + throw new IllegalStateException("Should not be used on the server"); + } + lock.lock(); + try { + final var st = this.state; + if (st != State.INITIAL_PKT_NOT_RECEIVED_FROM_PEER) { + throw new IllegalStateException("Cannot associate a peer id, from retry packet," + + " in current state " + st); + } + this.peerDecidedRetryConnId = peerConnId; + this.activeConnId = peerConnId; + this.state = State.RETRY_PKT_RECEIVED_FROM_PEER; + } finally { + lock.unlock(); + } + } + + /** + * Returns the connectionId the server included in the Source Connection ID field of a + * Retry packet. May be null. + * + * @return the connection id sent in the server's retry packet + */ + QuicConnectionId retryConnId() { + lock.lock(); + try { + return this.peerDecidedRetryConnId; + } finally { + lock.unlock(); + } + } + + /** + * The peer in its INITIAL packet would have sent a connection id representing itself. That + * connection id may not be the same that we might have sent in the INITIAL packet. If it isn't + * the same, then we switch the peer connection id, that we keep track off, to the one that + * the peer has chosen. + * + * @param initialPacket the INITIAL packet from the peer + */ + void finalizeHandshakePeerConnId(final InitialPacket initialPacket) throws QuicTransportException { + lock.lock(); + try { + final QuicConnectionId sourceId = initialPacket.sourceId(); + final var st = this.state; + if (st == State.PEER_CONN_ID_FINALIZED) { + // we have already finalized the peer connection id, through a previous INITIAL + // packet receipt (there can be more than one INITIAL packets). + // now we just verify that this INITIAL packet too has the finalized peer connection + // id and if it doesn't then we throw an exception + final QuicConnectionId handshakePeerConnId = this.peerConnectionIds.get(0L); + assert handshakePeerConnId != null : "Handshake peer connection id is unavailable"; + if (!handshakePeerConnId.equals(sourceId)) { + throw new QuicTransportException("Invalid source connection id in INITIAL packet", + QuicTLSEngine.KeySpace.INITIAL, 0, PROTOCOL_VIOLATION); + } + return; + } + // this is the first INITIAL packet from the peer, so we finalize the peer connection id + final PeerConnectionId handshakePeerConnId = new PeerConnectionId(sourceId.getBytes()); + // at this point we have either switched to a new peer connection id (chosen by the peer) + // or have agreed to use the one we chose for the peer. In either case, we register this + // as the handshake peer connection id with sequence number 0. + // RFC-9000, section 5.1.1: The initial connection ID issued by an endpoint is sent in + // the Source Connection ID field of the long packet header during the handshake. + // The sequence number of the initial connection ID is 0. + this.peerConnectionIds.put(0L, handshakePeerConnId); + this.state = State.PEER_CONN_ID_FINALIZED; + this.activeConnIdSeq = 0; + this.activeConnId = handshakePeerConnId; + if (debug.on()) { + debug.log("scid: %s finalized handshake peerConnectionId as: %s", + connection.localConnectionId().toHexString(), + handshakePeerConnId.toHexString()); + } + } finally { + lock.unlock(); + } + } + + /** + * Save the connection ID from the preferred address QUIC transport parameter + * + * @param preferredConnId preferred connection ID + * @param preferredStatelessResetToken preferred stateless reset token + */ + void handlePreferredAddress(final ByteBuffer preferredConnId, + final byte[] preferredStatelessResetToken) { + if (!isClient) { + throw new IllegalStateException("Should not be used on the server"); + } + lock.lock(); + try { + final PeerConnectionId peerConnId = new PeerConnectionId(preferredConnId, + preferredStatelessResetToken); + // keep track of this peer connection id + // RFC-9000, section 5.1.1: If the preferred_address transport parameter is sent, + // the sequence number of the supplied connection ID is 1 + assert largestReceivedSequenceNumber == 0; + this.peerConnectionIds.put(1L, peerConnId); + largestReceivedSequenceNumber = 1; + } finally { + lock.unlock(); + } + } + + /** + * Save the stateless reset token QUIC transport parameter + * + * @param statelessResetToken stateless reset token + */ + void handshakeStatelessResetToken(final byte[] statelessResetToken) { + if (!isClient) { + throw new IllegalStateException("Should not be used on the server"); + } + lock.lock(); + try { + final QuicConnectionId handshakeConnId = this.peerConnectionIds.get(0L); + if (handshakeConnId == null) { + throw new IllegalStateException("No handshake peer connection available"); + } + // recreate the conn id with the stateless token + this.peerConnectionIds.put(0L, new PeerConnectionId(handshakeConnId.asReadOnlyBuffer(), + statelessResetToken)); + // register with the endpoint + connection.endpoint().associateStatelessResetToken(statelessResetToken, connection); + } finally { + lock.unlock(); + } + } + + /** + * {@return the active peer connection ID} + */ + QuicConnectionId getPeerConnId() { + lock.lock(); + try { + if (activeConnIdSeq < largestReceivedRetirePriorTo) { + // stop using the old connection ID + switchConnectionId(); + } + return activeConnId; + } finally { + lock.unlock(); + } + } + + private QuicConnectionId getPeerConnId(final long sequenceNum) { + assert lock.isHeldByCurrentThread(); + return this.peerConnectionIds.get(sequenceNum); + } + + /** + * Process the incoming NEW_CONNECTION_ID frame. + * + * @param newCid the NEW_CONNECTION_ID frame + * @throws QuicTransportException if the frame violates the protocol + */ + void handleNewConnectionIdFrame(final NewConnectionIDFrame newCid) + throws QuicTransportException { + if (debug.on()) { + debug.log("Received NEW_CONNECTION_ID frame: %s", newCid); + } + // pre-checks + final long sequenceNumber = newCid.sequenceNumber(); + assert sequenceNumber >= 0 : "negative sequence number disallowed in new connection id frame"; + final long retirePriorTo = newCid.retirePriorTo(); + if (retirePriorTo > sequenceNumber) { + // RFC 9000, section 19.15: Receiving a value in the Retire Prior To field that is greater + // than that in the Sequence Number field MUST be treated as a connection error of + // type FRAME_ENCODING_ERROR + throw new QuicTransportException("Invalid retirePriorTo " + retirePriorTo, + QuicTLSEngine.KeySpace.ONE_RTT, + newCid.getTypeField(), QuicTransportErrors.FRAME_ENCODING_ERROR); + } + final ByteBuffer connectionId = newCid.connectionId(); + final int connIdLength = connectionId.remaining(); + if (connIdLength < 1 || connIdLength > MAX_CONNECTION_ID_LENGTH) { + // RFC-9000, section 19.15: Values less than 1 and greater than 20 are invalid and + // MUST be treated as a connection error of type FRAME_ENCODING_ERROR + throw new QuicTransportException("Invalid connection id length " + connIdLength, + QuicTLSEngine.KeySpace.ONE_RTT, + newCid.getTypeField(), QuicTransportErrors.FRAME_ENCODING_ERROR); + } + final ByteBuffer statelessResetToken = newCid.statelessResetToken(); + assert statelessResetToken.remaining() == QuicConnectionImpl.RESET_TOKEN_LENGTH; + lock.lock(); + try { + // see if we have received any connection ids for this same sequence number. + // this is possible if the packet containing the new connection id frame was retransmitted. + // the connection id for such a (duplicate) sequence number is expected to be the same. + // RFC-9000, section 19.15: if a sequence number is used for different connection IDs, + // the endpoint MAY treat that receipt as a connection error of type PROTOCOL_VIOLATION + final QuicConnectionId previousConnIdForSeqNum = getPeerConnId(sequenceNumber); + if (previousConnIdForSeqNum != null) { + if (previousConnIdForSeqNum.matches(connectionId)) { + // frame with same sequence number and connection id, probably a retransmission. + // ignore this frame + if (Log.trace()) { + Log.logTrace("{0} Ignoring (duplicate) new connection id frame with" + + " sequence number {1}", logTag, sequenceNumber); + } + if (debug.on()) { + debug.log("Ignoring (duplicate) new connection id frame with" + + " sequence number %d", sequenceNumber); + } + return; + } + // mismatch, throw protocol violation error + throw new QuicTransportException("Invalid connection id in (duplicated)" + + " new connection id frame with sequence number " + sequenceNumber, + QuicTLSEngine.KeySpace.ONE_RTT, + newCid.getTypeField(), PROTOCOL_VIOLATION); + } + if ((sequenceNumber <= largestReceivedSequenceNumber && !gaps.contains(sequenceNumber)) + || sequenceNumber < largestReceivedRetirePriorTo) { + if (Log.trace()) { + Log.logTrace("{0} Ignoring (retired) new connection id frame with" + + " sequence number {1}", logTag, sequenceNumber); + } + if (debug.on()) { + debug.log("Ignoring (retired) new connection id frame with" + + " sequence number %d", sequenceNumber); + } + return; + } + long numConnIdsToAdd = Math.max(sequenceNumber - largestReceivedSequenceNumber, 0); + final long numCurrentActivePeerConnIds = this.peerConnectionIds.size() + this.gaps.size(); + // we can temporarily store up to 3x the active connection ID limit, + // including active and retired IDs. + if (numCurrentActivePeerConnIds + numConnIdsToAdd + toRetire.size() + > 3 * this.connection.getLocalActiveConnIDLimit()) { + // RFC-9000, section 5.1.1: After processing a NEW_CONNECTION_ID frame and adding and + // retiring active connection IDs, if the number of active connection IDs exceeds + // the value advertised in its active_connection_id_limit transport parameter, + // an endpoint MUST close the connection with an error of type CONNECTION_ID_LIMIT_ERROR + throw new QuicTransportException("Connection id limit reached", + QuicTLSEngine.KeySpace.ONE_RTT, newCid.getTypeField(), + QuicTransportErrors.CONNECTION_ID_LIMIT_ERROR); + } + // end pre-checks + // if we reached here, the number of connection IDs is less than twice the active limit. + // Insert gaps for the sequence numbers we haven't seen yet + insertGaps(sequenceNumber); + // Update the list of sequence numbers to retire + retirePriorTo(retirePriorTo); + // insert the new connection ID + final byte[] statelessResetTokenBytes = new byte[QuicConnectionImpl.RESET_TOKEN_LENGTH]; + statelessResetToken.get(statelessResetTokenBytes); + final PeerConnectionId newPeerConnId = new PeerConnectionId(connectionId, statelessResetTokenBytes); + final var previous = this.peerConnectionIds.putIfAbsent(sequenceNumber, newPeerConnId); + assert previous == null : "A peer connection id already exists for sequence number " + + sequenceNumber; + // post-checks + // now we can accurately check the number of active and retired connection IDs + if (peerConnectionIds.size() + gaps.size() + > this.connection.getLocalActiveConnIDLimit()) { + // RFC-9000, section 5.1.1: After processing a NEW_CONNECTION_ID frame and adding and + // retiring active connection IDs, if the number of active connection IDs exceeds + // the value advertised in its active_connection_id_limit transport parameter, + // an endpoint MUST close the connection with an error of type CONNECTION_ID_LIMIT_ERROR + throw new QuicTransportException("Active connection id limit reached", + QuicTLSEngine.KeySpace.ONE_RTT, newCid.getTypeField(), + QuicTransportErrors.CONNECTION_ID_LIMIT_ERROR); + } + if (toRetire.size() > 2 * this.connection.getLocalActiveConnIDLimit()) { + // RFC-9000, section 5.1.2: + // An endpoint SHOULD limit the number of connection IDs it has retired locally for + // which RETIRE_CONNECTION_ID frames have not yet been acknowledged. + // An endpoint SHOULD allow for sending and tracking a number + // of RETIRE_CONNECTION_ID frames of at least twice the value + // of the active_connection_id_limit transport parameter + throw new QuicTransportException("Retired connection id limit reached: " + toRetire, + QuicTLSEngine.KeySpace.ONE_RTT, newCid.getTypeField(), + QuicTransportErrors.CONNECTION_ID_LIMIT_ERROR); + } + if (this.largestReceivedRetirePriorTo < retirePriorTo) { + this.largestReceivedRetirePriorTo = retirePriorTo; + } + if (this.largestReceivedSequenceNumber < sequenceNumber) { + this.largestReceivedSequenceNumber = sequenceNumber; + } + } finally { + lock.unlock(); + } + } + + private void switchConnectionId() { + assert lock.isHeldByCurrentThread(); + // the caller is expected to retire the active connection id prior to calling this + assert !peerConnectionIds.containsKey(activeConnIdSeq); + Map.Entry entry = peerConnectionIds.ceilingEntry(largestReceivedRetirePriorTo); + activeConnIdSeq = entry.getKey(); + activeConnId = entry.getValue(); + // link the peer issued stateless reset token to this connection + final QuicEndpoint endpoint = this.connection.endpoint(); + endpoint.associateStatelessResetToken(entry.getValue().getStatelessResetToken(), this.connection); + + if (Log.trace()) { + Log.logTrace("{0} Switching to connection ID {1}", logTag, activeConnIdSeq); + } + if (debug.on()) { + debug.log("Switching to connection ID %d", activeConnIdSeq); + } + } + + private void insertGaps(long sequenceNumber) { + assert lock.isHeldByCurrentThread(); + for (long i = largestReceivedSequenceNumber + 1; i < sequenceNumber; i++) { + gaps.add(i); + } + } + + private void retirePriorTo(final long priorTo) { + assert lock.isHeldByCurrentThread(); + // remove/retire (in preparation of sending a RETIRE_CONNECTION_ID frames) + for (Iterator> iterator = peerConnectionIds.entrySet().iterator(); iterator.hasNext(); ) { + Map.Entry entry = iterator.next(); + final long seqNumToRetire = entry.getKey(); + if (seqNumToRetire >= priorTo) { + break; + } + iterator.remove(); + toRetire.add(seqNumToRetire); + // Note that the QuicEndpoint only stores local connection ids and doesn't store peer + // connection ids. It does however store the peer-issued stateless reset token of a + // peer connection id, so we let the endpoint know that the stateless reset token needs + // to be forgotten since the corresponding peer connection id is being retired + final byte[] resetTokenToForget = entry.getValue().getStatelessResetToken(); + if (resetTokenToForget != null) { + this.connection.endpoint().forgetStatelessResetToken(resetTokenToForget); + } + } + for (Iterator iterator = gaps.iterator(); iterator.hasNext(); ) { + Long gap = iterator.next(); + if (gap >= priorTo) { + return; + } + iterator.remove(); + toRetire.add(gap); + } + } + + /** + * Produce a queued RETIRE_CONNECTION_ID frame, if it fits in the packet + * + * @param remaining bytes remaining in the packet + * @return a RetireConnectionIdFrame, or null if none is queued or remaining is too low + */ + public QuicFrame nextFrame(int remaining) { + // retire connection id: + // type - 1 byte + // sequence number - var int + if (remaining < 9) { + return null; + } + lock.lock(); + try { + final Long seqNumToRetire = toRetire.poll(); + if (seqNumToRetire != null) { + if (seqNumToRetire == activeConnIdSeq) { + // can't send this connection ID yet, we will send it in the next packet + toRetire.add(seqNumToRetire); + return null; + } + return new RetireConnectionIDFrame(seqNumToRetire); + } + return null; + } finally { + lock.unlock(); + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/PeerConnectionId.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/PeerConnectionId.java new file mode 100644 index 00000000000..0c6f946d1a2 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/PeerConnectionId.java @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.nio.ByteBuffer; +import java.util.HexFormat; + + +/** + * A free-form connection ID to wrap the connection ID bytes + * sent by the peer. + * Client and server might impose some structure on the + * connection ID bytes. For instance, they might choose to + * encode the connection ID length in the connection ID bytes. + * This class makes no assumption on the structure of the + * connection id bytes. + */ +public final class PeerConnectionId extends QuicConnectionId { + private final byte[] statelessResetToken; + + /** + * A new {@link QuicConnectionId} represented by the given bytes. + * @param connId The connection ID bytes. + */ + public PeerConnectionId(final byte[] connId) { + super(ByteBuffer.wrap(connId.clone())); + this.statelessResetToken = null; + } + + /** + * A new {@link QuicConnectionId} represented by the given bytes. + * @param connId The connection ID bytes. + * @param statelessResetToken The stateless reset token to be associated with this connection id. + * Can be null. + * @throws IllegalArgumentException If the {@code statelessResetToken} is non-null and if its + * length isn't 16 bytes + * + */ + public PeerConnectionId(final ByteBuffer connId, final byte[] statelessResetToken) { + super(cloneBuffer(connId)); + if (statelessResetToken != null) { + if (statelessResetToken.length != 16) { + throw new IllegalArgumentException("Invalid stateless reset token length " + + statelessResetToken.length); + } + this.statelessResetToken = statelessResetToken.clone(); + } else { + this.statelessResetToken = null; + } + } + + private static ByteBuffer cloneBuffer(ByteBuffer src) { + final byte[] idBytes = new byte[src.remaining()]; + src.get(idBytes); + return ByteBuffer.wrap(idBytes); + } + + /** + * {@return the stateless reset token associated with this connection id. returns null if no + * token exists} + */ + public byte[] getStatelessResetToken() { + return this.statelessResetToken == null ? null : this.statelessResetToken.clone(); + } + + @Override + public String toString() { + return this.getClass().getSimpleName() + "(length:" + length() + ')'; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicClient.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicClient.java new file mode 100644 index 00000000000..58a07c22f66 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicClient.java @@ -0,0 +1,585 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.LongFunction; + +import javax.net.ssl.SSLParameters; + +import jdk.internal.net.http.AltServicesRegistry; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.QuicEndpoint.QuicEndpointFactory; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicVersion; + +/** + * This class represents a QuicClient. + * The QuicClient is responsible for creating/returning instances + * of QuicConnection for a given AltService, and for linking them + * with an instance of QuicEndpoint and QuicSelector for reading + * and writing Datagrams off the network. + * A QuicClient is also a factory for QuicConnectionIds. + * There is a 1-1 relationship between a QuicClient and an Http3Client. + * A QuicClient can be closed: closing a QuicClient will close all + * quic connections opened on that client. + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ +public final class QuicClient implements QuicInstance, AutoCloseable { + private static final AtomicLong IDS = new AtomicLong(); + private static final AtomicLong CONNECTIONS = new AtomicLong(); + + private final Logger debug = Utils.getDebugLogger(this::name); + + // See RFC 9000 section 14 + static final int SMALLEST_MAXIMUM_DATAGRAM_SIZE = 1200; + static final int INITIAL_SERVER_CONNECTION_ID_LENGTH = 17; + static final int MAX_ENDPOINTS_LIMIT = 16; + static final int DEFAULT_MAX_ENDPOINTS = Utils.getIntegerNetProperty( + "jdk.httpclient.quic.maxEndpoints", 1); + + private final String clientId; + private final String name; + private final Executor executor; + private final QuicTLSContext quicTLSContext; + private final SSLParameters sslParameters; + // QUIC versions in their descending order of preference + private final List availableVersions; + private final InetSocketAddress bindAddress; + private final QuicTransportParameters transportParams; + private final ReentrantLock lock = new ReentrantLock(); + private final QuicEndpoint[] endpoints = new QuicEndpoint[computeMaxEndpoints()]; + private int insertionPoint; + private volatile QuicSelector selector; + private volatile boolean closed; + // keep track of any initial tokens that a server has advertised for use. The key in this + // map is the server's host and port representation and the value is the token to use. + private final Map initialTokens = new ConcurrentHashMap<>(); + private final QuicEndpointFactory endpointFactory = new QuicEndpointFactory(); + private final LongFunction appErrorCodeToString; + + private QuicClient(final QuicClient.Builder builder) { + Objects.requireNonNull(builder, "Quic client builder"); + if (builder.availableVersions == null) { + throw new IllegalStateException("Need at least one available Quic version"); + } + if (builder.tlsContext == null) { + throw new IllegalStateException("No QuicTLSContext set"); + } + this.clientId = builder.clientId == null ? nextName() : builder.clientId; + this.name = "QuicClient(%s)".formatted(clientId); + this.appErrorCodeToString = builder.appErrorCodeToString == null + ? QuicInstance.super::appErrorToString + : builder.appErrorCodeToString; + // verify that QUIC TLS supports all requested QUIC versions + var test = new ArrayList<>(builder.availableVersions); + test.removeAll(builder.tlsContext.createEngine().getSupportedQuicVersions()); + if (!test.isEmpty()) { + throw new IllegalArgumentException( + "Requested QUIC versions not supported by TLS: " + test); + } + this.availableVersions = builder.availableVersions; + this.quicTLSContext = builder.tlsContext; + this.bindAddress = builder.bindAddr == null ? new InetSocketAddress(0) : builder.bindAddr; + this.executor = builder.executor; + this.sslParameters = builder.sslParams == null + ? new SSLParameters() + : requireTLS13(builder.sslParams); + this.transportParams = builder.transportParams; + if (debug.on()) debug.log("created"); + } + + + private static int computeMaxEndpoints() { + // available processors may change according to the API doc, + // so recompute this for each new client... + int availableProcessors = Runtime.getRuntime().availableProcessors(); + int max = DEFAULT_MAX_ENDPOINTS <= 0 ? availableProcessors >> 1 : DEFAULT_MAX_ENDPOINTS; + return Math.clamp(max, 1, MAX_ENDPOINTS_LIMIT); + } + + // verifies that the TLS protocol(s) configured in SSLParameters, if any, + // allows TLSv1.3 + private static SSLParameters requireTLS13(final SSLParameters parameters) { + final String[] protos = parameters.getProtocols(); + if (protos == null || protos.length == 0) { + // no specific protocols specified, so it's OK + return parameters; + } + for (final String proto : protos) { + if ("TLSv1.3".equals(proto)) { + // TLSv1.3 is allowed, that's good + return parameters; + } + } + // explicit TLS protocols have been configured in SSLParameters and it doesn't + // include TLSv1.3. QUIC mandates TLSv1.3, so we can't use this SSLParameters + throw new IllegalArgumentException("TLSv1.3 is required for QUIC," + + " but SSLParameters is configured with " + Arrays.toString(protos)); + } + + @Override + public String appErrorToString(long code) { + return appErrorCodeToString.apply(code); + } + + @Override + public QuicTransportParameters getTransportParameters() { + if (this.transportParams == null) { + return null; + } + // return a copy + return new QuicTransportParameters(this.transportParams); + } + + private static String nextName() { + return "quic-client-" + IDS.incrementAndGet(); + } + + /** + * The address that the QuicEndpoint will bind to. + * @implNote By default, this is wildcard:0 + * @return the address that the QuicEndpoint will bind to. + */ + public InetSocketAddress bindAddress() { + return bindAddress; + } + + @Override + public boolean isVersionAvailable(final QuicVersion quicVersion) { + return this.availableVersions.contains(quicVersion); + } + + /** + * {@return the versions that are available for use on this instance, in the descending order + * of their preference} + */ + @Override + public List getAvailableVersions() { + return this.availableVersions; + } + + /** + * Creates a new unconnected {@code QuicConnection} to the given + * {@code service}. + * + * @param service the alternate service for which to create the connection for + * @return a new unconnected {@code QuicConnection} + * @throws IllegalArgumentException if the ALPN of this transport isn't the same as that of the + * passed alternate service + * @apiNote The caller is expected to call {@link QuicConnectionImpl#startHandshake()} to + * initiate the handshaking. The connection is considered "connected" when + * the handshake is successfully completed. + */ + public QuicConnectionImpl createConnectionFor(final AltServicesRegistry.AltService service) { + final InetSocketAddress peerAddress = new InetSocketAddress(service.identity().host(), + service.identity().port()); + final String alpn = service.alpn(); + if (alpn == null) { + throw new IllegalArgumentException("missing ALPN on alt service"); + } + final SSLParameters sslParameters = createSSLParameters(new String[]{alpn}); + return new QuicConnectionImpl(null, this, peerAddress, + service.origin().host(), service.origin().port(), sslParameters, "QuicClientConnection(%s)", + CONNECTIONS.incrementAndGet()); + } + + /** + * Creates a new unconnected {@code QuicConnection} to the given + * {@code peerAddress}. + * + * @param peerAddress the address of the peer + * @return a new unconnected {@code QuicConnection} + * @apiNote The caller is expected to call {@link QuicConnectionImpl#startHandshake()} to + * initiate the handshaking. The connection is considered "connected" when + * the handshake is successfully completed. + */ + public QuicConnectionImpl createConnectionFor(final InetSocketAddress peerAddress, + final String[] alpns) { + Objects.requireNonNull(peerAddress); + Objects.requireNonNull(alpns); + if (alpns.length == 0) { + throw new IllegalArgumentException("at least one ALPN is needed"); + } + final SSLParameters sslParameters = createSSLParameters(alpns); + return new QuicConnectionImpl(null, this, peerAddress, peerAddress.getHostString(), + peerAddress.getPort(), sslParameters, "QuicClientConnection(%s)", CONNECTIONS.incrementAndGet()); + } + + private SSLParameters createSSLParameters(final String[] alpns) { + final SSLParameters sslParameters = Utils.copySSLParameters(this.getSSLParameters()); + sslParameters.setApplicationProtocols(alpns); + // section 4.2, RFC-9001 (QUIC) Clients MUST NOT offer TLS versions older than 1.3 + sslParameters.setProtocols(new String[] {"TLSv1.3"}); + return sslParameters; + } + + @Override + public String instanceId() { + return clientId; + } + + @Override + public QuicTLSContext getQuicTLSContext() { + return quicTLSContext; + } + + @Override + public SSLParameters getSSLParameters() { + return Utils.copySSLParameters(sslParameters); + } + + /** + * The name identifying this QuicClient, used in debug traces. + * @implNote This is {@code "QuicClient()"}. + * @return the name identifying this QuicClient. + */ + public String name() { + return name; + } + + /** + * The HttpClientImpl Id. used to identify the client in + * debug traces. + * @return A string identifying the HttpClientImpl instance. + */ + public String clientId() { + return clientId; + } + + /** + * The executor used by this QuicClient when a task needs to + * be offloaded to a separate thread. + * @implNote This is the HttpClientImpl internal executor. + * @return the executor used by this QuicClient. + */ + @Override + public Executor executor() { + return executor; + } + + @Override + public QuicEndpoint getEndpoint() throws IOException { + return chooseEndpoint(); + } + + private QuicEndpoint chooseEndpoint() throws IOException { + QuicEndpoint endpoint; + lock.lock(); + try { + if (closed) throw new IllegalStateException("QuicClient is closed"); + int index = insertionPoint; + if (index >= endpoints.length) index = 0; + endpoint = endpoints[index]; + if (endpoint != null) { + if (endpoints.length == 1) return endpoint; + if (endpoint.connectionCount() < 2) return endpoint; + for (int i = 1; i < endpoints.length - 1; i++) { + var nexti = (index + i) % endpoints.length; + var next = endpoints[nexti]; + if (next == null) continue; + if (next.connectionCount() < endpoint.connectionCount()) { + endpoint = next; + index = nexti; + } + } + if (++index >= endpoints.length) index = 0; + insertionPoint = index; + + if (Log.quicControl()) { + Log.logQuic("Selecting endpoint: " + endpoint.name()); + } else if (debug.on()) { + debug.log("Selecting endpoint: " + endpoint.name()); + } + + return endpoint; + } + + final var endpointName = "QuicEndpoint(" + clientId + "-" + index + ")"; + if (Log.quicControl()) { + Log.logQuic("Adding new endpoint: " + endpointName); + } else if (debug.on()) { + debug.log("Adding new endpoint: " + endpointName); + } + endpoint = createEndpoint(endpointName); + assert endpoints[index] == null; + endpoints[index] = endpoint; + insertionPoint = index + 1; + } finally { + lock.unlock(); + } + // register the newly created endpoint with the selector + QuicEndpoint.registerWithSelector(endpoint, selector, debug); + return endpoint; + } + + /** + * Creates an endpoint with the given name, and register it with a selector. + * @return the new QuicEndpoint + * @throws IOException if an error occurs when setting up the selector + * or linking the transport with the selector. + * @throws IllegalStateException if the client is closed. + */ + private QuicEndpoint createEndpoint(final String endpointName) throws IOException { + var selector = this.selector; + boolean newSelector = false; + final QuicEndpoint.ChannelType configuredChannelType = QuicEndpoint.CONFIGURED_CHANNEL_TYPE; + if (selector == null) { + // create a selector first + lock.lock(); + try { + if (closed) { + throw new IllegalStateException("QuicClient is closed"); + } + selector = this.selector; + if (selector == null) { + final String selectorName = "QuicSelector(" + clientId + ")"; + selector = this.selector = switch (configuredChannelType) { + case NON_BLOCKING_WITH_SELECTOR -> + QuicSelector.createQuicNioSelector(this, selectorName); + case BLOCKING_WITH_VIRTUAL_THREADS -> + QuicSelector.createQuicVirtualThreadPoller(this, selectorName); + }; + newSelector = true; + } + } finally { + lock.unlock(); + } + } + if (newSelector) { + // we may be closed when we reach here. It doesn't matter though. + // if the selector is closed before it's started the thread will + // immediately exit (or exit after the first wakeup) + selector.start(); + } + final QuicEndpoint endpoint = switch (configuredChannelType) { + case NON_BLOCKING_WITH_SELECTOR -> + endpointFactory.createSelectableEndpoint(this, endpointName, + bindAddress(), selector.timer()); + case BLOCKING_WITH_VIRTUAL_THREADS -> + endpointFactory.createVirtualThreadedEndpoint(this, endpointName, + bindAddress(), selector.timer()); + }; + assert endpoint.channelType() == configuredChannelType + : "bad endpoint for " + configuredChannelType + ": " + endpoint.getClass(); + return endpoint; + } + + @Override + public void unmatchedQuicPacket(SocketAddress source, QuicPacket.HeadersType type, ByteBuffer buffer) { + if (debug.on()) { + debug.log("dropping unmatched packet in buffer [%s, %d bytes, %s]", + type, buffer.remaining(), source); + } + } + + /** + * @param peerAddress The address of the server + * @return the initial token to use in INITIAL packets during connection establishment + * against a server represented by the {@code peerAddress}. Returns null if no token exists for + * the server. + */ + byte[] initialTokenFor(final InetSocketAddress peerAddress) { + if (peerAddress == null) { + return null; + } + final InitialTokenRecipient recipient = new InitialTokenRecipient(peerAddress.getHostString(), + peerAddress.getPort()); + // an initial token (obtained through NEW_TOKEN frame) can be used only once against the + // peer which advertised it. Hence, we remove it. + return this.initialTokens.remove(recipient); + } + + /** + * Registers a token to use in INITIAL packets during connection establishment against a server + * represented by the {@code peerAddress}. + * + * @param peerAddress The address of the server + * @param token The token to use + * @throws NullPointerException If either of {@code peerAddress} or {@code token} is null + * @throws IllegalArgumentException If the token is of zero length + */ + void registerInitialToken(final InetSocketAddress peerAddress, final byte[] token) { + Objects.requireNonNull(peerAddress); + Objects.requireNonNull(token); + if (token.length == 0) { + throw new IllegalArgumentException("Empty token"); + } + final InitialTokenRecipient recipient = new InitialTokenRecipient(peerAddress.getHostString(), + peerAddress.getPort()); + // multiple initial tokens (through NEW_TOKEN frame) can be sent by the same peer, but as + // per RFC-9000, section 8.1.3, it's OK for clients to just use the last received token, + // since the rest are less likely to be useful + this.initialTokens.put(recipient, token); + } + + @Override + public void close() { + // TODO: ignore exceptions while closing? + lock.lock(); + try { + if (closed) return; + closed = true; + } finally { + lock.unlock(); + } + for (int i = 0 ; i < endpoints.length ; i++) { + var endpoint = endpoints[i]; + if (endpoint != null) closeEndpoint(endpoint); + } + var selector = this.selector; + if (selector != null) selector.close(); + } + + private void closeEndpoint(QuicEndpoint endpoint) { + try { endpoint.close(); } catch (Throwable t) { + if (debug.on()) { + debug.log("Failed to close endpoint: %s: %s", endpoint.name(), t); + } + } + } + + // Called in case of RejectedExecutionException, or shutdownNow; + public void abort(Throwable t) { + lock.lock(); + try { + if (closed) return; + closed = true; + } finally { + lock.unlock(); + } + for (int i = 0 ; i < endpoints.length ; i++) { + var endpoint = endpoints[i]; + if (endpoint != null) abortEndpoint(endpoint, t); + } + var selector = this.selector; + if (selector != null) selector.abort(t); + } + + private void abortEndpoint(QuicEndpoint endpoint, Throwable cause) { + try { endpoint.abort(cause); } catch (Throwable t) { + if (debug.on()) { + debug.log("Failed to abort endpoint: %s: %s", endpoint.name(), t); + } + } + } + + private record InitialTokenRecipient (String host, int port) { + } + + public static final class Builder { + private String clientId; + private List availableVersions; + private Executor executor; + private SSLParameters sslParams; + private QuicTLSContext tlsContext; + private QuicTransportParameters transportParams; + private InetSocketAddress bindAddr; + private LongFunction appErrorCodeToString; + + public Builder availableVersions(final List versions) { + Objects.requireNonNull(versions, "Quic versions"); + if (versions.isEmpty()) { + throw new IllegalArgumentException("Need at least one available Quic version"); + } + this.availableVersions = List.copyOf(versions); + return this; + } + + public Builder applicationErrors(LongFunction errorCodeToString) { + this.appErrorCodeToString = errorCodeToString; + return this; + } + + public Builder availableVersions(final QuicVersion version, final QuicVersion... more) { + Objects.requireNonNull(version, "Quic version"); + if (more == null) { + this.availableVersions = List.of(version); + return this; + } + final List versions = new ArrayList<>(); + versions.add(version); + for (final QuicVersion v : more) { + Objects.requireNonNull(v, "Quic version"); + versions.add(v); + } + this.availableVersions = List.copyOf(versions); + return this; + } + + public Builder clientId(final String clientId) { + this.clientId = clientId; + return this; + } + + public Builder tlsContext(final QuicTLSContext tlsContext) { + this.tlsContext = tlsContext; + return this; + } + + public Builder sslParameters(final SSLParameters sslParameters) { + this.sslParams = sslParameters; + return this; + } + + public Builder bindAddress(final InetSocketAddress bindAddr) { + this.bindAddr = bindAddr; + return this; + } + + public Builder executor(final Executor executor) { + this.executor = executor; + return this; + } + + public Builder transportParameters(final QuicTransportParameters transportParams) { + this.transportParams = transportParams; + return this; + } + + public QuicClient build() { + return new QuicClient(this); + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicCongestionController.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicCongestionController.java new file mode 100644 index 00000000000..4bfad2c5560 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicCongestionController.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.quic; + +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.quic.packets.QuicPacket; + +import java.util.Collection; + +public interface QuicCongestionController { + + /** + * {@return true if a new non-ACK packet can be sent at this time} + */ + boolean canSendPacket(); + + /** + * Update the maximum datagram size + * @param newSize new maximum datagram size. + */ + void updateMaxDatagramSize(int newSize); + + /** + * Update CC with a non-ACK packet + * @param packetBytes packet size in bytes + */ + void packetSent(int packetBytes); + + /** + * Update CC after a non-ACK packet is acked + * + * @param packetBytes acked packet size in bytes + * @param sentTime time when packet was sent + */ + void packetAcked(int packetBytes, Deadline sentTime); + + /** + * Update CC after packets are declared lost + * + * @param lostPackets collection of lost packets + * @param sentTime time when the most recent lost packet was sent + * @param persistent true if persistent congestion detected, false otherwise + */ + void packetLost(Collection lostPackets, Deadline sentTime, boolean persistent); + + /** + * Update CC after packets are discarded + * @param discardedPackets collection of discarded packets + */ + void packetDiscarded(Collection discardedPackets); + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnection.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnection.java new file mode 100644 index 00000000000..05bfa1adc8c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnection.java @@ -0,0 +1,229 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.function.Predicate; +import java.util.stream.Stream; + +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; +import jdk.internal.net.http.quic.streams.QuicSenderStream; +import jdk.internal.net.http.quic.streams.QuicStream; +import jdk.internal.net.http.quic.streams.QuicStreams; +import jdk.internal.net.quic.QuicTLSEngine; + +/** + * This class implements a QUIC connection. + * A QUIC connection is established between a client and a server + * over a QuicEndpoint endpoint. + * A QUIC connection can then multiplex multiple QUIC streams to the + * same server. + * This abstract class exposes public methods used by the higher level + * protocol. + * + *

A typical call flow to establish a connection would be: + * {@snippet : + * AltService service = ...; + * QuicClient client = ...; + * QuicConnection connection = client.createConnectionFor(service); + * connection.startHandshake() + * .thenApply((r) -> { ... }) + * ...; + * } + * + */ +public abstract class QuicConnection { + + /** + * Starts the Quic Handshake. + * @return A completable future which will be completed when the + * handshake is completed. + * @throws UnsupportedOperationException If this connection isn't a client connection + */ + public abstract CompletableFuture startHandshake(); + + /** + * Creates a new locally initiated bidirectional stream. + *

+ * Creation of streams is limited to the maximum limit advertised by the peer. If a new stream + * cannot be created due to this limitation, then this method will use the + * {@code limitIncreaseDuration} to decide how long to wait for a potential increase in the + * limit. + *

+ * If the limit has been reached and the {@code limitIncreaseDuration} is not + * {@link Duration#isPositive() positive} then this method returns a {@code CompletableFuture} + * which has been completed exceptionally with {@link QuicStreamLimitException}. Else, this + * method returns a {@code CompletableFuture} which waits for that duration for a potential + * increase in the limit. If, during this period, the stream creation limit does increase and + * stream creation succeeds then the returned {@code CompletableFuture} will be completed + * successfully, else it will complete exceptionally with {@link QuicStreamLimitException}. + * + * @param limitIncreaseDuration Amount of time to wait for the bidirectional stream creation + * limit to be increased by the peer, if this connection has + * currently reached its limit + * @return a CompletableFuture which completes either with a new locally initiated + * bidirectional stream or exceptionally if the stream creation failed + */ + public abstract CompletableFuture openNewLocalBidiStream( + Duration limitIncreaseDuration); + + /** + * Creates a new locally initiated unidirectional stream. Locally created unidirectional streams + * are write-only streams. + *

+ * Creation of streams is limited to the maximum limit advertised by the peer. If a new stream + * cannot be created due to this limitation, then this method will use the + * {@code limitIncreaseDuration} to decide how long to wait for a potential increase in the + * limit. + *

+ * If the limit has been reached and the {@code limitIncreaseDuration} is not + * {@link Duration#isPositive() positive} then this method returns a {@code CompletableFuture} + * which has been completed exceptionally with {@link QuicStreamLimitException}. Else, this + * method returns a {@code CompletableFuture} which waits for that duration for a potential + * increase in the limit. If, during this period, the stream creation limit does increase and + * stream creation succeeds then the returned {@code CompletableFuture} will be completed + * successfully, else it will complete exceptionally with {@link QuicStreamLimitException}. + * + * @param limitIncreaseDuration Amount of time to wait for the unidirectional stream creation + * limit to be increased by the peer, if this connection has + * currently reached its limit + * @return a CompletableFuture which completes either with a new locally initiated + * unidirectional stream or exceptionally if the stream creation failed + */ + public abstract CompletableFuture openNewLocalUniStream( + Duration limitIncreaseDuration); + + /** + * Adds a listener that will be invoked when a remote stream is + * created. + * + * @apiNote + * The listener will be invoked with any remote streams + * already opened, and not yet acquired by another listener. + * Any stream passed to the listener is either a {@link QuicBidiStream} + * or a {@link QuicReceiverStream} depending on the + * {@linkplain QuicStreams#streamType(long) stream type} of the given + * streamId. + * The listener should return {@code true} if it wishes to acquire + * the stream. + * + * @param streamConsumer the listener + */ + public abstract void addRemoteStreamListener(Predicate streamConsumer); + + /** + * Removes a listener previously added with {@link #addRemoteStreamListener(Predicate)} + * @return {@code true} if the listener was found and removed, {@code false} otherwise + */ + public abstract boolean removeRemoteStreamListener(Predicate streamConsumer); + + /** + * {@return a stream of all currently opened {@link QuicStream} in the connection} + * + * @apiNote + * All current quic streams are included, whether local or remote, and whether they + * have been acquired or not. + * + * @see #addRemoteStreamListener(Predicate) + */ + public abstract Stream quicStreams(); + + /** + * {@return true if this connection is open} + */ + public abstract boolean isOpen(); + + /** + * {@return a long identifier that can be used to uniquely + * identify a quic connection in the context of the + * {@link QuicInstance} that created it} + */ + public long uniqueId() { return 0; } + + /** + * {@return a debug tag to be used with {@linkplain + * jdk.internal.net.http.common.Logger lower level logging}} + * This typically includes both the connection {@link #uniqueId()} + * and the {@link QuicInstance#instanceId()}. + */ + public abstract String dbgTag(); + + /** + * {@return a debug tag} + * Typically used with {@linkplain jdk.internal.net.http.common.Log + * higher level logging} + */ + public abstract String logTag(); + + /** + * {@return the {@link TerminationCause} if the connection has + * closed or is being closed, otherwise returns null} + */ + public abstract TerminationCause terminationCause(); + + public abstract QuicTLSEngine getTLSEngine(); + + public abstract InetSocketAddress peerAddress(); + + public abstract SocketAddress localAddress(); + + /** + * {@return a {@code CompletableFuture} that gets completed when + * the peer has acknowledged, or replied to the first {@link + * QuicPacket.PacketType#INITIAL INITIAL} + * packet + */ + public abstract CompletableFuture handshakeReachedPeer(); + + /** + * Requests to send a PING frame to the peer. + * An implementation may decide to support sending of out-of-band ping + * frames (triggered by the application layer) only for a subset of the + * {@linkplain jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace + * packet number spaces}. It may complete with -1 if it doesn't want to request + * sending of a ping frame at the time {@code requestSendPing()} is called. + * @return A completable future that will be completed with the number of + * milliseconds it took to get a valid response. It may also complete + * exceptionally, or with {@code -1L} if the ping was not sent. + */ + public abstract CompletableFuture requestSendPing(); + + /** + * {@return this connection {@code QuicConnectionId} or null} + * @implSpec + * The default implementation of this method returns null + */ + public QuicConnectionId localConnectionId() { return null; } + + /** + * {@return the {@link ConnectionTerminator} for this connection} + */ + public abstract ConnectionTerminator connectionTerminator(); +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionId.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionId.java new file mode 100644 index 00000000000..21c40a69588 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionId.java @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2020, 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.nio.ByteBuffer; +import java.util.HexFormat; + +/** + * Models a Quic Connection id. + * QuicConnectionId instance are typically created by a Quic client or server. + */ +// Connection IDs are used as keys in an ID to connection map. +// They implement Comparable to mitigate the penalty of hash collisions. +public abstract class QuicConnectionId implements Comparable { + + /** + * The maximum length, in bytes, of a connection id. + * This is supposed to be version-specific, but for now, we + * are going to treat that as a universal constant. + */ + public static final int MAX_CONNECTION_ID_LENGTH = 20; + protected final int hashCode; + protected final ByteBuffer buf; + + protected QuicConnectionId(ByteBuffer buf) { + this.buf = buf.asReadOnlyBuffer(); + hashCode = this.buf.hashCode(); + } + + /** + * Returns the length of this connection id, in bytes; + * @return the length of this connection id + */ + public int length() { + return buf.remaining(); + } + + /** + * Returns this connection id bytes as a read-only buffer. + * @return A new read only buffer containing this connection id bytes. + */ + public ByteBuffer asReadOnlyBuffer() { + return buf.asReadOnlyBuffer(); + } + + /** + * Returns this connection id bytes as a byte array. + * @return A new byte array containing this connection id bytes. + */ + public byte[] getBytes() { + var length = length(); + byte[] bytes = new byte[length]; + buf.get(buf.position(), bytes, 0, length); + return bytes; + } + + /** + * Compare this connection id bytes with the bytes in the + * given byte buffer. + *

The given byte buffer is expected to have + * its {@linkplain ByteBuffer#position() position} set at the start + * of the connection id, and its {@linkplain ByteBuffer#limit() limit} + * at the end. In other words, {@code Buffer.remaining()} should + * indicate the connection id length. + *

This method does not advance the buffer position. + * + * @implSpec This is equivalent to:

{@code
+     *  this.asReadOnlyBuffer().comparesTo(idbytes)
+     *  }
+ * + * @param idbytes A byte buffer containing the id bytes of another + * connection id. + * @return {@code -1}, {@code 0}, or {@code 1} if this connection's id + * bytes are less, equal, or greater than the provided bytes. + */ + public int compareBytes(ByteBuffer idbytes) { + return buf.compareTo(idbytes); + } + + /** + * Tells whether the given byte buffer matches this connection id. + * The given byte buffer is expected to have + * its {@linkplain ByteBuffer#position() position} set at the start + * of the connection id, and its {@linkplain ByteBuffer#limit() limit} + * at the end. In other words, {@code Buffer.remaining()} should + * indicate the connection id length. + *

This method does not advance the buffer position. + * + * @implSpec + * This is equivalent to:

{@code
+     *  this.asReadOnlyBuffer().mismatch(idbytes) == -1
+     *  }
+ * + * @param idbytes A buffer that delimits a connection id. + * @return true if the bytes in the given buffer match this + * connection id bytes. + */ + public boolean matches(ByteBuffer idbytes) { + return buf.equals(idbytes); + } + + @Override + public int compareTo(QuicConnectionId o) { + return buf.compareTo(o.buf); + } + + + @Override + public final boolean equals(Object o) { + if (o instanceof QuicConnectionId that) { + return buf.equals(that.buf); + } + return false; + } + + @Override + public final int hashCode() { + return hashCode; + } + + /** + * {@return an hexadecimal string representing this connection id bytes, + * as returned by {@code HexFormat.of().formatHex(getBytes())}} + */ + public String toHexString() { + return HexFormat.of().formatHex(getBytes()); + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionIdFactory.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionIdFactory.java new file mode 100644 index 00000000000..04cb2e6c263 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionIdFactory.java @@ -0,0 +1,354 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import javax.crypto.KeyGenerator; +import javax.crypto.Mac; +import java.nio.ByteBuffer; +import java.security.InvalidKeyException; +import java.security.Key; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.Arrays; +import java.util.Random; +import java.util.concurrent.atomic.AtomicLong; + +import static jdk.internal.net.http.quic.QuicConnectionId.MAX_CONNECTION_ID_LENGTH; + +/** + * A class to generate connection ids bytes. + * This algorithm is specific to our implementation - it's not defined + * in any RFC (connection id bytes are free form). + * For the purpose of validation we encode the length of + * the connection id into the connection id bytes. + * For the purpose of uniqueness we encode a unique id. + * The rest of the connection id are random bytes. + */ +public class QuicConnectionIdFactory { + private static final Random RANDOM = new SecureRandom(); + private static final String CLIENT_DESC = "QuicClientConnectionId"; + private static final String SERVER_DESC = "QuicServerConnectionId"; + + private static final int MIN_CONNECTION_ID_LENGTH = 9; + + private final AtomicLong tokens = new AtomicLong(); + private volatile boolean wrapped; + private final byte[] scrambler; + private final Key statelessTokenKey; + private final String simpleDesc; + private final int connectionIdLength = RANDOM.nextInt(MIN_CONNECTION_ID_LENGTH, MAX_CONNECTION_ID_LENGTH+1); + + public static QuicConnectionIdFactory getClient() { + return new QuicConnectionIdFactory(CLIENT_DESC); + } + + public static QuicConnectionIdFactory getServer() { + return new QuicConnectionIdFactory(SERVER_DESC); + } + + private QuicConnectionIdFactory(String simpleDesc) { + this.simpleDesc = simpleDesc; + byte[] temp = new byte[MAX_CONNECTION_ID_LENGTH]; + RANDOM.nextBytes(temp); + scrambler = temp; + try { + KeyGenerator kg = KeyGenerator.getInstance("HmacSHA256"); + statelessTokenKey = kg.generateKey(); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("HmacSHA256 key generator not available", e); + } + } + + /** + * The connection ID length used by this Quic instance. + * This is the source connection id length for outgoing packets, + * and the destination connection id length for incoming packets. + * @return the connection ID length used by this instance + */ + public int connectionIdLength() { + return connectionIdLength; + } + + /** + * Creates a new connection ID for a connection. + * @return a new connection ID + */ + public QuicConnectionId newConnectionId() { + long token = newToken(); + return new QuicLocalConnectionId(token, simpleDesc, + newConnectionId(connectionIdLength, token)); + } + + /** + * Quick validation to see if the buffer can contain a connection + * id generated by this instance. The byte buffer is expected to have + * its {@linkplain ByteBuffer#position() position} set at the start + * of the connection id, and its {@linkplain ByteBuffer#limit() limit} + * at the end. In other words, {@code Buffer.remaining()} should + * indicate the connection id length. + *

This method does not advance the buffer position, and + * returns a connection id that wraps the given buffer. + * The returned connection id is only safe to use as long as + * the buffer is not modified. + *

It is usually only used temporarily as a lookup key + * to locate an existing {@code QuicConnection}. + * + * @param buffer A buffer that delimits a connection id. + * @return a new QuicConnectionId if the buffer can contain + * a connection id generated by this instance, {@code null} + * otherwise. + */ + public QuicConnectionId unsafeConnectionIdFor(ByteBuffer buffer) { + int expectedLength = connectionIdLength; + + int remaining = buffer.remaining(); + if (remaining < MIN_CONNECTION_ID_LENGTH) return null; + if (remaining != expectedLength) return null; + + byte first = buffer.get(0); + int len = extractConnectionIdLength(first); + if (len < MIN_CONNECTION_ID_LENGTH) return null; + if (len > MAX_CONNECTION_ID_LENGTH) return null; + if (len != expectedLength) return null; + + long token = peekConnectionIdToken(buffer); + if (!isValidToken(token)) return null; + var cid = new QuicLocalConnectionId(buffer, token, simpleDesc); + assert cid.length() == expectedLength; + return cid; + } + + /** + * Returns a stateless reset token for the given connection ID + * @param connectionId connection ID + * @return stateless reset token for the given connection ID + * @throws IllegalArgumentException if the connection ID was not generated by this factory + */ + public byte[] statelessTokenFor(QuicConnectionId connectionId) { + if (!(connectionId instanceof QuicLocalConnectionId)) { + throw new IllegalArgumentException("Not a locally-generated connection ID"); + } + Mac mac; + try { + mac = Mac.getInstance("HmacSHA256"); + mac.init(statelessTokenKey); + } catch (NoSuchAlgorithmException | InvalidKeyException e) { + throw new RuntimeException("HmacSHA256 is not available", e); + } + byte[] result = mac.doFinal(connectionId.getBytes()); + return Arrays.copyOf(result, 16); + } + + // visible for testing + public long newToken() { + var token = tokens.incrementAndGet(); + if (token < 0) { + token = -token - 1; + wrapped = true; + } + return token; + } + + // visible for testing + public byte[] newConnectionId(int length, long token) { + length = Math.clamp(length, MIN_CONNECTION_ID_LENGTH, MAX_CONNECTION_ID_LENGTH); + assert length <= MAX_CONNECTION_ID_LENGTH; + assert length >= MIN_CONNECTION_ID_LENGTH; + byte[] bytes = new byte[length]; + RANDOM.nextBytes(bytes); + + if (token < 0) token = -token - 1; + assert token >= 0; + int len = variableLengthLength(token); + assert len < 8; + + bytes[0] = (byte) ((length << 3) & 0xF8); + bytes[0] = (byte) (bytes[0] | len); + assert (bytes[0] & 0x07) == len; + assert ((bytes[0] & 0xFF) >> 3) == length : + "%s != %s".formatted(bytes[0] & 0xFF, length); + int shift = 8 * len; + for (int i = 0; i <= len; i++) { + assert shift <= 56; + bytes[i + 1] = (byte) ((token >> shift) & 0xFF); + shift -= 8; + } + for (int i = 0; i < length; i++) { + bytes[i] = (byte) ((bytes[i] & 0xFF) ^ (scrambler[i] & 0xFF)); + } + + assert length == getConnectionIdLength(bytes); + assert token == getConnectionIdToken(bytes); + return bytes; + } + + // visible for testing + public int getConnectionIdLength(byte[] bytes) { + assert bytes.length >= MIN_CONNECTION_ID_LENGTH; + var length = extractConnectionIdLength(bytes[0]); + assert length <= MAX_CONNECTION_ID_LENGTH; + return length; + } + + // visible for testing + public long getConnectionIdToken(byte[] bytes) { + assert bytes.length >= MIN_CONNECTION_ID_LENGTH; + int len = extractTokenLength(bytes[0]); + long token = 0; + int shift = len * 8; + for (int i = 0; i <= len; i++) { + assert shift >= 0; + assert shift <= 56; + int j = i + 1; + long l = ((bytes[j] & 0xFF) ^ (scrambler[j] & 0xFF)) & 0xFF; + l = l << shift; + token += l; + shift -= 8; + } + assert token >= 0; + return token; + } + + private long peekConnectionIdToken(ByteBuffer bytes) { + assert bytes.remaining() >= MIN_CONNECTION_ID_LENGTH; + int len = extractTokenLength(bytes.get(0)); + long token = 0; + int shift = len * 8; + for (int i = 0; i <= len; i++) { + assert shift >= 0; + assert shift <= 56; + int j = i + 1; + long l = ((bytes.get(j) & 0xFF) ^ (scrambler[j] & 0xFF)) & 0xFF; + l = l << shift; + token += l; + shift -= 8; + } + return token; + } + + private boolean isValidToken(long token) { + if (token < 0) return false; + long prevToken = tokens.get(); + boolean wrapped = prevToken < 0 || this.wrapped; + // if `tokens` has wrapped, we can say nothing... + // otherwise, we can say it should not be coded on more bytes than + // the previous token that was distributed + if (!wrapped) { + return token <= prevToken; + } + return true; + } + + private int extractConnectionIdLength(byte b) { + var bits = ((b & 0xFF) ^ (scrambler[0] & 0xFF)) & 0xFF; + bits = bits >> 3; + return bits; + } + + private int extractTokenLength(byte b) { + var bits = ((b & 0xFF) ^ (scrambler[0] & 0xFF)) & 0xFF; + return bits & 0x07; + } + + private static int variableLengthLength(long token) { + assert token >= 0; + int len = 0; + int shift = 0; + for (int i = 1; i < 8; i++) { + shift += 8; + if ((token >> shift) == 0) break; + len++; + } + assert len < 8; + return len; + } + + /** + * Checks if {@code connId} looks like a connection ID we could possibly generate. + * If it does, returns a stateless reset datagram. + * @param connId the destination connection id that was received on the packet + * @param length maximum length of the stateless reset packet + * @return stateless reset datagram payload, or null + */ + public ByteBuffer statelessReset(ByteBuffer connId, int length) { + // 43 bytes max: + // first byte bits 01xx xxxx + // followed by random bytes + // terminated by 16 bytes reset token + length = Math.min(length, 43); + if (length < 21) { // minimum QUIC short datagram length + return null; + } + + var cid = (QuicLocalConnectionId)unsafeConnectionIdFor(connId); + if (cid != null) { + var localToken = statelessTokenFor(cid); + assert localToken != null; + ByteBuffer buf = ByteBuffer.allocate(length); + buf.put((byte)(0x40 + RANDOM.nextInt(0x40))); + byte[] random = new byte[length - 17]; + RANDOM.nextBytes(random); + buf.put(random); + buf.put(localToken); + assert !buf.hasRemaining() : buf.remaining(); + buf.flip(); + return buf; + } + return null; + } + + // A connection id generated by this instance. + private static final class QuicLocalConnectionId extends QuicConnectionId { + private final long token; + private final String simpleDesc; + + // Connection Ids created with this constructor are safer + // to use in maps as the buffer wraps a safe byte array in + // this constructor. + private QuicLocalConnectionId(long token, String simpleDesc, byte[] bytes) { + super(ByteBuffer.wrap(bytes)); + this.token = token; + this.simpleDesc = simpleDesc; + } + + // Connection Ids created with this constructor are only + // safe to use as long as the caller abstain from mutating + // the provided byte buffer. + // Typically, they will be transiently used to look up some + // connection in a map indexed by a connection id. + private QuicLocalConnectionId(ByteBuffer buffer, long token, String simpleDesc) { + super(buffer); + assert token >= 0; + this.token = token; + this.simpleDesc = simpleDesc; + } + + @Override + public String toString() { + return "%s(length=%s, token=%s, hash=%s)" + .formatted(simpleDesc, length(), token, hashCode); + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionImpl.java new file mode 100644 index 00000000000..f05519d339b --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicConnectionImpl.java @@ -0,0 +1,4353 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodHandles.Lookup; +import java.lang.invoke.VarHandle; +import java.net.ConnectException; +import java.net.Inet6Address; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.NetworkChannel; +import java.nio.channels.UnresolvedAddressException; +import java.security.SecureRandom; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Queue; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLParameters; + +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.TimeSource; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.OrderedFlow.CryptoDataFlow; +import jdk.internal.net.http.quic.QuicEndpoint.QuicDatagram; +import jdk.internal.net.http.quic.QuicTransportParameters.VersionInformation; +import jdk.internal.net.http.quic.frames.AckFrame; +import jdk.internal.net.http.quic.frames.ConnectionCloseFrame; +import jdk.internal.net.http.quic.frames.CryptoFrame; +import jdk.internal.net.http.quic.frames.DataBlockedFrame; +import jdk.internal.net.http.quic.frames.HandshakeDoneFrame; +import jdk.internal.net.http.quic.frames.MaxDataFrame; +import jdk.internal.net.http.quic.frames.MaxStreamDataFrame; +import jdk.internal.net.http.quic.frames.MaxStreamsFrame; +import jdk.internal.net.http.quic.frames.NewConnectionIDFrame; +import jdk.internal.net.http.quic.frames.NewTokenFrame; +import jdk.internal.net.http.quic.frames.PaddingFrame; +import jdk.internal.net.http.quic.frames.PathChallengeFrame; +import jdk.internal.net.http.quic.frames.PathResponseFrame; +import jdk.internal.net.http.quic.frames.PingFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.frames.ResetStreamFrame; +import jdk.internal.net.http.quic.frames.RetireConnectionIDFrame; +import jdk.internal.net.http.quic.frames.StopSendingFrame; +import jdk.internal.net.http.quic.frames.StreamDataBlockedFrame; +import jdk.internal.net.http.quic.frames.StreamFrame; +import jdk.internal.net.http.quic.frames.StreamsBlockedFrame; +import jdk.internal.net.http.quic.packets.HandshakePacket; +import jdk.internal.net.http.quic.packets.InitialPacket; +import jdk.internal.net.http.quic.packets.LongHeader; +import jdk.internal.net.http.quic.packets.OneRttPacket; +import jdk.internal.net.http.quic.packets.PacketSpace; +import jdk.internal.net.http.quic.packets.QuicPacketDecoder; +import jdk.internal.net.http.quic.packets.QuicPacketEncoder; +import jdk.internal.net.http.quic.packets.QuicPacketEncoder.OutgoingQuicPacket; +import jdk.internal.net.http.quic.packets.RetryPacket; +import jdk.internal.net.http.quic.packets.VersionNegotiationPacket; +import jdk.internal.net.http.quic.streams.CryptoWriterQueue; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import jdk.internal.net.http.quic.streams.QuicBidiStreamImpl; +import jdk.internal.net.http.quic.streams.QuicConnectionStreams; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; +import jdk.internal.net.http.quic.streams.QuicSenderStream; +import jdk.internal.net.http.quic.streams.QuicStream; +import jdk.internal.net.http.quic.streams.QuicStream.StreamState; +import jdk.internal.net.http.quic.streams.QuicStreams; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketType; +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicOneRttContext; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTLSEngine.HandshakeState; +import jdk.internal.net.quic.QuicTLSEngine.KeySpace; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.http.quic.QuicTransportParameters.ParameterId; +import jdk.internal.net.quic.QuicVersion; + +import static jdk.internal.net.http.quic.QuicClient.INITIAL_SERVER_CONNECTION_ID_LENGTH; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.active_connection_id_limit; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_data; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_stream_data_bidi_local; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_stream_data_bidi_remote; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_stream_data_uni; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_streams_bidi; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_streams_uni; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_source_connection_id; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.max_idle_timeout; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.max_udp_payload_size; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.version_information; +import static jdk.internal.net.http.quic.TerminationCause.forException; +import static jdk.internal.net.http.quic.TerminationCause.forTransportError; +import static jdk.internal.net.http.quic.QuicConnectionId.MAX_CONNECTION_ID_LENGTH; +import static jdk.internal.net.http.quic.QuicRttEstimator.MAX_PTO_BACKOFF_TIMEOUT; +import static jdk.internal.net.http.quic.QuicRttEstimator.MIN_PTO_BACKOFF_TIMEOUT; +import static jdk.internal.net.http.quic.frames.QuicFrame.MAX_VL_INTEGER; +import static jdk.internal.net.http.quic.packets.QuicPacketNumbers.computePacketNumberLength; +import static jdk.internal.net.http.quic.streams.QuicStreams.isUnidirectional; +import static jdk.internal.net.http.quic.streams.QuicStreams.streamType; +import static jdk.internal.net.quic.QuicTransportErrors.PROTOCOL_VIOLATION; + +/** + * This class implements a QUIC connection. + * A QUIC connection is established between a client and a server over a + * QuicEndpoint endpoint. + * A QUIC connection can then multiplex multiple QUIC streams to the same server. + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9001 + * RFC 9001: Using TLS to Secure QUIC + * @spec https://www.rfc-editor.org/info/rfc9002 + * RFC 9002: QUIC Loss Detection and Congestion Control + */ +public class QuicConnectionImpl extends QuicConnection implements QuicPacketReceiver { + + private static final int MAX_IPV6_MTU = 65527; + private static final int MAX_IPV4_MTU = 65507; + + // Quic assumes a minimum packet size of 1200 + // See https://www.rfc-editor.org/rfc/rfc9000#name-datagram-size + public static final int SMALLEST_MAXIMUM_DATAGRAM_SIZE = + QuicClient.SMALLEST_MAXIMUM_DATAGRAM_SIZE; + + public static final int DEFAULT_MAX_INITIAL_TIMEOUT = Math.clamp( + Utils.getIntegerProperty("jdk.httpclient.quic.maxInitialTimeout", 30), + 1, Integer.MAX_VALUE); + public static final long DEFAULT_INITIAL_MAX_DATA = Math.clamp( + Utils.getLongProperty("jdk.httpclient.quic.maxInitialData", 15 << 20), + 0, 1L << 60); + public static final long DEFAULT_INITIAL_STREAM_MAX_DATA = Math.clamp( + Utils.getIntegerProperty("jdk.httpclient.quic.maxStreamInitialData", 6 << 20), + 0, 1L << 60); + public static final int DEFAULT_MAX_BIDI_STREAMS = + Utils.getIntegerProperty("jdk.httpclient.quic.maxBidiStreams", 100); + public static final int DEFAULT_MAX_UNI_STREAMS = + Utils.getIntegerProperty("jdk.httpclient.quic.maxUniStreams", 100); + public static final boolean USE_DIRECT_BUFFER_POOL = Utils.getBooleanProperty( + "jdk.internal.httpclient.quic.poolDirectByteBuffers", !QuicEndpoint.DGRAM_SEND_ASYNC); + + public static final int RESET_TOKEN_LENGTH = 16; // RFC states 16 bytes for stateless token + public static final long MAX_STREAMS_VALUE_LIMIT = 1L << 60; // cannot exceed 2^60 as per RFC + + // VarHandle provide the same atomic compareAndSet functionality + // than AtomicXXXXX classes, but without the additional cost in + // footprint. + private static final VarHandle VERSION_NEGOTIATED; + private static final VarHandle STATE; + private static final VarHandle MAX_SND_DATA; + private static final VarHandle MAX_RCV_DATA; + public static final int DEFAULT_DATAGRAM_SIZE; + private static final int MAX_INCOMING_CRYPTO_CAPACITY = 64 << 10; + + static { + try { + Lookup lookup = MethodHandles.lookup(); + VERSION_NEGOTIATED = lookup + .findVarHandle(QuicConnectionImpl.class, "versionNegotiated", boolean.class); + STATE = lookup.findVarHandle(QuicConnectionImpl.class, "state", int.class); + MAX_SND_DATA = lookup.findVarHandle(OneRttFlowControlledSendingQueue.class, "maxData", long.class); + MAX_RCV_DATA = lookup.findVarHandle(OneRttFlowControlledReceivingQueue.class, "maxData", long.class); + } catch (Exception x) { + throw new ExceptionInInitializerError(x); + } + int size = Utils.getIntegerProperty("jdk.httpclient.quic.defaultMTU", + SMALLEST_MAXIMUM_DATAGRAM_SIZE); + // don't allow the value to be below 1200 and above 65527, to conform with RFC-9000, + // section 18.2: + // The default for this parameter is the maximum permitted UDP payload of 65527. + // Values below 1200 are invalid. + if (size < SMALLEST_MAXIMUM_DATAGRAM_SIZE || size > MAX_IPV6_MTU) { + // fallback to SMALLEST_MAXIMUM_DATAGRAM_SIZE + size = SMALLEST_MAXIMUM_DATAGRAM_SIZE; + } + DEFAULT_DATAGRAM_SIZE = size; + } + + protected final Logger debug = Utils.getDebugLogger(this::dbgTag); + + final QuicRttEstimator rttEstimator = new QuicRttEstimator(); + final QuicCongestionController congestionController; + /** + * The state of the quic connection. + * The handshake is confirmed when HANDSHAKE_DONE has been received, + * or when the first 1-RTT packet has been successfully decrypted. + * See RFC 9001 section 4.1.2 + * https://www.rfc-editor.org/rfc/rfc9001#name-handshake-confirmed + */ + private final StateHandle stateHandle = new StateHandle(); + private final AtomicBoolean startHandshakeCalled = new AtomicBoolean(); + private final InetSocketAddress peerAddress; + private final QuicInstance quicInstance; + private final String dbgTag; + private final QuicTLSEngine quicTLSEngine; + private final CodingContext codingContext; + private final PacketSpaces packetSpaces; + private final OneRttFlowControlledSendingQueue oneRttSndQueue = + new OneRttFlowControlledSendingQueue(); + private final OneRttFlowControlledReceivingQueue oneRttRcvQueue = + new OneRttFlowControlledReceivingQueue(this::logTag); + protected final QuicConnectionStreams streams; + protected final Queue outgoing1RTTFrames = new ConcurrentLinkedQueue<>(); + // for one-rtt crypto data (session tickets) + private final CryptoDataFlow peerCryptoFlow = new CryptoDataFlow(); + private final CryptoWriterQueue localCryptoFlow = new CryptoWriterQueue(); + private final HandshakeFlow handshakeFlow = new HandshakeFlow(); + final ConnectionTerminatorImpl terminator; + protected final IdleTimeoutManager idleTimeoutManager; + protected final QuicTransportParameters transportParams; + // the initial (local) connection ID + private final QuicConnectionId connectionId; + private final PeerConnIdManager peerConnIdManager; + private final LocalConnIdManager localConnIdManager; + private volatile QuicConnectionId incomingInitialPacketSourceId; + protected final QuicEndpoint endpoint; + private volatile QuicTransportParameters localTransportParameters; + private volatile QuicTransportParameters peerTransportParameters; + private volatile byte[] initialToken; + // the number of (active) connection ids the peer is willing to accept for a given connection + private volatile long peerActiveConnIdsLimit = 2; // default is 2 as per RFC + + private volatile int state; + // the quic version currently in use + private volatile QuicVersion quicVersion; + // the quic version from the first packet + private final QuicVersion originalVersion; + private volatile QuicPacketDecoder decoder; + private volatile QuicPacketEncoder encoder; + // (client-only) if true, we no longer accept VERSIONS packets + private volatile boolean versionCompatible; + // if true, we no longer accept version changes + private volatile boolean versionNegotiated; + // true if we changed version in response to VERSIONS packet + private volatile boolean processedVersionsPacket; + // start off with 1200 or whatever is configured through + // jdk.net.httpclient.quic.defaultPDU system property + private int maxPeerAdvertisedPayloadSize = DEFAULT_DATAGRAM_SIZE; + // max MTU size on the connection: either MAX_IPV4_MTU or MAX_IPV6_MTU, + // depending on whether the peer address is IPv6 or IPv4 + private final int maxConnectionMTU; + // we start with a pathMTU that is 1200 or whatever is configured through + // jdk.net.httpclient.quic.defaultPDU system property + private int pathMTU = DEFAULT_DATAGRAM_SIZE; + private final SequentialScheduler handshakeScheduler = + SequentialScheduler.lockingScheduler(this::continueHandshake0); + private final ReentrantLock handshakeLock = new ReentrantLock(); + private final String cachedToString; + private final String logTag; + private final long labelId; + // incoming PATH_CHALLENGE frames waiting for PATH_RESPONSE + private final Queue pathChallengeFrameQueue = new ConcurrentLinkedQueue<>(); + + private volatile MaxInitialTimer maxInitialTimer; + + static String dbgTag(QuicInstance quicInstance, String logTag) { + return String.format("QuicConnection(%s, %s)", + quicInstance.instanceId(), logTag); + } + + protected QuicConnectionImpl(final QuicVersion firstFlightVersion, + final QuicInstance quicInstance, + final InetSocketAddress peerAddress, + final String peerName, + final int peerPort, + final SSLParameters sslParameters, + final String logTagFormat, + final long labelId) { + this.labelId = labelId; + this.quicInstance = Objects.requireNonNull(quicInstance, "quicInstance"); + try { + this.endpoint = quicInstance.getEndpoint(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + this.peerAddress = peerAddress; + this.maxConnectionMTU = peerAddress.getAddress() instanceof Inet6Address + ? MAX_IPV6_MTU + : MAX_IPV4_MTU; + this.pathMTU = Math.clamp(DEFAULT_DATAGRAM_SIZE, SMALLEST_MAXIMUM_DATAGRAM_SIZE, maxConnectionMTU); + this.cachedToString = String.format(logTagFormat.formatted("quic:%s:%s:%s"), labelId, + Arrays.toString(sslParameters.getApplicationProtocols()), peerAddress); + this.connectionId = this.endpoint.idFactory().newConnectionId(); + this.logTag = logTagFormat.formatted(labelId); + this.dbgTag = dbgTag(quicInstance, logTag); + this.congestionController = new QuicRenoCongestionController(dbgTag); + this.originalVersion = this.quicVersion = firstFlightVersion == null + ? QuicVersion.firstFlightVersion(quicInstance.getAvailableVersions()) + : firstFlightVersion; + final boolean isClientConn = isClientConnection(); + this.peerConnIdManager = new PeerConnIdManager(this, dbgTag); + this.localConnIdManager = new LocalConnIdManager(this, dbgTag, connectionId); + this.decoder = QuicPacketDecoder.of(this.quicVersion); + this.encoder = QuicPacketEncoder.of(this.quicVersion); + this.codingContext = new QuicCodingContext(); + final QuicTLSEngine engine = this.quicInstance.getQuicTLSContext() + .createEngine(peerName, peerPort); + engine.setUseClientMode(isClientConn); + engine.setSSLParameters(sslParameters); + this.quicTLSEngine = engine; + quicTLSEngine.setRemoteQuicTransportParametersConsumer(this::consumeQuicParameters); + packetSpaces = PacketSpaces.forConnection(this); + quicTLSEngine.setOneRttContext(packetSpaces.getOneRttContext()); + streams = new QuicConnectionStreams(this, debug); + if (quicInstance instanceof QuicClient quicClient) { + // use the (INITIAL) token that a server might have sent to this client (through + // NEW_TOKEN frame) on a previous connection against that server + this.initialToken = quicClient.initialTokenFor(this.peerAddress); + } + terminator = new ConnectionTerminatorImpl(this); + idleTimeoutManager = new IdleTimeoutManager(this); + transportParams = quicInstance.getTransportParameters() == null + ? new QuicTransportParameters() + : quicInstance.getTransportParameters(); + if (debug.on()) debug.log("Quic Connection Created"); + } + + @Override + public final long uniqueId() { + return labelId; + } + + /** + * An abstraction to represent the connection state as a bit mask. + * This is not an enum as some stages can overlap. + */ + public abstract static class QuicConnectionState { + public static final int + NEW = 0, // the connection is new + HISENT = 1, // first initial hello packet sent + HSCOMPLETE = 16, // handshake completed + CLOSING = 128, // connection has entered "Closing" state as defined in RFC-9000 + DRAINING = 256, // connection has entered "Draining" state as defined in RFC-9000 + CLOSED = 512; // CONNECTION_CLOSE ACK sent or received + public abstract int state(); + public boolean helloSent() {return isMarked(HISENT);} + public boolean handshakeComplete() { return isMarked(HSCOMPLETE);} + public boolean closing() { return isMarked(CLOSING);} + public boolean draining() { return isMarked(DRAINING);} + public boolean opened() { return (state() & (CLOSED | DRAINING | CLOSING)) == 0; } + public boolean isMarked(int mask) { return isMarked(state(), mask); } + public String toString() { return toString(state()); } + public static boolean isMarked(int state, int mask) { + return mask == 0 ? state == 0 : (state & mask) == mask; + } + public static String toString(int state) { + if (state == NEW) return "new"; + if (isMarked(state, CLOSED)) return "closed"; + if (isMarked(state, DRAINING)) return "draining"; + if (isMarked(state, CLOSING)) return "closing"; + if (isMarked(state, HSCOMPLETE)) return "handshakeComplete"; + if (isMarked(state, HISENT)) return "helloSent"; + return "Unknown(" + state + ")"; + } + } + + /** + * A {link QuicTimedEvent} used to interrupt the handshake + * if no response to the first initial packet is received within + * a reasonable delay (default is ~ 30s). + * This avoids waiting more than 30s for ConnectionException + * to be raised if no server is available at the peer address. + * This class is only used on the client side. + */ + final class MaxInitialTimer implements QuicTimedEvent { + private final Deadline maxInitialDeadline; + private final QuicTimerQueue timerQueue; + private final long eventId; + private volatile Deadline deadline; + private volatile boolean initialPacketReceived; + private volatile boolean connectionClosed; + + // optimization: if done is true it avoids volatile read + // of initialPacketReceived and/or connectionClosed + // from initialPacketReceived() + private boolean done; + private MaxInitialTimer(QuicTimerQueue timerQueue, Deadline maxDeadline) { + this.eventId = QuicTimerQueue.newEventId(); + this.timerQueue = timerQueue; + maxInitialDeadline = deadline = maxDeadline; + assert isClientConnection() : "MaxInitialTimer should only be used on QuicClients"; + } + + /** + * Called when an initial packet is received from the + * peer. At this point the MaxInitialTimer is disarmed, + * and further calls to this method are no-op. + */ + void initialPacketReceived() { + if (done) return; // races are OK - avoids volatile read + boolean firsPacketReceived = initialPacketReceived; + boolean closed = connectionClosed; + if (done = (firsPacketReceived || closed)) return; + initialPacketReceived = true; + if (debug.on()) { + debug.log("Quic initial timer disarmed after %s seconds", + DEFAULT_MAX_INITIAL_TIMEOUT - + Deadline.between(now(), maxInitialDeadline).toSeconds()); + } + if (!closed) { + // rescheduling with Deadline.MAX will take the + // MaxInitialTimer out of the timer queue. + timerQueue.reschedule(this, Deadline.MAX); + } + } + + @Override + public Deadline deadline() { + return deadline; + } + + /** + * This method is called if the timer expires. + * If no initial packet has been received ( + * {@link #initialPacketReceived()} was never called), + * the connection's handshakeCF is completed with a + * {@link ConnectException}. + * Calling this method a second time is a no-op. + * @return {@link Deadline#MAX}, always. + */ + @Override + public Deadline handle() { + if (done) return Deadline.MAX; + boolean firsPacketReceived = initialPacketReceived; + boolean closed = connectionClosed; + if (!firsPacketReceived && !closed) { + assert !now().isBefore(maxInitialDeadline); + var connectException = new ConnectException("No response from peer for %s seconds" + .formatted(DEFAULT_MAX_INITIAL_TIMEOUT)); + if (QuicConnectionImpl.this.handshakeFlow.handshakeCF() + .completeExceptionally(connectException)) { + // abandon the connection, but sends ConnectionCloseFrame + TerminationCause cause = TerminationCause.forException( + new QuicTransportException(connectException.getMessage(), + KeySpace.INITIAL, 0, QuicTransportErrors.APPLICATION_ERROR)); + terminator.terminate(cause); + } + connectionClosed = done = closed = true; + } + assert firsPacketReceived || closed; + return Deadline.MAX; + } + + @Override + public long eventId() { + return eventId; + } + + @Override + public Deadline refreshDeadline() { + boolean firstPacketReceived = initialPacketReceived; + boolean closed = connectionClosed; + Deadline newDeadlne = deadline; + if (closed || firstPacketReceived) newDeadlne = deadline = Deadline.MAX; + return newDeadlne; + } + + private Deadline now() { + return QuicConnectionImpl.this.endpoint().timeSource().instant(); + } + } + + /** + * A state handle is a mutable implementation of {@link QuicConnectionState} + * that allows to view the volatile connection int variable {@code state} as + * a {@code QuicConnectionState}, and provides methods to mutate it in + * a thread safe atomic way. + */ + protected final class StateHandle extends QuicConnectionState { + public int state() { return state;} + + /** + * Updates the state to a new state value with the passed bit {@code mask} set. + * + * @param mask The state mask + * @return true if previously the state value didn't have the {@code mask} set and this + * method successfully updated the state value to set the {@code mask} + */ + final boolean mark(final int mask) { + int state, desired; + do { + state = desired = state(); + if ((state & mask) == mask) return false; // already set + desired = state | mask; + } while (!STATE.compareAndSet(QuicConnectionImpl.this, state, desired)); + return true; // compareAndSet switched the old state to the desired state + } + public boolean markHelloSent() { return mark(HISENT); } + public boolean markHandshakeComplete() { return mark(HSCOMPLETE); } + } + + /** + * Keeps track of: + * - handshakeCF the handshake completable future + * - localInitial the local initial crypto writer queue + * - peerInitial the peer initial crypto flow + * - localHandshake the local handshake crypto queue + * - peerHandshake the peer handshake crypto flow + */ + protected final class HandshakeFlow { + + private final CompletableFuture handshakeCF; + // a CompletableFuture which will get completed when the handshake initiated locally, + // has "reached" the peer i.e. when the peer acknowledges or replies to the first + // INITIAL packet sent by an endpoint + final CompletableFuture handshakeReachedPeerCF; + private final CryptoWriterQueue localInitial = new CryptoWriterQueue(); + private final CryptoDataFlow peerInitial = new CryptoDataFlow(); + private final CryptoWriterQueue localHandshake = new CryptoWriterQueue(); + private final CryptoDataFlow peerHandshake = new CryptoDataFlow(); + private final AtomicBoolean handshakeStarted = new AtomicBoolean(); + + private HandshakeFlow() { + this.handshakeCF = new MinimalFuture<>(); + this.handshakeReachedPeerCF = new MinimalFuture<>(); + // ensure that the handshakeReachedPeerCF gets completed exceptionally + // if an exception is raised before the first INITIAL packet is + // acked by the peer. + handshakeCF.whenComplete((r, t) -> { + if (Log.quicHandshake()) { + Log.logQuic("{0} handshake completed {1}", + logTag(), + t == null ? "successfully" : ("exceptionally: " + t)); + } + if (t != null) { + handshakeReachedPeerCF.completeExceptionally(t); + } + }); + } + + /** + * {@return the CompletableFuture representing a handshake} + */ + public CompletableFuture handshakeCF() { + return this.handshakeCF; + } + + public void failHandshakeCFs(final Throwable cause) { + assert cause != null : "missing cause when failing handshake CFs"; + SSLHandshakeException sslHandshakeException = null; + if (!handshakeCF.isDone()) { + sslHandshakeException = sslHandshakeException(cause); + handshakeCF.completeExceptionally(sslHandshakeException); + } + if (!handshakeReachedPeerCF.isDone()) { + if (sslHandshakeException == null) { + sslHandshakeException = sslHandshakeException(cause); + } + handshakeReachedPeerCF.completeExceptionally(sslHandshakeException); + } + } + + private SSLHandshakeException sslHandshakeException(final Throwable cause) { + if (cause instanceof SSLHandshakeException ssl) { + return ssl; + } + return new SSLHandshakeException("QUIC connection establishment failed", cause); + } + + /** + * Marks the start of a handshake. + * @throws IllegalStateException If handshake has already started + */ + private void markHandshakeStart() { + if (!handshakeStarted.compareAndSet(false, true)) { + throw new IllegalStateException("Handshake has already started on " + + QuicConnectionImpl.this); + } + } + } + + public record PacketSpaces(PacketSpace initial, PacketSpace handshake, PacketSpace app) { + public PacketSpace get(PacketNumberSpace pnspace) { + return switch (pnspace) { + case INITIAL -> initial(); + case HANDSHAKE -> handshake(); + case APPLICATION -> app(); + default -> throw new IllegalArgumentException(String.valueOf(pnspace)); + }; + } + + private QuicOneRttContext getOneRttContext() { + final var appPacketSpaceMgr = app(); + assert appPacketSpaceMgr instanceof QuicOneRttContext + : "unexpected 1-RTT packet space manager"; + return (QuicOneRttContext) appPacketSpaceMgr; + } + + public static PacketSpaces forConnection(final QuicConnectionImpl connection) { + final var initialPktSpaceMgr = new PacketSpaceManager(connection, PacketNumberSpace.INITIAL); + return new PacketSpaces(initialPktSpaceMgr, + new PacketSpaceManager.HandshakePacketSpaceManager(connection, initialPktSpaceMgr), + new PacketSpaceManager.OneRttPacketSpaceManager(connection)); + } + + public void close() { + initial.close(); + handshake.close(); + app.close(); + } + } + + private final ConcurrentLinkedQueue incoming = new ConcurrentLinkedQueue<>(); + private final SequentialScheduler incomingLoopScheduler = + SequentialScheduler.lockingScheduler(this::incoming); + + + /* + * delegate handling of the datagrams to the executor to free up + * the endpoint readLoop. Helps with processing ACKs in a more + * timely fashion, which avoids too many retransmission. + * The endpoint readLoop runs on a single thread, while this loop + * will have one thread per connection which helps with a better + * utilization of the system resources. + */ + private void scheduleForDecryption(IncomingDatagram datagram) { + // Processes an incoming encrypted packet that has just been + // read off the network. + var received = datagram.buffer.remaining(); + if (incomingLoopScheduler.isStopped()) { + if (debug.on()) { + debug.log("scheduleForDecryption closed: dropping datagram (%d bytes)", + received); + } + return; + } + if (debug.on()) { + debug.log("scheduleForDecryption: %d bytes", received); + } + endpoint.buffer(received); + incoming.add(datagram); + + incomingLoopScheduler.runOrSchedule(quicInstance().executor()); + } + + private void incoming() { + try { + IncomingDatagram datagram; + while ((datagram = incoming.poll()) != null) { + ByteBuffer buffer = datagram.buffer; + int remaining = buffer.remaining(); + try { + if (incomingLoopScheduler.isStopped()) { + // we still need to unbuffer, continue here will + // ensure we skip directly to the finally-block + // below. + continue; + } + + internalProcessIncoming(datagram.source(), + datagram.destConnId(), + datagram.headersType(), + datagram.buffer()); + } catch (Throwable t) { + if (Log.errors() || debug.on()) { + String msg = "Failed to process datagram: " + t; + Log.logError(logTag() + " " + msg); + debug.log(msg, t); + } + } finally { + endpoint.unbuffer(remaining); + } + } + } catch (Throwable t) { + terminator.terminate(TerminationCause.forException(t)); + } + } + + /** + * Schedule an incoming quic packet for decryption. + * The ByteBuffer should contain a single packet, and its + * limit should be set at the end of the packet. + * + * @param buffer a byte buffer containing the incoming packet + */ + private void decrypt(ByteBuffer buffer) { + // Processes an incoming encrypted packet that has just been + // read off the network. + PacketType packetType = decoder.peekPacketType(buffer); + var received = buffer.remaining(); + var pos = buffer.position(); + if (debug.on()) { + debug.log("decrypt %s(pos=%d, remaining=%d)", + packetType, pos, received); + } + try { + assert packetType != PacketType.VERSIONS; + var packet = codingContext.parsePacket(buffer); + if (packet != null) { + processDecrypted(packet); + } else { + if (packetType == PacketType.HANDSHAKE) { + packetSpaces.initial.fastRetransmit(); + } + } + } catch (QuicTransportException qte) { + // close the connection on this fatal error + if (Log.errors() || debug.on()) { + final String msg = "closing connection due to error while decoding" + + " packet (type=" + packetType + "): " + qte; + Log.logError(logTag() + " " + msg); + debug.log(msg, qte); + } + terminator.terminate(TerminationCause.forException(qte)); + } catch (Throwable t) { + if (Log.errors() || debug.on()) { + String msg = "Failed to decode packet (type=" + packetType + "): " + t; + Log.logError(logTag() + " " + msg); + debug.log(msg, t); + } + } + } + + public void closeIncoming() { + incomingLoopScheduler.stop(); + IncomingDatagram icd; + // we still need to unbuffer all datagrams in the queue + while ((icd = incoming.poll()) != null ) { + endpoint.unbuffer(icd.buffer().remaining()); + } + } + + /** + * A protection record contains a packet to encrypt, and a datagram that may already + * contain encrypted packets. The firstPacketOffset indicates the position of the + * first encrypted packet in the datagram. The packetOffset indicates the position + * at which this packet will be - or has been - written in the datagram. + * Before the packet is encrypted and written to the datagram, the packetOffset + * should be the same as the datagram buffer position. + * After the packet has been written, the packetOffset should indicate + * at which position the packet has been written. The datagram position + * indicates where to write the next packet. + *

+ * Additionally, a {@code ProtectionRecord} may carry some flags indicating the + * intended usage of the datagram. The following flags are supported: + *

    + *
  • {@link #SINGLE_PACKET}: the default - it is not expected that the + * datagram will contain more packets
  • + *
  • {@link #COALESCED}: should be used if it is expected that the + * datagram will contain more than one packet
  • + *
  • {@link #LAST_PACKET}: should be used in conjunction with {@link #COALESCED} + * to indicate that the packet being protected is the last that will be + * added to the datagram
  • + *
+ * + * @apiNote + * Flag values can be combined, but some combinations + * may not make sense. A single packet can also be identified as any + * packet that doesn't have the {@code COALESCED} bit on. + * The flag is used to convey information that may be used to figure + * out whether to send the datagram right away, or whether to wait for + * more packet to be coalesced inside it. + * + * @param packet the packet to encrypt + * @param datagram the datagram in which the encrypted packet should be written + * @param firstPacketOffset the position of the first encrypted packet in the datagram + * @param packetOffset the offset at which the packet should be / has been written in the datagram + * @param flags a bit mask containing some details about the datagram being sent out + */ + public record ProtectionRecord(QuicPacket packet, ByteBuffer datagram, + int firstPacketOffset, int packetOffset, + long retransmittedPacketNumber, int flags) { + /** + * This is the default. + * This protection record is adding a single packet to be sent into + * the datagram and the datagram can be sent as soon as the packet + * has been encrypted. + */ + public static final int SINGLE_PACKET = 0; + /** + * This can be used when it is expected that more than one packet + * will be added to this datagram. We should wait until the last packet + * has been added before sending the datagram out. + */ + public static final int COALESCED = 1; + /** + * This protection record is adding the last packet to be sent into + * the datagram and the datagram can be sent as soon as the packet + * has been encrypted. + */ + public static final int LAST_PACKET = 2; + + // indicate that the packet is not retransmitted + private static final long NOT_RETRANSMITTED = -1L; + + ProtectionRecord withOffset(int packetOffset) { + if (this.packetOffset == packetOffset) { + return this; + } + return new ProtectionRecord(packet, datagram, firstPacketOffset, + packetOffset, retransmittedPacketNumber, flags); + } + + public ProtectionRecord encrypt(final CodingContext codingContext) + throws QuicKeyUnavailableException, QuicTransportException { + final PacketType packetType = packet.packetType(); + assert packetType != PacketType.VERSIONS; + // keep track of position before encryption + final int preEncryptPos = datagram.position(); + codingContext.writePacket(packet, datagram); + final ProtectionRecord encrypted = withOffset(preEncryptPos); + return encrypted; + } + + /** + * Records the intent of protecting a packet that will be sent as soon + * as it has been encrypted, without waiting for more packets to be + * coalesced into the datagram. + * + * @param packet the packet to protect + * @param allocator an allocator to allocate the datagram + * @return a protection record to submit for packet protection + */ + public static ProtectionRecord single(QuicPacket packet, + Function allocator) { + ByteBuffer datagram = allocator.apply(packet); + int offset = datagram.position(); + return new ProtectionRecord(packet, datagram, + offset, offset, NOT_RETRANSMITTED, 0); + } + + /** + * Records the intent of protecting a packet that retransmits + * a previously transmitted packet. The packet will be sent as soon + * as it has been encrypted, without waiting for more packets to be + * coalesced into the datagram. + * + * @param packet the packet to protect + * @param retransmittedPacketNumber the packet number of the original + * packet that was considered lost + * @param allocator an allocator to allocate the datagram + * @return a protection record to submit for packet protection + */ + public static ProtectionRecord retransmitting(QuicPacket packet, + long retransmittedPacketNumber, + Function allocator) { + ByteBuffer datagram = allocator.apply(packet); + int offset = datagram.position(); + return new ProtectionRecord(packet, datagram, offset, offset, + retransmittedPacketNumber, 0); + } + + /** + * Records the intent of protecting a packet that will be followed by + * more packets to be coalesced in the same datagram. The datagram + * should not be sent until the last packet has been coalesced. + * + * @param packet the packet to protect + * @param datagram the datagram in which packet will be coalesced + * @param firstPacketOffset the offset of the first packet in the datagram + * @return a protection record to submit for packet protection + */ + public static ProtectionRecord more(QuicPacket packet, ByteBuffer datagram, int firstPacketOffset) { + return new ProtectionRecord(packet, datagram, firstPacketOffset, + datagram.position(), NOT_RETRANSMITTED, COALESCED); + } + + /** + * Records the intent of protecting the last packet that will be + * coalesced in the given datagram. The datagram can be sent as soon + * as the packet has been encrypted and coalesced into the given + * datagram. + * + * @param packet the packet to protect + * @param datagram the datagram in which packet will be coalesced + * @param firstPacketOffset the offset of the first packet in the datagram + * @return a protection record to submit for packet protection + */ + public static ProtectionRecord last(QuicPacket packet, ByteBuffer datagram, int firstPacketOffset) { + return new ProtectionRecord(packet, datagram, firstPacketOffset, + datagram.position(), NOT_RETRANSMITTED, LAST_PACKET | COALESCED); + } + } + + final QuicPacket newQuicPacket(final KeySpace keySpace, final List frames) { + final PacketSpace packetSpace = packetSpaces.get(PacketNumberSpace.of(keySpace)); + return encoder.newOutgoingPacket(keySpace, packetSpace, + localConnectionId(), peerConnectionId(), initialToken(), + frames, + codingContext); + } + + /** + * Encrypt an outgoing quic packet. + * The ProtectionRecord indicates the position at which the encrypted packet + * should be written in the datagram, as well as the position of the + * first packet in the datagram. After encrypting the packet, this method calls + * {@link #pushEncryptedDatagram(ProtectionRecord)} + * + * @param protectionRecord a record containing a quic packet to encrypt, + * a destination byte buffer, and various offset information. + */ + final void pushDatagram(final ProtectionRecord protectionRecord) + throws QuicKeyUnavailableException, QuicTransportException { + final QuicPacket packet = protectionRecord.packet(); + if (debug.on()) { + debug.log("encrypting packet into datagram %s(pn:%s, %s)", packet.packetType(), + packet.packetNumber(), packet.frames()); + } + // Processes an outgoing unencrypted packet that needs to be + // encrypted before being packaged in a datagram. + final ProtectionRecord encrypted; + try { + encrypted = protectionRecord.encrypt(codingContext); + } catch (Throwable e) { + // release the datagram ByteBuffer on failure to encrypt + datagramDiscarded(new QuicDatagram(this, peerAddress, protectionRecord.datagram())); + if (Log.errors()) { + Log.logError("Failed to encrypt packet: " + e); + // certain failures like key not being available are OK + // in some situations. log the stacktrace only if this + // was an unexpected failure. + boolean skipStackTrace = false; + if (e instanceof QuicKeyUnavailableException) { + final PacketSpace packetSpace = packetSpace(protectionRecord.packet().numberSpace()); + skipStackTrace = packetSpace.isClosed(); + } + if (!skipStackTrace) { + Log.logError(e); + } + } + throw e; + } + // we currently don't support a ProtectionRecord with more than one QuicPacket + assert (encrypted.flags & ProtectionRecord.COALESCED) == 0 : "coalesced packets not supported"; + // encryption of the datagram is complete, now push the encrypted + // datagram through the endpoint + if (Log.quicPacketOutLoggable(packet)) { + Log.logQuicPacketOut(logTag(), packet); + } + pushEncryptedDatagram(encrypted); + } + + protected void completeHandshakeCF() { + // This can be called from the decrypt loop, and can trigger + // sending of 1-RTT application data from within the same + // thread: we use an executor here to avoid running the application + // sending loop from within the Quic decrypt loop. + completeHandshakeCF(quicInstance().executor()); + } + + protected final void completeHandshakeCF(Executor executor) { + final var handshakeCF = handshakeFlow.handshakeCF(); + if (handshakeCF.isDone()) { + return; + } + var handshakeState = quicTLSEngine.getHandshakeState(); + if (executor != null) { + handshakeCF.completeAsync(() -> handshakeState, executor); + } else { + handshakeCF.complete(handshakeState); + } + } + + /** + * A class used to check that 1-RTT received data doesn't exceed + * the MAX_DATA of the connection + */ + class OneRttFlowControlledReceivingQueue { + private static final long MIN_BUFFER_SIZE = 16L << 10; // 16k + private volatile long receivedData; + private volatile long maxData; + private volatile long processedData; + // Desired buffer size; used when updating maxStreamData + private final long desiredBufferSize = Math.clamp(DEFAULT_INITIAL_MAX_DATA, MIN_BUFFER_SIZE, MAX_VL_INTEGER); + private final Supplier logTag; + + OneRttFlowControlledReceivingQueue(Supplier logTag) { + this.logTag = Objects.requireNonNull(logTag); + } + + /** + * Called when new local parameters are available + * @param localParameters the new local paramaters + */ + void newLocalParameters(QuicTransportParameters localParameters) { + if (localParameters.isPresent(ParameterId.initial_max_data)) { + long maxData = this.maxData; + long newMaxData = localParameters.getIntParameter(ParameterId.initial_max_data); + while (maxData < newMaxData) { + if (MAX_RCV_DATA.compareAndSet(this, maxData, newMaxData)) break; + maxData = this.maxData; + } + } + } + + /** + * Checks whether the give frame would cause the connection max data + * to be exceeded. If no, increase the amount of data processed by + * this connection by the length of the frame. If yes, sends a + * ConnectionCloseFrame with FLOW_CONTROL_ERROR. + * + * @param diff number of bytes newly received + * @param frameType type of frame received + * @throws QuicTransportException if processing this frame would cause the connection + * max data to be exceeded + */ + void checkAndIncreaseReceivedData(long diff, long frameType) throws QuicTransportException { + assert diff > 0; + long max, processed; + boolean exceeded; + synchronized (this) { + max = maxData; + processed = receivedData; + if (max - processed < diff) { + exceeded = true; + } else { + try { + receivedData = processed = Math.addExact(processed, diff); + exceeded = false; + } catch (ArithmeticException x ) { + // should not happen - flow control should have + // caught that + receivedData = processed = Long.MAX_VALUE; + exceeded = true; + } + } + } + if (exceeded) { + String reason = "Connection max data exceeded: max data processed=%s, max connection data=%s" + .formatted(processed, max); + throw new QuicTransportException(reason, + QuicTLSEngine.KeySpace.ONE_RTT, frameType, QuicTransportErrors.FLOW_CONTROL_ERROR); + } + } + + public void increaseProcessedData(long diff) { + long processed, received, max; + synchronized (this) { + processed = processedData += diff; + received = receivedData; + max = maxData; + } + if (Log.quicProcessed()) { + Log.logQuic(logTag()+ " Processed: " + processed + + ", received: " + received + + ", max:" + max); + } + if (needSendMaxData()) { + runAppPacketSpaceTransmitter(); + } + } + + private long bumpMaxData() { + long newMaxData = processedData + desiredBufferSize; + long maxData = this.maxData; + if (newMaxData - maxData < (desiredBufferSize / 5)) { + return 0; + } + while (maxData < newMaxData) { + if (MAX_RCV_DATA.compareAndSet(this, maxData, newMaxData)) + return newMaxData; + maxData = this.maxData; + } + return 0; + } + + public boolean needSendMaxData() { + return maxData - processedData < desiredBufferSize/2; + } + + String logTag() { return logTag.get(); } + } + + /** + * An event loop triggered when stream data is available for sending. + * We use a sequential scheduler here to make sure we don't send + * more data than allowed by the connection's flow control. + * This guarantee that only one thread composes flow controlled + * OneRTT packets at a given time, which in turn guarantees that the + * credit computed at the beginning of the loop will still be + * available after the packet has been composed. + */ + class OneRttFlowControlledSendingQueue { + private volatile long dataProcessed; + private volatile long maxData; + + /** + * Called when a MAX_DATA frame is received. + * This method is a no-op if the given value is less than the + * current max stream data for the connection. + * + * @param maxData the maximum data offset that the peer is prepared + * to accept for the whole connection + * @param isInitial true when processing transport parameters, + * false when processing MaxDataFrame + * @return the actual max data after taking the given value into account + */ + public long setMaxData(long maxData, boolean isInitial) { + long max; + long processed; + boolean wasblocked, unblocked = false; + do { + synchronized (this) { + max = this.maxData; + processed = dataProcessed; + } + wasblocked = max <= processed; + if (max < maxData) { + if (MAX_SND_DATA.compareAndSet(this, max, maxData)) { + max = maxData; + unblocked = (wasblocked && max > processed); + } + } + } while (max < maxData); + if (unblocked && !isInitial) { + packetSpaces.app.runTransmitter(); + } + return max; + } + + /** + * {@return the remaining credit for this connection} + */ + public long credit() { + synchronized (this) { + return maxData - dataProcessed; + } + } + + // We can continue sending if we have credit and data is available to send + private boolean canSend() { + return credit() > 0 && streams.hasAvailableData() + || streams.hasControlFrames() + || hasQueuedFrames() + || oneRttRcvQueue.needSendMaxData(); + } + + // implementation of the sending loop. + private boolean send1RTTData() { + Throwable failure; + try { + return doSend1RTTData(); + } catch (Throwable t) { + failure = t; + } + if (failure instanceof QuicKeyUnavailableException qkue) { + if (!QuicConnectionImpl.this.stateHandle().opened()) { + // connection is already being closed and that explains the + // key unavailability (they might have been discarded). just log + // and return + if (debug.on()) { + debug.log("failed to send stream data, reason: " + qkue.getMessage()); + } + return false; + } + // connection is still open but a key unavailability exception was raised. + // close the connection and use an IOException instead of the internal + // QuicKeyUnavailableException as the cause for the connection close. + failure = new IOException(qkue.getMessage()); + } + if (debug.on()) { + debug.log("failed to send stream data", failure); + } + // close the connection to make sure it's not just ignored + terminator.terminate(TerminationCause.forException(failure)); + return false; + } + + private boolean doSend1RTTData() throws QuicKeyUnavailableException, QuicTransportException { + // Loop over all sending streams to see if data is available - include + // as much data as possible in the quic packet before sending it. + // The QuicConnectionStreams make sure that streams are polled in a fair + // manner (using round-robin?) + // This loop is called through a sequential scheduler to make + // sure we only have one thread emitting flow control data for + // this connection + final PacketSpace space = packetSpace(PacketNumberSpace.APPLICATION); + final int maxDatagramSize = getMaxDatagramSize(); + final QuicConnectionId peerConnectionId = peerConnectionId(); + final int dstIdLength = peerConnectionId().length(); + if (!canSend()) { + return false; + } + final long packetNumber = space.allocateNextPN(); + final long largestPeerAckedPN = space.getLargestPeerAckedPN(); + int remaining = QuicPacketEncoder.computeMaxOneRTTPayloadSize( + codingContext, packetNumber, dstIdLength, maxDatagramSize, largestPeerAckedPN); + if (remaining == 0) { + // not enough space to send available data + return false; + } + final List frames = new ArrayList<>(); + remaining -= addConnectionControlFrames(remaining, frames); + assert remaining >= 0 : remaining; + long produced = streams.produceFramesToSend(encoder, remaining, credit(), frames); + if (frames.isEmpty()) { + // produced cannot be > 0 unless there are some frames to send + assert produced == 0; + return false; + } + // non-atomic operation should be OK since sendStreamData0 is called + // only from the sending loop, and this is the only place where we + // mutate dataProcessed. + dataProcessed += produced; + final OneRttPacket packet = encoder.newOneRttPacket(peerConnectionId, + packetNumber, largestPeerAckedPN, frames, codingContext); + QuicConnectionImpl.this.send1RTTPacket(packet); + return true; + } + + /** + * Produces connection-level control frames for sending in the next one-rtt + * packet. The frames are added to the provided list. + * + * @param maxAllowedBytes maximum number of bytes the method is allowed to add + * @param frames list where the frames are added + * @return number of bytes added + */ + private int addConnectionControlFrames(final int maxAllowedBytes, + final List frames) { + assert maxAllowedBytes > 0 : "unexpected max allowed bytes: " + maxAllowedBytes; + int added = 0; + int remaining = maxAllowedBytes; + QuicFrame f; + while ((f = outgoing1RTTFrames.peek()) != null) { + final int frameSize = f.size(); + if (frameSize <= remaining) { + outgoing1RTTFrames.remove(); + frames.add(f); + added += frameSize; + remaining -= frameSize; + } else { + break; + } + } + PathChallengeFrame pcf; + while (remaining >= 9 && (pcf = pathChallengeFrameQueue.poll()) != null) { + f = new PathResponseFrame(pcf.data()); + final int frameSize = f.size(); + assert frameSize <= remaining : "Frame too large"; + frames.add(f); + added += frameSize; + remaining -= frameSize; + } + + // NEW_CONNECTION_ID + while ((f = localConnIdManager.nextFrame(remaining)) != null) { + final int frameSize = f.size(); + assert frameSize <= remaining : "Frame too large"; + frames.add(f); + added += frameSize; + remaining -= frameSize; + } + // RETIRE_CONNECTION_ID + while ((f = peerConnIdManager.nextFrame(remaining)) != null) { + final int frameSize = f.size(); + assert frameSize <= remaining : "Frame too large"; + frames.add(f); + added += frameSize; + remaining -= frameSize; + } + + if (remaining == 0) { + return added; + } + final PacketSpace space = packetSpace(PacketNumberSpace.APPLICATION); + final AckFrame ack = space.getNextAckFrame(false, remaining); + if (ack != null) { + final int ackFrameSize = ack.size(); + assert ackFrameSize <= remaining; + if (debug.on()) { + debug.log("Adding AckFrame"); + } + frames.add(ack); + added += ackFrameSize; + remaining -= ackFrameSize; + } + final long credit = credit(); + if (credit < remaining && remaining > 10) { + if (debug.on()) { + debug.log("Adding DataBlockedFrame"); + } + DataBlockedFrame dbf = new DataBlockedFrame(maxData); + frames.add(dbf); + added += dbf.size(); + remaining -= dbf.size(); + } + // max data + if (remaining > 10) { + long maxData = oneRttRcvQueue.bumpMaxData(); + if (maxData != 0) { + if (debug.on()) { + debug.log("Adding MaxDataFrame (processed: %s)", + oneRttRcvQueue.processedData); + } + MaxDataFrame mdf = new MaxDataFrame(maxData); + frames.add(mdf); + added += mdf.size(); + remaining -= mdf.size(); + } + } + // session ticket + if (quicTLSEngine.getCurrentSendKeySpace() == KeySpace.ONE_RTT) { + try { + ByteBuffer payloadBuffer = quicTLSEngine.getHandshakeBytes(KeySpace.ONE_RTT); + if (payloadBuffer != null) { + localCryptoFlow.enqueue(payloadBuffer); + } + } catch (IOException e) { + throw new AssertionError("Should not happen!", e); + } + if (localCryptoFlow.remaining() > 0 && remaining > 3) { + CryptoFrame frame = localCryptoFlow.produceFrame(remaining); + if (frame != null) { + if (debug.on()) { + debug.log("Adding CryptoFrame"); + } + frames.add(frame); + added += frame.size(); + remaining -= frame.size(); + assert remaining >= 0; + } + } + } + return added; + } + } + + /** + * Invoked to send a ONERTT packet containing stream data or + * control frames. + * + * @apiNote + * This method can be overridden if some action needs to be + * performed after sending a packet containing certain type + * of frames. Typically, a server side connection may want + * to close the HANDSHAKE space only after sending the + * HANDSHAKE_DONE frame. + * + * @param packet The ONERTT packet to send. + */ + protected void send1RTTPacket(final OneRttPacket packet) + throws QuicKeyUnavailableException, QuicTransportException { + pushDatagram(ProtectionRecord.single(packet, + QuicConnectionImpl.this::allocateDatagramForEncryption)); + } + + /** + * Schedule a frame for sending in a 1-RTT packet. + *

+ * For use with frames that do not change with time + * (like MAX_* / *_BLOCKED / ACK), + * or with remaining datagram capacity (like STREAM or CRYPTO), + * and do not require certain path (PATH_CHALLENGE / RESPONSE). + *

+ * Use with frames like HANDSHAKE_DONE, NEW_TOKEN, + * NEW_CONNECTION_ID, RETIRE_CONNECTION_ID. + *

+ * Maximum accepted frame size is 1000 bytes to ensure that the frame + * will fit in a 1-RTT datagram in the foreseeable future. + * @param frame frame to send + * @throws IllegalArgumentException if frame is larger than 1000 bytes + */ + protected void enqueue1RTTFrame(final QuicFrame frame) { + if (frame.size() > 1000) { + throw new IllegalArgumentException("Frame too big"); + } + assert frame.isValidIn(PacketType.ONERTT) : "frame " + frame + " is not" + + " eligible in 1-RTT space"; + outgoing1RTTFrames.add(frame); + } + + /** + * {@return true if queued frames are available for sending} + */ + private boolean hasQueuedFrames() { + return !outgoing1RTTFrames.isEmpty(); + } + + protected QuicPacketEncoder encoder() { return encoder;} + protected QuicPacketDecoder decoder() { return decoder; } + public QuicEndpoint endpoint() { return endpoint; } + protected final StateHandle stateHandle() { return stateHandle; } + protected CodingContext codingContext() { + return codingContext; + } + + public long largestAckedPN(PacketNumberSpace packetSpace) { + var space = packetSpaces.get(packetSpace); + return space.getLargestPeerAckedPN(); + } + + public long largestProcessedPN(PacketNumberSpace packetSpace) { + var space = packetSpaces.get(packetSpace); + return space.getLargestProcessedPN(); + } + + public int connectionIdLength() { + return localConnectionId().length(); + } + + public QuicInstance quicInstance() { + return this.quicInstance; + } + + public QuicVersion quicVersion() { + return this.quicVersion; + } + + protected class QuicCodingContext implements CodingContext { + @Override public long largestProcessedPN(PacketNumberSpace packetSpace) { + return QuicConnectionImpl.this.largestProcessedPN(packetSpace); + } + @Override public long largestAckedPN(PacketNumberSpace packetSpace) { + return QuicConnectionImpl.this.largestAckedPN(packetSpace); + } + @Override public int connectionIdLength() { + return QuicConnectionImpl.this.connectionIdLength(); + } + @Override public int writePacket(QuicPacket packet, ByteBuffer buffer) + throws QuicKeyUnavailableException, QuicTransportException { + int start = buffer.position(); + encoder.encode(packet, buffer, this); + return buffer.position() - start; + } + @Override public QuicPacket parsePacket(ByteBuffer src) + throws IOException, QuicKeyUnavailableException, QuicTransportException { + return decoder.decode(src, this); + } + @Override + public QuicConnectionId originalServerConnId() { + return QuicConnectionImpl.this.originalServerConnId(); + } + + @Override + public QuicTLSEngine getTLSEngine() { + return quicTLSEngine; + } + + @Override + public boolean verifyToken(QuicConnectionId destinationID, byte[] token) { + return QuicConnectionImpl.this.verifyToken(destinationID, token); + } + } + + protected boolean verifyToken(QuicConnectionId destinationID, byte[] token) { + // server must send zero-length token + return token == null; + } + + protected PacketEmitter emitter() { + return new PacketEmitter() { + @Override + public QuicTimerQueue timer() { + return QuicConnectionImpl.this.endpoint().timer(); + } + + @Override + public void retransmit(PacketSpace packetSpaceManager, QuicPacket packet, int attempts) + throws QuicKeyUnavailableException, QuicTransportException { + QuicConnectionImpl.this.retransmit(packetSpaceManager, packet, attempts); + } + + @Override + public long emitAckPacket(PacketSpace packetSpaceManager, + AckFrame frame, + boolean sendPing) + throws QuicKeyUnavailableException, QuicTransportException { + return QuicConnectionImpl.this.emitAckPacket(packetSpaceManager, frame, sendPing); + } + + @Override + public void acknowledged(QuicPacket packet) { + QuicConnectionImpl.this.packetAcknowledged(packet); + } + + @Override + public boolean sendData(PacketNumberSpace packetNumberSpace) + throws QuicKeyUnavailableException, QuicTransportException { + return QuicConnectionImpl.this.sendData(packetNumberSpace); + } + + @Override + public Executor executor() { + return quicInstance().executor(); + } + + @Override + public void reschedule(QuicTimedEvent task) { + var endpoint = QuicConnectionImpl.this.endpoint(); + if (endpoint == null) return; + endpoint.timer().reschedule(task); + } + + @Override + public void reschedule(QuicTimedEvent task, Deadline deadline) { + var endpoint = QuicConnectionImpl.this.endpoint(); + if (endpoint == null) return; + endpoint.timer().reschedule(task, deadline); + } + + @Override + public void checkAbort(PacketNumberSpace packetNumberSpace) { + QuicConnectionImpl.this.checkAbort(packetNumberSpace); + } + + @Override + public void ptoBackoffIncreased(PacketSpaceManager space, long backoff) { + if (Log.quicRetransmit()) { + Log.logQuic("%s OUT: [%s] increase backoff to %s, duration %s ms: %s" + .formatted(QuicConnectionImpl.this.logTag(), + space.packetNumberSpace(), backoff, + space.getPtoDuration().toMillis(), + rttEstimator.state())); + } + } + + @Override + public String logTag() { + return QuicConnectionImpl.this.logTag(); + } + + @Override + public boolean isOpen() { + return QuicConnectionImpl.this.stateHandle.opened(); + } + }; + } + + private void checkAbort(PacketNumberSpace packetNumberSpace) { + // if pto backoff > 32 (i.e. PTO expired 5 times in a row), abort, + // unless we haven't reached MIN_PTO_BACKOFF_TIMEOUT + var backoff = rttEstimator.getPtoBackoff(); + if (backoff > QuicRttEstimator.MAX_PTO_BACKOFF) { + // If the maximum backoff is exceeded, we close the connection + // only if the associated backoff timeout exceeds the + // MIN_PTO_BACKOFF_TIMEOUT. Otherwise, we allow the backoff + // factor to grow again past the MAX_PTO_BACKOFF + if (rttEstimator.isMinBackoffTimeoutExceeded()) { + if (debug.on()) { + debug.log("%s Too many probe time outs: %s", packetNumberSpace, backoff); + debug.log(String.valueOf(rttEstimator.state())); + debug.log("State: %s", stateHandle().toString()); + } + if (Log.quicRetransmit() || Log.quicCC()) { + Log.logQuic("%s OUT: %s: Too many probe timeouts %s" + .formatted(logTag(), packetNumberSpace, + rttEstimator.state())); + StringBuilder sb = new StringBuilder(logTag()); + sb.append(" State: ").append(stateHandle().toString()); + for (PacketNumberSpace sp : PacketNumberSpace.values()) { + if (sp == PacketNumberSpace.NONE) continue; + if (packetSpaces.get(sp) instanceof PacketSpaceManager m) { + sb.append("\nPacketSpace: ").append(sp).append('\n'); + m.debugState(" ", sb); + } + } + Log.logQuic(sb.toString()); + } else if (debug.on()) { + for (PacketNumberSpace sp : PacketNumberSpace.values()) { + if (sp == PacketNumberSpace.NONE) continue; + if (packetSpaces.get(sp) instanceof PacketSpaceManager m) { + m.debugState(); + } + } + } + var pto = rttEstimator.getBasePtoDuration(); + var to = pto.multipliedBy(backoff); + if (to.compareTo(MAX_PTO_BACKOFF_TIMEOUT) > 0) to = MAX_PTO_BACKOFF_TIMEOUT; + String msg = "%s: Too many probe time outs (%s: backoff %s, duration %s, %s)" + .formatted(logTag(), packetNumberSpace, backoff, + to, rttEstimator.state()); + final TerminationCause terminationCause; + if (packetNumberSpace == PacketNumberSpace.HANDSHAKE) { + terminationCause = TerminationCause.forException(new SSLHandshakeException(msg)); + } else if (packetNumberSpace == PacketNumberSpace.INITIAL) { + terminationCause = TerminationCause.forException(new ConnectException(msg)); + } else { + terminationCause = TerminationCause.forException(new IOException(msg)); + } + terminator.terminate(terminationCause); + } else { + if (debug.on()) { + debug.log("%s: Max PTO backoff reached (%s) before min probe timeout exceeded (%s)," + + " allow more backoff %s", + packetNumberSpace, backoff, MIN_PTO_BACKOFF_TIMEOUT, rttEstimator.state()); + } + if (Log.quicRetransmit() || Log.quicCC()) { + Log.logQuic("%s OUT: %s: Max PTO backoff reached (%s) before min probe timeout exceeded (%s) - %s" + .formatted(QuicConnectionImpl.this.logTag(), packetNumberSpace, backoff, + MIN_PTO_BACKOFF_TIMEOUT, rttEstimator.state())); + } + } + } + } + + // this method is called when a packet has been acknowledged + private void packetAcknowledged(QuicPacket packet) { + // process packet frames to track acknowledgement + // of RESET_STREAM frames etc... + if (debug.on()) { + debug.log("Packet %s(pn:%s) is acknowledged by peer", + packet.packetType(), + packet.packetNumber()); + } + packet.frames().forEach(this::frameAcknowledged); + } + + // this method is called when a frame has been acknowledged + private void frameAcknowledged(QuicFrame frame) { + if (frame instanceof ResetStreamFrame reset) { + long streamId = reset.streamId(); + if (streams.isSendingStream(streamId)) { + streams.streamResetAcknowledged(reset); + } + } else if (frame instanceof StreamFrame streamFrame) { + if (streamFrame.isLast()) { + streams.streamDataSentAcknowledged(streamFrame); + } + } + } + + protected PacketSpaces packetNumberSpaces() { + return packetSpaces; + } + protected PacketSpace packetSpace(PacketNumberSpace packetNumberSpace) { + return packetSpaces.get(packetNumberSpace); + } + + public String dbgTag() { return dbgTag; } + + public String streamDbgTag(long streamId, String direction) { + String dir = direction == null || direction.isEmpty() + ? "" : ("(" + direction + ")"); + return dbgTag + "[streamId" + dir + "=" + streamId + "]"; + } + + + @Override + public CompletableFuture openNewLocalBidiStream(final Duration limitIncreaseDuration) { + if (!stateHandle.opened()) { + return MinimalFuture.failedFuture(new ClosedChannelException()); + } + final CompletableFuture> streamCF = + this.handshakeFlow.handshakeCF().thenApply((ignored) -> + streams.createNewLocalBidiStream(limitIncreaseDuration)); + return streamCF.thenCompose(Function.identity()); + } + + @Override + public CompletableFuture openNewLocalUniStream(final Duration limitIncreaseDuration) { + if (!stateHandle.opened()) { + return MinimalFuture.failedFuture(new ClosedChannelException()); + } + final CompletableFuture> streamCF = + this.handshakeFlow.handshakeCF().thenApply((ignored) + -> streams.createNewLocalUniStream(limitIncreaseDuration)); + return streamCF.thenCompose(Function.identity()); + } + + @Override + public void addRemoteStreamListener(Predicate streamConsumer) { + streams.addRemoteStreamListener(streamConsumer); + } + + @Override + public boolean removeRemoteStreamListener(Predicate streamConsumer) { + return streams.removeRemoteStreamListener(streamConsumer); + } + + @Override + public Stream quicStreams() { + return streams.quicStreams(); + } + + @Override + public List connectionIds() { + return localConnIdManager.connectionIds(); + } + + LocalConnIdManager localConnectionIdManager() { + return localConnIdManager; + } + + /** + * {@return the local connection id} + */ + public QuicConnectionId localConnectionId() { + return connectionId; + } + + /** + * {@return the peer connection id} + */ + public QuicConnectionId peerConnectionId() { + return this.peerConnIdManager.getPeerConnId(); + } + + /** + * Returns the original connection id. + * This is the original destination connection id that + * the client generated when connecting to the server for + * the first time. + * @return the original connection id + */ + protected QuicConnectionId originalServerConnId() { + return this.peerConnIdManager.originalServerConnId(); + } + + private record IncomingDatagram(SocketAddress source, ByteBuffer destConnId, + QuicPacket.HeadersType headersType, ByteBuffer buffer) {} + + @Override + public boolean accepts(SocketAddress source) { + // The client ever accepts packets from two sources: + // => the original peer address + // => the preferred peer address (not implemented) + if (!source.equals(peerAddress)) { + // We only accept packets from the endpoint to + // which we send them. + if (debug.on()) { + debug.log("unexpected sender %s, skipping packet", source); + } + return false; + } + return true; + } + + public void processIncoming(SocketAddress source, ByteBuffer destConnId, + QuicPacket.HeadersType headersType, ByteBuffer buffer) { + // Processes an incoming datagram that has just been + // read off the network. + if (debug.on()) { + debug.log("processIncoming %s(pos=%d, remaining=%d)", + headersType, buffer.position(), buffer.remaining()); + } + if (!stateHandle.opened()) { + if (debug.on()) { + debug.log("connection closed, skipping packet"); + } + return; + } + + assert accepts(source); + + scheduleForDecryption(new IncomingDatagram(source, destConnId, headersType, buffer)); + } + + public void internalProcessIncoming(SocketAddress source, ByteBuffer destConnId, + QuicPacket.HeadersType headersType, ByteBuffer buffer) { + try { + int packetIndex = 0; + while(buffer.hasRemaining()) { + int startPos = buffer.position(); + packetIndex++; + boolean isLongHeader = QuicPacketDecoder.peekHeaderType(buffer, startPos) == QuicPacket.HeadersType.LONG; + // It's only safe to check version here if versionNegotiated is true. + // We might be receiving an INITIAL packet before the version negotiation + // has been handled. + if (isLongHeader) { + LongHeader header = QuicPacketDecoder.peekLongHeader(buffer); + if (header == null) { + if (debug.on()) { + debug.log("Dropping long header packet (%s in datagram): too short", packetIndex); + } + return; + } + if (!header.destinationId().matches(destConnId)) { + if (debug.on()) { + debug.log("Dropping long header packet (%s in datagram):" + + " wrong connection id (%s vs %s)", + packetIndex, + header.destinationId().toHexString(), + Utils.asHexString(destConnId)); + } + return; + } + var peekedVersion = header.version(); + final var version = this.quicVersion.versionNumber(); + if (version != peekedVersion) { + if (peekedVersion == 0) { + if (!versionCompatible) { + VersionNegotiationPacket packet = (VersionNegotiationPacket) codingContext.parsePacket(buffer); + processDecrypted(packet); + } else { + if (debug.on()) { + debug.log("Versions packet (%s in datagram) ignored", packetIndex); + } + } + return; + } + QuicVersion packetVersion = QuicVersion.of(peekedVersion).orElse(null); + if (packetVersion == null) { + if (debug.on()) { + debug.log("Unknown Quic version in long header packet" + + " (%s in datagram) %s: 0x%x", + packetIndex, headersType, peekedVersion); + } + return; + } else if (versionNegotiated) { + if (debug.on()) { + debug.log("Dropping long header packet (%s in datagram)" + + " with version %s, already negotiated %s", + packetIndex, packetVersion, quicVersion); + } + return; + } else if (!quicInstance().isVersionAvailable(packetVersion)) { + if (debug.on()) { + debug.log("Dropping long header packet (%s in datagram)" + + " with disabled version %s", + packetIndex, packetVersion); + } + return; + } else { + // do we need to be less trusting here? + if (debug.on()) { + debug.log("Switching version to %s, previous: %s", + packetVersion, quicVersion); + } + switchVersion(packetVersion); + } + } + if (decoder.peekPacketType(buffer) == PacketType.INITIAL && + !quicTLSEngine.keysAvailable(KeySpace.INITIAL)) { + if (debug.on()) { + debug.log("Dropping INITIAL packet (%s in datagram): %s", + packetIndex, "keys discarded"); + } + decoder.skipPacket(buffer, startPos); + continue; + } + } else { + var cid = QuicPacketDecoder.peekShortConnectionId(buffer, destConnId.remaining()); + if (cid == null) { + if (debug.on()) { + debug.log("Dropping short header packet (%s in datagram):" + + " too short", packetIndex); + } + return; + } + if (cid.mismatch(destConnId) != -1) { + if (debug.on()) { + debug.log("Dropping short header packet (%s in datagram):" + + " wrong connection id (%s vs %s)", + packetIndex, Utils.asHexString(cid), Utils.asHexString(destConnId)); + } + return; + } + + } + ByteBuffer packet = decoder.nextPacketSlice(buffer, buffer.position()); + PacketType packetType = decoder.peekPacketType(packet); + if (debug.on()) { + debug.log("unprotecting packet (%s in datagram) %s(%s bytes)", + packetIndex, packetType, packet.remaining()); + } + decrypt(packet); + } + } catch (Throwable t) { + if (debug.on()) { + debug.log("Failed to process incoming packet", t); + } + } + } + + /** + * Called when an incoming packet has been decrypted. + *

+ * @param quicPacket the decrypted quic packet + */ + public void processDecrypted(QuicPacket quicPacket) { + PacketType packetType = quicPacket.packetType(); + long packetNumber = quicPacket.packetNumber(); + if (debug.on()) { + debug.log("processDecrypted %s(%d)", packetType, packetNumber); + } + if (Log.quicPacketInLoggable(quicPacket)) { + Log.logQuicPacketIn(logTag(), quicPacket); + } + if (packetType != PacketType.VERSIONS) { + versionCompatible = true; + // versions will also set versionCompatible later + } + if (isClientConnection() + && quicPacket instanceof InitialPacket longPacket + && quicPacket.frames().stream().anyMatch(CryptoFrame.class::isInstance)) { + markVersionNegotiated(longPacket.version()); + } + PacketSpace packetSpace = null; + if (packetNumber >= 0) { + packetSpace = packetSpace(quicPacket.numberSpace()); + + // From RFC 9000, Section 13.2.3: + // A receiver MUST retain an ACK Range unless it can ensure that + // it will not subsequently accept packets with numbers in + // that range. Maintaining a minimum packet number that increases + // as ranges are discarded is one way to achieve this with minimal + // state. + long threshold = packetSpace.getMinPNThreshold(); + if (packetNumber <= threshold) { + // discard the packet, as we are no longer acknowledging + // packets in this range. + if (debug.on()) + debug.log("discarding packet %s(%d) - threshold: %d", + packetType, packetNumber, threshold); + return; + } + if (packetSpace.isAcknowledged(packetNumber)) { + if (debug.on()) + debug.log("discarding packet %s(%d) - duplicated", + packetType, packetNumber, threshold); + } + + if (debug.on()) { + debug.log("receiving packet %s(pn:%s, %s)", packetType, + packetNumber, quicPacket.frames()); + } + } + switch (packetType) { + case VERSIONS -> processVersionNegotiationPacket(quicPacket); + case INITIAL -> processInitialPacket(quicPacket); + case ONERTT -> processOneRTTPacket(quicPacket); + case HANDSHAKE -> processHandshakePacket(quicPacket); + case RETRY -> processRetryPacket(quicPacket); + case ZERORTT -> { + if (debug.on()) { + debug.log("Dropping unhandled quic packet %s", packetType); + } + } + case NONE -> throw new InternalError("Unrecognized packet type"); + } + // packet has been processed successfully - connection isn't idle (RFC-9000, section 10.1) + this.terminator.keepAlive(); + if (packetSpace != null) { + packetSpace.packetReceived( + packetType, + packetNumber, + quicPacket.isAckEliciting()); + } + } + + /** + * {@return true if this is a stream initiated locally, and false if + * this is a stream initiated by the peer}. + * @param streamId a stream ID. + */ + protected final boolean isLocalStream(long streamId) { + return isClientConnection() == QuicStreams.isClientInitiated(streamId); + } + + /** + * If a stream with this streamId was already created, returns it. + * @param streamId the stream ID + * @return the stream identified by the given {@code streamId}, or {@code null}. + */ + protected QuicStream findStream(long streamId) { + return streams.findStream(streamId); + } + + /** + * @return true if this stream ID identifies a stream that was + * already opened + * @param streamId the stream id + */ + protected boolean isExistingStreamId(long streamId) { + long next = streams.peekNextStreamId(streamType(streamId)); + return streamId < next; + } + + /** + * Get or open a peer initiated stream with the given stream ID + * @param streamId the id of the remote stream + * @param frameType type of the frame received, used in exceptions + * @return the remote initiated stream identified by the given + * stream ID, or null + * @throws QuicTransportException if the streamID is higher than allowed + */ + protected QuicStream openOrGetRemoteStream(long streamId, long frameType) throws QuicTransportException { + assert !isLocalStream(streamId); + return streams.getOrCreateRemoteStream(streamId, frameType); + } + + /** + * Called to process a {@link OneRttPacket} after it has been successfully decrypted + * @param quicPacket the Quic packet + * @throws IllegalArgumentException if the {@code quicPacket} isn't a 1-RTT packet + * @throws NullPointerException if {@code quicPacket} is null + */ + protected void processOneRTTPacket(final QuicPacket quicPacket) { + Objects.requireNonNull(quicPacket); + if (quicPacket.packetType() != PacketType.ONERTT) { + throw new IllegalArgumentException("Not a ONERTT packet: " + quicPacket.packetType()); + } + assert quicPacket instanceof OneRttPacket : "Unexpected ONERTT packet class type: " + + quicPacket.getClass(); + final OneRttPacket oneRTT = (OneRttPacket) quicPacket; + try { + if (debug.on()) { + debug.log("processing packet ONERTT(%s)", quicPacket.packetNumber()); + } + final var frames = oneRTT.frames(); + if (debug.on()) { + debug.log("processing frames: " + frames.stream() + .map(Object::getClass).map(Class::getSimpleName) + .collect(Collectors.joining(", ", "[", "]"))); + } + for (var frame : oneRTT.frames()) { + if (!frame.isValidIn(PacketType.ONERTT)) { + throw new QuicTransportException("Invalid frame in ONERTT packet", + KeySpace.ONE_RTT, frame.getTypeField(), + PROTOCOL_VIOLATION); + } + if (debug.on()) { + debug.log("received 1-RTT frame %s", frame); + } + switch (frame) { + case AckFrame ackFrame -> { + incoming1RTTFrame(ackFrame); + } + case StreamFrame streamFrame -> { + incoming1RTTFrame(streamFrame); + } + case CryptoFrame crypto -> { + incoming1RTTFrame(crypto); + } + case ResetStreamFrame resetStreamFrame -> { + incoming1RTTFrame(resetStreamFrame); + } + case DataBlockedFrame dataBlockedFrame -> { + incoming1RTTFrame(dataBlockedFrame); + } + case StreamDataBlockedFrame streamDataBlockedFrame -> { + incoming1RTTFrame(streamDataBlockedFrame); + } + case StreamsBlockedFrame streamsBlockedFrame -> { + incoming1RTTFrame(streamsBlockedFrame); + } + case PaddingFrame paddingFrame -> { + incoming1RTTFrame(paddingFrame); + } + case MaxDataFrame maxData -> { + incoming1RTTFrame(maxData); + } + case MaxStreamDataFrame maxStreamData -> { + incoming1RTTFrame(maxStreamData); + } + case MaxStreamsFrame maxStreamsFrame -> { + incoming1RTTFrame(maxStreamsFrame); + } + case StopSendingFrame stopSendingFrame -> { + incoming1RTTFrame(stopSendingFrame); + } + case PingFrame ping -> { + incoming1RTTFrame(ping); + } + case ConnectionCloseFrame close -> { + incoming1RTTFrame(close); + } + case HandshakeDoneFrame handshakeDoneFrame -> { + incoming1RTTFrame(handshakeDoneFrame); + } + case NewConnectionIDFrame newCid -> { + incoming1RTTFrame(newCid); + } + case RetireConnectionIDFrame retireCid -> { + incoming1RTTFrame(oneRTT, retireCid); + } + case NewTokenFrame newTokenFrame -> { + incoming1RTTFrame(newTokenFrame); + } + case PathResponseFrame pathResponseFrame -> { + incoming1RTTFrame(pathResponseFrame); + } + case PathChallengeFrame pathChallengeFrame -> { + incoming1RTTFrame(pathChallengeFrame); + } + default -> { + if (debug.on()) { + debug.log("Frame type: %s not supported yet", frame.getClass()); + } + } + } + } + } catch (Throwable t) { + onProcessingError(quicPacket, t); + } + } + + /** + * Gets a receiving stream instance for the given ID, used for processing + * incoming STREAM, RESET_STREAM and STREAM_DATA_BLOCKED frames. + * Returns null if the instance is gone already. Throws an exception if the stream ID is incorrect. + * @param streamId stream ID + * @param frameType received frame type. Used in QuicTransportException + * @return receiver stream, or null if stream is already gone + * @throws QuicTransportException if the stream ID is not a valid receiving stream + */ + private QuicReceiverStream getReceivingStream(long streamId, long frameType) throws QuicTransportException { + var stream = findStream(streamId); + boolean isLocalStream = isLocalStream(streamId); + boolean isUnidirectional = isUnidirectional(streamId); + if (isLocalStream && isUnidirectional) { + // stream is write-only + throw new QuicTransportException("Stream %s (type %s) is unidirectional" + .formatted(streamId, streamType(streamId)), + KeySpace.ONE_RTT, frameType, QuicTransportErrors.STREAM_STATE_ERROR); + } + if (stream == null && isLocalStream) { + // the stream is either closed or bad stream + if (!isExistingStreamId(streamId)) { + throw new QuicTransportException("No such stream %s (type %s)" + .formatted(streamId, streamType(streamId)), + KeySpace.ONE_RTT, frameType, + QuicTransportErrors.STREAM_STATE_ERROR); + } + return null; + } + + if (stream == null) { + assert !isLocalStream; + // Note: The quic protocol allows any peer to open + // a bidirectional remote stream. + // The HTTP/3 protocol does not allow a server to open a + // bidirectional stream on the client. If this is a client + // connection and the stream type is bidirectional and + // remote, the connection will be closed by the HTTP/3 + // higher level protocol but not here, since this is + // not a Quic protocol error. + stream = openOrGetRemoteStream(streamId, frameType); + if (stream == null) { + return null; + } + } + return (QuicReceiverStream)stream; + } + + /** + * Gets a sending stream instance for the given ID, used for processing + * incoming MAX_STREAM_DATA and STOP_SENDING frames. + * Returns null if the instance is gone already. Throws an exception if the stream ID is incorrect. + * @param streamId stream ID + * @param frameType received frame type. Used in QuicTransportException + * @return sender stream, or null if stream is already gone + * @throws QuicTransportException if the stream ID is not a valid sending stream + */ + private QuicSenderStream getSendingStream(long streamId, long frameType) throws QuicTransportException { + var stream = findStream(streamId); + boolean isLocalStream = isLocalStream(streamId); + boolean isUnidirectional = isUnidirectional(streamId); + if (!isLocalStream && isUnidirectional) { + // stream is read-only + throw new QuicTransportException("Stream %s (type %s) is unidirectional" + .formatted(streamId, streamType(streamId)), + QuicTLSEngine.KeySpace.ONE_RTT, frameType, QuicTransportErrors.STREAM_STATE_ERROR); + } + if (stream == null && isLocalStream) { + // the stream is either closed or bad stream + if (!isExistingStreamId(streamId)) { + throw new QuicTransportException("No such stream %s (type %s)" + .formatted(streamId, streamType(streamId)), + QuicTLSEngine.KeySpace.ONE_RTT, frameType, + QuicTransportErrors.STREAM_STATE_ERROR); + } + return null; + } + + if (stream == null) { + assert !isLocalStream; + stream = openOrGetRemoteStream(streamId, frameType); + if (stream == null) { + return null; + } + } + return (QuicSenderStream)stream; + } + + /** + * Called to process an {@link InitialPacket} after it has been decrypted. + * @param quicPacket the Quic packet + * @throws IllegalArgumentException if {@code quicPacket} isn't a INITIAL packet + * @throws NullPointerException if {@code quicPacket} is null + */ + protected void processInitialPacket(final QuicPacket quicPacket) { + Objects.requireNonNull(quicPacket); + if (quicPacket.packetType() != PacketType.INITIAL) { + throw new IllegalArgumentException("Not a INITIAL packet: " + quicPacket.packetType()); + } + try { + if (quicPacket instanceof InitialPacket initial) { + MaxInitialTimer initialTimer = this.maxInitialTimer; + if (initialTimer != null) { + // will be a no-op after the first call; + initialTimer.initialPacketReceived(); + // we no longer need the timer + this.maxInitialTimer = null; + } + int total; + updatePeerConnectionId(initial); + total = processInitialPacketPayload(initial); + assert total == initial.payloadSize(); + // received initial packet from server - we won't need to replay anything now + handshakeFlow.localInitial.discardReplayData(); + continueHandshake(); + if (quicTLSEngine.getHandshakeState() == HandshakeState.NEED_RECV_CRYPTO && + quicTLSEngine.keysAvailable(KeySpace.HANDSHAKE)) { + // arm the anti-deadlock PTO timer + packetSpaces.handshake.runTransmitter(); + } + } else { + throw new InternalError("Bad packet type: " + quicPacket); + } + } catch (Throwable t) { + terminator.terminate(TerminationCause.forException(t)); + } + } + + protected void updatePeerConnectionId(InitialPacket initial) throws QuicTransportException { + this.incomingInitialPacketSourceId = initial.sourceId(); + this.peerConnIdManager.finalizeHandshakePeerConnId(initial); + } + + public QuicConnectionId getIncomingInitialPacketSourceId() { + return incomingInitialPacketSourceId; + } + + @Override + public CompletableFuture handshakeReachedPeer() { + return this.handshakeFlow.handshakeReachedPeerCF; + } + + /** + * Process the payload of an incoming initial packet + * @param packet the incoming packet + * @return the total number of bytes consumed + * @throws SSLHandshakeException if the handshake failed + * @throws IOException if a frame couldn't be decoded, or the payload + * wasn't entirely consumed. + */ + protected int processInitialPacketPayload(final InitialPacket packet) + throws IOException, QuicTransportException { + int provided=0, total=0; + int initialPayloadSize = packet.payloadSize(); + if (debug.on()) { + debug.log("Processing initial packet pn:%s payload:%s", + packet.packetNumber(), initialPayloadSize); + } + for (final var frame: packet.frames()) { + if (debug.on()) { + debug.log("received INITIAL frame %s", frame); + } + int size = frame.size(); + total += size; + switch (frame) { + case AckFrame ack -> { + incomingInitialFrame(ack); + } + case CryptoFrame crypto -> { + provided = incomingInitialFrame(crypto); + } + case PaddingFrame paddingFrame -> { + incomingInitialFrame(paddingFrame); + } + case PingFrame ping -> { + incomingInitialFrame(ping); + } + case ConnectionCloseFrame close -> { + incomingInitialFrame(close); + } + default -> { + if (debug.on()) { + debug.log("Received invalid frame: " + frame); + } + assert !frame.isValidIn(packet.packetType()) : frame.getClass(); + throw new QuicTransportException("Invalid frame in this packet type", + packet.packetType().keySpace().orElse(null), frame.getTypeField(), + PROTOCOL_VIOLATION); + } + } + } + if (total != initialPayloadSize) { + throw new IOException("Initial payload wasn't fully consumed: %s read, of which %s crypto, from %s size" + .formatted(total, provided, initialPayloadSize)); + } + return total; + } + /** + * Process the payload of an incoming handshake packet + * @param packet the incoming packet + * @return the total number of bytes consumed + * @throws SSLHandshakeException if the handshake failed + * @throws IOException if a frame couldn't be decoded, or the payload + * wasn't entirely consumed. + */ + protected int processHandshakePacketPayload(final HandshakePacket packet) + throws IOException, QuicTransportException { + int provided=0, total=0; + int payloadSize = packet.payloadSize(); + for (final var frame: packet.frames()) { + if (debug.on()) { + debug.log("received HANDSHAKE frame %s", frame); + } + int size = frame.size(); + total += size; + switch (frame) { + case AckFrame ack -> { + incomingHandshakeFrame(ack); + } + case CryptoFrame crypto -> { + provided = incomingHandshakeFrame(crypto); + } + case PaddingFrame paddingFrame -> { + incomingHandshakeFrame(paddingFrame); + } + case PingFrame ping -> { + incomingHandshakeFrame(ping); + } + case ConnectionCloseFrame close -> { + incomingHandshakeFrame(close); + } + default -> { + assert !frame.isValidIn(packet.packetType()) : frame.getClass(); + throw new QuicTransportException("Invalid frame in this packet type", + packet.packetType().keySpace().orElse(null), frame.getTypeField(), + PROTOCOL_VIOLATION); + } + } + } + if (total != payloadSize) { + throw new IOException("Handshake payload wasn't fully consumed: %s read, of which %s crypto, from %s size" + .formatted(total, provided, payloadSize)); + } + return total; + } + + /** + * Called to process an {@link HandshakePacket} after it has been decrypted. + * @param quicPacket the handshake quic packet + * @throws IllegalArgumentException if {@code quicPacket} is not a HANDSHAKE packet + * @throws NullPointerException if {@code quicPacket} is null + */ + protected void processHandshakePacket(final QuicPacket quicPacket) { + Objects.requireNonNull(quicPacket); + if (quicPacket.packetType() != PacketType.HANDSHAKE) { + throw new IllegalArgumentException("Not a HANDSHAKE packet: " + quicPacket.packetType()); + } + final var handshake = this.handshakeFlow.handshakeCF(); + if (handshake.isDone() && debug.on()) { + debug.log("Receiving HandshakePacket(%s) after handshake is done: %s", + quicPacket.packetNumber(), quicPacket.frames()); + } + try { + if (quicPacket instanceof HandshakePacket hs) { + int total; + total = processHandshakePacketPayload(hs); + assert total == hs.payloadSize(); + continueHandshake(); + } else { + throw new InternalError("Bad packet type: " + quicPacket); + } + } catch (Throwable t) { + terminator.terminate(TerminationCause.forException(t)); + } + } + + /** + * Called to process a {@link RetryPacket} after it has been decrypted. + * @param quicPacket the retry quic packet + * @throws IllegalArgumentException if {@code quicPacket} is not a RETRY packet + * @throws NullPointerException if {@code quicPacket} is null + */ + protected void processRetryPacket(final QuicPacket quicPacket) { + Objects.requireNonNull(quicPacket); + if (quicPacket.packetType() != PacketType.RETRY) { + throw new IllegalArgumentException("Not a RETRY packet: " + quicPacket.packetType()); + } + try { + if (!(quicPacket instanceof RetryPacket rt)) { + throw new InternalError("Bad packet type: " + quicPacket); + } + assert stateHandle.helloSent() : "unexpected message"; + if (rt.retryToken().length == 0) { + if (debug.on()) { + debug.log("Invalid retry, empty token"); + } + return; + } + final QuicConnectionId currentPeerConnId = this.peerConnIdManager.getPeerConnId(); + if (rt.sourceId().equals(currentPeerConnId)) { + if (debug.on()) { + debug.log("Invalid retry, same connection ID"); + } + return; + } + if (this.peerConnIdManager.retryConnId() != null) { + if (debug.on()) { + debug.log("Ignoring retry, already got one"); + } + return; + } + // ignore retry if we already received initial packets + if (incomingInitialPacketSourceId != null) { + if (debug.on()) { + debug.log("Already received initial, ignoring retry"); + } + return; + } + final int version = rt.version(); + final QuicVersion retryVersion = QuicVersion.of(version).orElse(null); + if (retryVersion == null) { + if (debug.on()) { + debug.log("Ignoring retry packet with unknown version 0x" + + Integer.toHexString(version)); + } + // ignore the packet + return; + } + final QuicVersion originalVersion = this.quicVersion; // the original version used to establish the connection + if (originalVersion != retryVersion) { + if (debug.on()) { + debug.log("Ignoring retry packet with version 0x" + + Integer.toHexString(version) + + " since it doesn't match the original version 0x" + + Integer.toHexString(originalVersion.versionNumber())); + } + // ignore the packet + return; + } + ReentrantLock tl = packetSpaces.initial.getTransmitLock(); + tl.lock(); + try { + initialToken = rt.retryToken(); + final QuicConnectionId retryConnId = rt.sourceId(); + this.peerConnIdManager.retryConnId(retryConnId); + quicTLSEngine.deriveInitialKeys(originalVersion, retryConnId.asReadOnlyBuffer()); + this.packetSpace(PacketNumberSpace.INITIAL).retry(); + handshakeFlow.localInitial.replayData(); + } finally { + tl.unlock(); + } + packetSpaces.initial.runTransmitter(); + } catch (Throwable t) { + terminator.terminate(TerminationCause.forException(t)); + } + } + + /** + * {@return the next (higher) max streams limit that should be advertised to the remote peer. + * Returns {@code 0} if the limit should not be increased} + * + * @param bidi true if bidirectional stream, false otherwise + */ + public long nextMaxStreamsLimit(final boolean bidi) { + if (isClientConnection() && bidi) return 0; // server does not open bidi streams + return streams.nextMaxStreamsLimit(bidi); + } + + /** + * Called when a stateless reset token is received. + */ + @Override + public void processStatelessReset() { + terminator.incomingStatelessReset(); + } + + /** + * Called to process a received {@link VersionNegotiationPacket} + * @param quicPacket the {@link VersionNegotiationPacket} + * @throws IllegalArgumentException if {@code quicPacket} is not a {@link PacketType#VERSIONS} + * packet + * @throws NullPointerException if {@code quicPacket} is null + */ + protected void processVersionNegotiationPacket(final QuicPacket quicPacket) { + Objects.requireNonNull(quicPacket); + if (quicPacket.packetType() != PacketType.VERSIONS) { + throw new IllegalArgumentException("Not a VERSIONS packet type: " + quicPacket.packetType()); + } + // servers aren't expected to receive version negotiation packet + if (!this.isClientConnection()) { + if (debug.on()) { + debug.log("(server) ignoring version negotiation packet"); + } + return; + } + try { + final var handshakeCF = this.handshakeFlow.handshakeCF(); + // we must ignore version negotiation if we already had a successful exchange + var versionCompatible = this.versionCompatible; + if (versionCompatible || handshakeCF.isDone()) { + if (debug.on()) { + debug.log("ignoring version negotiation packet (neg: %s, state: %s, hs: %s)", + versionCompatible, stateHandle, handshakeCF); + } + return; + } + // we shouldn't receive unsolicited version negotiation packets + assert stateHandle.helloSent(); + if (!(quicPacket instanceof VersionNegotiationPacket negotiate)) { + if (debug.on()) { + debug.log("Bad packet type %s for %s", + quicPacket.getClass().getName(), quicPacket); + } + return; + } + if (!negotiate.sourceId().equals(originalServerConnId())) { + if (debug.on()) { + debug.log("Received version negotiation packet with wrong connection id"); + debug.log("expected source id: %s, received source id: %s", + originalServerConnId(), negotiate.sourceId()); + debug.log("ignoring version negotiation packet (wrong id)"); + } + return; + } + final int[] serverSupportedVersions = negotiate.supportedVersions(); + if (debug.on()) { + debug.log("Received version negotiation packet with supported=%s", + Arrays.toString(serverSupportedVersions)); + } + assert this.quicInstance() instanceof QuicClient : "Not a quic client"; + final QuicClient client = (QuicClient) this.quicInstance(); + QuicVersion negotiatedVersion = null; + for (final int v : serverSupportedVersions) { + final QuicVersion serverVersion = QuicVersion.of(v).orElse(null); + if (serverVersion == null) { + if (debug.on()) { + debug.log("Ignoring unrecognized server supported version %d", v); + } + continue; + } + if (serverVersion == this.quicVersion) { + // RFC-9000, section 6.2: + // A client MUST discard a Version Negotiation packet that lists + // the QUIC version selected by the client. + if (debug.on()) { + debug.log("ignoring version negotiation packet since the version" + + " %d matches the current quic version selected by the client", v); + } + return; + } + // check if the current quic client is enabled for this version + if (!client.isVersionAvailable(serverVersion)) { + if (debug.on()) { + debug.log("Ignoring server supported version %d because the " + + "client isn't enabled for it", v); + } + continue; + } + if (debug.on()) { + if (negotiatedVersion == null) { + debug.log("Accepting server supported version %d", + serverVersion.versionNumber()); + negotiatedVersion = serverVersion; + } else { + // currently all versions are equal + debug.log("Skipping server supported version %d", + serverVersion.versionNumber()); + } + } + } + // at this point if negotiatedVersion is null, then it implies that none of the server + // supported versions are supported by the client. The spec expects us to abandon the + // current connection attempt in such cases (RFC-9000, section 6.2) + if (negotiatedVersion == null) { + final String msg = "No support for any of the QUIC versions being negotiated: " + + Arrays.toString(serverSupportedVersions); + if (debug.on()) { + debug.log("No version could be negotiated: %s", msg); + } + terminator.terminate(forException(new IOException(msg))); + return; + } + // a different version than the current client chosen version has been negotiated, + // switch the client connection to use this negotiated version + ReentrantLock tl = packetSpaces.initial.getTransmitLock(); + tl.lock(); + try { + if (switchVersion(negotiatedVersion)) { + final ByteBuffer quicInitialParameters = buildInitialParameters(); + quicTLSEngine.setLocalQuicTransportParameters(quicInitialParameters); + quicTLSEngine.restartHandshake(); + handshakeFlow.localInitial.reset(); + continueHandshake(); + packetSpaces.initial.runTransmitter(); + this.versionCompatible = true; + processedVersionsPacket = true; + } + } finally { + tl.unlock(); + } + } catch (Throwable t) { + if (debug.on()) { + debug.log("Failed to handle packet", t); + } + } + + } + + /** + * Switch to a new version after receiving a version negotiation + * packet. This method checks that no version was previously + * negotiated, in which case it switches the connection to the + * new version and returns true. + * Otherwise, it returns false. + * + * @param negotiated the new version that was negotiated + * @return true if switching to the new version was successful + */ + protected boolean switchVersion(QuicVersion negotiated) { + try { + assert !versionNegotiated; + if (debug.on()) + debug.log("switch to negotiated version %s", negotiated); + this.quicVersion = negotiated; + this.decoder = QuicPacketDecoder.of(negotiated); + this.encoder = QuicPacketEncoder.of(negotiated); + this.packetSpace(PacketNumberSpace.INITIAL).versionChanged(); + // regenerate the INITIAL keys using the new negotiated Quic version + this.quicTLSEngine.deriveInitialKeys(negotiated, originalServerConnId().asReadOnlyBuffer()); + return true; + } catch (Throwable t) { + terminator.terminate(forException(t)); + throw new RuntimeException("failed to switch to version", t); + } + } + + /** + * Mark the version as negotiated. No further version changes are possible. + * + * @param packetVersion the packet version + */ + protected void markVersionNegotiated(int packetVersion) { + int version = this.quicVersion.versionNumber(); + assert packetVersion == version; + if (!versionNegotiated) { + if (VERSION_NEGOTIATED.compareAndSet(this, false, true)) { + // negotiated version finalized + quicTLSEngine.versionNegotiated(QuicVersion.of(version).get()); + } + } + } + + /** + * {@return a boolean value telling whether the datagram in the + * protection record is complete} + * The datagram is complete when no other packet need to be coalesced + * in the datagram. + * If a datagram is complete, it is ready to be sent. + * + * @param protectionRecord the protection record + */ + protected boolean isDatagramComplete(ProtectionRecord protectionRecord) { + return protectionRecord.datagram.remaining() == 0 + || protectionRecord.flags == ProtectionRecord.SINGLE_PACKET + || (protectionRecord.flags & ProtectionRecord.LAST_PACKET) != 0 + || (protectionRecord.flags & ProtectionRecord.COALESCED) == 0; + } + + /** + * {@return the peer address that should be used when sending datagram + * to the peer} + */ + public InetSocketAddress peerAddress() { + return peerAddress; + } + + /** + * {@return the local address of the quic endpoint} + * @throws UncheckedIOException if the address is not available + */ + public SocketAddress localAddress() { + try { + var endpoint = this.endpoint; + if (endpoint == null) { + throw new IOException("no endpoint defined"); + } + return endpoint.getLocalAddress(); + } catch (IOException io) { + throw new UncheckedIOException(io); + } + } + + /** + * Pushes the {@linkplain ProtectionRecord#datagram() datagram} contained in + * the {@code protectionRecord}, through the {@linkplain QuicEndpoint endpoint}. + * + * @param protectionRecord the ProtectionRecord containing the datagram + */ + private void pushEncryptedDatagram(final ProtectionRecord protectionRecord) { + final long packetNumber = protectionRecord.packet().packetNumber(); + assert packetNumber >= 0 : "unexpected packet number: " + packetNumber; + final long retransmittedPacketNumber = protectionRecord.retransmittedPacketNumber(); + assert packetNumber > retransmittedPacketNumber : "packet number: " + packetNumber + + " was expected to be greater than packet the packet being retransmitted: " + + retransmittedPacketNumber; + final boolean pktContainsConnClose = containsConnectionClose(protectionRecord.packet()); + // if the connection isn't open then except for the packet containing a CONNECTION_CLOSE + // frame, we don't push any other packets. + if (!isOpen() && !pktContainsConnClose) { + if (debug.on()) { + debug.log("connection isn't open - ignoring %s(pn:%s): frames:%s", + protectionRecord.packet.packetType(), + protectionRecord.packet.packetNumber(), + protectionRecord.packet.frames()); + } + datagramDropped(new QuicDatagram(this, peerAddress, protectionRecord.datagram)); + return; + } + // TODO: revisit this: we need to figure out how best to emit coalesced packet, + // and having one protection record per packet may not be the the best. + // Maybe a protection record should have a list of coalesced packets + // instead of a single packet? + final ByteBuffer datagram = protectionRecord.datagram(); + final int firstPacketOffset = protectionRecord.firstPacketOffset(); + // flip the datagram + datagram.limit(datagram.position()); + datagram.position(firstPacketOffset); + if (debug.on()) { + final PacketType packetType = protectionRecord.packet().packetType(); + final int packetOffset = protectionRecord.packetOffset(); + if (packetOffset == firstPacketOffset) { + debug.log("Pushing datagram([%s(%d)], %d)", packetType, packetNumber, + datagram.remaining()); + } else { + debug.log("Pushing coalesced datagram([%s(%d)], %d)", + packetType, packetNumber, datagram.remaining()); + } + } + + // upon successful sending of the datagram, notify that the packet was sent + // we call packetSent just before sending the packet here, to make sure + // that the PendingAcknowledgement will be present in the queue before + // we receive the ACK frame from the server. Not doing this would create + // a race where the peer might be able to send the ack, and we might process + // it, before the PendingAcknowledgement is added. + final QuicPacket packet = protectionRecord.packet(); + final PacketSpace packetSpace = packetSpace(packet.numberSpace()); + packetSpace.packetSent(packet, retransmittedPacketNumber, packetNumber); + + // if we are sending a packet containing a CONNECTION_CLOSE frame, then we + // also switch/remove the current connection instance in the endpoint. + if (pktContainsConnClose) { + if (stateHandle.isMarked(QuicConnectionState.DRAINING)) { + // a CONNECTION_CLOSE frame is being sent to the peer when the local + // connection state is in DRAINING. This implies that the local endpoint + // is responding to an incoming CONNECTION_CLOSE frame from the peer. + // we remove the connection from the endpoint for such cases. + endpoint.pushClosedDatagram(this, peerAddress(), datagram); + } else if (stateHandle.isMarked(QuicConnectionState.CLOSING)) { + // a CONNECTION_CLOSE frame is being sent to the peer when the local + // connection state is in CLOSING. For such cases, we switch this + // connection in the endpoint to one which responds with + // CONNECTION_CLOSE frame for any subsequent incoming packets + // from the peer. + endpoint.pushClosingDatagram(this, peerAddress(), datagram); + } else { + // should not happen + throw new IllegalStateException("connection is neither draining nor closing," + + " cannot send a connection close frame"); + } + } else { + pushDatagram(peerAddress(), datagram); + } + // RFC-9000, section 10.1: An endpoint also restarts its idle timer when sending + // an ack-eliciting packet ... + if (packet.isAckEliciting()) { + this.terminator.keepAlive(); + } + } + + /** + * Calls the {@link QuicEndpoint#pushDatagram(QuicPacketReceiver, SocketAddress, ByteBuffer)} + * + * @param destination The destination of this datagram + * @param datagram The datagram + */ + protected void pushDatagram(final SocketAddress destination, final ByteBuffer datagram) { + endpoint.pushDatagram(this, destination, datagram); + } + + /** + * Called when a datagram scheduled for writing by this connection + * could not be written to the network. + * @param t the error that occurred + */ + @Override + public void onWriteError(Throwable t) { + // log exception if still opened + if (stateHandle.opened()) { + if (Log.errors()) { + Log.logError("%s: Failed to write datagram: %s", dbgTag(), t ); + Log.logError(t); + } else if (debug.on()) { + debug.log("Failed to write datagram", t); + } + } + } + + /** + * Called when a packet couldn't be processed + * @param t the error that occurred + */ + public void onProcessingError(QuicPacket packet, Throwable t) { + terminator.terminate(TerminationCause.forException(t)); + } + + /** + * Starts the Quic Handshake. + * @return A completable future which will be completed when the + * handshake is completed. + * @throws UnsupportedOperationException If this connection isn't a client connection + */ + public final CompletableFuture startHandshake() { + if (!isClientConnection()) { + throw new UnsupportedOperationException("Not a client connection, cannot start handshake"); + } + if (!this.startHandshakeCalled.compareAndSet(false, true)) { + throw new IllegalStateException("handshake has already been started on connection"); + } + if (this.peerAddress.isUnresolved()) { + // fail if address is unresolved + return MinimalFuture.failedFuture( + Utils.toConnectException(new UnresolvedAddressException())); + } + CompletableFuture cf; + try { + // register the connection with an endpoint + assert this.quicInstance instanceof QuicClient : "Not a QuicClient"; + endpoint.registerNewConnection(this); + cf = MinimalFuture.completedFuture(null); + } catch (Throwable t) { + cf = MinimalFuture.failedFuture(t); + } + return cf.thenApply(this::sendFirstInitialPacket) + .exceptionally((t) -> { + // complete the handshake CFs with the failure + handshakeFlow.failHandshakeCFs(t); + return handshakeFlow; + }) + .thenCompose(HandshakeFlow::handshakeCF) + .thenApply(this::onHandshakeCompletion); + } + + /** + * This method is called when the handshake is successfully completed. + * @param result the result of the handshake + */ + protected QuicEndpoint onHandshakeCompletion(final HandshakeState result) { + if (debug.on()) { + debug.log("Quic handshake successfully completed with %s(%s)", + quicTLSEngine.getApplicationProtocol(), peerAddress()); + } + // now that the handshake has successfully completed, start the + // idle timeout management for this connection + this.idleTimeoutManager.start(); + return this.endpoint; + } + + protected HandshakeFlow handshakeFlow() { + return handshakeFlow; + } + + protected void startInitialTimer() { + if (!isClientConnection()) return; + MaxInitialTimer initialTimer = maxInitialTimer; + if (initialTimer == null && DEFAULT_MAX_INITIAL_TIMEOUT < Integer.MAX_VALUE) { + Deadline maxInitialDeadline = null; + synchronized (this) { + initialTimer = maxInitialTimer; + if (initialTimer == null) { + Deadline now = this.endpoint().timeSource().instant(); + maxInitialDeadline = now.plusSeconds(DEFAULT_MAX_INITIAL_TIMEOUT); + initialTimer = maxInitialTimer = new MaxInitialTimer(this.endpoint().timer(), maxInitialDeadline); + } + } + if (maxInitialDeadline != null) { + if (Log.quic()) { + Log.logQuic("{0}: Arming quic initial timer for {1}", logTag(), + Deadline.between(this.endpoint().timeSource().instant(), maxInitialDeadline)); + } + if (debug.on()) { + debug.log("Arming quic initial timer for %s seconds", + Deadline.between(this.endpoint().timeSource().instant(), maxInitialDeadline).toSeconds()); + } + initialTimer.timerQueue.reschedule(initialTimer, maxInitialDeadline); + } + } + } + + // adaptation to Function + private HandshakeFlow sendFirstInitialPacket(Void unused) { + // may happen if connection cancelled before endpoint is + // created + final TerminationCause tc = terminationCause(); + if (tc != null) { + throw new CompletionException(tc.getCloseCause()); + } + try { + startInitialTimer(); + if (Log.quic()) { + Log.logQuic(logTag() + ": connectionId: " + + connectionId.toHexString() + + ", " + endpoint + ": " + endpoint.name() + + " - " + endpoint.getLocalAddressString()); + } else if (debug.on()) { + debug.log(logTag() + ": connectionId: " + + connectionId.toHexString() + + ", " + endpoint + ": " + endpoint.name() + + " - " + endpoint.getLocalAddressString()); + } + var localAddress = endpoint.getLocalAddress(); + var conflict = Utils.addressConflict(localAddress, peerAddress); + if (conflict != null) { + String msg = conflict; + if (debug.on()) { + debug.log("%s (local: %s, remote: %s)", msg, localAddress, peerAddress); + } + Log.logError("{0} {1} (local: {2}, remote: {3})", logTag(), + msg, localAddress, peerAddress); + throw new SSLHandshakeException(msg); + } + final QuicConnectionId clientSelectedPeerId = initialServerConnectionId(); + this.peerConnIdManager.originalServerConnId(clientSelectedPeerId); + handshakeFlow.markHandshakeStart(); + stateHandle.markHelloSent(); + // the "original version" used to establish the connection + final QuicVersion originalVersion = this.quicVersion; + quicTLSEngine.deriveInitialKeys(originalVersion, clientSelectedPeerId.asReadOnlyBuffer()); + final ByteBuffer quicInitialParameters = buildInitialParameters(); + quicTLSEngine.setLocalQuicTransportParameters(quicInitialParameters); + handshakeFlow.localInitial.keepReplayData(); + continueHandshake(); + packetSpaces.initial.runTransmitter(); + } catch (Throwable t) { + terminator.terminate(forException(t)); + throw new CompletionException(terminationCause().getCloseCause()); + } + return handshakeFlow; + } + + private static final Random RANDOM = new SecureRandom(); + + private QuicConnectionId initialServerConnectionId() { + byte[] bytes = new byte[INITIAL_SERVER_CONNECTION_ID_LENGTH]; + RANDOM.nextBytes(bytes); + return new PeerConnectionId(bytes); + } + + /** + * Compose a list of Quic frames containing a crypto frame and an ack frame, + * omitting null frames. + * @param crypto the crypto frame + * @param ack the ack frame + * @return A list of {@link QuicFrame}. + */ + private List makeList(CryptoFrame crypto, AckFrame ack) { + List frames = new ArrayList<>(2); + if (crypto != null) { + frames.add(crypto); + } + if (ack != null) { + frames.add(ack); + } + return frames; + } + + /** + * Allocate a {@link ByteBuffer} that can be used to encrypt the + * given packet. + * @param packet the packet to encrypt + * @return a new {@link ByteBuffer} with sufficient space to encrypt + * the given packet. + */ + protected ByteBuffer allocateDatagramForEncryption(QuicPacket packet) { + int size = packet.size(); + if (packet.hasLength()) { // packet can be coalesced + size = Math.max(size, getMaxDatagramSize()); + } + if (size > getMaxDatagramSize()) { + + if (Log.errors()) { + var error = new AssertionError("%s: Size too big: %s > %s".formatted( + logTag(), + size, getMaxDatagramSize())); + Log.logError(logTag() + ": Packet too big: " + packet.prettyPrint()); + Log.logError(error); + } else if (debug.on()) { + var error = new AssertionError("%s: Size too big: %s > %s".formatted( + logTag(), + size, getMaxDatagramSize())); + debug.log("Packet too big: " + packet.prettyPrint()); + debug.log(error); + } + // Revisit: if we implement Path MTU detection, then the max datagram size + // may evolve, increasing or decreasing as the path change. + // In which case - we may want to tune this, down and only + // log an error or warning? + final String errMsg = "Failed to encode packet, too big: " + size; + terminator.terminate(forTransportError(PROTOCOL_VIOLATION).loggedAs(errMsg)); + throw new UncheckedIOException(terminator.getTerminationCause().getCloseCause()); + } + return getOutgoingByteBuffer(size); + } + + /** + * {@return the maximum datagram size that can be used on the + * connection path} + * @implSpec + * Initially this is {@link #DEFAULT_DATAGRAM_SIZE}, but the + * value will then be decided if the peer sends a specific size + * in the transport parameters and the value can further evolve based + * on path MTU. + */ + // this is public for use in tests + public int getMaxDatagramSize() { + // TODO: we should implement path MTU detection, or maybe let + // this be configurable. Sizes of 32256 or 64512 seem to + // be giving much better throughput when downloading. + // large files + return Math.min(maxPeerAdvertisedPayloadSize, pathMTU); + } + + /** + * Retrieves cryptographic messages from TLS engine, enqueues them for sending + * and starts the transmitter. + */ + protected void continueHandshake() { + handshakeScheduler.runOrSchedule(); + } + + protected void continueHandshake0() { + try { + continueHandshake1(); + } catch (Throwable t) { + var flow = handshakeFlow; + flow.handshakeReachedPeerCF.completeExceptionally(t); + flow.handshakeCF.completeExceptionally(t); + } + } + + private void continueHandshake1() throws IOException { + HandshakeFlow flow = handshakeFlow; + // make sure the localInitialQueue is not modified concurrently + // while we are in this loop + boolean handshakeDataAvailable = false; + boolean initialDataAvailable = false; + for (;;) { + var handshakeState = quicTLSEngine.getHandshakeState(); + if (debug.on()) { + debug.log("continueHandshake: state: %s", handshakeState); + } + if (handshakeState == QuicTLSEngine.HandshakeState.NEED_SEND_CRYPTO) { + // buffer next TLS message + KeySpace keySpace = quicTLSEngine.getCurrentSendKeySpace(); + ByteBuffer payloadBuffer; + handshakeLock.lock(); + try { + payloadBuffer = quicTLSEngine.getHandshakeBytes(keySpace); + assert payloadBuffer != null; + assert payloadBuffer.hasRemaining(); + if (keySpace == KeySpace.INITIAL) { + flow.localInitial.enqueue(payloadBuffer); + initialDataAvailable = true; + } else if (keySpace == KeySpace.HANDSHAKE) { + flow.localHandshake.enqueue(payloadBuffer); + handshakeDataAvailable = true; + } + } finally { + handshakeLock.unlock(); + } + + assert payloadBuffer != null; + if (debug.on()) { + debug.log("continueHandshake: buffered %s bytes in %s keyspace", + payloadBuffer.remaining(), keySpace); + } + } else if (handshakeState == QuicTLSEngine.HandshakeState.NEED_TASK) { + quicTLSEngine.getDelegatedTask().run(); + } else { + if (debug.on()) { + debug.log("continueHandshake: nothing to do (state: %s)", handshakeState); + } + if (initialDataAvailable) { + packetSpaces.initial.runTransmitter(); + } + if (handshakeDataAvailable && flow.localInitial.remaining() == 0) { + packetSpaces.handshake.runTransmitter(); + } + return; + } + } + } + + private boolean sendData(PacketNumberSpace packetNumberSpace) + throws QuicKeyUnavailableException, QuicTransportException { + if (packetNumberSpace != PacketNumberSpace.APPLICATION) { + // This method can be called by two packet spaces: INITIAL and HANDSHAKE. + // We need to lock to make sure that the method is not run concurrently. + handshakeLock.lock(); + try { + return sendInitialOrHandshakeData(packetNumberSpace); + } finally { + handshakeLock.unlock(); + } + } else { + return oneRttSndQueue.send1RTTData(); + } + } + + private boolean sendInitialOrHandshakeData(final PacketNumberSpace packetNumberSpace) + throws QuicKeyUnavailableException, QuicTransportException { + if (Log.quicCrypto()) { + Log.logQuic(String.format("%s: Send %s data", logTag(), packetNumberSpace)); + } + final HandshakeFlow flow = handshakeFlow; + final QuicConnectionId peerConnId = peerConnectionId(); + if (packetNumberSpace == PacketNumberSpace.INITIAL && flow.localInitial.remaining() > 0) { + // process buffered initial data + byte[] token = initialToken(); + int tksize = token == null ? 0 : token.length; + PacketSpace packetSpace = packetSpaces.get(PacketNumberSpace.INITIAL); + int maxDstIdLength = isClientConnection() ? + MAX_CONNECTION_ID_LENGTH : // reserve space for the id to grow + peerConnId.length(); + int maxSrcIdLength = connectionId.length(); + // compute maxPayloadSize given maxSizeBeforeEncryption + var largestAckedPN = packetSpace.getLargestPeerAckedPN(); + var packetNumber = packetSpace.allocateNextPN(); + int maxPayloadSize = QuicPacketEncoder.computeMaxInitialPayloadSize(codingContext, 4, tksize, + maxSrcIdLength, maxDstIdLength, SMALLEST_MAXIMUM_DATAGRAM_SIZE); + // compute how many bytes were reserved to allow smooth retransmission + // of packets + int reserved = QuicPacketEncoder.computeMaxInitialPayloadSize(codingContext, + computePacketNumberLength(packetNumber, + codingContext.largestAckedPN(PacketNumberSpace.INITIAL)), + tksize, connectionId.length(), peerConnId.length(), + SMALLEST_MAXIMUM_DATAGRAM_SIZE) - maxPayloadSize; + assert reserved >= 0 : "reserved is negative: " + reserved; + if (debug.on()) { + debug.log("reserved %s byte in initial packet", reserved); + } + if (maxPayloadSize < 5) { + // token too long, can't fit a crypto frame in this packet. Abort. + final String msg = "Initial token too large, maxPayload: " + maxPayloadSize; + terminator.terminate(TerminationCause.forException(new IOException(msg))); + return false; + } + AckFrame ackFrame = packetSpace.getNextAckFrame(false, maxPayloadSize); + int ackSize = ackFrame == null ? 0 : ackFrame.size(); + if (debug.on()) { + debug.log("ack frame size: %d", ackSize); + } + + CryptoFrame crypto = flow.localInitial.produceFrame(maxPayloadSize - ackSize); + int cryptoSize = crypto == null ? 0 : crypto.size(); + assert cryptoSize <= maxPayloadSize : cryptoSize - maxPayloadSize; + List frames = makeList(crypto, ackFrame); + + if (debug.on()) { + debug.log("building initial packet: source=%s, dest=%s", + connectionId, peerConnId); + } + OutgoingQuicPacket packet = encoder.newInitialPacket( + connectionId, peerConnId, token, + packetNumber, largestAckedPN, frames, codingContext); + int size = packet.size(); + if (debug.on()) { + debug.log("initial packet size is %d, max is %d", + size, SMALLEST_MAXIMUM_DATAGRAM_SIZE); + } + assert size == SMALLEST_MAXIMUM_DATAGRAM_SIZE : size - SMALLEST_MAXIMUM_DATAGRAM_SIZE; + + stateHandle.markHelloSent(); + if (debug.on()) { + debug.log("protecting initial quic hello packet for %s(%s) - %d bytes", + Arrays.toString(quicTLSEngine.getSSLParameters().getApplicationProtocols()), + peerAddress(), packet.size()); + } + pushDatagram(ProtectionRecord.single(packet, this::allocateDatagramForEncryption)); + if (flow.localHandshake.remaining() > 0) { + if (Log.quicCrypto()) { + Log.logQuic(String.format("%s: local handshake has remaining, starting HANDSHAKE transmitter", logTag())); + } + packetSpaces.handshake.runTransmitter(); + } + } else if (packetNumberSpace == PacketNumberSpace.HANDSHAKE && flow.localHandshake.remaining() > 0) { + // process buffered handshake data + PacketSpace packetSpace = packetSpaces.get(PacketNumberSpace.HANDSHAKE); + AckFrame ackFrame = packetSpace.getNextAckFrame(false); + int ackSize = ackFrame == null ? 0 : ackFrame.size(); + if (debug.on()) { + debug.log("ack frame size: %d", ackSize); + } + // compute maxPayloadSize given maxSizeBeforeEncryption + var largestAckedPN = packetSpace.getLargestPeerAckedPN(); + var packetNumber = packetSpace.allocateNextPN(); + int maxPayloadSize = QuicPacketEncoder.computeMaxHandshakePayloadSize(codingContext, + packetNumber, connectionId.length(), peerConnId.length(), + SMALLEST_MAXIMUM_DATAGRAM_SIZE); + maxPayloadSize = maxPayloadSize - ackSize; + + final CryptoFrame crypto = flow.localHandshake.produceFrame(maxPayloadSize); + assert crypto != null : "Handshake data was available (" + + flow.localHandshake.remaining() + " bytes) for sending, but no CRYPTO" + + " frame was produced, for max frame size: " + maxPayloadSize; + int cryptoSize = crypto.size(); + assert cryptoSize <= maxPayloadSize : cryptoSize - maxPayloadSize; + List frames = makeList(crypto, ackFrame); + + if (debug.on()) { + debug.log("building handshake packet: source=%s, dest=%s", + connectionId, peerConnId); + } + OutgoingQuicPacket packet = encoder.newHandshakePacket( + connectionId, peerConnId, + packetNumber, largestAckedPN, frames, codingContext); + int size = packet.size(); + if (debug.on()) { + debug.log("handshake packet size is %d, max is %d", + size, SMALLEST_MAXIMUM_DATAGRAM_SIZE); + } + assert size <= SMALLEST_MAXIMUM_DATAGRAM_SIZE : size - SMALLEST_MAXIMUM_DATAGRAM_SIZE; + + if (debug.on()) { + debug.log("protecting handshake quic hello packet for %s(%s) - %d bytes", + Arrays.toString(quicTLSEngine.getSSLParameters().getApplicationProtocols()), + peerAddress(), packet.size()); + } + pushDatagram(ProtectionRecord.single(packet, this::allocateDatagramForEncryption)); + var handshakeState = quicTLSEngine.getHandshakeState(); + if (debug.on()) { + debug.log("Handshake state is now: %s", handshakeState); + } + if (flow.localHandshake.remaining() == 0 + && quicTLSEngine.isTLSHandshakeComplete() + && !flow.handshakeCF.isDone()) { + if (stateHandle.markHandshakeComplete()) { + if (debug.on()) { + debug.log("Handshake completed"); + } + completeHandshakeCF(); + } + } + if (!packetSpaces.initial().isClosed() && flow.localInitial.remaining() > 0) { + if (Log.quicCrypto()) { + Log.logQuic(String.format("%s: local initial has remaining, starting INITIAL transmitter", logTag())); + } + packetSpaces.initial.runTransmitter(); + } + } else { + return false; + } + return true; + } + + public QuicTransportParameters peerTransportParameters() { + return peerTransportParameters; + } + + public QuicTransportParameters localTransportParameters() { + return localTransportParameters; + } + + protected void consumeQuicParameters(final ByteBuffer byteBuffer) throws QuicTransportException { + final QuicTransportParameters params = QuicTransportParameters.decode(byteBuffer); + if (debug.on()) { + debug.log("Received peer Quic transport params: %s", params.toStringWithValues()); + } + final QuicConnectionId retryConnId = this.peerConnIdManager.retryConnId(); + if (params.isPresent(ParameterId.retry_source_connection_id)) { + if (retryConnId == null) { + throw new QuicTransportException("Retry connection ID was set even though no retry was performed", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } else if (!params.matches(ParameterId.retry_source_connection_id, retryConnId)) { + throw new QuicTransportException("Retry connection ID does not match", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + } else { + if (retryConnId != null) { + throw new QuicTransportException("Retry connection ID was expected but absent", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + } + if (!params.isPresent(ParameterId.original_destination_connection_id)) { + throw new QuicTransportException( + "Original connection ID transport parameter missing", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + + } + if (!params.isPresent(initial_source_connection_id)) { + throw new QuicTransportException( + "Initial source connection ID transport parameter missing", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + + } + final QuicConnectionId clientSelectedPeerConnId = this.peerConnIdManager.originalServerConnId(); + if (!params.matches(ParameterId.original_destination_connection_id, clientSelectedPeerConnId)) { + throw new QuicTransportException( + "Original connection ID does not match", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + if (!params.matches(initial_source_connection_id, incomingInitialPacketSourceId)) { + throw new QuicTransportException( + "Initial source connection ID does not match", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + // RFC-9000, section 18.2: A server that chooses a zero-length connection ID MUST NOT + // provide a preferred address. + if (peerConnectionId().length() == 0 && + params.isPresent(ParameterId.preferred_address)) { + throw new QuicTransportException( + "Preferred address present but connection ID has zero length", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + if (params.isPresent(active_connection_id_limit)) { + final long limit = params.getIntParameter(active_connection_id_limit); + if (limit < 2) { + throw new QuicTransportException( + "Invalid active_connection_id_limit " + limit, + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + } + if (params.isPresent(ParameterId.stateless_reset_token)) { + final byte[] statelessResetToken = params.getParameter(ParameterId.stateless_reset_token); + if (statelessResetToken.length != RESET_TOKEN_LENGTH) { + // RFC states 16 bytes for stateless token + throw new QuicTransportException( + "Invalid stateless reset token length " + statelessResetToken.length, + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + } + VersionInformation vi = + params.getVersionInformationParameter(version_information); + if (vi != null) { + if (vi.chosenVersion() != quicVersion().versionNumber()) { + throw new QuicTransportException( + "[version_information] Chosen Version does not match version in use", + null, 0, QuicTransportErrors.VERSION_NEGOTIATION_ERROR); + } + if (processedVersionsPacket) { + if (vi.availableVersions().length == 0) { + throw new QuicTransportException( + "[version_information] available versions empty", + null, 0, QuicTransportErrors.VERSION_NEGOTIATION_ERROR); + } + if (Arrays.stream(vi.availableVersions()) + .anyMatch(i -> i == originalVersion.versionNumber())) { + throw new QuicTransportException( + "[version_information] original version was available", + null, 0, QuicTransportErrors.VERSION_NEGOTIATION_ERROR); + } + } + } else { + if (processedVersionsPacket && quicVersion != QuicVersion.QUIC_V1) { + throw new QuicTransportException( + "version_information parameter absent", + null, 0, QuicTransportErrors.VERSION_NEGOTIATION_ERROR); + } + } + handleIncomingPeerTransportParams(params); + + // params.setIntParameter(ParameterId.max_idle_timeout, TimeUnit.SECONDS.toMillis(30)); + // params.setParameter(ParameterId.stateless_reset_token, ...); // no token + // params.setIntParameter(ParameterId.initial_max_data, DEFAULT_INITIAL_MAX_DATA); + // params.setIntParameter(ParameterId.initial_max_stream_data_bidi_local, DEFAULT_INITIAL_STREAM_MAX_DATA); + // params.setIntParameter(ParameterId.initial_max_stream_data_bidi_remote, DEFAULT_INITIAL_STREAM_MAX_DATA); + // params.setIntParameter(ParameterId.initial_max_stream_data_uni, DEFAULT_INITIAL_STREAM_MAX_DATA); + // params.setIntParameter(ParameterId.initial_max_streams_bidi, DEFAULT_MAX_STREAMS); + // params.setIntParameter(ParameterId.initial_max_streams_uni, DEFAULT_MAX_STREAMS); + // params.setIntParameter(ParameterId.ack_delay_exponent, 3); // unit 2^3 microseconds + // params.setIntParameter(ParameterId.max_ack_delay, 25); //25 millis + // params.setBooleanParameter(ParameterId.disable_active_migration, false); + // params.setPreferredAddressParameter(ParameterId.preferred_address, ...); + // params.setIntParameter(ParameterId.active_connection_id_limit, 2); + } + + /** + * {@return the number of (active) connection ids that this endpoint is willing + * to accept from the peer for a given connection} + */ + protected long getLocalActiveConnIDLimit() { + // currently we don't accept anything more than 2 (the RFC defined default minimum) + return 2; + } + + /** + * {@return the number of (active) connection ids that the peer is willing to accept + * for a given connection} + */ + protected long getPeerActiveConnIDLimit() { + return this.peerActiveConnIdsLimit; + } + + protected ByteBuffer buildInitialParameters() { + final QuicTransportParameters params = new QuicTransportParameters(this.transportParams); + setIntParamIfNotSet(params, active_connection_id_limit, this::getLocalActiveConnIDLimit); + final long idleTimeoutMillis = TimeUnit.SECONDS.toMillis( + Utils.getLongProperty("jdk.httpclient.quic.idleTimeout", 30)); + setIntParamIfNotSet(params, max_idle_timeout, () -> idleTimeoutMillis); + setIntParamIfNotSet(params, max_udp_payload_size, () -> { + assert this.endpoint != null : "Endpoint hasn't been set"; + return (long) this.endpoint.getMaxUdpPayloadSize(); + }); + setIntParamIfNotSet(params, initial_max_data, () -> DEFAULT_INITIAL_MAX_DATA); + setIntParamIfNotSet(params, initial_max_stream_data_bidi_local, + () -> DEFAULT_INITIAL_STREAM_MAX_DATA); + setIntParamIfNotSet(params, initial_max_stream_data_uni, () -> DEFAULT_INITIAL_STREAM_MAX_DATA); + setIntParamIfNotSet(params, initial_max_stream_data_bidi_remote, () -> DEFAULT_INITIAL_STREAM_MAX_DATA); + setIntParamIfNotSet(params, initial_max_streams_uni, () -> (long) DEFAULT_MAX_UNI_STREAMS); + setIntParamIfNotSet(params, initial_max_streams_bidi, () -> (long) DEFAULT_MAX_BIDI_STREAMS); + if (!params.isPresent(initial_source_connection_id)) { + params.setParameter(initial_source_connection_id, connectionId.getBytes()); + } + if (!params.isPresent(version_information)) { + final VersionInformation vi = + QuicTransportParameters.buildVersionInformation(quicVersion, + quicInstance().getAvailableVersions()); + params.setVersionInformationParameter(version_information, vi); + } + // params.setIntParameter(ParameterId.ack_delay_exponent, 3); // unit 2^3 microseconds + // params.setIntParameter(ParameterId.max_ack_delay, 25); //25 millis + // params.setBooleanParameter(ParameterId.disable_active_migration, false); + final ByteBuffer buf = ByteBuffer.allocate(params.size()); + params.encode(buf); + buf.flip(); + if (debug.on()) { + debug.log("local transport params: %s", params.toStringWithValues()); + } + newLocalTransportParameters(params); + return buf; + } + + protected static void setIntParamIfNotSet(final QuicTransportParameters params, + final ParameterId paramId, + final Supplier valueSupplier) { + if (params.isPresent(paramId)) { + return; + } + params.setIntParameter(paramId, valueSupplier.get()); + } + + // the token to be included in initial packets, if any. + private byte[] initialToken() { + return initialToken; + } + + protected void newLocalTransportParameters(QuicTransportParameters params) { + localTransportParameters = params; + oneRttRcvQueue.newLocalParameters(params); + streams.newLocalTransportParameters(params); + final long idleTimeout = params.getIntParameter(max_idle_timeout, 0); + this.idleTimeoutManager.localIdleTimeout(idleTimeout); + } + + private List ackOrPing(AckFrame ack, boolean sendPing) { + if (sendPing) { + return ack == null ? List.of(new PingFrame()) : List.of(new PingFrame(), ack); + } + assert ack != null; + return List.of(ack); + } + + /** + * Emit a possibly non ACK-eliciting packet containing the given ACK frame. + * @param packetSpaceManager the packet space manager on behalf + * of which the acknowledgement should + * be sent. + * @param ackFrame the ACK frame to be sent. + * @param sendPing whether a PING frame should be sent. + * @return the emitted packet number, or -1L if not applicable or not emitted + */ + private long emitAckPacket(final PacketSpace packetSpaceManager, final AckFrame ackFrame, + final boolean sendPing) + throws QuicKeyUnavailableException, QuicTransportException { + if (ackFrame == null && !sendPing) { + return -1L; + } + if (debug.on()) { + if (sendPing) { + debug.log("Sending PING packet %s ack", + ackFrame == null ? "without" : "with"); + } else { + debug.log("sending ACK packet"); + } + } + final List frames = ackOrPing(ackFrame, sendPing); + final PacketNumberSpace packetNumberSpace = packetSpaceManager.packetNumberSpace(); + if (debug.on()) { + debug.log("Sending packet for %s, frame=%s", packetNumberSpace, frames); + } + final KeySpace keySpace = switch (packetNumberSpace) { + case APPLICATION -> KeySpace.ONE_RTT; + case HANDSHAKE -> KeySpace.HANDSHAKE; + case INITIAL -> KeySpace.INITIAL; + default -> throw new UnsupportedOperationException( + "Invalid packet number space: " + packetNumberSpace); + }; + if (sendPing && Log.quicRetransmit()) { + Log.logQuic("{0} {1}: sending PingFrame", logTag(), keySpace); + } + final QuicPacket ackpacket = encoder.newOutgoingPacket(keySpace, + packetSpaceManager, localConnectionId(), + peerConnectionId(), initialToken(), frames, codingContext); + pushDatagram(ProtectionRecord.single(ackpacket, this::allocateDatagramForEncryption)); + return ackpacket.packetNumber(); + } + + private LinkedList removeOutdatedFrames(List frames) { + // Remove frames that should not be retransmitted + LinkedList result = new LinkedList<>(); + for (QuicFrame f : frames) { + if (!(f instanceof PaddingFrame) && + !(f instanceof AckFrame) && + !(f instanceof PathChallengeFrame) && + !(f instanceof PathResponseFrame)) { + result.add(f); + } + } + return result; + } + + /** + * Retransmit the given packet on behalf of the given packet space + * manager. + * @param packetSpaceManager the packet space manager on behalf of + * which the packet is being retransmitted + * @param packet the unacknowledged packet which should be retransmitted + */ + private void retransmit(PacketSpace packetSpaceManager, QuicPacket packet, int attempts) + throws QuicKeyUnavailableException, QuicTransportException { + if (debug.on()) { + debug.log("Retransmitting packet [type=%s, pn=%d, attempts:%d]: %s", + packet.packetType(), packet.packetNumber(), attempts, packet); + } + + assert packetSpaceManager.packetNumberSpace() == packet.numberSpace(); + long oldPacketNumber = packet.packetNumber(); + assert oldPacketNumber >= 0; + + long largestAckedPN = packetSpaceManager.getLargestPeerAckedPN(); + long newPacketNumber = packetSpaceManager.allocateNextPN(); + final int maxDatagramSize = getMaxDatagramSize(); + final QuicConnectionId peerConnectionId = peerConnectionId(); + final int dstIdLength = peerConnectionId.length(); + final PacketNumberSpace packetNumberSpace = packetSpaceManager.packetNumberSpace(); + final int initialDstIdLength = MAX_CONNECTION_ID_LENGTH; // reserve space for the ID to grow + + int maxPayloadSize = switch (packetNumberSpace) { + case APPLICATION -> QuicPacketEncoder.computeMaxOneRTTPayloadSize( + codingContext, newPacketNumber, dstIdLength, maxDatagramSize, largestAckedPN); + case INITIAL -> QuicPacketEncoder.computeMaxInitialPayloadSize( + codingContext, computePacketNumberLength(newPacketNumber, + codingContext.largestAckedPN(PacketNumberSpace.INITIAL)), + ((InitialPacket) packet).tokenLength(), + localConnectionId().length(), initialDstIdLength, maxDatagramSize); + case HANDSHAKE -> QuicPacketEncoder.computeMaxHandshakePayloadSize( + codingContext, newPacketNumber, localConnectionId().length(), + dstIdLength, maxDatagramSize); + default -> throw new IllegalArgumentException( + "Invalid packet number space: " + packetNumberSpace); + }; + + // The new packet may have larger size(), which might no longer fit inside + // the maximum datagram size supported on the path. To avoid that, we + // strip the padding and old ack frame from the original packet, and + // include the new ack frame only if it fits in the available size. + LinkedList frames = removeOutdatedFrames(packet.frames()); + int size = frames.stream().mapToInt(QuicFrame::size).sum(); + int remaining = maxPayloadSize - size; + AckFrame ack = packetSpaceManager.getNextAckFrame(false, remaining); + if (ack != null) { + assert ack.size() <= remaining : "AckFrame size %s is bigger than %s" + .formatted(ack.size(), remaining); + frames.addFirst(ack); + } + QuicPacket retransmitted = + switch (packet.packetType()) { + case INITIAL -> encoder.newInitialPacket(localConnectionId(), + peerConnectionId, ((InitialPacket) packet).token(), + newPacketNumber, largestAckedPN, frames, + codingContext); + case HANDSHAKE -> encoder.newHandshakePacket(localConnectionId(), + peerConnectionId, newPacketNumber, largestAckedPN, + frames, codingContext); + case ONERTT, ZERORTT -> encoder.newOneRttPacket( + peerConnectionId, newPacketNumber, largestAckedPN, + frames, codingContext); + default -> throw new IllegalArgumentException("packetType: %s, packet: %s" + .formatted(packet.packetType(), packet.packetNumber())); + }; + + if (Log.quicRetransmit()) { + Log.logQuic("%s OUT: retransmitting packet [%s] pn:%s as pn:%s".formatted( + logTag(), packet.packetType(), oldPacketNumber, newPacketNumber)); + } + pushDatagram(ProtectionRecord.retransmitting(retransmitted, + oldPacketNumber, + this::allocateDatagramForEncryption)); + } + + @Override + public CompletableFuture requestSendPing() { + final KeySpace space = quicTLSEngine.getCurrentSendKeySpace(); + final PacketSpace spaceManager = packetSpaces.get(PacketNumberSpace.of(space)); + return spaceManager.requestSendPing(); + } + + /** + * {@return the underlying {@code NetworkChannel} used by this connection, + * which may be {@code null} if the endpoint has not been configured yet} + */ + public NetworkChannel channel() { + QuicEndpoint endpoint = this.endpoint; + return endpoint == null ? null : endpoint.channel(); + } + + @Override + public String toString() { + return cachedToString; + } + + @Override + public boolean isOpen() { + return stateHandle.opened(); + } + + @Override + public TerminationCause terminationCause() { + return terminator.getTerminationCause(); + } + + public final CompletableFuture futureTerminationCause() { + return terminator.futureTerminationCause(); + } + + @Override + public ConnectionTerminator connectionTerminator() { + return this.terminator; + } + + /** + * {@return true if this connection is a client connection} + * Server side connections will return false. + */ + public boolean isClientConnection() { + return true; + } + + /** + * Called when new quic transport parameters are available from the peer. + * @param params the peer's new quic transport parameter + */ + protected void handleIncomingPeerTransportParams(final QuicTransportParameters params) { + peerTransportParameters = params; + this.idleTimeoutManager.peerIdleTimeout(params.getIntParameter(max_idle_timeout)); + // when we reach here, the value for max_udp_payload_size has already been + // asserted that it isn't outside the allowed range of 1200 to 65527. That has + // happened in QuicTransportParameters.checkParameter(). + // intentional cast to int since the value will be within int range + maxPeerAdvertisedPayloadSize = (int) params.getIntParameter(max_udp_payload_size); + congestionController.updateMaxDatagramSize(getMaxDatagramSize()); + if (params.isPresent(ParameterId.initial_max_data)) { + oneRttSndQueue.setMaxData(params.getIntParameter(ParameterId.initial_max_data), true); + } + streams.newPeerTransportParameters(params); + packetSpaces.app().updatePeerTransportParameters( + params.getIntParameter(ParameterId.max_ack_delay), + params.getIntParameter(ParameterId.ack_delay_exponent)); + // param value for this param is already validated outside of this method, so we just + // set the value without any validations + this.peerActiveConnIdsLimit = params.getIntParameter(active_connection_id_limit); + if (params.isPresent(ParameterId.stateless_reset_token)) { + // the stateless reset token for the handshake connection id + final byte[] statelessResetToken = params.getParameter(ParameterId.stateless_reset_token); + // register with peer connid manager + this.peerConnIdManager.handshakeStatelessResetToken(statelessResetToken); + } + if (params.isPresent(ParameterId.preferred_address)) { + final byte[] val = params.getParameter(ParameterId.preferred_address); + final ByteBuffer preferredConnId = QuicTransportParameters.getPreferredConnectionId(val); + final byte[] preferredStatelessResetToken = QuicTransportParameters + .getPreferredStatelessResetToken(val); + this.peerConnIdManager.handlePreferredAddress(preferredConnId, preferredStatelessResetToken); + } + if (debug.on()) { + debug.log("incoming peer parameters handled"); + } + } + + protected void incomingInitialFrame(final AckFrame frame) throws QuicTransportException { + packetSpaces.initial.processAckFrame(frame); + if (!handshakeFlow.handshakeReachedPeerCF.isDone()) { + if (debug.on()) debug.log("completing handshakeStartedCF normally"); + handshakeFlow.handshakeReachedPeerCF.complete(null); + } + } + + protected int incomingInitialFrame(final CryptoFrame frame) throws QuicTransportException { + // make sure to provide the frames in order, and + // buffer them if at the wrong offset + if (!handshakeFlow.handshakeReachedPeerCF.isDone()) { + if (debug.on()) debug.log("completing handshakeStartedCF normally"); + handshakeFlow.handshakeReachedPeerCF.complete(null); + } + final CryptoDataFlow peerInitial = handshakeFlow.peerInitial; + final long buffer = frame.offset() + frame.length() - peerInitial.offset(); + if (buffer > MAX_INCOMING_CRYPTO_CAPACITY) { + throw new QuicTransportException( + "Crypto buffer exceeded, required: " + buffer, + KeySpace.INITIAL, frame.frameType(), + QuicTransportErrors.CRYPTO_BUFFER_EXCEEDED); + } + int provided = 0; + CryptoFrame nextFrame = peerInitial.receive(frame); + while (nextFrame != null) { + if (debug.on()) { + debug.log("Provide crypto frame to engine: %s", nextFrame); + } + quicTLSEngine.consumeHandshakeBytes(KeySpace.INITIAL, nextFrame.payload()); + provided += nextFrame.length(); + nextFrame = peerInitial.poll(); + if (debug.on()) { + debug.log("Provided: " + provided); + } + } + return provided; + } + + protected void incomingInitialFrame(final PaddingFrame frame) throws QuicTransportException { + // nothing to do + } + + protected void incomingInitialFrame(final PingFrame frame) throws QuicTransportException { + // nothing to do + } + + protected void incomingInitialFrame(final ConnectionCloseFrame frame) + throws QuicTransportException { + terminator.incomingConnectionCloseFrame(frame); + } + + protected void incomingHandshakeFrame(final AckFrame frame) throws QuicTransportException { + packetSpaces.handshake.processAckFrame(frame); + } + + protected int incomingHandshakeFrame(final CryptoFrame frame) throws QuicTransportException { + final CryptoDataFlow peerHandshake = handshakeFlow.peerHandshake; + // make sure to provide the frames in order, and + // buffer them if at the wrong offset + final long buffer = frame.offset() + frame.length() - peerHandshake.offset(); + if (buffer > MAX_INCOMING_CRYPTO_CAPACITY) { + throw new QuicTransportException( + "Crypto buffer exceeded, required: " + buffer, + KeySpace.HANDSHAKE, frame.frameType(), + QuicTransportErrors.CRYPTO_BUFFER_EXCEEDED); + } + int provided = 0; + CryptoFrame nextFrame = peerHandshake.receive(frame); + while (nextFrame != null) { + quicTLSEngine.consumeHandshakeBytes(KeySpace.HANDSHAKE, nextFrame.payload()); + provided += nextFrame.length(); + nextFrame = peerHandshake.poll(); + } + return provided; + } + + protected void incomingHandshakeFrame(final PaddingFrame frame) throws QuicTransportException { + // nothing to do + } + + protected void incomingHandshakeFrame(final PingFrame frame) throws QuicTransportException { + // nothing to do + } + + protected void incomingHandshakeFrame(final ConnectionCloseFrame frame) + throws QuicTransportException { + terminator.incomingConnectionCloseFrame(frame); + } + + protected void incoming1RTTFrame(final AckFrame ackFrame) throws QuicTransportException { + packetSpaces.app.processAckFrame(ackFrame); + } + + protected void incoming1RTTFrame(final StreamFrame frame) throws QuicTransportException { + final long streamId = frame.streamId(); + final QuicReceiverStream stream = getReceivingStream(streamId, frame.getTypeField()); + if (stream != null) { + assert frame.streamId() == stream.streamId(); + streams.processIncomingFrame(stream, frame); + } + } + + protected void incoming1RTTFrame(final CryptoFrame frame) throws QuicTransportException { + final long buffer = frame.offset() + frame.length() - peerCryptoFlow.offset(); + if (buffer > MAX_INCOMING_CRYPTO_CAPACITY) { + throw new QuicTransportException( + "Crypto buffer exceeded, required: " + buffer, + KeySpace.ONE_RTT, frame.frameType(), + QuicTransportErrors.CRYPTO_BUFFER_EXCEEDED); + } + CryptoFrame nextFrame = peerCryptoFlow.receive(frame); + while (nextFrame != null) { + quicTLSEngine.consumeHandshakeBytes(KeySpace.ONE_RTT, nextFrame.payload()); + nextFrame = peerCryptoFlow.poll(); + } + } + + protected void incoming1RTTFrame(final ResetStreamFrame frame) throws QuicTransportException { + final long streamId = frame.streamId(); + final QuicReceiverStream stream = getReceivingStream(streamId, frame.getTypeField()); + if (stream != null) { + assert frame.streamId() == stream.streamId(); + streams.processIncomingFrame(stream, frame); + } + } + + protected void incoming1RTTFrame(final StreamDataBlockedFrame frame) + throws QuicTransportException { + final QuicReceiverStream stream = getReceivingStream(frame.streamId(), frame.getTypeField()); + if (stream != null) { + assert frame.streamId() == stream.streamId(); + streams.processIncomingFrame(stream, frame); + } + } + + protected void incoming1RTTFrame(final DataBlockedFrame frame) throws QuicTransportException { + // TODO implement similar logic as STREAM_DATA_BLOCKED frame receipt + // and increment gradually if consumption is more than 1/4th the window size of the + // connection + } + + protected void incoming1RTTFrame(final StreamsBlockedFrame frame) + throws QuicTransportException { + if (frame.maxStreams() > MAX_STREAMS_VALUE_LIMIT) { + throw new QuicTransportException("Invalid maxStreams value %s" + .formatted(frame.maxStreams()), + KeySpace.ONE_RTT, + frame.getTypeField(), QuicTransportErrors.FRAME_ENCODING_ERROR); + } + streams.peerStreamsBlocked(frame); + } + + protected void incoming1RTTFrame(final PaddingFrame frame) throws QuicTransportException { + // nothing to do + } + + protected void incoming1RTTFrame(final MaxDataFrame frame) throws QuicTransportException { + oneRttSndQueue.setMaxData(frame.maxData(), false); + } + + protected void incoming1RTTFrame(final MaxStreamDataFrame frame) + throws QuicTransportException { + final long streamId = frame.streamID(); + final QuicSenderStream stream = getSendingStream(streamId, frame.getTypeField()); + if (stream != null) { + streams.setMaxStreamData(stream, frame.maxStreamData()); + } + } + + protected void incoming1RTTFrame(final MaxStreamsFrame frame) throws QuicTransportException { + if (frame.maxStreams() >> 60 != 0) { + throw new QuicTransportException("Invalid maxStreams value %s" + .formatted(frame.maxStreams()), + KeySpace.ONE_RTT, + frame.getTypeField(), QuicTransportErrors.FRAME_ENCODING_ERROR); + } + final boolean increased = streams.tryIncreaseStreamLimit(frame); + if (debug.on()) { + debug.log((increased ? "increased" : "did not increase") + + " " + (frame.isBidi() ? "bidi" : "uni") + + " stream limit to " + frame.maxStreams()); + } + } + + protected void incoming1RTTFrame(final StopSendingFrame frame) throws QuicTransportException { + final long streamId = frame.streamID(); + final QuicSenderStream stream = getSendingStream(streamId, frame.getTypeField()); + if (stream != null) { + streams.stopSendingReceived(stream, + frame.errorCode()); + } + } + + protected void incoming1RTTFrame(final PingFrame frame) throws QuicTransportException { + // nothing to do + } + + protected void incoming1RTTFrame(final ConnectionCloseFrame frame) + throws QuicTransportException { + terminator.incomingConnectionCloseFrame(frame); + } + + protected void incoming1RTTFrame(final HandshakeDoneFrame frame) + throws QuicTransportException { + if (quicTLSEngine.tryReceiveHandshakeDone()) { + // now that HANDSHAKE_DONE is received (and thus handshake confirmed), + // close the HANDSHAKE packet space (and thus discard the keys) + if (debug.on()) { + debug.log("received HANDSHAKE_DONE from server, initiating close of" + + " HANDSHAKE packet space"); + } + packetSpaces.handshake.close(); + } + packetSpaces.app.confirmHandshake(); + } + + protected void incoming1RTTFrame(final NewConnectionIDFrame frame) + throws QuicTransportException { + if (peerConnectionId().length() == 0) { + throw new QuicTransportException( + "NEW_CONNECTION_ID not allowed here", + null, frame.getTypeField(), PROTOCOL_VIOLATION); + } + this.peerConnIdManager.handleNewConnectionIdFrame(frame); + } + + protected void incoming1RTTFrame(final OneRttPacket oneRttPacket, + final RetireConnectionIDFrame frame) + throws QuicTransportException { + this.localConnIdManager.handleRetireConnectionIdFrame(oneRttPacket.destinationId(), + PacketType.ONERTT, frame); + } + + protected void incoming1RTTFrame(final NewTokenFrame frame) throws QuicTransportException { + // as per RFC 9000, section 19.7, token cannot be empty and if it is, then + // a connection error of type FRAME_ENCODING_ERROR needs to be raised + final byte[] newToken = frame.token(); + if (newToken.length == 0) { + throw new QuicTransportException("Empty token in NEW_TOKEN frame", + KeySpace.ONE_RTT, + frame.getTypeField(), QuicTransportErrors.FRAME_ENCODING_ERROR); + } + assert this.quicInstance instanceof QuicClient : "Not a QuicClient"; + final QuicClient quicClient = (QuicClient) this.quicInstance; + // set this as the initial token to be used in INITIAL packets when attempting + // any new subsequent connections against this same target server + quicClient.registerInitialToken(this.peerAddress, newToken); + if (debug.on()) { + debug.log("Registered a new (initial) token for peer " + this.peerAddress); + } + } + + protected void incoming1RTTFrame(final PathResponseFrame frame) + throws QuicTransportException { + throw new QuicTransportException("Unmatched PATH_RESPONSE frame", + KeySpace.ONE_RTT, + frame.getTypeField(), PROTOCOL_VIOLATION); + } + + protected void incoming1RTTFrame(final PathChallengeFrame frame) + throws QuicTransportException { + pathChallengeFrameQueue.offer(frame); + if (pathChallengeFrameQueue.size() > 3) { + // we don't expect to hold more than 1 PathChallenge per path. + // If there's more than 3 outstanding challenges, drop the oldest one. + pathChallengeFrameQueue.poll(); + } + } + + /** + * Signal the connection that some stream data is available for sending on one or more streams. + * @param streamIds the stream ids + */ + public void streamDataAvailableForSending(final Set streamIds) { + for (final long id : streamIds) { + streams.enqueueForSending(id); + } + packetSpaces.app.runTransmitter(); + } + + /** + * Called when the receiving part or the sending part of a stream + * reaches a terminal state. + * @param streamId the id of the stream + * @param state the terminal state + */ + public void notifyTerminalState(long streamId, StreamState state) { + assert state.isTerminal() : state; + streams.notifyTerminalState(streamId, state); + } + + /** + * Called to request sending of a RESET_STREAM frame. + * + * @apiNote + * Should only be called for sending streams. For stopping a + * receiving stream then {@link #scheduleStopSendingFrame(long, long)} should be called. + * This method should only be called from {@code QuicSenderStreamImpl}, after + * switching the state of the stream to RESET_SENT. + * + * @param streamId the id of the stream that should be reset + * @param errorCode the application error code + */ + public void requestResetStream(long streamId, long errorCode) { + assert streams.isSendingStream(streamId); + streams.requestResetStream(streamId, errorCode); + packetSpaces.app.runTransmitter(); + } + + /** + * Called to request sending of a STOP_SENDING frame. + * @apiNote + * Should only be called for receiving streams. For stopping a + * sending stream then {@link #requestResetStream(long, long)} + * should be called. + * This method should only be called from {@code QuicReceiverStreamImpl} + * @param streamId the stream id to be cancelled + * @param errorCode the application error code + */ + public void scheduleStopSendingFrame(long streamId, long errorCode) { + assert streams.isReceivingStream(streamId); + streams.scheduleStopSendingFrame(new StopSendingFrame(streamId, errorCode)); + packetSpaces.app.runTransmitter(); + } + + /** + * Called to request sending of a MAX_STREAM_DATA frame. + * @apiNote + * Should only be called for receiving streams. + * This method should only be called from {@code QuicReceiverStreamImpl} + * @param streamId the stream id to be cancelled + * @param maxStreamData the new max data we are prepared to receive on + * this stream + */ + public void requestSendMaxStreamData(long streamId, long maxStreamData) { + assert streams.isReceivingStream(streamId); + streams.requestSendMaxStreamData(new MaxStreamDataFrame(streamId, maxStreamData)); + packetSpaces.app.runTransmitter(); + } + + /** + * Called when frame data can be safely added to the amount of + * data received by the connection for MAX_DATA flow control + * purpose. + * @throws QuicTransportException if flow control was exceeded + * @param frameType type of frame received + */ + public void increaseReceivedData(long diff, long frameType) throws QuicTransportException { + oneRttRcvQueue.checkAndIncreaseReceivedData(diff, frameType); + } + + /** + * Called when frame data is removed from the connection + * and the amount of data can be added to MAX_DATA window. + * @param diff amount of data processed + */ + public void increaseProcessedData(long diff) { + oneRttRcvQueue.increaseProcessedData(diff); + } + + public QuicTLSEngine getTLSEngine() { + return quicTLSEngine; + } + + /** + * {@return the computed PTO for the current packet number space, + * adjusted by our max ack delay} + */ + public long peerPtoMs() { + return rttEstimator.getBasePtoDuration().toMillis() + + (quicTLSEngine.getCurrentSendKeySpace() == KeySpace.ONE_RTT ? + PacketSpaceManager.ADVERTISED_MAX_ACK_DELAY : 0); + } + + public void runAppPacketSpaceTransmitter() { + this.packetSpaces.app.runTransmitter(); + } + + public void shutdown() { + packetSpaces.close(); + } + + public final String logTag() { + return logTag; + } + + /* ======================================================== + * Direct Byte Buffer Pool + * ======================================================== */ + + // Maximum size of the connection's Direct ByteBuffer Pool. + // For a connection configured to attempt sending datagrams in thread + // (QuicEndpoint.SEND_DGRAM_ASYNC == false), 2 should be enough, as we + // shouldn't have more than 2 packet number spaces active at the same time. + private static final int MAX_DBB_POOL_SIZE = 3; + // The ByteBuffer pool, which contains available byte buffers + private final ConcurrentLinkedQueue bbPool = new ConcurrentLinkedQueue<>(); + // The number of Direct Byte Buffers allocated for sending and managed by the pool. + // This is the number of Direct Byte Buffers currently in flight, plus the number + // of available byte buffers present in the pool. It will never exceed + // MAX_DBB_POOL_SIZE. + private final AtomicInteger bbAllocated = new AtomicInteger(); + + // Some counters used for printing debug statistics when Log quic:dbb is enabled + // Byte Buffers in flight: the number of byte buffers that were returned by + // getOutgoingByteBuffer() minus the number of byte buffers that were released + // through datagramReleased() + private final AtomicInteger bbInFlight = new AtomicInteger(); + // Peak number of byte buffers in flight. Never decreases. + private final AtomicInteger bbPeak = new AtomicInteger(); + // Number of unreleased byte buffers. This should eventually reach 0. + final AtomicInteger bbUnreleased = new AtomicInteger(); + + /** + * {@return a new {@code ByteBuffer} to encode and encrypt packets in a datagram} + * This method may either allocate a new heap BteBuffer or return a (possibly + * new) Direct ByteBuffer from the connection's Direct Byte Buffer Pool. + * @param size the maximum size of the datagram + */ + protected ByteBuffer getOutgoingByteBuffer(int size) { + bbUnreleased.incrementAndGet(); + if (USE_DIRECT_BUFFER_POOL) { + if (size <= getMaxDatagramSize()) { + ByteBuffer buffer = bbPool.poll(); + if (buffer != null) { + if (buffer.limit() >= getMaxDatagramSize()) { + if (Log.quicDBB()) { + Log.logQuic("[" + Thread.currentThread().getName() + "] " + + logTag() + ": DIRECTBB: got direct buffer from pool" + + ", inFlight: " + bbInFlight.get() + ", peak: " + bbPeak.get() + + ", unreleased:" + bbUnreleased.get()); + } + int inFlight = bbInFlight.incrementAndGet(); + if (inFlight > bbPeak.get()) { + synchronized (this) { + if (inFlight > bbPeak.get()) bbPeak.set(inFlight); + } + } + return buffer; + } + bbAllocated.decrementAndGet(); + if (Log.quicDBB()) { + Log.logQuic("[" + Thread.currentThread().getName() + "] " + + logTag() + ": DIRECTBB: releasing direct buffer"); + } + buffer = null; + } + + assert buffer == null; + int allocated; + while ((allocated = bbAllocated.get()) < MAX_DBB_POOL_SIZE) { + if (bbAllocated.compareAndSet(allocated, allocated + 1)) { + if (Log.quicDBB()) { + Log.logQuic("[" + Thread.currentThread().getName() + "] " + + logTag() + ": DIRECTBB: allocating direct buffer #" + (allocated + 1) + + ", inFlight: " + bbInFlight.get() + ", peak: " + + bbPeak.get() + ", unreleased:" + bbUnreleased.get()); + } + int inFlight = bbInFlight.incrementAndGet(); + if (inFlight > bbPeak.get()) { + synchronized (this) { + if (inFlight > bbPeak.get()) bbPeak.set(inFlight); + } + } + return ByteBuffer.allocateDirect(getMaxDatagramSize()); + } + } + if (Log.quicDBB()) { + Log.logQuic("[" + Thread.currentThread().getName() + "] " + + logTag() + ": DIRECTBB: too many buffers allocated: " + allocated + + ", inFlight: " + bbInFlight.get() + ", peak: " + + bbPeak.get() + ", unreleased:" + bbUnreleased.get()); + } + + } else { + if (Log.quicDBB()) { + Log.logQuic("[" + Thread.currentThread().getName() + "] " + + logTag() + ": DIRECTBB: wrong size " + size); + } + } + } + int inFlight = bbInFlight.incrementAndGet(); + if (inFlight > bbPeak.get()) { + synchronized (this) { + if (inFlight > bbPeak.get()) bbPeak.set(inFlight); + } + } + return ByteBuffer.allocate(size); + } + @Override + public void datagramSent(QuicDatagram datagram) { + datagramReleased(datagram); + } + + @Override + public void datagramDiscarded(QuicDatagram datagram) { + if (Log.quicDBB()) { + Log.logQuic("[" + Thread.currentThread().getName() + "] " + + logTag() + ": DIRECTBB: datagram discarded " + datagram.payload().isDirect() + + ", inFlight: " + bbInFlight.get() + ", peak: " + bbPeak.get() + + ", unreleased:" + bbUnreleased.get()); + } + datagramReleased(datagram); + } + + public void datagramDropped(QuicDatagram datagram) { + if (Log.quicDBB()) { + Log.logQuic("[" + Thread.currentThread().getName() + "] " + + logTag() + ": DIRECTBB: datagram dropped " + datagram.payload().isDirect() + + ", inFlight: " + bbInFlight.get() + ", peak: " + bbPeak.get() + + ", unreleased:" + bbUnreleased.get()); + } + datagramReleased(datagram); + } + + /** + * Returns a {@link jdk.internal.net.http.quic.QuicEndpoint.Datagram} which contains + * an encrypted QUIC packet containing + * a {@linkplain ConnectionCloseFrame CONNECTION_CLOSE frame}. The CONNECTION_CLOSE + * frame will have a frame type of {@code 0x1c} and error code of {@code NO_ERROR}. + *

+ * This method should only be invoked when the {@link QuicEndpoint} is being closed + * and the endpoint wants to send out a {@code CONNECTION_CLOSE} frame on a best-effort + * basis (in a fire and forget manner). + * + * @return the datagram containing the QUIC packet with a CONNECTION_CLOSE frame or + * an {@linkplain Optional#empty() empty Optional} if the datagram couldn't + * be constructed. + */ + final Optional connectionCloseDatagram() { + try { + final ByteBuffer quicPktPayload = this.terminator.makeConnectionCloseDatagram(); + return Optional.of(new QuicDatagram(this, peerAddress, quicPktPayload)); + } catch (Exception e) { + // ignore any exception because providing the connection close datagram + // when the endpoint is being closed, is on best-effort basis + return Optional.empty(); + } + } + + /** + * Called when a datagram is being released, either from + * {@link #datagramSent(QuicDatagram)}, {@link #datagramDiscarded(QuicDatagram)}, + * or {@link #datagramDropped(QuicDatagram)}. + * This method may either release the datagram and let it get garbage collected, + * or return it to the pool. + * @param datagram the released datagram + */ + protected void datagramReleased(QuicDatagram datagram) { + bbUnreleased.decrementAndGet(); + if (Log.quicDBB()) { + Log.logQuic("[" + Thread.currentThread().getName() + "] " + + logTag() + ": DIRECTBB: datagram released " + datagram.payload().isDirect() + + ", inFlight: " + bbInFlight.get() + ", peak: " + bbPeak.get() + + ", unreleased:" + bbUnreleased.get()); + } + bbInFlight.decrementAndGet(); + if (USE_DIRECT_BUFFER_POOL) { + ByteBuffer buffer = datagram.payload(); + buffer.clear(); + if (buffer.isDirect()) { + if (buffer.limit() >= getMaxDatagramSize()) { + if (Log.quicDBB()) { + Log.logQuic("[" + Thread.currentThread().getName() + "] " + + logTag() + ": DIRECTBB: offering buffer to pool"); + } + bbPool.offer(buffer); + } else { + if (Log.quicDBB()) { + Log.logQuic("[" + Thread.currentThread().getName() + "] " + + logTag() + ": DIRECTBB: releasing direct buffer (too small)"); + } + bbAllocated.decrementAndGet(); + } + } + } + } + + public String loggableState() { + // for HTTP3 debugging + // If the connection was active (open bidi streams), log connection state + if (streams.quicStreams().noneMatch(QuicStream::isBidirectional)) { + // no active requests + return "No active requests"; + } + Deadline now = TimeSource.now(); + StringBuilder result = new StringBuilder("sending: {canSend:" + oneRttSndQueue.canSend() + + ", credit: " + oneRttSndQueue.credit() + + ", sendersReady: " + streams.hasAvailableData() + + ", hasControlFrames: " + streams.hasControlFrames() + + "}, cc: { backoff: " + rttEstimator.getPtoBackoff() + + ", duration: " + ((PacketSpaceManager) packetSpaces.app).getPtoDuration() + + ", current deadline: " + Utils.debugDeadline(now, + ((PacketSpaceManager) packetSpaces.app).deadline()) + + ", prospective deadline: " + Utils.debugDeadline(now, + ((PacketSpaceManager) packetSpaces.app).prospectiveDeadline()) + + "}, streams: ["); + streams.quicStreams().filter(QuicStream::isBidirectional).forEach( + s -> { + QuicBidiStreamImpl qb = (QuicBidiStreamImpl) s; + result.append("{id:" + s.streamId() + + ", available: " + qb.senderPart().available() + + ", blocked: " + qb.senderPart().isBlocked() + "}," + ); + } + ); + result.append("]"); + return result.toString(); + } + + /** + * {@return true if the packet contains a CONNECTION_CLOSE frame, false otherwise} + * @param packet the QUIC packet + */ + private static boolean containsConnectionClose(final QuicPacket packet) { + for (final QuicFrame frame : packet.frames()) { + if (frame instanceof ConnectionCloseFrame) { + return true; + } + } + return false; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicEndpoint.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicEndpoint.java new file mode 100644 index 00000000000..3df9b599f82 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicEndpoint.java @@ -0,0 +1,2062 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.SocketOption; +import java.net.SocketException; +import java.nio.ByteBuffer; +import java.nio.channels.CancelledKeyException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.DatagramChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.InvalidKeyException; +import java.security.Key; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.HexFormat; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.TimeLine; +import jdk.internal.net.http.common.TimeSource; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.QuicSelector.QuicNioSelector; +import jdk.internal.net.http.quic.QuicSelector.QuicVirtualThreadPoller; +import jdk.internal.net.http.quic.packets.QuicPacket.HeadersType; +import jdk.internal.net.http.quic.packets.QuicPacketDecoder; +import jdk.internal.util.OperatingSystem; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.KeyGenerator; +import javax.crypto.NoSuchPaddingException; + +import static jdk.internal.net.http.quic.QuicEndpoint.ChannelType.BLOCKING_WITH_VIRTUAL_THREADS; +import static jdk.internal.net.http.quic.QuicEndpoint.ChannelType.NON_BLOCKING_WITH_SELECTOR; +import static jdk.internal.net.http.quic.TerminationCause.forSilentTermination; + + +/** + * A QUIC Endpoint. A QUIC endpoint encapsulate a DatagramChannel + * and is registered with a Selector. It subscribes for read and + * write events from the selector, and implements a readLoop and + * a writeLoop. + *

+ * The read event or write event are triggered by the selector + * thread. When the read event is triggered, all available datagrams + * are read from the channel and pushed into a read queue. + * Then the readLoop is triggered. + * When the write event is triggered, the key interestOps are + * modified to pause write events, and the writeLoop is triggered. + *

+ * The readLoop and writeLoop should never execute on the selector + * thread, but rather, in the client's executor. + *

+ * When the writeLoop is triggered, it polls the writeQueue and + * writes as many datagram as it can to the channel. At the end, + * if there still remains some datagrams in the writeQueue, the + * write event is resumed. Otherwise, the writeLoop is next + * triggered when new datagrams are added to the writeQueue. + *

+ * When the readLoop is triggered, it polls the read queue + * and attempts to match each received packet with a + * QuicConnection. If no connection matches, it attempts + * to match the packet with stateless reset tokens. + * If no stateless reset token match, the packet is + * discarded. + *

+ */ +public abstract sealed class QuicEndpoint implements AutoCloseable + permits QuicEndpoint.QuicSelectableEndpoint, QuicEndpoint.QuicVirtualThreadedEndpoint { + + private static final int INCOMING_MAX_DATAGRAM; + static final boolean DGRAM_SEND_ASYNC; + static final int MAX_BUFFERED_HIGH; + static final int MAX_BUFFERED_LOW; + static { + // This default value is the maximum payload size of + // an IPv6 datagram, which is 65527 (which is bigger + // than that of an IPv4). + // We have only one direct buffer of this size per endpoint. + final int defSize = 65527; + // This is the value that will be transmitted to the server in the + // max_udp_payload_size parameter + int size = Utils.getIntegerProperty("jdk.httpclient.quic.maxUdpPayloadSize", defSize); + // don't allow the value to be below 1200 and above 65527, to conform with RFC-9000, + // section 18.2. + if (size < 1200 || size > 65527) { + // fallback to default size + size = defSize; + } + INCOMING_MAX_DATAGRAM = size; + // TODO: evaluate pros and cons WRT performance and decide for one or the other + // before GA. + DGRAM_SEND_ASYNC = Utils.getBooleanProperty("jdk.internal.httpclient.quic.sendAsync", false); + int maxBufferHigh = Math.clamp(Utils.getIntegerProperty("jdk.httpclient.quic.maxBufferedHigh", + 512 << 10), 128 << 10, 6 << 20); + int maxBufferLow = Math.clamp(Utils.getIntegerProperty("jdk.httpclient.quic.maxBufferedLow", + 384 << 10), 64 << 10, 6 << 20); + if (maxBufferLow >= maxBufferHigh) maxBufferLow = maxBufferHigh >> 1; + MAX_BUFFERED_HIGH = maxBufferHigh; + MAX_BUFFERED_LOW = maxBufferLow; + } + + /** + * This interface represent a UDP Datagram. This could be + * either an incoming datagram or an outgoing datagram. + */ + public sealed interface Datagram + permits QuicDatagram, StatelessReset, SendStatelessReset, UnmatchedDatagram { + /** + * {@return the peer address} + * For incoming datagrams, this is the sender address. + * For outgoing datagrams, this is the destination address. + */ + SocketAddress address(); + + /** + * {@return the datagram payload} + */ + ByteBuffer payload(); + } + + /** + * An incoming UDP Datagram for which no connection was found. + * On the server side it may represent a new connection attempt. + * @param address the {@linkplain Datagram#address() sender address} + * @param payload {@inheritDoc} + */ + public record UnmatchedDatagram(SocketAddress address, ByteBuffer payload) implements Datagram {} + + /** + * A stateless reset that should be sent in response + * to an incoming datagram targeted at a deleted connection. + * @param address the {@linkplain Datagram#address() destination address} + * @param payload the outgoing stateless reset + */ + public record SendStatelessReset(SocketAddress address, ByteBuffer payload) implements Datagram {} + + /** + * An incoming datagram containing a stateless reset + * @param connection the connection to reset + * @param address the {@linkplain Datagram#address() sender address} + * @param payload the datagram payload + */ + public record StatelessReset(QuicPacketReceiver connection, SocketAddress address, ByteBuffer payload) implements Datagram {} + + /** + * An outgoing datagram, or an incoming datagram for which + * a connection was identified. + * @param connection the sending or receiving connection + * @param address {@inheritDoc} + * @param payload {@inheritDoc} + */ + public record QuicDatagram(QuicPacketReceiver connection, SocketAddress address, ByteBuffer payload) + implements Datagram {} + + /** + * An enum identifying the type of channels used and supported by QuicEndpoint and + * QuicSelector + */ + public enum ChannelType { + NON_BLOCKING_WITH_SELECTOR, + BLOCKING_WITH_VIRTUAL_THREADS; + public boolean isBlocking() { + return this == BLOCKING_WITH_VIRTUAL_THREADS; + } + } + + // A temporary internal property to switch between two QuicSelector implementation: + // - if true, uses QuicNioSelector, an implementation based non-blocking and channels + // and an NIO Selector + // - if false, uses QuicVirtualThreadPoller, an implementation that use Virtual Threads + // to poll blocking channels + // On windows, we default to using non-blocking IO with a Selector in order + // to avoid a potential deadlock in WEPoll (see JDK-8334574). + private static final boolean USE_NIO_SELECTOR = + Utils.getBooleanProperty("jdk.internal.httpclient.quic.useNioSelector", + OperatingSystem.isWindows()); + /** + * The configured channel type + */ + public static final ChannelType CONFIGURED_CHANNEL_TYPE = USE_NIO_SELECTOR + ? NON_BLOCKING_WITH_SELECTOR + : BLOCKING_WITH_VIRTUAL_THREADS; + + final Logger debug = Utils.getDebugLogger(this::name); + private final QuicInstance quicInstance; + private final String name; + private final DatagramChannel channel; + private final ByteBuffer receiveBuffer; + final Executor executor; + final ConcurrentLinkedQueue readQueue = new ConcurrentLinkedQueue<>(); + final ConcurrentLinkedQueue writeQueue = new ConcurrentLinkedQueue<>(); + final QuicTimerQueue timerQueue; + private volatile boolean closed; + + // A synchronous scheduler to consume the readQueue list; + final SequentialScheduler readLoopScheduler = + SequentialScheduler.lockingScheduler(this::readLoop); + + // A synchronous scheduler to consume the writeQueue list; + final SequentialScheduler writeLoopScheduler = + SequentialScheduler.lockingScheduler(this::writeLoop); + + // A ConcurrentMap to store registered connections. + // The connection IDs might come from external sources. They implement Comparable + // to mitigate collision attacks. + // This map must not share the idFactory with other maps, + // see RFC 9000 section 21.11. Stateless Reset Oracle + private final ConcurrentMap connections = + new ConcurrentHashMap<>(); + + // a factory of local connection IDs. + private final QuicConnectionIdFactory idFactory; + + // Key used to encrypt tokens before storing in {@link #peerIssuedResetTokens} + private final Key tokenEncryptionKey; + + // keeps a link of the peer issued stateless reset token to the corresponding connection that + // will be closed if the specific stateless reset token is received + private final ConcurrentMap peerIssuedResetTokens = + new ConcurrentHashMap<>(); + + private static ByteBuffer allocateReceiveBuffer() { + return ByteBuffer.allocateDirect(INCOMING_MAX_DATAGRAM); + } + + private final AtomicInteger buffered = new AtomicInteger(); + volatile boolean readingStalled; + + public QuicConnectionIdFactory idFactory() { + return idFactory; + } + + public int buffer(int bytes) { + return buffered.addAndGet(bytes); + } + + public int unbuffer(int bytes) { + var newval = buffered.addAndGet(-bytes); + assert newval >= 0; + if (newval <= MAX_BUFFERED_LOW) { + resumeReading(); + } + return newval; + } + + boolean bufferTooBig() { + return buffered.get() >= MAX_BUFFERED_HIGH; + } + + public int buffered() { + return buffered.get(); + } + + boolean readingPaused() { + return readingStalled; + } + + abstract void resumeReading(); + + abstract void pauseReading(); + + private QuicEndpoint(QuicInstance quicInstance, + DatagramChannel channel, + String name, + QuicTimerQueue timerQueue) { + this.quicInstance = quicInstance; + this.name = name; + this.channel = channel; + this.receiveBuffer = allocateReceiveBuffer(); + this.executor = quicInstance.executor(); + this.timerQueue = timerQueue; + if (debug.on()) debug.log("created for %s", channel); + try { + KeyGenerator kg = KeyGenerator.getInstance("AES"); + tokenEncryptionKey = kg.generateKey(); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("AES key generator not available", e); + } + idFactory = isServer() + ? QuicConnectionIdFactory.getServer() + : QuicConnectionIdFactory.getClient(); + } + + public String name() { + return name; + } + + public DatagramChannel channel() { + return channel; + } + + Executor writeLoopExecutor() { return executor; } + + public SocketAddress getLocalAddress() throws IOException { + return channel.getLocalAddress(); + } + + public String getLocalAddressString() { + try { + return String.valueOf(channel.getLocalAddress()); + } catch (IOException io) { + return "No address available"; + } + } + + int getMaxUdpPayloadSize() { + return INCOMING_MAX_DATAGRAM; + } + + protected abstract ChannelType channelType(); + + /** + * A {@link QuicEndpoint} implementation based on non blocking + * {@linkplain DatagramChannel Datagram Channels} and using a + * NIO {@link Selector}. + * This implementation is tied to a {@link QuicNioSelector}. + */ + static final class QuicSelectableEndpoint extends QuicEndpoint { + volatile SelectionKey key; + + private QuicSelectableEndpoint(QuicInstance quicInstance, + DatagramChannel channel, + String name, + QuicTimerQueue timerQueue) { + super(quicInstance, channel, name, timerQueue); + assert !channel.isBlocking() : "SelectableQuicEndpoint channel is blocking"; + } + + @Override + public ChannelType channelType() { + return NON_BLOCKING_WITH_SELECTOR; + } + + /** + * Attaches this endpoint to a selector. + * + * @param selector the selector to attach to + * @throws ClosedChannelException if the channel is already closed + */ + public void attach(Selector selector) throws ClosedChannelException { + var key = this.key; + assert key == null; + // this block is needed to coordinate with detach() and + // selected(). See comment in selected(). + synchronized (this) { + this.key = super.channel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE, this); + } + } + + @Override + void resumeReading() { + boolean resumed = false; + SelectionKey key; + synchronized (this) { + key = this.key; + if (key != null && key.isValid()) { + if (isClosed() || isChannelClosed()) return; + int ops = key.interestOps(); + int newops = ops | SelectionKey.OP_READ; + if (ops != newops) { + key.interestOpsOr(SelectionKey.OP_READ); + readingStalled = false; + resumed = true; + } + } + } + if (resumed) { + // System.out.println(this + " endpoint resumed reading"); + if (debug.on()) debug.log("endpoint resumed reading"); + key.selector().wakeup(); + } + } + + @Override + void pauseReading() { + boolean paused = false; + synchronized (this) { + if (readingStalled) return; + if (key != null && key.isValid() && bufferTooBig()) { + if (isClosed() || isChannelClosed()) return; + int ops = key.interestOps(); + int newops = ops & ~SelectionKey.OP_READ; + if (ops != newops) { + key.interestOpsAnd(~SelectionKey.OP_READ); + readingStalled = true; + paused = true; + } + } + } + if (paused) { + // System.out.println(this + " endpoint paused reading"); + if (debug.on()) debug.log("endpoint paused reading"); + } + } + + /** + * Invoked by the {@link QuicSelector} when this endpoint's channel + * is selected. + * + * @param readyOps The operations that are ready for this endpoint. + */ + public void selected(int readyOps) { + var key = this.key; + try { + if (key == null) { + // null keys have been observed here. + // key can only be null if it's been cancelled, by detach() + // or if the call to channel::register hasn't returned yet + // the synchronized block below will block until + // channel::register returns if needed. + // This can only happen once, when attaching the channel, + // so there should be no performance issue in synchronizing + // here. + synchronized (this) { + key = this.key; + } + } + + if (key == null) { + if (debug.on()) { + debug.log("key is null"); + if (QuicEndpoint.class.desiredAssertionStatus()) { + Thread.dumpStack(); + } + } + return; + } + + if (debug.on()) { + debug.log("selected(interest=%s, ready=%s)", + Utils.interestOps(key), + Utils.readyOps(key)); + } + + int interestOps = key.interestOps(); + + // Some operations may be ready even when we are not interested. + // Typically, a channel may be ready for writing even if we have + // nothing to write. The events we need to invoke are therefore + // at the intersection of the ready set with the interest set. + int event = readyOps & interestOps; + if ((event & SelectionKey.OP_READ) == SelectionKey.OP_READ) { + onReadEvent(); + if (isClosed()) { + key.interestOpsAnd(~SelectionKey.OP_READ); + } + } + if ((event & SelectionKey.OP_WRITE) == SelectionKey.OP_WRITE) { + onWriteEvent(); + } + if (debug.on()) { + debug.log("interestOps: %s", Utils.interestOps(key)); + } + } finally { + if (!channel().isOpen()) { + if (key != null) key.cancel(); + close(); + } + } + } + + private void onReadEvent() { + var key = this.key; + try { + if (debug.on()) debug.log("onReadEvent"); + channelReadLoop(); + } finally { + if (debug.on()) { + debug.log("Leaving readEvent: ops=%s", Utils.interestOps(key)); + } + } + } + + private void onWriteEvent() { + // trigger code that will process the received + // datagrams asynchronously + // => Use a sequential scheduler, making sure it never + // runs on this thread. + // Do we need a pub/sub mechanism here? + // The write event will be paused/resumed by the + // writeLoop if needed + if (debug.on()) debug.log("onWriteEvent"); + var key = this.key; + if (key != null && key.isValid()) { + int previous; + synchronized (this) { + previous = key.interestOpsAnd(~SelectionKey.OP_WRITE); + } + if (debug.on()) debug.log("key changed from %s to: %s", + Utils.describeOps(previous), Utils.interestOps(key)); + } + writeLoopScheduler.runOrSchedule(writeLoopExecutor()); + if (debug.on() && key != null) { + debug.log("Leaving writeEvent: ops=%s", Utils.interestOps(key)); + } + } + + @Override + void writeLoop() { + super.writeLoop(); + // update selection key if needed + var key = this.key; + try { + if (key != null && key.isValid()) { + int ops, newops; + synchronized (this) { + ops = newops = key.interestOps(); + if (writeQueue.isEmpty()) { + // we have nothing else to write for now + newops &= ~SelectionKey.OP_WRITE; + } else { + // there's more to write + newops |= SelectionKey.OP_WRITE; + } + if (newops != ops && key.selector().isOpen()) { + key.interestOps(newops); + key.selector().wakeup(); + } + } + if (debug.on()) { + debug.log("leaving writeLoop: ops=%s", Utils.describeOps(newops)); + } + } + } catch (CancelledKeyException x) { + if (debug.on()) debug.log("key cancelled"); + if (writeQueue.isEmpty()) return; + else { + closeWriteQueue(x); + } + } + } + + @Override + void readLoop() { + try { + super.readLoop(); + } finally { + if (debug.on()) { + debug.log("leaving readLoop: ops=%s", Utils.interestOps(key)); + } + } + } + + @Override + public void detach() { + var key = this.key; + if (key == null) return; + if (debug.on()) { + debug.log("cancelling key: " + key); + } + // this block is needed to coordinate with attach() and + // selected(). See comment in selected(). + synchronized (this) { + key.cancel(); + this.key = null; + } + } + } + + /** + * A {@link QuicEndpoint} implementation based on blocking + * {@linkplain DatagramChannel Datagram Channels} and using a + * Virtual Threads to poll the channel. + * This implementation is tied to a {@link QuicVirtualThreadPoller}. + */ + static final class QuicVirtualThreadedEndpoint extends QuicEndpoint { + Future key; + volatile QuicVirtualThreadPoller poller; + boolean readingDone; + + private QuicVirtualThreadedEndpoint(QuicInstance quicInstance, + DatagramChannel channel, + String name, + QuicTimerQueue timerQueue) { + super(quicInstance, channel, name, timerQueue); + } + + @Override + boolean readingPaused() { + synchronized (this) { + return readingDone = super.readingPaused(); + } + } + + @Override + void resumeReading() { + boolean resumed; + boolean resumedInOtherThread = false; + QuicVirtualThreadPoller poller; + synchronized (this) { + resumed = readingStalled; + readingStalled = false; + poller = this.poller; + // readingDone is false here, it means reading already resumed + // no need to start a new reading thread + if (poller != null && (resumedInOtherThread = readingDone)) { + readingDone = false; + attach(poller); + } + } + if (resumedInOtherThread) { + // last time readingPaused() was called it returned true, so we know + // the previous poller thread has stopped reading and will exit. + // We attached a new poller above, so reading will resume in that + // other thread + // System.out.println(this + " endpoint resumed reading in new virtual thread"); + if (debug.on()) debug.log("endpoint resumed reading in new virtual thread"); + } else if (resumed) { + // readingStalled was true, and readingDone was false - which means some + // poller thread is already active, and will find readingStalled == true + // and will continue reading. So reading will resume in the currently + // active poller thread + // System.out.println(this + " endpoint resumed reading in same virtual thread"); + if (debug.on()) debug.log("endpoint resumed reading in same virtual thread"); + } // if readingStalled was false and readingDone was false there is nothing to do. + } + + @Override + void pauseReading() { + boolean paused = false; + synchronized (this) { + if (bufferTooBig()) paused = readingStalled = true; + } + if (paused) { + // System.out.println(this + " endpoint paused reading"); + if (debug.on()) debug.log("endpoint paused reading"); + } + } + + @Override + public ChannelType channelType() { + return BLOCKING_WITH_VIRTUAL_THREADS; + } + + void attach(QuicVirtualThreadPoller poller) { + this.poller = poller; + var future = poller.startReading(this); + synchronized (this) { + this.key = future; + } + } + + Executor writeLoopExecutor() { + QuicVirtualThreadPoller poller = this.poller; + if (poller == null) return executor; + return poller.readLoopExecutor(); + } + + private final SequentialScheduler channelScheduler = SequentialScheduler.lockingScheduler(this::channelReadLoop0); + + @Override + void channelReadLoop() { + channelScheduler.runOrSchedule(); + } + + private void channelReadLoop0() { + super.channelReadLoop(); + } + + @Override + public void detach() { + var key = this.key; + try { + if (key != null) { + // do not interrupt the reading task if running: + // closing the channel later on will ensure that the + // task eventually terminates. + key.cancel(false); + } + } catch (Throwable e) { + if (debug.on()) { + debug.log("Failed to cancel future: " + e); + } + } + } + } + + private ByteBuffer copyOnHeap(ByteBuffer buffer) { + ByteBuffer onHeap = ByteBuffer.allocate(buffer.remaining()); + return onHeap.put(buffer).flip(); + } + + void channelReadLoop() { + // we can't prevent incoming datagram from being received + // at this level of the stack. If there is a datagram available, + // we must read it immediately and put it in the read queue. + // + // We maintain a counter of the number of bytes currently + // in the read queue. If that number exceeds a high watermark + // threshold, we will pause reading, and thus stop adding + // to the queue. + // + // As the read queue gets emptied, reading will be resumed + // when a low watermark threshold is crossed in the other + // direction. + // + // At the moment we have a single channel per endpoint, + // and we're using a single endpoint by default. + // + // We have a single selector thread, and we copy off + // the data from off-heap to on-heap before adding it + // to the queue. + // + // We can therefore do the reading directly in the + // selector thread and offload the parsing (the readLoop) + // to the executor. + // + // The readLoop will in turn resume the reading, if needed, + // when it crosses the low watermark threshold. + // + boolean nonBlocking = channelType() == NON_BLOCKING_WITH_SELECTOR; + int count; + final var buffer = this.receiveBuffer; + buffer.clear(); + final int initialStart = 1; // start readloop at first buffer + // if blocking we want to nudge the scheduler after each read since we don't + // know how much the next receive will take. If non-blocking, we nudge it + // after three consecutive read. + final int maxBeforeStart = nonBlocking ? 3 : 1; // nudge again after 3 buffers + int readLoopStarted = initialStart; + int totalpkt = 0; + try { + int sincepkt = 0; + while (!isClosed() && !readingPaused()) { + var pos = buffer.position(); + var limit = buffer.limit(); + if (debug.on()) + debug.log("receiving with buffer(pos=%s, limit=%s)", pos, limit); + assert pos == 0; + assert limit > pos; + + final SocketAddress source = channel.receive(buffer); + assert source != null || !channel.isBlocking(); + if (source == null) { + if (debug.on()) debug.log("nothing to read..."); + if (nonBlocking) break; + } + + totalpkt++; + sincepkt++; + buffer.flip(); + count = buffer.remaining(); + if (debug.on()) { + debug.log("received %s bytes from %s", count, source); + } + if (count > 0) { + // Optimization: add some basic check here to drop the packet here if: + // - it is too small, it is not a quic packet we would handle + Datagram datagram; + if ((datagram = matchDatagram(source, buffer)) == null) { + if (debug.on()) { + debug.log("dropping invalid packet for this instance (%s bytes)", count); + } + buffer.clear(); + continue; + } + // at this point buffer has been copied. We only buffer what's + // needed. + int rcv = datagram.payload().remaining(); + int buffered = buffer(rcv); + if (debug.on()) { + debug.log("adding %s in read queue from %s, queue size %s, buffered %s, type %s", + rcv, source, readQueue.size(), buffered, datagram.getClass().getSimpleName()); + } + readQueue.add(datagram); + buffer.clear(); + if (--readLoopStarted == 0 || buffered >= MAX_BUFFERED_HIGH) { + readLoopStarted = maxBeforeStart; + if (debug.on()) debug.log("triggering readLoop"); + readLoopScheduler.runOrSchedule(executor); + Deadline now; + Deadline pending; + if (nonBlocking && totalpkt > 1 && (pending = timerQueue.pendingScheduledDeadline()) + .isBefore(now = timeSource().instant())) { + // we have read 3 packets, some events are pending, return + // to the selector to process the event queue + assert this instanceof QuicEndpoint.QuicSelectableEndpoint + : "unexpected endpoint type: " + this.getClass() + "@[" + name + "]"; + assert Thread.currentThread() instanceof QuicSelector.QuicSelectorThread; + if (Log.quicRetransmit() || Log.quicTimer()) { + Log.logQuic(name() + ": reschedule needed: " + Utils.debugDeadline(now, pending) + + ", totalpkt: " + totalpkt + + ", sincepkt: " + sincepkt); + } else if (debug.on()) { + debug.log("reschedule needed: " + Utils.debugDeadline(now, pending) + + ", totalpkt: " + totalpkt + + ", sincepkt: " + sincepkt); + } + timerQueue.processEventsAndReturnNextDeadline(now, executor); + sincepkt = 0; + } + } + // check buffered.get() directly as it may have + // been decremented by the read loop already + if (this.buffered.get() >= MAX_BUFFERED_HIGH) { + // we passed the high watermark, let's pause reading. + // the read loop should already have been kicked + // of above, or will be below when we exit the while + // loop + pauseReading(); + } + } else { + if (debug.on()) debug.log("Dropped empty datagram"); + } + } + // trigger code that will process the received + // datagrams asynchronously + // => Use a sequential scheduler, making sure it never + // runs on this thread. + if (!readQueue.isEmpty() && readLoopStarted != maxBeforeStart) { + if (debug.on()) debug.log("triggering readLoop: queue size " + readQueue.size()); + readLoopScheduler.runOrSchedule(executor); + } + } catch (Throwable t) { + // TODO: special handling for interrupts? + onReadError(t); + } finally { + if (nonBlocking) { + if ((Log.quicRetransmit() && Log.channel()) || Log.quicTimer()) { + Log.logQuic(name() + ": channelReadLoop totalpkt:" + totalpkt); + } else if (debug.on()) { + debug.log("channelReadLoop totalpkt:" + totalpkt); + } + } + } + } + + /** + * This method tries to figure out whether the received packet + * matches a connection, or a stateless reset. + * @param source the source address + * @param buffer the incoming datagram payload + * @return a {@link Datagram} to be processed by the read loop + * if a match is found, or null if the datagram can be dropped + * immediately + */ + private Datagram matchDatagram(SocketAddress source, ByteBuffer buffer) { + HeadersType headersType = QuicPacketDecoder.peekHeaderType(buffer, buffer.position()); + // short header packets whose length is < 21 are never valid + if (headersType == HeadersType.SHORT && buffer.remaining() < 21) { + return null; + } + final ByteBuffer cidbytes = switch (headersType) { + case LONG, SHORT -> peekConnectionBytes(headersType, buffer); + default -> null; + }; + if (cidbytes == null) { + return null; + } + int length = cidbytes.remaining(); + if (length > QuicConnectionId.MAX_CONNECTION_ID_LENGTH) { + return null; + } + if (debug.on()) { + debug.log("headers(%s), connectionId(%d), datagram(%d)", + headersType, cidbytes.remaining(), buffer.remaining()); + } + QuicPacketReceiver connection = findQuicConnectionFor(source, cidbytes, headersType == HeadersType.LONG); + // check stateless reset + if (connection == null) { + if (headersType == HeadersType.SHORT) { + // a short packet may be a stateless reset, or may + // trigger a stateless reset + connection = checkStatelessReset(source, buffer); + if (connection != null) { + // We received a stateless reset, process it later in the readLoop + return new StatelessReset(connection, source, copyOnHeap(buffer)); + } else if (buffer.remaining() > 21) { + // check if we should send a stateless reset + final ByteBuffer reset = idFactory.statelessReset(cidbytes, buffer.remaining() - 1); + if (reset != null) { + // will send stateless reset later from the read loop + return new SendStatelessReset(source, reset); + } + } + return null; // drop unmatched short packets + } + // client can drop all unmatched long quic packets here + if (isClient()) return null; + } + + if (connection != null) { + if (!connection.accepts(source)) return null; + return new QuicDatagram(connection, source, copyOnHeap(buffer)); + } else { + return new UnmatchedDatagram(source, copyOnHeap(buffer)); + } + } + + + private int send(ByteBuffer datagram, SocketAddress destination) throws IOException { + return channel.send(datagram, destination); + } + + void writeLoop() { + try { + writeLoop0(); + } catch (Throwable error) { + if (!expectExceptions && !closed) { + if (Log.errors()) { + Log.logError(name + ": failed to write to channel: " + error); + Log.logError(error); + } + abort(error); + } + } + } + + boolean sendDatagram(QuicDatagram datagram) throws IOException { + int sent; + var payload = datagram.payload(); + var tosend = payload.remaining(); + final var dest = datagram.address(); + if (isClosed() || isChannelClosed()) { + if (debug.on()) { + debug.log("endpoint or channel closed; skipping sending of datagram(%d) to %s", + tosend, dest); + } + return false; + } + if (debug.on()) { + debug.log("sending datagram(%d) to %s", + tosend, dest); + } + sent = send(payload, dest); + if (debug.on()) debug.log("sent %d bytes to %s", sent, dest); + if (sent == 0 && sent != tosend) return false; + assert sent == tosend; + if (datagram.connection != null) { + datagram.connection.datagramSent(datagram); + } + return true; + } + + void onSendError(QuicDatagram datagram, int tosend, IOException x) { + // close the connection this came from? + // close all the connections whose destination is that address? + var connection = datagram.connection(); + var dest = datagram.address(); + String msg = x.getMessage(); + if (msg != null && msg.contains("too big")) { + int max = -1; + if (connection instanceof QuicConnectionImpl cimpl) { + max = cimpl.getMaxDatagramSize(); + } + msg = "on endpoint %s: Failed to send datagram (%s bytes, max: %s) to %s: %s" + .formatted(this.name, tosend, max, dest, x); + if (connection == null && debug.on()) debug.log(msg); + x = new SocketException(msg, x); + } + if (connection != null) { + connection.datagramDiscarded(datagram); + connection.onWriteError(x); + if (!channel.isOpen()) { + abort(x); + } + } + } + + private void writeLoop0() { + // write as much as we can + while (!writeQueue.isEmpty()) { + var datagram = writeQueue.peek(); + var payload = datagram.payload(); + var tosend = payload.remaining(); + try { + if (sendDatagram(datagram)) { + var rem = writeQueue.poll(); + assert rem == datagram; + } else break; + } catch (IOException x) { + // close the connection this came from? + // close all the connections whose destination is that address? + onSendError(datagram, tosend, x); + var rem = writeQueue.poll(); + assert rem == datagram; + } + } + + } + + void closeWriteQueue(Throwable t) { + QuicDatagram qd; + while ((qd = writeQueue.poll()) != null) { + if (qd.connection != null) { + qd.connection.onWriteError(t); + } + } + } + + private ByteBuffer peekConnectionBytes(HeadersType headersType, ByteBuffer payload) { + var cidlen = idFactory.connectionIdLength(); + return switch (headersType) { + case LONG -> QuicPacketDecoder.peekLongConnectionId(payload); + case SHORT -> QuicPacketDecoder.peekShortConnectionId(payload, cidlen); + default -> null; + }; + } + + // The readloop is triggered whenever new datagrams are + // added to the read queue. + void readLoop() { + try { + if (debug.on()) debug.log("readLoop"); + while (!readQueue.isEmpty()) { + var datagram = readQueue.poll(); + var payload = datagram.payload(); + var source = datagram.address(); + int remaining = payload.remaining(); + var pos = payload.position(); + unbuffer(remaining); + if (debug.on()) { + debug.log("readLoop: type(%x) %d from %s", + payload.hasRemaining() ? payload.get(0) : 0, + remaining, + source); + } + try { + if (closed) { + if (debug.on()) { + debug.log("closed: ignoring incoming datagram"); + } + return; + } + switch (datagram) { + case QuicDatagram(var connection, var _, var _) -> { + var headersType = QuicPacketDecoder.peekHeaderType(payload, pos); + var destConnId = peekConnectionBytes(headersType, payload); + connection.processIncoming(source, destConnId, headersType, payload); + } + case UnmatchedDatagram(var _, var _) -> { + var headersType = QuicPacketDecoder.peekHeaderType(payload, pos); + unmatchedQuicPacket(datagram, headersType, payload); + } + case StatelessReset(var connection, var _, var _) -> { + if (debug.on()) { + debug.log("Processing stateless reset from %s", source); + } + connection.processStatelessReset(); + } + case SendStatelessReset(var _, var _) -> { + if (debug.on()) { + debug.log("Sending stateless reset to %s", source); + } + send(payload, source); + } + } + + } catch (Throwable t) { + if (debug.on()) debug.log("Failed to handle datagram: " + t, t); + Log.logError(t); + } + } + } catch (Throwable t) { + onReadError(t); + } + } + + private void onReadError(Throwable t) { + if (!expectExceptions) { + if (debug.on()) { + debug.log("Error handling event: ", t); + } + Log.logError(t); + if (t instanceof RejectedExecutionException + || t instanceof ClosedChannelException + || t instanceof AssertionError) { + expectExceptions = true; + abort(t); + } + } + } + + /** + * checks if the received datagram contains a stateless reset token; + * returns the associated connection if true, null otherwise + * @param source the sender's address + * @param buffer datagram contents + * @return connection associated with the stateless token, or {@code null} + */ + protected QuicPacketReceiver checkStatelessReset(SocketAddress source, final ByteBuffer buffer) { + // We couldn't identify the connection: maybe that's a stateless reset? + if (closed) return null; + if (debug.on()) { + debug.log("Check if received datagram could be stateless reset (datagram[%d, %s])", + buffer.remaining(), source); + } + if (buffer.remaining() < 21) { + // too short to be a stateless reset: + // RFC 9000: + // Endpoints MUST discard packets that are too small to be valid QUIC packets. + // To give an example, with the set of AEAD functions defined in [QUIC-TLS], + // short header packets that are smaller than 21 bytes are never valid. + if (debug.on()) { + debug.log("Packet too short for a stateless reset (%s bytes < 21)", + buffer.remaining()); + } + return null; + } + final byte[] tokenBytes = new byte[16]; + buffer.get(buffer.limit() - 16, tokenBytes); + final var token = makeToken(tokenBytes); + QuicPacketReceiver connection = peerIssuedResetTokens.get(token); + if (closed) return null; + if (connection != null) { + if (debug.on()) { + debug.log("Received reset token: %s for connection: %s", + HexFormat.of().formatHex(tokenBytes), connection); + } + } else { + if (debug.on()) { + debug.log("Not a stateless reset"); + } + } + return connection; + } + + private StatelessResetToken makeToken(byte[] tokenBytes) { + // encrypt token to block timing attacks, see RFC 9000 section 10.3.1 + try { + Cipher cipher = Cipher.getInstance("AES/ECB/NoPadding"); + cipher.init(Cipher.ENCRYPT_MODE, tokenEncryptionKey); + byte[] encryptedBytes = cipher.doFinal(tokenBytes); + return new StatelessResetToken(encryptedBytes); + } catch (NoSuchAlgorithmException | NoSuchPaddingException | + IllegalBlockSizeException | BadPaddingException | + InvalidKeyException e) { + throw new RuntimeException("AES encryption failed", e); + } + } + + /** + * Called when parsing a quic packet that couldn't be matched to any registered + * connection. + * + * @param datagram The datagram containing the packet + * @param headersType The quic packet type + * @param buffer A buffer positioned at the start of the unmatched quic packet. + * The buffer may contain more coalesced quic packets. + */ + protected void unmatchedQuicPacket(Datagram datagram, + HeadersType headersType, + ByteBuffer buffer) throws IOException { + QuicInstance instance = quicInstance; + if (closed) { + if (debug.on()) { + debug.log("closed: ignoring unmatched datagram"); + } + return; + } + + var address = datagram.address(); + if (isServer() && headersType == HeadersType.LONG ) { + // long packets need to be rematched here for servers. + // we read packets in one thread and process them here in + // a different thread: + // the connection may have been added later on when processing + // a previous long packet in this thread, so we need to + // check the connection map again here. + var idbytes = peekConnectionBytes(headersType, buffer); + var connection = findQuicConnectionFor(address, idbytes, true); + if (connection != null) { + // a matching connection was found, this packet is no longer + // unmatched + if (connection.accepts(address)) { + connection.processIncoming(address, idbytes, headersType, buffer); + } + return; + } + } + + if (debug.on()) { + debug.log("Unmatched packet in datagram [%s, %d, %s] for %s", headersType, + buffer.remaining(), address, instance); + debug.log("Unmatched packet: delegating to instance"); + } + instance.unmatchedQuicPacket(address, headersType, buffer); + } + + private boolean isServer() { + return !isClient(); + } + + private boolean isClient() { + return quicInstance instanceof QuicClient; + } + + // Parses the list of active connection + // Attempts to find one that matches + // If none match return null + // Revisit: + // if we had an efficient sorted tree where we could locate a connection id + // from the idbytes we wouldn't need to use an "unsafe connection id" + // quick and dirty solution for now: we use a ConcurrentHashMap and construct + // a throw away QuicConnectionId that wrap our mutable idbytes. + // This is OK since the classes that may see these bytes are all internal + // and won't mutate them. + QuicPacketReceiver findQuicConnectionFor(SocketAddress peerAddress, ByteBuffer idbytes, boolean longHeaders) { + if (idbytes == null) return null; + var cid = idFactory.unsafeConnectionIdFor(idbytes); + if (cid == null) { + if (!longHeaders || isClient()) { + if (debug.on()) { + debug.log("No connection match for: %s", Utils.asHexString(idbytes)); + } + return null; + } + // this is a long headers packet and we're the server; + // the client might still be using the original connection ID + cid = new PeerConnectionId(idbytes, null); + } + if (debug.on()) { + debug.log("Looking up QuicConnection for: %s", cid); + } + var quicConnection = connections.get(cid); + assert quicConnection == null || allConnectionIds(quicConnection).anyMatch(cid::equals); + return quicConnection; + } + + private static Stream allConnectionIds(QuicPacketReceiver quicConnection) { + return Stream.concat(quicConnection.connectionIds().stream(), quicConnection.initialConnectionId().stream()); + } + + /** + * Detach the channel from the selector implementation + */ + public abstract void detach(); + + private void silentTerminateConnection(QuicPacketReceiver c) { + try { + if (c instanceof QuicConnectionImpl connection) { + final TerminationCause st = forSilentTermination("QUIC endpoint closed - no error"); + connection.terminator.terminate(st); + } + } catch (Throwable t) { + if (debug.on()) { + debug.log("Failed to close connection %s: %s", c, t); + } + } finally { + if (c != null) c.shutdown(); + } + } + + // Called in case of RejectedExecutionException, or shutdownNow; + void abortConnection(QuicPacketReceiver c, Throwable error) { + try { + if (c instanceof QuicConnectionImpl connection) { + connection.terminator.terminate(TerminationCause.forException(error)); + } + } catch (Throwable t) { + if (debug.on()) { + debug.log("Failed to close connection %s: %s", c, t); + } + } finally { + if (c != null) c.shutdown(); + } + } + + boolean isClosed() { + return closed; + } + + private void detachAndCloseChannel() throws IOException { + try { + detach(); + } finally { + channel.close(); + } + } + + volatile boolean expectExceptions; + + @Override + public void close() { + if (closed) return; + synchronized (this) { + if (closed) return; + closed = true; + } + try { + while (!connections.isEmpty()) { + if (debug.on()) { + debug.log("closing %d connections", connections.size()); + } + final Set connCloseSent = new HashSet<>(); + for (var cid : connections.keySet()) { + // endpoint is closing, so (on a best-effort basis) we send out a datagram + // containing a QUIC packet with a CONNECTION_CLOSE frame to the peer. + // Immediately after that, we silently terminate the connection since + // there's no point maintaining the connection's infrastructure for + // sending (or receiving) additional packets when the endpoint itself + // won't be around for dealing with the packets. + final QuicPacketReceiver rcvr = connections.remove(cid); + if (rcvr instanceof QuicConnectionImpl quicConn) { + final boolean shouldSendConnClose = connCloseSent.add(quicConn); + // send the datagram containing the CONNECTION_CLOSE frame only once + // per connection + if (shouldSendConnClose) { + sendConnectionCloseQuietly(quicConn); + } + } + silentTerminateConnection(rcvr); + } + } + } finally { + try { + // TODO: do we need to wait for something (ACK?) + // before actually stopping all loop and closing the channel? + if (debug.on()) { + debug.log("Closing channel " + channel + " of endpoint " + this); + } + writeLoopScheduler.stop(); + readLoopScheduler.stop(); + QuicDatagram datagram; + while ((datagram = writeQueue.poll()) != null) { + if (datagram.connection != null) { + datagram.connection.datagramDropped(datagram); + } + } + expectExceptions = true; + detachAndCloseChannel(); + } catch (IOException io) { + if (debug.on()) + debug.log("Failed to detach and close channel: " + io); + } + } + } + + // sends a datagram with a CONNECTION_CLOSE frame for the connection and ignores + // any exceptions that may occur while trying to do so. + private void sendConnectionCloseQuietly(final QuicConnectionImpl quicConn) { + try { + final Optional datagram = quicConn.connectionCloseDatagram(); + if (datagram.isEmpty()) { + return; + } + if (debug.on()) { + debug.log("sending CONNECTION_CLOSE datagram for connection %s", quicConn); + } + send(datagram.get().payload(), datagram.get().address()); + } catch (Exception e) { + // ignore + if (debug.on()) { + debug.log("failed to send CONNECTION_CLOSE datagram for" + + " connection %s due to %s", quicConn, e); + } + } + } + + // Called in case of RejectedExecutionException, or shutdownNow; + public void abort(Throwable error) { + + if (closed) return; + synchronized (this) { + if (closed) return; + closed = true; + } + assert closed; + if (debug.on()) { + debug.log("aborting: " + error); + } + writeLoopScheduler.stop(); + readLoopScheduler.stop(); + QuicDatagram datagram; + while ((datagram = writeQueue.poll()) != null) { + if (datagram.connection != null) { + datagram.connection.datagramDropped(datagram); + } + } + try { + while (!connections.isEmpty()) { + if (debug.on()) + debug.log("closing %d connections", connections.size()); + for (var cid : connections.keySet()) { + abortConnection(connections.remove(cid), error); + } + } + } finally { + try { + if (debug.on()) { + debug.log("Closing channel " + channel + " of endpoint " + this); + } + detachAndCloseChannel(); + } catch (IOException io) { + if (debug.on()) + debug.log("Failed to detach and close channel: " + io); + } + } + } + + + @Override + public String toString() { + return name; + } + + boolean forceSendAsync() { + return DGRAM_SEND_ASYNC || !writeQueue.isEmpty(); + // TODO remove + // perform all writes in a virtual thread. This should trigger + // JDK-8334574 more frequently. + // || (IS_WINDOWS + // && channelType().isBlocking() + // && !Thread.currentThread().isVirtual()); + } + + /** + * Schedule a datagram for writing to the underlying channel. + * If any datagram is pending the given datagram is appended + * to the list of pending datagrams for writing. + * @param source the source connection + * @param destination the destination address + * @param payload the encrypted datagram + */ + public void pushDatagram(QuicPacketReceiver source, SocketAddress destination, ByteBuffer payload) { + int tosend = payload.remaining(); + if (debug.on()) { + debug.log("attempting to send datagram [%s bytes]", tosend); + } + var datagram = new QuicDatagram(source, destination, payload); + try { + // if DGRAM_SEND_ASYNC is true we don't attempt to send from the current + // thread but push the datagram on the queue and invoke the write loop. + if (forceSendAsync() || !sendDatagram(datagram)) { + if (tosend == payload.remaining()) { + writeQueue.add(datagram); + if (debug.on()) { + debug.log("datagram [%s bytes] added to write queue, queue size %s", + tosend, writeQueue.size()); + } + writeLoopScheduler.runOrSchedule(writeLoopExecutor()); + } else { + source.datagramDropped(datagram); + if (debug.on()) { + debug.log("datagram [%s bytes] dropped: payload partially consumed, remaining %s", + tosend, payload.remaining()); + } + } + } + } catch (IOException io) { + onSendError(datagram, tosend, io); + } + } + + /** + * Called to schedule sending of a datagram that contains a {@code ConnectionCloseFrame}. + * This will replace the {@link QuicConnectionImpl} with a {@link ClosedConnection} that + * will replay the datagram containing the {@code ConnectionCloseFrame} whenever a packet + * for that connection is received. + * @param connection the connection being closed + * @param destination the peer address + * @param datagram the datagram + */ + public void pushClosingDatagram(QuicConnectionImpl connection, InetSocketAddress destination, ByteBuffer datagram) { + if (debug.on()) debug.log("Pushing closing datagram for " + connection.logTag()); + closing(connection, datagram.slice()); + pushDatagram(connection, destination, datagram); + } + + /** + * Called to schedule sending of a datagram that contains a single {@code ConnectionCloseFrame} + * sent in response to a {@code ConnectionClose} frame. + * This will completely remove the connection from the connection map. + * @param connection the connection being closed + * @param destination the peer address + * @param datagram the datagram + */ + public void pushClosedDatagram(QuicConnectionImpl connection, + InetSocketAddress destination, + ByteBuffer datagram) { + if (debug.on()) debug.log("Pushing closed datagram for " + connection.logTag()); + removeConnection(connection); + pushDatagram(connection, destination, datagram); + } + + /** + * This will completely remove the connection from the endpoint. Any subsequent packets + * directed to this connection from a peer, may end up receiving a stateless reset + * from this endpoint. + * + * @param connection the connection to be removed + */ + void removeConnection(final QuicPacketReceiver connection) { + if (debug.on()) debug.log("removing connection " + connection); + // remove the connection completely + connection.connectionIds().forEach(connections::remove); + assert !connections.containsValue(connection) : connection; + // remove references to this connection from the map which holds the peer issued + // reset tokens + dropPeerIssuedResetTokensFor(connection); + } + + /** + * Add the cid to connection mapping to the endpoint. + * + * @param cid the connection ID to be added + * @param connection the connection that should be mapped to the cid + * @return true if connection ID was added, false otherwise + */ + public boolean addConnectionId(QuicConnectionId cid, QuicPacketReceiver connection) { + var old = connections.putIfAbsent(cid, connection); + return old == null; + } + + /** + * Remove the cid to connection mapping from the endpoint. + * + * @param cid the connection ID to be removed + * @param connection the connection that is mapped to the cid + * @return true if connection ID was removed, false otherwise + */ + public boolean removeConnectionId(QuicConnectionId cid, QuicPacketReceiver connection) { + if (debug.on()) debug.log("removing connection ID " + cid); + return connections.remove(cid, connection); + } + + public final int connectionCount() { + return connections.size(); + } + + // drop peer issued stateless tokes for the given connection + private void dropPeerIssuedResetTokensFor(QuicPacketReceiver connection) { + // remove references to this connection from the map which holds the peer issued + // reset tokens + peerIssuedResetTokens.values().removeIf(conn -> connection == conn); + } + + // remap peer issued stateless token from connection `from` to connection `to` + private void remapPeerIssuedResetToken(QuicPacketReceiver from, QuicPacketReceiver to) { + assert from != null; + assert to != null; + peerIssuedResetTokens.replaceAll((tok, c) -> c == from ? to : c); + } + + public void draining(final QuicPacketReceiver connection) { + // remap the connection to a DrainingConnection + if (closed) return; + connection.connectionIds().forEach((id) -> + connections.compute(id, this::remapDraining)); + assert !connections.containsValue(connection) : connection; + } + + private DrainingConnection remapDraining(QuicConnectionId id, QuicPacketReceiver conn) { + if (closed) return null; + var debugOn = debug.on() && !Thread.currentThread().isVirtual(); + if (conn instanceof ClosingConnection closing) { + if (debugOn) debug.log("remapping %s to DrainingConnection", id); + final var draining = closing.toDraining(); + remapPeerIssuedResetToken(closing, draining); + draining.startTimer(); + return draining; + } else if (conn instanceof DrainingConnection draining) { + return draining; + } else if (conn instanceof QuicConnectionImpl impl) { + final long idleTimeout = impl.peerPtoMs() * 3; // 3 PTO + impl.localConnectionIdManager().close(); + if (debugOn) debug.log("remapping %s to DrainingConnection", id); + var draining = new DrainingConnection(conn.connectionIds(), idleTimeout); + // we can ignore stateless reset in the draining state. + remapPeerIssuedResetToken(impl, draining); + draining.startTimer(); + return draining; + } else if (conn == null) { + // connection absent (was probably removed), don't remap to draining + if (debugOn) { + debug.log("no existing connection present for %s, won't remap to draining", id); + } + return null; + } else { + assert false : "unexpected connection type: " + conn; // just remove + return null; + } + } + + protected void closing(QuicConnectionImpl connection, ByteBuffer datagram) { + if (closed) return; + ByteBuffer closing = ByteBuffer.allocate(datagram.limit()); + closing.put(datagram.slice()); + closing.flip(); + connection.connectionIds().forEach((id) -> + connections.compute(id, (i, r) -> remapClosing(i, r, closing))); + assert !connections.containsValue(connection) : connection; + } + + private ClosedConnection remapClosing(QuicConnectionId id, QuicPacketReceiver conn, ByteBuffer datagram) { + if (closed) return null; + var debugOn = debug.on() && !Thread.currentThread().isVirtual(); + if (conn instanceof ClosingConnection closing) { + // we already have a closing datagram, drop the new one + return closing; + } else if (conn instanceof DrainingConnection draining) { + return draining; + } else if (conn instanceof QuicConnectionImpl impl) { + final long idleTimeout = impl.peerPtoMs() * 3; // 3 PTO + impl.localConnectionIdManager().close(); + if (debugOn) debug.log("remapping %s to ClosingConnection", id); + var closing = new ClosingConnection(conn.connectionIds(), idleTimeout, datagram); + remapPeerIssuedResetToken(impl, closing); + closing.startTimer(); + return closing; + } else if (conn == null) { + // connection absent (was probably removed), don't remap to closing + if (debugOn) { + debug.log("no existing connection present for %s, won't remap to closing", id); + } + return null; + } else { + assert false : "unexpected connection type: " + conn; // just remove + return null; + } + } + + public void registerNewConnection(QuicConnectionImpl quicConnection) throws IOException { + if (closed) throw new ClosedChannelException(); + quicConnection.connectionIds().forEach((id) -> putConnection(id, quicConnection)); + } + + /** + * A peer issues a stateless reset token which it can then send to close the connection. This + * method links the peer issued token against the connection that needs to be closed if/when + * that stateless reset token arrives in the packet. + * + * @param statelessResetToken the peer issued (16 byte) stateless reset token + * @param connection the connection to link the token against + */ + void associateStatelessResetToken(final byte[] statelessResetToken, final QuicPacketReceiver connection) { + Objects.requireNonNull(connection); + Objects.requireNonNull(statelessResetToken); + final int tokenLength = statelessResetToken.length; + if (statelessResetToken.length != 16) { + throw new IllegalArgumentException("Invalid stateless reset token length " + tokenLength); + } + if (debug.on()) { + debug.log("associating stateless reset token with connection %s", connection); + } + this.peerIssuedResetTokens.put(makeToken(statelessResetToken), connection); + } + + /** + * Discard the stateless reset token that this endpoint might have previously + * {@link #associateStatelessResetToken(byte[], QuicPacketReceiver) associated any connection} + * with + * @param statelessResetToken The stateless reset token + */ + void forgetStatelessResetToken(final byte[] statelessResetToken) { + // just a tiny optimization - we know stateless reset token must be of 16 bytes, if the passed + // value isn't, then no point doing any more work + if (statelessResetToken.length != 16) { + return; + } + this.peerIssuedResetTokens.remove(makeToken(statelessResetToken)); + } + + /** + * {@return the timer queue associated with this endpoint} + */ + public QuicTimerQueue timer() { + return timerQueue; + } + + public boolean isChannelClosed() { + return !channel().isOpen(); + } + + /** + * {@return the time source associated with this endpoint} + * @apiNote + * There is a unique global {@linkplain TimeSource#source()} for the whole + * JVM, but this method can be overridden in tests to define an alternative + * timeline for the test. + */ + protected TimeLine timeSource() { + return TimeSource.source(); + } + + private void putConnection(QuicConnectionId id, QuicConnectionImpl quicConnection) { + // ideally we'd want to use an immutable byte buffer as a key here. + // but we don't have that. So we use the connection id instead. + var old = connections.put(id, quicConnection); + assert old == null : "%s already registered with %s (%s)" + .formatted(old, id, old == quicConnection ? "old == new" : "old != new"); + } + + + /** + * Represent a closing or draining quic connection: if we receive any packet + * for this connection we ignore them (if in draining state) or replay the + * closed packets in decreasing frequency: we reply to the + * first packet, then to the third, then to the seventh, etc... + * We stop replying after 16*16/2. + */ + sealed abstract class ClosedConnection implements QuicPacketReceiver, QuicTimedEvent + permits QuicEndpoint.ClosingConnection, QuicEndpoint.DrainingConnection { + + // default time we keep the ClosedConnection alive while closing/draining - if + // PTO information is not available (if 0 is passed as idleTimeoutMs when creating + // an instance of this class) + final static long NO_IDLE_TIMEOUT = 2000; + final List localConnectionIds; + final long maxIdleTimeMs; + final long id; + int more = 1; + int waitformore; + volatile Deadline deadline; + volatile Deadline updatedDeadline; + + ClosedConnection(List localConnectionIds, long maxIdleTimeMs) { + this.id = QuicTimerQueue.newEventId(); + this.maxIdleTimeMs = maxIdleTimeMs == 0 ? NO_IDLE_TIMEOUT : maxIdleTimeMs; + this.deadline = Deadline.MAX; + this.updatedDeadline = Deadline.MAX; + this.localConnectionIds = List.copyOf(localConnectionIds); + } + + @Override + public List connectionIds() { + return localConnectionIds; + } + + @Override + public final void processIncoming(SocketAddress source, ByteBuffer destConnId, HeadersType headersType, ByteBuffer buffer) { + Deadline updated = updatedDeadline; + var waitformore = this.waitformore; + // Deadline.MIN will be set in case of write errors + if (updated != Deadline.MIN && waitformore == 0) { + var more = this.more; + this.waitformore = more; + this.more = more = more << 1; + if (more > 16) { + // the server doesn't seem to take into account our + // connection close frame. Just stop responding + updatedDeadline = Deadline.MIN; + } else { + updatedDeadline = updated.plusMillis(maxIdleTimeMs); + } + handleIncoming(source, destConnId, headersType, buffer); + } else { + this.waitformore = waitformore - 1; + dropIncoming(source, destConnId, headersType, buffer); + } + + timer().reschedule(this, updatedDeadline); + } + + protected void handleIncoming(SocketAddress source, ByteBuffer idbytes, + HeadersType headersType, ByteBuffer buffer) { + dropIncoming(source, idbytes, headersType, buffer); + } + + protected abstract void dropIncoming(SocketAddress source, ByteBuffer idbytes, + HeadersType headersType, ByteBuffer buffer); + + @Override + public final void onWriteError(Throwable t) { + if (debug.on()) + debug.log("failed to write close packet", t); + removeConnection(this); + // handle() will be called, which will cause + // the timer queue to remove this object + updatedDeadline = Deadline.MIN; + timer().reschedule(this); + } + + public final void startTimer() { + deadline = updatedDeadline = timeSource().instant().plusMillis(maxIdleTimeMs); + timer().offer(this); + } + + @Override + public final Deadline deadline() { + return deadline; + } + + @Override + public final Deadline handle() { + removeConnection(this); + // Deadline.MAX means do not reschedule + return updatedDeadline = Deadline.MAX; + } + + @Override + public final Deadline refreshDeadline() { + // Returning Deadline.MIN here will cause handle() to + // be called and will remove this task from the timer queue. + return deadline = updatedDeadline; + } + + @Override + public final long eventId() { + return id; + } + + @Override + public final void processStatelessReset() { + // the peer has sent us a stateless reset: no need to + // replay CloseConnectionFrame. Just remove this connection. + removeConnection(this); + // handle() will be called, which will cause + // the timer queue to remove this object + updatedDeadline = Deadline.MIN; + timer().reschedule(this); + } + + public void shutdown() { + updatedDeadline = Deadline.MIN; + timer().reschedule(this); + } + } + + + /** + * Represent a closing quic connection: if we receive any packet for this + * connection we simply replay the packet(s) that contained the + * ConnectionCloseFrame frame. + * Packets are replayed in decreasing frequency. We reply to the + * first packet, then to the third, then to the seventh, etc... + * We stop replying after 16*16/2. + */ + final class ClosingConnection extends ClosedConnection { + + final ByteBuffer closePacket; + + ClosingConnection(List localConnIdManager, long maxIdleTimeMs, + ByteBuffer closePacket) { + super(localConnIdManager, maxIdleTimeMs); + this.closePacket = Objects.requireNonNull(closePacket); + } + + @Override + public void handleIncoming(SocketAddress source, ByteBuffer idbytes, + HeadersType headersType, ByteBuffer buffer) { + if (isClosed() || isChannelClosed()) { + // don't respond with any more datagrams and instead just drop + // the incoming ones since the channel is closed + dropIncoming(source, idbytes, headersType, buffer); + return; + } + if (debug.on()) { + debug.log("ClosingConnection(%s): sending closed packets", localConnectionIds); + } + pushDatagram(this, source, closePacket.asReadOnlyBuffer()); + } + + @Override + protected void dropIncoming(SocketAddress source, ByteBuffer idbytes, HeadersType headersType, ByteBuffer buffer) { + if (debug.on()) { + debug.log("ClosingConnection(%s): dropping %s packet", localConnectionIds, headersType); + } + } + + private DrainingConnection toDraining() { + return new DrainingConnection(localConnectionIds, maxIdleTimeMs); + } + } + + /** + * Represent a draining quic connection: if we receive any packet for this + * connection we simply ignore them. + */ + final class DrainingConnection extends ClosedConnection { + + DrainingConnection(List localConnIdManager, long maxIdleTimeMs) { + super(localConnIdManager, maxIdleTimeMs); + } + + @Override + public void dropIncoming(SocketAddress source, ByteBuffer idbytes, HeadersType headersType, ByteBuffer buffer) { + if (debug.on()) { + debug.log("DrainingConnection(%s): dropping %s packet", + localConnectionIds, headersType); + } + } + + } + + private record StatelessResetToken (byte[] token) { + StatelessResetToken(final byte[] token) { + this.token = token.clone(); + } + @Override + public int hashCode() { + return Arrays.hashCode(token); + } + + @Override + public boolean equals(final Object obj) { + if (obj instanceof StatelessResetToken other) { + return Arrays.equals(token, other.token); + } + return false; + } + } + + /** + * {@return a new {@link QuicEndpoint} of the given {@code endpointType}} + * @param endpointType the concrete endpoint type, one of {@link QuicSelectableEndpoint + * QuicSelectableEndpoint.class} or {@link QuicVirtualThreadedEndpoint + * QuicVirtualThreadedEndpoint.class}. + * @param quicInstance the quic instance + * @param name the endpoint name + * @param bindAddress the address to bind to + * @param timerQueue the timer queue + * @param the concrete endpoint type, one of {@link QuicSelectableEndpoint} + * or {@link QuicVirtualThreadedEndpoint} + * @throws IOException if an IOException occurs + * @throws IllegalArgumentException if the given endpoint type is not one of + * {@link QuicSelectableEndpoint QuicSelectableEndpoint.class} or + * {@link QuicVirtualThreadedEndpoint QuicVirtualThreadedEndpoint.class} + */ + private static T create(Class endpointType, + QuicInstance quicInstance, + String name, + SocketAddress bindAddress, + QuicTimerQueue timerQueue) throws IOException { + DatagramChannel channel = DatagramChannel.open(); + // avoid dependency on extnet + Optional> df = channel.supportedOptions().stream(). + filter(o -> "IP_DONTFRAGMENT".equals(o.name())).findFirst(); + if (df.isPresent()) { + // TODO on some platforms this doesn't work on dual stack sockets + // see Net#shouldSetBothIPv4AndIPv6Options + @SuppressWarnings("unchecked") + var option = (SocketOption) df.get(); + channel.setOption(option, true); + } + if (QuicSelectableEndpoint.class.isAssignableFrom(endpointType)) { + channel.configureBlocking(false); + } + Consumer logSink = Log.quic() ? Log::logQuic : null; + Utils.configureChannelBuffers(logSink, channel, + quicInstance.getReceiveBufferSize(), quicInstance.getSendBufferSize()); + channel.bind(bindAddress); // could do that on attach instead? + + if (endpointType.isAssignableFrom(QuicSelectableEndpoint.class)) { + return endpointType.cast(new QuicSelectableEndpoint(quicInstance, channel, name, timerQueue)); + } else if (endpointType.isAssignableFrom(QuicVirtualThreadedEndpoint.class)) { + return endpointType.cast(new QuicVirtualThreadedEndpoint(quicInstance, channel, name, timerQueue)); + } else { + throw new IllegalArgumentException(endpointType.getName()); + } + } + + public static final class QuicEndpointFactory { + + public QuicEndpointFactory() { + } + + /** + * {@return a new {@code QuicSelectableEndpoint}} + * + * @param quicInstance the quic instance + * @param name the endpoint name + * @param bindAddress the address to bind to + * @param timerQueue the timer queue + * @throws IOException if an IOException occurs + */ + public QuicSelectableEndpoint createSelectableEndpoint(QuicInstance quicInstance, + String name, + SocketAddress bindAddress, + QuicTimerQueue timerQueue) + throws IOException { + return create(QuicSelectableEndpoint.class, quicInstance, name, bindAddress, timerQueue); + } + + /** + * {@return a new {@code QuicVirtualThreadedEndpoint}} + * + * @param quicInstance the quic instance + * @param name the endpoint name + * @param bindAddress the address to bind to + * @param timerQueue the timer queue + * @throws IOException if an IOException occurs + */ + public QuicVirtualThreadedEndpoint createVirtualThreadedEndpoint(QuicInstance quicInstance, + String name, + SocketAddress bindAddress, + QuicTimerQueue timerQueue) + throws IOException { + return create(QuicVirtualThreadedEndpoint.class, quicInstance, name, bindAddress, timerQueue); + } + } + + /** + * Registers the given endpoint with the given selector. + *

+ * An endpoint of class {@link QuicSelectableEndpoint} is only + * compatible with a selector of type {@link QuicNioSelector}. + * An endpoint of tyoe {@link QuicVirtualThreadedEndpoint} is only + * compatible with a selector of type {@link QuicVirtualThreadPoller}. + *
+ * If the given endpoint implementation is not compatible with + * the given selector implementation an {@link IllegalStateException} + * is thrown. + * + * @param endpoint the endpoint + * @param selector the selector + * @param debug a logger for debugging + * + * @throws IOException if an IOException occurs + * @throws IllegalStateException if the endpoint and selector implementations + * are not compatible + */ + public static void registerWithSelector(QuicEndpoint endpoint, QuicSelector selector, Logger debug) + throws IOException { + if (selector instanceof QuicVirtualThreadPoller poller) { + var loopingEndpoint = (QuicVirtualThreadedEndpoint) endpoint; + poller.register(loopingEndpoint); + } else if (selector instanceof QuicNioSelector selectable) { + var selectableEndpoint = (QuicEndpoint.QuicSelectableEndpoint) endpoint; + selectable.register(selectableEndpoint); + } else { + throw new IllegalStateException("Incompatible selector and endpoint implementations: %s <-> %s" + .formatted(selector.getClass(), endpoint.getClass())); + } + if (debug.on()) debug.log("endpoint registered with selector"); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicInstance.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicInstance.java new file mode 100644 index 00000000000..a963625d7f5 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicInstance.java @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.io.IOException; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.util.HexFormat; +import java.util.List; +import java.util.concurrent.Executor; + +import javax.net.ssl.SSLParameters; + +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicVersion; + +/** + * A {@code QuicInstance} represents a common abstraction which is + * either a {@code QuicClient} or a {@code QuicServer}, or possibly + * both. It defines the subset of public methods that a + * {@code QuicEndpoint} and a {@code QuicSelector} need to operate + * with a quic client, or a quic server; + */ +public interface QuicInstance { + + /** + * The executor used by this quic instance when a task needs to + * be offloaded to a separate thread. + * @implNote This is the HttpClientImpl internal executor. + * @return the executor used by this QuicClient. + */ + Executor executor(); + + /** + * {@return an endpoint to associate with a connection} + * @throws IOException + */ + QuicEndpoint getEndpoint() throws IOException; + + /** + * This method is called when a quic packet that couldn't be attributed + * to a registered connection is received. + * @param source the source address of the datagram + * @param type the packet type + * @param buffer A buffer positioned at the start of the quic packet + */ + void unmatchedQuicPacket(SocketAddress source, QuicPacket.HeadersType type, ByteBuffer buffer); + + /** + * {@return true if the passed version is available for use on this instance, false otherwise} + */ + boolean isVersionAvailable(QuicVersion quicVersion); + + /** + * {@return the versions that are available for use on this instance} + */ + List getAvailableVersions(); + + /** + * Instance ID used for debugging traces. + * @return A string uniquely identifying this instance. + */ + String instanceId(); + + /** + * Get the QuicTLSContext used by this quic instance. + * @return the QuicTLSContext used by this quic instance. + */ + QuicTLSContext getQuicTLSContext(); + + QuicTransportParameters getTransportParameters(); + + /** + * The {@link SSLParameters} for this Quic instance. + * May be {@code null} if no parameters have been specified. + * + * @implSpec + * The default implementation of this method returns {@code null}. + * + * @return The {@code SSLParameters} for this quic instance or {@code null}. + */ + default SSLParameters getSSLParameters() { return new SSLParameters(); } + + /** + * {@return the configured {@linkplain java.net.StandardSocketOptions#SO_RCVBUF + * UDP receive buffer} size this instance should use} + */ + default int getReceiveBufferSize() { + return Utils.getIntegerNetProperty( + "jdk.httpclient.quic.receiveBufferSize", + 0 // only set the size if > 0 + ); + } + + /** + * {@return the configured {@linkplain java.net.StandardSocketOptions#SO_SNDBUF + * UDP send buffer} size this instance should use} + */ + default int getSendBufferSize() { + return Utils.getIntegerNetProperty( + "jdk.httpclient.quic.sendBufferSize", + 0 // only set the size if > 0 + ); + } + + /** + * {@return a string describing the given application error code} + * @param errorCode an application error code + * @implSpec By default, this method returns a generic + * string containing the hexadecimal value of the given errorCode. + * Subclasses built for supporting a given application protocol, + * such as HTTP/3, may override this method to return more + * specific names, such as for instance, {@code "H3_REQUEST_CANCELLED"} + * for {@code 0x010c}. + * @apiNote This method is typically used for logging and/or debugging + * purposes, to generate a more user-friendly log message. + */ + default String appErrorToString(long errorCode) { + return "ApplicationError(code=0x" + HexFormat.of().toHexDigits(errorCode) + ")"; + } + + default String name() { + return String.format("%s(%s)", this.getClass().getSimpleName(), instanceId()); + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicPacketReceiver.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicPacketReceiver.java new file mode 100644 index 00000000000..6652b44de3a --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicPacketReceiver.java @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2020, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import jdk.internal.net.http.quic.QuicEndpoint.QuicDatagram; +import jdk.internal.net.http.quic.packets.QuicPacket; + +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Optional; + +/** + * The {@code QuicPacketReceiver} is an abstraction that defines the + * interface between a {@link QuicEndpoint} and a {@link QuicConnection}. + * This defines the minimum set of methods that the endpoint will need + * in order to be able to dispatch a received {@link jdk.internal.net.http.quic.packets.QuicPacket} + * to its destination. This abstraction is typically useful when dealing with + * {@linkplain QuicEndpoint.ClosedConnection + * closed connections, which need to remain alive for a certain time + * after being closed in order to satisfy the requirement of the quic + * protocol (typically for retransmitting the CLOSE_CONNECTION frame + * if needed). + */ +public interface QuicPacketReceiver { + + /** + * {@return a list of local connectionIds for this connection) + */ + List connectionIds(); + + /** + * {@return the initial connection id assigned by the peer} + * On the client side, this is always {@link Optional#empty()}. + * On the server side, it contains the initial connection id + * that was assigned by the client in the first INITIAL packet. + * + * @implSpec + * The default implementation of this method returns {@link Optional#empty()} + */ + default Optional initialConnectionId() { + return Optional.empty(); + } + + /** + * Called when an incoming datagram is received. + *

+ * The buffer is positioned at the start of the datagram to process. + * The buffer may contain more than one QUIC packet. + * + * @param source The peer address, as received from the UDP stack + * @param destConnId Destination connection id bytes included in the packet + * @param headersType The quic packet type + * @param buffer A buffer positioned at the start of the quic packet, + * not yet decrypted, and possibly containing coalesced + * packets. + */ + void processIncoming(SocketAddress source, ByteBuffer destConnId, + QuicPacket.HeadersType headersType, ByteBuffer buffer); + + /** + * Called when a datagram scheduled for writing by this connection + * could not be written to the network. + * @param t the error that occurred + */ + void onWriteError(Throwable t); + + /** + * Called when a stateless reset token is received. + */ + void processStatelessReset(); + + /** + * Called to shut a closed connection down. + * This is the last step when closing a connection, and typically + * only release resources held by all packet spaces. + */ + void shutdown(); + + /** + * Called after a datagram has been written to the socket. + * At this point the datagram's ByteBuffer can typically be released, + * or returned to a buffer pool. + * @implSpec + * The default implementation of this method does nothing. + * @param datagram the datagram that was sent + */ + default void datagramSent(QuicDatagram datagram) { } + + /** + * Called after a datagram has been discarded as a result of + * some error being raised, for instance, when an attempt + * to write it to the socket has failed, or if the encryption + * of a packet in the datagram has failed. + * At this point the datagram's ByteBuffer can typically be released, + * or returned to a buffer pool. + * @implSpec + * The default implementation of this method does nothing. + * @param datagram the datagram that was discarded + */ + default void datagramDiscarded(QuicDatagram datagram) { } + + /** + * Called after a datagram has been dropped. Typically, this + * could happen if the datagram was only partly written, or if + * the connection was closed before the datagram could be sent. + * At this point the datagram's ByteBuffer can typically be released, + * or returned to a buffer pool. + * @implSpec + * The default implementation of this method does nothing. + * @param datagram the datagram that was dropped + */ + default void datagramDropped(QuicDatagram datagram) { } + + /** + * {@return whether this receiver accepts packets from the given source} + * @param source the sender address + */ + default boolean accepts(SocketAddress source) { + return true; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicRenoCongestionController.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicRenoCongestionController.java new file mode 100644 index 00000000000..fde253740d1 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicRenoCongestionController.java @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.quic; + +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.TimeLine; +import jdk.internal.net.http.common.TimeSource; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.frames.AckFrame; +import jdk.internal.net.http.quic.packets.QuicPacket; + +import java.util.Collection; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Implementation of QUIC congestion controller based on RFC 9002. + * This is a QUIC variant of New Reno algorithm. + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9002 + * RFC 9002: QUIC Loss Detection and Congestion Control + */ +class QuicRenoCongestionController implements QuicCongestionController { + // higher of 14720 and 2*maxDatagramSize; we use fixed maxDatagramSize + private static final int INITIAL_WINDOW = Math.max(14720, 2 * QuicConnectionImpl.DEFAULT_DATAGRAM_SIZE); + private static final int MAX_BYTES_IN_FLIGHT = Math.clamp( + Utils.getLongProperty("jdk.httpclient.quic.maxBytesInFlight", 1 << 24), + 1 << 14, 1 << 24); + private final TimeLine timeSource; + private final String dbgTag; + private final Lock lock = new ReentrantLock(); + private long congestionWindow = INITIAL_WINDOW; + private int maxDatagramSize = QuicConnectionImpl.DEFAULT_DATAGRAM_SIZE; + private int minimumWindow = 2 * maxDatagramSize; + private long bytesInFlight; + // maximum bytes in flight seen since the last congestion event + private long maxBytesInFlight; + private Deadline congestionRecoveryStartTime; + private long ssThresh = Long.MAX_VALUE; + + public QuicRenoCongestionController(String dbgTag) { + this.dbgTag = dbgTag; + this.timeSource = TimeSource.source(); + } + + private boolean inCongestionRecovery(Deadline sentTime) { + return (congestionRecoveryStartTime != null && + !sentTime.isAfter(congestionRecoveryStartTime)); + } + + private void onCongestionEvent(Deadline sentTime) { + if (inCongestionRecovery(sentTime)) { + return; + } + congestionRecoveryStartTime = timeSource.instant(); + ssThresh = congestionWindow / 2; + congestionWindow = Math.max(minimumWindow, ssThresh); + maxBytesInFlight = 0; + if (Log.quicCC()) { + Log.logQuic(dbgTag+ " Congestion: ssThresh: " + ssThresh + + ", in flight: " + bytesInFlight + + ", cwnd:" + congestionWindow); + } + } + + private static boolean inFlight(QuicPacket packet) { + // packet is in flight if it contains anything other than a single ACK frame + // specifically, a packet containing padding is considered to be in flight. + return packet.frames().size() != 1 || + !(packet.frames().get(0) instanceof AckFrame); + } + + @Override + public boolean canSendPacket() { + lock.lock(); + try { + if (bytesInFlight >= MAX_BYTES_IN_FLIGHT) { + return false; + } + var canSend = congestionWindow - bytesInFlight >= maxDatagramSize; + return canSend; + } finally { + lock.unlock(); + } + } + + @Override + public void updateMaxDatagramSize(int newSize) { + lock.lock(); + try { + if (minimumWindow != newSize * 2) { + minimumWindow = newSize * 2; + maxDatagramSize = newSize; + congestionWindow = Math.max(congestionWindow, minimumWindow); + } + } finally { + lock.unlock(); + } + } + + @Override + public void packetSent(int packetBytes) { + lock.lock(); + try { + bytesInFlight += packetBytes; + if (bytesInFlight > maxBytesInFlight) { + maxBytesInFlight = bytesInFlight; + } + } finally { + lock.unlock(); + } + } + + @Override + public void packetAcked(int packetBytes, Deadline sentTime) { + lock.lock(); + try { + bytesInFlight -= packetBytes; + // RFC 9002 says we should not increase cwnd when application limited. + // The concept itself is poorly defined. + // Here we limit cwnd growth based on the maximum bytes in flight + // observed since the last congestion event + if (inCongestionRecovery(sentTime)) { + if (Log.quicCC()) { + Log.logQuic(dbgTag+ " Acked, in recovery: bytes: " + packetBytes + + ", in flight: " + bytesInFlight); + } + return; + } + boolean isAppLimited; + if (congestionWindow < ssThresh) { + isAppLimited = congestionWindow >= 2 * maxBytesInFlight; + if (!isAppLimited) { + congestionWindow += packetBytes; + } + } else { + isAppLimited = congestionWindow > maxBytesInFlight + 2L * maxDatagramSize; + if (!isAppLimited) { + congestionWindow += Math.max((long) maxDatagramSize * packetBytes / congestionWindow, 1L); + } + } + if (Log.quicCC()) { + if (isAppLimited) { + Log.logQuic(dbgTag+ " Acked, not blocked: bytes: " + packetBytes + + ", in flight: " + bytesInFlight); + } else { + Log.logQuic(dbgTag + " Acked, increased: bytes: " + packetBytes + + ", in flight: " + bytesInFlight + + ", new cwnd:" + congestionWindow); + } + } + } finally { + lock.unlock(); + } + } + + @Override + public void packetLost(Collection lostPackets, Deadline sentTime, boolean persistent) { + lock.lock(); + try { + for (QuicPacket packet : lostPackets) { + if (inFlight(packet)) { + bytesInFlight -= packet.size(); + } + } + onCongestionEvent(sentTime); + if (persistent) { + congestionWindow = minimumWindow; + congestionRecoveryStartTime = null; + if (Log.quicCC()) { + Log.logQuic(dbgTag+ " Persistent congestion: ssThresh: " + ssThresh + + ", in flight: " + bytesInFlight + + ", cwnd:" + congestionWindow); + } + } + } finally { + lock.unlock(); + } + } + + @Override + public void packetDiscarded(Collection discardedPackets) { + lock.lock(); + try { + for (QuicPacket packet : discardedPackets) { + if (inFlight(packet)) { + bytesInFlight -= packet.size(); + } + } + } finally { + lock.unlock(); + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicRttEstimator.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicRttEstimator.java new file mode 100644 index 00000000000..0e2f9401bc0 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicRttEstimator.java @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.concurrent.TimeUnit; + +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.common.Utils; + +/** + * Estimator for quic connection round trip time. + * Defined in + * RFC 9002 section 5. + * Takes RTT samples as input (max 1 sample per ACK frame) + * Produces: + * - minimum RTT over a period of time (minRtt) for internal use + * - exponentially weighted moving average (smoothedRtt) + * - mean deviation / variation in the observed samples (rttVar) + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9002 + * RFC 9002: QUIC Loss Detection and Congestion Control + */ +public class QuicRttEstimator { + + // The property indicates the maximum number of retries. The constant holds the value + // 2^N where N is the value of the property (clamped between 2 and 20, default 8): + // 1=>2 2=>4 3=>8 4=>16 5=>32 6=>64 7=>128 8=>256 ... etc + public static final long MAX_PTO_BACKOFF = 1L << Math.clamp( + Utils.getIntegerNetProperty("jdk.httpclient.quic.maxPtoBackoff", 8), + 2, 20); + // The timeout calculated for PTO will stay clamped at MAX_PTO_BACKOFF_TIMEOUT if + // the calculated value exceeds MAX_PTO_BACKOFF_TIMEOUT + public static final Duration MAX_PTO_BACKOFF_TIMEOUT = Duration.ofSeconds(Math.clamp( + Utils.getIntegerNetProperty("jdk.httpclient.quic.maxPtoBackoffTime", 240), + 1, 1200)); + // backoff will continue to be increased past MAX_PTO_BACKOFF if the timeout calculated + // for PTO is less than MIN_PTO_BACKOFF_TIMEOUT + public static final Duration MIN_PTO_BACKOFF_TIMEOUT = Duration.ofSeconds(Math.clamp( + Utils.getIntegerNetProperty("jdk.httpclient.quic.minPtoBackoffTime", 15), + 0, 1200)); + + private static final long INITIAL_RTT = TimeUnit.MILLISECONDS.toMicros(Math.clamp( + Utils.getIntegerNetProperty("jdk.httpclient.quic.initialRTT", 333), + 50, 1000)); + + // kGranularity, 1ms is recommended by RFC 9002 section 6.1.2 + private static final long GRANULARITY_MICROS = TimeUnit.MILLISECONDS.toMicros(1); + private Deadline firstSample; + private long latestRttMicros; + private long minRttMicros; + private long smoothedRttMicros = INITIAL_RTT; + private long rttVarMicros = INITIAL_RTT / 2; + private long ptoBackoffFactor = 1; + private long rttSampleCount = 0; + + public record QuicRttEstimatorState(long latestRttMicros, + long minRttMicros, + long smoothedRttMicros, + long rttVarMicros, + long rttSampleCount) {} + + public synchronized QuicRttEstimatorState state() { + return new QuicRttEstimatorState(latestRttMicros, minRttMicros, smoothedRttMicros, rttVarMicros, rttSampleCount); + } + + /** + * Update the estimator with latest RTT sample. + * Use only samples where: + * - the largest acknowledged PN is newly acknowledged + * - at least one of the newly acked packets is ack-eliciting + * @param latestRttMicros time between when packet was sent + * and ack was received, in microseconds + * @param ackDelayMicros ack delay received in ack frame, decoded to microseconds + * @param now time at which latestRttMicros was calculated + */ + public synchronized void consumeRttSample(long latestRttMicros, long ackDelayMicros, Deadline now) { + this.rttSampleCount += 1; + this.latestRttMicros = latestRttMicros; + if (firstSample == null) { + firstSample = now; + minRttMicros = latestRttMicros; + smoothedRttMicros = latestRttMicros; + rttVarMicros = latestRttMicros / 2; + } else { + minRttMicros = Math.min(minRttMicros, latestRttMicros); + long adjustedRtt; + if (latestRttMicros >= minRttMicros + ackDelayMicros) { + adjustedRtt = latestRttMicros - ackDelayMicros; + } else { + adjustedRtt = latestRttMicros; + } + rttVarMicros = (3 * rttVarMicros + Math.abs(smoothedRttMicros - adjustedRtt)) / 4; + smoothedRttMicros = (7 * smoothedRttMicros + adjustedRtt) / 8; + } + } + + /** + * {@return time threshold for time-based loss detection} + * See + * RFC 9002 section 6.1.2 + * + */ + public synchronized Duration getLossThreshold() { + // max(kTimeThreshold * max(smoothed_rtt, latest_rtt), kGranularity) + long maxRttMicros = Math.max(smoothedRttMicros, latestRttMicros); + long lossThresholdMicros = Math.max(9*maxRttMicros / 8, GRANULARITY_MICROS); + return Duration.of(lossThresholdMicros, ChronoUnit.MICROS); + } + + /** + * {@return the amount of time to wait for acknowledgement of a sent packet, + * excluding max ack delay} + * See + * RFC 9002 section 6.1.2 + */ + public synchronized Duration getBasePtoDuration() { + // PTO = smoothed_rtt + max(4*rttvar, kGranularity) + max_ack_delay + // max_ack_delay is applied by the caller + long basePtoMicros = smoothedRttMicros + + Math.max(4 * rttVarMicros, GRANULARITY_MICROS); + return Duration.of(basePtoMicros, ChronoUnit.MICROS); + } + + public synchronized boolean isMinBackoffTimeoutExceeded() { + return MIN_PTO_BACKOFF_TIMEOUT.compareTo(getBasePtoDuration().multipliedBy(ptoBackoffFactor)) < 0; + } + + public synchronized long getPtoBackoff() { + return ptoBackoffFactor; + } + + public synchronized long increasePtoBackoff() { + // limit to make sure we don't accidentally overflow + if (ptoBackoffFactor <= MAX_PTO_BACKOFF || !isMinBackoffTimeoutExceeded()) { + ptoBackoffFactor *= 2; // can go up to 2 * MAX_PTO_BACKOFF + } + return ptoBackoffFactor; + } + + public synchronized void resetPtoBackoff() { + ptoBackoffFactor = 1; + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicSelector.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicSelector.java new file mode 100644 index 00000000000..2bce3415b17 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicSelector.java @@ -0,0 +1,536 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.io.IOException; +import java.nio.channels.CancelledKeyException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.ClosedSelectorException; +import java.nio.channels.DatagramChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.time.temporal.ChronoUnit; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReentrantLock; + +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.TimeLine; +import jdk.internal.net.http.common.TimeSource; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.QuicEndpoint.QuicVirtualThreadedEndpoint; +import jdk.internal.net.http.quic.QuicEndpoint.QuicSelectableEndpoint; + + +/** + * A QUIC selector to select over one or several quic transport + * endpoints. + */ +public abstract sealed class QuicSelector implements Runnable, AutoCloseable + permits QuicSelector.QuicNioSelector, QuicSelector.QuicVirtualThreadPoller { + + /** + * The maximum timeout passed to Selector::select. + */ + public static final long IDLE_PERIOD_MS = 1500; + + private static final TimeLine source = TimeSource.source(); + final Logger debug = Utils.getDebugLogger(this::name); + + private final String name; + private volatile boolean done; + private final QuicInstance instance; + private final QuicSelectorThread thread; + private final QuicTimerQueue timerQueue; + + private QuicSelector(QuicInstance instance, String name) { + this.instance = instance; + this.name = name; + this.timerQueue = new QuicTimerQueue(this::wakeup, debug); + this.thread = new QuicSelectorThread(this); + } + + public String name() { + return name; + } + + // must be overridden by subclasses + public void register(T endpoint) throws ClosedChannelException { + if (debug.on()) debug.log("attaching %s", endpoint); + } + + // must be overridden by subclasses + public void wakeup() { + if (debug.on()) debug.log("waking up selector"); + } + + public QuicTimerQueue timer() { + return timerQueue; + } + + /** + * A {@link QuicSelector} implementation based on blocking + * {@linkplain DatagramChannel Datagram Channels} and using a + * Virtual Threads to poll the channels. + * This implementation is tied to {@link QuicVirtualThreadedEndpoint} instances. + */ + static final class QuicVirtualThreadPoller extends QuicSelector { + + static final boolean usePlatformThreads = + Utils.getBooleanProperty("jdk.internal.httpclient.quic.poller.usePlatformThreads", false); + + static final class EndpointTask implements Runnable { + + final QuicVirtualThreadedEndpoint endpoint; + final ConcurrentLinkedQueue endpoints; + EndpointTask(QuicVirtualThreadedEndpoint endpoint, + ConcurrentLinkedQueue endpoints) { + this.endpoint = endpoint; + this.endpoints = endpoints; + } + + public void run() { + try { + endpoint.channelReadLoop(); + } finally { + endpoints.remove(this); + } + } + } + + private final Object waiter = new Object(); + private final ConcurrentLinkedQueue endpoints = new ConcurrentLinkedQueue<>(); + private final ReentrantLock stateLock = new ReentrantLock(); + private final ExecutorService virtualThreadService; + + private volatile long wakeups; + + + private QuicVirtualThreadPoller(QuicInstance instance, String name) { + super(instance, name); + virtualThreadService = usePlatformThreads + ? Executors.newThreadPerTaskExecutor(Thread.ofPlatform() + .name(name + "-pt-worker", 1).factory()) + : Executors.newThreadPerTaskExecutor(Thread.ofVirtual() + .name(name + "-vt-worker-", 1).factory()); + if (debug.on()) debug.log("created"); + } + + ExecutorService readLoopExecutor() { + return virtualThreadService; + } + + @Override + public void register(QuicVirtualThreadedEndpoint endpoint) throws ClosedChannelException { + super.register(endpoint); + endpoint.attach(this); + } + + public Future startReading(QuicVirtualThreadedEndpoint endpoint) { + EndpointTask task; + stateLock.lock(); + try { + if (done()) throw new ClosedSelectorException(); + task = new EndpointTask(endpoint, endpoints); + endpoints.add(task); + return virtualThreadService.submit(task); + } finally { + stateLock.unlock(); + } + } + + void markDone() { + // use stateLock to prevent startReading + // to be called *after* shutdown. + stateLock.lock(); + try { + super.shutdown(); + } finally { + stateLock.unlock(); + } + } + + @Override + public void shutdown() { + markDone(); + try { + virtualThreadService.shutdown(); + } finally { + wakeup(); + } + } + + @Override + public void wakeup() { + super.wakeup(); + synchronized (waiter) { + wakeups++; + // there's only one thread that can be waiting + // on waiter - the thread that executes the run() + // method. + waiter.notify(); + } + } + + @Override + public void run() { + try { + if (debug.on()) debug.log("started"); + long waited = 0; + while (!done()) { + var wakeups = this.wakeups; + long timeout = Math.min(computeNextDeadLine(), IDLE_PERIOD_MS); + if (Log.quicTimer()) { + Log.logQuic(String.format("%s: wait(%s) wakeups:%s (+%s), waited:%s", + name(), timeout, this.wakeups, this.wakeups - wakeups, waited)); + } else if (debug.on()) { + debug.log("wait(%s) wakeups:%s (+%s), waited: %s", + timeout, this.wakeups, this.wakeups - wakeups, waited); + } + long wwaited = -1; + synchronized (waiter) { + if (done()) return; + if (wakeups == this.wakeups) { + var start = System.nanoTime(); + waiter.wait(timeout); + var stop = System.nanoTime(); + wwaited = waited = (stop - start) / 1000_000; + } else waited = 0; + } + if (wwaited != -1 && wwaited < timeout) { + if (Log.quicTimer()) { + Log.logQuic(String.format("%s: waked up early: waited %s, timeout %s", + name(), waited, timeout)); + } + } + } + } catch (Throwable t) { + if (done()) return; + if (debug.on()) debug.log("Selector failed", t); + if (Log.errors()) { + Log.logError("QuicVirtualThreadPoller: selector exiting due to " + t); + Log.logError(t); + } + abort(t); + } finally { + if (debug.on()) debug.log("exiting"); + if (!done()) markDone(); + timer().stop(); + endpoints.removeIf(this::close); + virtualThreadService.close(); + } + } + + boolean close(EndpointTask task) { + try { + task.endpoint.close(); + } catch (Throwable e) { + if (debug.on()) { + debug.log("Failed to close endpoint %s: %s", task.endpoint.name(), e); + } + } + return true; + } + + boolean abort(EndpointTask task, Throwable error) { + try { + task.endpoint.abort(error); + } catch (Throwable e) { + if (debug.on()) { + debug.log("Failed to close endpoint %s: %s", task.endpoint.name(), e); + } + } + return true; + } + + @Override + public void abort(Throwable t) { + super.shutdown(); + endpoints.removeIf(task -> abort(task, t)); + super.abort(t); + } + } + + /** + * A {@link QuicSelector} implementation based on non-blocking + * {@linkplain DatagramChannel Datagram Channels} and using a + * NIO {@link Selector}. + * This implementation is tied to {@link QuicSelectableEndpoint} instances. + */ + static final class QuicNioSelector extends QuicSelector { + final Selector selector; + + private QuicNioSelector(QuicInstance instance, Selector selector, String name) { + super(instance, name); + this.selector = selector; + if (debug.on()) debug.log("created"); + } + + + public void register(QuicSelectableEndpoint endpoint) throws ClosedChannelException { + super.register(endpoint); + endpoint.attach(selector); + selector.wakeup(); + } + + public void wakeup() { + super.wakeup(); + selector.wakeup(); + } + + /** + * Shuts down the {@code QuicSelector} by marking the + * {@linkplain QuicSelector#shutdown() selector done}, + * and {@linkplain Selector#wakeup() waking up the + * selector thread}. + * Upon waking up, the selector thread will invoke + * {@link Selector#close()}. + * This method doesn't wait for the selector thread to terminate. + * @see #awaitTermination(long, TimeUnit) + */ + public void shutdown() { + super.shutdown(); + selector.wakeup(); + } + + @Override + public void run() { + try { + if (debug.on()) debug.log("started"); + while (!done()) { + long timeout = Math.min(computeNextDeadLine(), IDLE_PERIOD_MS); + // selected = 0 indicates that no key had its ready ops changed: + // it doesn't mean that no key is ready. Therefore - if a key + // was ready to read, and is again ready to read, it doesn't + // increment the selected count. + if (debug.on()) debug.log("select(%s)", timeout); + int selected = selector.select(timeout); + var selectedKeys = selector.selectedKeys(); + if (debug.on()) { + debug.log("Selected: changes=%s, keys=%s", selected, selectedKeys.size()); + } + + // We do not synchronize on selectedKeys: selectedKeys is only + // modified in this thread, whether directly, by calling selectedKeys.clear() below, + // or indirectly, by calling selector.close() below. + for (var key : selectedKeys) { + QuicSelectableEndpoint endpoint = (QuicSelectableEndpoint) key.attachment(); + if (debug.on()) { + debug.log("selected(%s): %s", Utils.readyOps(key), endpoint); + } + try { + endpoint.selected(key.readyOps()); + } catch (CancelledKeyException x) { + if (debug.on()) { + debug.log("Key for %s cancelled: %s", endpoint.name(), x); + } + } + } + // need to clear the selected keys. select won't do that. + selectedKeys.clear(); + } + } catch (Throwable t) { + if (done()) return; + if (debug.on()) debug.log("Selector failed", t); + if (Log.errors()) { + Log.logError("QuicNioSelector: selector exiting due to " + t); + Log.logError(t); + } + abort(t); + } finally { + if (debug.on()) debug.log("exiting"); + timer().stop(); + + try { + selector.close(); + } catch (IOException io) { + if (debug.on()) debug.log("failed to close selector: " + io); + } + } + } + + boolean abort(SelectionKey key, Throwable error) { + try { + QuicSelectableEndpoint endpoint = (QuicSelectableEndpoint) key.attachment(); + endpoint.abort(error); + } catch (Throwable e) { + if (debug.on()) { + debug.log("Failed to close endpoint associated with key %s: %s", key, error); + } + } + return true; + } + + @Override + public void abort(Throwable error) { + super.shutdown(); + try { + if (selector.isOpen()) { + for (var k : selector.keys()) { + abort(k, error); + } + } + } catch (ClosedSelectorException cse) { + // ignore + } finally { + super.abort(error); + } + } + } + + public long computeNextDeadLine() { + Deadline now = source.instant(); + Deadline deadline = timerQueue.processEventsAndReturnNextDeadline(now, instance.executor()); + if (deadline.equals(Deadline.MAX)) return IDLE_PERIOD_MS; + if (deadline.equals(Deadline.MIN)) { + if (Log.quicTimer()) { + Log.logQuic(String.format("%s: %s millis until %s", name, 1, "now")); + } + return 1; + } + now = source.instant(); + long millis = now.until(deadline, ChronoUnit.MILLIS); + // millis could be 0 if the next deadline is within 1ms of now. + // in that case, round up millis to 1ms since returning 0 + // means the selector would block indefinitely + if (Log.quicTimer()) { + Log.logQuic(String.format("%s: %s millis until %s", + name, (millis <= 0L ? 1L : millis), deadline)); + } + return millis <= 0L ? 1L : millis; + } + + public void start() { + thread.start(); + } + + /** + * Shuts down the {@code QuicSelector} by invoking {@link Selector#close()}. + * This method doesn't wait for the selector thread to terminate. + * @see #awaitTermination(long, TimeUnit) + */ + public void shutdown() { + if (debug.on()) debug.log("closing"); + done = true; + } + + boolean done() { + return done; + } + + /** + * Awaits termination of the selector thread, up until + * the given timeout has elapsed. + * If the current thread is the selector thread, returns + * immediately without waiting. + * + * @param timeout the maximum time to wait for termination + * @param unit the timeout unit + */ + public void awaitTermination(long timeout, TimeUnit unit) { + if (Thread.currentThread() == thread) { + return; + } + try { + thread.join(unit.toMillis(timeout)); + } catch (InterruptedException ie) { + if (debug.on()) debug.log("awaitTermination interrupted: " + ie); + Thread.currentThread().interrupt(); + } + } + + /** + * Closes this {@code QuicSelector}. + * This method calls {@link #shutdown()} and then {@linkplain + * #awaitTermination(long, TimeUnit) waits for the selector thread + * to terminate}, up to two {@link #IDLE_PERIOD_MS}. + */ + @Override + public void close() { + shutdown(); + awaitTermination(IDLE_PERIOD_MS * 2, TimeUnit.MILLISECONDS); + } + + @Override + public String toString() { + return name; + } + + // Called in case of RejectedExecutionException, or shutdownNow; + public void abort(Throwable t) { + shutdown(); + } + + static class QuicSelectorThread extends Thread { + QuicSelectorThread(QuicSelector selector) { + super(null, selector, + "Thread(%s)".formatted(selector.name()), + 0, false); + this.setDaemon(true); + } + } + + /** + * {@return a new instance of {@code QuicNioSelector}} + *

+ * A {@code QuicNioSelector} is an implementation of {@link QuicSelector} + * based on non blocking {@linkplain DatagramChannel Datagram Channels} and + * using an underlying {@linkplain Selector NIO Selector}. + *

+ * The returned implementation can only be used with + * {@link QuicSelectableEndpoint} endpoints. + * + * @param quicInstance the quic instance + * @param name the selector name + * @throws IOException if an IOException occurs when creating the underlying {@link Selector} + */ + public static QuicSelector createQuicNioSelector(QuicInstance quicInstance, String name) + throws IOException { + Selector selector = Selector.open(); + return new QuicNioSelector(quicInstance, selector, name); + } + + /** + * {@return a new instance of {@code QuicVirtualThreadPoller}} + * A {@code QuicVirtualThreadPoller} is an implementation of + * {@link QuicSelector} based on blocking {@linkplain DatagramChannel + * Datagram Channels} and using {@linkplain Thread#ofVirtual() + * Virtual Threads} to poll the datagram channels. + *

+ * The returned implementation can only be used with + * {@link QuicVirtualThreadedEndpoint} endpoints. + * + * @param quicInstance the quic instance + * @param name the selector name + */ + public static QuicSelector createQuicVirtualThreadPoller(QuicInstance quicInstance, String name) { + return new QuicVirtualThreadPoller(quicInstance, name); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicStreamLimitException.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicStreamLimitException.java new file mode 100644 index 00000000000..e5802fef20c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicStreamLimitException.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +/** + * Used internally to indicate Quic stream limit has been reached + */ +public final class QuicStreamLimitException extends Exception { + + @java.io.Serial + private static final long serialVersionUID = 4181770819022847041L; + + public QuicStreamLimitException(String message) { + super(message); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTimedEvent.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTimedEvent.java new file mode 100644 index 00000000000..9269b12bf64 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTimedEvent.java @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.util.Comparator; + +import jdk.internal.net.http.common.Deadline; + +/** + * Models an event which is triggered upon reaching a + * deadline. {@code QuicTimedEvent} instances are designed to be + * registered with a single {@link QuicTimerQueue}. + * + * @implSpec + * Implementations should make sure that each instance of + * {@code QuicTimedEvent} is only present once in a single + * {@link QuicTimerQueue} at any given time. It is however + * allowed to register the event again with the same {@link QuicTimerQueue} + * after it has been handled, or if it is no longer registered in any + * queue. + */ +sealed interface QuicTimedEvent + permits PacketSpaceManager.PacketTransmissionTask, + QuicTimerQueue.Marker, + QuicEndpoint.ClosedConnection, + IdleTimeoutManager.IdleTimeoutEvent, + IdleTimeoutManager.StreamDataBlockedEvent, + QuicConnectionImpl.MaxInitialTimer { + + /** + * {@return the deadline at which the event should be triggered, + * or {@link Deadline#MAX} if the event does not need + * to be scheduled} + * @implSpec + * Care should be taken to not change the deadline while the + * event is registered with a {@link QuicTimerQueue timer queue}. + * The only safe time when the deadline can be changed is: + *

    + *
  • when {@link #refreshDeadline()} method, since the event + * is not in any queue at that point,
  • + *
  • when the deadline is {@link Deadline#MAX}, since the + * event should not be in any queue if it has no + * deadline
  • + *
+ * + */ + Deadline deadline(); + + /** + * Handles the triggered event. + * Returns a new deadline, if the event needs to be + * rescheduled, or {@code Deadline.MAX} otherwise. + * + * @implSpec + * The {@link #deadline() deadline} should not be + * changed before {@link #refreshDeadline()} is called. + * + * @return a new deadline if the event should be + * rescheduled right away, {@code Deadline.MAX} + * otherwise. + */ + Deadline handle(); + + /** + * An event id, obtained at construction time from + * {@link QuicTimerQueue#newEventId()}. This is used + * to implement a total order among subclasses. + * @return this event's id. + */ + long eventId(); + + /** + * {@return true if this event's deadline is before the + * other's event deadline} + * + * @implSpec + * The default implementation of this method is to return {@code + * deadline().isBefore(other.deadline())}. + * + * @param other the other event + */ + default boolean isBefore(QuicTimedEvent other) { + return deadline().isBefore(other.deadline()); + } + + /** + * Compares this event's deadline with the other event's deadline. + * + * @implSpec + * The default implementation of this method compares deadlines in the same manner as + * {@link Deadline#compareTo(Deadline) this.deadline().compareTo(other.deadline())} would. + * + * @param other the other event + * + * @return {@code -1}, {@code 0}, or {@code 1} depending on whether this + * event's deadline is before, equals to, or after, the other event's + * deadline. + */ + default int compareDeadlines(QuicTimedEvent other) { return deadline().compareTo(other.deadline());} + + /** + * Called to cause an event to refresh its deadline. + * This method is called by the {@link QuicTimerQueue} + * when rescheduling an event. + * @apiNote + * The value returned by {@link #deadline()} can only be safely + * updated when this method is called. + */ + Deadline refreshDeadline(); + + /** + * Compares two instance of {@link QuicTimedEvent}. + * First compared their {@link #deadline()}, then their {@link #eventId()}. + * It is expected that two elements with same deadline and same event id + * must the same {@link QuicTimedEvent} instance. + * + * @param one a first QuicTimedEvent instance + * @param two a second QuicTimedEvent instance + * @return whether the first element is less, same, or greater than the + * second. + */ + static int compare(QuicTimedEvent one, QuicTimedEvent two) { + if (one == two) return 0; + int cmp = one.compareDeadlines(two); + cmp = cmp == 0 ? Long.compare(one.eventId(), two.eventId()) : cmp; + // ensure total ordering; + assert cmp != 0 || one.equals(two) && two.equals(one); + return cmp; + } + + /** + * A comparator that compares {@code QuicTimedEvent} instances by their deadline, in the same + * manner as {@link #compare(QuicTimedEvent, QuicTimedEvent) QuicTimedEvent::compare}. + */ + // public static final (are redundant) + Comparator COMPARATOR = QuicTimedEvent::compare; + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTimerQueue.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTimerQueue.java new file mode 100644 index 00000000000..830415593cb --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTimerQueue.java @@ -0,0 +1,522 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicLong; + +import jdk.internal.net.http.common.Deadline; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.TimeSource; +import jdk.internal.net.http.common.Utils; + +/** + * A timer queue that can process events which are due, and possibly + * reschedule them if needed. An instance of a {@link QuicTimerQueue} + * is usually associated with an instance of {@link QuicSelector} which + * provides the timer/wakeup facility. + */ +public final class QuicTimerQueue { + + // A queue that contains scheduled events + private final ConcurrentSkipListSet scheduled = + new ConcurrentSkipListSet<>(QuicTimedEvent.COMPARATOR); + + // A queue that contains events which are due. The queue is + // filled by processAndReturnNextDeadline() + private final ConcurrentLinkedQueue due = + new ConcurrentLinkedQueue<>(); + + // A queue that contains events that need to be rescheduled. + // The event may already be in the scheduled queue - in which + // case it will be removed before being added back. + private final Set rescheduled = + ConcurrentHashMap.newKeySet(); + + // A callback to tell the timer thread to wake up + private final Runnable notifier; + + // A loop to process events which are due, or which need to + // be rescheduled. + private final SequentialScheduler processor = + SequentialScheduler.lockingScheduler(this::processDue); + + private final Logger debug; + private volatile boolean closed; + private volatile Deadline scheduledDeadline = Deadline.MAX; + private volatile Deadline returnedDeadline = Deadline.MAX; + + /** + * Creates a new timer queue with the given notifier. + * A notifier is used to notify the timer thread that + * new events have been added to the queue of scheduled + * event. The notifier should wake up the thread and + * trigger a call to either {@link + * #processEventsAndReturnNextDeadline(Deadline, Executor)} + * or {@link #nextDeadline()}. + * + * @param notifier A notifier to wake up the timer thread when + * new event have been added and the next + * deadline has changed. + */ + public QuicTimerQueue(Runnable notifier, Logger debug) { + this.notifier = notifier; + this.debug = debug; + } + + // For debug purposes only + private String d(Deadline deadline) { + return Utils.debugDeadline(debugNow(), deadline); + } + + // For debug purposes only + private String d(Deadline now, Deadline deadline) { + return Utils.debugDeadline(now, deadline); + } + + // For debug purposes only + private Deadline debugNow() { + return TimeSource.now(); + } + + /** + * Schedule the given event by adding it to the timer queue. + * + * @param event an event to be scheduled + */ + public void offer(QuicTimedEvent event) { + if (event instanceof Marker marker) + throw new IllegalArgumentException(marker.name()); + assert QuicTimedEvent.COMPARATOR.compare(event, FLOOR) > 0; + assert QuicTimedEvent.COMPARATOR.compare(event, CEILING) < 0; + Deadline deadline = event.deadline(); + scheduled.add(event); + scheduled(deadline); + if (debug.on()) debug.log("QuicTimerQueue: event %s offered", event); + if (notify(deadline)) { + if (debug.on()) debug.log("QuicTimerQueue: event %s will be rescheduled", event); + if (Log.quicTimer()) { + var now = debugNow(); + Log.logQuic(String.format("%s: QuicTimerQueue: event %s will be scheduled" + + " at %s (returned deadline: %s, nextDeadline: %s)", + Thread.currentThread().getName(), event, d(now, deadline), + d(now, returnedDeadline), d(now, nextDeadline()))); + } + notifier.run(); + } else { + if (Log.quicTimer()) { + var now = debugNow(); + Log.logQuic(String.format("%s: QuicTimerQueue: event %s will not be scheduled" + + " at %s (returned deadline: %s, nextDeadline: %s)", + Thread.currentThread().getName(), event, d(now, deadline), + d(now, returnedDeadline), d(now, nextDeadline()))); + } + } + } + + /** + * The next deadline for this timer queue. This is only weakly + * consistent. If the queue is empty, {@link Deadline#MAX} is + * returned. + * + * @return The next deadline, or {@code Deadline.MAX}. + */ + public Deadline nextDeadline() { + var event = scheduled.ceiling(FLOOR); + return event == null ? Deadline.MAX : event.deadline(); + } + + public Deadline pendingScheduledDeadline() { + return scheduledDeadline; + } + + /** + * Process all events that were due before {@code now}, and + * returns the next deadline. The events are processed within + * an executor's thread, so this method may return before all + * events have been processed. The events are processed in + * order, with respect to their deadline. Processing an event + * involves invoking its {@link QuicTimedEvent#handle() handle} + * method. If that method returns a new deadline different from + * {@link Deadline#MAX} the processed event is rescheduled + * immediately. Otherwise, it will not be rescheduled. + * + * @param now The point in time before which events are + * considered to be due. Usually, that's now. + * @param executor An executor to process events which are due. + * + * @return the next unexpired deadline, or {@link Deadline#MAX} + * if the queue is empty. + */ + public Deadline processEventsAndReturnNextDeadline(Deadline now, Executor executor) { + QuicTimedEvent event; + int drained = 0; + int dues; + synchronized (this) { + scheduledDeadline = Deadline.MAX; + } + // moved scheduled / rescheduled tasks to due, until + // nothing else is due. Then process dues. + do { + dues = processRescheduled(now); + dues = dues + processScheduled(now); + drained += dues; + } while (dues > 0); + Deadline newDeadline = (event = scheduled.ceiling(FLOOR)) == null ? Deadline.MAX : event.deadline(); + assert event == null || newDeadline.isBefore(Deadline.MAX) : "Invalid deadline for " + event; + if (debug.on()) { + debug.log("QuicTimerQueue: newDeadline: " + d(now, newDeadline) + + (event == null ? "no event scheduled" : (" for " + event))); + } + Deadline next; + synchronized (this) { + var scheduled = scheduledDeadline; + scheduledDeadline = Deadline.MAX; + // if some task is being rescheduled with a deadline + // that is before any scheduled deadline, use that deadline. + next = returnedDeadline = min(newDeadline, scheduled); + } + if (next.equals(Deadline.MAX)) { + if (Log.quicTimer()) { + Log.logQuic(String.format("%s: TimerQueue: no deadline" + + " (scheduled: %s, rescheduled: %s, dues %s)", + Thread.currentThread().getName(), this.scheduled.size(), + this.rescheduled.size(), this.due.size())); + } + } + if (drained > 0) { + if (Log.quicTimer()) { + Log.logQuic(String.format("%s: TimerQueue: %s events to handle (%s in dues)", + Thread.currentThread().getName(), drained, this.due.size())); + } + processor.runOrSchedule(executor); + } + return next; + } + + // return the deadline which is before the other + private Deadline min(Deadline one, Deadline two) { + return one.isBefore(two) ? one : two; + } + + // walk through the rescheduled tasks and moves any + // that are due to `due`. Otherwise, move them to + // `scheduled` + private int processRescheduled(Deadline now) { + int drained = 0; + for (var it = rescheduled.iterator(); it.hasNext(); ) { + QuicTimedEvent event = it.next(); + it.remove(); // remove before processing to avoid race + scheduled.remove(event); + Deadline deadline = event.refreshDeadline(); + if (deadline.equals(Deadline.MAX)) { + continue; + } + if (deadline.isAfter(now)) { + scheduled.add(event); + } else { + due.add(event); + drained++; + } + } + if (drained > 0) { + if (debug.on()) { + debug.log("QuicTimerQueue: %s rescheduled tasks are due", drained); + } + } + return drained; + } + + // walk through the scheduled tasks and moves any + // that are due to `due`. + private int processScheduled(Deadline now) { + QuicTimedEvent event; + int drained = 0; + while ((event = scheduled.ceiling(FLOOR)) != null) { + Deadline deadline = event.deadline(); + if (!isDue(deadline, now)) { + break; + } + event = scheduled.pollFirst(); + if (event == null) { + break; + } + drained++; + due.add(event); + } + if (drained > 0 && debug.on()) { + debug.log("QuicTimerQueue: %s scheduled tasks are due", drained); + } + return drained; + } + + private static boolean isDue(final Deadline deadline, final Deadline now) { + return deadline.compareTo(now) <= 0; + } + + // process all due events in order + private void processDue() { + try { + QuicTimedEvent event; + if (closed) return; + if (debug.on()) debug.log("QuicTimerQueue: processDue"); + if (Log.quicTimer()) { + Log.logQuic(String.format("%s: TimerQueue: process %s events", + Thread.currentThread().getName(), due.size())); + } + Deadline minDeadLine = Deadline.MAX; + while ((event = due.poll()) != null) { + if (closed) return; + Deadline nextDeadline = event.handle(); + if (Deadline.MAX.equals(nextDeadline)) continue; + rescheduled.add(event); + if (nextDeadline.isBefore(minDeadLine)) minDeadLine = nextDeadline; + } + + // record the minimal deadline that was rescheduled + scheduled(minDeadLine); + + // wake up the selector thread if necessary + if (notify(minDeadLine)) { + if (Log.quicTimer()) { + Log.logQuic(String.format("%s: TimerQueue: notify: minDeadline: %s", + Thread.currentThread().getName(), d(minDeadLine))); + } + notifier.run(); + } else if (!minDeadLine.equals(Deadline.MAX)) { + if (Log.quicTimer()) { + Log.logQuic(String.format("%s: TimerQueue: no need to notify: minDeadline: %s", + Thread.currentThread().getName(), d(minDeadLine))); + } + } + + } catch (Throwable t) { + if (!closed) { + if (Log.errors()) { + Log.logError(Thread.currentThread().getName() + + ": Unexpected exception while processing due events: " + t); + Log.logError(t); + } else if (debug.on()) { + debug.log("Unexpected exception while processing due events", t); + } + throw t; + } else { + if (Log.errors()) { + Log.logError(Thread.currentThread().getName() + + ": Ignoring exception while closing: " + t); + Log.logError(t); + } else if (debug.on()) { + debug.log("Ignoring exception while closing: " + t); + } + } + } + } + + // We do not need to notify the selector thread if the next scheduled + // deadline is before the given deadline, or if it is after + // the last returned deadline. + private boolean notify(Deadline deadline) { + synchronized (this) { + if (deadline.isBefore(nextDeadline()) + || deadline.isBefore(returnedDeadline)) { + return true; + } + } + return false; + } + + // Record a prospective attempt to reschedule an event at + // the given deadline + private Deadline scheduled(Deadline deadline) { + synchronized (this) { + var scheduled = scheduledDeadline; + if (deadline.isBefore(scheduled)) { + scheduledDeadline = deadline; + return deadline; + } + return scheduled; + } + } + + /** + * Reschedule the given {@code QuicTimedEvent}. + * + * @apiNote + * This method is used if the prospective future deadline at which the event + * should be scheduled is not known by the caller. + * This may cause an idle wakeup in the selector thread owning this + * {@code QuicTimerQueue}. Use {@link #reschedule(QuicTimedEvent, Deadline)} + * to minimize idle wakeup. + * + * @param event an event to reschedule + */ + public void reschedule(QuicTimedEvent event) { + if (event instanceof Marker marker) + throw new IllegalArgumentException(marker.name()); + assert QuicTimedEvent.COMPARATOR.compare(event, FLOOR) > 0; + assert QuicTimedEvent.COMPARATOR.compare(event, CEILING) < 0; + rescheduled.add(event); + if (debug.on()) debug.log("QuicTimerQueue: event %s will be rescheduled", event); + if (Log.quicTimer()) { + var now = debugNow(); + Log.logQuic(String.format("%s: QuicTimerQueue: event %s will be rescheduled" + + " (returned deadline: %s, nextDeadline: %s)", + Thread.currentThread().getName(), event, d(now, returnedDeadline), + d(now, nextDeadline()))); + } + notifier.run(); + } + + /** + * Reschedule the given {@code QuicTimedEvent}. + * + * @apiNote + * This method should be used in preference of {@link #reschedule(QuicTimedEvent)} + * if the prospective future deadline at which the event should be scheduled is + * already known by the caller. Using this method will minimize idle wakeup + * of the selector thread, in comparison of {@link #reschedule(QuicTimedEvent)}. + * + * @param event an event to reschedule + * @param deadline the prospective future deadline at which the event should + * be rescheduled + */ + public void reschedule(QuicTimedEvent event, Deadline deadline) { + if (event instanceof Marker marker) + throw new IllegalArgumentException(marker.name()); + assert QuicTimedEvent.COMPARATOR.compare(event, FLOOR) > 0; + assert QuicTimedEvent.COMPARATOR.compare(event, CEILING) < 0; + rescheduled.add(event); + scheduled(deadline); + // no need to wake up the selector thread if the next deadline + // is already before the new deadline + + if (notify(deadline)) { + if (Log.quicTimer()) { + var now = debugNow(); + Log.logQuic(String.format("%s: QuicTimerQueue: event %s will be rescheduled" + + " at %s (returned deadline: %s, nextDeadline: %s)", + Thread.currentThread().getName(), event, d(now, deadline), + d(now, returnedDeadline), d(now, nextDeadline()))); + } else if (debug.on()) { + debug.log("QuicTimerQueue: event %s will be rescheduled", event); + } + notifier.run(); + } else { + if (Log.quicTimer()) { + var now = debugNow(); + Log.logQuic(String.format("%s: QuicTimerQueue: event %s will not be rescheduled" + + " at %s (returned deadline: %s, nextDeadline: %s)", + Thread.currentThread().getName(), event, d(now, deadline), + d(now, returnedDeadline), d(now, nextDeadline()))); + } + } + } + + private static final AtomicLong EVENTIDS = new AtomicLong(); + + /** + * {@return a unique id for a new {@link QuicTimedEvent}} + * Each new instance of {@link QuicTimedEvent} is created with a long + * ID returned by this method to ensure a total ordering of + * {@code QuicTimedEvent} instances, even when their deadlines + * are equal. + */ + public static long newEventId() { + return EVENTIDS.getAndIncrement(); + } + + // aliases + private static final Marker FLOOR = Marker.FLOOR; + private static final Marker CEILING = Marker.CEILING; + + /** + * Called to clean up the timer queue when it is no longer needed. + * Makes sure that all pending tasks are cleared from the various lists. + */ + public void stop() { + closed = true; + do { + processor.stop(); + due.clear(); + rescheduled.clear(); + scheduled.clear(); + } while (!due.isEmpty() || !rescheduled.isEmpty() || !scheduled.isEmpty()); + } + + // This class is used to work around the lack of a peek() method + // in ConcurrentSkipListSet. ConcurrentSkipListSet has a method + // called first(), but it throws NoSuchElementException if the + // set isEmpty() - whereas peek() would return {@code null}. + // The next best thing is to use ConcurrentSkipListSet::ceiling, + // but for that we need to define a minimum event which is lower + // than any other event: we do this by defining Marker.FLOOR + // which has deadline=Deadline.MIN and eventId=Long.MIN_VALUE; + // Note: it would be easier to use a record, but an enum ensures that we + // can only have the two instances FLOOR and CEILING. + enum Marker implements QuicTimedEvent { + /** + * A {@code Marker} event to pass to {@link ConcurrentSkipListSet#ceiling(Object) + * ConcurrentSkipListSet::ceiling} in order to get the first event in the list, + * or {@code null}. + * + * @apiNote + * The intended usage is:
{@code
+         *       var head = scheduled.ceiling(FLOOR);
+         * }
+ * + */ + FLOOR(Deadline.MIN, Long.MIN_VALUE), + /** + * A {@code Marker} event to pass to {@link ConcurrentSkipListSet#floor(Object) + * ConcurrentSkipListSet::floor} in order to get the last event in the list, + * or {@code null}. + * + * @apiNote + * The intended usage is:
{@code
+         *       var head = scheduled.floor(CEILING);
+         * }
+ * + */ + CEILING(Deadline.MAX, Long.MAX_VALUE); + private final Deadline deadline; + private final long eventId; + private Marker(Deadline deadline, long eventId) { + this.deadline = deadline; + this.eventId = eventId; + } + + @Override public Deadline deadline() { return deadline; } + @Override public Deadline refreshDeadline() {return Deadline.MAX;} + @Override public Deadline handle() { return Deadline.MAX; } + @Override public long eventId() { return eventId; } + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTransportParameters.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTransportParameters.java new file mode 100644 index 00000000000..36832575add --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/QuicTransportParameters.java @@ -0,0 +1,1319 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.EnumMap; +import java.util.HexFormat; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Stream; + +import jdk.internal.net.http.common.Log; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.quic.QuicTransportParametersConsumer; +import jdk.internal.net.quic.QuicVersion; + +/** + * This class models a collection of Quic transport parameters. This class is mutable + * and not thread safe. + * + * A parameter is considered absent if {@link #getParameter(TransportParameterId)} + * yields {@code null}. The parameter is present otherwise. + * Parameters can be removed by calling {@link + * #setParameter(TransportParameterId, byte[]) setParameter(id, null)}. + * The methods {@link #getBooleanParameter(TransportParameterId)} and + * {@link #getIntParameter(TransportParameterId)} allow easy access to + * parameters whose type is boolean or int, respectively. + * When such a parameter is absent, its default value is returned by + * those methods. + + * From + * RFC 9000, section 18.2: + * + *
+ *
{@code
+ * Many transport parameters listed here have integer values.
+ * Those transport parameters that are identified as integers use a
+ * variable-length integer encoding; see Section 16. Transport parameters
+ * have a default value of 0 if the transport parameter is absent, unless
+ * otherwise stated.
+ * }
+ * + *

[...] + * + *

{@code
+ * If present, transport parameters that set initial per-stream flow control limits
+ * (initial_max_stream_data_bidi_local, initial_max_stream_data_bidi_remote, and
+ * initial_max_stream_data_uni) are equivalent to sending a MAX_STREAM_DATA frame
+ * (Section 19.10) on every stream of the corresponding type immediately after opening.
+ * If the transport parameter is absent, streams of that type start with a flow control
+ * limit of 0.
+ *
+ * A client MUST NOT include any server-only transport parameter:
+ *        original_destination_connection_id,
+ *        preferred_address,
+ *        retry_source_connection_id, or
+ *        stateless_reset_token.
+ *
+ * A server MUST treat receipt of any of these transport parameters as a connection error
+ * of type TRANSPORT_PARAMETER_ERROR.
+ * }
+ *
+ * + * @see ParameterId + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class QuicTransportParameters { + + /** + * An interface to model a transport parameter ID. + * A transport parameter ID has a {@linkplain #name() name} (which is + * not transmitted) and an {@linkplain #idx() identifier}. + * Standard parameters are modeled by enum values in + * {@link ParameterId}. + */ + public sealed interface TransportParameterId { + /** + * {@return the transport parameter name} + * This a human-readable string. + */ + String name(); + + /** + * {@return the transport parameter identifier} + */ + int idx(); + + /** + * {@return the parameter id corresponding to the given identifier, if + * defined, an empty optional otherwise} + * @param idx a parameter identifier + */ + static Optional valueOf(long idx) { + return ParameterId.valueOf(idx); + } + } + + /** + * Standard Quic transport parameter names and ids. + * These are the transport parameters defined in IANA + * "QUIC Transport Parameters" registry. + * @see + * RFC 9000, Section 22.3 + */ + public enum ParameterId implements TransportParameterId { + /** + * original_destination_connection_id (0x00). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     This parameter is the value of the Destination Connection ID field
+         *     from the first Initial packet sent by the client; see Section 7.3.
+         *     This transport parameter is only sent by a server.
+         * }
+ * @see RFC 9000, Section 7.3 + */ + original_destination_connection_id(0x00), + + /** + * max_idle_timeout (0x01). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     The maximum idle timeout is a value in milliseconds that is encoded
+         *     as an integer; see (Section 10.1).
+         *     Idle timeout is disabled when both endpoints omit this transport
+         *     parameter or specify a value of 0.
+         * }
+ * @see RFC 9000, Section 10.1 + */ + max_idle_timeout(0x01), + + /** + * stateless_reset_token (0x02). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     A stateless reset token is used in verifying a stateless reset;
+         *     see Section 10.3.
+         *     This parameter is a sequence of 16 bytes. This transport parameter MUST NOT
+         *     be sent by a client but MAY be sent by a server. A server that does not send
+         *     this transport parameter cannot use stateless reset (Section 10.3) for
+         *     the connection ID negotiated during the handshake.
+         * }
+ * @see RFC 9000, Section 10.3 + */ + stateless_reset_token(0x02), + + /** + * max_udp_payload_size (0x03). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     The maximum UDP payload size parameter is an integer value that limits the
+         *     size of UDP payloads that the endpoint is willing to receive. UDP datagrams
+         *     with payloads larger than this limit are not likely to be processed by
+         *     the receiver.
+         *
+         *     The default for this parameter is the maximum permitted UDP payload of 65527.
+         *     Values below 1200 are invalid.
+         *
+         *     This limit does act as an additional constraint on datagram size
+         *     in the same way as the path MTU, but it is a property of the endpoint
+         *     and not the path; see Section 14.
+         *     It is expected that this is the space an endpoint dedicates to
+         *     holding incoming packets.
+         * }
+ * @see RFC 9000, Section 14 + */ + max_udp_payload_size(0x03), + + /** + * initial_max_data (0x04). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     The initial maximum data parameter is an integer value that contains
+         *     the initial value for the maximum amount of data that can be sent on
+         *     the connection. This is equivalent to sending a MAX_DATA (Section 19.9)
+         *     for the connection immediately after completing the handshake.
+         * }
+ * @see RFC 9000, Section 19.9 + */ + initial_max_data(0x04), + + /** + * initial_max_stream_data_bidi_local (0x05). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     This parameter is an integer value specifying the initial flow control
+         *     limit for locally initiated bidirectional streams. This limit applies to
+         *     newly created bidirectional streams opened by the endpoint that
+         *     sends the transport parameter.
+         *     In client transport parameters, this applies to streams with an identifier
+         *     with the least significant two bits set to 0x00;
+         *     in server transport parameters, this applies to streams with the least
+         *     significant two bits set to 0x01.
+         * }
+ */ + initial_max_stream_data_bidi_local(0x05), + + /** + * initial_max_stream_data_bidi_remote (0x06). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     This parameter is an integer value specifying the initial flow control
+         *     limit for peer-initiated bidirectional streams. This limit applies to
+         *     newly created bidirectional streams opened by the endpoint that receives
+         *     the transport parameter. In client transport parameters, this applies to
+         *     streams with an identifier with the least significant two bits set to 0x01;
+         *     in server transport parameters, this applies to streams with the least
+         *     significant two bits set to 0x00.
+         * }
+ */ + initial_max_stream_data_bidi_remote(0x06), + + /** + * initial_max_stream_data_uni (0x07). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     This parameter is an integer value specifying the initial flow control
+         *     limit for unidirectional streams. This limit applies to newly created
+         *     unidirectional streams opened by the endpoint that receives the transport
+         *     parameter. In client transport parameters, this applies to streams with
+         *     an identifier with the least significant two bits set to 0x03; in server
+         *     transport parameters, this applies to streams with the least significant
+         *     two bits set to 0x02.
+         * }
+ */ + initial_max_stream_data_uni(0x07), + + /** + * initial_max_streams_bidi (0x08). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     The initial maximum bidirectional streams parameter is an integer value
+         *     that contains the initial maximum number of bidirectional streams the
+         *     endpoint that receives this transport parameter is permitted to initiate.
+         *     If this parameter is absent or zero, the peer cannot open bidirectional
+         *     streams until a MAX_STREAMS frame is sent. Setting this parameter is equivalent
+         *     to sending a MAX_STREAMS (Section 19.11) of the corresponding type with the
+         *     same value.
+         * }
+ * @see RFC 9000, Section 19.11 + */ + initial_max_streams_bidi(0x08), + + /** + * initial_max_streams_uni (0x09). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     The initial maximum unidirectional streams parameter is an integer value that
+         *     contains the initial maximum number of unidirectional streams the endpoint
+         *     that receives this transport parameter is permitted to initiate. If this parameter
+         *     is absent or zero, the peer cannot open unidirectional streams until a MAX_STREAMS
+         *     frame is sent. Setting this parameter is equivalent to sending a MAX_STREAMS
+         *     (Section 19.11) of the corresponding type with the same value.
+         * }
+ * @see RFC 9000, Section 19.11 + */ + initial_max_streams_uni(0x09), + + /** + * ack_delay_exponent (0x0a). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     The acknowledgment delay exponent is an integer value indicating an exponent
+         *     used to decode the ACK Delay field in the ACK frame (Section 19.3). If this
+         *     value is absent, a default value of 3 is assumed (indicating a multiplier of 8).
+         *     Values above 20 are invalid.
+         * }
+ * @see RFC 9000, Section 19.3 + */ + ack_delay_exponent(0x0a), + + /** + * max_ack_delay (0x0b). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     The maximum acknowledgment delay is an integer value indicating the maximum
+         *     amount of time in milliseconds by which the endpoint will delay sending acknowledgments.
+         *     This value SHOULD include the receiver's expected delays in alarms firing. For example,
+         *     if a receiver sets a timer for 5ms and alarms commonly fire up to 1ms late, then it
+         *     should send a max_ack_delay of 6ms. If this value is absent, a default of 25
+         *     milliseconds is assumed. Values of 2^14 or greater are invalid.
+         * }
+ */ + max_ack_delay(0x0b), + + /** + * disable_active_migration (0x0c). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     The disable active migration transport parameter is included if the endpoint does not
+         *     support active connection migration (Section 9) on the address being used during the
+         *     handshake. An endpoint that receives this transport parameter MUST NOT use a new local
+         *     address when sending to the address that the peer used during the handshake. This transport
+         *     parameter does not prohibit connection migration after a client has acted on a
+         *     preferred_address transport parameter. This parameter is a zero-length value.
+         * }
+ * @see RFC 9000, Section 9 + */ + disable_active_migration(0x0c), + + /** + * preferred_address (0x0d). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     The server's preferred address is used to effect a change in server address at the
+         *     end of the handshake, as described in Section 9.6. This transport parameter is only
+         *     sent by a server.
+         *     Servers MAY choose to only send a preferred address of one address family
+         *     by sending an all-zero address and port (0.0.0.0:0 or [::]:0) for the
+         *     other family. IP addresses are encoded in network byte order.
+         *
+         *     The preferred_address transport parameter contains an address and port for both
+         *     IPv4 and IPv6. The four-byte IPv4 Address field is followed by the associated
+         *     two-byte IPv4 Port field. This is followed by a 16-byte IPv6 Address field and
+         *     two-byte IPv6 Port field. After address and port pairs, a Connection ID Length
+         *     field describes the length of the following Connection ID field.
+         *     Finally, a 16-byte Stateless Reset Token field includes the stateless reset
+         *     token associated with the connection ID. The format of this transport parameter
+         *     is shown in Figure 22 below.
+         *
+         *     The Connection ID field and the Stateless Reset Token field contain an alternative
+         *     connection ID that has a sequence number of 1; see Section 5.1.1. Having these values
+         *     sent alongside the preferred address ensures that there will be at least one
+         *     unused active connection ID when the client initiates migration to the preferred
+         *     address.
+         *
+         *     The Connection ID and Stateless Reset Token fields of a preferred address are
+         *     identical in syntax and semantics to the corresponding fields of a NEW_CONNECTION_ID
+         *     frame (Section 19.15). A server that chooses a zero-length connection ID MUST NOT
+         *     provide a preferred address. Similarly, a server MUST NOT include a zero-length
+         *     connection ID in this transport parameter. A client MUST treat a violation of
+         *     these requirements as a connection error of type TRANSPORT_PARAMETER_ERROR.
+         *
+         * Preferred Address {
+         *   IPv4 Address (32),
+         *   IPv4 Port (16),
+         *   IPv6 Address (128),
+         *   IPv6 Port (16),
+         *   Connection ID Length (8),
+         *   Connection ID (..),
+         *   Stateless Reset Token (128),
+         * }
+         *
+         * Figure 22: Preferred Address Format
+         * }
+ * @see RFC 9000, Section 5.1.1 + * @see RFC 9000, Section 9.6 + * @see RFC 9000, Section 19.15 + */ + preferred_address(0x0d), + + /** + * active_connection_id_limit (0x0e). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     This is an integer value specifying the maximum number of connection IDs from
+         *     the peer that an endpoint is willing to store. This value includes the connection
+         *     ID received during the handshake, that received in the preferred_address transport
+         *     parameter, and those received in NEW_CONNECTION_ID frames. The value of the
+         *     active_connection_id_limit parameter MUST be at least 2. An endpoint that receives
+         *     a value less than 2 MUST close the connection with an error of type
+         *     TRANSPORT_PARAMETER_ERROR. If this transport parameter is absent, a default of 2 is
+         *     assumed. If an endpoint issues a zero-length connection ID, it will never send a
+         *     NEW_CONNECTION_ID frame and therefore ignores the active_connection_id_limit value
+         *     received from its peer.
+         * }
+ */ + active_connection_id_limit(0x0e), + + /** + * initial_source_connection_id (0x0f). + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+         *     This is the value that the endpoint included in the Source Connection ID field of
+         *     the first Initial packet it sends for the connection; see Section 7.3.
+         * }
+ * @see RFC 9000, Section 7.3 + */ + initial_source_connection_id(0x0f), + + /** + * retry_source_connection_id (0x10). + *

+ * From + * RFC 9000, Section 18.2 + *

{@code
+         *     This is the value that the server included in the Source Connection ID field of a
+         *     Retry packet; see Section 7.3. This transport parameter is only sent by a server.
+         * }
+ * @see RFC 9000, Section 7.3 + */ + retry_source_connection_id(0x10), + + /** + * version_information (0x11). + *

+ * From + * RFC 9368, Section 3 + *

{@code
+         *     During the handshake, endpoints will exchange Version Information,
+         *     which consists of a Chosen Version and a list of Available Versions.
+         *     Any version of QUIC that supports this mechanism MUST provide a mechanism
+         *     to exchange Version Information in both directions during the handshake,
+         *     such that this data is authenticated.
+         * }
+ */ + version_information(0x11); + + /* + * Reserved Transport Parameters (31 * N + 27 for int values of N) + *

+ * From + * RFC 9000, Section 18.1 + *

{@code
+         *     Transport parameters with an identifier of the form 31 * N + 27
+         *     for integer values of N are reserved to exercise the requirement
+         *     that unknown transport parameters be ignored. These transport
+         *     parameters have no semantics and can carry arbitrary values.
+         * }
+ */ + // No values are defined here, but these will be + // ignored if received (see + // sun.security.ssl.QuicTransportParametersExtension). + + /** + * The number of known transport parameters. + * This is also the number of enum values defined by the + * {@link ParameterId} enumeration. + */ + private static final int PARAMETERS_COUNT = ParameterId.values().length; + + ParameterId(int idx) { + // idx() and valueOf() assume that idx = ordinal; + // if that's no longer the case, update the implementation + // and remove this assert. + assert idx == ordinal(); + } + + @Override + public int idx() { + return ordinal(); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + + private static Optional valueOf(long idx) { + if (idx < 0 || idx >= PARAMETERS_COUNT) return Optional.empty(); + return Optional.of(values()[(int)idx]); + } + } + + public record VersionInformation(int chosenVersion, int[] availableVersions) { } + + /** + * A map to store transport parameter values. + * Contains a byte array corresponding to the encoded value + * of the parameter. + */ + private final Map values; + + /** + * Constructs a new empty array of Quic transport parameters. + */ + public QuicTransportParameters() { + values = new EnumMap<>(ParameterId.class); + } + + /** + * Constructs a new collection of Quic transport parameters initialized + * from the specified collection. + * @param params the parameter collection used to initialize this object * + */ + public QuicTransportParameters(QuicTransportParameters params) { + values = new EnumMap<>(params.values); + } + + /** + * {@return true if the given parameter is present, false otherwise} + * @apiNote + * This is equivalent to {@link #getParameter(TransportParameterId) + * getParameter(id) != null}, but avoids cloning the parameter value. + * @param id the parameter id + */ + public boolean isPresent(TransportParameterId id) { + byte[] value = values.get((ParameterId) id); + return value != null; + } + + /** + * {@return the value of the given parameter, as a byte array, or + * {@code null} if the parameter is absent. + * @param id the parameter id + */ + public byte[] getParameter(TransportParameterId id) { + byte[] value = values.get((ParameterId) id); + return value == null ? null : value.clone(); + } + + /** + * {@return true if the value of the given parameter matches the given connection ID} + * @param id the transport parameter id + * @param connectionId the connection id to match against + */ + public boolean matches(TransportParameterId id, QuicConnectionId connectionId) { + byte[] value = values.get((ParameterId) id); + return connectionId.matches(ByteBuffer.wrap(value).asReadOnlyBuffer()); + } + + /** + * Sets the value of the given parameter. + * If the given value is {@code null}, the parameter is removed. + * @param id the parameter id + * @param value the new parameter value, or {@code null}. + * @throws IllegalArgumentException if the given value is invalid for + * the given parameter id + */ + public void setParameter(TransportParameterId id, byte[] value) { + ParameterId pid = checkParameterValue(id, value); + if (value != null) { + values.put(pid, value.clone()); + } else { + values.remove(pid); + } + } + + /** + * {@return the value of the given parameter, as an unsigned int + * in the range {@code [0, 2^62 - 1]}} + * If the parameter is not present its default value (as specified in the RFC) is returned. + * @param id the parameter id + * @throws IllegalArgumentException if the value of the given parameter + * cannot be decoded as a variable length unsigned int + */ + public long getIntParameter(TransportParameterId id) { + return getIntParameter((ParameterId)id); + } + + private long getIntParameter(final ParameterId pid) { + return switch (pid) { + case max_idle_timeout, max_udp_payload_size, initial_max_data, + initial_max_stream_data_bidi_local, initial_max_stream_data_bidi_remote, + initial_max_stream_data_uni, initial_max_streams_bidi, + initial_max_streams_uni, ack_delay_exponent, max_ack_delay, + active_connection_id_limit -> { + byte[] value = values.get(pid); + final long res; + if (value == null) { + res = switch (pid) { + case active_connection_id_limit -> 2; + case max_udp_payload_size -> 65527; + case ack_delay_exponent -> 3; + case max_ack_delay -> 25; + default -> 0; + }; + } else { + res = decodeVLIntFully(pid, ByteBuffer.wrap(value)); + } + yield res; + } + default -> throw new IllegalArgumentException(String.valueOf(pid)); + + }; + } + + /** + * {@return the value of the given parameter, as an unsigned int + * in the range {@code [0, 2^62 - 1]}} + * If the parameter is not present then {@code defaultValue} is returned. + * @param id the parameter id + * @throws IllegalArgumentException if the value of the given parameter + * cannot be decoded as a variable length unsigned int or if the {@code defaultValue} + * exceeds the maximum allowed value for variable length integer + */ + public long getIntParameter(TransportParameterId id, long defaultValue) { + if (defaultValue > VariableLengthEncoder.MAX_ENCODED_INTEGER) { + throw new IllegalArgumentException("default value " + defaultValue + + " exceeds maximum allowed variable length" + + " integer value " + VariableLengthEncoder.MAX_ENCODED_INTEGER); + } + ParameterId pid = (ParameterId)id; + return switch (pid) { + case max_idle_timeout, max_udp_payload_size, initial_max_data, + initial_max_stream_data_bidi_local, initial_max_stream_data_bidi_remote, + initial_max_stream_data_uni, initial_max_streams_bidi, + initial_max_streams_uni, ack_delay_exponent, max_ack_delay, + active_connection_id_limit -> { + byte[] value = values.get(pid); + final long res; + if (value == null) { + res = defaultValue; + } else { + res = decodeVLIntFully(pid, ByteBuffer.wrap(value)); + } + yield res; + } + default -> throw new IllegalArgumentException(String.valueOf(pid)); + }; + } + + /** + * Sets the value of the given parameter, as an unsigned int. + * If a negative value is provided, the parameter is removed. + * + * @param id the parameter id + * @param value the new value of the parameter, or a negative value + * + * @throws IllegalArgumentException if the value of the given parameter is + * not an int, or if the provided value is out of range + */ + public void setIntParameter(TransportParameterId id, long value) { + ParameterId pid = (ParameterId)id; + switch (pid) { + case max_idle_timeout, max_udp_payload_size, initial_max_data, + initial_max_stream_data_bidi_local, initial_max_stream_data_bidi_remote, + initial_max_stream_data_uni, initial_max_streams_bidi, + initial_max_streams_uni, ack_delay_exponent, max_ack_delay, + active_connection_id_limit -> { + byte[] v = null; + if (value >= 0) { + int length = VariableLengthEncoder.getEncodedSize(value); + if (length <= 0) throw new IllegalArgumentException("failed to encode " + value); + int size = VariableLengthEncoder.encode(ByteBuffer.wrap(v = new byte[length]), value); + assert size == length; + checkParameterValue(pid, v); + } + setOrRemove(pid, v); + } + default -> throw new IllegalArgumentException(String.valueOf(pid)); + } + } + + /** + * {@return the value of the given parameter, as a boolean} + * If the parameter is not present its default value (false) + * is returned. + * + * @param id the parameter id + * + * @throws IllegalArgumentException if the value of the given parameter + * is not a boolean + */ + public boolean getBooleanParameter(TransportParameterId id) { + ParameterId pid = (ParameterId)id; + if (pid != ParameterId.disable_active_migration) { + throw new IllegalArgumentException(String.valueOf(id)); + } + return values.get(pid) != null; + } + + /** + * Sets the value of the given parameter, as a boolean. + * @apiNote + * It is not possible to distinguish between a boolean parameter + * whose value is absent and a parameter whose value is false. + * Both are represented by a {@code null} value in the parameter + * array. + * @param id the parameter id + * @param value the new value of the parameter + * @throws IllegalArgumentException if the value of the given parameter is + * not a boolean + */ + public void setBooleanParameter(TransportParameterId id, boolean value) { + ParameterId pid = (ParameterId)id; + if (pid != ParameterId.disable_active_migration) { + throw new IllegalArgumentException(String.valueOf(id)); + } + setOrRemove(pid, value ? NOBYTES : null); + } + + private void setOrRemove(ParameterId pid, byte[] value) { + if (value != null) { + values.put(pid, value); + } else { + values.remove(pid); + } + } + + /** + * {@return the value of the given parameter, as {@link VersionInformation}} + * If the parameter is not present {@code null} is returned + * + * @param id the parameter id + * + * @throws IllegalArgumentException if the value of the given parameter + * is not a version information + * @throws QuicTransportException if the parameter value has incorrect length, + * or if any version is equal to zero + */ + public VersionInformation getVersionInformationParameter(TransportParameterId id) + throws QuicTransportException { + ParameterId pid = (ParameterId)id; + if (pid != ParameterId.version_information) { + throw new IllegalArgumentException(String.valueOf(id)); + } + byte[] val = values.get(pid); + if (val == null) { + return null; + } + if (val.length < 4 || (val.length & 3) != 0) { + throw new QuicTransportException( + "Invalid version information length " + val.length, + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + ByteBuffer bbval = ByteBuffer.wrap(val); + assert bbval.order() == ByteOrder.BIG_ENDIAN; + int chosen = bbval.getInt(); + if (chosen == 0) { + throw new QuicTransportException( + "[version_information] Chosen Version = 0", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + int[] available = new int[bbval.remaining() / 4]; + for (int i = 0; i < available.length; i++) { + int version = bbval.getInt(); + if (version == 0) { + throw new QuicTransportException( + "[version_information] Available Version = 0", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + available[i] = version; + } + return new VersionInformation(chosen, available); + } + + /** + * Sets the value of the given parameter, as {@link VersionInformation}. + * @param id the parameter id + * @param value the new value of the parameter + * @throws IllegalArgumentException if the value of the given parameter is + * not a version information + */ + public void setVersionInformationParameter(TransportParameterId id, VersionInformation value) { + ParameterId pid = (ParameterId)id; + if (pid != ParameterId.version_information) { + throw new IllegalArgumentException(String.valueOf(id)); + } + byte[] val = new byte[value.availableVersions.length * 4 + 4]; + ByteBuffer bbval = ByteBuffer.wrap(val); + assert bbval.order() == ByteOrder.BIG_ENDIAN; + bbval.putInt(value.chosenVersion); + for (int available : value.availableVersions) { + bbval.putInt(available); + } + assert !bbval.hasRemaining(); + values.put(pid, val); + } + + /** + * {@return a {@link VersionInformation} object corresponding to the specified versions} + * @param chosenVersion chosen version + * @param availableVersions available versions + */ + public static VersionInformation buildVersionInformation( + QuicVersion chosenVersion, List availableVersions) { + int[] available = new int[availableVersions.size()]; + for (int i = 0; i < available.length; i++) { + available[i] = availableVersions.get(i).versionNumber(); + } + return new VersionInformation(chosenVersion.versionNumber(), available); + } + + /** + * Sets the value of a parameter whose format corresponds to the + * {@link ParameterId#preferred_address} parameter. + * @param id the parameter id + * @param ipv4 the preferred IPv4 address (or the IPv4 wildcard address) + * @param port4 the preferred IPv4 port (or 0) + * @param ipv6 the preferred IPv6 address (or the IPv6 wildcard address) + * @param port6 the preferred IPv6 port (or 0) + * @param connectionId the connection id bytes + * @param statelessToken the stateless token + * @throws IllegalArgumentException if any of the given parameters has an + * illegal value, or if the given parameter value is not of the + * {@link ParameterId#preferred_address} format + * @see ParameterId#preferred_address + */ + public void setPreferredAddressParameter(TransportParameterId id, + Inet4Address ipv4, int port4, + Inet6Address ipv6, int port6, + ByteBuffer connectionId, + ByteBuffer statelessToken) { + ParameterId pid = (ParameterId)id; + if (pid != ParameterId.preferred_address) { + throw new IllegalArgumentException(String.valueOf(id)); + } + int cidlen = connectionId.remaining(); + if (cidlen == 0 || cidlen > QuicConnectionId.MAX_CONNECTION_ID_LENGTH) { + throw new IllegalArgumentException( + "connection id len out of range [1..20]: " + cidlen); + } + int tklen = statelessToken.remaining(); + if (tklen != TOKEN_SIZE) { + throw new IllegalArgumentException("bad stateless token length: expected 16, found " + tklen); + } + if (port4 < 0 || port4 > MAX_PORT) + throw new IllegalArgumentException("IPv4 port out of range: " + port4); + if (port6 < 0 || port6 > MAX_PORT) + throw new IllegalArgumentException("IPv6 port out of range: " + port6); + int size = MIN_PREF_ADDR_SIZE + cidlen; + byte[] value = new byte[size]; + ByteBuffer buffer = ByteBuffer.wrap(value); + if (!ipv4.isAnyLocalAddress()) { + buffer.put(IPV4_ADDR_OFFSET, ipv4.getAddress()); + } + buffer.putShort(IPV4_PORT_OFFSET, (short) port4); + if (!ipv6.isAnyLocalAddress()) { + buffer.put(IPV6_ADDR_OFFSET, ipv6.getAddress()); + } + buffer.putShort(IPV6_PORT_OFFSET, (short)port6); + buffer.put(CID_LEN_OFFSET, (byte) cidlen); + buffer.put(CID_OFFSET, connectionId, connectionId.position(), cidlen); + assert size - CID_OFFSET - cidlen == TOKEN_SIZE : (size - CID_OFFSET - cidlen); + assert tklen == TOKEN_SIZE; + buffer.put(CID_OFFSET + cidlen, statelessToken, statelessToken.position(), tklen); + values.put(pid, value); + } + + /** + * {@return the size in bytes required to encode the parameter + * array} + */ + public int size() { + int size = 0; + for (var kv : values.entrySet()) { + var i = kv.getKey().idx(); + var value = kv.getValue(); + if (value == null) continue; + assert value.length > 0 || i == ParameterId.disable_active_migration.idx(); + size += VariableLengthEncoder.getEncodedSize(i); + size += VariableLengthEncoder.getEncodedSize(value.length); + size += value.length; + } + return size; + } + + /** + * Encodes the transport parameters into the given byte buffer. + *

+ * From + * RFC 9000, Section 18.2: + *

{@code
+     * The extension_data field of the quic_transport_parameters
+     * extension defined in [QUIC-TLS] contains the QUIC transport
+     * parameters. They are encoded as a sequence of transport
+     * parameters, as shown in Figure 20:
+     *
+     * Transport Parameters {
+     *   Transport Parameter (..) ...,
+     * }
+     *
+     * Figure 20: Sequence of Transport Parameters
+     *
+     * Each transport parameter is encoded as an (identifier, length,
+     * value) tuple, as shown in Figure 21:
+     *
+     * Transport Parameter {
+     *   Transport Parameter ID (i),
+     *   Transport Parameter Length (i),
+     *   Transport Parameter Value (..),
+     * }
+     * }
+ * + * @param buffer a byte buffer in which to encode the transport parameters + * @return the number of bytes written + * @throws BufferOverflowException if there is not enough space in the + * provided buffer + * @see jdk.internal.net.quic.QuicTLSEngine#setLocalQuicTransportParameters(ByteBuffer) + * @see + * RFC 9000, Section 18 + * @see + * RFC 9001 [QUIC-TLS] + */ + public int encode(ByteBuffer buffer) { + int start = buffer.position(); + for (var kv : values.entrySet()) { + var i = kv.getKey().idx(); + var value = kv.getValue(); + if (value == null) continue; + + VariableLengthEncoder.encode(buffer, i); + VariableLengthEncoder.encode(buffer, value.length); + buffer.put(value); + } + var written = buffer.position() - start; + if (QuicTransportParameters.class.desiredAssertionStatus()) { + int size = size(); + assert written == size + : "unexpected number of bytes encoded: %d, expected %d" + .formatted(written, size); + } + return written; + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder("Quic Transport Params["); + for (var kv : values.entrySet()) { + var param = kv.getKey(); + var value = kv.getValue(); + if (value != null) { + // param is set + // we just return the string representation of the param ids and don't include + // the encoded values + sb.append(param); + sb.append(", "); + } + } + return sb.append("]").toString(); + } + + // values for (variable length) integer params are decoded, for other params + // that are set, the value is printed as a hex string. + public String toStringWithValues() { + final StringBuilder sb = new StringBuilder("Quic Transport Params["); + for (var kv : values.entrySet()) { + var param = kv.getKey(); + var value = kv.getValue(); + if (value != null) { + // param is set, so include it in the string representation + sb.append(param); + final String valAsString = valueToString(param); + sb.append("=").append(valAsString); + sb.append(", "); + } + } + return sb.append("]").toString(); + } + + private String valueToString(final ParameterId parameterId) { + assert this.values.get(parameterId) != null : "param " + parameterId + " not set"; + try { + return switch (parameterId) { + // int params + case max_idle_timeout, max_udp_payload_size, initial_max_data, + initial_max_stream_data_bidi_local, + initial_max_stream_data_bidi_remote, + initial_max_stream_data_uni, initial_max_streams_bidi, + initial_max_streams_uni, ack_delay_exponent, max_ack_delay, + active_connection_id_limit -> + String.valueOf(getIntParameter(parameterId)); + default -> + '"' + HexFormat.of().formatHex(values.get(parameterId)) + '"'; + }; + } catch (RuntimeException e) { + // if the value was a malformed integer, return the hex representation + return '"' + HexFormat.of().formatHex(values.get(parameterId)) + '"'; + } + } + + /** + * Decodes the quic transport parameters from the given buffer. + * Parameters which are not supported are silently discarded. + * + * @param buffer a byte buffer containing the transport parameters + * + * @return the decoded transport parameters + * @throws QuicTransportException if the parameters couldn't be decoded + * + * @see jdk.internal.net.quic.QuicTLSEngine#setRemoteQuicTransportParametersConsumer(QuicTransportParametersConsumer) (ByteBuffer) + * @see jdk.internal.net.quic.QuicTransportParametersConsumer#accept(ByteBuffer) + * @see #encode(ByteBuffer) + * @see + * RFC 9000, Section 18 + */ + public static QuicTransportParameters decode(ByteBuffer buffer) + throws QuicTransportException { + QuicTransportParameters parameters = new QuicTransportParameters(); + while (buffer.hasRemaining()) { + final long id = VariableLengthEncoder.decode(buffer); + final ParameterId pid = TransportParameterId.valueOf(id) + .orElse(null); + final String name = pid == null ? String.valueOf(id) : pid.toString(); + long length = VariableLengthEncoder.decode(buffer); + if (length < 0) { + throw new QuicTransportException( + "Can't decode length for transport parameter " + name, + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + if (length > buffer.remaining()) { + throw new QuicTransportException("Transport parameter truncated", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + byte[] value = new byte[(int) length]; + buffer.get(value); + if (pid == null) { + // RFC-9000, section 7.4.2: An endpoint MUST ignore transport parameters + // that it does not support. + if (Log.quicControl()) { + Log.logQuic("ignoring unsupported transport parameter: " + name); + } + continue; + } + try { + checkParameterValue(pid, value); + } catch (RuntimeException e) { + throw new QuicTransportException(e.getMessage(), + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + var oldValue = parameters.values.putIfAbsent(pid, value); + if (oldValue != null) { + throw new QuicTransportException( + "Duplicate transport parameter " + name, + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + } + return parameters; + } + + /** + * Reads the preferred address encoded in the value + * of a parameter whose format corresponds to the {@link + * ParameterId#preferred_address} parameter. + * If the given {@code value} is {@code null}, this + * method returns {@code null}. + * Otherwise, the returned list contains + * at most one IPv4 address and/or one IPv6 address. + * + * @apiNote + * To obtain the list of addresses encoded in the + * {@link ParameterId#preferred_address} parameter, use + * {@link #getPreferredAddress(TransportParameterId, byte[]) + * getPreferredAddress(ParameterId.preferred_address,} + * {@link #getParameter(TransportParameterId) + * parameters.getParameter(ParameterId.preferred_address)}. + * + * @param id the parameter id + * @param value the value of the parameter + * @return a list of {@link InetSocketAddress}, or {@code null} if the + * given value is {@code null}. + * @see ParameterId#preferred_address + */ + public static List getPreferredAddress( + TransportParameterId id, byte[] value) { + if (value == null) return null; + if (value.length < MIN_PREF_ADDR_SIZE) { + throw new IllegalArgumentException(id + + ": not enough bytes in value; found " + value.length); + } + ByteBuffer buffer = ByteBuffer.wrap(value); + int ipv4port = buffer.getShort(IPV4_PORT_OFFSET) & 0xFFFF; + int ipv6port = buffer.getShort(IPV6_PORT_OFFSET) & 0xFFFF; + + byte[] ipv4 = new byte[IPV4_SIZE]; + buffer.get(IPV4_ADDR_OFFSET, ipv4); + byte[] ipv6 = new byte[IPV6_SIZE]; + buffer.get(IPV6_ADDR_OFFSET, ipv6); + InetSocketAddress ipv4addr = new InetSocketAddress(getByAddress(id, ipv4), ipv4port); + InetSocketAddress ipv6addr = new InetSocketAddress(getByAddress(id, ipv6), ipv6port); + return Stream.of(ipv4addr, ipv6addr) + .filter((isa) -> !isa.getAddress().isAnyLocalAddress()) + .toList(); + } + + /** + * Reads the connection id bytes from the value of a parameter + * whose format corresponds to the {@link ParameterId#preferred_address} + * parameter. + * If the given {@code value} is {@code null}, this + * method returns {@code null}. + * + * @param preferredAddressValue the value of {@link ParameterId#preferred_address} param + * @return the connection id bytes + * @see ParameterId#preferred_address + */ + public static ByteBuffer getPreferredConnectionId(final byte[] preferredAddressValue) { + if (preferredAddressValue == null) { + return null; + } + final int length = getPreferredConnectionIdLength(ParameterId.preferred_address, + preferredAddressValue); + return ByteBuffer.wrap(preferredAddressValue, CID_OFFSET, length); + } + + /** + * Reads the stateless token bytes from the value of a parameter + * whose format corresponds to the {@link ParameterId#preferred_address} + * parameter. + * + * If the given {@code value} is {@code null}, this + * method returns {@code null}. + * + * @param preferredAddressValue the value of {@link ParameterId#preferred_address} param + * @return the stateless reset token bytes + * @see ParameterId#preferred_address + */ + public static byte[] getPreferredStatelessResetToken(final byte[] preferredAddressValue) { + if (preferredAddressValue == null) { + return null; + } + final int length = getPreferredConnectionIdLength(ParameterId.preferred_address, + preferredAddressValue); + final int offset = CID_OFFSET + length; + final byte[] statelessResetToken = new byte[TOKEN_SIZE]; + System.arraycopy(preferredAddressValue, offset, statelessResetToken, 0, TOKEN_SIZE); + return statelessResetToken; + } + + static final byte[] NOBYTES = new byte[0]; + static final int IPV6_SIZE = 16; + static final int IPV4_SIZE = 4; + static final int PORT_SIZE = 2; + static final int TOKEN_SIZE = 16; + static final int CIDLEN_SIZE = 1; + static final int IPV4_ADDR_OFFSET = 0; + static final int IPV4_PORT_OFFSET = IPV4_ADDR_OFFSET + IPV4_SIZE; + static final int IPV6_ADDR_OFFSET = IPV4_PORT_OFFSET + PORT_SIZE; + static final int IPV6_PORT_OFFSET = IPV6_ADDR_OFFSET + IPV6_SIZE; + static final int CID_LEN_OFFSET = IPV6_PORT_OFFSET + PORT_SIZE; + static final int CID_OFFSET = CID_LEN_OFFSET + CIDLEN_SIZE; + static final int MIN_PREF_ADDR_SIZE = CID_OFFSET + TOKEN_SIZE; + static final int MAX_PORT = 0xFFFF; + + private static int getPreferredConnectionIdLength(TransportParameterId id, byte[] value) { + if (value.length < MIN_PREF_ADDR_SIZE) { + throw new IllegalArgumentException(id + + ": not enough bytes in value; found " + value.length); + } + int length = value[CID_LEN_OFFSET] & 0xFF; + if (length > QuicConnectionId.MAX_CONNECTION_ID_LENGTH || length == 0) { + throw new IllegalArgumentException(id + + ": invalid preferred connection ID length: " + length); + } + if (length != value.length - MIN_PREF_ADDR_SIZE) { + throw new IllegalArgumentException(id + + ": invalid preferred address length: " + value.length + + ", expected: " + (MIN_PREF_ADDR_SIZE + length)); + } + return length; + } + + private static InetAddress getByAddress(TransportParameterId id, byte[] address) { + try { + return InetAddress.getByAddress(address); + } catch (UnknownHostException x) { + // should not happen + throw new IllegalArgumentException(id + + "Invalid address: " + HexFormat.of().formatHex(address)); + } + } + + /** + * verifies that the {@code value} is acceptable (as specified in the RFC) for the + * {@code tpid} + * + * @param tpid the transport parameter id + * @param value the value + * @return the corresponding parameter id if the value is acceptable, else throws a + * {@link IllegalArgumentException} + */ + private static ParameterId checkParameterValue(TransportParameterId tpid, byte[] value) { + ParameterId id = (ParameterId)tpid; + if (value != null) { + switch (id) { + case disable_active_migration -> { + if (value.length > 0) + throw new IllegalArgumentException(id + + ": value must be null or 0-length; found " + + value.length + " bytes"); + } + case stateless_reset_token -> { + if (value.length != 16) + throw new IllegalArgumentException(id + + ": value must be null or 16 bytes long; found " + + value.length + " bytes"); + } + case initial_source_connection_id, original_destination_connection_id, + retry_source_connection_id -> { + if (value.length > QuicConnectionId.MAX_CONNECTION_ID_LENGTH) { + throw new IllegalArgumentException(id + + ": value must not exceed " + + QuicConnectionId.MAX_CONNECTION_ID_LENGTH + + "bytes; found " + value.length + " bytes"); + } + } + case preferred_address -> getPreferredConnectionIdLength(id, value); + case version_information -> { + if (value.length < 4 || value.length % 4 != 0) { + throw new IllegalArgumentException(id + + ": value length must be a positive multiple of 4 " + + "bytes; found " + value.length + " bytes"); + } + } + default -> { + long intvalue; + try { + intvalue = decodeVLIntFully(id, ByteBuffer.wrap(value)); + } catch (IllegalArgumentException x) { + throw x; + } catch (Exception x) { + throw new IllegalArgumentException(id + + ": value is not a valid variable length integer", x); + } + if (intvalue < 0) + throw new IllegalArgumentException(id + + ": value is not a valid variable length integer"); + switch (id) { + case max_udp_payload_size -> { + if (intvalue < 1200 || intvalue > 65527) { + throw new IllegalArgumentException(id + + ": value out of range [1200, 65527]; found " + + intvalue); + } + } + case ack_delay_exponent -> { + if (intvalue > 20) { + throw new IllegalArgumentException(id + + ": value out of range [0, 20]; found " + + intvalue); + } + } + case max_ack_delay -> { + if (intvalue >= (1 << 14)) { + throw new IllegalArgumentException(id + + ": value out of range [0, 2^14); found " + + intvalue); + } + } + case active_connection_id_limit -> { + if (intvalue < 2) { + throw new IllegalArgumentException(id + + ": value out of range [2...]; found " + + intvalue); + } + } + case initial_max_streams_bidi, initial_max_streams_uni -> { + if (intvalue >= 1L << 60) { + throw new IllegalArgumentException(id + + ": value out of range [0,2^60); found " + + intvalue); + } + } + } + } + } + } + return id; + } + + private static long decodeVLIntFully(ParameterId id, ByteBuffer buffer) { + long value = VariableLengthEncoder.decode(buffer); + if (value < 0 || value > (1L << 62) - 1) { + throw new IllegalArgumentException(id + + ": failed to decode variable length integer"); + } + if (buffer.hasRemaining()) + throw new IllegalArgumentException(id + + ": extra bytes in provided value at index " + + buffer.position()); + return value; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/TerminationCause.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/TerminationCause.java new file mode 100644 index 00000000000..9e441cf7873 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/TerminationCause.java @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.io.IOException; +import java.util.Objects; + +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import static jdk.internal.net.quic.QuicTransportErrors.NO_ERROR; + +// TODO: document this +public abstract sealed class TerminationCause { + private String logMsg; + private String peerVisibleReason; + private final long closeCode; + private final Throwable originalCause; + private final IOException reportedCause; + + private TerminationCause(final long closeCode, final Throwable closeCause) { + this.closeCode = closeCode; + this.originalCause = closeCause; + if (closeCause != null) { + this.logMsg = closeCause.toString(); + } + this.reportedCause = toReportedCause(this.originalCause, this.logMsg); + } + + private TerminationCause(final long closeCode, final String loggedAs) { + this.closeCode = closeCode; + this.originalCause = null; + this.logMsg = loggedAs; + this.reportedCause = toReportedCause(this.originalCause, this.logMsg); + } + + public final long getCloseCode() { + return this.closeCode; + } + + public final IOException getCloseCause() { + return this.reportedCause; + } + + public final String getLogMsg() { + return logMsg; + } + + public final TerminationCause loggedAs(final String logMsg) { + this.logMsg = logMsg; + return this; + } + + public final String getPeerVisibleReason() { + return this.peerVisibleReason; + } + + public final TerminationCause peerVisibleReason(final String reasonPhrase) { + this.peerVisibleReason = reasonPhrase; + return this; + } + + public abstract boolean isAppLayer(); + + public static TerminationCause forTransportError(final QuicTransportErrors err) { + return new TransportError(err); + } + + public static TerminationCause forTransportError(long errorCode, String loggedAs, long frameType) { + return new TransportError(errorCode, loggedAs, frameType); + } + + static SilentTermination forSilentTermination(final String loggedAs) { + return new SilentTermination(loggedAs); + } + + public static TerminationCause forException(final Throwable cause) { + Objects.requireNonNull(cause); + if (cause instanceof QuicTransportException qte) { + return new TransportError(qte); + } + return new InternalError(cause); + } + + // allows for higher (application) layer to inform the connection terminator + // that the higher layer had completed a graceful shutdown of the connection + // and the QUIC layer can now do an immediate close of the connection using + // the {@code closeCode} + public static TerminationCause appLayerClose(final long closeCode) { + return new AppLayerClose(closeCode, (Throwable)null); + } + + public static TerminationCause appLayerClose(final long closeCode, String loggedAs) { + return new AppLayerClose(closeCode, loggedAs); + } + + public static TerminationCause appLayerException(final long closeCode, + final Throwable cause) { + return new AppLayerClose(closeCode, cause); + } + + private static IOException toReportedCause(final Throwable original, + final String fallbackExceptionMsg) { + if (original == null) { + return fallbackExceptionMsg == null + ? new IOException("connection terminated") + : new IOException(fallbackExceptionMsg); + } else if (original instanceof QuicTransportException qte) { + return new IOException(qte.getMessage()); + } else if (original instanceof IOException ioe) { + return ioe; + } else { + return new IOException(original); + } + } + + + static final class TransportError extends TerminationCause { + final long frameType; + final QuicTLSEngine.KeySpace keySpace; + + private TransportError(final QuicTransportErrors err) { + super(err.code(), err.name()); + this.frameType = 0; // unknown frame type + this.keySpace = null; + } + + private TransportError(final QuicTransportException exception) { + super(exception.getErrorCode(), exception); + this.frameType = exception.getFrameType(); + this.keySpace = exception.getKeySpace(); + peerVisibleReason(exception.getReason()); + } + + public TransportError(long errorCode, String loggedAs, long frameType) { + super(errorCode, loggedAs); + this.frameType = frameType; + keySpace = null; + } + + @Override + public boolean isAppLayer() { + return false; + } + } + + static final class InternalError extends TerminationCause { + + private InternalError(final Throwable cause) { + super(QuicTransportErrors.INTERNAL_ERROR.code(), cause); + } + + @Override + public boolean isAppLayer() { + return false; + } + } + + static final class AppLayerClose extends TerminationCause { + private AppLayerClose(final long closeCode, String loggedAs) { + super(closeCode, loggedAs); + } + + // TODO: allow optionally to specify "name" of the close code for app layer + // like "H3_GENERAL_PROTOCOL_ERROR" (helpful in logging) + private AppLayerClose(final long closeCode, final Throwable cause) { + super(closeCode, cause); + } + + @Override + public boolean isAppLayer() { + return true; + } + } + + static final class SilentTermination extends TerminationCause { + + private SilentTermination(final String loggedAs) { + // the error code won't play any role, since silent termination + // doesn't cause any packets to be generated or sent to the peer + super(NO_ERROR.code(), loggedAs); + } + + @Override + public boolean isAppLayer() { + return false; // doesn't play a role in context of silent termination + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/VariableLengthEncoder.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/VariableLengthEncoder.java new file mode 100644 index 00000000000..91380fcfca4 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/VariableLengthEncoder.java @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +import java.nio.ByteBuffer; + +/** + * QUIC packets and frames commonly use a variable-length encoding for + * non-negative values. This encoding ensures that smaller values will use less + * in the packet or frame. + * + *

The QUIC variable-length encoding reserves the two most significant bits + * of the first byte to encode the size of the length value as a base 2 logarithm + * value. The length itself is then encoded on the remaining bits, in network + * byte order. This means that the length values will be encoded on 1, 2, 4, or + * 8 bytes and can encode 6-, 14-, 30-, or 62-bit values + * respectively, or a value within the range of 0 to 4611686018427387903 + * inclusive. + * + * @spec https://www.rfc-editor.org/rfc/rfc9000.html#integer-encoding + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public class VariableLengthEncoder { + + /** + * The maximum number of bytes on which a variable length + * integer can be encoded. + */ + public static final int MAX_INTEGER_LENGTH = 8; + + /** + * The maximum value a variable length integer can + * take. + */ + public static final long MAX_ENCODED_INTEGER = (1L << 62) - 1; + + static { + assert MAX_ENCODED_INTEGER == 4611686018427387903L; + } + + private VariableLengthEncoder() { + throw new InternalError("should not come here"); + } + + /** + * Decode a variable length value from {@code ByteBuffer}. This method assumes that the + * position of {@code buffer} has been set to the first byte where the length + * begins. If the methods completes successfully, the position will be set + * to the byte after the last byte read. + * + * @param buffer the {@code ByteBuffer} that the length will be decoded from + * + * @return the value. If an error occurs, {@code -1} is returned and + * the buffer position is left unchanged. + */ + public static long decode(ByteBuffer buffer) { + return decode(BuffersReader.single(buffer)); + } + + /** + * Decode a variable length value from {@code BuffersReader}. This method assumes that the + * position of {@code buffers} has been set to the first byte where the length + * begins. If the methods completes successfully, the position will be set + * to the byte after the last byte. + * + * @param buffers the {@code BuffersReader} that the length will be decoded from + * + * @return the value. If an error occurs, {@code -1} is returned and + * the buffer position is left unchanged. + */ + public static long decode(BuffersReader buffers) { + if (!buffers.hasRemaining()) + return -1; + + long pos = buffers.position(); + int lenByte = buffers.get(pos) & 0xFF; + pos++; + // read size of length from leading two bits + int prefix = lenByte >> 6; + int len = 1 << prefix; + // retrieve remaining bits that constitute the length + long result = lenByte & 0x3F; + long idx = 0, lim = buffers.limit(); + if (lim - pos < len - 1) return -1; + while (idx++ < len - 1) { + assert pos < lim; + result = ((result << Byte.SIZE) + (buffers.get(pos) & 0xFF)); + pos++; + } + // Set position of ByteBuffer to next byte following length + assert pos == buffers.position() + len; + assert pos <= buffers.limit(); + buffers.position(pos); + + assert (result >= 0) && (result < (1L << 62)); + return result; + } + + /** + * Encode (a variable length) value into {@code ByteBuffer}. This method assumes that the + * position of {@code buffer} has been set to the first byte where the length + * begins. If the methods completes successfully, the position will be set + * to the byte after the last length byte. + * + * @param buffer the {@code ByteBuffer} that the length will be encoded into + * @param value the variable length value + * + * @throws IllegalArgumentException + * if value supplied falls outside of acceptable bounds [0, 2^62-1], + * or if the given buffer doesn't contain enough space to encode the + * value + * + * @return the {@code position} of the buffer + */ + public static int encode(ByteBuffer buffer, long value) throws IllegalArgumentException { + // check for valid parameters + if (value < 0 || value > MAX_ENCODED_INTEGER) + throw new IllegalArgumentException( + "value supplied falls outside of acceptable bounds"); + if (!buffer.hasRemaining()) + throw new IllegalArgumentException( + "buffer does not contain enough bytes to store length"); + + // set length prefix to indicate size of length + int lengthPrefix = getVariableLengthPrefix(value); + assert lengthPrefix >= 0 && lengthPrefix <= 3; + lengthPrefix <<= (Byte.SIZE - 2); + + int lengthSize = getEncodedSize(value); + assert lengthSize > 0; + assert lengthSize <= 8; + + var limit = buffer.limit(); + var pos = buffer.position(); + + // check that it's possible to add length to buffer + if (lengthSize > limit - pos) + throw new IllegalArgumentException("buffer does not contain enough bytes to store length"); + + // create mask to use in isolating byte to transfer to buffer + long mask = 255L << (Byte.SIZE * (lengthSize - 1)); + // convert length to bytes and add to buffer + boolean isFirstByte = true; + for (int i = lengthSize; i > 0; i--) { + assert buffer.hasRemaining() : "no space left at " + (lengthSize - i); + assert mask != 0; + assert mask == (255L << ((i - 1) * 8)) + : "mask: %x, expected %x".formatted(mask, (255L << ((i - 1) * 8))); + + long b = value & mask; + for (int j = i - 1; j > 0; j--) { + b >>= Byte.SIZE; + } + + assert b == (value & mask) >> (8 * (i - 1)); + + if (isFirstByte) { + assert (b & 0xC0) == 0; + buffer.put((byte) (b | lengthPrefix)); + isFirstByte = false; + } else { + buffer.put((byte) b); + } + // move mask over to next byte - avoid carrying sign bit + mask = (mask >>> Byte.SIZE); + } + var bytes = buffer.position() - pos; + assert bytes == lengthSize; + return lengthSize; + } + + /** + * Returns the variable length prefix. + * The variable length prefix is the base 2 logarithm of + * the number of bytes required to encode + * a positive value as a variable length integer: + * [0, 1, 2, 3] for [1, 2, 4, 8] bytes. + * + * @param value the value to encode + * + * @throws IllegalArgumentException + * if the supplied value falls outside the acceptable bounds [0, 2^62-1] + * + * @return the base 2 logarithm of the number of bytes required to encode + * the value as a variable length integer. + */ + public static int getVariableLengthPrefix(long value) throws IllegalArgumentException { + if ((value > MAX_ENCODED_INTEGER) || (value < 0)) + throw new IllegalArgumentException("invalid length"); + + int lengthPrefix; + if (value > (1L << 30) - 1) + lengthPrefix = 3; // 8 bytes + else if (value > (1L << 14) - 1) + lengthPrefix = 2; // 4 bytes + else if (value > (1L << 6) - 1) + lengthPrefix = 1; // 2 bytes + else + lengthPrefix = 0; // 1 byte + + return lengthPrefix; + } + + /** + * Returns the number of bytes needed to encode + * the given value as a variable length integer. + * This a number between 1 and 8. + * + * @param value the value to encode + * + * @return the number of bytes needed to encode + * the given value as a variable length integer. + * + * @throws IllegalArgumentException + * if the value supplied falls outside of acceptable bounds [0, 2^62-1] + */ + public static int getEncodedSize(long value) throws IllegalArgumentException { + if (value < 0 || value > MAX_ENCODED_INTEGER) + throw new IllegalArgumentException("invalid variable length integer: " + value); + return 1 << getVariableLengthPrefix(value); + } + + /** + * Peeks at a variable length value encoded at the given offset. + * If the byte buffer doesn't contain enough bytes to read the + * variable length value, -1 is returned. + * + *

This method doesn't advance the buffer position. + * + * @param buffer the buffer to read from + * @param offset the offset in the buffer to start reading from + * + * @return the variable length value encoded at the given offset, or -1 + */ + public static long peekEncodedValue(ByteBuffer buffer, int offset) { + return peekEncodedValue(BuffersReader.single(buffer), offset); + } + + /** + * Peeks at a variable length value encoded at the given offset. + * If the byte buffer doesn't contain enough bytes to read the + * variable length value, -1 is returned. + * + * This method doesn't advance the buffer position. + * + * @param buffers the buffer to read from + * @param offset the offset in the buffer to start reading from + * + * @return the variable length value encoded at the given offset, or -1 + */ + public static long peekEncodedValue(BuffersReader buffers, long offset) { + + // figure out on how many bytes the length is encoded. + int size = peekEncodedValueSize(buffers, offset); + if (size <= 0) return -1L; + assert size > 0 && size <= 8; + + // check that we have enough bytes in the buffer + long limit = buffers.limit(); + long pos = offset; + if (limit - size < pos) return -1L; + + // peek at the variable length: + // - read first byte + int first = buffers.get(pos++); + long res = first & 0x3F; + if (size == 1) return res; + + // - read the rest of the bytes + size -= 1; + assert size > 0; + for (int i=0 ; i < size; i++) { + if (limit <= pos) return -1L; + res = (res << 8) | (long) (buffers.get(pos++) & 0xFF); + } + return res; + } + + /** + * Peeks at a variable length value encoded at the given offset, + * and return the number of bytes on which this value is encoded. + * If the byte buffer is empty or the offset is past + * the limit -1 is returned. + * This method doesn't advance the buffer position. + * + * @param buffer the buffer to read from + * @param offset the offset in the buffer to start reading from + * + * @return the number of bytes on which the variable length + * value is encoded at the given offset, or -1 + */ + public static int peekEncodedValueSize(ByteBuffer buffer, int offset) { + return peekEncodedValueSize(BuffersReader.single(buffer), offset); + } + + /** + * Peeks at a variable length value encoded at the given offset, + * and return the number of bytes on which this value is encoded. + * If the byte buffer is empty or the offset is past + * the limit -1 is returned. + * This method doesn't advance the buffer position. + * + * @param buffers the buffers to read from + * @param offset the offset in the buffer to start reading from + * + * @return the number of bytes on which the variable length + * value is encoded at the given offset, or -1 + */ + public static int peekEncodedValueSize(BuffersReader buffers, long offset) { + long limit = buffers.limit(); + long pos = offset; + if (limit <= pos) return -1; + int first = buffers.get(pos); + int prefix = (first & 0xC0) >>> 6; + int size = 1 << prefix; + assert size > 0 && size <= 8; + return size; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/AckFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/AckFrame.java new file mode 100644 index 00000000000..7983d1be4f0 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/AckFrame.java @@ -0,0 +1,931 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; +import java.util.Objects; +import java.util.Spliterator; +import java.util.function.LongConsumer; +import java.util.stream.LongStream; +import java.util.stream.StreamSupport; + +/** + * An ACK Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class AckFrame extends QuicFrame { + + private final long largestAcknowledged; + private final long ackDelay; + private final int ackRangeCount; + private final List ackRanges; + + private final boolean countsPresent; + private final long ect0Count; + private final long ect1Count; + private final long ecnCECount; + private final int size; + + private static final int COUNTS_PRESENT = 0x1; + + /** + * Reads an {@code AckFrame} from the given buffer. When entering + * this method the buffer position is supposed to be just past + * after the frame type. That, is the frame type has already + * been read. This method moves the position of the buffer to the + * first byte after the read ACK frame. + * @param buffer a buffer containing the ACK frame + * @param type the frame type read from the buffer + * @throws QuicTransportException if the ACK frame was malformed + */ + AckFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(ACK); + int pos = buffer.position(); + largestAcknowledged = decodeVLField(buffer, "largestAcknowledged"); + ackDelay = decodeVLField(buffer, "ackDelay"); + ackRangeCount = decodeVLFieldAsInt(buffer, "ackRangeCount"); + long firstAckRange = decodeVLField(buffer, "firstAckRange"); + long smallestAcknowledged = largestAcknowledged - firstAckRange; + if (smallestAcknowledged < 0) { + throw new QuicTransportException("Negative PN acknowledged", + null, type, + QuicTransportErrors.FRAME_ENCODING_ERROR); + } + var ackRanges = new ArrayList(ackRangeCount + 1); + AckRange first = AckRange.of(0, firstAckRange); + ackRanges.add(0, first); + for (int i=1; i <= ackRangeCount; i++) { + long gap = decodeVLField(buffer, "gap"); + long len = decodeVLField(buffer, "range length"); + ackRanges.add(i, AckRange.of(gap, len)); + smallestAcknowledged -= gap + len + 2; + if (smallestAcknowledged < 0) { + // verify after each range to avoid wrap around + throw new QuicTransportException("Negative PN acknowledged", + null, type, QuicTransportErrors.FRAME_ENCODING_ERROR); + } + } + this.ackRanges = List.copyOf(ackRanges); + if (type % 2 == 1) { + // packet contains ECN counts + countsPresent = true; + ect0Count = decodeVLField(buffer, "ect0Count"); + ect1Count = decodeVLField(buffer, "ect1Count"); + ecnCECount = decodeVLField(buffer, "ecnCECount"); + } else { + countsPresent = false; + ect0Count = -1; + ect1Count = -1; + ecnCECount = -1; + } + size = computeSize(); + int wireSize = buffer.position() - pos + getVLFieldLengthFor(getTypeField()); + assert size <= wireSize : "parsed: %s, computed size: %s" + .formatted(wireSize, size); + } + + /** + * Creates the short formed ACK frame with no count totals + */ + public AckFrame(long largestAcknowledged, long ackDelay, List ackRanges) + { + this(largestAcknowledged, ackDelay, ackRanges, -1, -1, -1); + } + + /** + * Creates the long formed ACK frame with count totals + */ + public AckFrame( + long largestAcknowledged, + long ackDelay, + List ackRanges, + long ect0Count, + long ect1Count, + long ecnCECount) + { + super(ACK); + this.largestAcknowledged = requireVLRange(largestAcknowledged, "largestAcknowledged"); + this.ackDelay = requireVLRange(ackDelay, "ackDelay"); + if (ackRanges.size() < 1) { + throw new IllegalArgumentException("insufficient ackRanges"); + } + if (ackRanges.get(0).gap() != 0) { + throw new IllegalArgumentException("first range must have zero gap"); + } + this.ackRanges = List.copyOf(ackRanges); + this.ackRangeCount = ackRanges.size() - 1; + this.countsPresent = ect0Count != -1 || ect1Count != -1 || ecnCECount != -1; + if (countsPresent) { + this.ect0Count = requireVLRange(ect0Count,"ect0Count"); + this.ect1Count = requireVLRange(ect1Count, "ect1Count"); + this.ecnCECount = requireVLRange(ecnCECount, "ecnCECount"); + } else { + this.ect0Count = ect0Count; + this.ect1Count = ect1Count; + this.ecnCECount = ecnCECount; + } + this.size = computeSize(); + } + + @Override + public long getTypeField() { + return ACK | (countsPresent ? COUNTS_PRESENT : 0); + } + + @Override + public boolean isAckEliciting() { return false; } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, getTypeField(), "type"); + encodeVLField(buffer, largestAcknowledged, "largestAcknowledged"); + encodeVLField(buffer, ackDelay, "ackDelay"); + encodeVLField(buffer, ackRangeCount, "ackRangeCount"); + encodeVLField(buffer, ackRanges.get(0).range(), "firstAckRange"); + for (int i=1; i <= ackRangeCount; i++) { + AckRange ar = ackRanges.get(i); + encodeVLField(buffer, ar.gap(), "gap"); + encodeVLField(buffer, ar.range(), "range"); + } + if (countsPresent) { + // encode the counts + encodeVLField(buffer, ect0Count, "ect0Count"); + encodeVLField(buffer, ect1Count, "ect1Count"); + encodeVLField(buffer, ecnCECount, "ecnCECount"); + } + assert buffer.position() - pos == size(); + } + + private int computeSize() { + int size = getVLFieldLengthFor(getTypeField()) + + getVLFieldLengthFor(largestAcknowledged) + + getVLFieldLengthFor(ackDelay) + + getVLFieldLengthFor(ackRangeCount) + + getVLFieldLengthFor(ackRanges.get(0).range()) + + ackRanges.stream().skip(1).mapToInt(AckRange::size).sum(); + if (countsPresent) { + size = size + getVLFieldLengthFor(ect0Count) + + getVLFieldLengthFor(ect1Count) + + getVLFieldLengthFor(ecnCECount); + } + return size; + } + + @Override + public int size() { return size; } + + /** + * {@return largest packet number acknowledged by this frame} + */ + public long largestAcknowledged() { + return largestAcknowledged; + } + + /** + * The ACK delay + */ + public long ackDelay() { + return ackDelay; + } + + /** + * {@return the number of ack ranges} + * This corresponds to {@link #ackRanges() ackRange.size() -1}. + */ + public long ackRangeCount() { + return ackRangeCount; + } + + /** + * {@return a new {@code AckFrame} identical to this one, but + * with the given {@code ackDelay}}; + * @param ackDelay + */ + public AckFrame withAckDelay(long ackDelay) { + if (ackDelay == this.ackDelay) return this; + return new AckFrame(largestAcknowledged, ackDelay, ackRanges, + ect0Count, ect1Count, ecnCECount); + } + + /** + * An ACK range, composed of a gap and a range. + */ + public record AckRange(long gap, long range) { + public static final AckRange INITIAL = new AckRange(0, 0); + public AckRange { + requireVLRange(gap, "gap"); + requireVLRange(range, "range"); + } + public int size() { + return getVLFieldLengthFor(gap) + getVLFieldLengthFor(range); + } + public static AckRange of(long gap, long range) { + if (gap == 0 && range == 0) return INITIAL; + return new AckRange(gap, range); + } + } + + /** + * The ack ranges. First element is an actual range relative + * to highest acknowledged packet number. Second (if present) + * is a gap and a range following that gap, and so on until the last. + * @return the list of {@code AckRange} where the first ack range + * has a gap of {@code 0} and a range corresponding to + * the {@code First ACK Range}. + */ + public List ackRanges() { + return ackRanges; + } + + /** + * {@return the ECT0 count from this frame or -1 if not present} + */ + public long ect0Count() { + return ect0Count; + } + + /** + * {@return the ECT1 count from this frame or -1 if not present} + */ + public long ect1Count() { + return ect1Count; + } + + /** + * {@return the ECN-CE count from this frame or -1 if not present} + */ + public long ecnCECount() { + return ecnCECount; + } + + /** + * {@return true if this frame contains an acknowledgment for the + * given packet number} + * @param packetNumber a packet number + */ + public boolean isAcknowledging(long packetNumber) { + return isAcknowledging(largestAcknowledged, ackRanges, packetNumber); + } + + /** + * {@return true if the given range is acknowledged by this frame} + * @param first the first packet in the range, inclusive + * @param last the last packet in the range, inclusive + */ + public boolean isRangeAcknowledged(long first, long last) { + return isRangeAcknowledged(largestAcknowledged, ackRanges, first, last); + } + + + /** + * {@return the smallest packet number acknowledged by this {@code AckFrame}} + */ + public long smallestAcknowledged() { + return smallestAcknowledged(largestAcknowledged, ackRanges); + } + + /** + * @return a stream of packet numbers acknowledged by this frame + */ + public LongStream acknowledged() { + return StreamSupport.longStream(new AckFrameSpliterator(this), false); + } + + + private static class AckFrameSpliterator implements Spliterator.OfLong { + + final AckFrame ackFrame; + + AckFrameSpliterator(AckFrame ackFrame) { + this.ackFrame = ackFrame; + this.largest = ackFrame.largestAcknowledged(); + this.smallest = largest + 2; + this.ackRangeIterator = ackFrame.ackRanges.iterator(); + } + + @Override + public long estimateSize() { + // It is costly to compute an estimate, so we just + // return Long.MAX_VALUE instead + return Long.MAX_VALUE; + } + + @Override + public int characteristics() { + // NONNULL - nulls are not expected to be returned by this long spliterator + // IMMUTABLE - ackFrame.ackRanges() returns unmodifiable list, which cannot be + // structurally modified + return NONNULL | IMMUTABLE; + } + + @Override + public OfLong trySplit() { + // null - this spliterator cannot be split + return null; + } + private final Iterator ackRangeIterator; + private volatile long largest; + private volatile long smallest; + private volatile long pn; // the current packet number + + // The stream returns packet number in decreasing order + // (largest packet number is returned first) + private boolean ackAndDecId(LongConsumer action) { + assert ackFrame.isAcknowledging(pn) + : "%s is not acknowledging %s".formatted(ackFrame, pn); + action.accept(pn--); + return true; + } + + @Override + public boolean tryAdvance(LongConsumer action) { + // First call will see pn == 0 and smallest >= 2, + // which guarantees we will not enter the if below + // before pn has been initialized from the + // first ackRange value + if (pn >= smallest) { + return ackAndDecId(action); + } + if (ackRangeIterator.hasNext()) { + var ackRange = ackRangeIterator.next(); + largest = smallest - ackRange.gap() - 2; + smallest = largest - ackRange.range; + pn = largest; + return ackAndDecId(action); + } + return false; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + return o instanceof AckFrame ackFrame + && largestAcknowledged == ackFrame.largestAcknowledged + && ackDelay == ackFrame.ackDelay + && ackRangeCount == ackFrame.ackRangeCount + && countsPresent == ackFrame.countsPresent + && ect0Count == ackFrame.ect0Count + && ect1Count == ackFrame.ect1Count + && ecnCECount == ackFrame.ecnCECount + && ackRanges.equals(ackFrame.ackRanges); + } + + @Override + public int hashCode() { + return Objects.hash(largestAcknowledged, ackDelay, + ackRanges, ect0Count, ect1Count, ecnCECount); + } + + @Override + public String toString() { + String res = "AckFrame(" + + "largestAcknowledged=" + largestAcknowledged + + ", ackDelay=" + ackDelay + + ", ackRanges=[" + prettyRanges() + "]"; + if (countsPresent) res = res + + ", ect0Count=" + ect0Count + + ", ect1Count=" + ect1Count + + ", ecnCECount=" + ecnCECount; + res += ")"; + return res; + } + + private String prettyRanges() { + String result = null; + long largest; + long smallest = largestAcknowledged + 2; + for (var ackRange : ackRanges) { + largest = smallest - ackRange.gap - 2; + smallest = largest - ackRange.range; + result = smallest + ".." + largest + (result != null ? ", "+result : ""); + } + return result; + } + + /** + * {@return the largest packet acknowledged by an + * {@link QuicFrame#ACK ACK} frame contained in the + * given packet, or {@code -1L} if the packet + * contains no {@code ACK} frame} + * @param packet a packet that may contain an {@code ACK} frame + */ + public static long largestAcknowledgedInPacket(QuicPacket packet) { + return packet.frames().stream() + .filter(AckFrame.class::isInstance) + .map(AckFrame.class::cast) + .mapToLong(AckFrame::largestAcknowledged) + .max().orElse(-1L); + } + + /** + * A builder that allows to incrementally build the AckFrame + * that will need to be sent, as new packets are received. + * This class is not MT-thread safe. + */ + public static final class AckFrameBuilder { + long largestAckAcked = -1; + long largestAcknowledged = -1; + long ackDelay = 0; + List ackRanges = new ArrayList<>(); + long ect0Count = -1; + long ect1Count = -1; + long ecnCECount = -1; + + /** + * An empty builder. + */ + public AckFrameBuilder() {} + + /** + * A builder initialize from the content of an AckFrame. + * @param frame the {@code AckFrame} to initialize this builder with. + * Must not be {@code null}. + */ + public AckFrameBuilder(AckFrame frame) { + largestAckAcked = -1; + largestAcknowledged = frame.largestAcknowledged; + ackDelay = frame.ackDelay; + ackRanges.addAll(frame.ackRanges); + ect0Count = frame.ect0Count; + ect1Count = frame.ect1Count; + ecnCECount = frame.ecnCECount; + } + + public long getLargestAckAcked() { + return largestAckAcked; + } + + /** + * Drops all acks for packet whose number is smaller + * than the given {@code largestAckAcked}. + * @param largestAckAcked the smallest packet number that + * should be acknowledged by this + * {@link AckFrame}. + * @return this builder + */ + public AckFrameBuilder dropAcksBefore(long largestAckAcked) { + if (largestAckAcked > this.largestAckAcked) { + this.largestAckAcked = largestAckAcked; + return dropIfSmallerThan(largestAckAcked); + } else { + this.largestAckAcked = largestAckAcked; + } + return this; + } + + /** + * Drops all instances of {@link AckRange} after the given + * index in the {@linkplain #ackRanges() Ack Range List}, and compute + * the new smallest packet number now acknowledged by this + * {@link AckFrame}: this computed packet number will then be + * returned by {@link #getLargestAckAcked()}. + * This is a no-op if index is greater or equal to + * {@code ackRanges().size() -1}. + * @param index the index after which ranges should be dropped. + * @return this builder + */ + public AckFrameBuilder dropAckRangesAfter(int index) { + if (index < 0) { + throw new IllegalArgumentException("invalid index %s for size %s" + .formatted(index, ackRanges.size())); + } + if (index >= ackRanges.size() - 1) return this; + long newLargestAckAcked = dropRangesIfAfter(index); + assert newLargestAckAcked > largestAckAcked; + largestAckAcked = newLargestAckAcked; + return this; + } + + /** + * Sets the ack delay. + * @param ackDelay the ack delay. + * @return this builder. + */ + public AckFrameBuilder ackDelay(long ackDelay) { + this.ackDelay = ackDelay; + return this; + } + + /** + * Sets the ect0Count. Passing -1 unsets the ectOcount. + * @param ect0Count the ect0Count + * @return this builder. + */ + public AckFrameBuilder ect0Count(long ect0Count) { + this.ect0Count = ect0Count; + return this; + } + + /** + * Sets the ect1Count. Passing -1 unsets the ect1count. + * @param ect1Count the ect1Count + * @return this builder. + */ + public AckFrameBuilder ect1Count(long ect1Count) { + this.ect1Count = ect1Count; + return this; + } + + /** + * Sets the ecnCECount. Passing -1 unsets the ecnCEOcount. + * @param ecnCECount the ecnCECount + * @return this builder. + */ + public AckFrameBuilder ecnCECount(long ecnCECount) { + this.ecnCECount = ecnCECount; + return this; + } + + /** + * Adds the given packet number to the list of ack ranges. + * If the packet is already being acknowledged by this frame, + * do nothing. + * @param packetNumber the packet number + * @return this builder + */ + public AckFrameBuilder addAck(long packetNumber) { + // check if we need to acknowledge this packet + if (packetNumber <= largestAckAcked) return this; + // System.out.println("adding " + packetNumber); + if (ackRanges.isEmpty()) { + // easy case: we only have one packet to acknowledge! + return acknowledgeFirstPacket(packetNumber); + } else if (packetNumber > largestAcknowledged) { + return acknowledgeLargerPacket(packetNumber); + } else if (packetNumber < largestAcknowledged) { + // now is the complex case: we need to find out: + // - whether this packet is already acknowledged, in which case, + // there is nothing to do (great) + // - or whether we can extend an existing range + // - or whether we need to create a new range (if the packet falls + // within a gap whose value is > 0). + // - or whether we should merge two ranges if the packet falls + // on a gap whose value is 0 + ListIterator iterator = ackRanges.listIterator(); + long largest = largestAcknowledged; + long smallest = largest + 2; + int index = -1; + while (iterator.hasNext()) { + var ackRange = iterator.next(); + // index of the current ackRange element + index++; + // largest packet number acknowledged by this ackRange + largest = smallest - ackRange.gap - 2; + // smallest packet number acknowledged by this ackRange + smallest = largest - ackRange.range; + + // if the packet number we want to acknowledge is greater + // than the largest packet acknowledged by this ackRange + // there are two cases: + if (packetNumber > largest) { + // the packet number is just above the largest packet + if (packetNumber -1 == largest) { + // the current ackRange must have a gap, and we can simply + // reduce that gap by 1, and extend the range by 1. + // the case where the current ackrange doesn't have a gap + // and the packet number is the largest + 1 should have + // been handled when processing the previous ackRange. + assert ackRange.gap > 0; + var gap = ackRange.gap - 1; + var range = ackRange.range + 1; + var replaced = AckRange.of(gap, range); + ackRanges.set(index, replaced); + return this; + } else { + // the packet falls within the gap of this ack range. + // we need to split the ackRange in two... + // + // in the case where we have + // [31,31] [27,27] -> 31, AckRange[g=0, r=0], AckRange[g=2, r=0] + // and we want to acknowledge 29. + // we should end up with: + // [31,31] [29,29] [27,27] -> + // 31, AckRange[g=0, r=0], AckRange[g=0, r=0], AckRange[g=0, r=0] + assert ackRange.gap > 0 : "%s at index (prev:%s, next:%s)" + .formatted(ackRanges, iterator.previousIndex(), iterator.nextIndex()); + assert packetNumber - ackRange.gap -2 <= largest; + + // compute the smallest packet that was acknowledged by the + // previous ackRange. This should be: + var previousSmallest = largest + ackRange.gap + 2; + + // System.out.printf("ack: %s, largest:%s, previousSmallest:%d%n", + // ackRange, largest, previousSmallest); + + // compute the point at which we should split the current ackRange + // the current ackRange will be split in two: first, and second + // - first will replace the current ackRange + // - second will be inserted after first + var firstgap = previousSmallest - packetNumber -2; + AckRange first = AckRange.of(firstgap, 0); + AckRange second = AckRange.of(ackRange.gap - firstgap -2, ackRange.range); + ackRanges.set(index, first); + iterator.add(second); + return this; + } + } else if (packetNumber < smallest) { + // otherwise, if the packet number is smaller than + // the smallest packet acknowledged by the current ackRange, + // there are two cases: + + // If the current ackRange is the last: it's simple! + // But there are again two cases: + if (!iterator.hasNext()) { + // If the packet number we want to acknowledge is just below + // the smallest packet number acknowledge by the current + // ackRange, there is no gap between the packet number and + // the current range, so we can simply extend the current range + // Otherwise, we need to append a new ackRange. + if (packetNumber == smallest - 1) { + // no gap: we can extend the current range + AckRange replaced = AckRange + .of(ackRange.gap, ackRange.range + 1); + ackRanges.set(index, replaced); + } else { + // gap: we need to add a new AckRange + AckRange last = AckRange.of(smallest - packetNumber - 2, 0); + iterator.add(last); + } + return this; + } else if (packetNumber == smallest - 1) { + // Otherwise, if the packet number to be acknowledged is + // just below the smallest packet ackowledged by the current + // range, there are again two cases, depending on + // whether the next ackRange has a gap that can be reduced, + // or not + assert iterator.hasNext(); + AckRange next = ackRanges.get(index + 1); + // if the gap of the next packet can be reduced, that's great! + // just do it! We need to reduce that gap by one, and extend + // the range of the current ackRange + if (next.gap > 0) { + // reduce the gap in the next ackrange, and increase + // the range in the current ackrange. + // System.out.printf("ack: %s, next: %s%n", ackRange, next); + AckRange first = AckRange.of(ackRange.gap, ackRange.range + 1); + AckRange second = AckRange.of(next.gap - 1, next.range); + // System.out.printf("first: %s, second: %s%n", first, second); + ackRanges.set(index, first); + ackRanges.set(index + 1, second); + return this; + } else { + // Otherwise, that's the complex case again. + // we have a gap of 1 packet between 2 ackranges. + // our packet number falls exactly in that gap. + // We need to merge the two ranges! + // merge with next ackRange: remove the current ackRange, + // the ackRange at the current index is now the next ackRange, + // replace it with a merged ACK range. + var mergedRanges = ackRange.range + next.range + 2; + iterator.remove(); + ackRanges.set(index, AckRange.of(ackRange.gap, mergedRanges)); + return this; + } + } + } else { + // Otherwise, the packet is already acknowledged! + // nothing to do. + assert packetNumber <= largest && packetNumber >= smallest; + return this; + } + } + } else { + // already acknowledged! + assert packetNumber == largestAcknowledged; + return this; + } + return this; + } + + /** + * {@return true if this builder contains no ACK yet} + */ + public boolean isEmpty() { + return ackRanges.isEmpty(); + } + + /** + * {@return the number of ACK ranges in this builder, including the fake + * first ACK range} + */ + public int length() { + return ackRanges.size(); + } + + /** + * {@return true if the given packet number is already acknowledged + * by this builder} + * @param packetNumber a packet number + */ + public boolean isAcknowledging(long packetNumber) { + if (isEmpty()) return false; + return AckFrame.isAcknowledging(largestAcknowledged, ackRanges, packetNumber); + } + + /** + * {@return the smallest packet number acknowledged by this {@code AckFrame}} + */ + public long smallestAcknowledged() { + if (largestAcknowledged == -1L) return -1L; + return AckFrame.smallestAcknowledged(largestAcknowledged, ackRanges); + } + + // drop acknowledgement of all packet numbers acknowledged + // by AckRange instances coming after the given index, and + // return the smallest packet number now acked by this + // AckFrame. + private long dropRangesIfAfter(int ackIndex) { + assert ackIndex > 0 && ackIndex < ackRanges.size(); + long largest = largestAcknowledged; + long smallest = largest + 2; + ListIterator iterator = ackRanges.listIterator(); + int index = -1; + boolean removeRemainings = false; + long newLargestAckAcked = -1; + while (iterator.hasNext()) { + if (index == ackIndex) { + newLargestAckAcked = smallest; + removeRemainings = true; + } + AckRange ackRange = iterator.next(); + if (removeRemainings) { + iterator.remove(); + continue; + } + index++; + largest = smallest - ackRange.gap - 2; + smallest = largest - ackRange.range; + } + return newLargestAckAcked; + } + + + // drop acknowledgement of all packet numbers less or equal + // to `largestAckAcked; + private AckFrameBuilder dropIfSmallerThan(long largestAckAcked) { + if (largestAckAcked >= largestAcknowledged) { + largestAcknowledged = -1; + ackRanges.clear(); + return this; + } + long largest = largestAcknowledged; + long smallest = largest + 2; + ListIterator iterator = ackRanges.listIterator(); + int index = -1; + boolean removeRemainings = false; + while (iterator.hasNext()) { + AckRange ackRange = iterator.next(); + if (removeRemainings) { + iterator.remove(); + continue; + } + index++; + largest = smallest - ackRange.gap - 2; + smallest = largest - ackRange.range; + if (largest <= largestAckAcked) { + iterator.remove(); + removeRemainings = true; + } else if (smallest <= largestAckAcked) { + long removed = largestAckAcked - smallest + 1; + long gap = ackRange.gap; + long range = ackRange.range - removed; + assert gap >= 0; + assert range >= 0; + ackRanges.set(index, new AckRange(gap, range)); + removeRemainings = true; + } + } + return this; + } + + /** + * Builds an {@code AckFrame} from this builder's content. + * @return a new {@code AckFrame}. + */ + public AckFrame build() { + return new AckFrame(largestAcknowledged, ackDelay, ackRanges, + ect0Count, ect1Count, ecnCECount); + } + + private AckFrameBuilder acknowledgeFirstPacket(long packetNumber) { + assert ackRanges.isEmpty(); + largestAcknowledged = packetNumber; + ackRanges.add(AckRange.INITIAL); + return this; + } + + private AckFrameBuilder acknowledgeLargerPacket(long largerThanLargest) { + var packetNumber = largerThanLargest; + // the new packet is larger than the largest acknowledged + var firstAckRange = ackRanges.get(0); + if (largestAcknowledged == packetNumber -1) { + // if packetNumber is largestAcknowledged + 1, we can simply + // extend the first ack range by 1 + firstAckRange = AckRange.of(0, firstAckRange.range + 1); + ackRanges.set(0, firstAckRange); + } else { + // otherwise - we have a gap - we need to acknowledge the new packetNumber, + // and then add the gap that separate it from the previous largestAcknowledged... + ackRanges.add(0, AckRange.INITIAL); // acknowledge packetNumber only + long gap = packetNumber - largestAcknowledged -2; + var secondAckRange = AckRange.of(gap, firstAckRange.range); + ackRanges.set(1, secondAckRange); // add the gap + } + largestAcknowledged = packetNumber; + return this; + } + + public static AckFrameBuilder ofNullable(AckFrame frame) { + return frame == null ? new AckFrameBuilder() : new AckFrameBuilder(frame); + } + + } + + // This is described in RFC 9000, Section 19.3.1 ACK Ranges + // https://www.rfc-editor.org/rfc/rfc9000#name-ack-ranges + private static boolean isAcknowledging(long largestAcknowledged, + List ackRanges, + long packetNumber) { + if (packetNumber > largestAcknowledged) return false; + var largest = largestAcknowledged; + long smallest = largestAcknowledged + 2; + for (var ackRange : ackRanges) { + largest = smallest - ackRange.gap - 2; + if (packetNumber > largest) return false; + smallest = largest - ackRange.range; + if (packetNumber >= smallest) return true; + } + return false; + } + + private static boolean isRangeAcknowledged(long largestAcknowledged, + List ackRanges, + long first, + long last) { + assert last >= first; + if (last > largestAcknowledged) return false; + var largest = largestAcknowledged; + long smallest = largestAcknowledged + 2; + for (var ackRange : ackRanges) { + largest = smallest - ackRange.gap - 2; + if (last > largest) return false; + smallest = largest - ackRange.range; + if (first >= smallest) return true; + } + return false; + } + + // This is described in RFC 9000, Section 19.3.1 ACK Ranges + // https://www.rfc-editor.org/rfc/rfc9000#name-ack-ranges + private static long smallestAcknowledged(long largestAcknowledged, + List ackRanges) { + long largest = largestAcknowledged; + long smallest = largest + 2; + assert !ackRanges.isEmpty(); + for (AckRange ackRange : ackRanges) { + largest = smallest - ackRange.gap - 2; + smallest = largest - ackRange.range; + } + return smallest; + } + + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/ConnectionCloseFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/ConnectionCloseFrame.java new file mode 100644 index 00000000000..5e0f2abd8df --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/ConnectionCloseFrame.java @@ -0,0 +1,238 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; + +import jdk.internal.net.quic.QuicTransportErrors; + +/** + * A CONNECTION_CLOSE Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class ConnectionCloseFrame extends QuicFrame { + + /** + * This variant indicates an error originating from the higher + * level protocol, for instance, HTTP/3. + */ + public static final int CONNECTION_CLOSE_VARIANT = 0x1d; + private final long errorCode; + private final long errorFrameType; + private final boolean variant; + private final byte[] reason; + private String cachedToString; + private String cachedReason; + + /** + * An immutable ConnectionCloseFrame of type 0x1c with no reason phrase + * and an error of type APPLICATION_ERROR. + * @apiNote + * From + * RFC 9000 - section 10.2.3: + *

+ * A CONNECTION_CLOSE of type 0x1d MUST be replaced by a CONNECTION_CLOSE + * of type 0x1c when sending the frame in Initial or Handshake packets. + * Otherwise, information about the application state might be revealed. + * Endpoints MUST clear the value of the Reason Phrase field and SHOULD + * use the APPLICATION_ERROR code when converting to a CONNECTION_CLOSE + * of type 0x1c. + *
+ */ + public static final ConnectionCloseFrame APPLICATION_ERROR = + new ConnectionCloseFrame(QuicTransportErrors.APPLICATION_ERROR.code(), 0,""); + + /** + * Incoming CONNECTION_CLOSE frame returned by QuicFrame.decode() + * + * @param buffer + * @param type + * @throws QuicTransportException if the frame was malformed + */ + ConnectionCloseFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(CONNECTION_CLOSE); + errorCode = decodeVLField(buffer, "errorCode"); + if (type == CONNECTION_CLOSE) { + variant = false; + errorFrameType = decodeVLField(buffer, "errorFrameType"); + } else { + assert type == CONNECTION_CLOSE_VARIANT; + errorFrameType = -1; + variant = true; + } + int reasonLength = decodeVLFieldAsInt(buffer, "reasonLength"); + validateRemainingLength(buffer, reasonLength, type); + reason = new byte[reasonLength]; + buffer.get(reason, 0, reasonLength); + } + + /** + * Outgoing CONNECTION_CLOSE frame (variant with errorFrameType - 0x1c). + * This indicates a {@linkplain jdk.internal.net.quic.QuicTransportErrors + * quic transport error}. + */ + public ConnectionCloseFrame(long errorCode, long errorFrameType, String reason) { + super(CONNECTION_CLOSE); + this.errorCode = requireVLRange(errorCode, "errorCode"); + this.errorFrameType = requireVLRange(errorFrameType, "errorFrameType"); + this.variant = false; + this.cachedReason = reason; + this.reason = getReasonBytes(reason); + } + + /** + * Outgoing CONNECTION_CLOSE frame (variant without errorFrameType). + * This indicates an error originating from the higher level protocol, + * for instance {@linkplain jdk.internal.net.http.http3.Http3Error HTTP/3}. + */ + public ConnectionCloseFrame(long errorCode, String reason) { + super(CONNECTION_CLOSE); + this.errorCode = requireVLRange(errorCode, "errorCode"); + this.errorFrameType = -1; + this.variant = true; + this.cachedReason = reason; + this.reason = getReasonBytes(reason); + } + + private static byte[] getReasonBytes(String reason) { + return reason != null ? + reason.getBytes(StandardCharsets.UTF_8) : + new byte[0]; + } + + /** + * {@return a ConnectionCloseFrame suitable for inclusion in + * an Initial or Handshake packet} + */ + public ConnectionCloseFrame clearApplicationState() { + return this.variant ? APPLICATION_ERROR : this; + } + + @Override + public long getTypeField() { + return variant ? CONNECTION_CLOSE_VARIANT : CONNECTION_CLOSE; + } + + @Override + public boolean isAckEliciting() { + return false; + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, getTypeField(), "type"); + encodeVLField(buffer, errorCode, "errorCode"); + if (!variant) { + encodeVLField(buffer, errorFrameType, "errorFrameType"); + } + encodeVLField(buffer, reason.length, "reasonLength"); + if (reason.length > 0) { + buffer.put(reason); + } + assert buffer.position() - pos == size(); + } + + @Override + public int size() { + return getVLFieldLengthFor(getTypeField()) + + getVLFieldLengthFor(errorCode) + + (variant ? 0 : getVLFieldLengthFor(errorFrameType)) + + getVLFieldLengthFor(reason.length) + + reason.length; + } + + public long errorCode() { + return errorCode; + } + + public long errorFrameType() { + return errorFrameType; + } + + public boolean variant() { + return variant; + } + + public boolean isQuicTransportCode() { + return !variant; + } + + public boolean isApplicationCode() { + return variant; + } + + public byte[] reason() { + return reason; + } + + public String reasonString() { + if (cachedReason != null) return cachedReason; + if (reason == null) return null; + if (reason.length == 0) return ""; + return cachedReason = new String(reason, StandardCharsets.UTF_8); + } + + @Override + public String toString() { + if (cachedToString == null) { + final StringBuilder sb = new StringBuilder("ConnectionCloseFrame[type=0x"); + final long type = getTypeField(); + sb.append(Long.toHexString(type)) + .append(", errorCode=0x").append(Long.toHexString(errorCode)); + // CRYPTO_ERROR codes ranging 0x0100-0x01ff + if (type == 0x1c) { + if (errorCode >= 0x0100 && errorCode <= 0x01ff) { + // this represents a CRYPTO_ERROR which as per RFC-9001, section 4.8: + // A TLS alert is converted into a QUIC connection error. The AlertDescription + // value is added to 0x0100 to produce a QUIC error code from the range reserved for + // CRYPTO_ERROR; ... The resulting value is sent in a QUIC CONNECTION_CLOSE + // frame of type 0x1c + + // find the tls alert code from the error code, by substracting 0x0100 from + // the error code + sb.append(", tlsAlertDescription=").append(errorCode - 0x0100); + } + sb.append(", errorFrameType=0x").append(Long.toHexString(errorFrameType)); + } + if (cachedReason == null) { + cachedReason = new String(reason, StandardCharsets.UTF_8); + } + sb.append(", reason=").append(cachedReason).append("]"); + + cachedToString = sb.toString(); + } + return cachedToString; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/CryptoFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/CryptoFrame.java new file mode 100644 index 00000000000..7b606527193 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/CryptoFrame.java @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; +import java.util.Objects; + +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.http.quic.VariableLengthEncoder; + +/** + * A CRYPTO Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class CryptoFrame extends QuicFrame { + + private final long offset; + private final int length; + private final ByteBuffer cryptoData; + + CryptoFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(CRYPTO); + offset = decodeVLField(buffer, "offset"); + length = decodeVLFieldAsInt(buffer, "length"); + if (offset + length > VariableLengthEncoder.MAX_ENCODED_INTEGER) { + throw new QuicTransportException("Maximum crypto offset exceeded", + null, type, QuicTransportErrors.FRAME_ENCODING_ERROR); + } + validateRemainingLength(buffer, length, type); + int pos = buffer.position(); + // The buffer is the datagram: we will make a copy if the datagram + // is larger than the crypto frame by 64 bytes. + cryptoData = Utils.sliceOrCopy(buffer, pos, length, 64); + buffer.position(pos + length); + } + + /** + * Creates CryptoFrame + */ + public CryptoFrame(long offset, int length, ByteBuffer cryptoData) { + this(offset, length, cryptoData, true); + } + + private CryptoFrame(long offset, int length, ByteBuffer cryptoData, boolean slice) + { + super(CRYPTO); + this.offset = requireVLRange(offset, "offset"); + if (length != cryptoData.remaining()) + throw new IllegalArgumentException("bad length: " + length); + this.length = length; + this.cryptoData = slice + ? cryptoData.slice(cryptoData.position(), length) + : cryptoData; + } + + /** + * Creates a new CryptoFrame which is a slice of this crypto frame. + * @param offset the new offset + * @param length the new length + * @return a slice of the current crypto frame + * @throws IndexOutOfBoundsException if the offset or length + * exceed the bounds of this crypto frame + */ + public CryptoFrame slice(long offset, int length) { + long offsetdiff = offset - offset(); + long oldlen = length(); + Objects.checkFromIndexSize(offsetdiff, length, oldlen); + int pos = cryptoData.position(); + // safe cast to int since offsetdiff < length + int newpos = Math.addExact(pos, (int)offsetdiff); + ByteBuffer slice = Utils.sliceOrCopy(cryptoData, newpos, length); + return new CryptoFrame(offset, length, slice, false); + } + + @Override + public void encode(ByteBuffer dest) { + if (size() > dest.remaining()) { + throw new BufferOverflowException(); + } + int pos = dest.position(); + encodeVLField(dest, CRYPTO, "type"); + encodeVLField(dest, offset, "offset"); + encodeVLField(dest, length, "length"); + assert cryptoData.remaining() == length; + putByteBuffer(dest, cryptoData); + assert dest.position() - pos == size(); + } + + @Override + public int size() { + return getVLFieldLengthFor(CRYPTO) + + getVLFieldLengthFor(offset) + + getVLFieldLengthFor(length) + + length; + } + + /** + * {@return the frame offset} + */ + public long offset() { + return offset; + } + + public int length() { + return length; + } + + /** + * {@return the frame payload} + */ + public ByteBuffer payload() { + return cryptoData.slice(); + } + + @Override + public String toString() { + return "CryptoFrame(" + + "offset=" + offset + + ", length=" + length + + ')'; + } + + public static int compareOffsets(CryptoFrame cf1, CryptoFrame cf2) { + return Long.compare(cf1.offset, cf2.offset); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/DataBlockedFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/DataBlockedFrame.java new file mode 100644 index 00000000000..b008e643cda --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/DataBlockedFrame.java @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A DATA_BLOCKED Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class DataBlockedFrame extends QuicFrame { + + private final long maxData; + + /** + * Incoming DATA_BLOCKED frame returned by QuicFrame.decode() + * + * @param buffer + * @param type + * @throws QuicTransportException if the frame was malformed + */ + DataBlockedFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(DATA_BLOCKED); + maxData = decodeVLField(buffer, "maxData"); + } + + /** + * Outgoing DATA_BLOCKED frame + */ + public DataBlockedFrame(long maxData) { + super(DATA_BLOCKED); + this.maxData = requireVLRange(maxData, "maxData"); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, DATA_BLOCKED, "type"); + encodeVLField(buffer, maxData, "maxData"); + assert buffer.position() - pos == size(); + } + + /** + */ + public long maxData() { + return maxData; + } + + @Override + public int size() { + return getVLFieldLengthFor(DATA_BLOCKED) + + getVLFieldLengthFor(maxData); + } + + @Override + public String toString() { + return "DataBlockedFrame(" + + "maxData=" + maxData + + ')'; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/HandshakeDoneFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/HandshakeDoneFrame.java new file mode 100644 index 00000000000..ffe6aff2f0d --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/HandshakeDoneFrame.java @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A HANDSHAKE_DONE Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class HandshakeDoneFrame extends QuicFrame { + + /** + * Incoming HANDSHAKE_DONE frame returned by QuicFrame.decode() + * + * @param buffer + * @param type + */ + HandshakeDoneFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(HANDSHAKE_DONE); + } + + /** + * Outgoing HANDSHAKE_DONE frame + */ + public HandshakeDoneFrame() { + super(HANDSHAKE_DONE); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, HANDSHAKE_DONE, "type"); + assert buffer.position() - pos == size(); + } + + @Override + public int size() { + return getVLFieldLengthFor(HANDSHAKE_DONE); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxDataFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxDataFrame.java new file mode 100644 index 00000000000..26720c34494 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxDataFrame.java @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A RESET_STREAM Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class MaxDataFrame extends QuicFrame { + + private final long maxData; + + /** + * Incoming MAX_DATA frame returned by QuicFrame.decode() + * + * @param buffer + * @param type + * @throws QuicTransportException if the frame was malformed + */ + MaxDataFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(MAX_DATA); + maxData = decodeVLField(buffer, "maxData"); + } + + /** + * Outgoing MAX_DATA frame + */ + public MaxDataFrame(long maxData) { + super(MAX_DATA); + this.maxData = requireVLRange(maxData, "maxData"); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, MAX_DATA, "type"); + encodeVLField(buffer, maxData, "maxData"); + assert buffer.position() - pos == size(); + } + + /** + */ + public long maxData() { + return maxData; + } + + @Override + public int size() { + return getVLFieldLengthFor(MAX_DATA) + + getVLFieldLengthFor(maxData); + } + + @Override + public String toString() { + return "MaxDataFrame(" + + "maxData=" + maxData + + ')'; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxStreamDataFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxStreamDataFrame.java new file mode 100644 index 00000000000..3fff70a377c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxStreamDataFrame.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A MAX_STREAM_DATA Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class MaxStreamDataFrame extends QuicFrame { + + private final long streamID; + private final long maxStreamData; + + /** + * Incoming MAX_STREAM_DATA frame returned by QuicFrame.decode() + * + * @param buffer + * @param type + * @throws QuicTransportException if the frame was malformed + */ + MaxStreamDataFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(MAX_STREAM_DATA); + streamID = decodeVLField(buffer, "streamID"); + maxStreamData = decodeVLField(buffer, "maxData"); + } + + /** + * Outgoing MAX_STREAM_DATA frame + */ + public MaxStreamDataFrame(long streamID, long maxStreamData) { + super(MAX_STREAM_DATA); + this.streamID = requireVLRange(streamID, "streamID"); + this.maxStreamData = requireVLRange(maxStreamData, "maxStreamData"); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, MAX_STREAM_DATA, "type"); + encodeVLField(buffer, streamID, "streamID"); + encodeVLField(buffer, maxStreamData, "maxStreamData"); + assert buffer.position() - pos == size(); + } + + /** + */ + public long maxStreamData() { + return maxStreamData; + } + + public long streamID() { + return streamID; + } + + @Override + public int size() { + return getVLFieldLengthFor(MAX_STREAM_DATA) + + getVLFieldLengthFor(streamID) + + getVLFieldLengthFor(maxStreamData); + } + + @Override + public String toString() { + return "MaxStreamDataFrame(" + + "streamId=" + streamID + + ", maxStreamData=" + maxStreamData + + ')'; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxStreamsFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxStreamsFrame.java new file mode 100644 index 00000000000..e35b16195a6 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/MaxStreamsFrame.java @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A MAX_STREAM Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class MaxStreamsFrame extends QuicFrame { + static final long MAX_VALUE = 1L << 60; + + private final long maxStreams; + private final boolean bidi; + + /** + * Incoming MAX_STREAMS frame returned by QuicFrame.decode() + * + * @param buffer + * @param type + * @throws QuicTransportException if the frame was malformed + */ + MaxStreamsFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(MAX_STREAMS); + bidi = (type == MAX_STREAMS); + maxStreams = decodeVLField(buffer, "maxStreams"); + if (maxStreams > MAX_VALUE) { + throw new QuicTransportException("Invalid maximum streams", + null, type, QuicTransportErrors.FRAME_ENCODING_ERROR); + } + } + + /** + * Outgoing MAX_STREAMS frame + */ + public MaxStreamsFrame(boolean bidi, long maxStreams) { + super(MAX_STREAMS); + this.bidi = bidi; + this.maxStreams = requireVLRange(maxStreams, "maxStreams"); + } + + @Override + public long getTypeField() { + return MAX_STREAMS + (bidi?0:1); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, getTypeField(), "type"); + encodeVLField(buffer, maxStreams, "maxStreams"); + assert buffer.position() - pos == size(); + } + + /** + */ + public long maxStreams() { + return maxStreams; + } + + public boolean isBidi() { + return bidi; + } + + @Override + public int size() { + return getVLFieldLengthFor(MAX_STREAMS) + + getVLFieldLengthFor(maxStreams); + } + + @Override + public String toString() { + return "MaxStreamsFrame(bidi=" + bidi + + ", maxStreams=" + maxStreams + ')'; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/NewConnectionIDFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/NewConnectionIDFrame.java new file mode 100644 index 00000000000..87e91b21e75 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/NewConnectionIDFrame.java @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A NEW_CONNECTION_ID Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class NewConnectionIDFrame extends QuicFrame { + + private final long sequenceNumber; + private final long retirePriorTo; + private final ByteBuffer connectionId; + private final ByteBuffer statelessResetToken; + + /** + * Incoming NEW_CONNECTION_ID frame returned by QuicFrame.decode() + * + * @param buffer + * @param type + * @throws QuicTransportException if the frame was malformed + */ + NewConnectionIDFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(NEW_CONNECTION_ID); + sequenceNumber = decodeVLField(buffer, "sequenceNumber"); + retirePriorTo = decodeVLField(buffer, "retirePriorTo"); + if (retirePriorTo > sequenceNumber) { + throw new QuicTransportException("Invalid retirePriorTo", + null, type, QuicTransportErrors.FRAME_ENCODING_ERROR); + } + validateRemainingLength(buffer, 17, type); + int length = Byte.toUnsignedInt(buffer.get()); + if (length < 1 || length > 20) { + throw new QuicTransportException("Invalid connection ID", + null, type, QuicTransportErrors.FRAME_ENCODING_ERROR); + } + validateRemainingLength(buffer, length + 16, type); + int position = buffer.position(); + connectionId = buffer.slice(position, length); + position += length; + statelessResetToken = buffer.slice(position, 16); + position += 16; + buffer.position(position); + } + + /** + * Outgoing NEW_CONNECTION_ID frame + */ + public NewConnectionIDFrame(long sequenceNumber, long retirePriorTo, ByteBuffer connectionId, ByteBuffer statelessResetToken) { + super(NEW_CONNECTION_ID); + this.sequenceNumber = requireVLRange(sequenceNumber, "sequenceNumber"); + this.retirePriorTo = requireVLRange(retirePriorTo, "retirePriorTo"); + int length = connectionId.remaining(); + if (length < 1 || length > 20) + throw new IllegalArgumentException("invalid length"); + this.connectionId = connectionId.slice(); + if (statelessResetToken.remaining() != 16) + throw new IllegalArgumentException("stateless reset token must be 16 bytes"); + this.statelessResetToken = statelessResetToken.slice(); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, NEW_CONNECTION_ID, "type"); + encodeVLField(buffer, sequenceNumber, "sequenceNumber"); + encodeVLField(buffer, retirePriorTo, "retirePriorTo"); + int length = connectionId.remaining(); + buffer.put((byte)length); + putByteBuffer(buffer, connectionId); + putByteBuffer(buffer, statelessResetToken); + assert buffer.position() - pos == size(); + } + + @Override + public int size() { + return getVLFieldLengthFor(NEW_CONNECTION_ID) + + getVLFieldLengthFor(sequenceNumber) + + getVLFieldLengthFor(retirePriorTo) + + 1 // connection length + + connectionId.remaining() + + statelessResetToken.remaining(); + } + + public long sequenceNumber() { + return sequenceNumber; + } + + public long retirePriorTo() { + return retirePriorTo; + } + + public ByteBuffer connectionId() { + return connectionId; + } + + public ByteBuffer statelessResetToken() { + return statelessResetToken; + } + + @Override + public String toString() { + return "NewConnectionIDFrame(seqNumber=" + sequenceNumber + + ", retirePriorTo=" + retirePriorTo + + ", connIdLength=" + connectionId.remaining() + + ")"; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/NewTokenFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/NewTokenFrame.java new file mode 100644 index 00000000000..08bc72ff1c2 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/NewTokenFrame.java @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; +import java.util.Objects; + +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; + +/** + * A NEW_TOKEN frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class NewTokenFrame extends QuicFrame { + private final byte[] token; + + /** + * Incoming NEW_TOKEN frame + * + * @param buffer + * @param type + * @throws QuicTransportException if the frame was malformed + */ + NewTokenFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(NEW_TOKEN); + int length = decodeVLFieldAsInt(buffer, "token length"); + if (length == 0) { + throw new QuicTransportException("Empty token", + null, type, + QuicTransportErrors.FRAME_ENCODING_ERROR); + } + validateRemainingLength(buffer, length, type); + final byte[] t = new byte[length]; + buffer.get(t); + this.token = t; + } + + /** + * Outgoing NEW_TOKEN frame whose token is the given ByteBuffer + * (position to limit) + */ + public NewTokenFrame(final ByteBuffer tokenBuf) { + super(NEW_TOKEN); + Objects.requireNonNull(tokenBuf); + final int length = tokenBuf.remaining(); + if (length <= 0) { + throw new IllegalArgumentException("Invalid token length"); + } + final byte[] t = new byte[length]; + tokenBuf.get(t); + this.token = t; + } + + @Override + public void encode(final ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, NEW_TOKEN, "type"); + encodeVLField(buffer, token.length, "token length"); + buffer.put(token); + assert buffer.position() - pos == size(); + } + + public byte[] token() { + return this.token; + } + + @Override + public int size() { + final int tokenLength = token.length; + return getVLFieldLengthFor(NEW_TOKEN) + + getVLFieldLengthFor(tokenLength) + + tokenLength; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PaddingFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PaddingFrame.java new file mode 100644 index 00000000000..d0258aee84e --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PaddingFrame.java @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * PaddingFrames. Since padding frames comprise a single zero byte + * this class actually represents sequences of PaddingFrames. + * When decoding, the class consumes all the zero bytes that are + * available and when encoding, the number of required padding bytes + * is specified. + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class PaddingFrame extends QuicFrame { + + private final int size; + + /** + * Incoming + */ + PaddingFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(PADDING); + int count = 1; + while (buffer.hasRemaining()) { + if (buffer.get() == 0) { + count++; + } else { + int pos = buffer.position(); + buffer.position(pos - 1); + break; + } + } + size = count; + } + + /** + * Outgoing + * @param size the number of padding frames that should be written + * to the buffer. Each frame is one byte long. + */ + public PaddingFrame(int size) { + super(PADDING); + if (size <= 0) { + throw new IllegalArgumentException("Size must be greater than zero"); + } + this.size = size; + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + for (int i=0; i buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, PATH_CHALLENGE, "type"); + putByteBuffer(buffer, data); + assert buffer.position() - pos == size(); + } + + /** + */ + public ByteBuffer data() { + return data; + } + + @Override + public int size() { + return getVLFieldLengthFor(PATH_CHALLENGE) + LENGTH; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PathResponseFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PathResponseFrame.java new file mode 100644 index 00000000000..2b1edcfe731 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PathResponseFrame.java @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A PATH_RESPONSE Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class PathResponseFrame extends QuicFrame { + + public static final int LENGTH = 8; + private final ByteBuffer data; + + /** + * Incoming PATH_RESPONSE frame returned by QuicFrame.decode() + * + * @param buffer + * @param type + * @throws QuicTransportException if the frame was malformed + */ + PathResponseFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(PATH_RESPONSE); + validateRemainingLength(buffer, LENGTH, type); + int position = buffer.position(); + data = buffer.slice(position, LENGTH); + buffer.position(position + LENGTH); + } + + /** + * Outgoing PATH_RESPONSE frame + */ + public PathResponseFrame(ByteBuffer data) { + super(PATH_RESPONSE); + if (data.remaining() != LENGTH) + throw new IllegalArgumentException("response data must be 8 bytes"); + this.data = data.slice(); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, PATH_RESPONSE, "type"); + putByteBuffer(buffer, data); + assert buffer.position() - pos == size(); + } + + /** + */ + public ByteBuffer data() { + return data; + } + + @Override + public int size() { + return getVLFieldLengthFor(PATH_RESPONSE) + LENGTH; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PingFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PingFrame.java new file mode 100644 index 00000000000..9717a3c31d1 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/PingFrame.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A PING frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class PingFrame extends QuicFrame { + + PingFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(PING); + } + + /** + */ + public PingFrame() { + super(PING); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, PING, "type"); + assert buffer.position() - pos == size(); + } + + @Override + public int size() { + return getVLFieldLengthFor(PING); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/QuicFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/QuicFrame.java new file mode 100644 index 00000000000..f4a230452d4 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/QuicFrame.java @@ -0,0 +1,387 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import java.nio.ByteBuffer; + +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.http.quic.VariableLengthEncoder; + +/** + * A QUIC Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public abstract sealed class QuicFrame permits + AckFrame, + DataBlockedFrame, + ConnectionCloseFrame, CryptoFrame, + HandshakeDoneFrame, + MaxDataFrame, MaxStreamDataFrame, MaxStreamsFrame, + NewConnectionIDFrame, NewTokenFrame, + PaddingFrame, PathChallengeFrame, PathResponseFrame, PingFrame, + ResetStreamFrame, RetireConnectionIDFrame, + StreamsBlockedFrame, StreamDataBlockedFrame, StreamFrame, StopSendingFrame { + + public static final long MAX_VL_INTEGER = (1L << 62) - 1; + /** + * Frame types + */ + public static final int PADDING=0x00; + public static final int PING=0x01; + public static final int ACK=0x02; + public static final int RESET_STREAM=0x04; + public static final int STOP_SENDING=0x05; + public static final int CRYPTO=0x06; + public static final int NEW_TOKEN=0x07; + public static final int STREAM=0x08; + public static final int MAX_DATA=0x10; + public static final int MAX_STREAM_DATA=0x11; + public static final int MAX_STREAMS=0x12; + public static final int DATA_BLOCKED=0x14; + public static final int STREAM_DATA_BLOCKED=0x15; + public static final int STREAMS_BLOCKED=0x16; + public static final int NEW_CONNECTION_ID=0x18; + public static final int RETIRE_CONNECTION_ID=0x19; + public static final int PATH_CHALLENGE=0x1a; + public static final int PATH_RESPONSE=0x1b; + public static final int CONNECTION_CLOSE=0x1c; + public static final int HANDSHAKE_DONE=0x1e; + private static final int MAX_KNOWN_FRAME_TYPE = HANDSHAKE_DONE; + private final int frameType; + + /** + * Concrete Frame types normally have two constructors which call this + * + * XXXFrame(ByteBuffer, int firstByte) which is called for incoming frames + * after reading the first byte to determine the type. The firstByte is also + * supplied to the constructor because it can contain additional state information + * + * XXXFrame(...) which is called to instantiate outgoing frames + * @param type the first byte of the frame, which encodes the frame type. + */ + QuicFrame(int type) { + frameType = type; + } + + @Override + public String toString() { + return this.getClass().getSimpleName(); + } + + /** + * decode given ByteBuffer and return a QUICFrame + */ + public static QuicFrame decode(ByteBuffer buffer) throws QuicTransportException { + long frameTypeLong = VariableLengthEncoder.decode(buffer); + if (frameTypeLong < 0) { + throw new QuicTransportException("Error decoding frame type", + null, 0, QuicTransportErrors.FRAME_ENCODING_ERROR); + } + if (frameTypeLong > Integer.MAX_VALUE) { + throw new QuicTransportException("Unrecognized frame", + null, frameTypeLong, QuicTransportErrors.FRAME_ENCODING_ERROR); + } + int frameType = (int)frameTypeLong; + var frame = switch (maskType(frameType)) { + case ACK -> new AckFrame(buffer, frameType); + case STREAM -> new StreamFrame(buffer, frameType); + case RESET_STREAM -> new ResetStreamFrame(buffer, frameType); + case PADDING -> new PaddingFrame(buffer, frameType); + case PING -> new PingFrame(buffer, frameType); + case STOP_SENDING -> new StopSendingFrame(buffer, frameType); + case CRYPTO -> new CryptoFrame(buffer, frameType); + case NEW_TOKEN -> new NewTokenFrame(buffer, frameType); + case DATA_BLOCKED -> new DataBlockedFrame(buffer, frameType); + case MAX_DATA -> new MaxDataFrame(buffer, frameType); + case MAX_STREAMS -> new MaxStreamsFrame(buffer, frameType); + case MAX_STREAM_DATA -> new MaxStreamDataFrame(buffer, frameType); + case STREAM_DATA_BLOCKED -> new StreamDataBlockedFrame(buffer, frameType); + case STREAMS_BLOCKED -> new StreamsBlockedFrame(buffer, frameType); + case NEW_CONNECTION_ID -> new NewConnectionIDFrame(buffer, frameType); + case RETIRE_CONNECTION_ID -> new RetireConnectionIDFrame(buffer, frameType); + case PATH_CHALLENGE -> new PathChallengeFrame(buffer, frameType); + case PATH_RESPONSE -> new PathResponseFrame(buffer, frameType); + case CONNECTION_CLOSE -> new ConnectionCloseFrame(buffer, frameType); + case HANDSHAKE_DONE -> new HandshakeDoneFrame(buffer, frameType); + default -> throw new QuicTransportException("Unrecognized frame", + null, frameType, QuicTransportErrors.FRAME_ENCODING_ERROR); + }; + assert frameClassOf(maskType(frameType)) == frame.getClass(); + assert frameTypeOf(frame.getClass()) == maskType(frameType); + assert frame.getTypeField() == frameType : "frame type mismatch: " + + frameType + "!=" + frame.getTypeField() + + " for frame: " + frame; + return frame; + } + + public static Class frameClassOf(int frameType) { + return switch (maskType(frameType)) { + case ACK -> AckFrame.class; + case STREAM -> StreamFrame.class; + case RESET_STREAM -> ResetStreamFrame.class; + case PADDING -> PaddingFrame.class; + case PING -> PingFrame.class; + case STOP_SENDING -> StopSendingFrame.class; + case CRYPTO -> CryptoFrame.class; + case NEW_TOKEN -> NewTokenFrame.class; + case DATA_BLOCKED -> DataBlockedFrame.class; + case MAX_DATA -> MaxDataFrame.class; + case MAX_STREAMS -> MaxStreamsFrame.class; + case MAX_STREAM_DATA -> MaxStreamDataFrame.class; + case STREAM_DATA_BLOCKED -> StreamDataBlockedFrame.class; + case STREAMS_BLOCKED -> StreamsBlockedFrame.class; + case NEW_CONNECTION_ID -> NewConnectionIDFrame.class; + case RETIRE_CONNECTION_ID -> RetireConnectionIDFrame.class; + case PATH_CHALLENGE -> PathChallengeFrame.class; + case PATH_RESPONSE -> PathResponseFrame.class; + case CONNECTION_CLOSE -> ConnectionCloseFrame.class; + case HANDSHAKE_DONE -> HandshakeDoneFrame.class; + default -> throw new IllegalArgumentException("Unrecognised frame"); + }; + } + + public static int frameTypeOf(Class frameClass) { + // we don't have class pattern matching yet - so switch + // on the class name instead + return switch (frameClass.getSimpleName()) { + case "AckFrame" -> ACK; + case "StreamFrame" -> STREAM; + case "ResetStreamFrame" -> RESET_STREAM; + case "PaddingFrame" -> PADDING; + case "PingFrame" -> PING; + case "StopSendingFrame" -> STOP_SENDING; + case "CryptoFrame" -> CRYPTO; + case "NewTokenFrame" -> NEW_TOKEN; + case "DataBlockedFrame" -> DATA_BLOCKED; + case "MaxDataFrame" -> MAX_DATA; + case "MaxStreamsFrame" -> MAX_STREAMS; + case "MaxStreamDataFrame" -> MAX_STREAM_DATA; + case "StreamDataBlockedFrame" -> STREAM_DATA_BLOCKED; + case "StreamsBlockedFrame" -> STREAMS_BLOCKED; + case "NewConnectionIDFrame" -> NEW_CONNECTION_ID; + case "RetireConnectionIDFrame" -> RETIRE_CONNECTION_ID; + case "PathChallengeFrame" -> PATH_CHALLENGE; + case "PathResponseFrame" -> PATH_RESPONSE; + case "ConnectionCloseFrame" -> CONNECTION_CLOSE; + case "HandshakeDoneFrame" -> HANDSHAKE_DONE; + default -> throw new IllegalArgumentException("Unrecognised frame"); + }; + } + + /** + * Writes src to dest, preserving position in src + */ + protected static void putByteBuffer(ByteBuffer dest, ByteBuffer src) { + dest.put(src.asReadOnlyBuffer()); + } + + /** + * Throws a QuicTransportException if the given buffer does not have enough bytes + * to finish decoding the frame + * + * @param buffer source buffer + * @param expected minimum number of bytes required + * @param type frame type to include in exception + * @throws QuicTransportException if the buffer is shorter than {@code expected} + */ + protected static void validateRemainingLength(ByteBuffer buffer, int expected, long type) + throws QuicTransportException + { + if (buffer.remaining() < expected) { + throw new QuicTransportException("Error decoding frame", + null, type, QuicTransportErrors.FRAME_ENCODING_ERROR); + } + } + + /** + * depending on the frame type, additional bits can be encoded + * in frameType(). This masks them out to return a unique value + * for each frame type. + */ + private static int maskType(int type) { + if (type >= ACK && type < RESET_STREAM) + return ACK; + if (type >= STREAM && type < MAX_DATA) + return STREAM; + if (type >= MAX_STREAMS && type < DATA_BLOCKED) + return MAX_STREAMS; + if (type >= STREAMS_BLOCKED && type < NEW_CONNECTION_ID) + return STREAMS_BLOCKED; + if (type >= CONNECTION_CLOSE && type < HANDSHAKE_DONE) + return CONNECTION_CLOSE; + // all others are unique + return type; + } + + /** + * {@return true if this frame is ACK-eliciting} + * A frame is ACK-eliciting if it is anything + * other than {@link QuicFrame#ACK}, + * {@link QuicFrame#PADDING} or + * {@link QuicFrame#CONNECTION_CLOSE} + * (or its variant). + */ + public boolean isAckEliciting() { return true; } + + /** + * {@return the minimum number of bytes needed to encode this frame} + */ + public abstract int size(); + + protected final long decodeVLField(ByteBuffer buffer, String name) throws QuicTransportException { + long v = VariableLengthEncoder.decode(buffer); + if (v < 0) { + throw new QuicTransportException("Error decoding field: " + name, + null, getTypeField(), QuicTransportErrors.FRAME_ENCODING_ERROR); + } + return v; + } + + protected final int decodeVLFieldAsInt(ByteBuffer buffer, String name) throws QuicTransportException { + long l = decodeVLField(buffer, name); + int intval = (int)l; + if (((long)intval) != l) { + throw new QuicTransportException(name + ":field too long", + null, getTypeField(), QuicTransportErrors.FRAME_ENCODING_ERROR); + } + return intval; + } + + protected static int requireVLRange(int val, String message) { + if (val < 0) { + throw new IllegalArgumentException(message + " " + val + " not in range"); + } + return val; + } + + protected static long requireVLRange(long val, String fieldName) { + if (val < 0 || val > MAX_VL_INTEGER) { + throw new IllegalArgumentException( + String.format("%s not in VL range: %s", fieldName, val)); + } + return val; + } + + protected static void encodeVLField(ByteBuffer buffer, long val, String name) { + try { + VariableLengthEncoder.encode(buffer, val); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Error encoding " + name, e); + } + } + + protected static int getVLFieldLengthFor(long val) { + return VariableLengthEncoder.getEncodedSize(val); + } + + /** + * The type of this frame, ie. one of values above, which means it + * excludes the additional information that is encoded into the first field + * of some QUIC frames. That additional info has to be maintained by the sub + * class and used by its encode() method to generate the first field for outgoing frames. + * + * see maskType() below + */ + public int frameType() { + return frameType; + } + + /** + * Encode this QUIC Frame into given ByteBuffer + */ + public abstract void encode(ByteBuffer buffer); + + /** + * {@return the type field that was / should be encoded} + * This is the {@linkplain #frameType() frame type} with + * possibly some additional bits set, depending on the + * frame. + * @implSpec + * The default implementation of this method is to return + * {@link #frameType()}. + */ + public long getTypeField() { return frameType(); } + + /** + * Tells whether this particular frame is valid in the given + * packet type. + * + *

From + * RFC 9000, section 12.5. Frames and Number Spaces: + *

+ * Some frames are prohibited in different packet number space + * The rules here generalize those of TLS, in that frames associated + * with establishing the connection can usually appear in packets + * in any packet number space, whereas those associated with transferring + * data can only appear in the application data packet number space: + * + *
    + *
  • PADDING, PING, and CRYPTO frames MAY appear in any packet number + * space.
  • + *
  • CONNECTION_CLOSE frames signaling errors at the QUIC layer (type 0x1c) + * MAY appear in any packet number space.
  • + *
  • CONNECTION_CLOSE frames signaling application errors (type 0x1d) + * MUST only appear in the application data packet number space. + *
  • ACK frames MAY appear in any packet number space but can only + * acknowledge packets that appeared in that packet number space. + * However, as noted below, 0-RTT packets cannot contain ACK frames.
  • + *
  • All other frame types MUST only be sent in the application data + * packet number space.
  • + *
+ * + * Note that it is not possible to send the following frames in 0-RTT + * packets for various reasons: ACK, CRYPTO, HANDSHAKE_DONE, NEW_TOKEN, + * PATH_RESPONSE, and RETIRE_CONNECTION_ID. A server MAY treat receipt + * of these frames in 0-RTT packets as a connection error of + * type PROTOCOL_VIOLATION. + *
+ * + * @param packetType the packet type + * @return true if the frame can be embedded in a packet of that type + */ + public boolean isValidIn(QuicPacket.PacketType packetType) { + return switch (frameType) { + case PADDING, PING -> true; + case ACK, CRYPTO -> switch (packetType) { + case VERSIONS, ZERORTT -> false; + default -> true; + }; + case CONNECTION_CLOSE -> { + if ((getTypeField() & 0x1D) == 0x1C) yield true; + yield QuicPacket.PacketNumberSpace.of(packetType) == QuicPacket.PacketNumberSpace.APPLICATION; + } + case HANDSHAKE_DONE, NEW_TOKEN, PATH_RESPONSE, + RETIRE_CONNECTION_ID -> switch (packetType) { + case ZERORTT -> false; + default -> QuicPacket.PacketNumberSpace.of(packetType) == QuicPacket.PacketNumberSpace.APPLICATION; + }; + default -> QuicPacket.PacketNumberSpace.of(packetType) == QuicPacket.PacketNumberSpace.APPLICATION; + }; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/ResetStreamFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/ResetStreamFrame.java new file mode 100644 index 00000000000..7a3579fc831 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/ResetStreamFrame.java @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A RESET_STREAM Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class ResetStreamFrame extends QuicFrame { + + private final long streamID; + private final long errorCode; + private final long finalSize; + + ResetStreamFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(RESET_STREAM); + streamID = decodeVLField(buffer, "streamID"); + errorCode = decodeVLField(buffer, "errorCode"); + finalSize = decodeVLField(buffer, "finalSize"); + } + + /** + */ + public ResetStreamFrame( + long streamID, + long errorCode, + long finalSize) + { + super(RESET_STREAM); + this.streamID = requireVLRange(streamID, "streamID"); + this.errorCode = requireVLRange(errorCode, "errorCode"); + this.finalSize = requireVLRange(finalSize, "finalSize"); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, RESET_STREAM, "type"); + encodeVLField(buffer, streamID, "streamID"); + encodeVLField(buffer, errorCode, "errorCode"); + encodeVLField(buffer, finalSize, "finalSize"); + assert buffer.position() - pos == size(); + } + + /** + */ + public long streamId() { + return streamID; + } + + /** + */ + public long errorCode() { + return errorCode; + } + + /** + */ + public long finalSize() { + return finalSize; + } + + @Override + public int size() { + return getVLFieldLengthFor(RESET_STREAM) + + getVLFieldLengthFor(streamID) + + getVLFieldLengthFor(errorCode) + + getVLFieldLengthFor(finalSize); + } + + @Override + public String toString() { + return "ResetStreamFrame(stream=" + streamID + + ", errorCode=" + errorCode + + ", finalSize=" + finalSize + ')'; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/RetireConnectionIDFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/RetireConnectionIDFrame.java new file mode 100644 index 00000000000..bf448f6a301 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/RetireConnectionIDFrame.java @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A RETIRE_CONNECTION_ID Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class RetireConnectionIDFrame extends QuicFrame { + + private final long sequenceNumber; + + /** + * Incoming RETIRE_CONNECTION_ID frame returned by QuicFrame.decode() + * + * @param buffer + * @param type + * @throws QuicTransportException if the frame was malformed + */ + RetireConnectionIDFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(RETIRE_CONNECTION_ID); + sequenceNumber = decodeVLField(buffer, "sequenceNumber"); + } + + /** + * Outgoing RETIRE_CONNECTION_ID frame + */ + public RetireConnectionIDFrame(long sequenceNumber) { + super(RETIRE_CONNECTION_ID); + this.sequenceNumber = requireVLRange(sequenceNumber, "sequenceNumber"); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, RETIRE_CONNECTION_ID, "type"); + encodeVLField(buffer, sequenceNumber, "sequenceNumber"); + assert buffer.position() - pos == size(); + } + + /** + */ + public long sequenceNumber() { + return sequenceNumber; + } + + @Override + public int size() { + return getVLFieldLengthFor(RETIRE_CONNECTION_ID) + + getVLFieldLengthFor(sequenceNumber); + } + + @Override + public String toString() { + return "RetireConnectionIDFrame(" + + "sequenceNumber=" + sequenceNumber + + ')'; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StopSendingFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StopSendingFrame.java new file mode 100644 index 00000000000..4a7d6525685 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StopSendingFrame.java @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A STOP_SENDING Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class StopSendingFrame extends QuicFrame { + + private final long streamID; + private final long errorCode; + + StopSendingFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(STOP_SENDING); + streamID = decodeVLField(buffer, "streamID"); + errorCode = decodeVLField(buffer, "errorCode"); + } + + /** + */ + public StopSendingFrame(long streamID, long errorCode) { + super(STOP_SENDING); + this.streamID = requireVLRange(streamID, "streamID"); + this.errorCode = requireVLRange(errorCode, "errorCode"); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, STOP_SENDING, "type"); + encodeVLField(buffer, streamID, "streamID"); + encodeVLField(buffer, errorCode, "errorCode"); + assert buffer.position() - pos == size(); + } + + /** + */ + public long streamID() { + return streamID; + } + + /** + */ + public long errorCode() { + return errorCode; + } + + @Override + public int size() { + return getVLFieldLengthFor(STOP_SENDING) + + getVLFieldLengthFor(streamID) + + getVLFieldLengthFor(errorCode); + } + + @Override + public String toString() { + return "StopSendingFrame(stream=" + streamID + + ", errorCode=" + errorCode + ')'; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamDataBlockedFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamDataBlockedFrame.java new file mode 100644 index 00000000000..7dd95e2278b --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamDataBlockedFrame.java @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A STREAM_DATA_BLOCKED Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class StreamDataBlockedFrame extends QuicFrame { + + private final long streamId; + private final long maxStreamData; + + /** + * Incoming STREAM_DATA_BLOCKED frame returned by QuicFrame.decode() + * + * @param buffer + * @param type + * @throws QuicTransportException if the frame was malformed + */ + StreamDataBlockedFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(STREAM_DATA_BLOCKED); + assert type == STREAM_DATA_BLOCKED : "STREAM_DATA_BLOCKED, unexpected frame type 0x" + + Integer.toHexString(type); + streamId = decodeVLField(buffer, "streamID"); + maxStreamData = decodeVLField(buffer, "maxData"); + } + + /** + * Outgoing STREAM_DATA_BLOCKED frame + */ + public StreamDataBlockedFrame(long streamId, long maxStreamData) { + super(STREAM_DATA_BLOCKED); + this.streamId = requireVLRange(streamId, "streamID"); + this.maxStreamData = requireVLRange(maxStreamData, "maxStreamData"); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, STREAM_DATA_BLOCKED, "type"); + encodeVLField(buffer, streamId, "streamID"); + encodeVLField(buffer, maxStreamData, "maxStreamData"); + assert buffer.position() - pos == size(); + } + + /** + */ + public long maxStreamData() { + return maxStreamData; + } + + public long streamId() { + return streamId; + } + + @Override + public int size() { + return getVLFieldLengthFor(STREAM_DATA_BLOCKED) + + getVLFieldLengthFor(streamId) + + getVLFieldLengthFor(maxStreamData); + } + + @Override + public String toString() { + return "StreamDataBlockedFrame(" + + "streamId=" + streamId + + ", maxStreamData=" + maxStreamData + + ')'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof StreamDataBlockedFrame that)) return false; + if (streamId != that.streamId) return false; + return maxStreamData == that.maxStreamData; + } + + @Override + public int hashCode() { + int result = (int) (streamId ^ (streamId >>> 32)); + result = 31 * result + (int) (maxStreamData ^ (maxStreamData >>> 32)); + return result; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamFrame.java new file mode 100644 index 00000000000..b597e53c06c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamFrame.java @@ -0,0 +1,264 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; +import java.util.Objects; + +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.packets.QuicPacketEncoder; +import jdk.internal.net.quic.QuicTransportException; + +/** + * A STREAM Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class StreamFrame extends QuicFrame { + + // Flags in frameType() + private static final int OFF = 0x4; + private static final int LEN = 0x2; + private static final int FIN = 0x1; + + private final long streamID; + // true if the OFF bit in the type field has been set + private final boolean typeFieldHasOFF; + private final long offset; + private final int length; // -1 means consume all data in packet + private final int dataLength; + private final ByteBuffer streamData; + private final boolean fin; + + StreamFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(STREAM); + streamID = decodeVLField(buffer, "streamID"); + if ((type & OFF) > 0) { + typeFieldHasOFF = true; + offset = decodeVLField(buffer, "offset"); + } else { + typeFieldHasOFF = false; + offset = 0; + } + if ((type & LEN) > 0) { + length = decodeVLFieldAsInt(buffer, "length"); + } else { + length = -1; + } + if (length == -1) { + int remaining = buffer.remaining(); + streamData = Utils.sliceOrCopy(buffer, buffer.position(), remaining); + buffer.position(buffer.limit()); + dataLength = remaining; + } else { + validateRemainingLength(buffer, length, type); + int pos = buffer.position(); + streamData = Utils.sliceOrCopy(buffer, pos, length); + buffer.position(pos + length); + dataLength = length; + } + fin = (type & FIN) == 1; + } + + /** + * Creates StreamFrame (length == -1 means no length specified in frame + * and is assumed to occupy the remainder of the Quic/UDP packet. + * If a length is specified then it must correspond with the remaining bytes + * in streamData + */ + // It would be interesting to have a version of this constructor that can take + // a list of ByteBuffer. + public StreamFrame(long streamID, long offset, int length, boolean fin, ByteBuffer streamData) { + this(streamID, offset, length, fin, streamData, true); + } + + private StreamFrame(long streamID, long offset, int length, boolean fin, ByteBuffer streamData, boolean slice) + { + super(STREAM); + this.streamID = requireVLRange(streamID, "streamID"); + this.offset = requireVLRange(offset, "offset"); + // if offset is non-zero then we mark that the type field has OFF bit set + // to allow for that bit to be set when encoding this frame + this.typeFieldHasOFF = this.offset != 0; + if (length != -1 && length != streamData.remaining()) { + throw new IllegalArgumentException("bad length"); + } + this.length = length; + this.dataLength = streamData.remaining(); + this.fin = fin; + this.streamData = slice + ? streamData.slice(streamData.position(), dataLength) + : streamData; + } + + /** + * Creates a new StreamFrame which is a slice of this stream frame. + * @param offset the new offset + * @param length the new length + * @return a slice of the current stream frame + * @throws IndexOutOfBoundsException if the offset or length + * exceed the bounds of this stream frame + */ + public StreamFrame slice(long offset, int length) { + long oldoffset = offset(); + long offsetdiff = offset - oldoffset; + long oldlen = dataLength(); + Objects.checkFromIndexSize(offsetdiff, length, oldlen); + int pos = streamData.position(); + // safe cast to int since offsetdiff < length + int newpos = Math.addExact(pos, (int)offsetdiff); + // preserves the FIN bit if set + boolean fin = this.fin && offset + length == oldoffset + oldlen; + ByteBuffer slice = Utils.sliceOrCopy(streamData, newpos, length); + return new StreamFrame(streamID, offset, length, fin, slice, false); + } + + /** + * {@return the stream id} + */ + public long streamId() { + return streamID; + } + + /** + * {@return whether this frame has a length} + * A frame that doesn't have a length must be the last + * frame in the packet. + */ + public boolean hasLength() { + return length != -1; + } + + /** + * {@return true if this is the last frame in the stream} + * The last frame has the FIN bit set. + */ + public boolean isLast() { return fin; } + + @Override + public long getTypeField() { + return STREAM | (hasLength() ? LEN : 0) + | (typeFieldHasOFF ? OFF : 0) + | (fin ? FIN : 0); + } + + @Override + public void encode(ByteBuffer dest) { + if (size() > dest.remaining()) { + throw new BufferOverflowException(); + } + int pos = dest.position(); + encodeVLField(dest, getTypeField(), "type"); + encodeVLField(dest, streamID, "streamID"); + if (typeFieldHasOFF) { + encodeVLField(dest, offset, "offset"); + } + if (hasLength()) { + encodeVLField(dest, length, "length"); + assert streamData.remaining() == length; + } + putByteBuffer(dest, streamData); + assert dest.position() - pos == size(); + } + + @Override + public int size() { + int size = getVLFieldLengthFor(getTypeField()) + + getVLFieldLengthFor(streamID); + if (typeFieldHasOFF) { + size += getVLFieldLengthFor(offset); + } + if (hasLength()) { + return size + getVLFieldLengthFor(length) + length; + } else { + return size + streamData.remaining(); + } + } + + /** + * {@return the frame payload} + */ + public ByteBuffer payload() { + return streamData.slice(); + } + + /** + * {@return the frame offset} + */ + public long offset() { return offset; } + + /** + * {@return the number of data bytes in the frame} + * @apiNote + * This is equivalent to calling {@code payload().remaining()}. + */ + public int dataLength() { + return dataLength; + } + + public static int compareOffsets(StreamFrame sf1, StreamFrame sf2) { + return Long.compare(sf1.offset, sf2.offset); + } + + /** + * Computes the header size that would be required to encode a frame with + * the given streamId, offset, and length. + * @apiNote + * This method is useful to figure out how many bytes can be allocated for + * the frame data, given a size constraint imposed by the space available + * for the whole datagram payload. + * @param encoder the {@code QuicPacketEncoder} - which can be used in case + * some part of the computation is Quic-version dependent. + * @param streamId the stream id + * @param offset the stream offset + * @param length the estimated length of the frame, typically this will be + * the min between the data available in the stream with respect + * to flow control, and the maximum remaining size for the datagram + * payload + * @return the estimated size of the header for a {@code StreamFrame} that would + * be created with the given parameters. + */ + public static int headerSize(QuicPacketEncoder encoder, long streamId, long offset, long length) { + // the header length is the size needed to encode the frame type, + // plus the size needed to encode the streamId, plus the size needed + // to encode the offset (if not 0) and the size needed to encode the + // length (if present) + int headerLength = getVLFieldLengthFor(STREAM | OFF | LEN | FIN) + + getVLFieldLengthFor(streamId); + if (offset != 0) headerLength += getVLFieldLengthFor(offset); + if (length >= 0) headerLength += getVLFieldLengthFor(length); + return headerLength; + } + + @Override + public String toString() { + return "StreamFrame(stream=" + streamID + + ", offset=" + offset + + ", length=" + length + + ", fin=" + fin + ')'; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamsBlockedFrame.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamsBlockedFrame.java new file mode 100644 index 00000000000..69292ffbbc0 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/frames/StreamsBlockedFrame.java @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.frames; + +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; + +/** + * A STREAMS_BLOCKED Frame + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public final class StreamsBlockedFrame extends QuicFrame { + + private final long maxStreams; + private final boolean bidi; + + /** + * Incoming STREAMS_BLOCKED frame returned by QuicFrame.decode() + * + * @param buffer + * @param type + * @throws QuicTransportException if the frame was malformed + */ + StreamsBlockedFrame(ByteBuffer buffer, int type) throws QuicTransportException { + super(STREAMS_BLOCKED); + bidi = (type == STREAMS_BLOCKED); + maxStreams = decodeVLField(buffer, "maxStreams"); + if (maxStreams > MaxStreamsFrame.MAX_VALUE) { + throw new QuicTransportException("Invalid maximum streams", + null, type, QuicTransportErrors.FRAME_ENCODING_ERROR); + } + } + + /** + * Outgoing STREAMS_BLOCKED frame + */ + public StreamsBlockedFrame(boolean bidi, long maxStreams) { + super(STREAMS_BLOCKED); + this.bidi = bidi; + this.maxStreams = requireVLRange(maxStreams, "maxStreams"); + } + + @Override + public long getTypeField() { + return STREAMS_BLOCKED + (bidi?0:1); + } + + @Override + public void encode(ByteBuffer buffer) { + if (size() > buffer.remaining()) { + throw new BufferOverflowException(); + } + int pos = buffer.position(); + encodeVLField(buffer, getTypeField(), "type"); + encodeVLField(buffer, maxStreams, "maxStreams"); + assert buffer.position() - pos == size(); + } + + @Override + public int size() { + return getVLFieldLengthFor(STREAMS_BLOCKED) + + getVLFieldLengthFor(maxStreams); + } + + /** + */ + public long maxStreams() { + return maxStreams; + } + + public boolean isBidi() { + return bidi; + } + + @Override + public String toString() { + return "StreamsBlockedFrame(bidi=" + bidi + + ", maxStreams=" + maxStreams + ')'; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/package-info.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/package-info.java new file mode 100644 index 00000000000..dcdd040c0ed --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/package-info.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic; + +/** + *

Internal classes for the Quic protocol implementation

+ * + * @spec https://www.rfc-editor.org/info/rfc8999 + * RFC 8999: Version-Independent Properties of QUIC + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9001 + * RFC 9001: Using TLS to Secure QUIC + * @spec https://www.rfc-editor.org/info/rfc9002 + * RFC 9002: QUIC Loss Detection and Congestion Control + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/HandshakePacket.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/HandshakePacket.java new file mode 100644 index 00000000000..a037f2e9387 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/HandshakePacket.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +import java.util.List; + +import jdk.internal.net.http.quic.frames.QuicFrame; + +/** + * This class models Quic Handshake Packets, as defined by + * RFC 9000, Section 17.2.4: + * + *
{@code
+ *    A Handshake packet uses long headers with a type value of 0x02, followed
+ *    by the Length and Packet Number fields; see Section 17.2. The first byte
+ *    contains the Reserved and Packet Number Length bits; see Section 17.2.
+ *    It is used to carry cryptographic handshake messages and acknowledgments
+ *    from the server and client.
+ *
+ *    Handshake Packet {
+ *      Header Form (1) = 1,
+ *      Fixed Bit (1) = 1,
+ *      Long Packet Type (2) = 2,
+ *      Reserved Bits (2),
+ *      Packet Number Length (2),
+ *      Version (32),
+ *      Destination Connection ID Length (8),
+ *      Destination Connection ID (0..160),
+ *      Source Connection ID Length (8),
+ *      Source Connection ID (0..160),
+ *      Length (i),
+ *      Packet Number (8..32),
+ *      Packet Payload (..),
+ *    }
+ * }
+ * + *

Subclasses of this class may be used to model packets exchanged with either + * Quic Version 2. + * Note that Quic Version 2 uses the same Handshake Packet structure than + * Quic Version 1, but uses a different long packet type than that shown above. See + * RFC 9369, Section 3.2. + * + * @see + * RFC 9000, Section 17.2 + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ +public interface HandshakePacket extends LongHeaderPacket { + @Override + default PacketType packetType() { + return PacketType.HANDSHAKE; + } + + @Override + default PacketNumberSpace numberSpace() { + return PacketNumberSpace.HANDSHAKE; + } + + @Override + default boolean hasLength() { return true; } + + /** + * This packet number. + * @return this packet number. + */ + @Override + long packetNumber(); + + @Override + List frames(); +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/InitialPacket.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/InitialPacket.java new file mode 100644 index 00000000000..7be864c8b9c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/InitialPacket.java @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +import java.util.List; + +import jdk.internal.net.http.quic.frames.QuicFrame; + +/** + * This class models Quic Initial Packets, as defined by + * RFC 9000, Section 17.2.2: + * + *

{@code
+ *    An Initial packet uses long headers with a type value of 0x00.
+ *    It carries the first CRYPTO frames sent by the client and server to perform
+ *    key exchange, and it carries ACK frames in either direction.
+ *
+ *    Initial Packet {
+ *      Header Form (1) = 1,
+ *      Fixed Bit (1) = 1,
+ *      Long Packet Type (2) = 0,
+ *      Reserved Bits (2),         # Protected
+ *      Packet Number Length (2),  # Protected
+ *      Version (32),
+ *      DCID Len (8),
+ *      Destination Connection ID (0..160),
+ *      SCID Len (8),
+ *      Source Connection ID (0..160),
+ *      Token Length (i),
+ *      Token (..),
+ *      Length (i),
+ *      Packet Number (8..32),     # Protected
+ *      # Protected Packet Payload (..)
+ *      Protected Payload (0..24), # Skipped Part
+ *      Protected Payload (128),   # Sampled Part
+ *      Protected Payload (..)     # Remainder
+ *    }
+ * }
+ * + *

Subclasses of this class may be used to model packets exchanged with either + * Quic Version 2. + * Note that Quic Version 2 uses the same Initial Packet structure than + * Quic Version 1, but uses a different long packet type than that shown above. See + * RFC 9369, Section 3.2. + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ +public interface InitialPacket extends LongHeaderPacket { + @Override + default PacketType packetType() { + return PacketType.INITIAL; + } + + @Override + default PacketNumberSpace numberSpace() { + return PacketNumberSpace.INITIAL; + } + + @Override + default boolean hasLength() { return true; } + + /** + * {@return the length of the token field, if present, 0 if not} + */ + int tokenLength(); + + /** + * {@return the token bytes, if present, {@code null} if not} + * + * From + * RFC 9000, Section 17.2.2: + * + *

{@code
+     *    The value of the token that was previously provided
+     *    in a Retry packet or NEW_TOKEN frame; see Section 8.1.
+     * }
+ * + * @see + * RFC 9000, Section 8.1 + */ + byte[] token(); + + /** + * This packet number. + * @return this packet number. + */ + @Override + long packetNumber(); + + @Override + List frames(); + + @Override + default String prettyPrint() { + return String.format("%s(pn:%s, size=%s, token[%s], frames:%s)", packetType(), packetNumber(), + size(), tokenLength(), frames()); + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/LongHeader.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/LongHeader.java new file mode 100644 index 00000000000..9c9a6e6ef31 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/LongHeader.java @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +import jdk.internal.net.http.quic.QuicConnectionId; + +/** + * This class models Quic Long Header Packet header, as defined by + * RFC 8999, Section 5.1: + * + *
{@code
+ *    Long Header Packet {
+ *       Header Form (1) = 1,
+ *       Version-Specific Bits (7),
+ *       Version (32),
+ *       Destination Connection ID Length (8),
+ *       Destination Connection ID (0..2040),
+ *       Source Connection ID Length (8),
+ *       Source Connection ID (0..2040),
+ *       Version-Specific Data (..),
+ *    }
+ * }
+ * + * @param version version + * @param destinationId Destination Connection ID + * @param sourceId Source Connection ID + * @param headerLength length in bytes of the packet header + * @spec https://www.rfc-editor.org/info/rfc8999 + * RFC 8999: Version-Independent Properties of QUIC + */ +public record LongHeader(int version, + QuicConnectionId destinationId, + QuicConnectionId sourceId, + int headerLength) { +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/LongHeaderPacket.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/LongHeaderPacket.java new file mode 100644 index 00000000000..960aef6530b --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/LongHeaderPacket.java @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +import jdk.internal.net.http.quic.QuicConnectionId; + +/** + * This class models Quic Long Header Packets, as defined by + * RFC 8999, Section 5.1: + * + *
{@code
+ *    Long Header Packet {
+ *       Header Form (1) = 1,
+ *       Version-Specific Bits (7),
+ *       Version (32),
+ *       Destination Connection ID Length (8),
+ *       Destination Connection ID (0..2040),
+ *       Source Connection ID Length (8),
+ *       Source Connection ID (0..2040),
+ *       Version-Specific Data (..),
+ *    }
+ * }
+ * + *

Subclasses of this class may be used to model packets exchanged with either + * Quic Version 2. + * + * @spec https://www.rfc-editor.org/info/rfc8999 + * RFC 8999: Version-Independent Properties of QUIC + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ +public interface LongHeaderPacket extends QuicPacket { + @Override + default HeadersType headersType() { return HeadersType.LONG; } + + /** + * {@return the packet's source connection ID} + */ + QuicConnectionId sourceId(); + + /** + * {@return the Quic version of the packet} + */ + int version(); + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/OneRttPacket.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/OneRttPacket.java new file mode 100644 index 00000000000..7df9f904a82 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/OneRttPacket.java @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +import java.util.List; + +import jdk.internal.net.http.quic.frames.QuicFrame; + +/** + * This class models Quic 1-RTT packets, as defined by + * RFC 9000, Section 17.3.1: + * + *

{@code
+ *    A 1-RTT packet uses a short packet header. It is used after the
+ *    version and 1-RTT keys are negotiated.
+ *
+ *    1-RTT Packet {
+ *      Header Form (1) = 0,
+ *      Fixed Bit (1) = 1,
+ *      Spin Bit (1),
+ *      Reserved Bits (2),         # Protected
+ *      Key Phase (1),             # Protected
+ *      Packet Number Length (2),  # Protected
+ *      Destination Connection ID (0..160),
+ *      Packet Number (8..32),     # Protected
+ *      # Protected Packet Payload:
+ *      Protected Payload (0..24), # Skipped Part
+ *      Protected Payload (128),   # Sampled Part
+ *      Protected Payload (..),    # Remainder
+ *    }
+ * }
+ * + *

Subclasses of this class may be used to model packets exchanged with either + * Quic Version 2. + * Quic Version 2 uses the same 1-RTT packet structure than + * Quic Version 1. + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ +public interface OneRttPacket extends ShortHeaderPacket { + + @Override + List frames(); + + @Override + default PacketNumberSpace numberSpace() { + return PacketNumberSpace.APPLICATION; + } + + @Override + default PacketType packetType() { + return PacketType.ONERTT; + } + + /** + * Returns the packet's Key Phase Bit: 0 or 1, if known. + * Returns -1 for outgoing packets. + * RFC 9000, Section 17.3.1: + * + *

{@code
+     *     Bit (0x04) of byte 0 indicates the key phase, which allows a recipient
+     *     of a packet to identify the packet protection keys that are used to
+     *     protect the packet. See [QUIC-TLS] for details.
+     *     This bit is protected using header protection; see Section 5.4 of [QUIC-TLS].
+     * }
+ * + * @return the packet's Key Phase Bit + * + * @see RFC 9001, [QUIC-TLS] + * @see RFC 9001, Section 5.4, [QUIC-TLS] + */ + default int keyPhase() { + return -1; + } + + /** + * Returns the packet's Latency Spin Bit: 0 or 1, if known. + * Returns -1 for outgoing packets. + * RFC 9000, Section 17.3.1: + * + *
{@code
+     *     The third most significant bit (0x20) of byte 0 is the latency spin
+     *     bit, set as described in Section 17.4.
+     * }
+ * + * @return the packet's Latency Spin Bit + * + * @see RFC 9000, Section 17.4 + */ + default int spin() { + return -1; + } + + @Override + default String prettyPrint() { + return String.format("%s(pn:%s, size=%s, phase:%s, spin:%s, frames:%s)", packetType(), packetNumber(), + size(), keyPhase(), spin(), frames()); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/PacketSpace.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/PacketSpace.java new file mode 100644 index 00000000000..cea0854e0e5 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/PacketSpace.java @@ -0,0 +1,245 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.locks.ReentrantLock; + +import jdk.internal.net.http.quic.QuicConnectionImpl; +import jdk.internal.net.http.quic.frames.AckFrame; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketType; +import jdk.internal.net.http.quic.PacketSpaceManager; +import jdk.internal.net.quic.QuicTransportException; + +/** + * An interface implemented by classes which keep track of packet + * numbers for a given packet number space. + */ +public sealed interface PacketSpace permits PacketSpaceManager { + + /** + * called on application packet space to record peer's transport parameters + * @param peerDelay max_ack_delay + * @param ackDelayExponent ack_delay_exponent + */ + void updatePeerTransportParameters(long peerDelay, long ackDelayExponent); + + /** + * {@return the packet number space managed by this class} + */ + PacketNumberSpace packetNumberSpace(); + + /** + * The largest processed PN is used to compute + * the packet number of an incoming Quic packet. + * + * @return the largest incoming packet number that + * was successfully processed in this space. + */ + long getLargestProcessedPN(); + + /** + * The largest received acked PN is used to compute the + * packet number that we include in an outgoing Quic packet. + * + * @return the largest packet number that was acknowledged by + * the peer in this space. + */ + long getLargestPeerAckedPN(); + + /** + * {@return the largest packet number that we have acknowledged in this + * space} + * + * @apiNote This is necessarily greater or equal to the packet number + * returned by {@linkplain #getMinPNThreshold()}. + */ + long getLargestSentAckedPN(); + + /** + * {@return the packet number threshold below which packets should be + * discarded without being processed in this space} + * + * @apiNote + * This corresponds to the largest acknowledged packet number + * carried in an outgoing ACK frame whose packet number has + * been acknowledged by the peer. In other words, the largest + * packet number sent by the peer for which we know that the + * peer has received an acknowledgement. + *

+ * Note that we need to track the ACK of outgoing packets that + * contain ACK frames in order to figure out whether a peer + * knows that a particular packet number has been received and + * avoid retransmission. However - we don't want ACK frames to grow + * too big and therefore we can drop some of the information, + * based on the largestSentAckedPN - see RFC 9000 Section 13.2 + * + */ + long getMinPNThreshold(); + + /** + * {@return a new packet number atomically allocated in this space} + */ + long allocateNextPN(); + + /** + * This method is called by {@link QuicConnectionImpl} upon reception of + * and successful negotiation of a new version. + * In that case we should stop retransmitting packet that have the + * "wrong" version: they will never be acknowledged. + */ + void versionChanged(); + + /** + * This method is called by {@link QuicConnectionImpl} upon reception of + * and successful processing of retry packet. + * In that case we should treat all previously sent packets as lost. + */ + void retry(); + + /** + * {@return a lock used by the transmission task}. + * Used to ensure that the transmission task does not observe partial changes + * during processing of incoming Versions and Retry packets. + */ + ReentrantLock getTransmitLock(); + /** + * Called when a packet is received. Causes the next ack frame to be + * updated. If a packet contains an {@link AckFrame}, the caller is + * expected to also later call {@link #processAckFrame(AckFrame)} + * when processing the packet payload. + * + * @param packet the received packet + * @param packetNumber the received packet number + * @param isAckEliciting whether this packet is ack eliciting + */ + void packetReceived(PacketType packet, long packetNumber, boolean isAckEliciting); + + /** + * Signals that a packet has been sent. + * This method is called by {@link QuicConnectionImpl} when a packet has been + * pushed to the endpoint for sending. + *

The retransmitted packet is taken out the pendingRetransmission list and + * the new packet is inserted in the pendingAcknowledgement list. + * + * @param packet the new packet being retransmitted + * @param previousPacketNumber the packet number of the previous packet that was not acknowledged, + * or -1 if this is not a retransmission + * @param packetNumber the new packet number under which this packet is being retransmitted + * @throws IllegalArgumentException If {@code newPacketNumber} is lesser than 0 + */ + void packetSent(QuicPacket packet, long previousPacketNumber, long packetNumber); + + /** + * Processes a received ACK frame. + * This method is called by {@link QuicConnectionImpl}. + * + * @param frame the ACK frame received. + */ + void processAckFrame(AckFrame frame) throws QuicTransportException; + + /** + * Signals that the peer confirmed the handshake. Application space only. + */ + void confirmHandshake(); + + /** + * Get the next ack frame to send. + * This method returns the prepared ack frame if: + * - it was not sent yet + * - there are new ack-eliciting packets to acknowledge + * - optionally, if the ack frame is overdue + * + * @param onlyOverdue if true, the frame will only be returned if it's overdue + * @return The next AckFrame to send to the peer, or {@code null} + * if there is nothing to acknowledge. + */ + AckFrame getNextAckFrame(boolean onlyOverdue); + + /** + * Get the next ack frame to send. + * This method returns the prepared ack frame if: + * - it was not sent yet + * - there are new ack-eliciting packets to acknowledge + * - the ack frame size doesn't exceed {@code maxSize} + * - optionally, if the ack frame is overdue + * + * @param onlyOverdue if true, the frame will only be returned if it's overdue + * @param maxSize + * @return The next AckFrame to send to the peer, or {@code null} + * if there is nothing to acknowledge. + */ + AckFrame getNextAckFrame(boolean onlyOverdue, int maxSize); + + /** + * Used to request sending of a ping frame, for instance, to verify that + * the connection is alive. + * @return a completable future that will be completed with the time it + * took, in milliseconds, for the peer to acknowledge the packet that + * contained the PingFrame (or any packet that was sent after) + * + * @apiNote The returned completable future is actually completed + * if any packet whose packet number is greater than the packet number + * that contained the ping frame is acknowledged. + */ + CompletableFuture requestSendPing(); + + /** + * Stops retransmission for this packet space. + */ + void close(); + + /** + * {@return true if this packet space is closed} + */ + boolean isClosed(); + + /** + * Triggers immediate run of transmit loop. + * + * This method is called by {@link QuicConnectionImpl} when new data may be + * available for sending, for example: + * - new stream data is available + * - new receive credit is available + * - stream is forcibly closed + */ + void runTransmitter(); + + /** + * {@return true if a packet with that packet number + * is already being acknowledged (will be, or has been + * acknowledged)} + * @param packetNumber the packet number + */ + boolean isAcknowledged(long packetNumber); + + /** + * Immediately retransmit one unacknowledged initial packet + * @spec https://www.rfc-editor.org/rfc/rfc9002#name-speeding-up-handshake-compl + * RFC 9002 6.2.3. Speeding up Handshake Completion + */ + void fastRetransmit(); +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacket.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacket.java new file mode 100644 index 00000000000..fa04cb3c947 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacket.java @@ -0,0 +1,249 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +import java.util.List; +import java.util.Optional; + +import jdk.internal.net.http.quic.QuicConnectionId; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.quic.QuicTLSEngine.KeySpace; + +/** + * A super-interface for all specific Quic packet implementation + * classes. + */ +public interface QuicPacket { + + /** + * {@return the packet's Destination Connection ID} + * + * @see + * RFC 9000, Section 7.2 + */ + QuicConnectionId destinationId(); + + /** + * The packet number space. + * NONE is for packets that don't have a packet number, + * such as Stateless Reset. + */ + enum PacketNumberSpace { + INITIAL, HANDSHAKE, APPLICATION, NONE; + + /** + * Maps a {@code PacketType} to the corresponding + * packet number space. + *

+ * For {@link PacketType#RETRY}, {@link PacketType#VERSIONS}, and + * {@link PacketType#NONE}, {@link PacketNumberSpace#NONE} is returned. + * + * @param packetType a packet type + * + * @return the packet number space that corresponds to the + * given packet type + */ + public static PacketNumberSpace of(PacketType packetType) { + return switch (packetType) { + case ONERTT, ZERORTT -> APPLICATION; + case INITIAL -> INITIAL; + case HANDSHAKE -> HANDSHAKE; + case RETRY, VERSIONS, NONE -> NONE; + }; + } + + /** + * Maps a {@code KeySpace} to the corresponding + * packet number space. + *

+ * For {@link KeySpace#RETRY}, {@link PacketNumberSpace#NONE} + * is returned. + * + * @param keySpace a key space + * + * @return the packet number space that corresponds to the given + * key space. + */ + public static PacketNumberSpace of(KeySpace keySpace) { + return switch (keySpace) { + case ONE_RTT, ZERO_RTT -> APPLICATION; + case HANDSHAKE -> HANDSHAKE; + case INITIAL -> INITIAL; + case RETRY -> NONE; + }; + } + } + + /** + * The packet type for Quic packets. + */ + enum PacketType { + NONE, INITIAL, VERSIONS, ZERORTT, HANDSHAKE, RETRY, ONERTT; + public boolean isLongHeaderType() { + return switch (this) { + case ONERTT, NONE, VERSIONS -> false; + default -> true; + }; + } + + /** + * {@return true if packets of this type are short-header packets} + */ + public boolean isShortHeaderType() { + return this == ONERTT; + } + + /** + * {@return the QUIC-TLS key space corresponding to this packet type} + * Some packet types, such as {@link #VERSIONS}, do not have an associated + * key space. + */ + public Optional keySpace() { + return switch (this) { + case INITIAL -> Optional.of(KeySpace.INITIAL); + case HANDSHAKE -> Optional.of(KeySpace.HANDSHAKE); + case RETRY -> Optional.of(KeySpace.RETRY); + case ZERORTT -> Optional.of(KeySpace.ZERO_RTT); + case ONERTT -> Optional.of(KeySpace.ONE_RTT); + case VERSIONS -> Optional.empty(); + case NONE -> Optional.empty(); + }; + } + } + + /** + * The Headers Type of the packet. + * This is either SHORT or LONG, or NONE when it can't be + * determined, or when we know that the packet is a stateless + * reset packet. A stateless reset packet is indistinguishable + * from a short header packet, so we only know that a packet + * is a stateless reset if we built it. In that case, the packet + * may advertise its header's type as NONE. + */ + enum HeadersType { NONE, SHORT, LONG} + + /** + * {@return this packet's number space} + */ + PacketNumberSpace numberSpace(); + + /** + * This packet size. + * @return the number of bytes needed to encode the packet. + * @see #payloadSize() + * @see #length() + */ + int size(); + + /** + * {@return true if this packet is ACK-eliciting} + * A packet is ACK-eliciting if it contains any + * {@linkplain QuicFrame#isAckEliciting() + * ACK-eliciting frame}. + */ + default boolean isAckEliciting() { + List frames = frames(); + if (frames == null || frames.isEmpty()) return false; + return frames.stream().anyMatch(QuicFrame::isAckEliciting); + } + + /** + * Whether this packet has a length field whose value can be read + * from the packet bytes. + * @return whether this packet has a length. + */ + default boolean hasLength() { + return switch (packetType()) { + case INITIAL, ZERORTT, HANDSHAKE -> true; + default -> false; + }; + } + + /** + * Returns the length of the payload and packet number. Includes encryption tag. + * + * This is the value stored in the {@code Length} field in Initial, + * Handshake and 0-RTT packets. + * @return the length of the payload and packet number. + * @throws UnsupportedOperationException if this packet type does not have + * the {@code Length} field. + * @see #hasLength() + * @see #size() + * @see #payloadSize() + */ + default int length() { + throw new UnsupportedOperationException(); + } + + /** + * This packet header's type. Either SHORT or LONG. + * @return this packet's header's type. + */ + HeadersType headersType(); + + /** + * {@return this packet's type} + */ + PacketType packetType(); + + /** + * {@return this packet's packet number, if applicable, {@code -1L} otherwise} + */ + default long packetNumber() { + return -1L; + } + + /** + * {@return this packet's frames} + */ + default List frames() { + return List.of(); + } + + /** + * {@return the packet's payload size} + * This is the number of bytes needed to encode the packet's + * {@linkplain #frames() frames}. + * @see #size() + * @see #length() + */ + default int payloadSize() { + List frames = frames(); + if (frames == null || frames.isEmpty()) return 0; + return frames.stream() + .mapToInt(QuicFrame::size) + .reduce(0, Math::addExact); + } + + default String prettyPrint() { + long pn = packetNumber(); + if (pn >= 0) { + return String.format("%s(pn:%s, size=%s, frames:%s)", packetType(), pn, size(), frames()); + } else { + return String.format("%s(size=%s)", packetType(), size()); + } + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketDecoder.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketDecoder.java new file mode 100644 index 00000000000..233a1a35778 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketDecoder.java @@ -0,0 +1,1748 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.quic.packets; + +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.PeerConnectionId; +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicVersion; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.QuicConnectionId; +import jdk.internal.net.http.quic.CodingContext; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketType; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.http.quic.VariableLengthEncoder; + +import javax.crypto.AEADBadTagException; +import javax.crypto.ShortBufferException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.HexFormat; +import java.util.Objects; +import java.util.List; +import java.nio.BufferUnderflowException; + +/** + * A {@code QuicPacketDecoder} encapsulates the logic to decode a + * quic packet. A {@code QuicPacketDecoder} is typically tied to + * a particular version of the QUIC protocol. + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9001 + * RFC 9001: Using TLS to Secure QUIC + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ +public class QuicPacketDecoder { + + private static final Logger debug = Utils.getDebugLogger(() -> "QuicPacketDecoder"); + + private final QuicVersion quicVersion; + private QuicPacketDecoder(final QuicVersion quicVersion) { + this.quicVersion = quicVersion; + } + + /** + * Reads the headers type from the given byte. + * @param first the first byte of a quic packet + * @return the headers type encoded in the given byte. + */ + private static QuicPacket.HeadersType headersType(byte first) { + int type = first & 0x80; + return type == 0 ? QuicPacket.HeadersType.SHORT : QuicPacket.HeadersType.LONG; + } + + /** + * Peeks at the headers type in the given byte buffer. + * Does not advance the cursor. + * + * @apiNote This method starts reading at the offset but respects + * the buffer limit.The provided offset must be less than the buffer + * limit in order for this method to read the header + * bytes. + * + * @param buffer the byte buffer containing a packet. + * @param offset the offset at which the packet starts. + * + * @return the header's type of the packet contained in this + * byte buffer. NONE if the header's type cannot be determined. + */ + public static QuicPacket.HeadersType peekHeaderType(ByteBuffer buffer, int offset) { + if (offset < 0 || offset >= buffer.limit()) return QuicPacket.HeadersType.NONE; + return headersType(buffer.get(offset)); + } + + /** + * Reads a connection ID length from the connection ID length + * byte. + * @param length the connection ID length byte. + * @return the connection ID length + */ + private static int connectionIdLength(byte length) { + // length is represented by an unsigned byte. + return length & 0xFF; + } + + /** + * Peeks at the connection id in the long header packet bytes. + * This method doesn't advance the cursor. + * The buffer position must be at the start of the long header packet. + * + * @param buffer the buffer containing a long headers packet. + * @return A ByteBuffer slice containing the connection id bytes, + * or null if the packet is malformed and the connection id + * could not be read. + */ + public static ByteBuffer peekLongConnectionId(ByteBuffer buffer) { + // the connection id length starts at index 5 (1 byte for headers, + // 4 bytes for version) + var pos = buffer.position(); + var remaining = buffer.remaining(); + if (remaining < 6) return null; + int length = connectionIdLength(buffer.get(pos + 5)); + if (length > QuicConnectionId.MAX_CONNECTION_ID_LENGTH) return null; + if (length > remaining - 6) return null; + return buffer.slice(pos + 6, length); + } + + /** + * Peeks at the header in the long header packet bytes. + * This method doesn't advance the cursor. + * The buffer position must be at the start of the long header packet. + * + * @param buffer the buffer containing a long header packet. + * @return A LongHeader containing the packet header data, + * or null if the packet is malformed + */ + public static LongHeader peekLongHeader(ByteBuffer buffer) { + return peekLongHeader(buffer, buffer.position()); + } + + /** + * Peeks at the header in the long header packet bytes. + * This method doesn't advance the cursor. + * + * @param buffer the buffer containing a long header packet. + * @param offset the position of the start of the packet + * @return A LongHeader containing the packet header data, + * or null if the packet is malformed + */ + public static LongHeader peekLongHeader(ByteBuffer buffer, int offset) { + // the destination connection id length starts at index 5 + // (1 byte for headers, 4 bytes for version) + // Therefore the packet needs at least 6 bytes to contain + // a DCID length (coded on 1 byte) + var remaining = buffer.remaining(); + var limit = buffer.limit(); + if (remaining < 7) return null; + if ((buffer.get(offset) & 0x80) == 0) { + // short header + return null; + } + assert buffer.order() == ByteOrder.BIG_ENDIAN; + int version = buffer.getInt(offset+1); + + + // read the DCID length (coded on 1 byte) + int length = connectionIdLength(buffer.get(offset + 5)); + if (length < 0 || length > QuicConnectionId.MAX_CONNECTION_ID_LENGTH) return null; + QuicConnectionId destinationId = new PeerConnectionId(buffer.slice(offset + 6, length), null); + + // We need at least 6 + length + 1 byte to have + // a chance to read the SCID length (coded on 1 byte) + if (length > remaining - 7) return null; + int srcPos = offset + 6 + length; + + // read the SCID length + int srclength = connectionIdLength(buffer.get(srcPos)); + if (srclength < 0 || srclength > QuicConnectionId.MAX_CONNECTION_ID_LENGTH) return null; + // we need at least pos + srclength + 1 byte in the + // packet to peek at the SCID + if (srclength > limit - srcPos - 1) return null; + QuicConnectionId sourceId = new PeerConnectionId(buffer.slice(srcPos + 1, srclength), null); + int headerLength = 7 + length + srclength; + + // Return the SCID as a buffer slice. + // The SCID begins at pos + 1 and has srclength bytes. + return new LongHeader(version, destinationId, sourceId, headerLength); + } + + /** + * Returns a bytebuffer containing the token from initial packet. + * This method doesn't advance the cursor. + * The buffer position must be at the start of an initial packet. + * @apiNote + * If the initial packet doesn't contain any token, an empty + * {@code ByteBuffer} is returned. + * @param buffer the buffer containing an initial packet. + * @return token or null if packet is malformed + */ + public static ByteBuffer peekInitialPacketToken(ByteBuffer buffer) { + + // the destination connection id length starts at index 5 + // (1 byte for headers, 4 bytes for version) + // Therefore the packet needs at least 6 bytes to contain + // a DCID length (coded on 1 byte) + var pos = buffer.position(); + var remaining = buffer.remaining(); + var limit = buffer.limit(); + if (remaining < 6) return null; + + // read the DCID length (coded on 1 byte) + int length = connectionIdLength(buffer.get(pos + 5)); + if (length > QuicConnectionId.MAX_CONNECTION_ID_LENGTH) return null; + if (length < 0) return null; + + // skip the DCID, and read the SCID length + // We need at least 6 + length + 1 byte to have + // a chance to read the SCID length (coded on 1 byte) + pos = pos + 6 + length; + if (pos > limit - 1) return null; + + // read the SCID length + int srclength = connectionIdLength(buffer.get(pos)); + if (srclength > QuicConnectionId.MAX_CONNECTION_ID_LENGTH) return null; + if (srclength < 0) return null; + // we need at least pos + srclength + 1 byte in the + // packet to peek at the token + if (srclength > limit - pos - 1) return null; + + //skip the SCID, and read the token length + pos = pos + srclength + 1; + + // read the token length + int tokenLengthLength = VariableLengthEncoder.peekEncodedValueSize(buffer, pos); + assert tokenLengthLength <= 8; + if (pos > limit - tokenLengthLength -1) return null; + long tokenLength = VariableLengthEncoder.peekEncodedValue(buffer, pos); + if (tokenLength < 0 || tokenLength > Integer.MAX_VALUE) return null; + if (tokenLength > limit - pos - tokenLengthLength) return null; + + // return the token + return buffer.slice(pos + tokenLengthLength, (int)tokenLength); + } + + /** + * Peeks at the connection id in the short header packet bytes. + * This method doesn't advance the cursor. + * The buffer position must be at the start of the short header packet. + * + * @param buffer the buffer containing a short headers packet. + * @param length the connection id length. + * + * @return A ByteBuffer slice containing the connection id bytes, + * or null if the packet is malformed and the connection id + * could not be read. + */ + public static ByteBuffer peekShortConnectionId(ByteBuffer buffer, int length) { + int pos = buffer.position(); + int limit = buffer.limit(); + assert pos >= 0; + assert length <= QuicConnectionId.MAX_CONNECTION_ID_LENGTH; + if (limit - pos < length + 1) return null; + return buffer.slice(pos+1, length); + } + + /** + * Returns the version of the first packet in the buffer. + * This method doesn't advance the cursor. + * Returns 0 if the version is 0 (version negotiation packet), + * or if the version cannot be determined. + * The packet is expected to start at the buffer's current position. + * + * @implNote + * This is equivalent to calling: + * {@code peekVersion(buffer, buffer.position())}. + * + * @param buffer the buffer containing the packet. + * + * @return the version of the packet in the buffer, or 0. + * @see + * RFC 8999: Version-Independent Properties of QUIC + */ + public static int peekVersion(ByteBuffer buffer) { + return peekVersion(buffer, buffer.position()); + } + + /** + * Returns the version of the first packet in the buffer. + * This method doesn't advance the cursor. + * Returns 0 if the version is 0 (version negotiation packet), + * or if the version cannot be determined. + * + * @apiNote This method starts reading at the offset but respects + * the buffer limit. The buffer limit must allow for reading + * the header byte and version number starting at the offset. + * + * @param buffer the buffer containing the packet. + * @param offset the offset at which the packet starts. + * + * @return the version of the packet in the buffer, or 0. + * @see + * RFC 8999: Version-Independent Properties of QUIC + */ + public static int peekVersion(ByteBuffer buffer, int offset) { + int limit = buffer.limit(); + assert offset >= 0; + if (limit - offset < 5) return 0; + QuicPacket.HeadersType headersType = peekHeaderType(buffer, offset); + if (headersType == QuicPacket.HeadersType.LONG) { + assert buffer.order() == ByteOrder.BIG_ENDIAN; + return buffer.getInt(offset+1); + } + return 0; + } + + /** + * Returns true if the first packet in the buffer is a version + * negotiation packet. + * This method doesn't advance the cursor. + * + * @apiNote This method starts reading at the offset but respects + * the buffer limit. If the packet is a long header packet, + * the buffer limit must allow for reading + * the header byte and version number starting at the offset. + * + * @param buffer the buffer containing the packet. + * @param offset the offset at which the packet starts. + * + * @return true if the first packet in the buffer is a version + * negotiation packet. + * @see + * RFC 8999: Version-Independent Properties of QUIC + */ + private static boolean isVersionNegotiation(ByteBuffer buffer, int offset) { + int limit = buffer.limit(); + if (limit - offset < 5) return false; + QuicPacket.HeadersType headersType = peekHeaderType(buffer, offset); + if (headersType == QuicPacket.HeadersType.LONG) { + assert buffer.order() == ByteOrder.BIG_ENDIAN; + return buffer.getInt(offset+1) == 0; + } + return false; + } + + public abstract static class IncomingQuicPacket implements QuicPacket { + private final QuicConnectionId destinationId; + + protected IncomingQuicPacket(QuicConnectionId destinationId) { + this.destinationId = destinationId; + } + + @Override + public final QuicConnectionId destinationId() { return destinationId; } + } + + private abstract static class IncomingLongHeaderPacket + extends IncomingQuicPacket implements LongHeaderPacket { + + private final QuicConnectionId sourceId; + private final int version; + IncomingLongHeaderPacket(QuicConnectionId sourceId, + QuicConnectionId destinationId, + int version) { + super(destinationId); + this.sourceId = sourceId; + this.version = version; + } + + @Override + public final QuicConnectionId sourceId() { return sourceId; } + + @Override + public final int version() { return version; } + } + + private abstract static class IncomingShortHeaderPacket + extends IncomingQuicPacket implements ShortHeaderPacket { + + IncomingShortHeaderPacket(QuicConnectionId destinationId) { + super(destinationId); + } + } + + private static final class IncomingRetryPacket + extends IncomingLongHeaderPacket implements RetryPacket { + final int size; + final byte[] retryToken; + + private IncomingRetryPacket(QuicConnectionId sourceId, QuicConnectionId destinationId, + int version, int size, byte[] retryToken) { + super(sourceId, destinationId, version); + this.size = size; + this.retryToken = retryToken; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] retryToken() { + return retryToken; + } + + /** + * Decode a valid {@code ByteBuffer} into an {@link IncomingRetryPacket}. + * + * @param reader A {@code PacketReader} to decode the {@code ByteBuffer} that contains + * the bytes of this packet + * @param context the decoding context + * + * @return an {@code IncomingRetryPacket} with its contents set + * according to the packets fields + * + * @throws IOException if decoding fails for any reason + * @throws BufferUnderflowException if buffer does not have enough bytes + */ + static IncomingRetryPacket decode(PacketReader reader, CodingContext context) + throws IOException, QuicTransportException { + try { + reader.verifyRetry(); + } catch (AEADBadTagException e) { + throw new IOException("Bad integrity tag", e); + } + + int size = reader.remaining(); + if (debug.on()) { + debug.log("IncomingRetryPacket.decode(%s)", reader); + } + + byte headers = reader.readHeaders(); // read headers + int version = reader.readVersion(); // read version + if (debug.on()) { + debug.log("IncomingRetryPacket.decode(headers(%x), version(%d), %s)", + headers, version, reader); + } + + // Retrieve the destination and source connections IDs + var destinationID = reader.readLongConnectionId(); + if (debug.on()) { + debug.log("IncomingRetryPacket.decode(dcid(%d), %s)", + destinationID.length(), reader); + } + var sourceID = reader.readLongConnectionId(); + if (debug.on()) { + debug.log("IncomingRetryPacket.decode(scid(%d), %s)", + sourceID.length(), reader); + } + + // Retry Token + byte[] retryToken = reader.readRetryToken(); + if (debug.on()) { + debug.log("IncomingRetryPacket.decode(retryToken(%d), %s)", + retryToken.length, reader); + } + + // Retry Integrity Tag + assert reader.remaining() == 16; + byte[] retryIntegrityTag = reader.readRetryIntegrityTag(); + if (debug.on()) { + debug.log("IncomingRetryPacket.decode(retryIntegrityTag(%d), %s)", + retryIntegrityTag.length, reader); + } + assert size == reader.bytesRead(); + + return new IncomingRetryPacket(sourceID, destinationID, version, + size, retryToken); + } + } + + private static final class IncomingHandshakePacket + extends IncomingLongHeaderPacket implements HandshakePacket { + + final int size; + final int length; + final long packetNumber; + final List frames; + + IncomingHandshakePacket(QuicConnectionId sourceId, QuicConnectionId destinationId, + int version, int length, long packetNumber, List frames, int size) { + super(sourceId, destinationId, version); + this.size = size; + this.length = length; + this.packetNumber = packetNumber; + this.frames = List.copyOf(frames); + } + + @Override + public int length() { + return length; + } + + @Override + public long packetNumber() { + return packetNumber; + } + + @Override + public int size() { + return size; + } + + @Override + public List frames() { return frames; } + + /** + * Decode a valid {@code ByteBuffer} into an {@link IncomingHandshakePacket}. + * This method removes packet protection and decrypt the packet encoded into + * the provided byte buffer, then creates an {@code IncomingHandshakePacket} + * with the decoded data. + * + * @param reader A {@code PacketReader} to decode the {@code ByteBuffer} that contains + * the bytes of this packet + * @param context the decoding context + * + * @return an {@code IncomingHandshakePacket} with its contents set + * according to the packets fields + * + * @throws IOException if decoding fails for any reason + * @throws BufferUnderflowException if buffer does not have enough bytes + * @throws QuicTransportException if packet is correctly signed but malformed + */ + static IncomingHandshakePacket decode(PacketReader reader, CodingContext context) + throws IOException, QuicKeyUnavailableException, QuicTransportException { + if (debug.on()) { + debug.log("IncomingHandshakePacket.decode(%s)", reader); + } + + byte headers = reader.readHeaders(); // read headers + int version = reader.readVersion(); // read version + if (debug.on()) { + debug.log("IncomingHandshakePacket.decode(headers(%x), version(%d), %s)", + headers, version, reader); + } + + // Retrieve the destination and source connections IDs + var destinationID = reader.readLongConnectionId(); + if (debug.on()) { + debug.log("IncomingHandshakePacket.decode(dcid(%d), %s)", + destinationID.length(), reader); + } + var sourceID = reader.readLongConnectionId(); + if (debug.on()) { + debug.log("IncomingHandshakePacket.decode(scid(%d), %s)", + sourceID.length(), reader); + } + + // Get length of packet number and payload + var packetLength = reader.readPacketLength(); + if (debug.on()) { + debug.log("IncomingHandshakePacket.decode(length(%d), %s)", + packetLength, reader); + } + + // Remove protection before reading packet number + reader.unprotectLong(packetLength); + + // re-read headers, now that protection is removed + headers = reader.headers(); + if (debug.on()) { + debug.log("IncomingHandshakePacket.decode([unprotected]headers(%x), %s)", + headers, reader); + } + + // Packet Number + var packetNumberLength = reader.packetNumberLength(); + var packetNumber = reader.readPacketNumber(packetNumberLength); + if (debug.on()) { + debug.log("IncomingHandshakePacket.decode(" + + "packetNumberLength(%d), packetNumber(%d), %s)", + packetNumberLength, packetNumber, reader); + } + + // Calculate payload length and retrieve payload + int payloadLen = (int) (packetLength - packetNumberLength); + if (debug.on()) { + debug.log("IncomingHandshakePacket.decode(payloadLen(%d), %s)", + payloadLen, reader); + } + ByteBuffer payload = null; + try { + payload = reader.decryptPayload(packetNumber, payloadLen, -1 /* key phase */); + } catch (AEADBadTagException e) { + Log.logError("[Quic] Failed to decrypt HANDSHAKE packet (Bad AEAD tag; discarding packet): " + e); + Log.logError(e); + throw new IOException("Bad AEAD tag", e); + } + // check reserved bits after checking integrity, see RFC 9000, section 17.2 + if ((headers & 0xc) != 0) { + throw new QuicTransportException("Nonzero reserved bits in packet header", + QuicTLSEngine.KeySpace.HANDSHAKE, 0, QuicTransportErrors.PROTOCOL_VIOLATION); + } + + List frames = reader.parsePayloadSlice(payload); + assert !payload.hasRemaining() : "remaining bytes in payload: " + payload.remaining(); + + // Finally, get the size (in bytes) of new packet + var size = reader.bytesRead(); + assert size == reader.position() - reader.offset(); + + assert packetLength == (int)packetLength; + return new IncomingHandshakePacket(sourceID, destinationID, + version, (int)packetLength, packetNumber, frames, size); + } + } + + private static final class IncomingZeroRttPacket + extends IncomingLongHeaderPacket implements ZeroRttPacket { + + final int size; + final int length; + final long packetNumber; + final List frames; + + IncomingZeroRttPacket(QuicConnectionId sourceId, QuicConnectionId destinationId, + int version, int length, long packetNumber, List frames, int size) { + super(sourceId, destinationId, version); + this.size = size; + this.length = length; + this.packetNumber = packetNumber; + this.frames = List.copyOf(frames); + } + + @Override + public int length() { + return length; + } + + @Override + public long packetNumber() { + return packetNumber; + } + + @Override + public int size() { + return size; + } + + @Override + public List frames() { return frames; } + + /** + * Decode a valid {@code ByteBuffer} into an {@link IncomingZeroRttPacket}. + * This method removes packet protection and decrypt the packet encoded into + * the provided byte buffer, then creates an {@code IncomingZeroRttPacket} + * with the decoded data. + * + * @param reader A {@code PacketReader} to decode the {@code ByteBuffer} that contains + * the bytes of this packet + * @param context the decoding context + * + * @return an {@code IncomingZeroRttPacket} with its contents set + * according to the packets fields + * + * @throws IOException if decoding fails for any reason + * @throws BufferUnderflowException if buffer does not have enough bytes + * @throws QuicTransportException if packet is correctly signed but malformed + */ + static IncomingZeroRttPacket decode(PacketReader reader, CodingContext context) + throws IOException, QuicKeyUnavailableException, QuicTransportException { + + if (debug.on()) { + debug.log("IncomingZeroRttPacket.decode(%s)", reader); + } + + byte headers = reader.readHeaders(); // read headers + int version = reader.readVersion(); // read version + if (debug.on()) { + debug.log("IncomingZeroRttPacket.decode(headers(%x), version(%d), %s)", + headers, version, reader); + } + + // Retrieve the destination and source connections IDs + var destinationID = reader.readLongConnectionId(); + if (debug.on()) { + debug.log("IncomingZeroRttPacket.decode(dcid(%d), %s)", + destinationID.length(), reader); + } + var sourceID = reader.readLongConnectionId(); + if (debug.on()) { + debug.log("IncomingZeroRttPacket.decode(scid(%d), %s)", + sourceID.length(), reader); + } + + // Get length of packet number and payload + var length = reader.readPacketLength(); + if (debug.on()) { + debug.log("IncomingZeroRttPacket.decode(length(%d), %s)", + length, reader); + } + + // Remove protection before reading packet number + reader.unprotectLong(length); + + // re-read headers, now that protection is removed + headers = reader.headers(); + if (debug.on()) { + debug.log("IncomingZeroRttPacket.decode([unprotected]headers(%x), %s)", + headers, reader); + } + + // Packet Number + var packetNumberLength = reader.packetNumberLength(); + var packetNumber = reader.readPacketNumber(packetNumberLength); + if (debug.on()) { + debug.log("IncomingZeroRttPacket.decode(" + + "packetNumberLength(%d), packetNumber(%d), %s)", + packetNumberLength, packetNumber, reader); + } + + // Calculate payload length and retrieve payload + int payloadLen = (int) (length - packetNumberLength); + if (debug.on()) { + debug.log("IncomingZeroRttPacket.decode(payloadLen(%d), %s)", + payloadLen, reader); + } + ByteBuffer payload = null; + try { + payload = reader.decryptPayload(packetNumber, payloadLen, -1 /* key phase */); + } catch (AEADBadTagException e) { + Log.logError("[Quic] Failed to decrypt ZERORTT packet (Bad AEAD tag; discarding packet): " + e); + Log.logError(e); + throw new IOException("Bad AEAD tag", e); + } + // check reserved bits after checking integrity, see RFC 9000, section 17.2 + if ((headers & 0xc) != 0) { + throw new QuicTransportException("Nonzero reserved bits in packet header", + QuicTLSEngine.KeySpace.ZERO_RTT, 0, QuicTransportErrors.PROTOCOL_VIOLATION); + } + List frames = reader.parsePayloadSlice(payload); + assert !payload.hasRemaining() : "remaining bytes in payload: " + payload.remaining(); + + // Finally, get the size (in bytes) of new packet + var size = reader.bytesRead(); + + assert length == (int)length; + return new IncomingZeroRttPacket(sourceID, destinationID, + version, (int)length, packetNumber, frames, size); + } + } + + private static final class IncomingOneRttPacket + extends IncomingShortHeaderPacket implements OneRttPacket { + + final int size; + final long packetNumber; + final List frames; + final int keyPhase; + final int spin; + + IncomingOneRttPacket(QuicConnectionId destinationId, + long packetNumber, List frames, + int spin, int keyPhase, int size) { + super(destinationId); + this.keyPhase = keyPhase; + this.spin = spin; + this.size = size; + this.packetNumber = packetNumber; + this.frames = frames; + } + + public long packetNumber() { + return packetNumber; + } + + @Override + public int size() { + return size; + } + + @Override + public int keyPhase() { + return keyPhase; + } + + @Override + public int spin() { + return spin; + } + + @Override + public List frames() { return frames; } + + /** + * Decode a valid {@code ByteBuffer} into an {@link IncomingOneRttPacket}. + * This method removes packet protection and decrypt the packet encoded into + * the provided byte buffer, then creates an {@code IncomingOneRttPacket} + * with the decoded data. + * + * @param reader A {@code PacketReader} to decode the {@code ByteBuffer} that contains + * the bytes of this packet + * @param context the decoding context + * + * @return an {@code IncomingOneRttPacket} with its contents set + * according to the packets fields + * + * @throws IOException if decoding fails for any reason + * @throws BufferUnderflowException if buffer does not have enough bytes + * @throws QuicTransportException if packet is correctly signed but malformed + */ + static IncomingOneRttPacket decode(PacketReader reader, CodingContext context) + throws IOException, QuicKeyUnavailableException, QuicTransportException { + + if (debug.on()) { + debug.log("IncomingOneRttPacket.decode(%s)", reader); + } + + byte headers = reader.readHeaders(); // read headers + if (debug.on()) { + debug.log("IncomingOneRttPacket.decode(headers(%x), %s)", + headers, reader); + } + + // Retrieve the destination and source connections IDs + var destinationID = reader.readShortConnectionId(); + if (debug.on()) { + debug.log("IncomingOneRttPacket.decode(dcid(%d), %s)", + destinationID.length(), reader); + } + + // Remove protection before reading packet number + reader.unprotectShort(); + + // re-read headers, now that protection is removed + headers = reader.headers(); + if (debug.on()) { + debug.log("IncomingOneRttPacket.decode([unprotected]headers(%x), %s)", + headers, reader); + } + // Packet Number + var packetNumberLength = reader.packetNumberLength(); + var packetNumber = reader.readPacketNumber(packetNumberLength); + if (debug.on()) { + debug.log("IncomingOneRttPacket.decode(" + + "packetNumberLength(%d), packetNumber(%d), %s)", + packetNumberLength, packetNumber, reader); + } + + // Calculate payload length and retrieve payload + int payloadLen = reader.remaining(); + if (debug.on()) { + debug.log("IncomingOneRttPacket.decode(payloadLen(%d), %s)", + payloadLen, reader); + } + final int keyPhase = (headers & 0x04) >> 2; + // keyphase is a 1 bit structure, so only 0 or 1 are valid values + assert keyPhase == 0 || keyPhase == 1 : "unexpected key phase: " + keyPhase; + final int spin = (headers & 0x20) >> 5; + assert spin == 0 || spin == 1 : "unexpected spin bit: " + spin; + + ByteBuffer payload = null; + try { + payload = reader.decryptPayload(packetNumber, payloadLen, keyPhase); + } catch (AEADBadTagException e) { + Log.logError("[Quic] Failed to decrypt ONERTT packet (Bad AEAD tag; discarding packet): " + e); + Log.logError(e); + throw new IOException("Bad AEAD tag", e); + } + // check reserved bits after checking integrity, see RFC 9000, section 17.3.1 + if ((headers & 0x18) != 0) { + throw new QuicTransportException("Nonzero reserved bits in packet header", + QuicTLSEngine.KeySpace.ONE_RTT, 0, QuicTransportErrors.PROTOCOL_VIOLATION); + } + List frames = reader.parsePayloadSlice(payload); + assert !payload.hasRemaining() : "remaining bytes in payload: " + payload.remaining(); + + // Finally, get the size (in bytes) of new packet + var size = reader.bytesRead(); + + return new IncomingOneRttPacket(destinationID, packetNumber, frames, spin, keyPhase, size); + } + } + + private static final class IncomingInitialPacket + extends IncomingLongHeaderPacket implements InitialPacket { + + final int size; + final int length; + final int tokenLength; + final long packetNumber; + final byte[] token; + final List frames; + + IncomingInitialPacket(QuicConnectionId sourceId, + QuicConnectionId destinationId, int version, + int tokenLength, byte[] token, int length, + long packetNumber, List frames, int size) { + super(sourceId, destinationId, version); + this.size = size; + this.length = length; + this.tokenLength = tokenLength; + this.token = token; + this.packetNumber = packetNumber; + this.frames = List.copyOf(frames); + } + + @Override + public int tokenLength() { return tokenLength; } + + @Override + public byte[] token() { return token; } + + @Override + public int length() { return length; } + + @Override + public long packetNumber() { return packetNumber; } + + @Override + public int size() { return size; } + + @Override + public List frames() { return frames; } + + /** + * Decode a valid {@code ByteBuffer} into an {@link IncomingInitialPacket}. + * This method removes packet protection and decrypt the packet encoded into + * the provided byte buffer, then creates an {@code IncomingInitialPacket} + * with the decoded data. + * + * @param reader A {@code PacketReader} to decode the {@code ByteBuffer} that contains + * the bytes of this packet + * @param context the decoding context + * + * @return an {@code IncomingInitialPacket} with its contents set + * according to the packets fields + * + * @throws IOException if decoding fails for any reason + * @throws BufferUnderflowException if buffer does not have enough bytes + * @throws QuicTransportException if packet is correctly signed but malformed + */ + static IncomingInitialPacket decode(PacketReader reader, CodingContext context) + throws IOException, QuicKeyUnavailableException, QuicTransportException { + + if (debug.on()) { + debug.log("IncomingInitialPacket.decode(%s)", reader); + } + + byte headers = reader.readHeaders(); // read headers + int version = reader.readVersion(); // read version + if (debug.on()) { + debug.log("IncomingInitialPacket.decode([protected]headers(%x), version(%d), %s)", + headers, version, reader); + } + + // Retrieve the destination and source connections IDs + var destinationID = reader.readLongConnectionId(); + if (debug.on()) { + debug.log("IncomingInitialPacket.decode(dcid(%d), %s)", + destinationID.length(), reader); + } + var sourceID = reader.readLongConnectionId(); + if (debug.on()) { + debug.log("IncomingInitialPacket.decode(scid(%d), %s)", + sourceID.length(), reader); + } + + // Get number of bytes needed to store the length of the token + var tokenLength = (int) reader.readTokenLength(); + if (debug.on()) { + debug.log("IncomingInitialPacket.decode(token-length(%d), %s)", + tokenLength, reader); + } + var token = reader.readToken(tokenLength); + if (debug.on()) { + debug.log("IncomingInitialPacket.decode(token(%d), %s)", + token == null ? 0 : token.length, reader); + } + + // Get length of packet number and payload + var packetLength = reader.readPacketLength(); + if (debug.on()) { + debug.log("IncomingInitialPacket.decode(packetLength(%d), %s)", + packetLength, reader); + } + assert packetLength == (int)packetLength; + if (packetLength > reader.remaining()) { + if (debug.on()) { + debug.log("IncomingInitialPacket rejected, invalid length(%d/%d), %s)", + packetLength, reader.remaining(), reader); + } + throw new BufferUnderflowException(); + } + + + // get the size (in bytes) of new packet + int size = reader.bytesRead() + (int)packetLength; + + if (!context.verifyToken(destinationID, token)) { + if (debug.on()) { + debug.log("IncomingInitialPacket rejected, invalid token(%s), %s)", + token == null ? "null" : HexFormat.of().formatHex(token), + reader); + } + return null; + } + + // Remove protection before reading packet number + reader.unprotectLong(packetLength); + + // re-read headers, now that protection is removed + headers = reader.headers(); + if (debug.on()) { + debug.log("IncomingInitialPacket.decode([unprotected]headers(%x), %s)", + headers, reader); + } + + // Packet Number + int packetNumberLength = reader.packetNumberLength(); + var packetNumber = reader.readPacketNumber(packetNumberLength); + if (debug.on()) { + debug.log("IncomingInitialPacket.decode(" + + "packetNumberLength(%d), packetNumber(%d), %s)", + packetNumberLength, packetNumber, reader); + } + + // Calculate payload length and retrieve payload + int payloadLen = (int) (packetLength - packetNumberLength); + if (debug.on()) { + debug.log("IncomingInitialPacket.decode(payloadLen(%d), %s)", + payloadLen, reader); + } + ByteBuffer payload = null; + try { + payload = reader.decryptPayload(packetNumber, payloadLen, -1 /* key phase */); + } catch (AEADBadTagException e) { + Log.logError("[Quic] Failed to decrypt INITIAL packet (Bad AEAD tag; discarding packet): " + e); + Log.logError(e); + throw new IOException("Bad AEAD tag", e); + } + // check reserved bits after checking integrity, see RFC 9000, section 17.2 + if ((headers & 0xc) != 0) { + throw new QuicTransportException("Nonzero reserved bits in packet header", + QuicTLSEngine.KeySpace.INITIAL, 0, QuicTransportErrors.PROTOCOL_VIOLATION); + } + List frames = reader.parsePayloadSlice(payload); + assert !payload.hasRemaining() : "remaining bytes in payload: " + payload.remaining(); + + assert size == reader.bytesRead() : size - reader.bytesRead(); + + return new IncomingInitialPacket(sourceID, destinationID, + version, tokenLength, token, (int)packetLength, packetNumber, frames, size); + } + + } + + private static final class IncomingVersionNegotiationPacket + extends IncomingLongHeaderPacket + implements VersionNegotiationPacket { + + final int size; + final int[] versions; + + IncomingVersionNegotiationPacket(QuicConnectionId sourceId, + QuicConnectionId destinationId, + int version, int[] versions, + int size) { + super(sourceId, destinationId, version); + this.size = size; + this.versions = Objects.requireNonNull(versions); + } + + @Override + public int size() { return size; } + + @Override + public List frames() { return List.of(); } + + @Override + public int payloadSize() { return versions.length << 2; } + + @Override + public int[] supportedVersions() { + return versions; + } + + /** + * Decode a valid {@code ByteBuffer} into an {@link IncomingVersionNegotiationPacket}. + * + * @param reader A {@code PacketReader} to decode the {@code ByteBuffer} that contains + * the bytes of this packet + * @param context the decoding context + * + * @return an {@code IncomingVersionNegotiationPacket} with its contents set + * according to the packets fields + * + * @throws IOException if decoding fails for any reason + * @throws BufferUnderflowException if buffer does not have enough bytes + */ + static IncomingVersionNegotiationPacket decode(PacketReader reader, CodingContext context) + throws IOException { + + if (debug.on()) { + debug.log("IncomingVersionNegotiationPacket.decode(%s)", reader); + } + + byte headers = reader.readHeaders(); // read headers + int version = reader.readVersion(); // read version + if (debug.on()) { + debug.log("IncomingVersionNegotiationPacket.decode(headers(%x), version(%d), %s)", + headers, version, reader); + } + // The long header bit should be set. We should ignore the other 7 bits + assert QuicPacketDecoder.headersType(headers) == HeadersType.LONG || (headers & 0x80) == 0x80; + + // Retrieve the destination and source connections IDs + var destinationID = reader.readLongConnectionId(); + if (debug.on()) { + debug.log("IncomingVersionNegotiationPacket.decode(dcid(%d), %s)", + destinationID.length(), reader); + } + var sourceID = reader.readLongConnectionId(); + if (debug.on()) { + debug.log("IncomingVersionNegotiationPacket.decode(scid(%d), %s)", + sourceID.length(), reader); + } + + // Calculate payload length and retrieve payload + final int payloadLen = reader.remaining(); + final int versionsCount = payloadLen >> 2; + if (debug.on()) { + debug.log("IncomingVersionNegotiationPacket.decode(payloadLen(%d), %s)", + payloadLen, reader); + } + int[] versions = reader.readSupportedVersions(); + + // Finally, get the size (in bytes) of new packet + var size = reader.bytesRead(); + assert !reader.hasRemaining() : "%s superfluous bytes in buffer" + .formatted(reader.remaining()); + + // sanity checks: + var msg = "Bad version negotiation packet"; + if (payloadLen != versionsCount << 2) { + throw new IOException("%s: %s bytes after %s versions" + .formatted(msg, payloadLen % 4, versionsCount)); + } + if (versionsCount == 0) { + throw new IOException("%s: no supported versions in packet" + .formatted(msg)); + } + + return new IncomingVersionNegotiationPacket(sourceID, destinationID, + version, versions, size); + } + } + + /** + * Decode the contents of the given {@code ByteBuffer} and, depending on the + * {@link PacketType}, return a {@link QuicPacket} with the corresponding type. + * This method removes packet protection and decrypt the packet encoded into + * the provided byte buffer as appropriate. + * + *

If successful, an {@code IncomingQuicPacket} instance is returned. + * The position of the buffer is moved to the first byte following the last + * decoded byte. The buffer limit is unchanged. + * + *

Otherwise, an exception is thrown. The position of the buffer is unspecified, + * but is usually set at the place where the error occurred. + * + * @apiNote If successful, and the limit was not reached, this method should be + * called again to decode the next packet contained in the buffer. Otherwise, if + * an exception occurs, the remaining bytes in the buffer should be dropped, since + * the position of the next packet in the buffer cannot be determined with + * certainty. + * + * @param buffer the buffer with the bytes to be decoded + * @param context the decoding context + * + * @throws IOException if decoding fails for any reason + * @throws BufferUnderflowException if buffer does not have enough bytes + * @throws QuicTransportException if packet is correctly signed but malformed + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9001 + * RFC 9001: Using TLS to Secure QUIC + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ + public IncomingQuicPacket decode(ByteBuffer buffer, CodingContext context) + throws IOException, QuicKeyUnavailableException, QuicTransportException { + Objects.requireNonNull(buffer); + + assert buffer.order() == ByteOrder.BIG_ENDIAN; + PacketType type = peekPacketType(buffer); + PacketReader packetReader = new PacketReader(buffer, context, type); + + QuicTLSEngine.KeySpace keySpace = type.keySpace().orElse(null); + if (keySpace != null && !context.getTLSEngine().keysAvailable(keySpace)) { + if (debug.on()) { + debug.log("QuicPacketDecoder.decode(%s): no keys, skipping", packetReader); + } + return null; + } + + return switch (type) { + case RETRY -> IncomingRetryPacket.decode(packetReader, context); + case ONERTT -> IncomingOneRttPacket.decode(packetReader, context); + case ZERORTT -> IncomingZeroRttPacket.decode(packetReader, context); + case HANDSHAKE -> IncomingHandshakePacket.decode(packetReader, context); + case INITIAL -> IncomingInitialPacket.decode(packetReader, context); + case VERSIONS -> IncomingVersionNegotiationPacket.decode(packetReader, context); + case NONE -> throw new IOException("Unknown type: " + type); // if junk received + default -> throw new IOException("Not implemented: " + type); // if has type but not recognised + }; + } + + private static QuicConnectionId decodeConnectionID(ByteBuffer buffer) { + if (!buffer.hasRemaining()) + throw new BufferUnderflowException(); + + int len = buffer.get() & 0xFF; + if (len > buffer.remaining()) { + throw new BufferUnderflowException(); + } + byte[] destinationConnectionID = new byte[len]; + + // Save buffer position ahead of time to check after read + int pos = buffer.position(); + buffer.get(destinationConnectionID); + // Ensure all bytes have been read correctly + assert pos + len == buffer.position(); + + return new PeerConnectionId(destinationConnectionID); + } + + /** + * Peek at the size of the first packet present in the buffer. + * The position of the buffer must be at the first byte of the + * first packet. This method doesn't advance the buffer position. + * @param buffer A byte buffer containing quic packets + * @return the size of the first packet present in the buffer. + */ + public int peekPacketSize(ByteBuffer buffer) { + int pos = buffer.position(); + int limit = buffer.limit(); + int available = limit - pos; + assert available >= 0 : available; + if (available <= 0) return available; + PacketType type = peekPacketType(buffer); + return switch (type) { + case HANDSHAKE, INITIAL, ZERORTT -> { + assert peekVersion(buffer, pos) == quicVersion.versionNumber(); + int end = peekPacketEnd(type, buffer); + assert end <= limit; + yield end - pos; + } + // ONERTT, RETRY, VERSIONS, NONE: + default -> available; + }; + } + + /** + * Reads the Quic V1 packet type from the given byte. + * + * @param headerByte the first byte of a quic packet + * @return the packet type encoded in the given byte. + */ + private PacketType packetType(byte headerByte) { + int htype = headerByte & 0xC0; + int ptype = headerByte & 0xF0; + return switch (htype) { + case 0xC0 -> switch (quicVersion) { + case QUIC_V1 -> switch (ptype) { + case 0xC0 -> PacketType.INITIAL; + case 0xD0 -> PacketType.ZERORTT; + case 0xE0 -> PacketType.HANDSHAKE; + case 0xF0 -> PacketType.RETRY; + default -> PacketType.NONE; + }; + case QUIC_V2 -> switch (ptype) { + case 0xD0 -> PacketType.INITIAL; + case 0xE0 -> PacketType.ZERORTT; + case 0xF0 -> PacketType.HANDSHAKE; + case 0xC0 -> PacketType.RETRY; + default -> PacketType.NONE; + }; + }; + case 0x40 -> PacketType.ONERTT; // may be a stateless reset too + default -> PacketType.NONE; + }; + } + + public PacketType peekPacketType(ByteBuffer buffer) { + int offset = buffer.position(); + return peekPacketType(buffer, offset); + } + + public PacketType peekPacketType(ByteBuffer buffer, int offset) { + if (offset < 0 || offset >= buffer.limit()) return PacketType.NONE; + var headers = buffer.get(offset); + var headersType = headersType(headers); + if (headersType == QuicPacket.HeadersType.LONG) { + if (isVersionNegotiation(buffer, offset)) { + return PacketType.VERSIONS; + } + var version = peekVersion(buffer, offset); + if (version != quicVersion.versionNumber()) { + return PacketType.NONE; + } + } + return packetType(headers); + } + + /** + * Returns the position just after the first packet present in the buffer. + * @param type the first packet type. Must be INITIAL, HANDSHAKE, or ZERORTT. + * @param buffer the byte buffer containing the packet + * @return the position just after the first packet present in the buffer. + */ + private int peekPacketEnd(PacketType type, ByteBuffer buffer) { + // Store initial position to calculate size of packet decoded + int initialPosition = buffer.position(); + int limit = buffer.limit(); + assert buffer.order() == ByteOrder.BIG_ENDIAN; + assert type == PacketType.HANDSHAKE + || type == PacketType.INITIAL + || type == PacketType.ZERORTT : type; + // This case should have been handled by the caller + assert buffer.getInt(initialPosition + 1) != 0 : "version is 0"; + + int pos = initialPosition; // header bits + pos = pos + 4; // version + pos = pos + 1; // dcid length + if (pos <= 0 || pos >= limit) return limit; + int dcidlen = buffer.get(pos) & 0xFF; + pos = pos + dcidlen + 1; // scid length + if (pos <= 0 || pos >= limit) return limit; + int scidlen = buffer.get(pos) & 0xFF; + pos = pos + scidlen + 1; // token length or packet length + if (pos <= 0 || pos >= limit) return limit; + + if (type == PacketType.INITIAL) { + int tksize = VariableLengthEncoder.peekEncodedValueSize(buffer, pos); + if (tksize <= 0 || tksize > 8) return limit; + if (limit - tksize < pos) return limit; + long tklen = VariableLengthEncoder.peekEncodedValue(buffer, pos); + if (tklen < 0 || tklen > limit - pos) return limit; + pos = pos + tksize + (int)tklen; // packet length + if (pos <= 0 || pos >= limit) return limit; + } + + int lensize = VariableLengthEncoder.peekEncodedValueSize(buffer, pos); + if (lensize <= 0 || lensize > 8) return limit; + long len = VariableLengthEncoder.peekEncodedValue(buffer, pos); + if (len < 0 || len > limit - pos) return limit; + pos = pos + lensize + (int)len; // end of packet + if (pos <= 0 || pos >= limit) return limit; + return pos; + } + + /** + * Find the length of the next packet in the buffer, and return + * the next packet bytes as a slice of the original packet. + * Advances the original buffer position to after the returned + * packet. + * @param buffer a buffer containing coalesced packets + * @param offset the offset at which the next packet starts + * @return the next packet. + */ + public ByteBuffer nextPacketSlice(ByteBuffer buffer, int offset) { + assert offset >= 0; + assert offset <= buffer.limit(); + int pos = buffer.position(); + int limit = buffer.limit(); + buffer.position(offset); + ByteBuffer next = null; + try { + int size = peekPacketSize(buffer); + if (debug.on()) { + debug.log("next packet bytes from %d (%d/%d)", + offset, size, buffer.remaining()); + } + next = buffer.slice(offset, size); + buffer.position(offset + size); + } catch (Throwable tt) { + if (debug.on()) { + debug.log("failed to peek packet size: " + tt, tt); + debug.log("dropping all remaining bytes (%d)", limit - pos); + } + buffer.position(limit); + next = buffer; + } + return next; + } + + /** + * Advance the bytebuffer position to the end of the packet + * @param buffer A byte buffer containing quic packets + * @param offset The offset at which the packet starts + */ + public void skipPacket(ByteBuffer buffer, int offset) { + assert offset >= 0; + assert offset <= buffer.limit(); + int pos = buffer.position(); + int limit = buffer.limit(); + buffer.position(offset); + try { + int size = peekPacketSize(buffer); + if (debug.on()) { + debug.log("dropping packet bytes from %d (%d/%d)", + offset, size, buffer.remaining()); + } + buffer.position(offset + size); + } catch (Throwable tt) { + if (debug.on()) { + debug.log("failed to peek packet size: " + tt, tt); + debug.log("dropping all remaining bytes (%d)", limit - pos); + } + buffer.position(limit); + } + } + + /** + * Returns a decoder for the given Quic version. + * Returns {@code null} if no decoder for that version exists. + * + * @param quicVersion the Quic protocol version number + * @return a decoder for the given Quic version or {@code null} + */ + public static QuicPacketDecoder of(QuicVersion quicVersion) { + return switch (quicVersion) { + case QUIC_V1 -> Decoders.QUIC_V1_DECODER; + case QUIC_V2 -> Decoders.QUIC_V2_DECODER; + default -> throw new IllegalArgumentException("No packet decoder for Quic version " + quicVersion); + }; + } + + /** + * Returns a {@code QuicPacketDecoder} to decode the packet + * starting at the specified offset in the buffer. + * This method will attempt to read the quic version in the + * packet in order to return the proper decoder. + * If the version is 0, then the decoder for Quic Version 1 + * is returned. + * + * @param buffer A buffer containing a Quic packet + * @param offset The offset at which the packet starts + * @return A {@code QuicPacketDecoder} instance to decode the + * packet starting at the given offset. + */ + public static QuicPacketDecoder of(ByteBuffer buffer, int offset) { + var version = peekVersion(buffer, offset); + final QuicVersion quicVersion = version == 0 ? QuicVersion.QUIC_V1 + : QuicVersion.of(version).orElse(null); + if (quicVersion == null) { + return null; + } + return of(quicVersion); + } + + /** + * A {@code PacketReader} to read a Quic packet. + * A {@code PacketReader} may have version specific code, and therefore + * has an implicit pointer to a {@code QuicPacketDecoder} instance. + *

+ * A {@code PacketReader} offers high level helper methods to read + * data (such as Connection IDs or Packet Numbers) from a Quic packet. + * It has however no or little knowledge of the actual packet structure. + * It is driven by the {@code decode} method of the appropriate + * {@code IncomingQuicPacket} type. + *

+ * A {@code PacketReader} is stateful: it encapsulates a {@code ByteBuffer} + * (or possibly a list of byte buffers - as a future enhancement) and + * advances the position on the buffer it is reading. + * + */ + class PacketReader { + private static final int PACKET_NUMBER_MASK = 0x03; + final ByteBuffer buffer; + final int offset; + final int initialLimit; + final CodingContext context; + final PacketType packetType; + + PacketReader(ByteBuffer buffer, CodingContext context) { + this(buffer, context, peekPacketType(buffer)); + } + + PacketReader(ByteBuffer buffer, CodingContext context, PacketType packetType) { + assert buffer.order() == ByteOrder.BIG_ENDIAN; + int pos = buffer.position(); + int limit = buffer.limit(); + this.buffer = buffer; + this.offset = pos; + this.initialLimit = limit; + this.context = context; + this.packetType = packetType; + } + + public int offset() { + return offset; + } + + public int position() { + return buffer.position(); + } + + public int remaining() { + return buffer.remaining(); + } + + public boolean hasRemaining() { + return buffer.hasRemaining(); + } + + public int bytesRead() { + return position() - offset; + } + + public void reset() { + buffer.position(offset); + buffer.limit(initialLimit); + } + + public byte headers() { + return buffer.get(offset); + } + + public void headers(byte headers) { + buffer.put(offset, headers); + } + + public PacketType packetType() { + return packetType; + } + + public int packetNumberLength() { + return (headers() & PACKET_NUMBER_MASK) + 1; + } + + public byte readHeaders() { + return buffer.get(); + } + + public int readVersion() { + return buffer.getInt(); + } + + public int[] readSupportedVersions() { + // Calculate payload length and retrieve payload + final int payloadLen = buffer.remaining(); + final int versionsCount = payloadLen >> 2; + + int[] versions = new int[versionsCount]; + for (int i=0 ; i= 0 && packetLength <= VariableLengthEncoder.MAX_ENCODED_INTEGER + : packetLength; + if (packetLength > remaining()) { + throw new BufferUnderflowException(); + } + return packetLength; + } + + public long readTokenLength() { + return readVariableLength(); + } + + public byte[] readToken(int tokenLength) { + // Check to ensure that tokenLength is within valid range + if (tokenLength < 0 || tokenLength > buffer.remaining()) { + throw new BufferUnderflowException(); + } + byte[] token = tokenLength > 0 ? new byte[tokenLength] : null; + if (tokenLength > 0) { + buffer.get(token); + } + return token; + } + + public long readVariableLength() { + return VariableLengthEncoder.decode(buffer); + } + + public void maskPacketNumber(int packetNumberLength, ByteBuffer mask) { + int pos = buffer.position(); + for (int i = 0; i < packetNumberLength; i++) { + buffer.put(pos + i, (byte)(buffer.get(pos + i) ^ mask.get())); + } + } + + public long readPacketNumber(int packetNumberLength) { + var packetNumberSpace = PacketNumberSpace.of(packetType); + var largestProcessedPN = context.largestProcessedPN(packetNumberSpace); + return QuicPacketNumbers.decodePacketNumber(largestProcessedPN, buffer, packetNumberLength); + } + + public long readPacketNumber() { + return readPacketNumber(packetNumberLength()); + } + + private ByteBuffer peekPayloadSlice(int relativeOffset, int length) { + int payloadStart = buffer.position() + relativeOffset; + return buffer.slice(payloadStart, length); + } + + private ByteBuffer decryptPayload(long packetNumber, int payloadLen, int keyPhase) + throws AEADBadTagException, QuicKeyUnavailableException, QuicTransportException { + // Calculate payload length and retrieve payload + ByteBuffer output = buffer.slice(); + // output's position is on the first byte of encrypted data + output.mark(); + int payloadStart = buffer.position(); + buffer.position(offset); + buffer.limit(payloadStart + payloadLen); + // buffer's position and limit are set to the boundaries of the encrypted packet + try { + context.getTLSEngine().decryptPacket(packetType.keySpace().get(), packetNumber, keyPhase, + buffer, payloadStart - offset, output); + } catch (ShortBufferException e) { + throw new QuicTransportException(e.toString(), null, 0, + QuicTransportErrors.INTERNAL_ERROR); + } + // buffer's position and limit are both at end of the packet + output.limit(output.position()); + output.reset(); + // output's position and limit are set to the boundaries of decrypted frame data + buffer.limit(initialLimit); + return output; + } + + public List parsePayloadSlice(ByteBuffer payload) + throws QuicTransportException { + if (!payload.hasRemaining()) { + throw new QuicTransportException("Packet with no frames", + packetType().keySpace().get(), 0, QuicTransportErrors.PROTOCOL_VIOLATION); + } + try { + List frames = new ArrayList<>(); + while (payload.hasRemaining()) { + int start = payload.position(); + frames.add(QuicFrame.decode(payload)); + int end = payload.position(); + assert start < end : "bytes remaining at offset %s: %s" + .formatted(start, payload.remaining()); + } + return frames; + } catch (RuntimeException e) { + throw new QuicTransportException(e.getMessage(), + packetType().keySpace().get(), 0, QuicTransportErrors.INTERNAL_ERROR); + } + } + + byte[] readRetryToken() { + var tokenLength = buffer.limit() - buffer.position() - 16; + assert tokenLength > 0; + byte[] retryToken = new byte[tokenLength]; + buffer.get(retryToken); + return retryToken; + } + + byte[] readRetryIntegrityTag() { + // The 16 last bytes in the datagram payload + assert remaining() == 16; + byte[] retryIntegrityTag = new byte[16]; + buffer.get(retryIntegrityTag); + return retryIntegrityTag; + } + + public void verifyRetry() throws AEADBadTagException, QuicTransportException { + // assume the buffer position and limit are set to packet boundaries + QuicTLSEngine tlsEngine = context.getTLSEngine(); + tlsEngine.verifyRetryPacket(quicVersion, + context.originalServerConnId().asReadOnlyBuffer(), buffer.asReadOnlyBuffer()); + } + + public QuicConnectionId readLongConnectionId() { + return decodeConnectionID(buffer); + } + + public QuicConnectionId readShortConnectionId() { + if (!buffer.hasRemaining()) + throw new BufferUnderflowException(); + + // Retrieve connection ID length from endpoint via context + int len = context.connectionIdLength(); + if (len > buffer.remaining()) { + throw new BufferUnderflowException(); + } + byte[] destinationConnectionID = new byte[len]; + + // Save buffer position ahead of time to check after read + int pos = buffer.position(); + buffer.get(destinationConnectionID); + // Ensure all bytes have been read correctly + assert pos + len == buffer.position(); + + return new PeerConnectionId(destinationConnectionID); + } + + @Override + public String toString() { + return "PacketReader(offset=%s, pos=%s, remaining=%s)" + .formatted(offset, position(), remaining()); + } + + public void unprotectLong(long packetLength) + throws QuicKeyUnavailableException, QuicTransportException { + unprotect(packetLength, (byte) 0x0f); + } + + public void unprotectShort() + throws QuicKeyUnavailableException, QuicTransportException { + unprotect(buffer.remaining(), (byte) 0x1f); + } + + private void unprotect(long packetLength, byte headerMask) + throws QuicKeyUnavailableException, QuicTransportException { + QuicTLSEngine tlsEngine = context.getTLSEngine(); + int sampleSize = tlsEngine.getHeaderProtectionSampleSize(packetType.keySpace().get()); + if (packetLength > buffer.remaining() || packetLength < sampleSize + 4) { + throw new BufferUnderflowException(); + } + ByteBuffer sample = peekPayloadSlice(4, sampleSize); + ByteBuffer encryptedSample = tlsEngine.computeHeaderProtectionMask(packetType.keySpace().get(), true, sample); + byte headers = headers(); + headers ^= (byte) (encryptedSample.get() & headerMask); + headers(headers); + maskPacketNumber(packetNumberLength(), encryptedSample); + } + } + + + private static final class Decoders { + static final QuicPacketDecoder QUIC_V1_DECODER = new QuicPacketDecoder(QuicVersion.QUIC_V1); + static final QuicPacketDecoder QUIC_V2_DECODER = new QuicPacketDecoder(QuicVersion.QUIC_V2); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketEncoder.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketEncoder.java new file mode 100644 index 00000000000..890c1a63a35 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketEncoder.java @@ -0,0 +1,1746 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.function.IntFunction; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.quic.QuicVersion; +import jdk.internal.net.http.quic.frames.PaddingFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.CodingContext; +import jdk.internal.net.http.quic.QuicConnectionId; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketType; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTLSEngine.KeySpace; +import jdk.internal.net.http.quic.VariableLengthEncoder; + +import javax.crypto.ShortBufferException; + +import static jdk.internal.net.http.quic.packets.QuicPacketNumbers.computePacketNumberLength; +import static jdk.internal.net.http.quic.packets.QuicPacketNumbers.encodePacketNumber; +import static jdk.internal.net.http.quic.QuicConnectionId.MAX_CONNECTION_ID_LENGTH; + +/** + * A {@code QuicPacketEncoder} encapsulates the logic to encode a + * quic packet. A {@code QuicPacketEncoder} is typically tied to + * a particular version of the QUIC protocol. + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9001 + * RFC 9001: Using TLS to Secure QUIC + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ +public class QuicPacketEncoder { + + private static final Logger debug = Utils.getDebugLogger(() -> "QuicPacketEncoder"); + + private final QuicVersion quicVersion; + private QuicPacketEncoder(final QuicVersion quicVersion) { + this.quicVersion = quicVersion; + } + + /** + * Computes the packet's header byte, which also encodes + * the packetNumber length. + * + * @param packetTypeTag quic-dependent packet type encoding + * @param pnsize the number of bytes needed to encode the packet number + * @return the packet's header byte + */ + private static byte headers(byte packetTypeTag, int pnsize) { + int pnprefix = pnsize - 1; + assert pnprefix >= 0; + assert pnprefix <= 3; + return (byte)(packetTypeTag | pnprefix); + } + + /** + * Returns the headers tag for the given packet type. + * Returns 0 if the packet type is NONE or unknown. + *

+ * For version negotiations packet, this method returns 0x80. + * The other 7 bits must be ignored by a client. + * When emitting a version negotiation packet the server should + * also set the fix bit (0x40) to 1. + * What distinguishes a version negotiation packet from other + * long header packet types is not the packet type found in the + * header's byte, but the fact that a. it is a long header and + * b. the version number in the packet (the 4 bytes following + * the header) is 0. + * @param packetType the packet type + * @return the headers tag for the given packet type. + * + * @see + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @see + * RFC 9369: QUIC Version 2 + */ + private byte packetHeadersTag(PacketType packetType) { + return (byte) switch (quicVersion) { + case QUIC_V1 -> switch (packetType) { + case ONERTT -> 0x40; + case INITIAL -> 0xC0; + case ZERORTT -> 0xD0; + case HANDSHAKE -> 0xE0; + case RETRY -> 0xF0; + case VERSIONS -> 0x80; // remaining bits are ignored + case NONE -> 0x00; + }; + case QUIC_V2 -> switch (packetType) { + case ONERTT -> 0x40; + case INITIAL -> 0xD0; + case ZERORTT -> 0xE0; + case HANDSHAKE -> 0xF0; + case RETRY -> 0xC0; + case VERSIONS -> 0x80; // remaining bits are ignored + case NONE -> 0x00; + }; + }; + } + + /** + * Encode the OneRttPacket into the provided buffer. + * This method encrypts the packet into the provided byte buffer as appropriate, + * adding packet protection as appropriate. + * + * @param packet + * @param buffer A buffer to encode the packet into + * @param context + * @throws BufferOverflowException if the buffer is not large enough + */ + private void encodePacket(OutgoingOneRttPacket packet, + ByteBuffer buffer, + CodingContext context) + throws QuicKeyUnavailableException, QuicTransportException { + QuicConnectionId destination = packet.destinationId(); + + if (debug.on()) { + debug.log("OneRttPacket::encodePacket(ByteBuffer(%d,%d)," + + " dst=%s, packet=%d, encodedPacket=%s," + + " payload=QuicFrames(frames: %s, bytes: %d)," + + " size=%d", + buffer.position(), buffer.limit(), destination, + packet.packetNumber, Arrays.toString(packet.encodedPacketNumber), + packet.frames, packet.payloadSize, packet.size); + } + assert buffer.order() == ByteOrder.BIG_ENDIAN; + + int encodedLength = packet.encodedPacketNumber.length; + assert encodedLength >= 1 && encodedLength <= 4 : encodedLength; + int pnprefix = encodedLength - 1; + + byte headers = headers(packetHeadersTag(packet.packetType()), + packet.encodedPacketNumber.length); + assert (headers & 0x03) == pnprefix : "incorrect packet number prefix in headers: " + headers; + + final PacketWriter writer = new PacketWriter(buffer, context, PacketType.ONERTT); + writer.writeHeaders(headers); + writer.writeShortConnectionId(destination); + int packetNumberStart = writer.position(); + writer.writeEncodedPacketNumber(packet.encodedPacketNumber); + int payloadStart = writer.position(); + writer.writePayload(packet.frames); + writer.encryptPayload(packet.packetNumber, payloadStart); + assert writer.bytesWritten() == packet.size : writer.bytesWritten() - packet.size; + writer.protectHeaderShort(packetNumberStart, packet.encodedPacketNumber.length); + } + + /** + * Encode the ZeroRttPacket into the provided buffer. + * This method encrypts the packet into the provided byte buffer as appropriate, + * adding packet protection as appropriate. + * + * @param packet + * @param buffer A buffer to encode the packet into. + * @param context + * @throws BufferOverflowException if the buffer is not large enough + */ + private void encodePacket(OutgoingZeroRttPacket packet, + ByteBuffer buffer, + CodingContext context) + throws QuicKeyUnavailableException, QuicTransportException { + int version = packet.version(); + if (quicVersion.versionNumber() != version) { + throw new IllegalArgumentException("Encoder version %s does not match packet version %s" + .formatted(quicVersion, version)); + } + QuicConnectionId destination = packet.destinationId(); + QuicConnectionId source = packet.sourceId(); + if (packet.size > buffer.remaining()) { + throw new BufferOverflowException(); + } + + if (debug.on()) { + debug.log("ZeroRttPacket::encodePacket(ByteBuffer(%d,%d)," + + " src=%s, dst=%s, version=%d, packet=%d, " + + "encodedPacket=%s, payload=QuicFrame(frames: %s, bytes: %d), size=%d", + buffer.position(), buffer.limit(), source, destination, + version, packet.packetNumber, Arrays.toString(packet.encodedPacketNumber), + packet.frames, packet.payloadSize, packet.size); + } + assert buffer.order() == ByteOrder.BIG_ENDIAN; + + int encodedLength = packet.encodedPacketNumber.length; + assert encodedLength >= 1 && encodedLength <= 4 : encodedLength; + int pnprefix = encodedLength - 1; + + byte headers = headers(packetHeadersTag(packet.packetType()), + packet.encodedPacketNumber.length); + assert (headers & 0x03) == pnprefix : headers; + + PacketWriter writer = new PacketWriter(buffer, context, PacketType.ZERORTT); + writer.writeHeaders(headers); + writer.writeVersion(version); + writer.writeLongConnectionId(destination); + writer.writeLongConnectionId(source); + writer.writePacketLength(packet.length); + int packetNumberStart = writer.position(); + writer.writeEncodedPacketNumber(packet.encodedPacketNumber); + int payloadStart = writer.position(); + writer.writePayload(packet.frames); + writer.encryptPayload(packet.packetNumber, payloadStart); + assert writer.bytesWritten() == packet.size : writer.bytesWritten() - packet.size; + writer.protectHeaderLong(packetNumberStart, packet.encodedPacketNumber.length); + } + + /** + * Encode the VersionNegotiationPacket into the provided + * buffer. + * + * @param packet + * @param buffer A buffer to encode the packet into. + * @throws BufferOverflowException if the buffer is not large enough + */ + private static void encodePacket(OutgoingVersionNegotiationPacket packet, + ByteBuffer buffer) { + QuicConnectionId destination = packet.destinationId(); + QuicConnectionId source = packet.sourceId(); + + if (debug.on()) { + debug.log("VersionNegotiationPacket::encodePacket(ByteBuffer(%d,%d)," + + " src=%s, dst=%s, versions=%s, size=%d", + buffer.position(), buffer.limit(), source, destination, + Arrays.toString(packet.versions), packet.size); + } + assert buffer.order() == ByteOrder.BIG_ENDIAN; + + int offset = buffer.position(); + int limit = buffer.limit(); + assert buffer.capacity() >= packet.size; + assert limit - offset >= packet.size; + + int typeTag = 0x80; + int rand = Encoders.RANDOM.nextInt() & 0x7F; + int headers = typeTag | rand; + if (debug.on()) { + debug.log("VersionNegotiationPacket::encodePacket:" + + " type: 0x%02x, unused: 0x%02x, headers: 0x%02x", + typeTag, rand & ~0x80, headers); + } + assert (headers & typeTag) == typeTag : headers; + assert (headers ^ typeTag) == rand : headers; + + // headers(1 byte), version(4 bytes) + buffer.put((byte)headers); // 1 + putInt32(buffer, 0); // 4 + + // DCID: 1 byte for length, + destination id bytes + var dcidlen = destination.length(); + assert dcidlen <= MAX_CONNECTION_ID_LENGTH && dcidlen >= 0 : dcidlen; + buffer.put((byte)dcidlen); // 1 + buffer.put(destination.asReadOnlyBuffer()); + assert buffer.position() == offset + 6 + dcidlen : buffer.position(); + + // SCID: 1 byte for length, + source id bytes + var scidlen = source.length(); + assert scidlen <= MAX_CONNECTION_ID_LENGTH && scidlen >= 0 : scidlen; + buffer.put((byte) scidlen); + buffer.put(source.asReadOnlyBuffer()); + assert buffer.position() == offset + 7 + dcidlen + scidlen : buffer.position(); + + // Put payload (= supported versions) + int versionsStart = buffer.position(); + for (int i = 0; i < packet.versions.length; i++) { + putInt32(buffer, packet.versions[i]); + } + int versionsEnd = buffer.position(); + if (debug.on()) { + debug.log("VersionNegotiationPacket::encodePacket:" + + " encoded %d bytes", offset - versionsEnd); + } + + assert versionsEnd - offset == packet.size; + assert versionsEnd - versionsStart == packet.versions.length << 2; + } + + /** + * Encode the HandshakePacket into the provided buffer. + * This method encrypts the packet into the provided byte buffer as appropriate, + * adding packet protection as appropriate. + * + * @param packet + * @param buffer A buffer to encode the packet into. + * @param context + * @throws BufferOverflowException if the buffer is not large enough + */ + private void encodePacket(OutgoingHandshakePacket packet, + ByteBuffer buffer, + CodingContext context) + throws QuicKeyUnavailableException, QuicTransportException { + int version = packet.version(); + if (quicVersion.versionNumber() != version) { + throw new IllegalArgumentException("Encoder version %s does not match packet version %s" + .formatted(quicVersion, version)); + } + QuicConnectionId destination = packet.destinationId(); + QuicConnectionId source = packet.sourceId(); + if (packet.size > buffer.remaining()) { + throw new BufferOverflowException(); + } + + if (debug.on()) { + debug.log("HandshakePacket::encodePacket(ByteBuffer(%d,%d)," + + " src=%s, dst=%s, version=%d, packet=%d, " + + "encodedPacket=%s, payload=QuicFrame(frames: %s, bytes: %d)," + + " size=%d", + buffer.position(), buffer.limit(), source, destination, + version, packet.packetNumber, Arrays.toString(packet.encodedPacketNumber), + packet.frames, packet.payloadSize, packet.size); + } + assert buffer.order() == ByteOrder.BIG_ENDIAN; + + int encodedLength = packet.encodedPacketNumber.length; + assert encodedLength >= 1 && encodedLength <= 4 : encodedLength; + int pnprefix = encodedLength - 1; + + byte headers = headers(packetHeadersTag(packet.packetType()), + packet.encodedPacketNumber.length); + assert (headers & 0x03) == pnprefix : headers; + + PacketWriter writer = new PacketWriter(buffer, context, PacketType.HANDSHAKE); + writer.writeHeaders(headers); + writer.writeVersion(version); + writer.writeLongConnectionId(destination); + writer.writeLongConnectionId(source); + writer.writePacketLength(packet.length); + int packetNumberStart = writer.position(); + writer.writeEncodedPacketNumber(packet.encodedPacketNumber); + int payloadStart = writer.position(); + writer.writePayload(packet.frames); + writer.encryptPayload(packet.packetNumber, payloadStart); + assert writer.bytesWritten() == packet.size : writer.bytesWritten() - packet.size; + writer.protectHeaderLong(packetNumberStart, packet.encodedPacketNumber.length); + } + + /** + * Encode the InitialPacket into the provided buffer. + * This method encrypts the packet into the provided byte buffer as appropriate, + * adding packet protection as appropriate. + * + * @param packet + * @param buffer A buffer to encode the packet into. + * @param context coding context + * @throws BufferOverflowException if the buffer is not large enough + */ + private void encodePacket(OutgoingInitialPacket packet, + ByteBuffer buffer, + CodingContext context) + throws QuicKeyUnavailableException, QuicTransportException { + int version = packet.version(); + if (quicVersion.versionNumber() != version) { + throw new IllegalArgumentException("Encoder version %s does not match packet version %s" + .formatted(quicVersion, version)); + } + QuicConnectionId destination = packet.destinationId(); + QuicConnectionId source = packet.sourceId(); + if (packet.size > buffer.remaining()) { + throw new BufferOverflowException(); + } + + if (debug.on()) { + debug.log("InitialPacket::encodePacket(ByteBuffer(%d,%d)," + + " src=%s, dst=%s, version=%d, packet=%d, " + + "encodedPacket=%s, token=%s, " + + "payload=QuicFrame(frames: %s, bytes: %d), size=%d", + buffer.position(), buffer.limit(), source, destination, + version, packet.packetNumber, Arrays.toString(packet.encodedPacketNumber), + packet.token == null ? null : "byte[%s]".formatted(packet.token.length), + packet.frames, packet.payloadSize, packet.size); + } + assert buffer.order() == ByteOrder.BIG_ENDIAN; + + int encodedLength = packet.encodedPacketNumber.length; + assert encodedLength >= 1 && encodedLength <= 4 : encodedLength; + int pnprefix = encodedLength - 1; + + byte headers = headers(packetHeadersTag(packet.packetType()), + packet.encodedPacketNumber.length); + assert (headers & 0x03) == pnprefix : headers; + + PacketWriter writer = new PacketWriter(buffer, context, PacketType.INITIAL); + writer.writeHeaders(headers); + writer.writeVersion(version); + writer.writeLongConnectionId(destination); + writer.writeLongConnectionId(source); + writer.writeToken(packet.token); + writer.writePacketLength(packet.length); + int packetNumberStart = writer.position(); + writer.writeEncodedPacketNumber(packet.encodedPacketNumber); + int payloadStart = writer.position(); + writer.writePayload(packet.frames); + writer.encryptPayload(packet.packetNumber, payloadStart); + assert writer.bytesWritten() == packet.size : writer.bytesWritten() - packet.size; + writer.protectHeaderLong(packetNumberStart, packet.encodedPacketNumber.length); + } + + /** + * Encode the RetryPacket into the provided buffer. + * + * @param packet + * @param buffer A buffer to encode the packet into. + * @param context + * @throws BufferOverflowException if the buffer is not large enough + */ + private void encodePacket(OutgoingRetryPacket packet, + ByteBuffer buffer, + CodingContext context) throws QuicTransportException { + int version = packet.version(); + if (quicVersion.versionNumber() != version) { + throw new IllegalArgumentException("Encoder version %s does not match packet version %s" + .formatted(quicVersion, version)); + } + QuicConnectionId destination = packet.destinationId(); + QuicConnectionId source = packet.sourceId(); + + if (debug.on()) { + debug.log("RetryPacket::encodePacket(ByteBuffer(%d,%d)," + + " src=%s, dst=%s, version=%d, retryToken=%d," + + " size=%d", + buffer.position(), buffer.limit(), source, destination, + version, packet.retryToken.length, packet.size); + } + assert buffer.order() == ByteOrder.BIG_ENDIAN; + assert packet.retryToken.length > 0; + assert buffer.remaining() >= packet.size; + + PacketWriter writer = new PacketWriter(buffer, context, PacketType.RETRY); + + byte headers = packetHeadersTag(packet.packetType()); + headers |= (byte)Encoders.RANDOM.nextInt(0x10); + writer.writeHeaders(headers); + writer.writeVersion(version); + writer.writeLongConnectionId(destination); + writer.writeLongConnectionId(source); + writer.writeRetryToken(packet.retryToken); + assert writer.remaining() >= 16; // 128 bits + writer.signRetry(version); + + assert writer.bytesWritten() == packet.size : writer.bytesWritten() - packet.size; + } + + public abstract static class OutgoingQuicPacket implements QuicPacket { + private final QuicConnectionId destinationId; + + protected OutgoingQuicPacket(QuicConnectionId destinationId) { + this.destinationId = destinationId; + } + + @Override + public final QuicConnectionId destinationId() { return destinationId; } + + @Override + public String toString() { + + return this.getClass().getSimpleName() + "[pn=" + this.packetNumber() + + ", frames=" + frames() + "]"; + } + } + + private abstract static class OutgoingShortHeaderPacket + extends OutgoingQuicPacket implements ShortHeaderPacket { + + OutgoingShortHeaderPacket(QuicConnectionId destinationId) { + super(destinationId); + } + } + + private abstract static class OutgoingLongHeaderPacket + extends OutgoingQuicPacket implements LongHeaderPacket { + + private final QuicConnectionId sourceId; + private final int version; + OutgoingLongHeaderPacket(QuicConnectionId sourceId, + QuicConnectionId destinationId, + int version) { + super(destinationId); + this.sourceId = sourceId; + this.version = version; + } + + @Override + public final QuicConnectionId sourceId() { return sourceId; } + + @Override + public final int version() { return version; } + + } + + private static final class OutgoingRetryPacket + extends OutgoingLongHeaderPacket implements RetryPacket { + + final int size; + final byte[] retryToken; + + OutgoingRetryPacket(QuicConnectionId sourceId, + QuicConnectionId destinationId, + int version, + byte[] retryToken) { + super(sourceId, destinationId, version); + this.retryToken = retryToken; + this.size = computeSize(retryToken.length); + } + + /** + * Compute the total packet size, starting at the headers byte and + * ending at the end of the retry integrity tag. This is used to allocate a + * ByteBuffer in which to encode the packet. + * + * @return the total packet size. + */ + private int computeSize(int tokenLength) { + assert tokenLength > 0; + + // Fixed size bits: + // headers(1 byte), version(4 bytes), DCID(1 byte), SCID(1 byte), + // retryTokenIntegrity(128 bits) => 7 + 16 = 23 bytes + int size = Math.addExact(23, tokenLength); + size = Math.addExact(size, sourceId().length()); + size = Math.addExact(size, destinationId().length()); + + return size; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] retryToken() { + return retryToken; + } + } + + private static final class OutgoingHandshakePacket + extends OutgoingLongHeaderPacket implements HandshakePacket { + + final long packetNumber; + final int length; + final int size; + final byte[] encodedPacketNumber; + final List frames; + final int payloadSize; + private int tagSize; + + OutgoingHandshakePacket(QuicConnectionId sourceId, + QuicConnectionId destinationId, + int version, + long packetNumber, + byte[] encodedPacketNumber, + List frames, int tagSize) { + super(sourceId, destinationId, version); + this.packetNumber = packetNumber; + this.encodedPacketNumber = encodedPacketNumber; + this.frames = List.copyOf(frames); + this.payloadSize = frames.stream().mapToInt(QuicFrame::size).reduce(0, Math::addExact); + this.tagSize = tagSize; + this.length = computeLength(payloadSize, encodedPacketNumber.length, tagSize); + this.size = computeSize(length); + } + + @Override + public int length() { + return length; + } + + @Override + public long packetNumber() { + return packetNumber; + } + + public byte[] encodedPacketNumber() { + return encodedPacketNumber.clone(); + } + + @Override + public int size() { + return size; + } + + @Override + public int payloadSize() { + return payloadSize; + } + + /** + * Computes the value for the packet length field. + * This is the number of bytes needed to encode the packetNumber + * and the payload. + * + * @param payloadSize The payload size + * @param pnsize The number of bytes needed to encode the packet number + * @param tagSize The size of the authentication tag added during encryption + * @return the value for the packet length field. + */ + private int computeLength(int payloadSize, int pnsize, int tagSize) { + assert payloadSize >= 0; + assert pnsize > 0 && pnsize <= 4 : pnsize; + + return Math.addExact(Math.addExact(pnsize, payloadSize), tagSize); + } + + /** + * Compute the total packet size, starting at the headers byte and + * ending at the last payload byte. This is used to allocate a + * ByteBuffer in which to encode the packet. + * + * @param length The value of the length header + * + * @return the total packet size. + */ + private int computeSize(int length) { + assert length >= 0; + + // how many bytes are needed to encode the packet length + // the packet length is the number of bytes needed to encode + // the remainder of the packet: packet number + payload bytes + int lnsize = VariableLengthEncoder.getEncodedSize(length); + + // Fixed size bits: + // headers(1 byte), version(4 bytes), DCID(1 byte), SCID(1 byte), => 7 bytes + int size = Math.addExact(7, sourceId().length()); + size = Math.addExact(size, destinationId().length()); + + size = Math.addExact(size, lnsize); + size = Math.addExact(size, length); + return size; + } + + @Override + public List frames() { return frames; } + + } + + private static final class OutgoingZeroRttPacket + extends OutgoingLongHeaderPacket implements ZeroRttPacket { + + final long packetNumber; + final int length; + final int size; + final byte[] encodedPacketNumber; + final List frames; + private int tagSize; + final int payloadSize; + + OutgoingZeroRttPacket(QuicConnectionId sourceId, + QuicConnectionId destinationId, + int version, + long packetNumber, + byte[] encodedPacketNumber, + List frames, int tagSize) { + super(sourceId, destinationId, version); + this.packetNumber = packetNumber; + this.encodedPacketNumber = encodedPacketNumber; + this.frames = List.copyOf(frames); + this.tagSize = tagSize; + this.payloadSize = this.frames.stream().mapToInt(QuicFrame::size) + .reduce(0, Math::addExact); + this.length = computeLength(payloadSize, encodedPacketNumber.length, tagSize); + this.size = computeSize(length); + } + + @Override + public int length() { + return length; + } + + @Override + public long packetNumber() { + return packetNumber; + } + + public byte[] encodedPacketNumber() { + return encodedPacketNumber.clone(); + } + + @Override + public int size() { + return size; + } + + /** + * Computes the value for the packet length field. + * This is the number of bytes needed to encode the packetNumber + * and the payload. + * + * @param payloadSize The payload size + * @param pnsize The number of bytes needed to encode the packet number + * @param tagSize The size of the authentication tag added during encryption + * @return the value for the packet length field. + */ + private int computeLength(int payloadSize, int pnsize, int tagSize) { + assert payloadSize >= 0; + assert pnsize > 0 && pnsize <= 4 : pnsize; + + return Math.addExact(Math.addExact(pnsize, payloadSize), tagSize); + } + + /** + * Compute the total packet size, starting at the headers byte and + * ending at the last payload byte. This is used to allocate a + * ByteBuffer in which to encode the packet. + * + * @param length The value of the length header + * + * @return the total packet size. + */ + private int computeSize(int length) { + assert length >= 0; + + // how many bytes are needed to encode the packet length + // the packet length is the number of bytes needed to encode + // the remainder of the packet: packet number + payload bytes + int lnsize = VariableLengthEncoder.getEncodedSize(length); + + // Fixed size bits: + // headers(1 byte), version(4 bytes), DCID(1 byte), SCID(1 byte), => 7 bytes + int size = Math.addExact(7, sourceId().length()); + size = Math.addExact(size, destinationId().length()); + + size = Math.addExact(size, lnsize); + size = Math.addExact(size, length); + return size; + } + + @Override + public List frames() { + return frames; + } + + @Override + public int payloadSize() { + return payloadSize; + } + + } + + private static final class OutgoingOneRttPacket + extends OutgoingShortHeaderPacket implements OneRttPacket { + + final long packetNumber; + final int size; + final byte[] encodedPacketNumber; + final List frames; + private int tagSize; + final int payloadSize; + + OutgoingOneRttPacket(QuicConnectionId destinationId, + long packetNumber, + byte[] encodedPacketNumber, + List frames, int tagSize) { + super(destinationId); + this.packetNumber = packetNumber; + this.encodedPacketNumber = encodedPacketNumber; + this.frames = List.copyOf(frames); + this.tagSize = tagSize; + this.payloadSize = this.frames.stream().mapToInt(QuicFrame::size) + .reduce(0, Math::addExact); + this.size = computeSize(payloadSize, encodedPacketNumber.length, tagSize); + } + + public long packetNumber() { + return packetNumber; + } + + public byte[] encodedPacketNumber() { + return encodedPacketNumber.clone(); + } + + @Override + public int size() { + return size; + } + + /** + * Compute the total packet size, starting at the headers byte and + * ending at the last payload byte. This is used to allocate a + * ByteBuffer in which to encode the packet. + * + * @param payloadSize The size of the packet's payload + * @param pnsize The number of bytes needed to encode the packet number + * @param tagSize The size of the authentication tag + * @return the total packet size. + */ + private int computeSize(int payloadSize, int pnsize, int tagSize) { + assert payloadSize >= 0; + assert pnsize > 0 && pnsize <= 4 : pnsize; + + // Fixed size bits: + // headers(1 byte) + int size = Math.addExact(1, destinationId().length()); + + size = Math.addExact(size, payloadSize); + size = Math.addExact(size, pnsize); + size = Math.addExact(size, tagSize); + return size; + } + + @Override + public List frames() { + return frames; + } + + @Override + public int payloadSize() { + return payloadSize; + } + + } + + private static final class OutgoingInitialPacket + extends OutgoingLongHeaderPacket implements InitialPacket { + + final byte[] token; + final long packetNumber; + final int length; + final int size; + final byte[] encodedPacketNumber; + final List frames; + private int tagSize; + final int payloadSize; + + private record InitialPacketVariableComponents(int length, byte[] token, QuicConnectionId sourceId, + QuicConnectionId destinationId) { + + } + + public OutgoingInitialPacket(QuicConnectionId sourceId, + QuicConnectionId destinationId, + int version, + byte[] token, + long packetNumber, + byte[] encodedPacketNumber, + List frames, int tagSize) { + super(sourceId, destinationId, version); + this.token = token; + this.packetNumber = packetNumber; + this.encodedPacketNumber = encodedPacketNumber; + this.frames = List.copyOf(frames); + this.tagSize = tagSize; + this.payloadSize = this.frames.stream() + .mapToInt(QuicFrame::size) + .reduce(0, Math::addExact); + this.length = computeLength(payloadSize, encodedPacketNumber.length, tagSize); + this.size = computePacketSize(new InitialPacketVariableComponents(length, token, sourceId, + destinationId)); + } + + @Override + public int tokenLength() { return token == null ? 0 : token.length; } + + @Override + public byte[] token() { return token; } + + @Override + public int length() { return length; } + + @Override + public long packetNumber() { return packetNumber; } + + public byte[] encodedPacketNumber() { + return encodedPacketNumber.clone(); + } + + @Override + public int size() { return size; } + + /** + * Computes the value for the packet length field. + * This is the number of bytes needed to encode the packetNumber + * and the payload. + * + * @param payloadSize The payload size + * @param pnsize The number of bytes needed to encode the packet number + * @param tagSize The size of the authentication tag added during encryption + * @return the value for the packet length field. + */ + private static int computeLength(int payloadSize, int pnsize, int tagSize) { + assert payloadSize >= 0; + assert pnsize > 0 && pnsize <= 4 : pnsize; + + return Math.addExact(Math.addExact(pnsize, payloadSize), tagSize); + } + + /** + * Compute the total packet size, starting at the headers byte and + * ending at the last payload byte. This is used to allocate a + * ByteBuffer in which to encode the packet. + * + * @param variableComponents The variable components of the packet + * + * @return the total packet size. + */ + private static int computePacketSize(InitialPacketVariableComponents variableComponents) { + assert variableComponents.length >= 0; + + // how many bytes are needed to encode the length of the token + final byte[] token = variableComponents.token; + int tkLenSpecifierSize = token == null || token.length == 0 + ? 1 : VariableLengthEncoder.getEncodedSize(token.length); + + // how many bytes are needed to encode the packet length + // the packet length is the number of bytes needed to encode + // the remainder of the packet: packet number + payload bytes + int lnsize = VariableLengthEncoder.getEncodedSize(variableComponents.length); + + // Fixed size bits: + // headers(1 byte), version(4 bytes), DCID length specifier(1 byte), + // SCID length specifier(1 byte), => 7 bytes + int size = Math.addExact(7, variableComponents.sourceId.length()); + size = Math.addExact(size, variableComponents.destinationId.length()); + size = Math.addExact(size, tkLenSpecifierSize); + if (token != null) { + size = Math.addExact(size, token.length); + } + size = Math.addExact(size, lnsize); + size = Math.addExact(size, variableComponents.length); + return size; + } + + @Override + public List frames() { + return frames; + } + + @Override + public int payloadSize() { + return payloadSize; + } + + } + + private static final class OutgoingVersionNegotiationPacket + extends OutgoingLongHeaderPacket + implements VersionNegotiationPacket { + + final int[] versions; + final int size; + final int payloadSize; + + public OutgoingVersionNegotiationPacket(QuicConnectionId sourceId, + QuicConnectionId destinationId, + int[] versions) { + super(sourceId, destinationId, 0); + this.versions = versions.clone(); + this.payloadSize = versions.length << 2; + this.size = computeSize(payloadSize); + } + + @Override + public int[] supportedVersions() { + return versions.clone(); + } + + @Override + public int size() { return size; } + + @Override + public int payloadSize() { return payloadSize; } + + /** + * Compute the total packet size, starting at the headers byte and + * ending at the last payload byte. This is used to allocate a + * ByteBuffer in which to encode the packet. + * + * @param payloadSize The size of the packet's payload + * @return the total packet size. + */ + private int computeSize(int payloadSize) { + assert payloadSize > 0; + // Fixed size bits: + // headers(1 byte), version(4 bytes), DCID(1 byte), SCID(1 byte), => 7 bytes + int size = Math.addExact(7, payloadSize); + size = Math.addExact(size, sourceId().length()); + size = Math.addExact(size, destinationId().length()); + return size; + } + + } + + /** + * Create a new unencrypted InitialPacket to be transmitted over the wire + * after encryption. + * + * @param source The source connection ID + * @param destination The destination connection ID + * @param token The token field (may be null if no token) + * @param packetNumber The packet number + * @param ackedPacketNumber The largest acknowledged packet number + * @param frames The initial packet payload + * + * @param codingContext + * @return the new initial packet + */ + public OutgoingQuicPacket newInitialPacket(QuicConnectionId source, + QuicConnectionId destination, + byte[] token, + long packetNumber, + long ackedPacketNumber, + List frames, + CodingContext codingContext) { + if (debug.on()) { + debug.log("newInitialPacket: fullPN=%d ackedPN=%d", + packetNumber, ackedPacketNumber); + } + byte[] encodedPacketNumber = encodePacketNumber(packetNumber, ackedPacketNumber); + QuicTLSEngine tlsEngine = codingContext.getTLSEngine(); + int tagSize = tlsEngine.getAuthTagSize(); + // https://www.rfc-editor.org/rfc/rfc9000#section-14.1 + // A client MUST expand the payload of all UDP datagrams carrying Initial packets + // to at least the smallest allowed maximum datagram size of 1200 bytes + // by adding PADDING frames to the Initial packet or by coalescing the Initial packet + + // first compute the packet size + final int originalPayloadSize = frames.stream() + .mapToInt(QuicFrame::size) + .reduce(0, Math::addExact); + final int originalLength = OutgoingInitialPacket.computeLength(originalPayloadSize, + encodedPacketNumber.length, tagSize); + final int originalPacketSize = OutgoingInitialPacket.computePacketSize( + new OutgoingInitialPacket.InitialPacketVariableComponents(originalLength, token, + source, destination)); + if (originalPacketSize >= 1200) { + return new OutgoingInitialPacket(source, destination, this.quicVersion.versionNumber(), + token, packetNumber, encodedPacketNumber, frames, tagSize); + } else { + // add padding + int numPaddingBytesNeeded = 1200 - originalPacketSize; + if (originalLength < 64 && originalLength + numPaddingBytesNeeded > 64) { + // if originalLength + numPaddingBytesNeeded == 64, will send + // 1201 bytes + numPaddingBytesNeeded--; + } + final List newFrames = new ArrayList<>(); + for (QuicFrame frame : frames) { + if (frame instanceof PaddingFrame) { + // a padding frame already exists, instead of including this and the new padding + // frame in the new frames, we just include 1 single padding frame whose + // combined size will be the sum of all existing padding frames and the + // additional padding bytes needed + numPaddingBytesNeeded += frame.size(); + continue; + } + // non-padding frame, include it in the new frames + newFrames.add(frame); + } + // add the padding frame as the first frame + newFrames.add(0, new PaddingFrame(numPaddingBytesNeeded)); + return new OutgoingInitialPacket( + source, destination, this.quicVersion.versionNumber(), + token, packetNumber, encodedPacketNumber, newFrames, tagSize); + } + } + + /** + * Create a new unencrypted VersionNegotiationPacket to be transmitted over the wire + * after encryption. + * + * @param source The source connection ID + * @param destination The destination connection ID + * @param versions The supported quic versions + * @return the new initial packet + */ + public static OutgoingQuicPacket newVersionNegotiationPacket(QuicConnectionId source, + QuicConnectionId destination, + int[] versions) { + return new OutgoingVersionNegotiationPacket(source, destination, versions); + } + + /** + * Create a new unencrypted RetryPacket to be transmitted over the wire + * after encryption. + * + * @param source The source connection ID + * @param destination The destination connection ID + * @param retryToken The retry token + * @return the new retry packet + */ + public OutgoingQuicPacket newRetryPacket(QuicConnectionId source, + QuicConnectionId destination, + byte[] retryToken) { + return new OutgoingRetryPacket( + source, destination, this.quicVersion.versionNumber(), retryToken); + } + + /** + * Create a new unencrypted ZeroRttPacket to be transmitted over the wire + * after encryption. + * + * @param source The source connection ID + * @param destination The destination connection ID + * @param packetNumber The packet number + * @param ackedPacketNumber The largest acknowledged packet number + * @param frames The zero RTT packet payload + * @param codingContext + * @return the new zero RTT packet + */ + public OutgoingQuicPacket newZeroRttPacket(QuicConnectionId source, + QuicConnectionId destination, + long packetNumber, + long ackedPacketNumber, + List frames, + CodingContext codingContext) { + if (debug.on()) { + debug.log("newZeroRttPacket: fullPN=%d ackedPN=%d", + packetNumber, ackedPacketNumber); + } + byte[] encodedPacketNumber = encodePacketNumber(packetNumber, ackedPacketNumber); + QuicTLSEngine tlsEngine = codingContext.getTLSEngine(); + int tagSize = tlsEngine.getAuthTagSize(); + int protectionSampleSize = tlsEngine.getHeaderProtectionSampleSize(KeySpace.ZERO_RTT); + int minLength = 4 + protectionSampleSize - encodedPacketNumber.length - tagSize; + + return new OutgoingZeroRttPacket( + source, destination, this.quicVersion.versionNumber(), packetNumber, + encodedPacketNumber, padFrames(frames, minLength), tagSize); + } + + /** + * Create a new unencrypted HandshakePacket to be transmitted over the wire + * after encryption. + * + * @param source The source connection ID + * @param destination The destination connection ID + * @param packetNumber The packet number + * @param frames The handshake packet payload + * @param codingContext + * @return the new handshake packet + */ + public OutgoingQuicPacket newHandshakePacket(QuicConnectionId source, + QuicConnectionId destination, + long packetNumber, + long largestAckedPN, + List frames, CodingContext codingContext) { + if (debug.on()) { + debug.log("newHandshakePacket: fullPN=%d ackedPN=%d", + packetNumber, largestAckedPN); + } + byte[] encodedPacketNumber = encodePacketNumber(packetNumber, largestAckedPN); + QuicTLSEngine tlsEngine = codingContext.getTLSEngine(); + int tagSize = tlsEngine.getAuthTagSize(); + int protectionSampleSize = tlsEngine.getHeaderProtectionSampleSize(KeySpace.HANDSHAKE); + int minLength = 4 + protectionSampleSize - encodedPacketNumber.length - tagSize; + + return new OutgoingHandshakePacket( + source, destination, this.quicVersion.versionNumber(), + packetNumber, encodedPacketNumber, padFrames(frames, minLength), tagSize); + } + + /** + * Create a new unencrypted OneRttPacket to be transmitted over the wire + * after encryption. + * + * @param destination The destination connection ID + * @param packetNumber The packet number + * @param ackedPacketNumber The largest acknowledged packet number + * @param frames The one RTT packet payload + * @param codingContext + * @return the new one RTT packet + */ + public OneRttPacket newOneRttPacket(QuicConnectionId destination, + long packetNumber, + long ackedPacketNumber, + List frames, + CodingContext codingContext) { + if (debug.on()) { + debug.log("newOneRttPacket: fullPN=%d ackedPN=%d", + packetNumber, ackedPacketNumber); + } + byte[] encodedPacketNumber = encodePacketNumber(packetNumber, ackedPacketNumber); + QuicTLSEngine tlsEngine = codingContext.getTLSEngine(); + int tagSize = tlsEngine.getAuthTagSize(); + int protectionSampleSize = tlsEngine.getHeaderProtectionSampleSize(KeySpace.ONE_RTT); + // packets should be at least 22 bytes longer than the local connection id length. + // we ensure that by padding the frames to the necessary size + int minPayloadSize = codingContext.minShortPacketPayloadSize(destination.length()); + assert protectionSampleSize == tagSize; + int minLength = Math.max(5, minPayloadSize) - encodedPacketNumber.length; + return new OutgoingOneRttPacket( + destination, packetNumber, + encodedPacketNumber, padFrames(frames, minLength), tagSize); + } + + /** + * Creates a packet in the given keyspace for the purpose of sending + * a CONNECTION_CLOSE, or a generic list of frames. + * The {@code initialToken} parameter is ignored if the key + * space is not INITIAL. + * + * @param keySpace the sending key space + * @param packetSpace the packet space + * @param sourceId the source connection id + * @param destinationId the destination connection id + * @param initialToken the initial token for INITIAL packets + * @param frames the list of frames + * @param codingContext the coding context + * @return a packet in the given key space + * @throws IllegalArgumentException if the packet number space is + * not one of INITIAL, HANDSHAKE, or APPLICATION + */ + public OutgoingQuicPacket newOutgoingPacket( + KeySpace keySpace, + PacketSpace packetSpace, + QuicConnectionId sourceId, + QuicConnectionId destinationId, + byte[] initialToken, + List frames, CodingContext codingContext) { + long largestAckedPN = packetSpace.getLargestPeerAckedPN(); + return switch (packetSpace.packetNumberSpace()) { + case APPLICATION -> { + long newPacketNumber = packetSpace.allocateNextPN(); + if (keySpace == KeySpace.ZERO_RTT) { + assert !frames.stream().anyMatch(f -> !f.isValidIn(PacketType.ZERORTT)) + : "%s contains frames not valid in %s" + .formatted(frames, keySpace); + yield newZeroRttPacket(sourceId, + destinationId, + newPacketNumber, + largestAckedPN, + frames, + codingContext); + } else { + assert keySpace == KeySpace.ONE_RTT; + assert !frames.stream().anyMatch(f -> !f.isValidIn(PacketType.ONERTT)) + : "%s contains frames not valid in %s" + .formatted(frames, keySpace); + final OneRttPacket oneRttPacket = newOneRttPacket(destinationId, + newPacketNumber, + largestAckedPN, + frames, + codingContext); + assert oneRttPacket instanceof OutgoingOneRttPacket : + "unexpected 1-RTT packet type: " + oneRttPacket.getClass(); + yield (OutgoingQuicPacket) oneRttPacket; + } + } + case HANDSHAKE -> { + assert keySpace == KeySpace.HANDSHAKE; + assert !frames.stream().anyMatch(f -> !f.isValidIn(PacketType.HANDSHAKE)) + : "%s contains frames not valid in %s" + .formatted(frames, keySpace); + long newPacketNumber = packetSpace.allocateNextPN(); + yield newHandshakePacket(sourceId, destinationId, + newPacketNumber, largestAckedPN, + frames, codingContext); + } + case INITIAL -> { + assert keySpace == KeySpace.INITIAL; + assert !frames.stream().anyMatch(f -> !f.isValidIn(PacketType.INITIAL)) + : "%s contains frames not valid in %s" + .formatted(frames, keySpace); + long newPacketNumber = packetSpace.allocateNextPN(); + yield newInitialPacket(sourceId, destinationId, + initialToken, newPacketNumber, + largestAckedPN, + frames, codingContext); + } + case NONE -> { + throw new IllegalArgumentException("packetSpace: %s, keySpace: %s" + .formatted(packetSpace.packetNumberSpace(), keySpace)); + } + }; + } + + /** + * Encodes the given QuicPacket. + * + * @param packet the packet to encode + * @param buffer the byte buffer to write the packet into + * @param context context for encoding + * @throws IllegalArgumentException if the packet is not an OutgoingQuicPacket, + * or if the packet version does not match the encoder version + * @throws BufferOverflowException if the buffer is not large enough + * @throws QuicKeyUnavailableException if the packet could not be encrypted + * because the required encryption key is not available + * @throws QuicTransportException if encrypting the packet resulted + * in an error that requires closing the connection + */ + public void encode(QuicPacket packet, ByteBuffer buffer, CodingContext context) + throws QuicKeyUnavailableException, QuicTransportException { + switch (packet) { + case OutgoingOneRttPacket p -> encodePacket(p, buffer, context); + case OutgoingZeroRttPacket p -> encodePacket(p, buffer, context); + case OutgoingVersionNegotiationPacket p -> encodePacket(p, buffer); + case OutgoingHandshakePacket p -> encodePacket(p, buffer, context); + case OutgoingInitialPacket p -> encodePacket(p, buffer, context); + case OutgoingRetryPacket p -> encodePacket(p, buffer, context); + default -> throw new IllegalArgumentException("packet is not an outgoing packet: " + + packet.getClass()); + } + } + + /** + * Compute the max size of the usable payload of an initial + * packet, given the max size of the datagram. + *

+     * Initial Packet {
+     *     Header (1 byte),
+     *     Version (4 bytes),
+     *     Destination Connection ID Length (1 byte),
+     *     Destination Connection ID (0..20 bytes),
+     *     Source Connection ID Length (1 byte),
+     *     Source Connection ID (0..20 bytes),
+     *     Token Length (variable int),
+     *     Token (..),
+     *     Length (variable int),
+     *     Packet Number (1..4 bytes),
+     *     Packet Payload (1 to ... bytes),
+     * }
+     * 
+ * + * @param codingContext the coding context, used to compute the + * encoded packet number + * @param pnsize packet number length + * @param tokenLength the length of the token (or {@code 0}) + * @param scidLength the length of the source connection id + * @param dstidLength the length of the destination connection id + * @param maxDatagramSize the desired total maximum size + * of the packet after encryption + * @return the maximum size of the payload that can be fit into this + * initial packet + */ + public static int computeMaxInitialPayloadSize(CodingContext codingContext, + int pnsize, + int tokenLength, + int scidLength, + int dstidLength, + int maxDatagramSize) { + // header=1, version=4, len(scidlen)+len(dstidlen)=2 + int overhead = 1 + 4 + 2 + scidLength + dstidLength + tokenLength + + VariableLengthEncoder.getEncodedSize(tokenLength); + // encryption tag, included in the payload, but not usable for frames + int tagSize = codingContext.getTLSEngine().getAuthTagSize(); + int length = maxDatagramSize - overhead - 1; // at least 1 byte for length encoding + if (length <= 0) return 0; + int lenbefore = VariableLengthEncoder.getEncodedSize(length); + length = length - lenbefore + 1; // discount length encoding + // int lenafter = VariableLengthEncoder.getEncodedSize(length); // check + // assert lenafter == lenbefore : "%s -> %s (before:%s, after:%s)" + // .formatted(maxDatagramSize - overhead -1, length, lenbefore, lenafter); + if (length <= 0) return 0; + int available = length - pnsize - tagSize; + if (available < 0) return 0; + return available; + } + + /** + * Compute the max size of the usable payload of a handshake + * packet, given the max size of the datagram. + *
+     * Initial Packet {
+     *     Header (1 byte),
+     *     Version (4 bytes),
+     *     Destination Connection ID Length (1 byte),
+     *     Destination Connection ID (0..20 bytes),
+     *     Source Connection ID Length (1 byte),
+     *     Source Connection ID (0..20 bytes),
+     *     Length (variable int),
+     *     Packet Number (1..4 bytes),
+     *     Packet Payload (1 to ... bytes),
+     * }
+     * 
+ * @param codingContext the coding context, used to compute the + * encoded packet number + * @param packetNumber the full packet number + * @param scidLength the length of the source connection id + * @param dstidLength the length of the destination connection id + * @param maxDatagramSize the desired total maximum size + * of the packet after encryption + * @return the maximum size of the payload that can be fit into this + * initial packet + */ + public static int computeMaxHandshakePayloadSize(CodingContext codingContext, + long packetNumber, + int scidLength, + int dstidLength, + int maxDatagramSize) { + // header=1, version=4, len(scidlen)+len(dstidlen)=2 + int overhead = 1 + 4 + 2 + scidLength + dstidLength; + int pnsize = computePacketNumberLength(packetNumber, + codingContext.largestAckedPN(PacketNumberSpace.HANDSHAKE)); + // encryption tag, included in the payload, but not usable for frames + int tagSize = codingContext.getTLSEngine().getAuthTagSize(); + int length = maxDatagramSize - overhead -1; // at least 1 byte for length encoding + if (length < 0) return 0; + int lenbefore = VariableLengthEncoder.getEncodedSize(length); + length = length - lenbefore + 1; // discount length encoding + int available = length - pnsize - tagSize; + return available; + } + + /** + * Computes the maximum usable payload that can be carried on in a + * {@link OneRttPacket} given the max datagram size before + * encryption. + * @param codingContext the coding context + * @param packetNumber the packet number + * @param dstidLength the peer connection id length + * @param maxDatagramSizeBeforeEncryption the maximum size of the datagram + * @return the maximum payload that can be carried on in a + * {@link OneRttPacket} given the max datagram size before + * encryption + */ + public static int computeMaxOneRTTPayloadSize(final CodingContext codingContext, + final long packetNumber, + final int dstidLength, + final int maxDatagramSizeBeforeEncryption, + final long largestPeerAckedPN) { + // header=1 + final int overhead = 1 + dstidLength; + // always reserve four bytes for packet number to avoid issues with packet + // sizes when retransmitting. This is a hack, but it avoids having to + // repack StreamFrames. + final int pnsize = 4; //computePacketNumberLength(packetNumber, largestPeerAckedPN); + // encryption tag, included in the payload, but not usable for frames + final int tagSize = codingContext.getTLSEngine().getAuthTagSize(); + final int available = maxDatagramSizeBeforeEncryption - overhead - pnsize - tagSize; + if (available < 0) return 0; + return available; + } + + private static ByteBuffer putInt32(ByteBuffer buffer, int value) { + assert buffer.order() == ByteOrder.BIG_ENDIAN; + return buffer.putInt(value); + } + + + /** + * A {@code PacketWriter} to write a Quic packet. + *

+ * A {@code PacketWriter} offers high level helper methods to write + * data (such as Connection IDs or Packet Numbers) from a Quic packet. + * It has however no or little knowledge of the actual packet structure. + * It is driven by the {@code encode} method of the appropriate + * {@code OutgoingQuicPacket} type. + *

+ * A {@code PacketWriter} is stateful: it encapsulates a {@code ByteBuffer} + * (or possibly a list of byte buffers - as a future enhancement) and + * advances the position on the buffer it is writing. + * + */ + static class PacketWriter { + final ByteBuffer buffer; + final int offset; + final int initialLimit; + final CodingContext context; + final PacketType packetType; + + PacketWriter(ByteBuffer buffer, CodingContext context, PacketType packetType) { + assert buffer.order() == ByteOrder.BIG_ENDIAN; + int pos = buffer.position(); + int limit = buffer.limit(); + this.buffer = buffer; + this.offset = pos; + this.initialLimit = limit; + this.context = context; + this.packetType = packetType; + } + + public int offset() { + return offset; + } + + public int position() { + return buffer.position(); + } + + public int remaining() { + return buffer.remaining(); + } + + public boolean hasRemaining() { + return buffer.hasRemaining(); + } + + public int bytesWritten() { + return position() - offset; + } + + public void reset() { + buffer.position(offset); + buffer.limit(initialLimit); + } + + public byte headers() { + return buffer.get(offset); + } + + public void headers(byte headers) { + buffer.put(offset, headers); + } + + public PacketType packetType() { + return packetType; + } + + public void writeHeaders(byte headers) { + buffer.put(headers); + } + + public void writeVersion(int version) { + buffer.putInt(version); + } + + public void writeSupportedVersions(int[] versions) { + for (int i=0 ; i= 0 && packetLength <= VariableLengthEncoder.MAX_ENCODED_INTEGER + : packetLength; + writeVariableLength(packetLength); + } + + private void writeTokenLength(long tokenLength) { + writeVariableLength(tokenLength); + } + + public void writeToken(byte[] token) { + if (token == null) { + buffer.put((byte)0); + } else { + writeTokenLength(token.length); + buffer.put(token); + } + } + + public void writeVariableLength(long value) { + VariableLengthEncoder.encode(buffer, value); + } + + private void maskPacketNumber(int packetNumberStart, int packetNumberLength, ByteBuffer mask) { + for (int i = 0; i < packetNumberLength; i++) { + buffer.put(packetNumberStart + i, (byte)(buffer.get(packetNumberStart + i) ^ mask.get())); + } + } + + public void writeEncodedPacketNumber(byte[] packetNumber) { + buffer.put(packetNumber); + } + + public void encryptPayload(final long packetNumber, final int payloadstart) + throws QuicTransportException, QuicKeyUnavailableException { + final int payloadend = buffer.position(); + buffer.position(payloadstart); // position the output buffer + final int payloadLength = payloadend - payloadstart; + final int headersLength = payloadstart - offset; + final ByteBuffer packetHeader = buffer.slice(offset, headersLength); + final ByteBuffer packetPayload = buffer.slice(payloadstart, payloadLength) + .asReadOnlyBuffer(); + try { + context.getTLSEngine().encryptPacket(packetType.keySpace().get(), packetNumber, + new HeaderGenerator(this.packetType, packetHeader), packetPayload, buffer); + } catch (ShortBufferException e) { + throw new QuicTransportException(e.toString(), null, 0, + QuicTransportErrors.INTERNAL_ERROR); + } + } + + public void writePayload(List frames) { + for (var frame : frames) frame.encode(buffer); + } + + public void writeLongConnectionId(QuicConnectionId connId) { + ByteBuffer src = connId.asReadOnlyBuffer(); + assert src.remaining() <= MAX_CONNECTION_ID_LENGTH; + buffer.put((byte)src.remaining()); + buffer.put(src); + } + + public void writeShortConnectionId(QuicConnectionId connId) { + ByteBuffer src = connId.asReadOnlyBuffer(); + assert src.remaining() <= MAX_CONNECTION_ID_LENGTH; + buffer.put(src); + } + + public void writeRetryToken(byte[] retryToken) { + buffer.put(retryToken); + } + + @Override + public String toString() { + return "PacketWriter(offset=%s, pos=%s, remaining=%s)" + .formatted(offset, position(), remaining()); + } + + public void protectHeaderLong(int packetNumberStart, int packetNumberLength) + throws QuicKeyUnavailableException, QuicTransportException { + protectHeader(packetNumberStart, packetNumberLength, (byte) 0x0f); + } + + public void protectHeaderShort(int packetNumberStart, int packetNumberLength) + throws QuicKeyUnavailableException, QuicTransportException { + protectHeader(packetNumberStart, packetNumberLength, (byte) 0x1f); + } + + private void protectHeader(int packetNumberStart, int packetNumberLength, byte headerMask) + throws QuicKeyUnavailableException, QuicTransportException { + // expect position at the end of packet + QuicTLSEngine tlsEngine = context.getTLSEngine(); + int sampleSize = tlsEngine.getHeaderProtectionSampleSize(packetType.keySpace().get()); + assert buffer.position() - packetNumberStart >= sampleSize + 4 : buffer.position() - packetNumberStart - sampleSize - 4; + + ByteBuffer sample = buffer.slice(packetNumberStart + 4, sampleSize); + ByteBuffer encryptedSample = tlsEngine.computeHeaderProtectionMask(packetType.keySpace().get(), false, sample); + byte headers = headers(); + headers ^= (byte) (encryptedSample.get() & headerMask); + headers(headers); + maskPacketNumber(packetNumberStart, packetNumberLength, encryptedSample); + } + + private void signRetry(final int version) throws QuicTransportException { + final QuicVersion retryVersion = QuicVersion.of(version) + .orElseThrow(() -> new IllegalArgumentException("Unknown Quic version 0x" + + Integer.toHexString(version))); + int payloadend = buffer.position(); + ByteBuffer temp = buffer.asReadOnlyBuffer(); + temp.position(offset); + temp.limit(payloadend); + try { + context.getTLSEngine().signRetryPacket(retryVersion, + context.originalServerConnId().asReadOnlyBuffer(), temp, buffer); + } catch (ShortBufferException e) { + throw new QuicTransportException("Failed to sign packet", + null, 0, QuicTransportErrors.INTERNAL_ERROR); + } + } + + // generates packet header and is capable of inserting a key phase into the header + // when appropriate + private static final class HeaderGenerator implements IntFunction { + private final PacketType packetType; + private final ByteBuffer header; + + private HeaderGenerator(final PacketType packetType, final ByteBuffer header) { + this.packetType = packetType; + this.header = header; + } + + @Override + public ByteBuffer apply(final int keyPhase) { + // we use key phase only in 1-RTT packet header + if (packetType != PacketType.ONERTT) { + assert keyPhase == 0 : "unexpected key phase " + keyPhase + + " for packet type " + packetType; + // return the packet header without setting any key phase bit + return header; + } + // update the key phase bit in the packet header + setKeyPhase(keyPhase); + return header.position(0).asReadOnlyBuffer(); + } + + private void setKeyPhase(final int kp) { + if (kp != 0 && kp != 1) { + throw new IllegalArgumentException("Invalid key phase: " + kp); + } + final byte headerFirstByte = this.header.get(); + final byte updated = (byte) (headerFirstByte | (kp << 2)); + this.header.put(0, updated); + } + } + } + + + /** + * Adds required padding frames if necessary. + * Needed to make sure there's enough bytes to apply header protection + * @param frames requested list of frames + * @param minLength requested minimum length + * @return list of frames that meets the minimum length requirement + */ + private static List padFrames(List frames, int minLength) { + if (frames.size() >= minLength) { + return frames; + } + int size = frames.stream().mapToInt(QuicFrame::size).reduce(0, Math::addExact); + if (size >= minLength) { + return frames; + } + List result = new ArrayList<>(frames.size() + 1); + // add padding frame in front - some frames extend to end of packet + result.add(new PaddingFrame(minLength - size)); + result.addAll(frames); + return result; + } + + /** + * Returns an encoder for the given Quic version. + * Returns {@code null} if no encoder for that version exists. + * + * @param quicVersion the Quic protocol version number + * @return an encoder for the given Quic version or {@code null} + */ + public static QuicPacketEncoder of(final QuicVersion quicVersion) { + return switch (quicVersion) { + case QUIC_V1 -> Encoders.QUIC_V1_ENCODER; + case QUIC_V2 -> Encoders.QUIC_V2_ENCODER; + default -> throw new IllegalArgumentException("No packet encoder for Quic version " + quicVersion); + }; + } + + private static final class Encoders { + static final Random RANDOM = new Random(); + static final QuicPacketEncoder QUIC_V1_ENCODER = new QuicPacketEncoder(QuicVersion.QUIC_V1); + static final QuicPacketEncoder QUIC_V2_ENCODER = new QuicPacketEncoder(QuicVersion.QUIC_V2); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketNumbers.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketNumbers.java new file mode 100644 index 00000000000..197d46fc0b0 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/QuicPacketNumbers.java @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +import java.nio.ByteBuffer; + +/** + * QUIC packet number encoding/decoding routines. + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public class QuicPacketNumbers { + + /** + * Returns the number of bytes needed to encode a packet number + * given the full packet number and the largest ACK'd packet. + * + * @param fullPN the full packet number + * @param largestAcked the largest ACK'd packet, or -1 if none so far + * + * @throws IllegalArgumentException if number can't represented in 4 bytes + * @return the number of bytes required to encode the packet + */ + public static int computePacketNumberLength(long fullPN, long largestAcked) { + + long numUnAcked; + + if (largestAcked == -1) { + numUnAcked = fullPN + 1; + } else { + numUnAcked = fullPN - largestAcked; + } + + /* + * log(n, 2) + 1; ceil(minBits / 8); + * + * value will never be non-positive, so don't need to worry about the + * special cases. + */ + assert numUnAcked > 0 : "numUnAcked %s < 0 (fullPN: %s, largestAcked: %s)" + .formatted(numUnAcked, fullPN, largestAcked); + int minBits = 64 - Long.numberOfLeadingZeros(numUnAcked) + 1; + int numBytes = (minBits + 7) / 8; + + if (numBytes > 4) { + throw new IllegalArgumentException( + "Encoded packet number needs %s bytes for pn=%s, ack=%s" + .formatted(numBytes, fullPN, largestAcked)); + } + + return numBytes; + } + + /** + * Encode the full packet number against the largest ACK'd packet. + * + * Follows the algorithm outlined in + * + * RFC 9000. Appendix A.2 + * + * @param fullPN the full packet number + * @param largestAcked the largest ACK'd packet, or -1 if none so far + * + * @throws IllegalArgumentException if number can't be represented in 4 bytes + * @return byte array containing fullPN + */ + public static byte[] encodePacketNumber( + long fullPN, long largestAcked) { + + // throws IAE if more than 4 bytes are needed + int numBytes = computePacketNumberLength(fullPN, largestAcked); + assert numBytes <= 4 : numBytes; + return truncatePacketNumber(fullPN, numBytes); + } + + /** + * Truncate the full packet number to fill into {@code numBytes}. + * + * Follows the algorithm outlined in + * + * RFC 9000, Appendix A.2 + * + * @apiNote + * {@code numBytes} should have been computed using + * {@link #computePacketNumberLength(long, long)} + * + * @param fullPN the full packet number + * @param numBytes the number of bytes in which to encode + * the packet number + * + * @throws IllegalArgumentException if numBytes is out of range + * @return byte array containing fullPN + */ + public static byte[] truncatePacketNumber( + long fullPN, int numBytes) { + + if (numBytes <= 0 || numBytes > 4) { + throw new IllegalArgumentException( + "Invalid packet number length: " + numBytes); + } + + // Fill in the array. + byte[] retval = new byte[numBytes]; + for (int i = numBytes - 1; i >= 0; i--) { + retval[i] = (byte) (fullPN & 0xff); + fullPN = fullPN >>> 8; + } + + return retval; + } + + /** + * Decode the packet numbers against the largest ACK'd packet after header + * protection has been removed. + * + * Follows the algorithm outlined in + * + * RFC 9000, Appendix A.3 + * + * @param largestPN the largest packet number that has been successfully + * processed in the current packet number space + * @param buf a {@code ByteBuffer} containing the value of the + * Packet Number field + * @param pnNBytes the number of bytes indicated by the Packet + * Number Length field + * + * @throws java.nio.BufferUnderflowException if there is not enough data in the + * buffer + * @return the decoded packet number + */ + public static long decodePacketNumber( + long largestPN, ByteBuffer buf, int pnNBytes) { + + assert pnNBytes >= 1 && pnNBytes <= 4 + : "decodePacketNumber: " + pnNBytes; + + long truncatedPN = 0; + for (int i = 0; i < pnNBytes; i++) { + truncatedPN = (truncatedPN << 8) | (buf.get() & 0xffL); + } + + int pnNBits = pnNBytes * 8; + + long expectedPN = largestPN + 1L; + assert expectedPN >= 0 : "expectedPN: " + expectedPN; + long pnWin = 1L << pnNBits; + long pnHWin = pnWin / 2L; + long pnMask = pnWin - 1L; + + // The incoming packet number should be greater than + // expectedPN - pn_HWin and less than or equal to + // expectedPN + pn_HWin + // + // This means we cannot just strip the trailing bits from + // expectedPN and add the truncatedPN because that might + // yield a value outside the window. + // + // The following code calculates a candidate value and + // makes sure it's within the packet number window. + // Note the extra checks to prevent overflow and underflow. + long candidatePN = (expectedPN & ~pnMask) | truncatedPN; + + if ((candidatePN <= (expectedPN - pnHWin)) + && (candidatePN < ((1L << 62) - pnWin))) { + return candidatePN + pnWin; + } + + if ((candidatePN - pnHWin > expectedPN) + && (candidatePN >= pnWin)) { + return candidatePN - pnWin; + } + return candidatePN; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/RetryPacket.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/RetryPacket.java new file mode 100644 index 00000000000..51638caf98c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/RetryPacket.java @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +/** + * This class models Quic Retry Packets, as defined by + * RFC 9000, Section 17.2.5: + * + *

{@code
+ *    A Retry packet uses a long packet header with a type value of 0x03.
+ *    It carries an address validation token created by the server.
+ *    It is used by a server that wishes to perform a retry; see Section 8.1.
+ *
+ *    Retry Packet {
+ *      Header Form (1) = 1,
+ *      Fixed Bit (1) = 1,
+ *      Long Packet Type (2) = 3,
+ *      Unused (4),
+ *      Version (32),
+ *      Destination Connection ID Length (8),
+ *      Destination Connection ID (0..160),
+ *      Source Connection ID Length (8),
+ *      Source Connection ID (0..160),
+ *      Retry Token (..),
+ *      Retry Integrity Tag (128),
+ *    }
+ * }
+ * + *

Subclasses of this class may be used to model packets exchanged with either + * Quic Version 2. + * Note that Quic Version 2 uses the same Retry Packet structure than + * Quic Version 1, but uses a different long packet type than that shown above. See + * RFC 9369, Section 3.2. + * + * @see RFC 9000, Section 8.1/a> + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ +public interface RetryPacket extends LongHeaderPacket { + @Override + default PacketType packetType() { + return PacketType.RETRY; + } + + /** + * This packet type is not numbered: returns + * {@link PacketNumberSpace#NONE} always. + * @return {@link PacketNumberSpace#NONE} + */ + @Override + default PacketNumberSpace numberSpace() { + return PacketNumberSpace.NONE; + } + + /** + * This packet type is not numbered: always returns -1L. + * @return -1L + */ + @Override + default long packetNumber() { return -1L; } + + /** + * {@return the packet's retry token} + * + * As per RFC 9000, Section 17.2.5: + *

{@code
+     *    An opaque token that the server can use to validate the client's address.
+     * }
+ * + * @see + * RFC 9000, Section 8.1 + */ + byte[] retryToken(); +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/ShortHeaderPacket.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/ShortHeaderPacket.java new file mode 100644 index 00000000000..a56689e6a0b --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/ShortHeaderPacket.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +/** + * This interface models Quic Short Header packets, as defined by + * RFC 8999, Section 5.2: + * + *
{@code
+ *    Short Header Packet {
+ *      Header Form (1) = 0,
+ *      Version-Specific Bits (7),
+ *      Destination Connection ID (..),
+ *      Version-Specific Data (..),
+ *    }
+ * }
+ * + *

Subclasses of this class may be used to model packets exchanged with either + * Quic Version 2. + * + * @spec https://www.rfc-editor.org/info/rfc8999 + * RFC 8999: Version-Independent Properties of QUIC + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ +public interface ShortHeaderPacket extends QuicPacket { + @Override + default HeadersType headersType() { return HeadersType.SHORT; } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/VersionNegotiationPacket.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/VersionNegotiationPacket.java new file mode 100644 index 00000000000..0ba5ba08c7e --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/VersionNegotiationPacket.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +/** + * This class models Quic Version Negotiation Packets, as defined by + * RFC 9000, Section 17.2.1: + * + *

{@code
+ *    A Version Negotiation packet is inherently not version-specific.
+ *    Upon receipt by a client, it will be identified as a Version
+ *    Negotiation packet based on the Version field having a value of 0.
+ *
+ *    The Version Negotiation packet is a response to a client packet that
+ *    contains a version that is not supported by the server, and is only
+ *    sent by servers.
+ *
+ *    The layout of a Version Negotiation packet is:
+ *
+ *    Version Negotiation Packet {
+ *      Header Form (1) = 1,
+ *      Unused (7),
+ *      Version (32) = 0,
+ *      Destination Connection ID Length (8),
+ *      Destination Connection ID (0..2040),
+ *      Source Connection ID Length (8),
+ *      Source Connection ID (0..2040),
+ *      Supported Version (32) ...,
+ *    }
+ * }
+ * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + */ +public interface VersionNegotiationPacket extends LongHeaderPacket { + @Override + default PacketType packetType() { return PacketType.VERSIONS; } + @Override + default int version() { return 0;} + /** + * This packet type is not numbered: returns + * {@link PacketNumberSpace#NONE} always. + * @return {@link PacketNumberSpace#NONE} + */ + @Override + default PacketNumberSpace numberSpace() { return PacketNumberSpace.NONE; } + /** + * This packet type is not numbered: returns -1L always. + * @return -1L + */ + @Override + default long packetNumber() { return -1L; } + int[] supportedVersions(); +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/ZeroRttPacket.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/ZeroRttPacket.java new file mode 100644 index 00000000000..c680aae1486 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/packets/ZeroRttPacket.java @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.packets; + +import java.util.List; + +import jdk.internal.net.http.quic.frames.QuicFrame; + +/** + * This class models Quic 0-RTT Packets, as defined by + * RFC 9000, Section 17.2.3: + * + *
{@code
+ *    A 0-RTT packet uses long headers with a type value of 0x01, followed
+ *    by the Length and Packet Number fields; see Section 17.2. The first
+ *    byte contains the Reserved and Packet Number Length bits; see Section 17.2.
+ *    A 0-RTT packet is used to carry "early" data from the client to the server
+ *    as part of the first flight, prior to handshake completion. As part of the
+ *    TLS handshake, the server can accept or reject this early data.
+ *
+ *    See Section 2.3 of [TLS13] for a discussion of 0-RTT data and its limitations.
+
+ *    0-RTT Packet {
+ *      Header Form (1) = 1,
+ *      Fixed Bit (1) = 1,
+ *      Long Packet Type (2) = 1,
+ *      Reserved Bits (2),
+ *      Packet Number Length (2),
+ *      Version (32),
+ *      Destination Connection ID Length (8),
+ *      Destination Connection ID (0..160),
+ *      Source Connection ID Length (8),
+ *      Source Connection ID (0..160),
+ *      Length (i),
+ *      Packet Number (8..32),
+ *      Packet Payload (..),
+ *    }
+ * } 
+ * + *

Subclasses of this class may be used to model packets exchanged with either + * Quic Version 2. + * Note that Quic Version 2 uses the same 0-RTT Packet structure than + * Quic Version 1, but uses a different long packet type than that shown above. See + * RFC 9369, Section 3.2. + * + * @see + * RFC 9000, Section 17.2 + * + * @see + * [TLS13] RFC 8446, Section 2.3 + * + * @spec https://www.rfc-editor.org/info/rfc9000 + * RFC 9000: QUIC: A UDP-Based Multiplexed and Secure Transport + * @spec https://www.rfc-editor.org/info/rfc9369 + * RFC 9369: QUIC Version 2 + */ +public interface ZeroRttPacket extends LongHeaderPacket { + @Override + default PacketType packetType() { + return PacketType.ZERORTT; + } + + @Override + default PacketNumberSpace numberSpace() { + return PacketNumberSpace.APPLICATION; + } + + @Override + default boolean hasLength() { return true; } + + /** + * This packet number. + * @return this packet number. + */ + @Override + long packetNumber(); + + @Override + List frames(); +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/AbstractQuicStream.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/AbstractQuicStream.java new file mode 100644 index 00000000000..e1a7fce1059 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/AbstractQuicStream.java @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import jdk.internal.net.http.quic.QuicConnectionImpl; + +/** + * An abstract class to model a QuicStream. + * A quic stream can be either unidirectional + * or bidirectional. A unidirectional stream can + * be opened for reading or for writing. + * Concrete subclasses of {@code AbstractQuicStream} should + * implement {@link QuicSenderStream} (unidirectional {@link + * StreamMode#WRITE_ONLY} stream), or {@link QuicReceiverStream} + * (unidirectional {@link StreamMode#READ_ONLY} stream), or + * both (bidirectional {@link StreamMode#READ_WRITE} stream). + */ +abstract sealed class AbstractQuicStream implements QuicStream + permits QuicBidiStreamImpl, QuicSenderStreamImpl, QuicReceiverStreamImpl { + + private final QuicConnectionImpl connection; + private final long streamId; + private final StreamMode mode; + + AbstractQuicStream(QuicConnectionImpl connection, long streamId) { + this.mode = mode(connection, streamId); + this.streamId = streamId; + this.connection = connection; + } + + private static StreamMode mode(QuicConnectionImpl connection, long streamId) { + if (QuicStreams.isBidirectional(streamId)) return StreamMode.READ_WRITE; + if (connection.isClientConnection()) { + return QuicStreams.isClientInitiated(streamId) + ? StreamMode.WRITE_ONLY : StreamMode.READ_ONLY; + } else { + return QuicStreams.isClientInitiated(streamId) + ? StreamMode.READ_ONLY : StreamMode.WRITE_ONLY; + } + } + + /** + * {@return the {@code QuicConnectionImpl} instance this stream + * belongs to} + */ + final QuicConnectionImpl connection() { + return connection; + } + + @Override + public final long streamId() { + return streamId; + } + + @Override + public final StreamMode mode() { + return mode; + } + + @Override + public final boolean isClientInitiated() { + return QuicStreams.isClientInitiated(type()); + } + + @Override + public final boolean isServerInitiated() { + return QuicStreams.isServerInitiated(type()); + } + + @Override + public final boolean isBidirectional() { + return QuicStreams.isBidirectional(type()); + } + + @Override + public final boolean isLocalInitiated() { + return connection().isClientConnection() == isClientInitiated(); + } + + @Override + public final boolean isRemoteInitiated() { + return connection().isClientConnection() != isClientInitiated(); + } + + @Override + public final int type() { + return QuicStreams.streamType(streamId); + } + + /** + * {@return true if this stream isn't expecting anything + * from the peer and can be removed from the streams map} + */ + public abstract boolean isDone(); + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/CryptoWriterQueue.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/CryptoWriterQueue.java new file mode 100644 index 00000000000..cbbbf6d083d --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/CryptoWriterQueue.java @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import jdk.internal.net.http.quic.frames.CryptoFrame; +import jdk.internal.net.http.quic.VariableLengthEncoder; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Iterator; +import java.util.Queue; + +/** + * Class that buffers crypto data received from QuicTLSEngine. + * Generates CryptoFrames of requested size. + * + * Normally the frames are produced sequentially. However, when the client + * receives a Retry packet or a Version Negotiation packet, the client hello + * needs to be replayed. In that case we need to keep the processed data + * in the queues. + */ +public class CryptoWriterQueue { + private final Queue queue = new ArrayDeque<>(); + private long position = 0; + // amount of bytes remaining across all the enqueued buffers + private int totalRemaining = 0; + private boolean keepReplayData; + + /** + * Notify the writer to start keeping processed data. Can only be called on a fresh writer. + * @throws IllegalStateException if some data was processed already + */ + public synchronized void keepReplayData() { + if (position > 0) { + throw new IllegalStateException("Some data was processed already"); + } + keepReplayData = true; + } + + /** + * Notify the writer to stop keeping processed data. + */ + public synchronized void discardReplayData() { + if (!keepReplayData) { + return; + } + keepReplayData = false; + for (Iterator iterator = queue.iterator(); iterator.hasNext(); ) { + ByteBuffer next = iterator.next(); + if (next.remaining() == 0) { + iterator.remove(); + } else { + return; + } + } + } + + /** + * Rewinds the enqueued buffer positions to allow for replaying the data + * @throws IllegalStateException if replay data is not available + */ + public synchronized void replayData() { + if (!keepReplayData) { + throw new IllegalStateException("Replay data not available"); + } + if (position == 0) { + return; + } + int rewound = 0; + for (Iterator iterator = queue.iterator(); iterator.hasNext(); ) { + ByteBuffer next = iterator.next(); + if (next.position() != 0) { + rewound += next.position(); + next.position(0); + } else { + break; + } + } + assert rewound == position : rewound - position; + position = 0; + totalRemaining += rewound; + } + + /** + * Clears the queue and resets position back to zero + */ + public synchronized void reset() { + position = 0; + totalRemaining = 0; + queue.clear(); + } + + /** + * Enqueues the provided crypto data + * @param buffer data to enqueue + */ + public synchronized void enqueue(ByteBuffer buffer) { + queue.add(buffer.slice()); + totalRemaining += buffer.remaining(); + } + + /** + * Stores the next portion of queued crypto data in a frame. + * May return null if there's no data to enqueue or if + * maxSize is too small to fit at least one byte of data. + * The produced frame may be shorter than maxSize even if there are + * remaining bytes. + * @param maxSize maximum size of the returned frame, in bytes + * @return frame with next portion of crypto data, or null + * @throws IllegalArgumentException if maxSize < 0 + */ + public synchronized CryptoFrame produceFrame(int maxSize) { + if (maxSize < 0) { + throw new IllegalArgumentException("negative maxSize"); + } + if (totalRemaining == 0) { + return null; + } + int posLength = VariableLengthEncoder.getEncodedSize(position); + // 1 (type) + posLength (position) + 1 (length) + 1 (payload) + if (maxSize < 3 + posLength) { + return null; + } + int maxPayloadPlusLen = maxSize - 1 - posLength; + int maxPayload; + if (maxPayloadPlusLen <= 64) { //63 bytes + 1 byte for length + maxPayload = maxPayloadPlusLen - 1; + } else if (maxPayloadPlusLen <= 16385) { // 16383 bytes + 2 bytes for length + maxPayload = maxPayloadPlusLen - 2; + } else { // 4 bytes for length + maxPayload = maxPayloadPlusLen - 4; + } + // the frame length that we decide upon + final int computedFrameLength = Math.min(maxPayload, totalRemaining); + assert computedFrameLength > 0 : computedFrameLength; + ByteBuffer frameData = null; + for (Iterator iterator = queue.iterator(); iterator.hasNext(); ) { + final ByteBuffer buffer = iterator.next(); + // amount of remaining bytes in the current bytebuffer being processed + final int numRemainingInBuffer = buffer.remaining(); + if (numRemainingInBuffer == 0) { + if (!keepReplayData) { + iterator.remove(); + } + continue; + } + if (frameData == null) { + frameData = ByteBuffer.allocate(computedFrameLength); + } + if (frameData.remaining() >= numRemainingInBuffer) { + // frame data can accommodate the entire buffered data, so copy it over + frameData.put(buffer); + if (!keepReplayData) { + iterator.remove(); + } + } else { + // target frameData buffer cannot accommodate the entire buffered data, + // so we copy over only that much that the target buffer can accommodate + + // amount of data available in the target buffer + final int spaceAvail = frameData.remaining(); + // copy over the buffered data into the target frameData buffer + frameData.put(frameData.position(), buffer, buffer.position(), spaceAvail); + // manually move the position of the target buffer to account for the copied data + frameData.position(frameData.position() + spaceAvail); + // manually move the position of the (input) buffered data to account for + // data that we just copied + buffer.position(buffer.position() + spaceAvail); + // target frameData buffer is fully populated, no more processing of available + // input buffer necessary in this round + break; + } + } + assert frameData != null; + assert !frameData.hasRemaining() : frameData.remaining(); + frameData.flip(); + long oldPosition = position; + position += computedFrameLength; + totalRemaining -= computedFrameLength; + assert totalRemaining >= 0 : totalRemaining; + assert totalRemaining > 0 || keepReplayData || queue.isEmpty(); + return new CryptoFrame(oldPosition, computedFrameLength, frameData); + } + + /** + * {@return the current number of buffered bytes} + */ + public synchronized int remaining() { + return totalRemaining; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicBidiStream.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicBidiStream.java new file mode 100644 index 00000000000..2185507d0db --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicBidiStream.java @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +/** + * An interface that represents a bidirectional stream. + * A bidirectional stream implements both {@link QuicSenderStream} + * and {@link QuicReceiverStream}. + */ +public non-sealed interface QuicBidiStream extends QuicStream, QuicReceiverStream, QuicSenderStream { + + /** + * The state of a bidirectional stream can be obtained by combining + * the state of its sending part and receiving part. + * + * A bidirectional stream is composed of sending and receiving + * parts. Implementations can represent states of the bidirectional + * stream as composites of sending and receiving stream states. + * The simplest model presents the stream as "open" when either + * sending or receiving parts are in a non-terminal state and + * "closed" when both sending and receiving streams are in + * terminal states. + * + * See RFC 9000, [Section 3.4] + * (https://www.rfc-editor.org/rfc/rfc9000#name-bidirectional-stream-states) + */ + enum BidiStreamState implements QuicStream.StreamState { + /** + * A bidirectional stream is considered "idle" if no + * data has been sent or received on that stream. + */ + IDLE, + /** + * A bidirectional stream is considered "open" until all data + * has been received, or all data has been sent, and no reset + * has been sent or received. + */ + OPENED, + /** + * A bidirectional stream is considered locally half closed + * if the sending part is locally closed: + * all data has been sent and acknowledged, or a reset has + * been sent, but the receiving part is still receiving. + */ + HALF_CLOSED_LOCAL, + /** + * A bidirectional stream is considered remotely half closed + * if the receiving part is closed: + * all data has been read or received on the receiving part, + * or reset has been read or received on the receiving part, but + * the sending part is still sending. + */ + HALF_CLOSED_REMOTE, + /** + * A bidirectional stream is considered closed when both parts + * have been reset or all data has been sent and acknowledged + * and all data has been received. + */ + CLOSED; + + /** + * @inheritDoc + * @apiNote + * A bidirectional stream may be considered closed (which is a terminal state), + * even if the sending or receiving part of a stream haven't reached a terminal + * state. Typically, if the sending part has sent a RESET frame, the stream + * may be considered closed even if the acknowledgement hasn't been received + * yet. + */ + @Override + public boolean isTerminal() { + return this == CLOSED; + } + } + + /** + * {@return a composed simplified state computed from the state of + * the receiving part and sending part of the stream} + *

+ * See RFC 9000, [Section 3.4] + * (https://www.rfc-editor.org/rfc/rfc9000#name-bidirectional-stream-states) + */ + default BidiStreamState getBidiStreamState() { + return switch (sendingState()) { + case READY -> switch (receivingState()) { + case RECV -> dataReceived() == 0 + ? BidiStreamState.IDLE + : BidiStreamState.OPENED; + case SIZE_KNOWN -> BidiStreamState.OPENED; + case DATA_RECVD, DATA_READ, RESET_RECVD, RESET_READ + -> BidiStreamState.HALF_CLOSED_REMOTE; + }; + case SEND, DATA_SENT -> switch (receivingState()) { + case RECV, SIZE_KNOWN -> BidiStreamState.OPENED; + case DATA_RECVD, DATA_READ, RESET_RECVD, RESET_READ + -> BidiStreamState.HALF_CLOSED_REMOTE; + }; + case DATA_RECVD, RESET_RECVD, RESET_SENT -> switch (receivingState()) { + case RECV, SIZE_KNOWN -> BidiStreamState.HALF_CLOSED_LOCAL; + case DATA_RECVD, DATA_READ, RESET_RECVD, RESET_READ + -> BidiStreamState.CLOSED; + }; + }; + } + + @Override + default StreamState state() { return getBidiStreamState(); } + + @Override + default boolean hasError() { + return rcvErrorCode() >= 0 || sndErrorCode() >= 0; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicBidiStreamImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicBidiStreamImpl.java new file mode 100644 index 00000000000..a964d96ef9c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicBidiStreamImpl.java @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import java.io.IOException; + +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.quic.QuicConnectionImpl; + +/** + * An implementation of a bidirectional stream. + * A bidirectional stream implements both {@link QuicSenderStream} + * and {@link QuicReceiverStream}. + */ +public final class QuicBidiStreamImpl extends AbstractQuicStream implements QuicBidiStream { + + // The sender part of this bidirectional stream + private final QuicSenderStreamImpl senderPart; + + // The receiver part of this bidirectional stream + private final QuicReceiverStreamImpl receiverPart; + + QuicBidiStreamImpl(QuicConnectionImpl connection, long streamId) { + this(connection, streamId, new QuicSenderStreamImpl(connection, streamId), + new QuicReceiverStreamImpl(connection, streamId)); + } + + private QuicBidiStreamImpl(QuicConnectionImpl connection, long streamId, + QuicSenderStreamImpl sender, QuicReceiverStreamImpl receiver) { + super(connection, streamId); + this.senderPart = sender; + this.receiverPart = receiver; + assert isBidirectional(); + } + + @Override + public ReceivingStreamState receivingState() { + return receiverPart.receivingState(); + } + + @Override + public QuicStreamReader connectReader(SequentialScheduler scheduler) { + return receiverPart.connectReader(scheduler); + } + + @Override + public void disconnectReader(QuicStreamReader reader) { + receiverPart.disconnectReader(reader); + } + + @Override + public void requestStopSending(long errorCode) { + receiverPart.requestStopSending(errorCode); + } + + @Override + public boolean isStopSendingRequested() { + return receiverPart.isStopSendingRequested(); + } + + @Override + public long dataReceived() { + return receiverPart.dataReceived(); + } + + @Override + public long maxStreamData() { + return receiverPart.maxStreamData(); + } + + @Override + public SendingStreamState sendingState() { + return senderPart.sendingState(); + } + + @Override + public QuicStreamWriter connectWriter(SequentialScheduler scheduler) { + return senderPart.connectWriter(scheduler); + } + + @Override + public void disconnectWriter(QuicStreamWriter writer) { + senderPart.disconnectWriter(writer); + } + + @Override + public void reset(long errorCode) throws IOException { + senderPart.reset(errorCode); + } + + @Override + public long dataSent() { + return senderPart.dataSent(); + } + + /** + * {@return the sender part implementation of this bidirectional stream} + */ + public QuicSenderStreamImpl senderPart() { + return senderPart; + } + + /** + * {@return the receiver part implementation of this bidirectional stream} + */ + public QuicReceiverStreamImpl receiverPart() { + return receiverPart; + } + + @Override + public boolean isDone() { + return receiverPart.isDone() && senderPart.isDone(); + } + + @Override + public long rcvErrorCode() { + return receiverPart.rcvErrorCode(); + } + + @Override + public long sndErrorCode() { + return senderPart.sndErrorCode(); + } + + @Override + public boolean stopSendingReceived() { + return senderPart.stopSendingReceived(); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicConnectionStreams.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicConnectionStreams.java new file mode 100644 index 00000000000..3b6198682df --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicConnectionStreams.java @@ -0,0 +1,1590 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Predicate; +import java.util.stream.Stream; + +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.quic.QuicConnectionImpl; +import jdk.internal.net.http.quic.QuicStreamLimitException; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.http.quic.frames.StreamsBlockedFrame; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTLSEngine.KeySpace; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.http.quic.frames.MaxStreamDataFrame; +import jdk.internal.net.http.quic.frames.MaxStreamsFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.frames.ResetStreamFrame; +import jdk.internal.net.http.quic.frames.StopSendingFrame; +import jdk.internal.net.http.quic.frames.StreamDataBlockedFrame; +import jdk.internal.net.http.quic.frames.StreamFrame; +import jdk.internal.net.http.quic.packets.QuicPacketEncoder; +import jdk.internal.net.http.quic.streams.QuicReceiverStream.ReceivingStreamState; +import jdk.internal.net.http.quic.streams.QuicStream.StreamMode; +import jdk.internal.net.http.quic.streams.QuicStream.StreamState; +import jdk.internal.net.http.quic.QuicTransportParameters; +import jdk.internal.net.http.quic.QuicTransportParameters.ParameterId; +import jdk.internal.net.quic.QuicTransportErrors; + +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static jdk.internal.net.http.quic.streams.QuicStreams.*; + +/** + * A helper class to help manage Quic streams in a Quic connection + */ +public final class QuicConnectionStreams { + + // the (sliding) window size of MAX_STREAMS limit + private static final long MAX_BIDI_STREAMS_WINDOW_SIZE = QuicConnectionImpl.DEFAULT_MAX_BIDI_STREAMS; + private static final long MAX_UNI_STREAMS_WINDOW_SIZE = QuicConnectionImpl.DEFAULT_MAX_UNI_STREAMS; + + // These atomic long ids record the expected next stream ID that + // should be allocated for the next stream of a given type. + // The type of a stream is a number in [0..3], and is used + // as an index in this list. + private final List nextStreamID = List.of( + new AtomicLong(), // 0: client initiated bidi + new AtomicLong(SRV_MASK), // 1: server initiated bidi + new AtomicLong(UNI_MASK), // 2: client initiated uni + new AtomicLong(UNI_MASK | SRV_MASK)); // 3: server initiated uni + + // the max uni streams that the current endpoint is allowed to initiate against the peer + private final StreamCreationPermit localUniMaxStreamLimit = new StreamCreationPermit(0); + // the max bidi streams that the current endpoint is allowed to initiate against the peer + private final StreamCreationPermit localBidiMaxStreamLimit = new StreamCreationPermit(0); + // the max uni streams that the remote peer is allowed to initiate against the current endpoint + private final AtomicLong remoteUniMaxStreamLimit = new AtomicLong(0); + // the max bidi streams that the remote peer is allowed to initiate against the current endpoint + private final AtomicLong remoteBidiMaxStreamLimit = new AtomicLong(0); + + private final StreamsContainer streams = new StreamsContainer(); + + // A collection of senders which have available data ready to send, (or which possibly + // are blocked and need to send STREAM_DATA_BLOCKED). + // A stream stays in the queue until it is blocked or until it + // has no more data available to send: when a stream has no more data available for sending it is not + // put back in the queue. It will be put in the queue again when selectForSending is called. + private final ReadyStreamCollection sendersReady; + + // A map that contains streams for which sending a RESET_STREAM frame was requested + // and their corresponding error codes. + // Once the frame has been sent (or has been scheduled to be sent) the stream removed from the map. + private final ConcurrentMap sendersReset = new ConcurrentHashMap<>(); + + // A map that contains streams for which sending a MAX_STREAM_DATA frame was requested. + // Once the frame has been sent (or has been scheduled to be sent) the stream removed from the map. + private final ConcurrentMap receiversSend = new ConcurrentHashMap<>(); + + // A queue of remote initiated streams that have not been acquired yet. + // see pollNewRemoteStreams and addRemoteStreamListener + private final ConcurrentLinkedQueue newRemoteStreams = new ConcurrentLinkedQueue<>(); + // A set of listeners listening to new streams created by the peer + private final Set> streamListeners = ConcurrentHashMap.newKeySet(); + // A lock to ensure consistency between invocation of streamListeners and + // the content of the newRemoteStreams queue. + private final Lock newRemoteStreamsLock = new ReentrantLock(); + + // The connection to which the streams managed by this + // instance of QuicConnectionStreams belong to. + private final QuicConnectionImpl connection; + + // will hold the highest limit from a STREAMS_BLOCKED frame that was sent by a peer for uni + // streams. this indicates the peer isn't able to create any more uni streams, past this limit + private final AtomicLong peerUniStreamsBlocked = new AtomicLong(-1); + // will hold the highest limit from a STREAMS_BLOCKED frame that was sent by a peer for bidi + // streams. this indicates the peer isn't able to create any more bidi streams, past this limit + private final AtomicLong peerBidiStreamsBlocked = new AtomicLong(-1); + // will hold the highest limit at which the local endpoint couldn't create a uni stream + // and a STREAMS_BLOCKED was required to be sent. -1 indicates the local endpoint hasn't yet + // been blocked for stream creation + private final AtomicLong uniStreamsBlocked = new AtomicLong(-1); + // will hold the highest limit at which the local endpoint couldn't create a bidi stream + // and a STREAMS_BLOCKED was required to be sent. -1 indicates the local endpoint hasn't yet + // been blocked for stream creation + private final AtomicLong bidiStreamsBlocked = new AtomicLong(-1); + // will hold the limit with which the local endpoint last sent a STREAMS_BLOCKED frame to the + // peer for uni streams. -1 indicates no STREAMS_BLOCKED frame has been sent yet. A new + // STREAMS_BLOCKED will be sent only if "uniStreamsBlocked" exceeds this + // "lastUniStreamsBlockedSent" + private final AtomicLong lastUniStreamsBlockedSent = new AtomicLong(-1); + // will hold the limit with which the local endpoint last sent a STREAMS_BLOCKED frame to the + // peer for bidi streams. -1 indicates no STREAMS_BLOCKED frame has been sent yet. A new + // STREAMS_BLOCKED will be sent only if "bidiStreamsBlocked" exceeds this + // "lastBidiStreamsBlockedSent" + private final AtomicLong lastBidiStreamsBlockedSent = new AtomicLong(-1); + // streams that have been blocked and aren't able to send data to the peer, + // due to reaching flow control limit imposed on those streams by the peer. + private final Set flowControlBlockedStreams = Collections.synchronizedSet(new HashSet<>()); + + private final Logger debug; + + // A QuicConnectionStream instance can be tied to a client connection + // or a server connection. + // If the connection is a client connection, then localFlag=0x00, + // localBidi=0x00, remoteBidi=0x01, localUni=0x02, remoteUni=0x03 + // If the connection is a server connection, then localFlag=0x01, + // localBidi=0x01, remoteBidi=0x00, localUni=0x03, remoteUni=0x02 + private final int localFlag, localBidi, remoteBidi, localUni, remoteUni; + + /** + * Creates a new instance of {@code QuicConnectionStreams} for the + * given connection. There is a 1-1 relationship between a + * {@code QuicConnectionImpl} instance and a {@code QuicConnectionStreams} + * instance. + * @param connection the connection to which the streams managed by this + * instance of {@code QuicConnectionStreams} belong. + */ + public QuicConnectionStreams(QuicConnectionImpl connection, Logger debug) { + this.connection = connection; + this.debug = Objects.requireNonNull(debug); + // implicit null check for connection + boolean isClient = connection.isClientConnection(); + localFlag = isClient ? 0 : SRV_MASK; + localBidi = isClient ? 0 : SRV_MASK; + remoteBidi = isClient ? SRV_MASK : 0; + localUni = isClient ? UNI_MASK : UNI_MASK | SRV_MASK; + remoteUni = isClient ? UNI_MASK | SRV_MASK : UNI_MASK; + sendersReady = isClient ? + new ReadyStreamQueue() : // faster stream opening + new ReadyStreamSortedQueue(); // faster stream closing + } + + /** + * {@return the next unallocated stream ID that would be expected + * for a stream of the given type} + * This method expects {@code streamType} to be a number in [0..3] but + * does not check it. An assert may be fired if an invalid type is passed. + * @param streamType The stream type, a number in [0..3] + */ + public long peekNextStreamId(int streamType) { + assert streamType >= 0 && streamType < 4; + var id = nextStreamID.get(streamType & 0x03); + return id.get(); + } + + /** + * Creates a new locally initiated unidirectional stream. + *

+ * If the stream cannot be created due to stream creation limit being reached, then this method + * will return a {@code CompletableFuture} which will complete either when the {@code timeout} + * has reached or the stream limit has been increased and the stream creation was successful. + * If the stream creation doesn't complete within the specified timeout then the returned + * {@code CompletableFuture} will complete exceptionally with a {@link QuicStreamLimitException} + * + * @param timeout the maximum duration to wait to acquire a permit for stream creation + * @return a CompletableFuture whose result on successful completion will return the newly + * created {@code QuicSenderStream} + */ + public CompletableFuture createNewLocalUniStream(final Duration timeout) { + @SuppressWarnings("unchecked") + final var streamCF = (CompletableFuture) createNewLocalStream(localUni, + StreamMode.WRITE_ONLY, timeout); + return streamCF; + } + + /** + * Creates a new locally initiated bidirectional stream. + *

+ * If the stream cannot be created due to stream creation limit being reached, then this method + * will return a {@code CompletableFuture} which will complete either when the {@code timeout} + * has reached or the stream limit has been increased and the stream creation was successful. + * If the stream creation doesn't complete within the specified timeout then the returned + * {@code CompletableFuture} will complete exceptionally with a {@link QuicStreamLimitException} + * + * @param timeout the maximum duration to wait to acquire a permit for stream creation + * @return a CompletableFuture whose result on successful completion will return the newly + * created {@code QuicBidiStream} + */ + public CompletableFuture createNewLocalBidiStream(final Duration timeout) { + @SuppressWarnings("unchecked") + final var streamCF = (CompletableFuture) createNewLocalStream(localBidi, + StreamMode.READ_WRITE, timeout); + return streamCF; + } + + private void register(long streamId, AbstractQuicStream stream) { + var previous = streams.put(streamId, stream); + assert previous == null : "stream " + streamId + " is already registered!"; + QuicTransportParameters peerParameters = connection.peerTransportParameters(); + if (peerParameters != null) { + if (debug.on()) { + debug.log("setting initial peer parameters on stream " + streamId); + } + newInitialPeerParameters(stream, peerParameters); + } + QuicTransportParameters localParameters = connection.localTransportParameters(); + if (localParameters != null) { + if (debug.on()) { + debug.log("setting initial local parameters on stream " + streamId); + } + newInitialLocalParameters(stream, localParameters); + } + if (stream instanceof QuicReceiverStream receiver && stream.isRemoteInitiated()) { + if (debug.on()) { + debug.log("accepting remote stream " + streamId); + } + newRemoteStreams.add(receiver); + acceptRemoteStreams(); + } + if (debug.on()) { + debug.log("new stream %s %s registered", streamId, stream.mode()); + } + } + + private void acceptRemoteStreams() { + newRemoteStreamsLock.lock(); + try { + for (var listener : streamListeners) { + var iterator = newRemoteStreams.iterator(); + while (iterator.hasNext()) { + var stream = iterator.next(); + if (debug.on()) { + debug.log("invoking remote stream listener for stream %s", + stream.streamId()); + } + if (listener.test(stream)) iterator.remove(); + } + } + } finally { + newRemoteStreamsLock.unlock(); + } + } + + private CompletableFuture createNewLocalStream( + final int localType, final StreamMode mode, final Duration timeout) { + assert localType >= 0 && localType < 4 : "bad local stream type " + localType; + assert (localType & SRV_MASK) == localFlag : "bad local stream type " + localType; + assert (localType & UNI_MASK) == 0 || mode == StreamMode.WRITE_ONLY + : "bad combination of local stream type (%s) and mode %s" + .formatted(localType, mode); + assert (localType & UNI_MASK) == UNI_MASK || mode == StreamMode.READ_WRITE + : "bad combination of local stream type (%s) and mode %s" + .formatted(localType, mode); + final boolean bidi = isBidirectional(localType); + final StreamCreationPermit permit = bidi ? this.localBidiMaxStreamLimit + : this.localUniMaxStreamLimit; + final CompletableFuture permitAcquisitionCF; + final long currentLimit = permit.currentLimit(); + final boolean acquired = permit.tryAcquire(); + if (acquired) { + permitAcquisitionCF = MinimalFuture.completedFuture(true); + } else { + // stream limit reached, request sending a STREAMS_BLOCKED frame + announceStreamsBlocked(bidi, currentLimit); + if (timeout.isPositive()) { + final Executor executor = this.connection.quicInstance().executor(); + if (debug.on()) { + debug.log("stream creation limit = " + permit.currentLimit() + + " reached; waiting for it to increase, timeout=" + timeout); + } + permitAcquisitionCF = permit.tryAcquire(timeout.toNanos(), NANOSECONDS, executor); + } else { + permitAcquisitionCF = MinimalFuture.completedFuture(false); + } + } + final CompletableFuture streamCF = + permitAcquisitionCF.thenCompose((acq) -> { + if (!acq) { + final String msg = "Stream limit = " + permit.currentLimit() + + " reached for locally initiated " + + (bidi ? "bidi" : "uni") + " streams"; + return MinimalFuture.failedFuture(new QuicStreamLimitException(msg)); + } + // stream limit hasn't been reached, we are allowed to create new one + final long streamId = nextStreamID.get(localType).getAndAdd(4); + final AbstractQuicStream stream = QuicStreams.createStream(connection, streamId); + assert stream.mode() == mode; + assert stream.type() == localType; + if (debug.on()) { + var strtype = (localType & UNI_MASK) == UNI_MASK ? "uni" : "bidi"; + debug.log("created new local %s stream type:%s, mode:%s, id:%s", + strtype, localType, mode, streamId); + } + register(streamId, stream); + return MinimalFuture.completedFuture(stream); + }); + return streamCF; + } + + /** + * Runs the APPLICATION space packet transmitter, if necessary, + * to potentially trigger sending a STREAMS_BLOCKED frame to the peer + * @param bidi true if the local endpoint is blocked for bidi streams, false for uni streams + * @param blockedOnLimit the stream creation limit due to which the local endpoint is + * currently blocked + */ + private void announceStreamsBlocked(final boolean bidi, final long blockedOnLimit) { + boolean runTransmitter = false; + if (bidi) { + long prevBlockedLimit = this.bidiStreamsBlocked.get(); + while (blockedOnLimit > prevBlockedLimit) { + if (this.bidiStreamsBlocked.compareAndSet(prevBlockedLimit, blockedOnLimit)) { + runTransmitter = true; + break; + } + prevBlockedLimit = this.bidiStreamsBlocked.get(); + } + } else { + long prevBlockedLimit = this.uniStreamsBlocked.get(); + while (blockedOnLimit > prevBlockedLimit) { + if (this.uniStreamsBlocked.compareAndSet(prevBlockedLimit, blockedOnLimit)) { + runTransmitter = true; + break; + } + prevBlockedLimit = this.uniStreamsBlocked.get(); + } + } + if (runTransmitter) { + if (debug.on()) { + debug.log("requesting packet transmission to send " + (bidi ? "bidi" : "uni") + + " STREAMS_BLOCKED with limit " + blockedOnLimit); + } + this.connection.runAppPacketSpaceTransmitter(); + } + } + + /** + * Runs the APPLICATION space packet transmitter, if necessary, to potentially trigger + * sending a MAX_STREAMS frame to the peer, upon receiving the STREAMS_BLOCKED {@code frame} + * from that peer + * @param frame the STREAMS_BLOCKED frame that was received from the peer + */ + public void peerStreamsBlocked(final StreamsBlockedFrame frame) { + final boolean bidi = frame.isBidi(); + final long blockedOnLimit = frame.maxStreams(); + boolean runTransmitter = false; + if (bidi) { + long prevBlockedLimit = this.peerBidiStreamsBlocked.get(); + while (blockedOnLimit > prevBlockedLimit) { + if (this.peerBidiStreamsBlocked.compareAndSet(prevBlockedLimit, blockedOnLimit)) { + runTransmitter = true; + break; + } + prevBlockedLimit = this.peerBidiStreamsBlocked.get(); + } + } else { + long prevBlockedLimit = this.peerUniStreamsBlocked.get(); + while (blockedOnLimit > prevBlockedLimit) { + if (this.peerUniStreamsBlocked.compareAndSet(prevBlockedLimit, blockedOnLimit)) { + runTransmitter = true; + break; + } + prevBlockedLimit = this.peerUniStreamsBlocked.get(); + } + } + if (runTransmitter) { + if (debug.on()) { + debug.log("requesting packet transmission in response to receiving " + + (bidi ? "bidi" : "uni") + " STREAMS_BLOCKED from peer," + + " blocked with limit " + blockedOnLimit); + } + this.connection.runAppPacketSpaceTransmitter(); + } + } + + /** + * Gets or opens a remotely initiated stream with the given stream ID. + * Creates all streams with lower IDs if needed. + * @param streamId the stream ID + * @param frameType type of the frame received, used in exceptions + * @return a remotely initiated stream with the given stream ID. + * May return null if the stream was already closed. + * @throws IllegalArgumentException if the streamID is of the wrong type for + * a remote stream. + * @throws QuicTransportException if the streamID is higher than allowed + */ + public QuicStream getOrCreateRemoteStream(long streamId, long frameType) + throws QuicTransportException { + final int streamType = streamType(streamId); + if ((streamId & SRV_MASK) == localFlag) { + throw new IllegalArgumentException("bad remote stream type %s for stream %s" + .formatted(streamType, streamId)); + } + final boolean bidi = isBidirectional(streamId); + final long maxStreamLimit = bidi ? this.remoteBidiMaxStreamLimit.get() + : this.remoteUniMaxStreamLimit.get(); + if (maxStreamLimit <= (streamId >> 2)) { + throw new QuicTransportException("stream ID %s exceeds the number of allowed streams(%s)" + .formatted(streamId, maxStreamLimit), QuicTLSEngine.KeySpace.ONE_RTT, frameType, + QuicTransportErrors.STREAM_LIMIT_ERROR); + } + + newRemoteStreamsLock.lock(); + try { + var id = nextStreamID.get(streamType); + long nextId = id.get(); + if (nextId > streamId) { + // already created + return streams.get(streamId); + } + // id must not be modified outside newRemoteStreamsLock + long altId = id.getAndSet(streamId + 4); + assert altId == nextId : "next ID concurrently modified"; + + AbstractQuicStream stream = null; + for (long i = nextId; i <= streamId; i += 4) { + stream = QuicStreams.createStream(connection, i); + assert stream.isRemoteInitiated(); + register(i, stream); + } + assert stream != null; + assert stream.streamId() == streamId : stream.streamId(); + return stream; + } finally { + newRemoteStreamsLock.unlock(); + } + } + + + /** + * Finds a stream with the given stream ID. Returns {@code null} if no + * stream with that ID is found. + * @param streamId a stream ID + * @return the stream with the given stream ID if found, {@code null} + * otherwise. + */ + public QuicStream findStream(long streamId) { + return streams.get(streamId); + } + + /** + * Adds a listener that will be invoked when a remote stream is + * created. + * + * @apiNote The listener will be invoked with any remote streams + * already opened, and not yet acquired by another listener. + * Any stream passed to the listener is either a {@link QuicBidiStream} + * or a {@link QuicReceiverStream} depending on the + * {@linkplain QuicStreams#streamType(long) + * stream type} of the given streamId. + * The listener should return true if it wishes to acquire + * the stream. + * + * @param streamConsumer the listener + * + */ + public void addRemoteStreamListener(Predicate streamConsumer) { + newRemoteStreamsLock.lock(); + try { + streamListeners.add(streamConsumer); + acceptRemoteStreams(); + } finally { + newRemoteStreamsLock.unlock(); + } + } + + /** + * Removes a listener previously added with {@link #addRemoteStreamListener(Predicate)} + * @return {@code true} if the listener was found and removed, {@code false} otherwise + */ + public boolean removeRemoteStreamListener(Predicate streamConsumer) { + newRemoteStreamsLock.lock(); + try { + return streamListeners.remove(streamConsumer); + } finally { + newRemoteStreamsLock.unlock(); + } + } + + /** + * {@return a stream of all currently active {@link QuicStream} in the connection} + */ + public Stream quicStreams() { + return streams.all(); + } + + /** + * {@return {@code true} if there is some data to send} + * @apiNote + * This method may return true in the case where a + * STREAM_DATA_BLOCKED frame needs to be sent, even if no + * other data is available. + */ + public boolean hasAvailableData() { + return !sendersReady.isEmpty(); + } + + /** + * {@return true if there are control frames to send} + * Typically, these are STREAMS_BLOCKED, MAX_STREAMS, RESET_STREAM, STOP_SENDING, and + * MAX_STREAM_DATA. + */ + public boolean hasControlFrames() { + return !sendersReset.isEmpty() || !receiversSend.isEmpty() + // either of these imply we may send a MAX_STREAMS frame + || peerUniStreamsBlocked.get() != -1 || peerBidiStreamsBlocked.get() != -1 + // either of these imply we should send a STREAMS_BLOCKED frame + || uniStreamsBlocked.get() > lastUniStreamsBlockedSent.get() + || bidiStreamsBlocked.get() > lastBidiStreamsBlockedSent.get(); + } + + /** + * {@return {@code true} if the given {@code streamId} indicates a stream + * that has a receiving part} + * In other words, returns {@code true} if the given stream is either + * bidirectional or peer-initiated. + * @param streamId a stream ID + */ + public boolean isReceivingStream(long streamId) { + return !isLocalUni(streamId); + } + + /** + * {@return {@code true} if the given {@code streamId} indicates a stream + * that has a sending part} + * In other words, returns {@code true} if the given stream is either + * bidirectional or local-initiated. + * @param streamId a stream ID + */ + public boolean isSendingStream(long streamId) { + return !isRemoteUni(streamId); + } + + /** + * {@return {@code true} if the given {@code streamId} indicates a local + * unidirectional stream} + * @param streamId a stream ID + */ + public boolean isLocalUni(long streamId) { + return streamType(streamId) == localUni; + } + + /** + * {@return {@code true} if the given {@code streamId} indicates a local + * bidirectional stream} + * @param streamId a stream ID + */ + public boolean isLocalBidi(long streamId) { + return streamType(streamId) == localBidi; + } + + /** + * {@return {@code true} if the given {@code streamId} indicates a + * peer initiated unidirectional stream} + * @param streamId a stream ID + */ + public boolean isRemoteUni(long streamId) { + return streamType(streamId) == remoteUni; + } + + /** + * {@return {@code true} if the given {@code streamId} indicates a + * peer initiated bidirectional stream} + * @param streamId a stream ID + */ + public boolean isRemoteBidi(long streamId) { + return streamType(streamId) == remoteBidi; + } + + /** + * Mark the stream whose ID is encoded in the given + * {@code ResetStreamFrame} as needing a RESET_STREAM frame to be sent. + * It will put the stream and the frame in the {@code sendersReset} map. + * @param streamId the id of the stream that should be reset + * @param errorCode the application error code + */ + public void requestResetStream(long streamId, long errorCode) { + assert isSendingStream(streamId); + var stream = senderImpl(streams.get(streamId)); + if (stream == null) { + if (debug.on()) { + debug.log("Can't reset stream %d: no such stream", streamId); + } + return; + } + sendersReset.putIfAbsent(stream, errorCode); + if (debug.on()) { + debug.log("Reset stream scheduled"); + } + } + + /** + * Mark the stream whose ID is encoded in the given + * {@code MaxStreamDataFrame} as needing a MAX_STREAM_DATA frame to be sent. + * It will put the stream and the frame in the {@code receiversSend} map. + * @param maxStreamDataFrame the MAX_STREAM_DATA frame to send + */ + public void requestSendMaxStreamData(MaxStreamDataFrame maxStreamDataFrame) { + Objects.requireNonNull(maxStreamDataFrame, "maxStreamDataFrame"); + long streamId = maxStreamDataFrame.streamID(); + assert isReceivingStream(streamId); + var stream = streams.get(streamId); + if (stream == null) { + if (debug.on()) { + debug.log("Can't send MaxStreamDataFrame %d: no such stream", streamId); + } + return; + } + if (stream instanceof QuicReceiverStream receiver) { + // don't replace a stop sending frame, and don't replace + // a max stream data frame if it has a bigger max stream data + receiversSend.compute(receiver, (s, frame) -> { + if (frame instanceof StopSendingFrame stopSendingFrame) { + assert s.streamId() == stopSendingFrame.streamID(); + // no need to send max data frame if we are requesting + // stop sending + return frame; + } + if (frame instanceof MaxStreamDataFrame maxFrame) { + assert s.streamId() == maxFrame.streamID(); + if (maxFrame.maxStreamData() > maxStreamDataFrame.maxStreamData()) { + // send the frame that has the greater max data + return maxFrame; + } else return maxStreamDataFrame; + } + assert frame == null; + return maxStreamDataFrame; + }); + } else { + if (debug.on()) { + debug.log("Can't send %s stream %d: not a receiver stream", + maxStreamDataFrame.getClass(), streamId); + } + } + } + + + /** + * Mark the stream whose ID is encoded in the given + * {@code StopSendingFrame} as needing a STOP_SENDING frame to be sent. + * It will put the stream and the frame in the {@code receiversSend} map. + * @param stopSendingFrame the STOP_SENDING frame to send + */ + public void scheduleStopSendingFrame(StopSendingFrame stopSendingFrame) { + Objects.requireNonNull(stopSendingFrame, "stopSendingFrame"); + long streamId = stopSendingFrame.streamID(); + assert isReceivingStream(streamId); + var stream = streams.get(streamId); + if (stream == null) { + if (debug.on()) { + debug.log("Can't send STOP_SENDING to stream %d: no such stream", streamId); + } + return; + } + if (stream instanceof QuicReceiverStream receiver) { + // don't need to check if we already have a frame registered: + // stop sending takes precedence. + receiversSend.put(receiver, stopSendingFrame); + } else { + if (debug.on()) { + debug.log("Can't send %s stream %d: not a receiver stream", + stopSendingFrame.getClass(), streamId); + } + } + } + + /** + * Called when the RESET_STREAM frame is acknowledged by the peer. + * @param reset the RESET_STREAM frame + */ + public void streamResetAcknowledged(ResetStreamFrame reset) { + Objects.requireNonNull(reset, "reset"); + long streamId = reset.streamId(); + assert isSendingStream(streamId) : + "stream %s is not a sending stream".formatted(streamId); + final var stream = streams.get(streamId); + if (stream == null) { + return; + } + var sender = senderImpl(stream); + if (sender != null) { + sender.resetAcknowledged(reset.finalSize()); + assert !stream.isDone() || !streams.streams.containsKey(streamId) + : "resetAcknowledged() should have removed the stream"; + if (debug.on()) { + debug.log("acknowledged reset for stream %d", streamId); + } + } + } + + /** + * Called when the final STREAM frame is acknowledged by the peer. + * @param streamFrame the final STREAM frame + */ + public void streamDataSentAcknowledged(StreamFrame streamFrame) { + long streamId = streamFrame.streamId(); + assert isSendingStream(streamId) : + "stream %s is not a sending stream".formatted(streamId); + assert streamFrame.isLast(); + final var stream = streams.get(streamId); + if (stream == null) { + return; + } + var sender = senderImpl(stream); + if (sender != null) { + sender.dataAcknowledged(streamFrame.offset() + streamFrame.dataLength()); + assert !stream.isDone() || !streams.streams.containsKey(streamId) + : "dataAcknowledged() should have removed the stream"; + if (debug.on()) { + debug.log("acknowledged data for stream %d", streamId); + } + } + } + + /** + * Tracks a stream, belonging to this connection, as being blocked from sending data + * due to flow control limit. + * + * @param streamId the stream id + */ + final void trackBlockedStream(final long streamId) { + this.flowControlBlockedStreams.add(streamId); + } + + /** + * Stops tracking a stream, belonging to this connection, that may have been previously + * tracked as being blocked due to flow control limit. + * + * @param streamId the stream id + */ + final void untrackBlockedStream(final long streamId) { + this.flowControlBlockedStreams.remove(streamId); + } + + + /** + * Removes a stream from the stream map after its state has been + * switched to DATA_RECVD or RESET_RECVD + * @param streamId the stream id + * @param stream the stream instance + */ + private void removeStream(long streamId, QuicStream stream) { + // if we were tracking this stream as blocked due to flow control, then + // stop tracking the stream. + untrackBlockedStream(streamId); + if (stream instanceof AbstractQuicStream astream) { + if (astream.isDone()) { + if (debug.on()) { + debug.log("Removing stream %d (%s)", + stream.streamId(), stream.getClass().getSimpleName()); + } + streams.remove(streamId, astream); + if (stream.isRemoteInitiated()) { + // the queue is not expected to contain many elements. + newRemoteStreams.remove(stream); + if (shouldSendMaxStreams(stream.isBidirectional())) { + this.connection.runAppPacketSpaceTransmitter(); + } + } + } else { + if (debug.on()) { + debug.log("Can't remove stream yet: %d (%s) is %s", + stream.streamId(), stream.getClass().getSimpleName(), + stream.state()); + } + } + } + assert stream instanceof AbstractQuicStream + : "stream %s: unexpected stream class: %s" + .formatted(streamId, stream.getClass()); + } + + /** + * Called when new local transport parameters are available + * @param params the new local transport parameters + */ + public void newLocalTransportParameters(final QuicTransportParameters params) { + // the limit imposed on the remote peer by the local endpoint + final long newRemoteUniMax = params.getIntParameter(ParameterId.initial_max_streams_uni); + tryIncreaseLimitTo(this.remoteUniMaxStreamLimit, newRemoteUniMax); + final long newRemoteBidiMax = params.getIntParameter(ParameterId.initial_max_streams_bidi); + tryIncreaseLimitTo(this.remoteBidiMaxStreamLimit, newRemoteBidiMax); + streams.all().forEach(s -> newInitialLocalParameters(s, params)); + } + + /** + * Called when new peer transport parameters are available + * @param params the new local transport parameters + */ + public void newPeerTransportParameters(final QuicTransportParameters params) { + // the limit imposed on the local endpoint by the remote peer + final long localUniMaxStreams = params.getIntParameter(ParameterId.initial_max_streams_uni); + if (debug.on()) { + debug.log("increasing localUniMaxStreamLimit to initial_max_streams_uni: " + + localUniMaxStreams); + } + this.localUniMaxStreamLimit.tryIncreaseLimitTo(localUniMaxStreams); + final long localBidiMaxStreams = params.getIntParameter(ParameterId.initial_max_streams_bidi); + if (debug.on()) { + debug.log("increasing localBidiMaxStreamLimit to initial_max_streams_bidi: " + + localBidiMaxStreams); + } + this.localBidiMaxStreamLimit.tryIncreaseLimitTo(localBidiMaxStreams); + // set initial parameters on streams + streams.all().forEach(s -> newInitialPeerParameters(s, params)); + if (debug.on()) { + debug.log("all streams updated (%s)", streams.streams.size()); + } + } + + /** + * Called to set initial peer parameters on a stream + * @param stream the stream on which parameters might be set + * @param params the peer transport parameters + */ + private void newInitialPeerParameters(QuicStream stream, QuicTransportParameters params) { + long streamId = stream.streamId(); + if (isLocalUni(stream.streamId())) { + if (params.isPresent(ParameterId.initial_max_stream_data_uni)) { + long maxData = params.getIntParameter(ParameterId.initial_max_stream_data_uni); + senderImpl(stream).setMaxStreamData(maxData); + } + } else if (isLocalBidi(streamId)) { + // remote for the peer is local for us + if (params.isPresent(ParameterId.initial_max_stream_data_bidi_remote)) { + long maxData = params.getIntParameter(ParameterId.initial_max_stream_data_bidi_remote); + senderImpl(stream).setMaxStreamData(maxData); + } + } else if (isRemoteBidi(streamId)) { + // local for the peer is remote for us + if (params.isPresent(ParameterId.initial_max_stream_data_bidi_local)) { + long maxData = params.getIntParameter(ParameterId.initial_max_stream_data_bidi_local); + senderImpl(stream).setMaxStreamData(maxData); + } + } + } + + private static boolean tryIncreaseLimitTo(final AtomicLong limit, final long newLimit) { + long currentLimit = limit.get(); + while (currentLimit < newLimit) { + if (limit.compareAndSet(currentLimit, newLimit)) { + return true; + } + currentLimit = limit.get(); + } + return false; + } + + /** + * Called to set initial peer parameters on a stream + * @param stream the stream on which parameters might be set + * @param params the peer transport parameters + */ + private void newInitialLocalParameters(QuicStream stream, QuicTransportParameters params) { + long streamId = stream.streamId(); + if (isRemoteUni(stream.streamId())) { + if (params.isPresent(ParameterId.initial_max_stream_data_uni)) { + long maxData = params.getIntParameter(ParameterId.initial_max_stream_data_uni); + receiverImpl(stream).updateMaxStreamData(maxData); + } + } else if (isLocalBidi(streamId)) { + if (params.isPresent(ParameterId.initial_max_stream_data_bidi_local)) { + long maxData = params.getIntParameter(ParameterId.initial_max_stream_data_bidi_local); + receiverImpl(stream).updateMaxStreamData(maxData); + } + } else if (isRemoteBidi(streamId)) { + if (params.isPresent(ParameterId.initial_max_stream_data_bidi_remote)) { + long maxData = params.getIntParameter(ParameterId.initial_max_stream_data_bidi_remote); + receiverImpl(stream).updateMaxStreamData(maxData); + } + } + } + + /** + * Set max stream data for a stream. + * Called when a {@link jdk.internal.net.http.quic.frames.MaxStreamDataFrame + * MaxStreamDataFrame} is received. + * @param stream the stream + * @param maxStreamData the max data that the peer is willing to accept on this stream + */ + public void setMaxStreamData(QuicSenderStream stream, long maxStreamData) { + var sender = senderImpl(stream); + if (sender != null) { + final long newFinalizedLimit = sender.setMaxStreamData(maxStreamData); + // if the connection was tracking this stream as blocked due to flow control + // and if this new MAX_STREAM_DATA limit unblocked this stream, then + // stop tracking the stream. + if (newFinalizedLimit == maxStreamData) { // the proposed limit was accepted + if (!sender.isBlocked()) { + untrackBlockedStream(stream.streamId()); + } + } + } + } + + /** + * This method is called when a {@link + * jdk.internal.net.http.quic.frames.StopSendingFrame} is received + * from the peer. + * @param stream the stream for which stop sending was requested + * by the peer + * @param errorCode the error code + */ + public void stopSendingReceived(QuicSenderStream stream, long errorCode) { + var sender = senderImpl(stream); + if (sender != null) { + // if the stream was being tracked as blocked from sending data, + // due to flow control limits imposed by the peer, then we now + // stop tracking it since the peer no longer wants us to send data + // on this stream. + untrackBlockedStream(stream.streamId()); + sender.stopSendingReceived(errorCode); + } + } + + /** + * Called when the receiving part or the sending part of a stream + * reaches a terminal state. + * @param streamId the id of the stream + * @param state the terminal state + */ + public void notifyTerminalState(long streamId, StreamState state) { + assert state.isTerminal() : state; + var stream = streams.get(streamId); + if (stream != null) { + removeStream(streamId, stream); + } + } + + /** + * Called when the connection is closed by the higher level + * protocol + * @param terminationCause the termination cause + */ + public void terminate(final TerminationCause terminationCause) { + assert terminationCause != null : "termination cause is null"; + // make sure all active streams are woken up when we close a connection + streams.all().forEach((stream) -> { + if (stream instanceof QuicSenderStream) { + var sender = senderImpl(stream); + try { + sender.terminate(terminationCause); + } catch (Throwable t) { + if (debug.on()) { + debug.log("failed to close sender stream %s: %s", sender.streamId(), t); + } + } + } + if (stream instanceof QuicReceiverStream) { + var receiver = receiverImpl(stream); + try { + receiver.terminate(terminationCause); + } catch (Throwable t) { + // log and ignore + if (debug.on()) { + debug.log("failed to close receiver stream %s: %s", receiver.streamId(), t); + } + } + } + }); + } + + /** + * This method is called by when a stream has data available for sending. + * + * @param streamId the stream id of the stream which is ready + * @see QuicConnectionImpl#streamDataAvailableForSending + */ + public void enqueueForSending(long streamId) { + var stream = streams.get(streamId); + if (stream == null) { + if (debug.on()) + debug.log("WARNING: stream %d not found", streamId); + return; + } + if (stream instanceof QuicSenderStream sender) { + // No need to check/assert the presence of this sender in the queue. + // In fact there is no guarantee that the sender isn't already in the + // queue, since the scheduler loop can also put it back into the queue, + // if for example, not everything that the sender wanted to send could + // fit in the quic packet. + sendersReady.add(sender); + } else { + String msg = String.format("Stream %s not a sending or bidi stream: %s", + streamId, stream.getClass().getName()); + if (debug.on()) { + debug.log("WARNING: " + msg); + } + throw new AssertionError(msg); + } + } + + /** + * If there are any streams in this connection that have been blocked from sending + * data due to flow control limit on that stream, then this method enqueues a + * {@code STREAM_DATA_BLOCKED} frame to be sent for each such stream. + */ + public final void enqueueStreamDataBlocked() { + connection.streamDataAvailableForSending(this.flowControlBlockedStreams); + } + + /** + * {@return the sender part implementation of the given stream, or {@code null}} + * This method returns null if the given stream doesn't have a sending part + * (that is, if it is a unidirectional peer initiated stream). + * @param stream a sending or bidirectional stream + */ + QuicSenderStreamImpl senderImpl(QuicStream stream) { + if (stream instanceof QuicSenderStreamImpl sender) { + return sender; + } else if (stream instanceof QuicBidiStreamImpl bidi) { + return bidi.senderPart(); + } + return null; + } + + /** + * {@return the receiver part implementation of the given stream, or {@code null}} + * This method returns null if the given stream doesn't have a receiver part + * (that is, if it is a unidirectional local initiated stream). + * @param stream a receiving or bidirectional stream + */ + QuicReceiverStreamImpl receiverImpl(QuicStream stream) { + if (stream instanceof QuicReceiverStreamImpl receiver) { + return receiver; + } else if (stream instanceof QuicBidiStreamImpl bidi) { + return bidi.receiverPart(); + } + return null; + } + + /** + * Called when a StreamFrame is received. + * @param stream the stream for which the StreamFrame was received + * @param frame the stream frame + * @throws QuicTransportException if an error occurred processing the frame + */ + public void processIncomingFrame(QuicStream stream, StreamFrame frame) throws QuicTransportException { + var receiver = receiverImpl(stream); + assert receiver != null; + receiver.processIncomingFrame(frame); + } + + /** + * Called when a ResetStreamFrame is received. + * @param stream the stream for which the ResetStreamFrame was received + * @param frame the reset stream frame + * @throws QuicTransportException if an error occurred processing the frame + */ + public void processIncomingFrame(QuicStream stream, ResetStreamFrame frame) throws QuicTransportException { + var receiver = receiverImpl(stream); + assert receiver != null; + receiver.processIncomingResetFrame(frame); + } + + public void processIncomingFrame(final QuicStream stream, final StreamDataBlockedFrame frame) { + assert stream.streamId() == frame.streamId() : "unexpected stream id " + frame.streamId() + + " in frame, expected " + stream.streamId(); + final QuicReceiverStreamImpl rcvrStream = receiverImpl(stream); + assert rcvrStream != null : "missing receiver stream for stream " + stream.streamId(); + rcvrStream.processIncomingFrame(frame); + } + + public boolean tryIncreaseStreamLimit(final MaxStreamsFrame maxStreamsFrame) { + final StreamCreationPermit permit = maxStreamsFrame.isBidi() + ? localBidiMaxStreamLimit : localUniMaxStreamLimit; + final long newLimit = maxStreamsFrame.maxStreams(); + if (debug.on()) { + if (maxStreamsFrame.isBidi()) { + debug.log("increasing localBidiMaxStreamLimit limit to: " + newLimit); + } else { + debug.log("increasing localUniMaxStreamLimit limit to: " + newLimit); + } + } + return permit.tryIncreaseLimitTo(newLimit); + } + + /** + * Checks whether any stream needs to have a STOP_SENDING, RESET_STREAM or any connection + * control frames like STREAMS_BLOCKED, MAX_STREAMS sent and adds the frame to the list + * if there's room. + * @param frames list of frames + * @param remaining maximum number of bytes that can be added by this method + * @return number of bytes actually added + */ + private long checkResetAndOtherControls(List frames, long remaining) { + if (debug.on()) + debug.log("checking reset and other control frames..."); + long added = 0; + // check STREAMS_BLOCKED, only send it if the local endpoint is blocked on a limit + // for which we haven't yet sent a STREAMS_BLOCKED + final long uniStreamsBlockedLimit = this.uniStreamsBlocked.get(); + final long lastUniStreamsBlockedSent = this.lastUniStreamsBlockedSent.get(); + if (uniStreamsBlockedLimit != -1 && uniStreamsBlockedLimit > lastUniStreamsBlockedSent) { + final StreamsBlockedFrame frame = new StreamsBlockedFrame(false, uniStreamsBlockedLimit); + final int size = frame.size(); + if (size > remaining - added) { + if (debug.on()) { + debug.log("Not enough space to add a STREAMS_BLOCKED frame for uni streams"); + } + } else { + frames.add(frame); + added += size; + // now that we are sending a STREAMS_BLOCKED frame, keep track of the limit + // that we sent it with + this.lastUniStreamsBlockedSent.set(frame.maxStreams()); + } + } + final long bidiStreamsBlockedLimit = this.bidiStreamsBlocked.get(); + final long lastBidiStreamsBlockedSent = this.lastBidiStreamsBlockedSent.get(); + if (bidiStreamsBlockedLimit != -1 && bidiStreamsBlockedLimit > lastBidiStreamsBlockedSent) { + final StreamsBlockedFrame frame = new StreamsBlockedFrame(true, bidiStreamsBlockedLimit); + final int size = frame.size(); + if (size > remaining - added) { + if (debug.on()) { + debug.log("Not enough space to add a STREAMS_BLOCKED frame for bidi streams"); + } + } else { + frames.add(frame); + added += size; + // now that we are sending a STREAMS_BLOCKED frame, keep track of the limit + // that we sent it with + this.lastBidiStreamsBlockedSent.set(frame.maxStreams()); + } + } + // check STOP_SENDING and MAX_STREAM_DATA + var rcvIterator = receiversSend.entrySet().iterator(); + while (rcvIterator.hasNext()) { + var entry = rcvIterator.next(); + var frame = entry.getValue(); + if (frame.size() > remaining - added) { + if (debug.on()) { + debug.log("Stream %s: not enough space for %s", + entry.getKey().streamId(), frame); + } + break; + } + var receiver = receiverImpl(entry.getKey()); + var size = checkSendControlFrame(receiver, frame, frames); + if (size > 0) { + added += size; + } + rcvIterator.remove(); + } + + // check RESET_STREAM + var sndIterator = sendersReset.entrySet().iterator(); + while (sndIterator.hasNext()) { + Map.Entry entry = sndIterator.next(); + var sender = senderImpl(entry.getKey()); + assert sender != null; + long finalSize = sender.dataSent(); + ResetStreamFrame frame = new ResetStreamFrame(sender.streamId(), entry.getValue(), finalSize); + final int size = frame.size(); + if (size > remaining - added) { + if (debug.on()) { + debug.log("Stream %s: not enough space for ResetFrame", + sender.streamId()); + } + break; + } + if (debug.on()) + debug.log("Stream %s: Adding ResetFrame", sender.streamId()); + frames.add(frame); + added += size; + sender.resetSent(); + sndIterator.remove(); + } + + if (remaining - added > 18) { + // add MAX_STREAMS if necessary + added += addMaxStreamsFrame(frames, false); + added += addMaxStreamsFrame(frames, true); + } + return added; + } + + private boolean shouldSendMaxStreams(final boolean bidi) { + final boolean rcvdStreamsBlocked = bidi + ? this.peerBidiStreamsBlocked.get() != -1 + : this.peerUniStreamsBlocked.get() != -1; + // if we either received a STREAMS_BLOCKED from the peer for that stream type + // or if our internal algorithm decides that the peer is about to reach the stream + // creation limit + return rcvdStreamsBlocked || nextMaxStreamsLimit(bidi) > 0; + } + + private long addMaxStreamsFrame(final List frames, final boolean bidi) { + final long newMaxStreamsLimit = connection.nextMaxStreamsLimit(bidi); + if (newMaxStreamsLimit == 0) { + return 0; + } + final boolean limitIncreased; + if (bidi) { + limitIncreased = tryIncreaseLimitTo(remoteBidiMaxStreamLimit, newMaxStreamsLimit); + } else { + limitIncreased = tryIncreaseLimitTo(remoteUniMaxStreamLimit, newMaxStreamsLimit); + } + if (!limitIncreased) { + return 0; + } + final MaxStreamsFrame frame = new MaxStreamsFrame(bidi, newMaxStreamsLimit); + frames.add(frame); + // now that we are sending MAX_STREAMS frame to the peer, reset the relevant + // STREAMS_BLOCKED flag that we might have set when/if we had received a STREAMS_BLOCKED + // from the peer + if (bidi) { + this.peerBidiStreamsBlocked.set(-1); + } else { + this.peerUniStreamsBlocked.set(-1); + } + if (debug.on()) { + debug.log("Increasing max remote %s streams to %s", + bidi ? "bidi" : "uni", newMaxStreamsLimit); + } + return frame.size(); + } + + public long nextMaxStreamsLimit(final boolean bidi) { + return bidi ? streams.remoteBidiNextMaxStreams : streams.remoteUniNextMaxStreams; + } + + /** + * {@return true if there are any streams on this connection which are blocked from + * sending data due to flow control limit, false otherwise} + */ + public final boolean hasBlockedStreams() { + return !this.flowControlBlockedStreams.isEmpty(); + } + + /** + * Checks whether the given stream is recorded as needing a control + * frame to be sent, and if so, add that frame to the list + * + * @param receiver the receiver part of the stream + * @param frame the frame to send + * @param frames list of frames + * @return size of the added frame, or zero if no frame was added + * @apiNote Typically, the control frame that is sent is either a MAX_STREAM_DATA + * or a STOP_SENDING frame + */ + private long checkSendControlFrame(QuicReceiverStreamImpl receiver, + QuicFrame frame, + List frames) { + if (frame == null) { + if (debug.on()) + debug.log("Stream %s: no receiver frame to send", receiver.streamId()); + return 0; + } + if (frame instanceof MaxStreamDataFrame maxStreamDataFrame) { + if (receiver.receivingState() == ReceivingStreamState.RECV) { + // if we know the final size, no point in increasing max data + if (debug.on()) + debug.log("Stream %s: Adding MaxStreamDataFrame", receiver.streamId()); + frames.add(frame); + receiver.updateMaxStreamData(maxStreamDataFrame.maxStreamData()); + return frame.size(); + } + return 0; + } else if (frame instanceof StopSendingFrame) { + if (debug.on()) + debug.log("Stream %s: Adding StopSendingFrame", receiver.streamId()); + frames.add(frame); + return frame.size(); + } else { + throw new InternalError("Should not reach here - not a control frame: " + frame); + } + } + + /** + * Package available data in {@link StreamFrame} instances and add them + * to the provided frames list. Additional frames, like connection control frames + * {@code STREAMS_BLOCKED}, {@code MAX_STREAMS} or stream flow control frames like + * {@code STREAM_DATA_BLOCKED} may also be added if space allows. The {@link StreamDataBlockedFrame} + * is added only once for a given stream, until the stream becomes ready again. + * @implSpec + * The total cumulated size of the returned frames must not exceed {@code maxSize}. + * The total cumulated lengths of the returned frames must not exceed {@code maxConnectionData}. + * + * @param encoder the {@link QuicPacketEncoder}, used if anything is quic version + * dependent. + * @param maxSize the cumulated maximum size of all the frames + * @param maxConnectionData the maximum number of stream data bytes that can + * be packaged to respect connection flow control + * constraints + * @param frames a list of frames in which to add the packaged data + * @return the total number of stream data bytes packaged in the created + * frames. This will not exceed the given {@code maxConnectionData}. + */ + public long produceFramesToSend(QuicPacketEncoder encoder, long maxSize, + long maxConnectionData, List frames) + throws QuicTransportException { + long remaining = maxSize; + long produced = 0; + try { + remaining -= checkResetAndOtherControls(frames, remaining); + // scan the streams and compose a list of frames - possibly including + // stream data blocked frames, + QuicSenderStreamImpl sender; + NEXT_STREAM: while ((sender = senderImpl(sendersReady.poll())) != null) { + long streamId = sender.streamId(); + boolean stillReady = true; + try { + do { + if (remaining == 0 || maxConnectionData == 0) break; + var state = sender.sendingState(); + switch (state) { + case SEND -> { + long offset = sender.dataSent(); + int headerSize = StreamFrame.headerSize(encoder, streamId, offset, remaining); + if (headerSize >= remaining) { + break NEXT_STREAM; + } + long maxControlled = Math.min(maxConnectionData, remaining - headerSize); + int maxData = (int) Math.min(Integer.MAX_VALUE, maxControlled); + if (maxData <= 0) { + break NEXT_STREAM; + } + ByteBuffer buffer = sender.poll(maxData); + if (buffer != null) { + int length = buffer.remaining(); + assert length <= remaining; + assert length <= maxData; + long streamSize = sender.streamSize(); + boolean fin = streamSize >= 0 && streamSize == offset + length; + if (fin) { + stillReady = false; + } + if (length > 0 || fin) { + StreamFrame frame = new StreamFrame(streamId, offset, length, fin, buffer); + int size = frame.size(); + assert size <= remaining : "stream:%s: size %s > remaining %s" + .formatted(streamId, size, remaining); + if (debug.on()) { + debug.log("stream:%s Adding StreamFrame: %s", + streamId, frame); + } + frames.add(frame); + remaining -= size; + produced += length; + maxConnectionData -= length; + } + } + var blocked = sender.isBlocked(); + if (blocked) { + // track this stream as blocked due to flow control + trackBlockedStream(streamId); + final var dataBlocked = new StreamDataBlockedFrame(streamId, sender.dataSent()); + // This might produce multiple StreamDataBlocked frames + // if the stream was added to sendersReady multiple times, so + // we check before actually sending a STREAM_DATA_BLOCKED frame + if (!frames.contains(dataBlocked)) { + var fdbSize = dataBlocked.size(); + if (dataBlocked.size() > remaining) { + // keep the stream in the ready list if we haven't been + // able to generate the StreamDataBlockedFrame + break NEXT_STREAM; + } + if (debug.on()) { + debug.log("stream:" + streamId + " sender is blocked: " + dataBlocked); + } + frames.add(dataBlocked); + remaining -= fdbSize; + } + stillReady = false; + continue NEXT_STREAM; + } + if (buffer == null) { + stillReady = sender.available() != 0; + continue NEXT_STREAM; + } + } + case DATA_SENT, DATA_RECVD, RESET_SENT, RESET_RECVD -> { + stillReady = false; + continue NEXT_STREAM; + } + case READY -> { + String msg = "stream:%s: illegal state %s".formatted(streamId, state); + throw new IllegalStateException(msg); + } + } + if (debug.on()) { + debug.log("packageStreamData: stream:%s, remaining:%s, " + + "maxConnectionData: %s, produced:%s", + streamId, remaining, maxConnectionData, produced); + } + } while (remaining > 0 && maxConnectionData > 0); + } catch (RuntimeException | AssertionError x) { + stillReady = false; + throw new QuicTransportException("Failed to compose frames for stream " + streamId, + KeySpace.ONE_RTT, 0, QuicTransportErrors.INTERNAL_ERROR.code(), x); + } finally { + if (stillReady) { + if (debug.on()) + debug.log("stream:%s is still ready", streamId); + enqueueForSending(streamId); + } else { + if (debug.on()) + debug.log("stream:%s is no longer ready", streamId); + } + } + assert maxConnectionData >= 0 : "produced " + produced + " max is " + maxConnectionData; + if (remaining == 0 || maxConnectionData == 0) break; + } + } catch (RuntimeException | AssertionError x) { + if (debug.on()) debug.log("Failed to compose frames", x); + if (Log.errors()) { + Log.logError(connection.logTag() + + ": Failed to compose frames", x); + } + throw new QuicTransportException("Failed to compose frames", + KeySpace.ONE_RTT, 0, QuicTransportErrors.INTERNAL_ERROR.code(), x); + } + return produced; + } + + private interface ReadyStreamCollection { + boolean isEmpty(); + + void add(QuicSenderStream sender); + + QuicStream poll(); + } + //This queue is used to ensure fair sending of stream data: the packageStreamData method + // will pop and push streams from/to this queue in a round-robin fashion so that one stream + // doesn't starve all the others. + private static class ReadyStreamQueue implements ReadyStreamCollection { + private ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); + + public boolean isEmpty() { + return queue.isEmpty(); + } + + public void add(QuicSenderStream sender) { + queue.add(sender); + } + + public QuicStream poll() { + return queue.poll(); + } + } + // This queue is used to ensure fast closing of streams: it always returns + // the ready stream with the lowest ID. + private static class ReadyStreamSortedQueue implements ReadyStreamCollection { + private ConcurrentSkipListMap queue = new ConcurrentSkipListMap<>(); + + public boolean isEmpty() { + return queue.isEmpty(); + } + + public void add(QuicSenderStream sender) { + queue.put(sender.streamId(), sender); + } + + public QuicStream poll() { + Map.Entry entry = queue.pollFirstEntry(); + if (entry == null) return null; + return entry.getValue(); + } + } + + // provides a limited view/operations over a ConcurrentHashMap(). we compute additional + // state in the remove() and put() APIs. providing only a limited set of APIs allows us + // to keep the places where we do that additional state computation, to minimal. + private final class StreamsContainer { + // A map of + private final ConcurrentMap streams = new ConcurrentHashMap<>(); + // active remote bidi stream count + private final AtomicLong remoteBidiActiveStreams = new AtomicLong(); + // active remote uni stream count + private final AtomicLong remoteUniActiveStreams = new AtomicLong(); + + private volatile long remoteBidiNextMaxStreams; + private volatile long remoteUniNextMaxStreams; + + AbstractQuicStream get(final long streamId) { + return streams.get(streamId); + } + + boolean remove(final long streamId, final AbstractQuicStream stream) { + if (!streams.remove(streamId, stream)) { + return false; + } + final int streamType = (int) (stream.streamId() & TYPE_MASK); + if (streamType == remoteBidi) { + final long currentActive = remoteBidiActiveStreams.decrementAndGet(); + remoteBidiNextMaxStreams = computeNextMaxStreamsLimit(streamType, currentActive, + remoteBidiMaxStreamLimit.get()); + } else if (streamType == remoteUni) { + final long currentActive = remoteUniActiveStreams.decrementAndGet(); + remoteUniNextMaxStreams = computeNextMaxStreamsLimit(streamType, currentActive, + remoteUniMaxStreamLimit.get()); + } + return true; + } + + AbstractQuicStream put(final long streamId, final AbstractQuicStream stream) { + final AbstractQuicStream previous = streams.put(streamId, stream); + final int streamType = (int) (stream.streamId() & TYPE_MASK); + if (streamType == remoteBidi) { + final long currentActive = remoteBidiActiveStreams.incrementAndGet(); + remoteBidiNextMaxStreams = computeNextMaxStreamsLimit(streamType, currentActive, + remoteBidiMaxStreamLimit.get()); + } else if (streamType == remoteUni) { + final long currentActive = remoteUniActiveStreams.incrementAndGet(); + remoteUniNextMaxStreams = computeNextMaxStreamsLimit(streamType, currentActive, + remoteUniMaxStreamLimit.get()); + } + return previous; + } + + Stream all() { + return streams.values().stream(); + } + + /** + * Returns the next (higher) max streams limit that can be advertised to the remote peer. + * Returns {@code 0} if the limit should not be increased. + */ + private long computeNextMaxStreamsLimit( + final int streamType, final long currentActiveCount, + final long currentMaxStreamsLimit) { + // we only deal with remote bidi or remote uni + assert (streamType == remoteBidi || streamType == remoteUni) + : "stream type is neither remote bidi nor remote uni: " + streamType; + final long usedRemoteStreams = peekNextStreamId(streamType) >> 2; + final boolean bidi = streamType == remoteBidi; + final var desiredStreamCount = bidi ? MAX_BIDI_STREAMS_WINDOW_SIZE + : MAX_UNI_STREAMS_WINDOW_SIZE; + final long desiredMaxStreams = usedRemoteStreams - currentActiveCount + desiredStreamCount; + // we compute a new limit after we consumed 25% (arbitrary decision) of the desired window + if (desiredMaxStreams - currentMaxStreamsLimit > desiredStreamCount >> 2) { + return desiredMaxStreams; + } + return 0; + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicReceiverStream.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicReceiverStream.java new file mode 100644 index 00000000000..1a20cbe5211 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicReceiverStream.java @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import jdk.internal.net.http.common.SequentialScheduler; + +/** + * An interface that represents the receiving part of a stream. + *

From RFC 9000: + * + * On the receiving part of a stream, an application protocol can: + *

    + *
  • read data; and
  • + *
  • abort reading of the stream and request closure, possibly + * resulting in a STOP_SENDING frame (Section 19.5).
  • + *
+ * + */ +public non-sealed interface QuicReceiverStream extends QuicStream { + + /** + * An enum that models the state of the receiving part of a stream. + */ + enum ReceivingStreamState implements QuicStream.StreamState { + /** + * The initial state for the receiving part of a + * stream is "Recv". + *

+ * In the "Recv" state, the endpoint receives STREAM + * and STREAM_DATA_BLOCKED frames. Incoming data is buffered + * and can be reassembled into the correct order for delivery + * to the application. As data is consumed by the application + * and buffer space becomes available, the endpoint sends + * MAX_STREAM_DATA frames to allow the peer to send more data. + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + RECV, + /** + * When a STREAM frame with a FIN bit is received, the final size of + * the stream is known; see Section 4.5. The receiving part of the + * stream then enters the "Size Known" state. In this state, the + * endpoint no longer needs to send MAX_STREAM_DATA frames; it only + * receives any retransmissions of stream data. + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + SIZE_KNOWN, + /** + * Once all data for the stream has been received, the receiving part + * enters the "Data Recvd" state. This might happen as a result of + * receiving the same STREAM frame that causes the transition to + * "Size Known". After all data has been received, any STREAM or + * STREAM_DATA_BLOCKED frames for the stream can be discarded. + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + DATA_RECVD, + /** + * The "Data Recvd" state persists until stream data has been delivered + * to the application. Once stream data has been delivered, the stream + * enters the "Data Read" state, which is a terminal state. + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + DATA_READ, + /** + * Receiving a RESET_STREAM frame in the "Recv" or "Size Known" state + * causes the stream to enter the "Reset Recvd" state. This might + * cause the delivery of stream data to the application to be + * interrupted. + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + RESET_RECVD, + /** + * Once the application receives the signal indicating that the + * stream was reset, the receiving part of the stream transitions to + * the "Reset Read" state, which is a terminal state. + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + RESET_READ; + + @Override + public boolean isTerminal() { + return this == DATA_READ || this == RESET_READ; + } + + /** + * {@return true if this state indicates that the stream has been reset by the sender} + */ + public boolean isReset() { return this == RESET_RECVD || this == RESET_READ; } + } + + /** + * {@return the receiving state of the stream} + */ + ReceivingStreamState receivingState(); + + /** + * Connects an {@linkplain QuicStreamReader#started() unstarted} reader + * to the receiver end of this stream. + * @param scheduler A sequential scheduler that will be invoked + * when the reader is started and new data becomes available for reading + * @return a {@code QuicStreamReader} to read data from this + * stream. + * @throws IllegalStateException if a reader is already connected. + */ + QuicStreamReader connectReader(SequentialScheduler scheduler); + + /** + * Disconnect the reader, so that a new reader can be connected. + * + * @apiNote + * This can be useful for handing the stream over after having read + * or peeked at some bytes. + * + * @param reader the reader to be disconnected + * @throws IllegalStateException if the given reader is not currently + * connected to the stream + */ + void disconnectReader(QuicStreamReader reader); + + /** + * Cancels the reading side of this stream by sending + * a STOP_SENDING frame. + * + * @param errorCode the application error code + * + */ + void requestStopSending(long errorCode); + + /** + * {@return the amount of data that has been received so far} + * @apiNote This may include data that has not been read by the + * application yet, but does not count any data that may have + * been received twice. + */ + long dataReceived(); + + /** + * {@return the maximum amount of data that can be received on + * this stream} + * + * @apiNote This corresponds to the maximum amount of data that + * the peer has been allowed to send. + */ + long maxStreamData(); + + /** + * {@return the error code for this stream, or {@code -1}} + */ + long rcvErrorCode(); + + default boolean isStopSendingRequested() { return false; } + + @Override + default boolean hasError() { + return rcvErrorCode() >= 0; + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicReceiverStreamImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicReceiverStreamImpl.java new file mode 100644 index 00000000000..120a9bffd5b --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicReceiverStreamImpl.java @@ -0,0 +1,942 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.OrderedFlow.StreamDataFlow; +import jdk.internal.net.http.quic.QuicConnectionImpl; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.http.quic.frames.ConnectionCloseFrame; +import jdk.internal.net.http.quic.frames.ResetStreamFrame; +import jdk.internal.net.http.quic.frames.StreamDataBlockedFrame; +import jdk.internal.net.http.quic.frames.StreamFrame; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.ByteBuffer; +import java.util.concurrent.ConcurrentLinkedQueue; + +import static jdk.internal.net.http.quic.QuicConnectionImpl.DEFAULT_INITIAL_STREAM_MAX_DATA; +import static jdk.internal.net.http.quic.frames.QuicFrame.MAX_VL_INTEGER; +import static jdk.internal.net.http.quic.streams.QuicReceiverStream.ReceivingStreamState.*; + +/** + * A class that implements the receiver part of a quic stream. + */ +final class QuicReceiverStreamImpl extends AbstractQuicStream implements QuicReceiverStream { + + private static final int MAX_SMALL_FRAGMENTS = + Utils.getIntegerProperty("jdk.httpclient.quic.maxSmallFragments", 100); + private final Logger debug = Utils.getDebugLogger(this::dbgTag); + private final String dbgTag; + + // The dataFlow reorders incoming stream frames and removes duplicates. + // It contains frames that cannot be delivered yet because they are not + // at the expected offset. + private final StreamDataFlow dataFlow = new StreamDataFlow(); + // The orderedQueue contains frames that can be delivered to the application now. + // They are inserted in the queue in order. + // The QuicStreamReader's scheduler loop consumes this queue. + private final ConcurrentLinkedQueue orderedQueue = new ConcurrentLinkedQueue<>(); + // Desired buffer size; used when updating maxStreamData + private final long desiredBufferSize; + // Maximum stream data + private volatile long maxStreamData; + // how much data has been processed on this stream. + // This is data that was poll'ed from orderedQueue or dropped after stream reset. + private volatile long processed; + // how much data has been delivered to orderedQueue. This doesn't take into account + // frames that may be stored in the dataFlow. + private volatile long received; + // maximum of offset+length across all received frames + private volatile long maxReceivedData; + // the size of the stream, when known. Defaults to 0 when unknown. + private volatile long knownSize; + // the connected reader + private volatile QuicStreamReaderImpl reader; + // eof when the last payload has been polled by the application + private volatile boolean eof; + // the state of the receiving stream + private volatile ReceivingStreamState receivingState; + private volatile boolean requestedStopSending; + private volatile long errorCode; + + private final static long MIN_BUFFER_SIZE = 16L << 10; + QuicReceiverStreamImpl(QuicConnectionImpl connection, long streamId) { + super(connection, validateStreamId(connection, streamId)); + errorCode = -1; + receivingState = ReceivingStreamState.RECV; + dbgTag = connection.streamDbgTag(streamId, "R"); + long bufsize = DEFAULT_INITIAL_STREAM_MAX_DATA; + desiredBufferSize = Math.clamp(bufsize, MIN_BUFFER_SIZE, MAX_VL_INTEGER); + } + + private static long validateStreamId(QuicConnectionImpl connection, long streamId) { + if (QuicStreams.isBidirectional(streamId)) return streamId; + if (connection.isClientConnection() == QuicStreams.isClientInitiated(streamId)) { + throw new IllegalArgumentException("A locally initiated stream can't be read-only"); + } + return streamId; + } + + /** + * Sends a {@link ConnectionCloseFrame} due to MAX_STREAM_DATA exceeded + * for the stream. + * @param streamFrame the stream frame that caused the excess + * @param maxData the value of MAX_STREAM_DATA which was exceeded + */ + private static QuicTransportException streamControlOverflow(StreamFrame streamFrame, long maxData) throws QuicTransportException { + String reason = "Stream max data exceeded: offset=%s, length=%s, max stream data=%s" + .formatted(streamFrame.offset(), streamFrame.dataLength(), maxData); + throw new QuicTransportException(reason, + QuicTLSEngine.KeySpace.ONE_RTT, streamFrame.getTypeField(), QuicTransportErrors.FLOW_CONTROL_ERROR); + } + + // debug tag for debug logger + String dbgTag() { + return dbgTag; + } + + @Override + public StreamState state() { + return receivingState(); + } + + @Override + public ReceivingStreamState receivingState() { + return receivingState; + } + + @Override + public QuicStreamReader connectReader(SequentialScheduler scheduler) { + var reader = this.reader; + if (reader == null) { + reader = new QuicStreamReaderImpl(scheduler); + if (Handles.READER.compareAndSet(this, null, reader)) { + if (debug.on()) debug.log("reader connected"); + return reader; + } + } + throw new IllegalStateException("reader already connected"); + } + + @Override + public void disconnectReader(QuicStreamReader reader) { + var previous = this.reader; + if (reader == previous) { + if (Handles.READER.compareAndSet(this, reader, null)) { + if (debug.on()) debug.log("reader disconnected"); + return; + } + } + throw new IllegalStateException("reader not connected"); + } + + @Override + public boolean isStopSendingRequested() { + return requestedStopSending; + } + + @Override + public void requestStopSending(final long errorCode) { + if (Handles.STOP_SENDING.compareAndSet(this, false, true)) { + assert requestedStopSending : "requestedStopSending should be true!"; + if (debug.on()) debug.log("requestedStopSending: true"); + var state = receivingState; + try { + setErrorCode(errorCode); + switch(state) { + case RECV, SIZE_KNOWN -> { + connection().scheduleStopSendingFrame(streamId(), errorCode); + } + // otherwise do nothing + } + } finally { + // RFC-9000, section 3.5: "If an application is no longer interested in the data it is + // receiving on a stream, it can abort reading the stream and specify an application + // error code." + // So it implies that the application isn't anymore interested in receiving the data + // that has been buffered in the stream, so we drop all buffered data on this stream + if (state != RECV && state != DATA_READ) { + // we know the final size; we can remove the stream + increaseProcessedData(knownSize); + if (switchReceivingState(RESET_READ)) { + eof = false; + } + } + dataFlow.clear(); + orderedQueue.clear(); + if (debug.on()) { + debug.log("Dropped all buffered frames on stream %d after STOP_SENDING was requested" + + " with error code 0x%s", streamId(), Long.toHexString(errorCode)); + } + } + } + } + + @Override + public long dataReceived() { + return received; + } + + @Override + public long maxStreamData() { + return maxStreamData; + } + + @Override + public boolean isDone() { + return switch (receivingState()) { + case DATA_READ, DATA_RECVD, RESET_READ, RESET_RECVD -> + // everything received from peer + true; + default -> + // the stream is only half closed + false; + }; + } + + /** + * Receives a QuicFrame from the remote peer. + * + * @param resetStreamFrame the frame received + */ + void processIncomingResetFrame(final ResetStreamFrame resetStreamFrame) + throws QuicTransportException { + try { + checkUpdateState(resetStreamFrame); + if (requestedStopSending) { + increaseProcessedData(knownSize); + switchReceivingState(RESET_READ); + } + } finally { + // make sure the state is switched to reset received. + // even if we're closing the connection + switchReceivingState(RESET_RECVD); + // wakeup reader, then throw exception. + QuicStreamReaderImpl reader = this.reader; + if (reader != null) reader.wakeup(); + } + } + + void processIncomingFrame(final StreamDataBlockedFrame streamDataBlocked) { + assert streamDataBlocked.streamId() == streamId() : "unexpected stream id"; + final long peerBlockedOn = streamDataBlocked.maxStreamData(); + final long currentLimit = this.maxStreamData; + if (peerBlockedOn > currentLimit) { + // shouldn't have happened. ignore and don't increase the limit. + return; + } + // the peer has stated that the stream is blocked due to flow control limit that we have + // imposed and has requested for increasing the limit. we approve that request + // and increase the limit only if the amount of received data that we have received and + // processed on this stream is more than 1/4 of the credit window. + if (!requestedStopSending + && currentLimit - processed < (desiredBufferSize - desiredBufferSize / 4)) { + demand(desiredBufferSize); + } else { + if (debug.on()) { + debug.log("ignoring STREAM_DATA_BLOCKED frame %s," + + " since current limit %d is large enough", streamDataBlocked, currentLimit); + } + } + } + + private void demand(final long additional) { + assert additional > 0 && additional < MAX_VL_INTEGER : "invalid demand: " + additional; + var received = dataReceived(); + var maxStreamData = maxStreamData(); + + final long newMax = Math.clamp(received + additional, maxStreamData, MAX_VL_INTEGER); + if (newMax > maxStreamData) { + connection().requestSendMaxStreamData(streamId(), newMax); + updateMaxStreamData(newMax); + } + } + + /** + * Called when the connection is closed + * @param terminationCause the termination cause + */ + void terminate(final TerminationCause terminationCause) { + setErrorCode(terminationCause.getCloseCode()); + final QuicStreamReaderImpl reader = this.reader; + if (reader != null) { + reader.wakeup(); + } + } + + @Override + public long rcvErrorCode() { + return errorCode; + } + + /** + * Receives a QuicFrame from the remote peer. + * + * @param streamFrame the frame received + */ + public void processIncomingFrame(final StreamFrame streamFrame) + throws QuicTransportException { + // RFC-9000, section 3.5: "STREAM frames received after sending a STOP_SENDING frame + // are still counted toward connection and stream flow control, even though these + // frames can be discarded upon receipt." + // so we do the necessary data size checks before checking if we sent a "STOP_SENDING" + // frame + checkUpdateState(streamFrame); + final ReceivingStreamState state = receivingState; + if (debug.on()) debug.log("receivingState: " + state); + long knownSize = this.knownSize; + // RESET was read or received: drop the frame. + if (state == RESET_READ || state == RESET_RECVD) { + if (debug.on()) { + debug.log("Dropping frame on stream %d since state is %s", + streamId(), state); + } + return; + } + if (requestedStopSending) { + // drop the frame + if (debug.on()) { + debug.log("Dropping frame that was received after a STOP_SENDING" + + " frame was sent on stream %d", streamId()); + } + increaseProcessedData(maxReceivedData); + if (state != RECV) { + // we know the final size; we can remove the stream + switchReceivingState(RESET_READ); + } + return; + } + + var readyFrame = dataFlow.receive(streamFrame); + var received = this.received; + boolean needWakeup = false; + while (readyFrame != null) { + // check again - this avoids a race condition where a frame + // would be considered ready if requestStopSending had been + // called concurrently, and `receive` was called after the + // state had been switched + if (requestedStopSending) { + return; + } + assert received == readyFrame.offset() + : "data received (%s) doesn't match offset (%s)" + .formatted(received, readyFrame.offset()); + this.received = received = received + readyFrame.dataLength(); + offer(readyFrame); + needWakeup = true; + readyFrame = dataFlow.poll(); + } + if (state == SIZE_KNOWN && received == knownSize) { + if (switchReceivingState(DATA_RECVD)) { + offerEof(); + needWakeup = true; + } + } + if (needWakeup) { + var reader = this.reader; + if (reader != null) reader.wakeup(); + } else { + int numFrames = dataFlow.size(); + long numBytes = dataFlow.buffered(); + if (numFrames > MAX_SMALL_FRAGMENTS && numBytes / numFrames < 400) { + // The peer sent a large number of small fragments + // that follow a gap and can't be immediately released to the reader; + // we need to buffer them, and the memory overhead is unreasonably high. + throw new QuicTransportException("Excessive stream fragmentation", + QuicTLSEngine.KeySpace.ONE_RTT, streamFrame.frameType(), + QuicTransportErrors.INTERNAL_ERROR); + } + } + } + + /** + * Checks for error conditions: + * - max stream data errors + * - max data errors + * - final size errors + * If everything checks OK, updates counters and returns, otherwise throws. + * + * @implNote + * This method may update counters before throwing. This is OK + * because we do not expect to use them again in this case. + * @param streamFrame received stream frame + * @throws QuicTransportException if frame is invalid + */ + private void checkUpdateState(StreamFrame streamFrame) throws QuicTransportException { + long offset = streamFrame.offset(); + long length = streamFrame.dataLength(); + assert offset >= 0; + assert length >= 0; + + // check maxStreamData + long maxData = maxStreamData; + assert maxData >= 0; + long size; + try { + size = Math.addExact(offset, length); + } catch (ArithmeticException x) { + // should not happen + if (debug.on()) { + debug.log("offset + length exceeds max value", x); + } + throw streamControlOverflow(streamFrame, Long.MAX_VALUE); + } + if (size > maxData) { + throw streamControlOverflow(streamFrame, maxData); + } + ReceivingStreamState state = receivingState; + // check finalSize if known + long knownSize = this.knownSize; + assert knownSize >= 0; + if (state != RECV && size > knownSize) { + String reason = "Stream final size exceeded: offset=%s, length=%s, final size=%s" + .formatted(streamFrame.offset(), streamFrame.dataLength(), knownSize); + throw new QuicTransportException(reason, + QuicTLSEngine.KeySpace.ONE_RTT, streamFrame.getTypeField(), QuicTransportErrors.FINAL_SIZE_ERROR); + } + // check maxData + updateMaxReceivedData(size, streamFrame.getTypeField()); + if (streamFrame.isLast()) { + // check max received data, throw if we have data beyond the (new) EOF + if (size < maxReceivedData) { + String reason = "Stream truncated: offset=%s, length=%s, max received=%s" + .formatted(streamFrame.offset(), streamFrame.dataLength(), maxReceivedData); + throw new QuicTransportException(reason, + QuicTLSEngine.KeySpace.ONE_RTT, streamFrame.getTypeField(), QuicTransportErrors.FINAL_SIZE_ERROR); + } + if (state == RECV && switchReceivingState(SIZE_KNOWN)) { + this.knownSize = size; + } else { + if (size != knownSize) { + String reason = "Stream final size changed: offset=%s, length=%s, final size=%s" + .formatted(streamFrame.offset(), streamFrame.dataLength(), knownSize); + throw new QuicTransportException(reason, + QuicTLSEngine.KeySpace.ONE_RTT, streamFrame.getTypeField(), QuicTransportErrors.FINAL_SIZE_ERROR); + } + } + } + } + + /** + * Checks for error conditions: + * - max stream data errors + * - max data errors + * - final size errors + * If everything checks OK, updates counters and returns, otherwise throws. + * + * @implNote + * This method may update counters before throwing. This is OK + * because we do not expect to use them again in this case. + * @param resetStreamFrame received reset stream frame + * @throws QuicTransportException if frame is invalid + */ + private void checkUpdateState(ResetStreamFrame resetStreamFrame) throws QuicTransportException { + // check maxStreamData + long maxData = maxStreamData; + assert maxData >= 0; + long size = resetStreamFrame.finalSize(); + long errorCode = resetStreamFrame.errorCode(); + setErrorCode(errorCode); + if (size > maxData) { + String reason = "Stream max data exceeded: finalSize=%s, max stream data=%s" + .formatted(size, maxData); + throw new QuicTransportException(reason, + QuicTLSEngine.KeySpace.ONE_RTT, resetStreamFrame.getTypeField(), QuicTransportErrors.FLOW_CONTROL_ERROR); + } + ReceivingStreamState state = receivingState; + updateMaxReceivedData(size, resetStreamFrame.getTypeField()); + // check max received data, throw if we have data beyond the (new) EOF + if (size < maxReceivedData) { + String reason = "Stream truncated: finalSize=%s, max received=%s" + .formatted(size, maxReceivedData); + throw new QuicTransportException(reason, + QuicTLSEngine.KeySpace.ONE_RTT, resetStreamFrame.getTypeField(), QuicTransportErrors.FINAL_SIZE_ERROR); + } + if (state == RECV && switchReceivingState(RESET_RECVD)) { + this.knownSize = size; + } else { + if (state == SIZE_KNOWN) { + switchReceivingState(RESET_RECVD); + } + if (size != knownSize) { + String reason = "Stream final size changed: new finalSize=%s, old final size=%s" + .formatted(size, knownSize); + throw new QuicTransportException(reason, + QuicTLSEngine.KeySpace.ONE_RTT, resetStreamFrame.getTypeField(), QuicTransportErrors.FINAL_SIZE_ERROR); + } + } + } + + void checkOpened() throws IOException { + final TerminationCause terminationCause = connection().terminationCause(); + if (terminationCause == null) { + return; + } + throw terminationCause.getCloseCause(); + } + + private void offer(StreamFrame frame) { + var payload = frame.payload(); + if (payload.hasRemaining()) { + orderedQueue.add(payload.slice()); + } + } + + private void offerEof() { + orderedQueue.add(QuicStreamReader.EOF); + } + + /** + * Update the value of MAX_STREAM_DATA for this stream + * @param newMaxStreamData + */ + public void updateMaxStreamData(long newMaxStreamData) { + long maxStreamData = this.maxStreamData; + boolean updated = false; + while (maxStreamData < newMaxStreamData) { + if (updated = Handles.MAX_STREAM_DATA.compareAndSet(this, maxStreamData, newMaxStreamData)) break; + maxStreamData = this.maxStreamData; + } + if (updated) { + if (debug.on()) { + debug.log("updateMaxStreamData: max stream data updated from %s to %s", + maxStreamData, newMaxStreamData); + } + } + } + + /** + * Update the {@code maxReceivedData} value, and return the amount + * by which {@code maxReceivedData} was increased. This method is a + * no-op and returns 0 if {@code maxReceivedData >= newMax}. + * + * @param newMax the new max offset - typically obtained + * by adding the length of a frame to its + * offset + * @param frameType type of frame received + * @throws QuicTransportException if flow control was violated + */ + private void updateMaxReceivedData(long newMax, long frameType) throws QuicTransportException { + assert newMax >= 0; + var max = this.maxReceivedData; + while (max < newMax) { + if (Handles.MAX_RECEIVED_DATA.compareAndSet(this, max, newMax)) { + // report accepted data to connection flow control, + // and update the amount of data received in the + // connection. This will also check whether connection + // flow control is exceeded, and throw in + // this case + connection().increaseReceivedData(newMax - max, frameType); + return; + } + max = this.maxReceivedData; + } + } + + /** + * Notifies the connection about received data that is no longer buffered. + */ + private void increaseProcessedDataBy(int diff) { + assert diff >= 0; + if (diff <= 0) return; + synchronized (this) { + if (requestedStopSending) { + // once we request stop sending, updates are handled by increaseProcessedData + return; + } + assert processed + diff <= received : processed+"+"+diff+">"+received+"("+maxReceivedData+")"; + processed += diff; + } + connection().increaseProcessedData(diff); + } + + /** + * Notifies the connection about received data that is no longer buffered. + */ + private void increaseProcessedData(long newProcessed) { + long diff; + synchronized (this) { + if (newProcessed > processed) { + diff = newProcessed - processed; + processed = newProcessed; + } else { + diff = 0; + } + } + if (diff > 0) { + connection().increaseProcessedData(diff); + } + } + + // private implementation of a QuicStreamReader for this stream + private final class QuicStreamReaderImpl extends QuicStreamReader { + + static final int STARTED = 1; + static final int PENDING = 2; + // should not need volatile here, as long as we + // switch to using synchronize whenever state & STARTED == 0 + // Once state & STARTED != 0 the state should no longer change + private int state; + + QuicStreamReaderImpl(SequentialScheduler scheduler) { + super(scheduler); + } + + @Override + public ReceivingStreamState receivingState() { + checkConnected(); + return QuicReceiverStreamImpl.this.receivingState(); + } + + @Override + public ByteBuffer poll() throws IOException { + checkConnected(); + var buffer = orderedQueue.poll(); + if (buffer == null) { + if (eof) return EOF; + var state = receivingState; + if (state == RESET_RECVD) { + increaseProcessedData(knownSize); + } + checkReset(); + // unfulfilled = maxStreamData - received; + // if we have received more than 1/4 of the buffer, update maxStreamData + if (!requestedStopSending && unfulfilled() < desiredBufferSize - desiredBufferSize / 4) { + demand(desiredBufferSize); + } + return null; + } + + if (requestedStopSending) { + // check reset again + checkReset(); + return null; + } + increaseProcessedDataBy(buffer.capacity()); + if (buffer == EOF) { + eof = true; + assert processed == received : processed + "!=" + received; + switchReceivingState(DATA_READ); + return EOF; + } + // if the amount of received data that has been processed on this stream is + // more than 1/4 of the credit window then send a MaxStreamData frame. + if (!requestedStopSending && maxStreamData - processed < desiredBufferSize - desiredBufferSize / 4) { + demand(desiredBufferSize); + } + return buffer; + } + + /** + * Checks whether the stream was reset and throws an exception if + * yes. + * + * @throws IOException if the stream is reset + */ + private void checkReset() throws IOException { + var state = receivingState; + if (state == RESET_READ || state == RESET_RECVD) { + if (state == RESET_RECVD) { + switchReceivingState(RESET_READ); + } + if (requestedStopSending) { + throw new IOException("Stream %s closed".formatted(streamId())); + } else { + throw new IOException("Stream %s reset by peer".formatted(streamId())); + } + } + checkOpened(); + } + + @Override + public ByteBuffer peek() throws IOException { + checkConnected(); + var buffer = orderedQueue.peek(); + if (buffer == null) { + checkReset(); + return eof ? EOF : null; + } + return buffer; + } + + private long unfulfilled() { + // TODO: should we synchronize to ensure consistency? + var max = maxStreamData; + var rcved = received; + return max - rcved; + } + + @Override + public QuicReceiverStream stream() { + var stream = QuicReceiverStreamImpl.this; + var reader = stream.reader; + return reader == this ? stream : null; + } + + @Override + public boolean connected() { + var reader = QuicReceiverStreamImpl.this.reader; + return reader == this; + } + + @Override + public boolean started() { + int state = this.state; + if ((state & STARTED) == STARTED) return true; + synchronized (this) { + state = this.state; + return (state & STARTED) == STARTED; + } + } + + private boolean wakeupOnStart(int state) { + assert Thread.holdsLock(this); + return (state & PENDING) != 0 + || !orderedQueue.isEmpty() + || receivingState != RECV; + } + + @Override + public void start() { + // Run the scheduler if woken up before starting + int state = this.state; + if ((state & STARTED) == 0) { + boolean wakeup = false; + synchronized (this) { + state = this.state; + if ((state & STARTED) == 0) { + wakeup = wakeupOnStart(state); + state = this.state = STARTED; + } + } + assert started(); + if (debug.on()) { + debug.log("reader started (wakeup: %s)", wakeup); + } + if (wakeup || !orderedQueue.isEmpty() || receivingState != RECV) wakeup(); + } + assert started(); + } + + private void checkConnected() { + if (!connected()) throw new IllegalStateException("reader not connected"); + } + + void wakeup() { + // Only run the scheduler after the reader is started. + int state = this.state; + boolean notstarted, pending = false; + if (notstarted = ((state & STARTED) == 0)) { + synchronized (this) { + state = this.state; + if (notstarted = ((state & STARTED) == 0)) { + state = this.state = (state | PENDING); + pending = (state & PENDING) == PENDING; + assert !started(); + } + } + } + if (notstarted) { + if (debug.on()) { + debug.log("reader not started (pending: %s)", pending); + } + return; + } + assert started(); + scheduler.runOrSchedule(connection().quicInstance().executor()); + } + } + + /** + * Called when a state change is needed + * @param newState the new state. + */ + private boolean switchReceivingState(ReceivingStreamState newState) { + ReceivingStreamState oldState = receivingState; + if (debug.on()) { + debug.log("switchReceivingState %s -> %s", + oldState, newState); + } + boolean switched = switch(newState) { + case SIZE_KNOWN -> markSizeKnown(); + case DATA_RECVD -> markDataRecvd(); + case RESET_RECVD -> markResetRecvd(); + case RESET_READ -> markResetRead(); + case DATA_READ -> markDataRead(); + default -> throw new UnsupportedOperationException("switch state to " + newState); + }; + if (debug.on()) { + if (switched) { + debug.log("switched receiving state from %s to %s", oldState, newState); + } else { + debug.log("receiving state not switched; state is %s", receivingState); + } + } + + if (switched && newState.isTerminal()) { + notifyTerminalState(newState); + } + + return switched; + } + + private void notifyTerminalState(ReceivingStreamState state) { + assert state == DATA_READ || state == RESET_READ : state; + connection().notifyTerminalState(streamId(), state); + } + + // DATA_RECV is reached when the last frame is received, + // and there's no gap + private boolean markDataRecvd() { + boolean done, switched = false; + ReceivingStreamState oldState; + do { + oldState = receivingState; + done = switch (oldState) { + // CAS: Compare And Set + case RECV, SIZE_KNOWN -> switched = + Handles.RECEIVING_STATE.compareAndSet(this, + oldState, DATA_RECVD); + case DATA_RECVD, DATA_READ, RESET_RECVD, RESET_READ -> true; + }; + } while (!done); + return switched; + } + + // SIZE_KNOWN is reached when a stream frame with the FIN bit is received + private boolean markSizeKnown() { + boolean done, switched = false; + ReceivingStreamState oldState; + do { + oldState = receivingState; + done = switch (oldState) { + // CAS: Compare And Set + case RECV -> switched = + Handles.RECEIVING_STATE.compareAndSet(this, + oldState, SIZE_KNOWN); + case DATA_RECVD, DATA_READ, SIZE_KNOWN, RESET_RECVD, RESET_READ -> true; + }; + } while(!done); + return switched; + } + + // RESET_RECV is reached when a RESET_STREAM frame is received + private boolean markResetRecvd() { + boolean done, switched = false; + ReceivingStreamState oldState; + do { + oldState = receivingState; + done = switch (oldState) { + // CAS: Compare And Set + case RECV, SIZE_KNOWN -> switched = + Handles.RECEIVING_STATE.compareAndSet(this, + oldState, RESET_RECVD); + case DATA_RECVD, DATA_READ, RESET_RECVD, RESET_READ -> true; + }; + } while(!done); + return switched; + } + + // Called when the consumer has polled the last data + // DATA_READ is a terminal state + private boolean markDataRead() { + boolean done, switched = false; + ReceivingStreamState oldState; + do { + oldState = receivingState; + done = switch (oldState) { + // CAS: Compare And Set + case SIZE_KNOWN, DATA_RECVD, RESET_RECVD -> switched = + Handles.RECEIVING_STATE.compareAndSet(this, + oldState, DATA_READ); + case RESET_READ, DATA_READ -> true; + default -> throw new IllegalStateException("%s: %s -> %s" + .formatted(streamId(), oldState, DATA_READ)); + }; + } while(!done); + return switched; + } + + // Called when the consumer has read the reset + // RESET_READ is a terminal state + private boolean markResetRead() { + boolean done, switched = false; + ReceivingStreamState oldState; + do { + oldState = receivingState; + done = switch (oldState) { + // CAS: Compare And Set + case SIZE_KNOWN, DATA_RECVD, RESET_RECVD -> switched = + Handles.RECEIVING_STATE.compareAndSet(this, + oldState, RESET_READ); + case RESET_READ, DATA_READ -> true; + default -> throw new IllegalStateException("%s: %s -> %s" + .formatted(streamId(), oldState, RESET_READ)); + }; + } while(!done); + return switched; + } + + private void setErrorCode(long code) { + Handles.ERROR_CODE.compareAndSet(this, -1, code); + } + + private static final class Handles { + static final VarHandle READER; + static final VarHandle RECEIVING_STATE; + static final VarHandle MAX_STREAM_DATA; + static final VarHandle MAX_RECEIVED_DATA; + static final VarHandle STOP_SENDING; + static final VarHandle ERROR_CODE; + static { + try { + var lookup = MethodHandles.lookup(); + RECEIVING_STATE = lookup.findVarHandle(QuicReceiverStreamImpl.class, + "receivingState", ReceivingStreamState.class); + READER = lookup.findVarHandle(QuicReceiverStreamImpl.class, + "reader", QuicStreamReaderImpl.class); + MAX_STREAM_DATA = lookup.findVarHandle(QuicReceiverStreamImpl.class, + "maxStreamData", long.class); + MAX_RECEIVED_DATA = lookup.findVarHandle(QuicReceiverStreamImpl.class, + "maxReceivedData", long.class); + STOP_SENDING = lookup.findVarHandle(QuicReceiverStreamImpl.class, + "requestedStopSending", boolean.class); + ERROR_CODE = lookup.findVarHandle(QuicReceiverStreamImpl.class, + "errorCode", long.class); + } catch (Exception x) { + throw new ExceptionInInitializerError(x); + } + } + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicSenderStream.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicSenderStream.java new file mode 100644 index 00000000000..bdd5b55ee0b --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicSenderStream.java @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import java.io.IOException; + +import jdk.internal.net.http.common.SequentialScheduler; + +/** + * An interface that represents the sending part of a stream. + *

From RFC 9000: + * + * On the sending part of a stream, an application protocol can: + *

    + *
  • write data, understanding when stream flow control credit + * (Section 4.1) has successfully been reserved to send the + * written data;
  • + *
  • end the stream (clean termination), resulting in a STREAM frame + * (Section 19.8) with the FIN bit set; and
  • + *
  • reset the stream (abrupt termination), resulting in a RESET_STREAM + * frame (Section 19.4) if the stream was not already in a terminal + * state.
  • + *
+ * + */ +public non-sealed interface QuicSenderStream extends QuicStream { + + /** + * An enum that models the state of the sending part of a stream. + */ + enum SendingStreamState implements QuicStream.StreamState { + /** + * The "Ready" state represents a newly created stream that is able + * to accept data from the application. Stream data might be + * buffered in this state in preparation for sending. + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + READY, + /** + * In the "Send" state, an endpoint transmits -- and retransmits as + * necessary -- stream data in STREAM frames. The endpoint respects + * the flow control limits set by its peer and continues to accept + * and process MAX_STREAM_DATA frames. An endpoint in the "Send" state + * generates STREAM_DATA_BLOCKED frames if it is blocked from sending + * by stream flow control limits (Section 4.1). + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + SEND, + /** + * After the application indicates that all stream data has been sent + * and a STREAM frame containing the FIN bit is sent, the sending part + * of the stream enters the "Data Sent" state. From this state, the + * endpoint only retransmits stream data as necessary. The endpoint + * does not need to check flow control limits or send STREAM_DATA_BLOCKED + * frames for a stream in this state. MAX_STREAM_DATA frames might be received + * until the peer receives the final stream offset. The endpoint can safely + * ignore any MAX_STREAM_DATA frames it receives from its peer for a + * stream in this state. + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + DATA_SENT, + /** + * From any state that is one of "Ready", "Send", or "Data Sent", an + * application can signal that it wishes to abandon transmission of + * stream data. Alternatively, an endpoint might receive a STOP_SENDING + * frame from its peer. In either case, the endpoint sends a RESET_STREAM + * frame, which causes the stream to enter the "Reset Sent" state. + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + RESET_SENT, + /** + * Once all stream data has been successfully acknowledged, the sending + * part of the stream enters the "Data Recvd" state, which is a + * terminal state. + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + DATA_RECVD, + /** + * Once a packet containing a RESET_STREAM has been acknowledged, the + * sending part of the stream enters the "Reset Recvd" state, which + * is a terminal state. + *

+ * [RFC 9000, Section 3.1] + * (https://www.rfc-editor.org/rfc/rfc9000#name-sending-stream-states) + */ + RESET_RECVD; + + @Override + public boolean isTerminal() { + return this == DATA_RECVD || this == RESET_RECVD; + } + + /** + * {@return true if a stream in this state can be used for sending, that is, + * if this state is either {@link #READY} or {@link #SEND}}. + */ + public boolean isSending() { return this == READY || this == SEND; } + + /** + * {@return true if this state indicates that the stream has been reset by the sender} + */ + public boolean isReset() { return this == RESET_SENT || this == RESET_RECVD; } + } + + /** + * {@return the sending state of the stream} + */ + SendingStreamState sendingState(); + + /** + * Connects a writer to the sending end of this stream. + * @param scheduler A sequential scheduler that will + * push data on the returned {@linkplain + * QuicStreamWriter#QuicStreamWriter(SequentialScheduler) + * writer}. + * @return a {@code QuicStreamWriter} to write data to this + * stream. + * @throws IllegalStateException if a writer is already connected. + */ + QuicStreamWriter connectWriter(SequentialScheduler scheduler); + + /** + * Disconnect the writer, so that a new writer can be connected. + * + * @apiNote + * This can be useful for handing the stream over after having written + * some bytes. + * + * @param writer the writer to be disconnected + * @throws IllegalStateException if the given writer is not currently + * connected to the stream + */ + public void disconnectWriter(QuicStreamWriter writer); + + /** + * Abruptly closes the writing side of a stream by sending + * a RESET_STREAM frame. + * @param errorCode the application error code + */ + void reset(long errorCode) throws IOException; + + /** + * {@return the amount of data that has been sent} + * @apiNote + * This may include data that has not been acknowledged. + */ + long dataSent(); + + /** + * {@return the error code for this stream, or {@code -1}} + */ + long sndErrorCode(); + + /** + * {@return true if STOP_SENDING was received} + */ + boolean stopSendingReceived(); + + @Override + default boolean hasError() { + return sndErrorCode() >= 0; + } + + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicSenderStreamImpl.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicSenderStreamImpl.java new file mode 100644 index 00000000000..292face444c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicSenderStreamImpl.java @@ -0,0 +1,662 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodHandles.Lookup; +import java.lang.invoke.VarHandle; +import java.nio.ByteBuffer; +import java.util.Set; + +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.QuicConnectionImpl; +import jdk.internal.net.http.quic.TerminationCause; + +/** + * A class that implements the sending part of a quic stream. + */ +public final class QuicSenderStreamImpl extends AbstractQuicStream implements QuicSenderStream { + private volatile SendingStreamState sendingState; + private volatile QuicStreamWriterImpl writer; + private final Logger debug = Utils.getDebugLogger(this::dbgTag); + private final String dbgTag; + private final StreamWriterQueueImpl queue = new StreamWriterQueueImpl(); + private volatile long errorCode; + private volatile boolean stopSendingReceived; + + QuicSenderStreamImpl(QuicConnectionImpl connection, long streamId) { + super(connection, validateStreamId(connection, streamId)); + errorCode = -1; + sendingState = SendingStreamState.READY; + dbgTag = connection.streamDbgTag(streamId, "W"); + } + + private static long validateStreamId(QuicConnectionImpl connection, long streamId) { + if (QuicStreams.isBidirectional(streamId)) return streamId; + if (connection.isClientConnection() != QuicStreams.isClientInitiated(streamId)) { + throw new IllegalArgumentException("A remotely initiated stream can't be write-only"); + } + return streamId; + } + + String dbgTag() { + return dbgTag; + } + + @Override + public SendingStreamState sendingState() { + return sendingState; + } + + @Override + public QuicStreamWriter connectWriter(SequentialScheduler scheduler) { + var writer = this.writer; + if (writer == null) { + writer = new QuicStreamWriterImpl(scheduler); + if (Handles.WRITER.compareAndSet(this, null, writer)) { + if (debug.on()) debug.log("writer connected"); + return writer; + } + } + throw new IllegalStateException("writer already connected"); + } + + @Override + public void disconnectWriter(QuicStreamWriter writer) { + var previous = this.writer; + if (writer == previous) { + if (Handles.WRITER.compareAndSet(this, writer, null)) { + if (debug.on()) debug.log("writer disconnected"); + return; + } + } + throw new IllegalStateException("reader not connected"); + } + + + @Override + public void reset(long errorCode) throws IOException { + if (debug.on()) { + debug.log("Resetting stream %s due to %s", streamId(), + connection().quicInstance().appErrorToString(errorCode)); + } + setErrorCode(errorCode); + if (switchSendingState(SendingStreamState.RESET_SENT)) { + long streamId = streamId(); + if (debug.on()) { + debug.log("Requesting to send RESET_STREAM(%d, %d)", + streamId, errorCode); + } + queue.markReset(); + if (connection().isOpen()) { + connection().requestResetStream(streamId, errorCode); + } + } + } + + @Override + public long sndErrorCode() { + return errorCode; + } + + @Override + public boolean stopSendingReceived() { + return stopSendingReceived; + } + + @Override + public long dataSent() { + // returns the amount of data that has been submitted for + // sending downstream. This will be the amount of data that + // has been consumed by the downstream consumer. + return queue.bytesConsumed(); + } + + /** + * Called to set the max stream data for this stream. + * @apiNote as per RFC 9000, any value less than the current + * max stream data is ignored + * @param newLimit the proposed new max stream data + * @return the new limit that has been finalized for max stream data. + * This new limit may or may not have been increased to the proposed {@code newLimit}. + */ + public long setMaxStreamData(final long newLimit) { + return queue.setMaxStreamData(newLimit); + } + + /** + * Called by {@link QuicConnectionStreams} after a RESET_STREAM frame + * has been sent + */ + public void resetSent() { + queue.markReset(); + queue.close(); + } + + /** + * Called when the packet containing the RESET_STREAM frame for this + * stream has been acknowledged. + * @param finalSize the final size acknowledged + * @return true if the state was switched to RESET_RECVD as a result + * of this method invocation + */ + public boolean resetAcknowledged(long finalSize) { + long queueSize = queue.bytesConsumed(); + if (switchSendingState(SendingStreamState.RESET_RECVD)) { + if (debug.on()) { + debug.log("Reset received: final: %d, processed: %d", + finalSize, queueSize); + } + if (finalSize != queueSize) { + if (Log.errors()) { + Log.logError("Stream %d: Acknowledged reset has wrong size: acked: %d, expected: %d", + streamId(), finalSize, queueSize); + } + } + return true; + } + return false; + } + + /** + * Called when the packet containing the final STREAM frame for this + * stream has been acknowledged. + * @param finalSize the final size acknowledged + * @return true if the state was switched to DATA_RECVD as a result + * of this method invocation + */ + public boolean dataAcknowledged(long finalSize) { + long queueSize = queue.bytesConsumed(); + if (switchSendingState(SendingStreamState.DATA_RECVD)) { + if (debug.on()) { + debug.log("Last data received: final: %d, processed: %d", + finalSize, queueSize); + } + if (finalSize != queueSize) { + if (Log.errors()) { + Log.logError("Stream %d: Acknowledged data has wrong size: acked: %d, expected: %d", + streamId(), finalSize, queueSize); + } + } + } + return false; + } + + /** + * Called when a STOP_SENDING frame is received from the peer + * @param errorCode the error code + */ + public void stopSendingReceived(long errorCode) { + if (queue.stopSending(errorCode)) { + stopSendingReceived = true; + setErrorCode(errorCode); + try { + if (connection().isOpen()) { + reset(errorCode); + } + } catch (IOException io) { + if (debug.on()) debug.log("Reset failed: " + io); + } finally { + QuicStreamWriterImpl writer = this.writer; + if (writer != null) writer.wakeupWriter(); + } + } + } + + /** + * Called when the connection is closed locally + * @param terminationCause the termination cause + */ + void terminate(final TerminationCause terminationCause) { + setErrorCode(terminationCause.getCloseCode()); + queue.close(); + final QuicStreamWriterImpl writer = this.writer; + if (writer != null) { + writer.wakeupWriter(); + } + } + + /** + * A concrete implementation of the {@link StreamWriterQueue} for this + * stream. + */ + private final class StreamWriterQueueImpl extends StreamWriterQueue { + @Override + protected void wakeupProducer() { + // The scheduler is provided by the producer + // to wakeup and run the producer's write loop. + var writer = QuicSenderStreamImpl.this.writer; + if (writer != null) { + writer.wakeupWriter(); + } + } + + @Override + protected Logger debug() { + return debug; + } + + @Override + protected void wakeupConsumer() { + // Notify the connection impl that either the data is available + // for writing or the stream is blocked and the peer needs to be + // made aware. The connection should + // eventually call QuicSenderStreamImpl::poll to + // get the data available for writing and package it + // in a StreamFrame or notice that the stream is blocked and send a + // STREAM_DATA_BLOCKED frame. + connection().streamDataAvailableForSending(Set.of(streamId())); + } + + @Override + protected void switchState(SendingStreamState state) { + // called to indicate a change in the stream state. + // at the moment the only expected value is DATA_SENT + assert state == SendingStreamState.DATA_SENT; + switchSendingState(state); + } + + @Override + protected long streamId() { + return QuicSenderStreamImpl.this.streamId(); + } + } + + + /** + * The stream internal implementation of a QuicStreamWriter. + * Most of the logic is implemented in the StreamWriterQueue, + * which is subclassed here to provide an implementation of its + * few abstract methods. + */ + private class QuicStreamWriterImpl extends QuicStreamWriter { + QuicStreamWriterImpl(SequentialScheduler scheduler) { + super(scheduler); + } + + void wakeupWriter() { + scheduler.runOrSchedule(connection().quicInstance().executor()); + } + + @Override + public SendingStreamState sendingState() { + checkConnected(); + return QuicSenderStreamImpl.this.sendingState(); + } + + @Override + public void scheduleForWriting(ByteBuffer buffer, boolean last) throws IOException { + checkConnected(); + SendingStreamState state = sending(last); + switch (state) { + // this isn't atomic but it doesn't really matter since reset + // will be handled by the same thread that polls. + case READY, SEND -> { + // allow a last empty buffer to be submitted even + // if the connection is closed. That can help + // unblock the consumer side. + if (buffer != QuicStreamReader.EOF || !last) { + checkOpened(); + } + queue.submit(buffer, last); + } + case RESET_SENT, RESET_RECVD -> throw streamResetException(); + case DATA_SENT, DATA_RECVD -> throw streamClosedException(); + } + } + + @Override + public void queueForWriting(ByteBuffer buffer) throws IOException { + checkConnected(); + SendingStreamState state = sending(false); + switch (state) { + // this isn't atomic but it doesn't really matter since reset + // will be handled by the same thread that polls. + case READY, SEND -> { + checkOpened(); + queue.queue(buffer); + } + case RESET_SENT, RESET_RECVD -> throw streamResetException(); + case DATA_SENT, DATA_RECVD -> throw streamClosedException(); + } + } + + /** + * Compose an exception to throw if data is submitted after the + * stream was reset + * @return a new IOException + */ + IOException streamResetException() { + long resetByPeer = queue.resetByPeer(); + if (resetByPeer < 0) { + return new IOException("stream %s reset by peer: errorCode %s" + .formatted(streamId(), - resetByPeer - 1)); + } else { + return new IOException("stream %s has been reset".formatted(streamId())); + } + } + + /** + * Compose an exception to throw if data is submitted after the + * the final data has been sent + * @return a new IOException + */ + IOException streamClosedException() { + return new IOException("stream %s is closed - all data has been sent" + .formatted(streamId())); + } + + @Override + public long credit() { + checkConnected(); + // how much data the producer can send before + // reaching the flow control limit. Could be + // negative if the limit has been reached already. + return queue.producerCredit(); + } + + @Override + public void reset(long errorCode) throws IOException { + setErrorCode(errorCode); + checkConnected(); + QuicSenderStreamImpl.this.reset(errorCode); + } + + @Override + public QuicSenderStream stream() { + var stream = QuicSenderStreamImpl.this; + var writer = stream.writer; + return writer == this ? stream : null; + } + + @Override + public boolean connected() { + var writer = QuicSenderStreamImpl.this.writer; + return writer == this; + } + + private void checkConnected() { + if (!connected()) { + throw new IllegalStateException("writer not connected"); + } + } + } + + void checkOpened() throws IOException { + final TerminationCause terminationCause = connection().terminationCause(); + if (terminationCause == null) { + return; + } + throw terminationCause.getCloseCause(); + } + + /** + * {@return the number of bytes that are available for sending, subject + * to flow control} + * @implSpec + * This method does not return more than what flow control for this + * stream would allow at the time the method is called. + * @implNote + * If the sender part is not finished initializing the default + * implementation of this method will return 0. + */ + public long available() { + return queue.readyToSend(); + } + + /** + * Whether the sending is blocked due to flow control. + * @return {@code true} if sending is blocked due to flow control + */ + public boolean isBlocked() { + return queue.consumerBlocked(); + } + + /** + * {@return the size of this stream, if known} + * @implSpec + * This method returns {@code -1} if the size of the stream is not + * known. + */ + public long streamSize() { + return queue.streamSize(); + } + + /** + * Polls at most {@code maxBytes} from the {@link StreamWriterQueue} of + * this stream. The semantics are equivalent to that of {@link + * StreamWriterQueue#poll(int)} + * @param maxBytes the maximum number of bytes to poll for sending + * @return a ByteBuffer containing at most {@code maxBytes} remaining + * bytes. + */ + public ByteBuffer poll(int maxBytes) { + return queue.poll(maxBytes); + } + + @Override + public boolean isDone() { + return switch (sendingState()) { + case DATA_RECVD, RESET_RECVD -> + // everything acknowledged + true; + default -> + // the stream is only half closed + false; + }; + } + + @Override + public StreamState state() { + return sendingState(); + } + + /** + * Called when some data is submitted (or offered) by the + * producer. If the stream is in the READY state, this will + * switch the sending state to SEND. + * @implNote + * The parameter {@code last} is ignored at this stage. + * {@link #switchSendingState(SendingStreamState) + * switchSendingState(SendingStreamState.DATA_SENT)} will be called + * later on when the last piece of data has been pushed downstream. + * + * @param last whether there will be no further data submitted + * by the producer. + * + * @return the state before switching to SEND. + */ + private SendingStreamState sending(boolean last) { + SendingStreamState state = sendingState; + if (state == SendingStreamState.READY) { + switchSendingState(SendingStreamState.SEND); + } + return state; + } + + /** + * Called when the StreamWriterQueue implementation notifies of + * a state change. + * @param newState the new state, according to the StreamWriterQueue. + */ + private boolean switchSendingState(SendingStreamState newState) { + SendingStreamState oldState = sendingState; + if (debug.on()) { + debug.log("switchSendingState %s -> %s", + oldState, newState); + } + boolean switched = switch(newState) { + case SEND -> markSending(); + case DATA_SENT -> markDataSent(); + case DATA_RECVD -> markDataRecvd(); + case RESET_SENT -> markResetSent(); + case RESET_RECVD -> markResetRecvd(); + default -> throw new UnsupportedOperationException("switch state to " + newState); + }; + if (debug.on()) { + if (switched) { + debug.log("switched sending state from %s to %s", oldState, newState); + } else { + debug.log("sending state not switched; state is %s", sendingState); + } + } + + if (switched && newState.isTerminal()) { + notifyTerminalState(newState); + } + + return switched; + } + + private void notifyTerminalState(SendingStreamState state) { + assert state.isTerminal() : state; + connection().notifyTerminalState(streamId(), state); + } + + // SEND can only be set from the READY state + private boolean markSending() { + boolean done, switched = false; + SendingStreamState oldState; + do { + oldState = sendingState; + done = switch (oldState) { + // CAS: Compare And Set + case READY -> switched = + Handles.SENDING_STATE.compareAndSet(this, + oldState, SendingStreamState.SEND); + case SEND, RESET_RECVD, RESET_SENT -> true; + // there should be no further submission of data after DATA_SENT + case DATA_SENT, DATA_RECVD -> + throw new IllegalStateException(String.valueOf(oldState)); + }; + } while(!done); + return switched; + } + + // DATA_SENT can only be set from the SEND state + private boolean markDataSent() { + boolean done, switched = false; + SendingStreamState oldState; + do { + oldState = sendingState; + done = switch (oldState) { + // CAS: Compare And Set + case SEND -> switched = + Handles.SENDING_STATE.compareAndSet(this, + oldState, SendingStreamState.DATA_SENT); + case DATA_SENT, RESET_RECVD, RESET_SENT, DATA_RECVD -> true; + case READY -> throw new IllegalStateException(String.valueOf(oldState)); + }; + } while (!done); + return switched; + } + + // Reset can only be set in the READY, SEND, or DATA_SENT state + private boolean markResetSent() { + boolean done, switched = false; + SendingStreamState oldState; + do { + oldState = sendingState; + done = switch (oldState) { + // CAS: Compare And Set + case READY, SEND, DATA_SENT -> switched = + Handles.SENDING_STATE.compareAndSet(this, + oldState, SendingStreamState.RESET_SENT); + case RESET_RECVD, RESET_SENT, DATA_RECVD -> true; + }; + } while(!done); + return switched; + } + + // Called when the packet containing the last frame is acknowledged + // DATA_RECVD is a terminal state + private boolean markDataRecvd() { + boolean done, switched = false; + SendingStreamState oldState; + do { + oldState = sendingState; + done = switch (oldState) { + // CAS: Compare And Set + case DATA_SENT, RESET_SENT -> switched = + Handles.SENDING_STATE.compareAndSet(this, + oldState, SendingStreamState.DATA_RECVD); + case RESET_RECVD, DATA_RECVD -> true; + default -> throw new IllegalStateException("%s: %s -> %s" + .formatted(streamId(), oldState, SendingStreamState.RESET_RECVD)); + }; + } while(!done); + return switched; + } + + // Called when the packet containing the reset frame is acknowledged + // RESET_RECVD is a terminal state + private boolean markResetRecvd() { + boolean done, switched = false; + SendingStreamState oldState; + do { + oldState = sendingState; + done = switch (oldState) { + // CAS: Compare And Set + case DATA_SENT, RESET_SENT -> switched = + Handles.SENDING_STATE.compareAndSet(this, + oldState, SendingStreamState.RESET_RECVD); + case RESET_RECVD, DATA_RECVD -> true; + default -> throw new IllegalStateException("%s: %s -> %s" + .formatted(streamId(), oldState, SendingStreamState.RESET_RECVD)); + }; + } while(!done); + return switched; + } + + private void setErrorCode(long code) { + Handles.ERROR_CODE.compareAndSet(this, -1, code); + } + + // Some VarHandles to implement CAS semantics on top of plain + // volatile fields in this class. + private static class Handles { + static final VarHandle SENDING_STATE; + static final VarHandle WRITER; + static final VarHandle ERROR_CODE; + static { + Lookup lookup = MethodHandles.lookup(); + try { + SENDING_STATE = lookup.findVarHandle(QuicSenderStreamImpl.class, + "sendingState", SendingStreamState.class); + WRITER = lookup.findVarHandle(QuicSenderStreamImpl.class, + "writer", QuicStreamWriterImpl.class); + ERROR_CODE = lookup.findVarHandle(QuicSenderStreamImpl.class, + "errorCode", long.class); + } catch (Exception e) { + throw new ExceptionInInitializerError(e); + } + } + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStream.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStream.java new file mode 100644 index 00000000000..4d99784299f --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStream.java @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +/** + * An interface to model a QuicStream. + * A quic stream can be either unidirectional + * or bidirectional. A unidirectional stream can + * be opened for reading or for writing. + * Concrete subclasses of {@code QuicStream} should + * implement {@link QuicSenderStream} (unidirectional {@link + * StreamMode#WRITE_ONLY} stream), or {@link QuicReceiverStream} + * (unidirectional {@link StreamMode#READ_ONLY} stream), or + * {@link QuicBidiStream} (bidirectional {@link StreamMode#READ_WRITE} stream). + */ +public sealed interface QuicStream + permits QuicSenderStream, QuicReceiverStream, QuicBidiStream, AbstractQuicStream { + + /** + * An interface that unifies the three different stream states. + * @apiNote + * This is mostly used for logging purposes, to log the + * combined state of a stream. + */ + sealed interface StreamState permits + QuicReceiverStream.ReceivingStreamState, + QuicSenderStream.SendingStreamState, + QuicBidiStream.BidiStreamState { + String name(); + + /** + * {@return true if this is a terminal state} + */ + boolean isTerminal(); + } + + /** + * The stream operation mode. + * One of {@link #READ_ONLY}, {@link #WRITE_ONLY}, or {@link #READ_WRITE}. + */ + enum StreamMode { + READ_ONLY, WRITE_ONLY, READ_WRITE; + + /** + * {@return true if this operation mode allows reading data from the stream} + */ + public boolean isReadable() { + return this != WRITE_ONLY; + } + + /** + * {@return true if this operation mode allows writing data to the stream} + */ + public boolean isWritable() { + return this != READ_ONLY; + } + } + + /** + * {@return the stream ID of this stream} + */ + long streamId(); + + /** + * {@return this stream operation mode} + * One of {@link StreamMode#READ_ONLY}, {@link StreamMode#WRITE_ONLY}, + * or {@link StreamMode#READ_WRITE}. + */ + StreamMode mode(); + + /** + * {@return whether this stream is client initiated} + */ + boolean isClientInitiated(); + + /** + * {@return whether this stream is server initiated} + */ + boolean isServerInitiated(); + + /** + * {@return whether this stream is bidirectional} + */ + boolean isBidirectional(); + + /** + * {@return true if this stream is local initiated} + */ + boolean isLocalInitiated(); + + /** + * {@return true if this stream is remote initiated} + */ + boolean isRemoteInitiated(); + + /** + * The type of this stream, as an int. This is a number between + * 0 and 3 inclusive, and corresponds to the last two lowest bits + * of the stream ID. + *

    + *
  • 0x00: bidirectional, client initiated
  • + *
  • 0x01: bidirectional, server initiated
  • + *
  • 0x02: unidirectional, client initiated
  • + *
  • 0x03: unidirectional, server initiated
  • + *
+ * @return the type of this stream, as an int + */ + int type(); + + /** + * {@return the combined stream state} + * + * @apiNote + * This is mostly used for logging purposes, to log the + * combined state of a stream. + */ + StreamState state(); + + /** + * {@return true if the stream has errors} + * For a {@linkplain QuicBidiStream bidirectional} stream, + * this method returns true if either its sending part or + * its receiving part was closed with a non-zero error code. + */ + boolean hasError(); + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreamReader.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreamReader.java new file mode 100644 index 00000000000..f38a12d3d7c --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreamReader.java @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Queue; + +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.quic.streams.QuicReceiverStream.ReceivingStreamState; + +/** + * An abstract class to model a reader plugged into + * a QuicStream from which data can be read + */ +public abstract class QuicStreamReader { + + /** + * A sentinel inserted into the queue after the FIN it has been received. + */ + public static final ByteBuffer EOF = ByteBuffer.wrap(new byte[0]).asReadOnlyBuffer(); + + // The scheduler to invoke when data becomes + // available. + final SequentialScheduler scheduler; + + /** + * Creates a new instance of a QuicStreamReader. + * The given scheduler will not be invoked until the reader + * is {@linkplain #start() started}. + * + * @param scheduler A sequential scheduler that will + * poll data out of this reader. + */ + public QuicStreamReader(SequentialScheduler scheduler) { + this.scheduler = scheduler; + } + + /** + * {@return the receiving state of the stream} + * + * @apiNote + * This method returns the state of the {@link QuicReceiverStream} + * to which this writer is {@linkplain + * QuicReceiverStream#connectReader(SequentialScheduler) connected}. + * + * @throws IllegalStateException if this reader is {@linkplain + * QuicReceiverStream#disconnectReader(QuicStreamReader) no longer connected} + * to its stream + * + */ + public abstract ReceivingStreamState receivingState(); + + /** + * {@return the ByteBuffer at the head of the queue, + * or null if no data is available}. If the end of the stream is + * reached then {@link #EOF} is returned. + * + * @implSpec + * This method behave just like {@link Queue#poll()}. + * + * @throws IOException if the stream was closed locally or + * reset by the peer + * @throws IllegalStateException if this reader is {@linkplain + * QuicReceiverStream#disconnectReader(QuicStreamReader) no longer connected} + * to its stream + */ + public abstract ByteBuffer poll() throws IOException; + + /** + * {@return the ByteBuffer at the head of the queue, + * or null if no data is available} + * + * @implSpec + * This method behave just like {@link Queue#peek()}. + * + * @throws IOException if the stream was reset by the peer + * @throws IllegalStateException if this reader is {@linkplain + * QuicReceiverStream#disconnectReader(QuicStreamReader) no longer connected} + * to its stream + */ + public abstract ByteBuffer peek() throws IOException; + + /** + * {@return the stream this reader is connected to, or {@code null} + * if this reader is not currently {@linkplain #connected() connected}} + */ + public abstract QuicReceiverStream stream(); + + + /** + * {@return true if this reader is connected to its stream} + * @see QuicReceiverStream#connectReader(SequentialScheduler) + * @see QuicReceiverStream#disconnectReader(QuicStreamReader) + */ + public abstract boolean connected(); + + /** + * {@return true if this reader has been {@linkplain #start() started}} + */ + public abstract boolean started(); + + /** + * Starts the reader. The {@linkplain + * QuicReceiverStream#connectReader(SequentialScheduler) scheduler} + * will not be invoked until the reader is {@linkplain #start() started}. + */ + public abstract void start(); + + /** + * {@return whether reset was received or read by this reader} + */ + public boolean isReset() { + return stream().receivingState().isReset(); + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreamWriter.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreamWriter.java new file mode 100644 index 00000000000..ef1de558e10 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreamWriter.java @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.quic.frames.StreamFrame; +import jdk.internal.net.http.quic.streams.QuicSenderStream.SendingStreamState; + +/** + * An abstract class to model a writer plugged into + * a QuicStream to which data can be written. The data + * is wrapped in {@link StreamFrame} + * before being written. + */ +public abstract class QuicStreamWriter { + + // The scheduler to invoke when flow credit + // become available. + final SequentialScheduler scheduler; + + /** + * Creates a new instance of a QuicStreamWriter. + * @param scheduler A sequential scheduler that will + * push data into this writer. + */ + public QuicStreamWriter(SequentialScheduler scheduler) { + this.scheduler = scheduler; + } + + /** + * {@return the sending state of the stream} + * + * @apiNote + * This method returns the state of the {@link QuicSenderStream} + * to which this writer is {@linkplain + * QuicSenderStream#connectWriter(SequentialScheduler) connected}. + * + * @throws IllegalStateException if this writer is {@linkplain + * QuicSenderStream#disconnectWriter(QuicStreamWriter) no longer connected} + * to its stream + */ + public abstract SendingStreamState sendingState(); + + /** + * Pushes a ByteBuffer to be scheduled for writing on the stream. + * The ByteBuffer will be wrapped in a StreamFrame before being + * sent. Data that cannot be sent due to a lack of flow + * credit will be buffered. + * + * @param buffer A byte buffer to schedule for writing + * @param last Whether that's the last data that will be sent + * through this stream. + * + * @throws IOException if the state of the stream isn't + * {@link SendingStreamState#READY} or {@link SendingStreamState#SEND} + * @throws IllegalStateException if this writer is {@linkplain + * QuicSenderStream#disconnectWriter(QuicStreamWriter) no longer connected} + * to its stream + */ + public abstract void scheduleForWriting(ByteBuffer buffer, boolean last) + throws IOException; + + /** + * Queues a {@code ByteBuffer} on the writing queue for this stream. + * The consumer will not be woken up. More data should be submitted + * using {@link #scheduleForWriting(ByteBuffer, boolean)} in order + * to wake the consumer. + * + * @apiNote + * Use this method as a hint that more data will be + * upcoming shortly that might be aggregated with + * the data being queued in order to reduce the number + * of packets that will be sent to the peer. + * This is useful when a small number of bytes + * need to be written to the stream before actual stream + * data. Typically, this can be used for writing the + * HTTP/3 stream type for a unidirectional HTTP/3 stream + * before starting to send stream data. + * + * @param buffer A byte buffer to schedule for writing + * + * @throws IOException if the state of the stream isn't + * {@link SendingStreamState#READY} or {@link SendingStreamState#SEND} + * @throws IllegalStateException if this writer is {@linkplain + * QuicSenderStream#disconnectWriter(QuicStreamWriter) no longer connected} + * to its stream + */ + public abstract void queueForWriting(ByteBuffer buffer) + throws IOException; + + /** + * Indicates how many bytes the writer is + * prepared to received for sending. + * When that value grows from 0, and if the queue has + * no pending data, the {@code scheduler} + * is triggered to elicit more calls to + * {@link #scheduleForWriting(ByteBuffer,boolean)}. + * + * @apiNote This information is used to avoid + * buffering too much data while waiting for flow + * credit on the underlying stream. When flow credit + * is available, the {@code scheduler} loop is + * invoked to resume writing. The scheduler can then + * call this method to figure out how much data to + * request from upstream. + * + * @throws IllegalStateException if this writer is {@linkplain + * QuicSenderStream#disconnectWriter(QuicStreamWriter) no longer connected} + * to its stream + */ + public abstract long credit(); + + /** + * Abruptly resets the stream. + * + * @param errorCode the application error code + * + * @throws IllegalStateException if this writer is {@linkplain + * QuicSenderStream#disconnectWriter(QuicStreamWriter) no longer connected} + * to its stream + */ + public abstract void reset(long errorCode) throws IOException; + + /** + * {@return the stream this writer is connected to, or {@code null} + * if this writer isn't currently {@linkplain #connected() connected}} + */ + public abstract QuicSenderStream stream(); + + /** + * {@return true if this writer is connected to its stream} + * @see QuicSenderStream#connectWriter(SequentialScheduler) + * @see QuicSenderStream#disconnectWriter(QuicStreamWriter) + */ + public abstract boolean connected(); + + /** + * {@return true if STOP_SENDING was received} + */ + public boolean stopSendingReceived() { + return connected() ? stream().stopSendingReceived() : false; + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreams.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreams.java new file mode 100644 index 00000000000..8706cdcf887 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/QuicStreams.java @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import jdk.internal.net.http.quic.QuicConnectionImpl; + +/** + * A collection of utilities methods to analyze and work with + * quic streams. + */ +public final class QuicStreams { + private QuicStreams() { throw new InternalError("should not come here"); } + + public static final int TYPE_MASK = 0x03; + public static final int UNI_MASK = 0x02; + public static final int SRV_MASK = 0x01; + + public static int streamType(long streamId) { + return (int) streamId & TYPE_MASK; + } + + public static boolean isBidirectional(long streamId) { + return ((int) streamId & UNI_MASK) == 0; + } + + public static boolean isUnidirectional(long streamId) { + return ((int) streamId & UNI_MASK) == UNI_MASK; + } + + public static boolean isBidirectional(int streamType) { + return (streamType & UNI_MASK) == 0; + } + + public static boolean isUnidirectional(int streamType) { + return (streamType & UNI_MASK) == UNI_MASK; + } + + public static boolean isClientInitiated(long streamId) { + return ((int) streamId & SRV_MASK) == 0; + } + + public static boolean isServerInitiated(long streamId) { + return ((int) streamId & SRV_MASK) == SRV_MASK; + } + + public static boolean isClientInitiated(int streamType) { + return (streamType & SRV_MASK) == 0; + } + + public static boolean isServerInitiated(int streamType) { + return (streamType & SRV_MASK) == SRV_MASK; + } + + public static AbstractQuicStream createStream(QuicConnectionImpl connection, long streamId) { + int type = streamType(streamId); + boolean isClient = connection.isClientConnection(); + return switch (type) { + case 0x00, 0x01 -> new QuicBidiStreamImpl(connection, streamId); + case 0x02 -> isClient ? new QuicSenderStreamImpl(connection, streamId) + : new QuicReceiverStreamImpl(connection, streamId); + case 0x03 -> isClient ? new QuicReceiverStreamImpl(connection, streamId) + : new QuicSenderStreamImpl(connection, streamId); + default -> throw new IllegalArgumentException("bad stream type %s for stream %s" + .formatted(type, streamId)); + }; + } + +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/StreamCreationPermit.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/StreamCreationPermit.java new file mode 100644 index 00000000000..5495681b12a --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/StreamCreationPermit.java @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.AbstractQueuedLongSynchronizer; +import java.util.function.Function; + +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.SequentialScheduler; + +/** + * Quic specifies limits on the number of uni and bidi streams that an endpoint can create. + * This {@code StreamCreationPermit} is used to keep track of that limit and is expected to be + * used before attempting to open a Quic stream. Either of {@link #tryAcquire()} or + * {@link #tryAcquire(long, TimeUnit, Executor)} must be used before attempting to open a new stream. Stream + * must only be opened if that method returns {@code true} which implies the stream creation limit + * hasn't yet reached. + *

+ * It is expected that for each of the stream types (remote uni, remote bidi, local uni + * and local bidi) a separate instance of {@code StreamCreationPermit} will be used. + *

+ * An instance of {@code StreamCreationPermit} starts with an initial limit and that limit can be + * increased to newer higher values whenever necessary. The limit however cannot be reduced to a + * lower value. + *

+ * None of the methods, including {@link #tryAcquire(long, TimeUnit, Executor)} and {@link #tryAcquire()} + * block the caller thread. + */ +final class StreamCreationPermit { + + private final InternalSemaphore semaphore; + private final SequentialScheduler permitAcquisitionScheduler = + SequentialScheduler.lockingScheduler(new TryAcquireTask()); + + private final ConcurrentLinkedQueue acquirers = new ConcurrentLinkedQueue<>(); + + /** + * @param initialMaxStreams the initial max streams limit + * @throws IllegalArgumentException if {@code initialMaxStreams} is less than 0 + * @throws NullPointerException if executor is null + */ + StreamCreationPermit(final long initialMaxStreams) { + if (initialMaxStreams < 0) { + throw new IllegalArgumentException("Invalid max streams limit: " + initialMaxStreams); + } + this.semaphore = new InternalSemaphore(initialMaxStreams); + } + + /** + * Attempts to increase the limit to {@code newLimit}. The limit will be atomically increased + * to the {@code newLimit}. If the {@linkplain #currentLimit() current limit} is higher than + * the {@code newLimit}, then the limit isn't changed and this method returns {@code false}. + * + * @param newLimit the new limit + * @return true if the limit was successfully increased to {@code newLimit}, false otherwise. + */ + boolean tryIncreaseLimitTo(final long newLimit) { + final boolean increased = this.semaphore.tryIncreaseLimitTo(newLimit); + if (increased) { + // let any waiting acquirers attempt acquiring a permit + permitAcquisitionScheduler.runOrSchedule(); + } + return increased; + } + + /** + * Attempts to acquire a permit to open a new stream. This method does not block and returns + * immediately. A stream should only be opened if the permit was successfully acquired. + * + * @return true if the permit was acquired and a new stream is allowed to be opened. + * false otherwise. + */ + boolean tryAcquire() { + return this.semaphore.tryAcquireShared(1) >= 0; + } + + /** + * Attempts to acquire a permit to open a new stream. If the permit is available then this method + * returns immediately with a {@link CompletableFuture} whose result is {@code true}. If the + * permit isn't currently available then this method returns a {@code CompletableFuture} which + * completes with a result of {@code false} if no permits were available for the duration + * represented by the {@code timeout}. If during this {@code timeout} period, a permit is + * acquired, because of an increase in the stream limit, then the returned + * {@code CompletableFuture} completes with a result of {@code true}. + * + * @param timeout the maximum amount of time to attempt acquiring a permit, after which the + * {@code CompletableFuture} will complete with a result of {@code false} + * @param unit the timeout unit + * @param executor the executor that will be used to asynchronously complete the + * returned {@code CompletableFuture} if a permit is acquired after this + * method has returned + * @return a {@code CompletableFuture} whose result will be {@code true} if the permit was + * acquired and {@code false} otherwise + * @throws IllegalArgumentException if {@code timeout} is negative + * @throws NullPointerException if the {@code executor} is null + */ + CompletableFuture tryAcquire(final long timeout, final TimeUnit unit, + final Executor executor) { + Objects.requireNonNull(executor); + if (timeout < 0) { + throw new IllegalArgumentException("invalid timeout: " + timeout); + } + if (tryAcquire()) { + return MinimalFuture.completedFuture(true); + } + final CompletableFuture future = new MinimalFuture() + .orTimeout(timeout, unit) + .handle((acquired, t) -> { + if (t instanceof TimeoutException te) { + // timed out + return MinimalFuture.completedFuture(false); + } + if (t == null) { + // completed normally + return MinimalFuture.completedFuture(acquired); + } + return MinimalFuture.failedFuture(t); + }).thenComposeAsync(Function.identity(), executor); + var waiter = new Waiter(future, executor); + this.acquirers.add(waiter); + // if the future completes in timeout the Waiter should be removed from the list. + // because this is a queue it might not be too efficient... + future.whenComplete((r,t) -> { if (r != null && !r) acquirers.remove(waiter);}); + // if stream limit might have increased in the meantime, + // trigger the task to have this newly registered waiter notified + // TODO: should we call runOrSchedule(executor) here instead? + permitAcquisitionScheduler.runOrSchedule(); + return future; + } + + /** + * {@return the current limit for stream creation} + */ + long currentLimit() { + return this.semaphore.currentLimit(); + } + + private final record Waiter(CompletableFuture acquirer, + Executor executor) { + Waiter { + assert acquirer != null : "Acquirer cannot be null"; + assert executor != null : "Executor cannot be null"; + } + } + + /** + * A task which iterates over the waiting acquirers and attempt + * to acquire a permit. If successful, the waiting acquirer(s) (i.e. the CompletableFuture(s)) + * are completed successfully. If not, the waiting acquirers continue to stay in the wait list + */ + private final class TryAcquireTask implements Runnable { + + @Override + public void run() { + Waiter waiter = null; + while ((waiter = acquirers.peek()) != null) { + final CompletableFuture acquirer = waiter.acquirer; + if (acquirer.isCancelled() || acquirer.isDone()) { + // no longer interested, or already completed, remove it + acquirers.remove(waiter); + continue; + } + if (!tryAcquire()) { + // limit reached, no permits available yet + break; + } + // compose a step which rolls back the acquired permit if the + // CompletableFuture completed in some other thread, after the permit was acquired. + acquirer.whenComplete((acquired, t) -> { + final boolean shouldRollback = acquirer.isCancelled() + || t != null + || !acquired; + if (shouldRollback) { + final boolean released = StreamCreationPermit.this.semaphore.releaseShared(1); + assert released : "acquired permit wasn't released"; + // an additional permit is now available due to the release, let any waiters + // acquire it if needed + permitAcquisitionScheduler.runOrSchedule(); + } + }); + // got a permit, complete the waiting acquirer + acquirers.remove(waiter); + acquirer.completeAsync(() -> true, waiter.executor); + } + } + } + + /** + * A {@link AbstractQueuedLongSynchronizer} whose {@linkplain #getState() state} represents + * the number of permits that have currently been acquired. This {@code Semaphore} only + * supports "shared" mode; i.e. exclusive mode isn't supported. + *

+ * The {@code Semaphore} maintains a {@linkplain #limit limit} which represents + * the maximum number of permits that can be acquired through an instance of this class. + * The {@code limit} can be {@linkplain #tryIncreaseLimitTo(long) increased} but cannot be + * reduced from the previous set limit. + */ + private static final class InternalSemaphore extends AbstractQueuedLongSynchronizer { + private static final long serialVersionUID = 4280985311770761500L; + + private final AtomicLong limit; + + /** + * @param initialLimit the initial limit, must be >=0 + */ + private InternalSemaphore(final long initialLimit) { + assert initialLimit >= 0 : "not a positive initial limit: " + initialLimit; + this.limit = new AtomicLong(initialLimit); + setState(0 /* num acquired */); + } + + /** + * Attempts to acquire additional permits. If no permits can be acquired, + * then this method returns -1. Upon successfully acquiring the + * {@code additionalAcquisitions} this method returns a value {@code >=0} which represents + * the additional number of permits that are available for acquisition. + * + * @param additionalAcquisitions the additional permits that are requested + * @return -1 If no permits can be acquired. Value >=0, representing the permits that are + * still available for acquisition. + */ + @Override + protected long tryAcquireShared(final long additionalAcquisitions) { + while (true) { + final long alreadyAcquired = getState(); + final long totalOnAcquisition = alreadyAcquired + additionalAcquisitions; + final long currentLimit = limit.get(); + if (totalOnAcquisition > currentLimit) { + return -1; // exceeds limit, so cannot acquire + } + final long numAvailableUponAcquisition = currentLimit - totalOnAcquisition; + if (compareAndSetState(alreadyAcquired, totalOnAcquisition)) { + return numAvailableUponAcquisition; + } + } + } + + /** + * Attempts to release permits + * + * @param releases the number of permits to release + * @return true if the permits were released, false otherwise + * @throws IllegalArgumentException if the number of {@code releases} exceeds the total + * number of permits that have been acquired + */ + @Override + protected boolean tryReleaseShared(final long releases) { + while (true) { + final long currentAcquisitions = getState(); + final long totalAfterRelease = currentAcquisitions - releases; + if (totalAfterRelease < 0) { + // we attempted to release more permits than what was acquired + throw new IllegalArgumentException("cannot release " + releases + + " permits from " + currentAcquisitions + " acquisitions"); + } + if (compareAndSetState(currentAcquisitions, totalAfterRelease)) { + return true; + } + } + } + + /** + * Tries to increase the limit to the {@code newLimit}. If the {@code newLimit} is lesser + * than the current limit, then this method returns false. Otherwise, this method will attempt + * to atomically increase the limit to {@code newLimit}. + * + * @param newLimit The new limit to set + * @return true if the limit was increased to {@code newLimit}. false otherwise + */ + private boolean tryIncreaseLimitTo(final long newLimit) { + long currentLimit = this.limit.get(); + while (currentLimit < newLimit) { + if (this.limit.compareAndSet(currentLimit, newLimit)) { + return true; + } + currentLimit = this.limit.get(); + } + return false; + } + + /** + * {@return the current limit} + */ + private long currentLimit() { + return this.limit.get(); + } + } +} diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/StreamWriterQueue.java b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/StreamWriterQueue.java new file mode 100644 index 00000000000..ebf205479f6 --- /dev/null +++ b/src/java.net.http/share/classes/jdk/internal/net/http/quic/streams/StreamWriterQueue.java @@ -0,0 +1,550 @@ +/* + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.quic.streams; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.locks.ReentrantLock; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.streams.QuicSenderStream.SendingStreamState; + +/** + * A class to handle the writing queue of a {@link QuicSenderStream}. + * This class maintains a queue of byte buffer containing stream data + * that has not yet been packaged for sending. It also keeps track of + * the max stream data value. + * It acts as a mailbox between a producer (typically a {@link QuicStreamWriter}), + * and a consumer (typically a {@link jdk.internal.net.http.quic.QuicConnectionImpl}). + * This class is abstract: a concrete implementation of this class must only + * implement {@link #wakeupProducer()} and {@link #wakeupConsumer()} which should + * wake up the producer and consumer respectively, when data can be polled or + * submitted from the queue. + */ +abstract class StreamWriterQueue { + /** + * The amount of data that a StreamWriterQueue is willing to buffer. + * The queue will buffer excess data, but will not wake up the producer + * until the excess is consumed. + */ + private static final int BUFFER_SIZE = + Utils.getIntegerProperty("jdk.httpclient.quic.streamBufferSize", 1 << 16); + + // The current buffer containing data to send. + private ByteBuffer current; + // The offset of the data that has been consumed + private volatile long bytesConsumed; + // The offset of the data that has been supplied by the + // producer. + // bytesProduced >= bytesConsumed at all times. + private volatile long bytesProduced; + // The stream size, when known, -1 otherwise. + // The stream size may be known at the creation of the stream, + // or at the latest when the last ByteBuffer is provided by + // the producer. + private volatile long streamSize = -1; + // true if reset was requested, false otherwise + private volatile boolean resetRequested; + // The maximum offset that will be accepted by the peer at this + // time. bytesConsumed <= maxStreamData at all times. + private volatile long maxStreamData; + // negative if stop sending was received; contains -(errorCode + 1) + private volatile long stopSending; + // The queue to buffer data before it's polled by the consumer + private final ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); + private final ReentrantLock lock = new ReentrantLock(); + + protected final void lock() { + lock.lock(); + } + + protected final void unlock() { + lock.unlock(); + } + + protected abstract Logger debug(); + + /** + * This method is called by the consumer to poll data from the stream + * queue. This method will return a {@code ByteBuffer} with at most + * {@code maxbytes} remaining bytes. The {@code ByteBuffer} may contain + * less bytes if not enough bytes are available, or if there is not + * enough {@linkplain #consumerCredit() credit} to send {@code maxbytes} + * to the peer. Only stream credit is taken into account. Taking into + * account connection credit is the responsibility of the caller. + * If there is no credit, or if there is no data available, {@code null} + * is returned. When credit and data are available again, {@link #wakeupConsumer()} + * is called to wake up the consumer. + * + * @apiNote + * This method increases the consumer offset. It must not be called concurrently + * by two different threads. + * + * @implNote + * If the producer was blocked due to full buffer before this method was called + * and the method removes enough of buffered data, + * {@link #wakeupProducer()} is called. + * + * @param maxbytes the maximum number of bytes the consumer is prepared + * to consume. + * @return a {@code ByteBuffer} containing at most {@code maxbytes}, or {@code null} + * if no data is available or data is blocked by flow control. + */ + public final ByteBuffer poll(int maxbytes) { + boolean producerWasBlocked, producerUnblocked; + long produced, consumed; + ByteBuffer buffer; + long credit = consumerCredit(); + assert credit >= 0 : credit; + if (credit < maxbytes) { + maxbytes = (int)credit; + } + if (maxbytes <= 0) return null; + lock(); + try { + producerWasBlocked = producerBlocked(); + buffer = current; + if (buffer == null) { + buffer = current = queue.poll(); + } + if (buffer == null) { + return null; + } + int remaining = buffer.remaining(); + int position = buffer.position(); + consumed = bytesConsumed; + if (remaining <= maxbytes) { + current = queue.poll(); + bytesConsumed = consumed = Math.addExact(consumed, remaining); + } else { + buffer = buffer.slice(position, maxbytes); + current.position(position + maxbytes); + bytesConsumed = consumed = Math.addExact(consumed, maxbytes); + } + long size = streamSize; + produced = bytesProduced; + producerUnblocked = producerWasBlocked && !producerBlocked(); + if (StreamWriterQueue.class.desiredAssertionStatus()) { + assert consumed <= produced + : "consumed: " + consumed + ", produced: " + produced + ", size: " + size; + assert size == -1 || consumed <= size + : "consumed: " + consumed + ", produced: " + produced + ", size: " + size; + assert size == -1 || produced <= size + : "consumed: " + consumed + ", produced: " + produced + ", size: " + size; + } + if (size >= 0 && consumed == size) { + switchState(SendingStreamState.DATA_SENT); + } + } finally { + unlock(); + } + if (producerUnblocked) { + debug().log("producer unblocked produced:%s, consumed:%s", + produced, consumed); + wakeupProducer(); + } + return buffer; + } + + /** + * Updates the flow control credit for this queue. + * The maximum offset that will be accepted by the consumer + * can only increase. Value that are less or equal to the + * current value of the max stream data are ignored. + * + * @implSpec + * If the consumer was blocked due to flow control before + * this method was called, and the new value of the max + * stream data allows to unblock the consumer, and data + * is available, {@link #wakeupConsumer()} is called. + * + * @param data the maximum offset that will be accepted by + * the consumer + * @return the maximum offset that will be accepted by the + * consumer. + */ + public final long setMaxStreamData(long data) { + assert data >= 0 : "maxStreamData: " + data; + long max, produced, consumed; + boolean consumerWasBlocked, consumerUnblocked; + lock(); + try { + max = maxStreamData; + consumed = bytesConsumed; + produced = bytesProduced; + consumerWasBlocked = consumerBlocked(); + if (data <= max) return max; + maxStreamData = max = data; + consumerUnblocked = consumerWasBlocked && !consumerBlocked(); + if (StreamWriterQueue.class.desiredAssertionStatus()) { + long size = streamSize; + assert consumed <= produced; + assert size == -1 || consumed <= size; + assert size == -1 || produced <= size; + } + } finally { + unlock(); + } + debug().log("set max stream data: %s", max); + if (consumerUnblocked && produced > 0) { + debug().log("consumer unblocked produced:%s, consumed:%s, max stream data:%s", + produced, consumed, max); + wakeupConsumer(); + } + return max; + } + + /** + * Whether the producer is blocked due to flow control. + * + * @return whether the producer is blocked due to full buffers + */ + public final boolean producerBlocked() { + return producerCredit() <= 0; + } + + /** + * Whether the consumer is blocked due to flow control. + * + * @return whether the producer is blocked due to flow control + */ + public final boolean consumerBlocked() { + return consumerCredit() <= 0; + } + + /** + * {@return the offset of the data consumed by the consumer} + * + * @apiNote + * The returned value is only weakly consistent: it is subject + * to race conditions if {@link #poll(int)} is called concurrently + * by another thread. + */ + public final long bytesConsumed() { + return bytesConsumed; + } + + /** + * {@return the offset of the data provided by the producer} + * + * @apiNote + * The returned value is only weakly consistent: it is subject + * to race conditions if {@link #submit(ByteBuffer, boolean)} + * or {@link #queue(ByteBuffer)} are called concurrently + * by another thread. + */ + public final long bytesProduced() { + return bytesProduced; + } + + /** + * {@return the amount of produced data which has not been consumed yet} + * This is independent of flow control. + * + * @apiNote + * The returned value is only weakly consistent: it is subject + * to race conditions if {@link #submit(ByteBuffer, boolean)} + * or {@link #queue(ByteBuffer)} or + * {@link #poll(int)} are called concurrently + * by another thread. + */ + public final long available() { + return bytesProduced - bytesConsumed; + } + + /** + * {@return the stream size if known, {@code -1} otherwise} + * + * @apiNote + * The returned value is only weakly consistent: it is subject + * to race conditions if {@link #submit(ByteBuffer, boolean)} + * is called concurrently by another thread. + */ + public final long streamSize() { + return streamSize; + } + + /** + * {@return the maximum offset that the peer is prepared to accept} + * + * @apiNote + * The returned value is only weakly consistent: it is subject + * to race conditions if {@link #setMaxStreamData(long)} is called + * concurrently by another thread. + */ + public final long maxStreamData() { + return maxStreamData; + } + + /** + * {@return {@code true} if the consumer has reached the end of + * this stream (equivalent to EOF)} + * This is independent of flow control. + * + * @apiNote + * The returned value is only weakly consistent: it is subject + * to race conditions if {@link #submit(ByteBuffer, boolean)} + * or {@link #poll(int)} are called concurrently + * by another thread. + */ + public final boolean isConsumerDone() { + long size = streamSize; + long consumed = bytesConsumed; + assert size == -1 || size >= consumed; + return size >= 0 && size <= consumed; + } + + /** + * {@return {@code true} if the producer has reached the end of + * this stream (equivalent to EOF)} + * This is independent of flow control. + * + * @apiNote + * The returned value is only weakly consistent: it is subject + * to race conditions if {@link #submit(ByteBuffer, boolean)} + * is called concurrently by another thread. + */ + public final boolean isProducerDone() { + return streamSize >= 0; + } + + /** + * This method is called by the producer to submit data to this + * stream. The producer should not modify the provided buffer + * after this point. The provided buffer will be queued even if + * the produced data exceeds the maximum offset that the peer + * is prepared to accept. + * + * @apiNote + * If sufficient credit is available, this method will wake + * up the consumer. + * + * @param buffer a buffer containing data for the stream + * @param last whether this is the last buffer that will ever be + * provided by the provided + * @throws IOException if the stream was reset by peer + * @throws IllegalStateException if the last data was submitted already + */ + public final void submit(ByteBuffer buffer, boolean last) throws IOException { + offer(buffer, last, false); + } + + /** + * This method is called by the producer to queue data to this + * stream. The producer should not modify the provided buffer + * after this point. The provided buffer will be queued even if + * the produced data exceeds the maximum offset that the peer + * is prepared to accept. + * + * @apiNote + * The consumer will not be woken, even if enough credit is + * available. More data should be submitted using + * {@link #submit(ByteBuffer, boolean)} in order to wake up the consumer. + * + * @param buffer a buffer containing data for the stream + * @throws IOException if the stream was reset by peer + * @throws IllegalStateException if the last data was submitted already + */ + public final void queue(ByteBuffer buffer) throws IOException { + offer(buffer, false, true); + } + + /** + * Queues a buffer in the writing queue. + * + * @param buffer the buffer to queue + * @param last whether this is the last data for the stream + * @param waitForMore whether we should wait for the next submission before + * waking up the consumer + * @throws IOException if the stream was reset by peer + * @throws IllegalStateException if the last data was submitted already + */ + private void offer(ByteBuffer buffer, boolean last, boolean waitForMore) + throws IOException { + long length = buffer.remaining(); + long consumed, produced, max; + boolean wakeupConsumer; + lock(); + try { + long stopSending = this.stopSending; + if (stopSending < 0) { + throw new IOException("Stream %s reset by peer: errorCode %s" + .formatted(streamId(), 1 - stopSending)); + } + if (resetRequested) return; + if (streamSize >= 0) { + throw new IllegalStateException("Too many bytes provided"); + } + consumed = bytesConsumed; + max = maxStreamData; + produced = Math.addExact(bytesProduced, length); + bytesProduced = produced; + if (length > 0 || last) { + // allow to queue a zero-length buffer if it's the last. + queue.offer(buffer); + } + if (last) { + streamSize = produced; + } + assert consumed <= produced; + wakeupConsumer = consumed < max && consumed < produced + || consumed == produced && last; + } finally { + unlock(); + } + if (wakeupConsumer && !waitForMore) { + debug().log("consumer unblocked produced:%s, consumed:%s, max stream data:%s", + produced, consumed, max); + wakeupConsumer(); + } + } + + /** + * {@return the credit of the producer} + * @implSpec + * this is the desired buffer size minus the amount of data already buffered. + */ + public final long producerCredit() { + lock(); + try { + return BUFFER_SIZE - available(); + } finally { + unlock(); + } + } + + /** + * {@return the credit of the consumer} + * @implSpec + * This is equivalent to {@link #maxStreamData()} - {@link #bytesConsumed()}. + */ + public final long consumerCredit() { + lock(); + try { + return maxStreamData - bytesConsumed; + } finally { + unlock(); + } + } + + /** + * {@return the amount of available data that can be sent + * with respect to flow control in this stream}. + * This does not take into account the global connection + * flow control. + */ + public final long readyToSend() { + long consumed, produced, max; + lock(); + try { + consumed = bytesConsumed; + max = maxStreamData; + produced = bytesProduced; + } finally { + unlock(); + } + assert max >= consumed; + assert produced >= consumed; + return Math.min(max - consumed, produced - consumed); + } + + public final void markReset() { + lock(); + try { + resetRequested = true; + } finally { + unlock(); + } + } + + final void close() { + lock(); + try { + bytesProduced = bytesConsumed; + queue.clear(); + current = null; + } finally { + unlock(); + } + } + + /** + * Called when a stop sending frame is received for this stream + * @param errorCode the error code + */ + protected final boolean stopSending(long errorCode) { + long stopSending; + lock(); + try { + if (resetRequested) return false; + if (streamSize >= 0 && bytesConsumed == streamSize) return false; + if ((stopSending = this.stopSending) < 0) return false; + this.stopSending = stopSending = - (errorCode + 1); + } finally { + unlock(); + } + assert stopSending < 0 && stopSending == - (errorCode + 1); + return true; + } + + /** + * {@return -1 minus the error code that was supplied by the peer + * when requesting for stop sending} + * @apiNote a strictly negative value indicates that the stream was + * reset by the peer. The error code supplied by the peer + * can be obtained with the formula:

{@code
+     *    long errorCode = - (resetByPeer() + 1);
+     *    }
+ */ + final long resetByPeer() { + return stopSending; + } + + /** + * This method is called to wake up the consumer when there is + * credit and data available for the consumer. + */ + protected abstract void wakeupConsumer(); + + /** + * This method is called to wake up the producer when there is + * credit available for the producer. + */ + protected abstract void wakeupProducer(); + + /** + * Called to switch the sending state when data has been sent. + * @param dataSent the new state - typically {@link SendingStreamState#DATA_SENT} + */ + protected abstract void switchState(SendingStreamState dataSent); + + /** + * {@return the stream id this queue was created for} + */ + protected abstract long streamId(); + +} diff --git a/src/java.net.http/share/classes/module-info.java b/src/java.net.http/share/classes/module-info.java index 14cbb85291d..392385136b0 100644 --- a/src/java.net.http/share/classes/module-info.java +++ b/src/java.net.http/share/classes/module-info.java @@ -75,6 +75,18 @@ *
  • {@systemProperty jdk.httpclient.hpack.maxheadertablesize} (default: 16384 or * 16 kB)
    The HTTP/2 client maximum HPACK header table size in bytes. *

  • + *
  • {@systemProperty jdk.httpclient.qpack.decoderMaxTableCapacity} (default: 0) + *
    The HTTP/3 client maximum QPACK decoder dynamic header table size in bytes. + *
    Setting this value to a positive number will allow HTTP/3 servers to add entries + * to the QPack decoder's dynamic table. When set to 0, servers are not permitted to add + * entries to the client's QPack encoder's dynamic table. + *

  • + *
  • {@systemProperty jdk.httpclient.qpack.encoderTableCapacityLimit} (default: 4096, + * or 4 kB) + *
    The HTTP/3 client maximum QPACK encoder dynamic header table size in bytes. + *
    Setting this value to a positive number allows the HTTP/3 client's QPack encoder to + * add entries to the server's QPack decoder's dynamic table, if the server permits it. + *

  • *
  • {@systemProperty jdk.httpclient.HttpClient.log} (default: none)
    * Enables high-level logging of various events through the {@linkplain java.lang.System.Logger * Platform Logging API}. The value contains a comma-separated list of any of the @@ -88,6 +100,8 @@ *

  • ssl
  • *
  • trace
  • *
  • channel
  • + *
  • http3
  • + *
  • quic
  • *
    * You can append the frames item with a colon-separated list of any of the following items: *
      @@ -96,31 +110,57 @@ *
    • window
    • *
    • all
    • *

    + * You can append the quic item with a colon-separated list of any of the following items; + * packets are logged in an abridged form that only shows frames offset and length, + * but not content: + *
      + *
    • ack: packets containing ack frames will be logged
    • + *
    • cc: information on congestion control will be logged
    • + *
    • control: packets containing quic controls (such as frames affecting + * flow control, or frames opening or closing streams) + * will be logged
    • + *
    • crypto: packets containing crypto frames will be logged
    • + *
    • data: packets containing stream frames will be logged
    • + *
    • dbb: information on direct byte buffer usage will be logged
    • + *
    • ping: packets containing ping frames will be logged
    • + *
    • processed: information on flow control (processed bytes) will be logged
    • + *
    • retransmit: information on packet loss and recovery will be logged
    • + *
    • timer: information on send task scheduling will be logged
    • + *
    • all
    • + *

    * Specifying an item adds it to the HTTP client's log. For example, if you specify the * following value, then the Platform Logging API logs all possible HTTP Client events:
    * "errors,requests,headers,frames:control:data:window,ssl,trace,channel"
    * Note that you can replace control:data:window with all. The name of the logger is * "jdk.httpclient.HttpClient", and all logging is at level INFO. + * To debug issues with the quic protocol a good starting point is to specify + * {@code quic:control:retransmit}. * *
  • {@systemProperty jdk.httpclient.keepalive.timeout} (default: 30)
    - * The number of seconds to keep idle HTTP connections alive in the keep alive cache. This - * property applies to both HTTP/1.1 and HTTP/2. The value for HTTP/2 can be overridden - * with the {@code jdk.httpclient.keepalive.timeout.h2 property}. + * The number of seconds to keep idle HTTP connections alive in the keep alive cache. + * By default this property applies to HTTP/1.1, HTTP/2 and HTTP/3. + * The value for HTTP/2 and HTTP/3 can be overridden with the + * {@code jdk.httpclient.keepalive.timeout.h2} and {@code jdk.httpclient.keepalive.timeout.h3} + * properties respectively. The value specified for HTTP/2 acts as default value for HTTP/3. *

  • *
  • {@systemProperty jdk.httpclient.keepalive.timeout.h2} (default: see * below)
    The number of seconds to keep idle HTTP/2 connections alive. If not set, then the * {@code jdk.httpclient.keepalive.timeout} setting is used. *

  • + *
  • {@systemProperty jdk.httpclient.keepalive.timeout.h3} (default: see + * below)
    The number of seconds to keep idle HTTP/3 connections alive. If not set, then the + * {@code jdk.httpclient.keepalive.timeout.h2} setting is used. + *

  • *
  • {@systemProperty jdk.httpclient.maxframesize} (default: 16384 or 16kB)
    * The HTTP/2 client maximum frame size in bytes. The server is not permitted to send a frame * larger than this. *

  • *
  • {@systemProperty jdk.httpclient.maxLiteralWithIndexing} (default: 512)
    * The maximum number of header field lines (header name and value pairs) that a - * client is willing to add to the HPack Decoder dynamic table during the decoding + * client is willing to add to the HPack or QPACK Decoder dynamic table during the decoding * of an entire header field section. * This is purely an implementation limit. - * If a peer sends a field section with encoding that + * If a peer sends a field section or a set of QPACK instructions with encoding that * exceeds this limit a {@link java.net.ProtocolException ProtocolException} will be raised. * A value of zero or a negative value means no limit. *

  • @@ -135,7 +175,7 @@ * A value of zero or a negative value means no limit. * *
  • {@systemProperty jdk.httpclient.maxstreams} (default: 100)
    - * The maximum number of HTTP/2 push streams that the client will permit servers to open + * The maximum number of HTTP/2 or HTTP/3 push streams that the client will permit servers to open * simultaneously. *

  • *
  • {@systemProperty jdk.httpclient.receiveBufferSize} (default: operating system @@ -187,6 +227,61 @@ * value means no limit. *

  • * + *

    + * The following system properties can be used to configure some aspects of the + * QUIC Protocol + * implementation used for HTTP/3: + *

      + *
    • {@systemProperty jdk.httpclient.quic.receiveBufferSize} (default: operating system + * default)
      The QUIC {@linkplain java.nio.channels.DatagramChannel UDP client socket} + * {@linkplain java.net.StandardSocketOptions#SO_RCVBUF receive buffer size} in bytes. + * Values less than or equal to zero are ignored. + *

    • + *
    • {@systemProperty jdk.httpclient.quic.sendBufferSize} (default: operating system + * default)
      The QUIC {@linkplain java.nio.channels.DatagramChannel UDP client socket} + * {@linkplain java.net.StandardSocketOptions#SO_SNDBUF send buffer size} in bytes. + * Values less than or equal to zero are ignored. + *

    • + *
    • {@systemProperty jdk.httpclient.quic.defaultMTU} (default: 1200 bytes)
      + * The default Maximum Transmission Unit (MTU) size that will be used on quic connections. + * The default implementation of the HTTP/3 client does not implement Path MTU Detection, + * but will attempt to send 1-RTT packets up to the size defined by this property. + * Specifying a higher value may give better upload performance when the client and + * servers are located on the same machine, but is likely to result in irrecoverable + * packet loss if used over the network. Allowed values are in the range [1200, 65527]. + * If an out-of-range value is specified, the minimum default value will be used. + *

    • + *
    • {@systemProperty jdk.httpclient.quic.maxBytesInFlight} (default: + * 16777216 bytes or 16MB)
      + * This is the maximum number of unacknowledged bytes that the quic congestion + * controller allows to be in flight. When this amount is reached, no new + * data is sent until some of the packets in flight are acknowledged. + *
      + * Allowed values are in the range [2^14, 2^24] (or [16kB, 16MB]). + * If an out-of-range value is specified, it will be clamped to the closest + * value in range. + *

    • + *
    • {@systemProperty jdk.httpclient.quic.maxInitialData} (default: 15728640 + * bytes, or 15MB)
      + * The initial flow control limit for quic connections in bytes. Valid values are in + * the range [0, 2^60]. The initial limit is also used to initialize the receive window + * size. If less than 16kB, the window size will be set to 16kB. + *

    • + *
    • {@systemProperty jdk.httpclient.quic.maxStreamInitialData} (default: 6291456 + * bytes, or 6MB)
      + * The initial flow control limit for quic streams in bytes. Valid values are in + * the range [0, 2^60]. The initial limit is also used to initialize the receive window + * size. If less than 16kB, the window size will be set to 16kB. + *

    • + *
    • {@systemProperty jdk.httpclient.quic.maxInitialTimeout} (default: 30 + * seconds)
      + * This is the maximum time, in seconds, during which the client will wait for a + * response from the server, and continue retransmitting the first Quic INITIAL packet, + * before raising a {@link java.net.ConnectException}. The first INITIAL packet received + * from the target server will disarm this timeout. + *

    • + * + *
    * @moduleGraph * @since 11 */ diff --git a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11SecretKeyFactory.java b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11SecretKeyFactory.java index a2c7a072769..9f42d610c4d 100644 --- a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11SecretKeyFactory.java +++ b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11SecretKeyFactory.java @@ -285,6 +285,10 @@ final class P11SecretKeyFactory extends SecretKeyFactorySpi { putKeyInfo(new TLSKeyInfo("TlsServerAppTrafficSecret")); putKeyInfo(new TLSKeyInfo("TlsServerHandshakeTrafficSecret")); putKeyInfo(new TLSKeyInfo("TlsUpdateNplus1")); + // QUIC-specific + putKeyInfo(new TLSKeyInfo("TlsInitialSecret")); + putKeyInfo(new TLSKeyInfo("TlsClientInitialTrafficSecret")); + putKeyInfo(new TLSKeyInfo("TlsServerInitialTrafficSecret")); putKeyInfo(new KeyInfo("Generic", CKK_GENERIC_SECRET, CKM_GENERIC_SECRET_KEY_GEN)); diff --git a/test/jdk/com/sun/net/httpserver/SANTest.java b/test/jdk/com/sun/net/httpserver/SANTest.java index dc7f5e7bdd5..d35a590931d 100644 --- a/test/jdk/com/sun/net/httpserver/SANTest.java +++ b/test/jdk/com/sun/net/httpserver/SANTest.java @@ -36,6 +36,18 @@ * java.base/sun.net.www.http * java.base/sun.net.www * java.base/sun.net + * java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.packets + * java.net.http/jdk.internal.net.http.quic.frames + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.base/jdk.internal.util * * @run main/othervm SANTest * @summary Update SimpleSSLContext keystore to use SANs for localhost IP addresses diff --git a/test/jdk/java/net/httpclient/AbstractNoBody.java b/test/jdk/java/net/httpclient/AbstractNoBody.java index 603fe16d216..9cc26704bb5 100644 --- a/test/jdk/java/net/httpclient/AbstractNoBody.java +++ b/test/jdk/java/net/httpclient/AbstractNoBody.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -29,9 +29,6 @@ import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; -import java.nio.ByteBuffer; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.net.http.HttpClient; @@ -39,6 +36,7 @@ import java.util.concurrent.atomic.AtomicLong; import javax.net.ssl.SSLContext; import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http3.Http3TestServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -49,6 +47,8 @@ import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; import static org.testng.Assert.assertEquals; public abstract class AbstractNoBody implements HttpServerAdapters { @@ -58,6 +58,7 @@ public abstract class AbstractNoBody implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI_fixed; String httpURI_chunk; String httpsURI_fixed; @@ -66,12 +67,16 @@ public abstract class AbstractNoBody implements HttpServerAdapters { String http2URI_chunk; String https2URI_fixed; String https2URI_chunk; + String http3URI_fixed; + String http3URI_chunk; + String http3URI_head; static final String SIMPLE_STRING = "Hello world. Goodbye world"; static final int ITERATION_COUNT = 3; // a shared executor helps reduce the amount of threads created by the test static final ExecutorService executor = Executors.newFixedThreadPool(ITERATION_COUNT * 2); static final ExecutorService serverExecutor = Executors.newFixedThreadPool(ITERATION_COUNT * 4); + static final AtomicLong serverCount = new AtomicLong(); static final AtomicLong clientCount = new AtomicLong(); static final long start = System.nanoTime(); public static String now() { @@ -85,6 +90,11 @@ public abstract class AbstractNoBody implements HttpServerAdapters { @DataProvider(name = "variants") public Object[][] variants() { return new Object[][]{ + { http3URI_fixed, false,}, + { http3URI_chunk, false }, + { http3URI_fixed, true,}, + { http3URI_chunk, true }, + { httpURI_fixed, false }, { httpURI_chunk, false }, { httpsURI_fixed, false }, @@ -112,17 +122,39 @@ public abstract class AbstractNoBody implements HttpServerAdapters { return HTTP_1_1; if (uri.contains("/http2/") || uri.contains("/https2/")) return HTTP_2; + if (uri.contains("/http3/")) + return HTTP_3; return null; } HttpRequest.Builder newRequestBuilder(String uri) { var builder = HttpRequest.newBuilder(URI.create(uri)); + if (version(uri) == HTTP_3) { + builder.version(HTTP_3); + builder.setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + } return builder; } + HttpResponse headRequest(HttpClient client) + throws IOException, InterruptedException + { + out.println("\n" + now() + "--- Sending HEAD request ----\n"); + err.println("\n" + now() + "--- Sending HEAD request ----\n"); + + var request = newRequestBuilder(http3URI_head) + .HEAD().version(HTTP_2).build(); + var response = client.send(request, BodyHandlers.ofString()); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), HTTP_2); + out.println("\n" + now() + "--- HEAD request succeeded ----\n"); + err.println("\n" + now() + "--- HEAD request succeeded ----\n"); + return response; + } + private HttpClient makeNewClient() { clientCount.incrementAndGet(); - return HttpClient.newBuilder() + return newClientBuilderForH3() .executor(executor) .proxy(NO_PROXY) .sslContext(sslContext) @@ -190,10 +222,22 @@ public abstract class AbstractNoBody implements HttpServerAdapters { https2URI_fixed = "https://" + https2TestServer.serverAuthority() + "/https2/noBodyFixed"; https2URI_chunk = "https://" + https2TestServer.serverAuthority() + "/https2/noBodyChunk"; + // HTTP/3 + HttpTestHandler h3_fixedLengthHandler = new FixedLengthNoBodyHandler(); + HttpTestHandler h3_chunkedHandler = new ChunkedNoBodyHandler(); + http3TestServer = HttpTestServer.create(HTTP_3, sslContext); + http3TestServer.addHandler(h3_fixedLengthHandler, "/http3/noBodyFixed"); + http3TestServer.addHandler(h3_chunkedHandler, "/http3/noBodyChunk"); + http3TestServer.addHandler(new HttpHeadOrGetHandler(), "/http3/noBodyHead"); + http3URI_fixed = "https://" + http3TestServer.serverAuthority() + "/http3/noBodyFixed"; + http3URI_chunk = "https://" + http3TestServer.serverAuthority() + "/http3/noBodyChunk"; + http3URI_head = "https://" + http3TestServer.serverAuthority() + "/http3/noBodyHead"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); var shared = newHttpClient(true); @@ -201,9 +245,13 @@ public abstract class AbstractNoBody implements HttpServerAdapters { out.println("HTTP/1.1 server (TLS) listening at: " + httpsTestServer.serverAuthority()); out.println("HTTP/2 server (h2c) listening at: " + http2TestServer.serverAuthority()); out.println("HTTP/2 server (h2) listening at: " + https2TestServer.serverAuthority()); - + out.println("HTTP/3 server (h2) listening at: " + http3TestServer.serverAuthority()); + out.println(" + alt endpoint (h3) listening at: " + http3TestServer.getH3AltService() + .map(Http3TestServer::getAddress)); out.println("Shared client is: " + shared); + headRequest(shared); + printStamp(END,"setup"); } @@ -215,6 +263,7 @@ public abstract class AbstractNoBody implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); executor.close(); serverExecutor.close(); printStamp(END, "teardown"); @@ -270,26 +319,4 @@ public abstract class AbstractNoBody implements HttpServerAdapters { } } } - - /* - * Converts a ByteBuffer containing bytes encoded using - * the given charset into a string. - * This method does not throw but will replace - * unrecognized sequences with the replacement character. - */ - public static String asString(ByteBuffer buffer, Charset charset) { - var decoded = charset.decode(buffer); - char[] chars = new char[decoded.length()]; - decoded.get(chars); - return new String(chars); - } - - /* - * Converts a ByteBuffer containing UTF-8 bytes into a - * string. This method does not throw but will replace - * unrecognized sequences with the replacement character. - */ - public static String asString(ByteBuffer buffer) { - return asString(buffer, StandardCharsets.UTF_8); - } } diff --git a/test/jdk/java/net/httpclient/AbstractThrowingPublishers.java b/test/jdk/java/net/httpclient/AbstractThrowingPublishers.java index 859169dcaae..9a9c2b44cd7 100644 --- a/test/jdk/java/net/httpclient/AbstractThrowingPublishers.java +++ b/test/jdk/java/net/httpclient/AbstractThrowingPublishers.java @@ -21,6 +21,7 @@ * questions. */ +import jdk.httpclient.test.lib.http3.Http3TestServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.ITestContext; import org.testng.ITestResult; @@ -30,7 +31,6 @@ import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeMethod; import org.testng.annotations.BeforeTest; import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; import javax.net.ssl.SSLContext; import java.io.IOException; @@ -39,6 +39,7 @@ import java.io.OutputStream; import java.io.UncheckedIOException; import java.net.URI; import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; import java.net.http.HttpRequest.BodyPublisher; import java.net.http.HttpRequest.BodyPublishers; @@ -70,9 +71,12 @@ import java.util.stream.Stream; import jdk.httpclient.test.lib.common.HttpServerAdapters; import static java.lang.String.format; +import static java.lang.System.err; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -84,6 +88,7 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI_fixed; String httpURI_chunk; String httpsURI_fixed; @@ -92,6 +97,9 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { String http2URI_chunk; String https2URI_fixed; String https2URI_chunk; + String http3URI_fixed; + String http3URI_chunk; + String http3URI_head; static final int ITERATION_COUNT = 1; // a shared executor helps reduce the amount of threads created by the test @@ -151,6 +159,41 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { + (params == null ? "()" : Arrays.toString(result.getParameters())); } + static Version version(String uri) { + if (uri.contains("/http1/") || uri.contains("/https1/")) + return HTTP_1_1; + if (uri.contains("/http2/") || uri.contains("/https2/")) + return HTTP_2; + if (uri.contains("/http3/")) + return HTTP_3; + return null; + } + + HttpRequest.Builder newRequestBuilder(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)); + if (version(uri) == HTTP_3) { + builder.version(HTTP_3); + builder.setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + } + return builder; + } + + HttpResponse headRequest(HttpClient client) + throws IOException, InterruptedException + { + System.out.println("\n" + now() + "--- Sending HEAD request ----\n"); + System.err.println("\n" + now() + "--- Sending HEAD request ----\n"); + + var request = newRequestBuilder(http3URI_head) + .HEAD().version(HTTP_2).build(); + var response = client.send(request, BodyHandlers.ofString()); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), HTTP_2); + System.out.println("\n" + now() + "--- HEAD request succeeded ----\n"); + System.err.println("\n" + now() + "--- HEAD request succeeded ----\n"); + return response; + } + @BeforeMethod void beforeMethod(ITestContext context) { if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { @@ -189,6 +232,8 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { private String[] uris() { return new String[] { + http3URI_fixed, + http3URI_chunk, httpURI_fixed, httpURI_chunk, httpsURI_fixed, @@ -326,7 +371,7 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { private HttpClient makeNewClient() { clientCount.incrementAndGet(); - return TRACKER.track(HttpClient.newBuilder() + return TRACKER.track(newClientBuilderForH3() .proxy(HttpClient.Builder.NO_PROXY) .executor(executor) .sslContext(sslContext) @@ -354,8 +399,12 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { HttpClient client = null; out.printf("%n%s testSanity(%s, %b)%n", now(), uri, sameClient); for (int i=0; i< ITERATION_COUNT; i++) { - if (!sameClient || client == null) + if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } + } SubmissionPublisher publisher = new SubmissionPublisher<>(executor,10); @@ -374,7 +423,7 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { }, executor); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(uri) .POST(bodyPublisher) .build(); BodyHandler handler = BodyHandlers.ofString(); @@ -445,18 +494,21 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { for (Where where : whereValues) { //if (where == Where.ON_SUBSCRIBE) continue; //if (where == Where.ON_ERROR) continue; - if (!sameClient || client == null) + if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } + } ThrowingBodyPublisher bodyPublisher = new ThrowingBodyPublisher(where.select(thrower), publishers.get()); - HttpRequest req = HttpRequest. - newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(uri) .header("X-expect-exception", "true") .POST(bodyPublisher) .build(); BodyHandler handler = BodyHandlers.ofString(); - System.out.println("try throwing in " + where); + System.out.println(now() + " try throwing in " + where); HttpResponse response = null; if (async) { try { @@ -564,8 +616,12 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { static final class UncheckedCustomExceptionThrower implements Thrower { @Override public void accept(Where where) { - out.println(now() + "Throwing in " + where); - throw new UncheckedCustomException(where.name()); + var thread = Thread.currentThread().getName(); + var thrown = new UncheckedCustomException("[" + thread + "] " + where.name()); + out.println(now() + "Throwing in " + where + ": " + thrown); + err.println(now() + "Throwing in " + where + ": " + thrown); + thrown.printStackTrace(); + throw thrown; } @Override @@ -591,8 +647,13 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { static final class UncheckedIOExceptionThrower implements Thrower { @Override public void accept(Where where) { - out.println(now() + "Throwing in " + where); - throw new UncheckedIOException(new CustomIOException(where.name())); + var thread = Thread.currentThread().getName(); + var cause = new CustomIOException("[" + thread + "] " + where.name()); + var thrown = new UncheckedIOException(cause); + out.println(now() + "Throwing in " + where + ": " + thrown); + err.println(now() + "Throwing in " + where + ": " + thrown); + cause.printStackTrace(); + throw thrown; } @Override @@ -719,6 +780,9 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { @BeforeTest public void setup() throws Exception { + System.out.println(now() + "setup"); + System.err.println(now() + "setup"); + sslContext = new SimpleSSLContext().get(); if (sslContext == null) throw new AssertionError("Unexpected null sslContext"); @@ -732,12 +796,16 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { httpURI_fixed = "http://" + httpTestServer.serverAuthority() + "/http1/fixed/x"; httpURI_chunk = "http://" + httpTestServer.serverAuthority() + "/http1/chunk/x"; + System.out.println(now() + "HTTP/1.1 server created"); + httpsTestServer = HttpTestServer.create(HTTP_1_1, sslContext); httpsTestServer.addHandler(h1_fixedLengthHandler, "/https1/fixed"); httpsTestServer.addHandler(h1_chunkHandler, "/https1/chunk"); httpsURI_fixed = "https://" + httpsTestServer.serverAuthority() + "/https1/fixed/x"; httpsURI_chunk = "https://" + httpsTestServer.serverAuthority() + "/https1/chunk/x"; + System.out.println(now() + "TLS HTTP/1.1 server created"); + // HTTP/2 HttpTestHandler h2_fixedLengthHandler = new HTTP_FixedLengthHandler(); HttpTestHandler h2_chunkedHandler = new HTTP_ChunkedHandler(); @@ -748,31 +816,67 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { http2URI_fixed = "http://" + http2TestServer.serverAuthority() + "/http2/fixed/x"; http2URI_chunk = "http://" + http2TestServer.serverAuthority() + "/http2/chunk/x"; + System.out.println(now() + "HTTP/2 server created"); + https2TestServer = HttpTestServer.create(HTTP_2, sslContext); https2TestServer.addHandler(h2_fixedLengthHandler, "/https2/fixed"); https2TestServer.addHandler(h2_chunkedHandler, "/https2/chunk"); https2URI_fixed = "https://" + https2TestServer.serverAuthority() + "/https2/fixed/x"; https2URI_chunk = "https://" + https2TestServer.serverAuthority() + "/https2/chunk/x"; - serverCount.addAndGet(4); + System.out.println(now() + "TLS HTTP/2 server created"); + + // HTTP/3 + HttpTestHandler h3_fixedLengthHandler = new HTTP_FixedLengthHandler(); + HttpTestHandler h3_chunkedHandler = new HTTP_ChunkedHandler(); + http3TestServer = HttpTestServer.create(HTTP_3, sslContext); + http3TestServer.addHandler(h3_fixedLengthHandler, "/http3/fixed"); + http3TestServer.addHandler(h3_chunkedHandler, "/http3/chunk"); + http3TestServer.addHandler(new HttpHeadOrGetHandler(), "/http3/head"); + http3URI_fixed = "https://" + http3TestServer.serverAuthority() + "/http3/fixed/x"; + http3URI_chunk = "https://" + http3TestServer.serverAuthority() + "/http3/chunk/x"; + http3URI_head = "https://" + http3TestServer.serverAuthority() + "/http3/head/x"; + + System.out.println(now() + "HTTP/3 server created"); + System.err.println(now() + "Starting servers"); + + serverCount.addAndGet(5); httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); + + out.println("HTTP/1.1 server (http) listening at: " + httpTestServer.serverAuthority()); + out.println("HTTP/1.1 server (TLS) listening at: " + httpsTestServer.serverAuthority()); + out.println("HTTP/2 server (h2c) listening at: " + http2TestServer.serverAuthority()); + out.println("HTTP/2 server (h2) listening at: " + https2TestServer.serverAuthority()); + out.println("HTTP/3 server (h2) listening at: " + http3TestServer.serverAuthority()); + out.println(" + alt endpoint (h3) listening at: " + http3TestServer.getH3AltService() + .map(Http3TestServer::getAddress)); + + headRequest(newHttpClient(true)); + + System.out.println(now() + "setup done"); + System.err.println(now() + "setup done"); } @AfterTest public void teardown() throws Exception { + System.out.println(now() + "teardown"); + System.err.println(now() + "teardown"); + String sharedClientName = sharedClient == null ? null : sharedClient.toString(); sharedClient = null; Thread.sleep(100); - AssertionError fail = TRACKER.check(500); + AssertionError fail = TRACKER.check(10000); try { httpTestServer.stop(); httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } finally { if (fail != null) { if (sharedClientName != null) { @@ -781,6 +885,8 @@ public abstract class AbstractThrowingPublishers implements HttpServerAdapters { throw fail; } } + System.out.println(now() + "teardown done"); + System.err.println(now() + "teardown done"); } static class HTTP_FixedLengthHandler implements HttpTestHandler { diff --git a/test/jdk/java/net/httpclient/AbstractThrowingPushPromises.java b/test/jdk/java/net/httpclient/AbstractThrowingPushPromises.java index b5230d48d10..a7aea89c379 100644 --- a/test/jdk/java/net/httpclient/AbstractThrowingPushPromises.java +++ b/test/jdk/java/net/httpclient/AbstractThrowingPushPromises.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -50,6 +50,7 @@ import org.testng.annotations.DataProvider; import javax.net.ssl.SSLContext; import java.io.BufferedReader; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -58,8 +59,10 @@ import java.io.UncheckedIOException; import java.net.URI; import java.net.URISyntaxException; import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; import java.net.http.HttpHeaders; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.net.http.HttpResponse.BodyHandlers; @@ -88,12 +91,14 @@ import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import static java.lang.System.out; import static java.lang.System.err; import static java.lang.String.format; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -103,10 +108,13 @@ public abstract class AbstractThrowingPushPromises implements HttpServerAdapters SSLContext sslContext; HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String http2URI_fixed; String http2URI_chunk; String https2URI_fixed; String https2URI_chunk; + String http3URI_fixed; + String http3URI_chunk; static final int ITERATION_COUNT = 1; // a shared executor helps reduce the amount of threads created by the test @@ -205,6 +213,8 @@ public abstract class AbstractThrowingPushPromises implements HttpServerAdapters private String[] uris() { return new String[] { + http3URI_fixed, + http3URI_chunk, http2URI_fixed, http2URI_chunk, https2URI_fixed, @@ -282,7 +292,8 @@ public abstract class AbstractThrowingPushPromises implements HttpServerAdapters private HttpClient makeNewClient() { clientCount.incrementAndGet(); - return TRACKER.track(HttpClient.newBuilder() + return TRACKER.track(newClientBuilderForH3() + .version(HTTP_3) .proxy(HttpClient.Builder.NO_PROXY) .executor(executor) .sslContext(sslContext) @@ -302,6 +313,22 @@ public abstract class AbstractThrowingPushPromises implements HttpServerAdapters } } + Http3DiscoveryMode config(String uri) { + return uri.contains("/http3/") ? HTTP_3_URI_ONLY : null; + } + + Version version(String uri) { + return uri.contains("/http3/") ? HTTP_3 : HTTP_2; + } + + HttpRequest request(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)) + .version(version(uri)); + var config = config(uri); + if (config != null) builder.setOption(H3_DISCOVERY, config); + return builder.build(); + } + // @Test(dataProvider = "sanity") protected void testSanityImpl(String uri, boolean sameClient) throws Exception { @@ -311,8 +338,8 @@ public abstract class AbstractThrowingPushPromises implements HttpServerAdapters if (!sameClient || client == null) client = newHttpClient(sameClient); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) - .build(); + HttpRequest req = request(uri); + BodyHandler> handler = new ThrowingBodyHandler((w) -> {}, BodyHandlers.ofLines()); @@ -339,7 +366,7 @@ public abstract class AbstractThrowingPushPromises implements HttpServerAdapters .collect(Collectors.joining("|")); assertEquals(promisedBody, promised.uri().toASCIIString()); } - assertEquals(3, pushPromises.size()); + assertEquals(pushPromises.size(), 3); if (!sameClient) { // Wait for the client to be garbage collected. // we use the ReferenceTracker API rather than HttpClient::close here, @@ -423,9 +450,8 @@ public abstract class AbstractThrowingPushPromises implements HttpServerAdapters if (!sameClient || client == null) client = newHttpClient(sameClient); - HttpRequest req = HttpRequest. - newBuilder(URI.create(uri)) - .build(); + HttpRequest req = request(uri); + ConcurrentMap>> promiseMap = new ConcurrentHashMap<>(); Supplier> throwing = () -> @@ -739,24 +765,31 @@ public abstract class AbstractThrowingPushPromises implements HttpServerAdapters throw new AssertionError("Unexpected null sslContext"); // HTTP/2 - HttpTestHandler h2_fixedLengthHandler = new HTTP_FixedLengthHandler(); - HttpTestHandler h2_chunkedHandler = new HTTP_ChunkedHandler(); + HttpTestHandler fixedLengthHandler = new HTTP_FixedLengthHandler(); + HttpTestHandler chunkedHandler = new HTTP_ChunkedHandler(); http2TestServer = HttpTestServer.create(HTTP_2); - http2TestServer.addHandler(h2_fixedLengthHandler, "/http2/fixed"); - http2TestServer.addHandler(h2_chunkedHandler, "/http2/chunk"); + http2TestServer.addHandler(fixedLengthHandler, "/http2/fixed"); + http2TestServer.addHandler(chunkedHandler, "/http2/chunk"); http2URI_fixed = "http://" + http2TestServer.serverAuthority() + "/http2/fixed/x"; http2URI_chunk = "http://" + http2TestServer.serverAuthority() + "/http2/chunk/x"; https2TestServer = HttpTestServer.create(HTTP_2, sslContext); - https2TestServer.addHandler(h2_fixedLengthHandler, "/https2/fixed"); - https2TestServer.addHandler(h2_chunkedHandler, "/https2/chunk"); + https2TestServer.addHandler(fixedLengthHandler, "/https2/fixed"); + https2TestServer.addHandler(chunkedHandler, "/https2/chunk"); https2URI_fixed = "https://" + https2TestServer.serverAuthority() + "/https2/fixed/x"; https2URI_chunk = "https://" + https2TestServer.serverAuthority() + "/https2/chunk/x"; - serverCount.addAndGet(2); + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(fixedLengthHandler, "/http3/fixed"); + http3TestServer.addHandler(chunkedHandler, "/http3/chunk"); + http3URI_fixed = "https://" + http3TestServer.serverAuthority() + "/http3/fixed/x"; + http3URI_chunk = "https://" + http3TestServer.serverAuthority() + "/http3/chunk/x"; + + serverCount.addAndGet(3); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -769,6 +802,7 @@ public abstract class AbstractThrowingPushPromises implements HttpServerAdapters try { http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } finally { if (fail != null) { if (sharedClientName != null) { @@ -794,15 +828,69 @@ public abstract class AbstractThrowingPushPromises implements HttpServerAdapters byte[] promiseBytes = promise.toASCIIString().getBytes(UTF_8); out.printf("TestServer: %s Pushing promise: %s%n", now(), promise); err.printf("TestServer: %s Pushing promise: %s%n", now(), promise); - HttpHeaders headers; + HttpHeaders reqHaders = HttpHeaders.of(Map.of(), ACCEPT_ALL); // empty + HttpHeaders rspHeaders; if (fixed) { String length = String.valueOf(promiseBytes.length); - headers = HttpHeaders.of(Map.of("Content-Length", List.of(length)), + rspHeaders = HttpHeaders.of(Map.of("Content-Length", List.of(length)), ACCEPT_ALL); } else { - headers = HttpHeaders.of(Map.of(), ACCEPT_ALL); // empty + rspHeaders = HttpHeaders.of(Map.of(), ACCEPT_ALL); // empty } - t.serverPush(promise, headers, promiseBytes); + t.serverPush(promise, reqHaders, rspHeaders, promiseBytes); + } catch (URISyntaxException x) { + throw new IOException(x.getMessage(), x); + } + } + + private static long sendHttp3PushPromiseFrame(HttpTestExchange t, + URI requestURI, + String pushPath, + boolean fixed) + throws IOException + { + try { + URI promise = new URI(requestURI.getScheme(), + requestURI.getAuthority(), + pushPath, null, null); + byte[] promiseBytes = promise.toASCIIString().getBytes(UTF_8); + out.printf("TestServer: %s sending PushPromiseFrame: %s%n", now(), promise); + err.printf("TestServer: %s Pushing PushPromiseFrame: %s%n", now(), promise); + // headers are added to the request headers sent in the push promise + HttpHeaders headers = HttpHeaders.of(Map.of(), ACCEPT_ALL); // empty + long pushId = t.sendHttp3PushPromiseFrame(-1, promise, headers); + out.printf("TestServer: %s PushPromiseFrame pushId=%s sent%n", now(), pushId); + err.printf("TestServer: %s PushPromiseFrame pushId=%s sent%n", now(), pushId); + return pushId; + } catch (URISyntaxException x) { + throw new IOException(x.getMessage(), x); + } + } + + private static void sendHttp3PushResponse(HttpTestExchange t, + long pushId, + URI requestURI, + String pushPath, + boolean fixed) + throws IOException + { + try { + URI promise = new URI(requestURI.getScheme(), + requestURI.getAuthority(), + pushPath, null, null); + byte[] promiseBytes = promise.toASCIIString().getBytes(UTF_8); + out.printf("TestServer: %s sending push response pushId=%s: %s%n", now(), pushId, promise); + err.printf("TestServer: %s Pushing push response pushId=%s: %s%n", now(), pushId, promise); + HttpHeaders reqHaders = HttpHeaders.of(Map.of(), ACCEPT_ALL); // empty + HttpHeaders rspHeaders; + if (fixed) { + String length = String.valueOf(promiseBytes.length); + rspHeaders = HttpHeaders.of(Map.of("Content-Length", List.of(length)), + ACCEPT_ALL); + } else { + rspHeaders = HttpHeaders.of(Map.of(), ACCEPT_ALL); // empty + } + t.sendHttp3PushResponse(pushId, promise, reqHaders, rspHeaders, new ByteArrayInputStream(promiseBytes)); } catch (URISyntaxException x) { throw new IOException(x.getMessage(), x); } @@ -822,13 +910,31 @@ public abstract class AbstractThrowingPushPromises implements HttpServerAdapters } byte[] resp = t.getRequestURI().toString().getBytes(StandardCharsets.UTF_8); t.sendResponseHeaders(200, resp.length); //fixed content length + + // With HTTP/3 fixed length we send a single DataFrame, + // therefore we can't interleave a PushPromiseFrame in + // the middle of the DataFrame, so we're going to send + // the PushPromiseFrame before the DataFrame, and then + // fulfill the promise later while sending the response + // body. + long[] pushIds = new long[2]; + if (t.getExchangeVersion() == HTTP_3) { + for (int i = 0; i < 2; i++) { + String path = requestURI.getPath() + "/after/promise-" + (i + 2); + pushIds[i] = sendHttp3PushPromiseFrame(t, requestURI, path, true); + } + } try (OutputStream os = t.getResponseBody()) { int bytes = resp.length/3; for (int i = 0; i<2; i++) { - String path = requestURI.getPath() + "/after/promise-" + (i + 2); os.write(resp, i * bytes, bytes); os.flush(); - pushPromiseFor(t, requestURI, path, true); + String path = requestURI.getPath() + "/after/promise-" + (i + 2); + if (t.getExchangeVersion() == HTTP_2) { + pushPromiseFor(t, requestURI, path, true); + } else if (t.getExchangeVersion() == HTTP_3) { + sendHttp3PushResponse(t, pushIds[i], requestURI, path, true); + } } os.write(resp, 2*bytes, resp.length - 2*bytes); } diff --git a/test/jdk/java/net/httpclient/AbstractThrowingSubscribers.java b/test/jdk/java/net/httpclient/AbstractThrowingSubscribers.java index 7362ada9772..0dc808b8bb2 100644 --- a/test/jdk/java/net/httpclient/AbstractThrowingSubscribers.java +++ b/test/jdk/java/net/httpclient/AbstractThrowingSubscribers.java @@ -21,6 +21,7 @@ * questions. */ +import jdk.httpclient.test.lib.http3.Http3TestServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.ITestContext; import org.testng.ITestResult; @@ -40,6 +41,7 @@ import java.io.OutputStream; import java.io.UncheckedIOException; import java.net.URI; import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; @@ -66,10 +68,13 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import jdk.httpclient.test.lib.common.HttpServerAdapters; +import static java.lang.System.err; import static java.lang.System.out; import static java.lang.String.format; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -81,6 +86,7 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI_fixed; String httpURI_chunk; String httpsURI_fixed; @@ -89,6 +95,9 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters String http2URI_chunk; String https2URI_fixed; String https2URI_chunk; + String http3URI_fixed; + String http3URI_chunk; + String http3URI_head; static final int ITERATION_COUNT = 1; static final int REPEAT_RESPONSE = 3; @@ -149,6 +158,41 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters + (params == null ? "()" : Arrays.toString(result.getParameters())); } + static Version version(String uri) { + if (uri.contains("/http1/") || uri.contains("/https1/")) + return HTTP_1_1; + if (uri.contains("/http2/") || uri.contains("/https2/")) + return HTTP_2; + if (uri.contains("/http3/")) + return HTTP_3; + return null; + } + + HttpRequest.Builder newRequestBuilder(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)); + if (version(uri) == HTTP_3) { + builder.version(HTTP_3); + builder.setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + } + return builder; + } + + HttpResponse headRequest(HttpClient client) + throws IOException, InterruptedException + { + System.out.println("\n" + now() + "--- Sending HEAD request ----\n"); + System.err.println("\n" + now() + "--- Sending HEAD request ----\n"); + + var request = newRequestBuilder(http3URI_head) + .HEAD().version(HTTP_2).build(); + var response = client.send(request, BodyHandlers.ofString()); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), HTTP_2); + System.out.println("\n" + now() + "--- HEAD request succeeded ----\n"); + System.err.println("\n" + now() + "--- HEAD request succeeded ----\n"); + return response; + } + @BeforeMethod void beforeMethod(ITestContext context) { if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { @@ -188,6 +232,8 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters private String[] uris() { return new String[] { + http3URI_fixed, + http3URI_chunk, httpURI_fixed, httpURI_chunk, httpsURI_fixed, @@ -238,7 +284,7 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters private HttpClient makeNewClient() { clientCount.incrementAndGet(); - HttpClient client = HttpClient.newBuilder() + HttpClient client = newClientBuilderForH3() .proxy(HttpClient.Builder.NO_PROXY) .executor(executor) .sslContext(sslContext) @@ -296,10 +342,14 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters String uri2 = uri + "-" + URICOUNT.incrementAndGet() + "/sanity"; out.printf("%ntestSanity(%s, %b)%n", uri2, sameClient); for (int i=0; i< ITERATION_COUNT; i++) { - if (!sameClient || client == null) + if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } + } - HttpRequest req = HttpRequest.newBuilder(URI.create(uri2)) + HttpRequest req = newRequestBuilder(uri2) .build(); BodyHandler handler = new ThrowingBodyHandler((w) -> {}, @@ -439,12 +489,15 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters HttpClient client = null; var throwing = thrower; for (Where where : EnumSet.complementOf(excludes)) { - - if (!sameClient || client == null) + if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } + } + String uri2 = uri + "-" + where; - HttpRequest req = HttpRequest. - newBuilder(URI.create(uri2)) + HttpRequest req = newRequestBuilder(uri2) .build(); thrower = thrower(where, throwing); @@ -593,8 +646,12 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters static final class UncheckedCustomExceptionThrower implements Thrower { @Override public void accept(Where where) { - out.println(now() + "Throwing in " + where); - throw new UncheckedCustomException(where.name()); + var thread = Thread.currentThread().getName(); + var thrown = new UncheckedCustomException("[" + thread + "] " + where.name()); + out.println(now() + "Throwing in " + where + ": " + thrown); + err.println(now() + "Throwing in " + where + ": " + thrown); + thrown.printStackTrace(); + throw thrown; } @Override @@ -611,8 +668,13 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters static final class UncheckedIOExceptionThrower implements Thrower { @Override public void accept(Where where) { - out.println(now() + "Throwing in " + where); - throw new UncheckedIOException(new CustomIOException(where.name())); + var thread = Thread.currentThread().getName(); + var cause = new CustomIOException("[" + thread + "] " + where.name()); + var thrown = new UncheckedIOException(cause); + out.println(now() + "Throwing in " + where + ": " + thrown); + err.println(now() + "Throwing in " + where + ": " + thrown); + cause.printStackTrace(); + throw thrown; } @Override @@ -759,10 +821,15 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters @BeforeTest public void setup() throws Exception { + System.out.println(now() + "setup"); + System.err.println(now() + "setup"); + sslContext = new SimpleSSLContext().get(); if (sslContext == null) throw new AssertionError("Unexpected null sslContext"); + System.out.println(now() + "HTTP/1.1 server created"); + // HTTP/1.1 HttpTestHandler h1_fixedLengthHandler = new HTTP_FixedLengthHandler(); HttpTestHandler h1_chunkHandler = new HTTP_ChunkedHandler(); @@ -772,6 +839,8 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters httpURI_fixed = "http://" + httpTestServer.serverAuthority() + "/http1/fixed/x"; httpURI_chunk = "http://" + httpTestServer.serverAuthority() + "/http1/chunk/x"; + System.out.println(now() + "TLS HTTP/1.1 server created"); + httpsTestServer = HttpTestServer.create(HTTP_1_1, sslContext); httpsTestServer.addHandler(h1_fixedLengthHandler, "/https1/fixed"); httpsTestServer.addHandler(h1_chunkHandler, "/https1/chunk"); @@ -788,21 +857,56 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters http2URI_fixed = "http://" + http2TestServer.serverAuthority() + "/http2/fixed/x"; http2URI_chunk = "http://" + http2TestServer.serverAuthority() + "/http2/chunk/x"; + System.out.println(now() + "HTTP/2 server created"); + https2TestServer = HttpTestServer.create(HTTP_2, sslContext); https2TestServer.addHandler(h2_fixedLengthHandler, "/https2/fixed"); https2TestServer.addHandler(h2_chunkedHandler, "/https2/chunk"); https2URI_fixed = "https://" + https2TestServer.serverAuthority() + "/https2/fixed/x"; https2URI_chunk = "https://" + https2TestServer.serverAuthority() + "/https2/chunk/x"; - serverCount.addAndGet(4); + System.out.println(now() + "TLS HTTP/2 server created"); + + // HTTP/3 + HttpTestHandler h3_fixedLengthHandler = new HTTP_FixedLengthHandler(); + HttpTestHandler h3_chunkedHandler = new HTTP_ChunkedHandler(); + http3TestServer = HttpTestServer.create(HTTP_3, sslContext); + http3TestServer.addHandler(h3_fixedLengthHandler, "/http3/fixed"); + http3TestServer.addHandler(h3_chunkedHandler, "/http3/chunk"); + http3TestServer.addHandler(new HttpHeadOrGetHandler(), "/http3/head"); + http3URI_fixed = "https://" + http3TestServer.serverAuthority() + "/http3/fixed/x"; + http3URI_chunk = "https://" + http3TestServer.serverAuthority() + "/http3/chunk/x"; + http3URI_head = "https://" + http3TestServer.serverAuthority() + "/http3/head/x"; + + System.out.println(now() + "HTTP/3 server created"); + System.err.println(now() + "Starting servers"); + + serverCount.addAndGet(5); httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); + + out.println("HTTP/1.1 server (http) listening at: " + httpTestServer.serverAuthority()); + out.println("HTTP/1.1 server (TLS) listening at: " + httpsTestServer.serverAuthority()); + out.println("HTTP/2 server (h2c) listening at: " + http2TestServer.serverAuthority()); + out.println("HTTP/2 server (h2) listening at: " + https2TestServer.serverAuthority()); + out.println("HTTP/3 server (h2) listening at: " + http3TestServer.serverAuthority()); + out.println(" + alt endpoint (h3) listening at: " + http3TestServer.getH3AltService() + .map(Http3TestServer::getAddress)); + + headRequest(newHttpClient(true)); + + System.out.println(now() + "setup done"); + System.err.println(now() + "setup done"); } @AfterTest public void teardown() throws Exception { + System.out.println(now() + "teardown"); + System.err.println(now() + "teardown"); + String sharedClientName = sharedClient == null ? null : sharedClient.toString(); sharedClient = null; @@ -813,6 +917,7 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } finally { if (fail != null) { if (sharedClientName != null) { @@ -821,6 +926,8 @@ public abstract class AbstractThrowingSubscribers implements HttpServerAdapters throw fail; } } + System.out.println(now() + "teardown done"); + System.err.println(now() + "teardown done"); } static class HTTP_FixedLengthHandler implements HttpTestHandler { diff --git a/test/jdk/java/net/httpclient/AggregateRequestBodyTest.java b/test/jdk/java/net/httpclient/AggregateRequestBodyTest.java index a879525b4b4..8ec3b256e62 100644 --- a/test/jdk/java/net/httpclient/AggregateRequestBodyTest.java +++ b/test/jdk/java/net/httpclient/AggregateRequestBodyTest.java @@ -28,7 +28,7 @@ * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.HttpServerAdapters * ReferenceTracker AggregateRequestBodyTest * @run testng/othervm -Djdk.internal.httpclient.debug=true - * -Djdk.httpclient.HttpClient.log=requests,responses,errors + * -Djdk.httpclient.HttpClient.log=requests,responses,errors,headers,frames * AggregateRequestBodyTest * @summary Tests HttpRequest.BodyPublishers::concat */ @@ -69,6 +69,7 @@ import jdk.httpclient.test.lib.common.HttpServerAdapters; import javax.net.ssl.SSLContext; import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; import org.testng.Assert; import org.testng.ITestContext; import org.testng.ITestResult; @@ -83,6 +84,8 @@ import org.testng.annotations.Test; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -95,14 +98,15 @@ public class AggregateRequestBodyTest implements HttpServerAdapters { HttpTestServer https1TestServer; // HTTPS/1.1 ( https ) HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) - String http1URI; - String https1URI; - String http2URI; - String https2URI; + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) + URI http1URI; + URI https1URI; + URI http2URI; + URI https2URI; + URI http3URI; static final int RESPONSE_CODE = 200; static final int ITERATION_COUNT = 4; - static final Class IAE = IllegalArgumentException.class; static final Class CE = CompletionException.class; // a shared executor helps reduce the amount of threads created by the test static final Executor executor = new TestExecutor(Executors.newCachedThreadPool()); @@ -197,12 +201,13 @@ public class AggregateRequestBodyTest implements HttpServerAdapters { } } - private String[] uris() { - return new String[] { + private URI[] uris() { + return new URI[] { http1URI, https1URI, http2URI, https2URI, + http3URI, }; } @@ -213,12 +218,18 @@ public class AggregateRequestBodyTest implements HttpServerAdapters { if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { return new Object[0][]; } - String[] uris = uris(); + URI[] uris = uris(); Object[][] result = new Object[uris.length * 2][]; int i = 0; for (boolean sameClient : List.of(false, true)) { - for (String uri : uris()) { - result[i++] = new Object[]{uri, sameClient}; + for (URI uri : uris()) { + HttpClient.Version version = null; + if (uri.equals(http1URI) || uri.equals(https1URI)) version = HttpClient.Version.HTTP_1_1; + else if (uri.equals(http2URI) || uri.equals(https2URI)) version = HttpClient.Version.HTTP_2; + else if (uri.equals(http3URI)) version = HTTP_3; + else throw new AssertionError("Unexpected URI: " + uri); + + result[i++] = new Object[]{uri, version, sameClient}; } } assert i == uris.length * 2; @@ -227,7 +238,7 @@ public class AggregateRequestBodyTest implements HttpServerAdapters { private HttpClient makeNewClient() { clientCount.incrementAndGet(); - HttpClient client = HttpClient.newBuilder() + HttpClient client = newClientBuilderForH3() .proxy(HttpClient.Builder.NO_PROXY) .executor(executor) .sslContext(sslContext) @@ -802,7 +813,7 @@ public class AggregateRequestBodyTest implements HttpServerAdapters { } @Test(dataProvider = "variants") - public void test(String uri, boolean sameClient) throws Exception { + public void test(URI uri, HttpClient.Version version, boolean sameClient) throws Exception { checkSkip(); System.out.printf("Request to %s (sameClient: %s)%n", uri, sameClient); System.err.printf("Request to %s (sameClient: %s)%n", uri, sameClient); @@ -814,9 +825,12 @@ public class AggregateRequestBodyTest implements HttpServerAdapters { .map(BodyPublishers::ofString) .toArray(HttpRequest.BodyPublisher[]::new) ); - HttpRequest request = HttpRequest.newBuilder(URI.create(uri)) + + HttpRequest request = HttpRequest.newBuilder(uri) + .version(version) .POST(publisher) .build(); + for (int i = 0; i < ITERATION_COUNT; i++) { System.out.println(uri + ": Iteration: " + i); System.err.println(uri + ": Iteration: " + i); @@ -826,9 +840,19 @@ public class AggregateRequestBodyTest implements HttpServerAdapters { throw new RuntimeException("wrong response code " + Integer.toString(response.statusCode())); assertEquals(response.body(), BODIES.stream().collect(Collectors.joining())); } + if (!sameClient) client.close(); System.out.println("test: DONE"); } + private URI buildURI(String scheme, String path, int port) { + return URIBuilder.newBuilder() + .scheme(scheme) + .loopback() + .port(port) + .path(path) + .buildUnchecked(); + } + @BeforeTest public void setup() throws Exception { sslContext = new SimpleSSLContext().get(); @@ -838,32 +862,37 @@ public class AggregateRequestBodyTest implements HttpServerAdapters { HttpTestHandler handler = new HttpTestEchoHandler(); http1TestServer = HttpTestServer.create(HTTP_1_1); http1TestServer.addHandler(handler, "/http1/echo/"); - http1URI = "http://" + http1TestServer.serverAuthority() + "/http1/echo/x"; + http1URI = buildURI("http", "/http1/echo/x", http1TestServer.getAddress().getPort()); https1TestServer = HttpTestServer.create(HTTP_1_1, sslContext); https1TestServer.addHandler(handler, "/https1/echo/"); - https1URI = "https://" + https1TestServer.serverAuthority() + "/https1/echo/x"; + https1URI = buildURI("https", "/https1/echo/x", https1TestServer.getAddress().getPort()); - // HTTP/2 http2TestServer = HttpTestServer.create(HTTP_2); http2TestServer.addHandler(handler, "/http2/echo/"); - http2URI = "http://" + http2TestServer.serverAuthority() + "/http2/echo/x"; + http2URI = buildURI("http", "/http2/echo/x", http2TestServer.getAddress().getPort()); https2TestServer = HttpTestServer.create(HTTP_2, sslContext); https2TestServer.addHandler(handler, "/https2/echo/"); - https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/echo/x"; + https2URI = buildURI("https", "/https2/echo/x", https2TestServer.getAddress().getPort()); - serverCount.addAndGet(4); + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(handler, "/http3/echo/"); + http3URI = buildURI("https", "/http3/echo/x", http3TestServer.getAddress().getPort()); + + serverCount.addAndGet(5); http1TestServer.start(); https1TestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest public void teardown() throws Exception { String sharedClientName = sharedClient == null ? null : sharedClient.toString(); + sharedClient.close(); sharedClient = null; Thread.sleep(100); AssertionError fail = TRACKER.check(500); @@ -872,6 +901,7 @@ public class AggregateRequestBodyTest implements HttpServerAdapters { https1TestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } finally { if (fail != null) { if (sharedClientName != null) { diff --git a/test/jdk/java/net/httpclient/AltServiceUsageTest.java b/test/jdk/java/net/httpclient/AltServiceUsageTest.java new file mode 100644 index 00000000000..2322b36918b --- /dev/null +++ b/test/jdk/java/net/httpclient/AltServiceUsageTest.java @@ -0,0 +1,454 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.OutputStream; +import java.net.DatagramSocket; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.channels.DatagramChannel; +import java.nio.channels.SocketChannel; +import java.nio.charset.StandardCharsets; +import java.util.Optional; +import java.util.OptionalLong; + +import javax.net.ssl.SSLContext; + +import jdk.test.lib.net.SimpleSSLContext; +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; + +/* + * @test + * @summary Verifies that the HTTP client correctly handles various alt-svc usages + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.HttpServerAdapters + * + * @run testng/othervm -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors + * AltServiceUsageTest + */ +public class AltServiceUsageTest implements HttpServerAdapters { + + private SSLContext sslContext; + private HttpTestServer originServer; + private HttpTestServer altServer; + + private DatagramChannel udpNotResponding; + + @BeforeClass + public void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + + // attempt to create an HTTP/3 server, an HTTP/2 server, and a + // DatagramChannel bound to the same port as the HTTP/2 server + int count = 0; + InetSocketAddress altServerAddress = null, originServerAddress = null; + while (count++ < 10) { + + createServers(); + altServerAddress = altServer.getAddress(); + originServerAddress = originServer.getAddress(); + + if (originServerAddress.equals(altServerAddress)) break; + udpNotResponding = DatagramChannel.open(); + try { + udpNotResponding.bind(originServerAddress); + break; + } catch (IOException x) { + System.out.printf("Failed to bind udpNotResponding to %s: %s%n", + originServerAddress, x); + safeStop(altServer); + safeStop(originServer); + udpNotResponding.close(); + } + } + if (count > 10) { + throw new AssertionError("Couldn't reserve UDP port at " + originServerAddress); + } + + System.out.println("HTTP/3 service started at " + altServerAddress); + System.out.println("HTTP/2 service started at " + originServerAddress); + System.err.println("**** All servers started. Test will start shortly ****"); + } + + private void createServers() throws IOException { + altServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + altServer.addHandler(new All200OKHandler(), "/foo/"); + altServer.addHandler(new RequireAltUsedHeader(), "/bar/"); + altServer.addHandler(new All200OKHandler(), "/maxAgeAltSvc/"); + altServer.start(); + + originServer = HttpTestServer.create(HTTP_2, sslContext); + originServer.addHandler(new H3AltServicePublisher(altServer.getAddress()), "/foo/"); + originServer.addHandler(new H3AltSvcPublisherWith421(altServer.getAddress()), "/foo421/"); + originServer.addHandler(new H3AltServicePublisher(altServer.getAddress()), "/bar/"); + originServer.addHandler(new H3AltSvcPublisherWithMaxAge(altServer.getAddress()), "/maxAgeAltSvc/"); + originServer.start(); + } + + @AfterClass + public void afterClass() throws Exception { + safeStop(originServer); + safeStop(altServer); + udpNotResponding.close(); + } + + private static void safeStop(final HttpTestServer server) { + if (server == null) { + return; + } + final InetSocketAddress serverAddr = server.getAddress(); + try { + System.out.println("Stopping server " + serverAddr); + server.stop(); + } catch (Exception e) { + System.err.println("Ignoring exception: " + e.getMessage() + " that occurred " + + "during stop of server: " + serverAddr); + } + } + + private static class H3AltServicePublisher implements HttpTestHandler { + private static final String RESPONSE_CONTENT = "apple"; + + private final String altSvcHeader; + + /** + * @param altServerAddr Address of the alt service which will be advertised by this handler + */ + private H3AltServicePublisher(final InetSocketAddress altServerAddr) { + this.altSvcHeader = "h3=\"" + altServerAddr.getHostName() + ":" + altServerAddr.getPort() + + "\"; persist=1; intentional-unknown-param-which-is-expected-to-be-ignored=foo;"; + } + + @Override + public void handle(final HttpTestExchange exchange) throws IOException { + String maxAgeParam = ""; + if (getMaxAge().isPresent()) { + maxAgeParam = "; ma=" + getMaxAge().getAsLong(); + } + exchange.getResponseHeaders().addHeader("alt-svc", altSvcHeader + maxAgeParam); + final int statusCode = getResponseCode(); + System.out.println("Sending response with status code " + statusCode + " and " + + "alt-svc header " + altSvcHeader); + final byte[] content = RESPONSE_CONTENT.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(statusCode, content.length); + try (final OutputStream os = exchange.getResponseBody()) { + os.write(content); + } + } + + protected OptionalLong getMaxAge() { + return OptionalLong.empty(); + } + + protected int getResponseCode() { + return 200; + } + } + + private static final class H3AltSvcPublisherWith421 extends H3AltServicePublisher { + + private H3AltSvcPublisherWith421(InetSocketAddress h3ServerAddr) { + super(h3ServerAddr); + } + + @Override + protected int getResponseCode() { + return 421; + } + } + + private static final class H3AltSvcPublisherWithMaxAge extends H3AltServicePublisher { + + private H3AltSvcPublisherWithMaxAge(InetSocketAddress h3ServerAddr) { + super(h3ServerAddr); + } + + @Override + protected OptionalLong getMaxAge() { + return OptionalLong.of(2); + } + } + + private static final class All200OKHandler implements HttpTestHandler { + private static final String RESPONSE_CONTENT = "orange"; + + @Override + public void handle(final HttpTestExchange exchange) throws IOException { + final byte[] content = RESPONSE_CONTENT.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, content.length); + try (final OutputStream os = exchange.getResponseBody()) { + os.write(content); + } + } + } + + private static final class RequireAltUsedHeader implements HttpTestHandler { + private static final String RESPONSE_CONTENT = "tomato"; + + @Override + public void handle(final HttpTestExchange exchange) throws IOException { + final Optional altUsed = exchange.getRequestHeaders().firstValue("alt-used"); + if (altUsed.isEmpty()) { + System.out.println("Error - Missing alt-used header in request"); + exchange.sendResponseHeaders(400, 0); + return; + } + if (altUsed.get().isBlank()) { + System.out.println("Error - Blank value for alt-used header in request"); + exchange.sendResponseHeaders(400, 0); + return; + } + System.out.println("Found alt-used request header: " + altUsed.get()); + final byte[] content = RESPONSE_CONTENT.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, content.length); + try (final OutputStream os = exchange.getResponseBody()) { + os.write(content); + } + } + } + + /** + * This test sends a HTTP3 request to a server which responds back with an alt-svc header pointing + * to a different server. The test then issues the exact same request again and this time it + * expects the alt-service host/server to handle that request. + */ + @Test + public void testAltSvcHeaderUsage() throws Exception { + HttpClient client = newClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .sslContext(sslContext) + .version(HTTP_3) + .build(); + // send a HTTP3 request to a server which is expected to respond back + // with a 200 response and an alt-svc header pointing to another/different H3 server + final URI reqURI = URI.create("https://" + toHostPort(originServer) + "/foo/"); + final HttpRequest request = HttpRequest.newBuilder() + .GET() + .uri(reqURI).build(); + System.out.println("Issuing request " + reqURI); + final HttpResponse response = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + Assert.assertEquals(response.statusCode(), 200, "Unexpected response code"); + Assert.assertEquals(response.body(), H3AltServicePublisher.RESPONSE_CONTENT, + "Unexpected response body"); + final Optional altSvcHeader = response.headers().firstValue("alt-svc"); + Assert.assertTrue(altSvcHeader.isPresent(), "alt-svc header is missing in response"); + System.out.println("Received alt-svc header value: " + altSvcHeader.get()); + final String expectedHeader = "h3=\"" + toHostPort(altServer) + "\""; + Assert.assertTrue(altSvcHeader.get().contains(expectedHeader), + "Unexpected alt-svc header value: " + altSvcHeader.get() + + ", was expected to contain: " + expectedHeader); + + // now issue the same request again and this time expect it to be handled by the alt-service + System.out.println("Again issuing request " + reqURI); + final HttpResponse secondResponse = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + Assert.assertEquals(secondResponse.statusCode(), 200, "Unexpected response code"); + Assert.assertEquals(secondResponse.body(), All200OKHandler.RESPONSE_CONTENT, + "Unexpected response body"); + var TRACKER = ReferenceTracker.INSTANCE; + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + var error = TRACKER.check(tracker, 1500); + if (error != null) throw error; + } + + /** + * This test sends a HTTP3 request to a server which responds back with an alt-svc header pointing + * to a different server and a response code of 421. The spec states that when this response code + * is sent, any alt-svc headers should be ignored by the HTTP client. This test then issues the same + * request again and expects that the alt-service wasn't used. + */ + @Test + public void testDontUseAltServiceFor421() throws Exception { + HttpClient client = newClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .sslContext(sslContext) + .version(HTTP_3) + .build(); + final URI reqURI = URI.create("https://" + toHostPort(originServer) + "/foo421/"); + final HttpRequest request = HttpRequest.newBuilder().GET().uri(reqURI).build(); + System.out.println("Issuing request " + reqURI); + final HttpResponse response = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + Assert.assertEquals(response.statusCode(), 421, "Unexpected response code"); + Assert.assertEquals(response.body(), H3AltServicePublisher.RESPONSE_CONTENT, + "Unexpected response body"); + final Optional altSvcHeader = response.headers().firstValue("alt-svc"); + Assert.assertTrue(altSvcHeader.isPresent(), "alt-svc header is missing in response"); + System.out.println("Received alt-svc header value: " + altSvcHeader.get()); + final String expectedHeader = "h3=\"" + toHostPort(altServer) + "\""; + Assert.assertTrue(altSvcHeader.get().contains(expectedHeader), + "Unexpected alt-svc header value: " + altSvcHeader.get() + + ", was expected to contain: " + expectedHeader); + + // now issue the same request again and this time expect it to be handled by the alt-service + System.out.println("Again issuing request " + reqURI); + final HttpResponse secondResponse = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + Assert.assertEquals(secondResponse.statusCode(), 421, "Unexpected response code"); + Assert.assertEquals(response.body(), H3AltServicePublisher.RESPONSE_CONTENT, + "Unexpected response body"); + var TRACKER = ReferenceTracker.INSTANCE; + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + var error = TRACKER.check(tracker, 1500); + if (error != null) throw error; + } + + /** + * This test sends a HTTP3 request to a server which responds back with an alt-svc header pointing + * to a different server. The test then issues the exact same request again and this time it + * expects the alt-service host/server to handle that request. The alt-service host/server which + * handles this request will assert that the request came in with an "alt-used" header (which + * is expected to be set by the HTTP client). If such a header is missing then that server + * responds back with an erroneous response code of 4xx. + */ + @Test + public void testAltUsedHeader() throws Exception { + HttpClient client = newClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .sslContext(sslContext) + .version(HTTP_3) + .build(); + // send a HTTP3 request to a server which is expected to respond back + // with a 200 response and an alt-svc header pointing to another/different H3 server + final URI reqURI = URI.create("https://" + toHostPort(originServer) + "/bar/"); + final HttpRequest request = HttpRequest.newBuilder().GET().uri(reqURI).build(); + System.out.println("Issuing request " + reqURI); + final HttpResponse response = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + Assert.assertEquals(response.statusCode(), 200, "Unexpected response code"); + Assert.assertEquals(response.body(), H3AltServicePublisher.RESPONSE_CONTENT, + "Unexpected response body"); + final Optional altSvcHeader = response.headers().firstValue("alt-svc"); + Assert.assertTrue(altSvcHeader.isPresent(), "alt-svc header is missing in response"); + System.out.println("Received alt-svc header value: " + altSvcHeader.get()); + final String expectedHeader = "h3=\"" + toHostPort(altServer) + "\""; + Assert.assertTrue(altSvcHeader.get().contains(expectedHeader), + "Unexpected alt-svc header value: " + altSvcHeader.get() + + ", was expected to contain: " + expectedHeader); + + // now issue the same request again and this time expect it to be handled by the alt-service + // (which on the server side will assert the presence of the alt-used header, set by the + // HTTP client) + System.out.println("Again issuing request " + reqURI); + final HttpResponse secondResponse = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + Assert.assertEquals(secondResponse.statusCode(), 200, "Unexpected response code"); + Assert.assertEquals(secondResponse.body(), RequireAltUsedHeader.RESPONSE_CONTENT, + "Unexpected response body"); + var TRACKER = ReferenceTracker.INSTANCE; + var tracker = TRACKER.getTracker(client); + client = null; + var error = TRACKER.check(tracker, 1500); + if (error != null) throw error; + } + + + /** + * This test sends a HTTP3 request to a server which responds back with an alt-svc header pointing + * to a different server. The advertised alt-svc is expected to have a max age of some seconds. + * The test then immediately issues the exact same request again and this time it + * expects the alt-service host/server to handle that request. Once this is done, the test waits + * for a while (duration is greater than the max age of the advertised alt-service) and then + * issues the exact same request again. This time the request is expected to be handled by the + * origin server and not the alt-service (since the alt-service is expected to have expired by + * now) + */ + @Test + public void testAltSvcMaxAgeExpiry() throws Exception { + HttpClient client = newClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .sslContext(sslContext) + .version(HTTP_3) + .build(); + // send a HTTP3 request to a server which is expected to respond back + // with a 200 response and an alt-svc header pointing to another/different H3 server + final URI reqURI = URI.create("https://" + toHostPort(originServer) + "/maxAgeAltSvc/"); + final HttpRequest request = HttpRequest.newBuilder().GET().uri(reqURI).build(); + System.out.println("Issuing request " + reqURI); + final HttpResponse response = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + Assert.assertEquals(response.statusCode(), 200, "Unexpected response code"); + Assert.assertEquals(response.body(), H3AltServicePublisher.RESPONSE_CONTENT, + "Unexpected response body"); + final Optional altSvcHeader = response.headers().firstValue("alt-svc"); + Assert.assertTrue(altSvcHeader.isPresent(), "alt-svc header is missing in response"); + System.out.println("Received alt-svc header value: " + altSvcHeader.get()); + final String expectedHeader = "h3=\"" + toHostPort(altServer) + "\""; + Assert.assertTrue(altSvcHeader.get().contains(expectedHeader), + "Unexpected alt-svc header value: " + altSvcHeader.get() + + ", was expected to contain: " + expectedHeader); + + // now issue the same request again and this time expect it to be handled by the alt-service + System.out.println("Again issuing request " + reqURI); + final HttpResponse secondResponse = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + Assert.assertEquals(secondResponse.statusCode(), 200, "Unexpected response code"); + Assert.assertEquals(secondResponse.body(), All200OKHandler.RESPONSE_CONTENT, + "Unexpected response body"); + + // wait for alt-service to expire + final long sleepDuration = 4000; + System.out.println("Sleeping for " + sleepDuration + " milli seconds for alt-service to expire"); + Thread.sleep(sleepDuration); + // now issue the same request again and this time expect it to be handled by the origin server + // since the alt-service is expected to be expired + System.out.println("Issuing request for a third time " + reqURI); + final HttpResponse thirdResponse = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + Assert.assertEquals(thirdResponse.statusCode(), 200, "Unexpected response code"); + Assert.assertEquals(thirdResponse.body(), H3AltServicePublisher.RESPONSE_CONTENT, + "Unexpected response body"); + var TRACKER = ReferenceTracker.INSTANCE; + var tracker = TRACKER.getTracker(client); + System.gc(); + client = null; + var error = TRACKER.check(tracker, 1500); + if (error != null) throw error; + } + + private static String toHostPort(final HttpTestServer server) { + final InetSocketAddress addr = server.getAddress(); + return addr.getHostName() + ":" + addr.getPort(); + } +} diff --git a/test/jdk/java/net/httpclient/AsFileDownloadTest.java b/test/jdk/java/net/httpclient/AsFileDownloadTest.java index 31d919230a4..42567602dff 100644 --- a/test/jdk/java/net/httpclient/AsFileDownloadTest.java +++ b/test/jdk/java/net/httpclient/AsFileDownloadTest.java @@ -35,6 +35,8 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpClient.Version; import java.net.http.HttpHeaders; import java.net.http.HttpRequest; import java.net.http.HttpRequest.BodyPublishers; @@ -48,11 +50,16 @@ import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Optional; + import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestExchange; +import jdk.httpclient.test.lib.common.TestServerConfigurator; import jdk.test.lib.net.SimpleSSLContext; import jdk.test.lib.util.FileUtils; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.common.TestServerConfigurator; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; import jdk.httpclient.test.lib.http2.Http2TestServer; import jdk.httpclient.test.lib.http2.Http2TestExchange; import jdk.httpclient.test.lib.http2.Http2Handler; @@ -61,6 +68,8 @@ import org.testng.annotations.BeforeTest; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import static java.lang.System.out; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.net.http.HttpResponse.BodyHandlers.ofFileDownload; import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.file.StandardOpenOption.*; @@ -75,6 +84,7 @@ import static org.testng.Assert.fail; * @library /test/lib /test/jdk/java/net/httpclient/lib * @build jdk.httpclient.test.lib.http2.Http2TestServer jdk.test.lib.net.SimpleSSLContext * jdk.test.lib.Platform jdk.test.lib.util.FileUtils + * jdk.httpclient.test.lib.common.HttpServerAdapters * jdk.httpclient.test.lib.common.TestServerConfigurator * @run testng/othervm/timeout=480 AsFileDownloadTest */ @@ -85,10 +95,12 @@ public class AsFileDownloadTest { HttpsServer httpsTestServer; // HTTPS/1.1 Http2TestServer http2TestServer; // HTTP/2 ( h2c ) Http2TestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer h3TestServer; // HTTP/3 String httpURI; String httpsURI; String http2URI; String https2URI; + String h3URI; final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; Path tempDir; @@ -144,38 +156,50 @@ public class AsFileDownloadTest { List list = new ArrayList<>(); Arrays.asList(contentDispositionValues).stream() - .map(e -> new Object[] {httpURI + "?" + e[0], e[1], e[2]}) + .map(e -> new Object[] {httpURI + "?" + e[0], e[1], e[2], Optional.empty()}) .forEach(list::add); Arrays.asList(contentDispositionValues).stream() - .map(e -> new Object[] {httpsURI + "?" + e[0], e[1], e[2]}) + .map(e -> new Object[] {httpsURI + "?" + e[0], e[1], e[2], Optional.empty()}) .forEach(list::add); Arrays.asList(contentDispositionValues).stream() - .map(e -> new Object[] {http2URI + "?" + e[0], e[1], e[2]}) + .map(e -> new Object[] {http2URI + "?" + e[0], e[1], e[2], Optional.empty()}) .forEach(list::add); Arrays.asList(contentDispositionValues).stream() - .map(e -> new Object[] {https2URI + "?" + e[0], e[1], e[2]}) + .map(e -> new Object[] {https2URI + "?" + e[0], e[1], e[2], Optional.empty()}) + .forEach(list::add); + Arrays.asList(contentDispositionValues).stream() + .map(e -> new Object[] {h3URI + "?" + e[0], e[1], e[2], Optional.of(Version.HTTP_3)}) .forEach(list::add); return list.stream().toArray(Object[][]::new); } + HttpClient newHttpClient(Optional version) { + var builder = version.isEmpty() || version.get() != Version.HTTP_3 + ? HttpClient.newBuilder() + : HttpServerAdapters.createClientBuilderForH3().version(Version.HTTP_3); + return builder.sslContext(sslContext).proxy(Builder.NO_PROXY).build(); + } + @Test(dataProvider = "positive") - void test(String uriString, String contentDispositionValue, String expectedFilename) - throws Exception - { + void test(String uriString, String contentDispositionValue, String expectedFilename, + Optional requestVersion) throws Exception { out.printf("test(%s, %s, %s): starting", uriString, contentDispositionValue, expectedFilename); - HttpClient client = HttpClient.newBuilder().sslContext(sslContext).build(); - TRACKER.track(client); - ReferenceQueue queue = new ReferenceQueue<>(); - WeakReference ref = new WeakReference<>(client, queue); - try { + try (HttpClient client = newHttpClient(requestVersion)) { + TRACKER.track(client); + ReferenceQueue queue = new ReferenceQueue<>(); + WeakReference ref = new WeakReference<>(client, queue); URI uri = URI.create(uriString); - HttpRequest request = HttpRequest.newBuilder(uri) - .POST(BodyPublishers.ofString("May the luck of the Irish be with you!")) - .build(); + HttpRequest.Builder requestBuilder = newRequestBuilder(uriString); + if (requestVersion.isPresent()) { + requestBuilder.version(requestVersion.get()); + } + HttpRequest request = requestBuilder.POST( + BodyPublishers.ofString("May the luck of the Irish be with you!")).build(); BodyHandler bh = ofFileDownload(tempDir.resolve(uri.getPath().substring(1)), CREATE, TRUNCATE_EXISTING, WRITE); + out.println("Issuing request " + request); HttpResponse response = client.send(request, bh); Path body = response.body(); out.println("Got response: " + response); @@ -189,6 +213,10 @@ public class AsFileDownloadTest { assertEquals(response.headers().firstValue("Content-Disposition").get(), contentDispositionValue); assertEquals(fileContents, "May the luck of the Irish be with you!"); + if (requestVersion.isPresent()) { + assertEquals(response.version(), requestVersion.get(), "unexpected HTTP version" + + " in response"); + } if (!body.toAbsolutePath().startsWith(tempDir.toAbsolutePath())) { System.out.println("Tempdir = " + tempDir.toAbsolutePath()); @@ -198,16 +226,9 @@ public class AsFileDownloadTest { // additional checks unrelated to file download caseInsensitivityOfHeaders(request.headers()); caseInsensitivityOfHeaders(response.headers()); - } finally { - client = null; - System.gc(); - while (!ref.refersTo(null)) { - System.gc(); - if (queue.remove(100) == ref) break; - } - AssertionError failed = TRACKER.checkShutdown(1000); - if (failed != null) throw failed; } + AssertionError failed = TRACKER.checkClosed(1000); + if (failed != null) throw failed; } // --- Negative @@ -238,54 +259,49 @@ public class AsFileDownloadTest { List list = new ArrayList<>(); Arrays.asList(contentDispositionBADValues).stream() - .map(e -> new Object[] {httpURI + "?" + e[0], e[1]}) + .map(e -> new Object[] {httpURI + "?" + e[0], e[1], Optional.empty()}) .forEach(list::add); Arrays.asList(contentDispositionBADValues).stream() - .map(e -> new Object[] {httpsURI + "?" + e[0], e[1]}) + .map(e -> new Object[] {httpsURI + "?" + e[0], e[1], Optional.empty()}) .forEach(list::add); Arrays.asList(contentDispositionBADValues).stream() - .map(e -> new Object[] {http2URI + "?" + e[0], e[1]}) + .map(e -> new Object[] {http2URI + "?" + e[0], e[1], Optional.empty()}) .forEach(list::add); Arrays.asList(contentDispositionBADValues).stream() - .map(e -> new Object[] {https2URI + "?" + e[0], e[1]}) + .map(e -> new Object[] {https2URI + "?" + e[0], e[1], Optional.empty()}) + .forEach(list::add); + Arrays.asList(contentDispositionBADValues).stream() + .map(e -> new Object[] {h3URI + "?" + e[0], e[1], Optional.of(Version.HTTP_3)}) .forEach(list::add); - return list.stream().toArray(Object[][]::new); } @Test(dataProvider = "negative") - void negativeTest(String uriString, String contentDispositionValue) - throws Exception - { + void negativeTest(String uriString, String contentDispositionValue, + Optional requestVersion) throws Exception { out.printf("negativeTest(%s, %s): starting", uriString, contentDispositionValue); - HttpClient client = HttpClient.newBuilder().sslContext(sslContext).build(); - TRACKER.track(client); - ReferenceQueue queue = new ReferenceQueue<>(); - WeakReference ref = new WeakReference<>(client, queue); + try (HttpClient client = newHttpClient(requestVersion)) { + TRACKER.track(client); + ReferenceQueue queue = new ReferenceQueue<>(); + WeakReference ref = new WeakReference<>(client, queue); - try { - URI uri = URI.create(uriString); - HttpRequest request = HttpRequest.newBuilder(uri) - .POST(BodyPublishers.ofString("Does not matter")) + HttpRequest.Builder reqBuilder = newRequestBuilder(uriString); + if (requestVersion.isPresent()) { + reqBuilder.version(requestVersion.get()); + } + HttpRequest request = reqBuilder.POST(BodyPublishers.ofString("Does not matter")) .build(); - BodyHandler bh = ofFileDownload(tempDir, CREATE, TRUNCATE_EXISTING, WRITE); try { + out.println("Issuing request " + request); HttpResponse response = client.send(request, bh); fail("UNEXPECTED response: " + response + ", path:" + response.body()); } catch (UncheckedIOException | IOException ioe) { System.out.println("Caught expected: " + ioe); } - } finally { - client = null; - System.gc(); - while (!ref.refersTo(null)) { - System.gc(); - if (queue.remove(100) == ref) break; - } - AssertionError failed = TRACKER.checkShutdown(1000); - if (failed != null) throw failed; } + AssertionError failed = TRACKER.checkClosed(1000); + if (failed != null) throw failed; } // -- Infrastructure @@ -297,6 +313,23 @@ public class AsFileDownloadTest { return h + ":" + server.getAddress().getPort(); } + Version version(String uri) { + if (uri.contains("/http1/")) return Version.HTTP_1_1; + if (uri.contains("/https1/")) return Version.HTTP_1_1; + if (uri.contains("/http2/")) return Version.HTTP_2; + if (uri.contains("/https2/")) return Version.HTTP_2; + if (uri.contains("/h3/")) return Version.HTTP_3; + return null; + } + + HttpRequest.Builder newRequestBuilder(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)); + if (version(uri) == Version.HTTP_3) { + builder.setOption(H3_DISCOVERY, h3TestServer.h3DiscoveryConfig()); + } + return builder; + } + @BeforeTest public void setup() throws Exception { tempDir = Paths.get("asFileDownloadTest.tmp.dir"); @@ -309,6 +342,7 @@ public class AsFileDownloadTest { Files.createDirectories(tempDir.resolve("https1/afdt/")); Files.createDirectories(tempDir.resolve("http2/afdt/")); Files.createDirectories(tempDir.resolve("https2/afdt/")); + Files.createDirectories(tempDir.resolve("h3/afdt/")); // HTTP/1.1 server logging in case of security exceptions ( uncomment if needed ) //Logger logger = Logger.getLogger("com.sun.net.httpserver"); @@ -339,10 +373,15 @@ public class AsFileDownloadTest { https2TestServer.addHandler(new Http2FileDispoHandler(), "/https2/afdt"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/afdt"; + h3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3TestServer.addHandler(new Http2FileDispoHandler(), "/h3/afdt"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3/afdt"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + h3TestServer.start(); } @AfterTest @@ -351,6 +390,7 @@ public class AsFileDownloadTest { httpsTestServer.stop(0); http2TestServer.stop(); https2TestServer.stop(); + h3TestServer.stop(); if (Files.exists(tempDir)) { // clean up @@ -392,7 +432,8 @@ public class AsFileDownloadTest { } } - static class Http2FileDispoHandler implements Http2Handler { + static class Http2FileDispoHandler implements Http2Handler, HttpServerAdapters.HttpTestHandler { + @Override public void handle(Http2TestExchange t) throws IOException { try (InputStream is = t.getRequestBody(); @@ -407,6 +448,21 @@ public class AsFileDownloadTest { os.write(bytes); } } + + @Override + public void handle(HttpTestExchange t) throws IOException { + try (InputStream is = t.getRequestBody(); + OutputStream os = t.getResponseBody()) { + byte[] bytes = is.readAllBytes(); + + String value = contentDispositionValueFromURI(t.getRequestURI()); + if (!value.equals("<>")) + t.getResponseHeaders().addHeader("Content-Disposition", value); + + t.sendResponseHeaders(200, bytes.length); + os.write(bytes); + } + } } // --- diff --git a/test/jdk/java/net/httpclient/AsyncExecutorShutdown.java b/test/jdk/java/net/httpclient/AsyncExecutorShutdown.java index 7e7709c033c..5338d64892b 100644 --- a/test/jdk/java/net/httpclient/AsyncExecutorShutdown.java +++ b/test/jdk/java/net/httpclient/AsyncExecutorShutdown.java @@ -40,12 +40,12 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.UncheckedIOException; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Redirect; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import java.nio.channels.ClosedChannelException; @@ -62,15 +62,10 @@ import java.util.concurrent.Executors; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Function; -import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLHandshakeException; -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import javax.net.ssl.SSLContext; + import jdk.test.lib.RandomFactory; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; @@ -82,6 +77,9 @@ import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; @@ -94,15 +92,21 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { } static final Random RANDOM = RandomFactory.getRandom(); + ExecutorService readerService; SSLContext sslContext; - HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] + HttpTestServer httpTestServer; // HTTP/1.1 [ 6 servers ] HttpTestServer httpsTestServer; // HTTPS/1.1 - HttpTestServer http2TestServer; // HTTP/2 ( h2c ) - HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http2TestServer; // HTTP/2 ( h2c ) + HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer h2h3TestServer; // HTTP/2 ( h2+h3 ) + HttpTestServer h3TestServer; // HTTP/2 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String h2h3URI; + String h3URI; + String h2h3Head; static final String MESSAGE = "AsyncExecutorShutdown message body"; static final int ITERATIONS = 3; @@ -110,10 +114,12 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { @DataProvider(name = "positive") public Object[][] positive() { return new Object[][] { - { httpURI, }, - { httpsURI, }, - { http2URI, }, - { https2URI, }, + { h2h3URI, HTTP_3, h2h3TestServer.h3DiscoveryConfig() }, + { h3URI, HTTP_3, h3TestServer.h3DiscoveryConfig() }, + { httpURI, HTTP_1_1, null }, + { httpsURI, HTTP_1_1, null }, + { http2URI, HTTP_2, null }, + { https2URI, HTTP_2, null }, }; } @@ -127,8 +133,9 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { return t; } - static String readBody(InputStream in) { - try { + static String readBody(InputStream body) { + out.printf("[%s] reading body%n", Thread.currentThread()); + try (var in = body) { return new String(in.readAllBytes(), StandardCharsets.UTF_8); } catch (IOException io) { throw new UncheckedIOException(io); @@ -159,26 +166,37 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { } @Test(dataProvider = "positive") - void testConcurrent(String uriString) throws Exception { - out.printf("%n---- starting (%s) ----%n", uriString); + void testConcurrent(String uriString, Version version, Http3DiscoveryMode config) throws Exception { + out.printf("%n---- starting (%s, %s, %s) ----%n%n", uriString, version, config); ExecutorService executorService = Executors.newCachedThreadPool(); - ExecutorService readerService = Executors.newCachedThreadPool(); - HttpClient client = HttpClient.newBuilder() + HttpClient client = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) .executor(executorService) + .version(version == HTTP_1_1 ? HTTP_2 : version) .sslContext(sslContext) .build(); TRACKER.track(client); assert client.executor().isPresent(); + Throwable failed = null; int step = RANDOM.nextInt(ITERATIONS); + int head = Math.min(1, step); + List> bodies = new ArrayList<>(); try { - List> bodies = new ArrayList<>(); for (int i = 0; i < ITERATIONS; i++) { + if (i == head && version == HTTP_3 && config != HTTP_3_URI_ONLY) { + // let's the first request go through whatever version, + // but ensure that the second will find an AltService + // record + out.printf("%d: sending head request%n", i); + headRequest(client); + out.printf("%d: head request sent%n", i); + } URI uri = URI.create(uriString + "/concurrent/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) + .setOption(H3_DISCOVERY, config) .build(); out.printf("Iteration %d request: %s%n", i, request.uri()); CompletableFuture> responseCF; @@ -189,11 +207,15 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { .thenApply((response) -> { out.println(si + ": Got response: " + response); assertEquals(response.statusCode(), 200); + if (si >= head) assertEquals(response.version(), version); return response; }); bodyCF = responseCF.thenApplyAsync(HttpResponse::body, readerService) .thenApply(AsyncExecutorShutdown::readBody) - .thenApply((s) -> { assertEquals(s, MESSAGE); return s;}); + .thenApply((s) -> { + assertEquals(s, MESSAGE); + return s; + }); } catch (RejectedExecutionException x) { out.println(i + ": Got expected exception: " + x); continue; @@ -205,7 +227,8 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { } if (i == step) { out.printf("%d: shutting down executor now%n", i, sleep); - executorService.shutdownNow(); + var list = executorService.shutdownNow(); + out.printf("%d: executor shut down: %s%n", i, list); } var cf = bodyCF.exceptionally((t) -> { Throwable cause = getCause(t); @@ -219,39 +242,72 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { checkCause(String.valueOf(si), cause); return null; }); + out.printf("%d: adding body to bodies list%n", i); bodies.add(cf); } - CompletableFuture.allOf(bodies.toArray(new CompletableFuture[0])).get(); + } catch (Throwable t) { + failed = t; } finally { client = null; - executorService.awaitTermination(2000, TimeUnit.MILLISECONDS); - readerService.shutdown(); - readerService.awaitTermination(2000, TimeUnit.MILLISECONDS); + System.gc(); + try { + out.println("Awaiting executorService termination"); + executorService.awaitTermination(2000, TimeUnit.MILLISECONDS); + out.println("Done"); + } catch (Throwable t) { + if (failed != null) { + failed.addSuppressed(t); + } else failed = t; + } + var error = TRACKER.checkShutdown(2000); + if (error != null) { + out.println("Client hasn't shutdown properly: " + error); + if (failed != null) failed.addSuppressed(error); + else failed = error; + } } + if (failed instanceof Exception fe) { + throw fe; + } else if (failed instanceof Error e) { + throw e; + } + out.println("Awaiting all bodies..."); + CompletableFuture.allOf(bodies.toArray(new CompletableFuture[0])).get(); } @Test(dataProvider = "positive") - void testSequential(String uriString) throws Exception { - out.printf("%n---- starting (%s) ----%n", uriString); + void testSequential(String uriString, Version version, Http3DiscoveryMode config) throws Exception { + out.printf("%n---- starting (%s, %s, %s) ----%n%n", uriString, version, config); ExecutorService executorService = Executors.newCachedThreadPool(); - ExecutorService readerService = Executors.newCachedThreadPool(); - HttpClient client = HttpClient.newBuilder() + HttpClient client = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) + .version(version == HTTP_1_1 ? HTTP_2 : version) .executor(executorService) .sslContext(sslContext) .build(); TRACKER.track(client); assert client.executor().isPresent(); + Throwable failed = null; int step = RANDOM.nextInt(ITERATIONS); + int head = Math.min(1, step); out.printf("will shutdown executor in step %d%n", step); try { for (int i = 0; i < ITERATIONS; i++) { + if (i == head && version == HTTP_3 && config != HTTP_3_URI_ONLY) { + // let's the first request go through whatever version, + // but ensure that the second will find an AltService + // record + out.printf("%d: sending head request%n", i); + headRequest(client); + out.printf("%d: head request sent%n", i); + } URI uri = URI.create(uriString + "/sequential/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) - .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) - .build(); + .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) + .setOption(H3_DISCOVERY, config) + .build(); out.printf("Iteration %d request: %s%n", i, request.uri()); final int si = i; CompletableFuture> responseCF; @@ -261,6 +317,7 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { .thenApply((response) -> { out.println(si + ": Got response: " + response); assertEquals(response.statusCode(), 200); + if (si > 0) assertEquals(response.version(), version); return response; }); bodyCF = responseCF.thenApplyAsync(HttpResponse::body, readerService) @@ -278,9 +335,10 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { } if (i == step) { out.printf("%d: shutting down executor now%n", i, sleep); - executorService.shutdownNow(); + var list = executorService.shutdownNow(); + out.printf("%d: executor shut down: %s%n", i, list); } - bodyCF.handle((r,t) -> { + bodyCF.handle((r, t) -> { if (t != null) { try { Throwable cause = getCause(t); @@ -301,22 +359,63 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { } }).thenCompose((c) -> c).get(); } - } finally { + } catch (Throwable t) { + t.printStackTrace(); + failed = t; + } finally { client = null; - executorService.awaitTermination(2000, TimeUnit.MILLISECONDS); - readerService.shutdown(); - readerService.awaitTermination(2000, TimeUnit.MILLISECONDS); + System.gc(); + try { + out.println("Awaiting executorService termination"); + executorService.awaitTermination(2000, TimeUnit.MILLISECONDS); + out.println("Done"); + } catch (Throwable t) { + t.printStackTrace(); + if (failed != null) { + failed.addSuppressed(t); + } else failed = t; + } + var error = TRACKER.checkShutdown(2000); + if (error != null) { + out.println("Client hasn't shutdown properly: " + error); + if (failed != null) failed.addSuppressed(error); + else failed = error; + } + } + if (failed instanceof Exception fe) { + throw fe; + } else if (failed instanceof Error e) { + throw e; } } // -- Infrastructure + void headRequest(HttpClient client) throws Exception { + HttpRequest request = HttpRequest.newBuilder(URI.create(h2h3Head)) + .version(HTTP_2) + .HEAD() + .build(); + var resp = client.send(request, BodyHandlers.discarding()); + assertEquals(resp.statusCode(), 200); + } + + static void shutdown(ExecutorService executorService) { + try { + executorService.shutdown(); + executorService.awaitTermination(2000, TimeUnit.MILLISECONDS); + } catch (InterruptedException ie) { + executorService.shutdownNow(); + } + } + @BeforeTest public void setup() throws Exception { out.println("\n**** Setup ****\n"); sslContext = new SimpleSSLContext().get(); if (sslContext == null) throw new AssertionError("Unexpected null sslContext"); + readerService = Executors.newCachedThreadPool(); httpTestServer = HttpTestServer.create(HTTP_1_1); httpTestServer.addHandler(new ServerRequestHandler(), "/http1/exec/"); @@ -332,10 +431,21 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { https2TestServer.addHandler(new ServerRequestHandler(), "/https2/exec/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/exec/retry"; + h2h3TestServer = HttpTestServer.create(HTTP_3, sslContext); + h2h3TestServer.addHandler(new ServerRequestHandler(), "/h2h3/exec/"); + h2h3URI = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/exec/retry"; + h2h3TestServer.addHandler(new HttpHeadOrGetHandler(), "/h2h3/head/"); + h2h3Head = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/head/"; + h3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3TestServer.addHandler(new ServerRequestHandler(), "/h3-only/exec/"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3-only/exec/retry"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + h2h3TestServer.start(); + h3TestServer.start(); } @AfterTest @@ -343,10 +453,13 @@ public class AsyncExecutorShutdown implements HttpServerAdapters { Thread.sleep(100); AssertionError fail = TRACKER.checkShutdown(5000); try { + shutdown(readerService); httpTestServer.stop(); httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + h2h3TestServer.stop(); + h3TestServer.stop(); } finally { if (fail != null) throw fail; } diff --git a/test/jdk/java/net/httpclient/AsyncShutdownNow.java b/test/jdk/java/net/httpclient/AsyncShutdownNow.java index 39c82735248..2617b60ee1c 100644 --- a/test/jdk/java/net/httpclient/AsyncShutdownNow.java +++ b/test/jdk/java/net/httpclient/AsyncShutdownNow.java @@ -45,7 +45,9 @@ import java.io.UncheckedIOException; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Redirect; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import java.nio.channels.ClosedChannelException; @@ -62,6 +64,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; + import jdk.httpclient.test.lib.common.HttpServerAdapters; import javax.net.ssl.SSLContext; @@ -72,11 +75,14 @@ import org.testng.annotations.BeforeTest; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; -import static java.lang.System.err; import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; @@ -96,10 +102,15 @@ public class AsyncShutdownNow implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer h2h3TestServer; // HTTP/3 ( h2 + h3 ) + HttpTestServer h3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String h2h3URI; + String h2h3Head; + String h3URI; static final String MESSAGE = "AsyncShutdownNow message body"; static final int ITERATIONS = 3; @@ -107,10 +118,12 @@ public class AsyncShutdownNow implements HttpServerAdapters { @DataProvider(name = "positive") public Object[][] positive() { return new Object[][] { - { httpURI, }, - { httpsURI, }, - { http2URI, }, - { https2URI, }, + { h2h3URI, HTTP_3, h2h3TestServer.h3DiscoveryConfig()}, + { h3URI, HTTP_3, h3TestServer.h3DiscoveryConfig()}, + { httpURI, HTTP_1_1, ALT_SVC}, // do not attempt HTTP/3 + { httpsURI, HTTP_1_1, ALT_SVC}, // do not attempt HTTP/3 + { http2URI, HTTP_2, ALT_SVC}, // do not attempt HTTP/3 + { https2URI, HTTP_2, ALT_SVC}, // do not attempt HTTP/3 }; } @@ -124,14 +137,57 @@ public class AsyncShutdownNow implements HttpServerAdapters { return t; } - static String readBody(InputStream in) { - try { + static String readBody(InputStream body) { + try (InputStream in = body) { return new String(in.readAllBytes(), StandardCharsets.UTF_8); } catch (IOException io) { throw new UncheckedIOException(io); } } + record ExchangeResult(int step, + Version version, + Http3DiscoveryMode config, + HttpResponse response, + boolean firstVersionMayNotMatch) { + + static ExchangeResult afterHead(int step, Version version, Http3DiscoveryMode config) { + return new ExchangeResult(step, version, config, null, false); + } + + static ExchangeResult ofSequential(int step, Version version, Http3DiscoveryMode config) { + return new ExchangeResult(step, version, config, null, true); + } + + ExchangeResult withResponse(HttpResponse response) { + return new ExchangeResult(step(), version(), config(), response, firstVersionMayNotMatch()); + } + + // Ensures that the input stream gets closed in case of assertion + ExchangeResult assertResponseState() { + out.println(step + ": Got response: " + response); + out.printf("%s: expect status 200 and version %s (%s) for %s%n", step, version, config, + response.request().uri()); + assertEquals(response.statusCode(), 200); + if (step == 0 && version == HTTP_3 && firstVersionMayNotMatch) { + out.printf("%s: version not checked%n", step); + } else { + assertEquals(response.version(), version); + out.printf("%s: got expected version %s%n", step, response.version()); + } + return this; + } + } + + void headRequest(HttpClient client) throws Exception { + HttpRequest request = HttpRequest.newBuilder(URI.create(h2h3Head)) + .version(HTTP_2) + .HEAD() + .build(); + var resp = client.send(request, BodyHandlers.discarding()); + assertEquals(resp.statusCode(), 200); + } + static boolean hasExpectedMessage(IOException io) { String message = io.getMessage(); if (message == null) return false; @@ -168,11 +224,12 @@ public class AsyncShutdownNow implements HttpServerAdapters { } @Test(dataProvider = "positive") - void testConcurrent(String uriString) throws Exception { - out.printf("%n---- starting concurrent (%s) ----%n%n", uriString); - HttpClient client = HttpClient.newBuilder() + void testConcurrent(String uriString, Version version, Http3DiscoveryMode config) throws Exception { + out.printf("%n---- starting concurrent (%s, %s, %s) ----%n%n", uriString, version, config); + HttpClient client = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) + .version(version == HTTP_1_1 ? HTTP_2 : version) .sslContext(sslContext) .build(); TRACKER.track(client); @@ -181,21 +238,25 @@ public class AsyncShutdownNow implements HttpServerAdapters { Throwable failed = null; List> bodies = new ArrayList<>(); try { + if (version == HTTP_3 && config != HTTP_3_URI_ONLY) { + headRequest(client); + } + for (int i = 0; i < ITERATIONS; i++) { URI uri = URI.create(uriString + "/concurrent/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) + .setOption(H3_DISCOVERY, config) .build(); out.printf("Iteration %d request: %s%n", i, request.uri()); CompletableFuture> responseCF; CompletableFuture bodyCF; final int si = i; + ExchangeResult result = ExchangeResult.afterHead(si, version, config); responseCF = client.sendAsync(request, BodyHandlers.ofInputStream()) - .thenApply((response) -> { - out.println(si + ": Got response: " + response); - assertEquals(response.statusCode(), 200); - return response; - }); + .thenApply(result::withResponse) + .thenApplyAsync(ExchangeResult::assertResponseState, readerService) + .thenApply(ExchangeResult::response); bodyCF = responseCF.thenApplyAsync(HttpResponse::body, readerService) .thenApply(AsyncShutdownNow::readBody) .thenApply((s) -> { @@ -260,11 +321,13 @@ public class AsyncShutdownNow implements HttpServerAdapters { } @Test(dataProvider = "positive") - void testSequential(String uriString) throws Exception { - out.printf("%n---- starting sequential (%s) ----%n%n", uriString); - HttpClient client = HttpClient.newBuilder() + void testSequential(String uriString, Version version, Http3DiscoveryMode config) throws Exception { + out.printf("%n---- starting sequential (%s, %s, %s) ----%n%n", + uriString, version, config); + HttpClient client = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) + .version(version == HTTP_1_1 ? HTTP_2 : version) .sslContext(sslContext) .build(); TRACKER.track(client); @@ -277,17 +340,17 @@ public class AsyncShutdownNow implements HttpServerAdapters { URI uri = URI.create(uriString + "/sequential/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) + .setOption(H3_DISCOVERY, config) .build(); out.printf("Iteration %d request: %s%n", i, request.uri()); final int si = i; + ExchangeResult result = ExchangeResult.ofSequential(si, version, config); CompletableFuture> responseCF; CompletableFuture bodyCF; responseCF = client.sendAsync(request, BodyHandlers.ofInputStream()) - .thenApply((response) -> { - out.println(si + ": Got response: " + response); - assertEquals(response.statusCode(), 200); - return response; - }); + .thenApply(result::withResponse) + .thenApplyAsync(ExchangeResult::assertResponseState, readerService) + .thenApply(ExchangeResult::response); bodyCF = responseCF.thenApplyAsync(HttpResponse::body, readerService) .thenApply(AsyncShutdownNow::readBody) .thenApply((s) -> { @@ -362,10 +425,21 @@ public class AsyncShutdownNow implements HttpServerAdapters { https2TestServer.addHandler(new ServerRequestHandler(), "/https2/exec/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/exec/retry"; + h2h3TestServer = HttpTestServer.create(HTTP_3, sslContext); + h2h3TestServer.addHandler(new ServerRequestHandler(), "/h2h3/exec/"); + h2h3URI = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/exec/retry"; + h2h3TestServer.addHandler(new HttpHeadOrGetHandler(), "/h2h3/head/"); + h2h3Head = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/head/"; + h3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3TestServer.addHandler(new ServerRequestHandler(), "/h3-only/exec/"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3-only/exec/retry"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + h2h3TestServer.start(); + h3TestServer.start(); } @AfterTest @@ -378,6 +452,8 @@ public class AsyncShutdownNow implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + h2h3TestServer.stop(); + h3TestServer.stop(); } finally { if (fail != null) throw fail; } diff --git a/test/jdk/java/net/httpclient/AuthFilterCacheTest.java b/test/jdk/java/net/httpclient/AuthFilterCacheTest.java index e819f96d947..32026db57fb 100644 --- a/test/jdk/java/net/httpclient/AuthFilterCacheTest.java +++ b/test/jdk/java/net/httpclient/AuthFilterCacheTest.java @@ -26,14 +26,18 @@ import java.net.*; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.stream.Collectors; + import jdk.httpclient.test.lib.common.HttpServerAdapters; -import com.sun.net.httpserver.HttpsConfigurator; import com.sun.net.httpserver.HttpsServer; import jdk.httpclient.test.lib.common.TestServerConfigurator; import org.testng.annotations.AfterClass; @@ -45,6 +49,12 @@ import javax.net.ssl.SSLContext; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.testng.Assert.*; /* * @test @@ -52,9 +62,9 @@ import static java.net.http.HttpClient.Version.HTTP_2; * @summary AuthenticationFilter.Cache::remove may throw ConcurrentModificationException * @library /test/lib /test/jdk/java/net/httpclient/lib * @build jdk.httpclient.test.lib.common.HttpServerAdapters jdk.test.lib.net.SimpleSSLContext - * DigestEchoServer jdk.httpclient.test.lib.common.TestServerConfigurator + * DigestEchoServer ReferenceTracker jdk.httpclient.test.lib.common.TestServerConfigurator * @run testng/othervm -Dtest.requiresHost=true - * -Djdk.httpclient.HttpClient.log=headers + * -Djdk.httpclient.HttpClient.log=requests,headers,errors,quic * -Djdk.internal.httpclient.debug=false * AuthFilterCacheTest */ @@ -63,7 +73,7 @@ public class AuthFilterCacheTest implements HttpServerAdapters { static final String RESPONSE_BODY = "Hello World!"; static final int REQUEST_COUNT = 5; - static final int URI_COUNT = 6; + static final int URI_COUNT = 8; static final CyclicBarrier barrier = new CyclicBarrier(REQUEST_COUNT * URI_COUNT); static final SSLContext context; @@ -80,11 +90,15 @@ public class AuthFilterCacheTest implements HttpServerAdapters { HttpTestServer http2Server; HttpTestServer https1Server; HttpTestServer https2Server; + HttpTestServer h3onlyServer; + HttpTestServer h3altSvcServer; DigestEchoServer.TunnelingProxy proxy; URI http1URI; URI https1URI; URI http2URI; URI https2URI; + URI h3onlyURI; + URI h3altSvcURI; InetSocketAddress proxyAddress; ProxySelector proxySelector; MyAuthenticator auth; @@ -101,14 +115,15 @@ public class AuthFilterCacheTest implements HttpServerAdapters { https1URI.resolve("proxy/orig/"), http2URI.resolve("direct/orig/"), https2URI.resolve("direct/orig/"), - https2URI.resolve("proxy/orig/"))} + https2URI.resolve("proxy/orig/"), + h3onlyURI.resolve("direct/orig/"), + h3altSvcURI.resolve("direct/orig/"))} }; return uris; } public HttpClient newHttpClient(ProxySelector ps, Authenticator auth) { - HttpClient.Builder builder = HttpClient - .newBuilder() + HttpClient.Builder builder = newClientBuilderForH3() .executor(virtualExecutor) .sslContext(context) .authenticator(auth) @@ -154,12 +169,34 @@ public class AuthFilterCacheTest implements HttpServerAdapters { https2URI = new URI("https://" + https2Server.serverAuthority() + "/AuthFilterCacheTest/https2/"); + h3onlyServer = HttpTestServer.create(HTTP_3_URI_ONLY, SSLContext.getDefault()); + h3onlyServer.addHandler(new TestHandler(), "/AuthFilterCacheTest/h3-only/"); + h3onlyURI = new URI("https://" + h3onlyServer.serverAuthority() + + "/AuthFilterCacheTest/h3-only/"); + h3onlyServer.start(); + + h3altSvcServer = HttpTestServer.create(ANY, SSLContext.getDefault()); + h3altSvcServer.addHandler(new TestHandler(), "/AuthFilterCacheTest/h3-alt-svc/"); + h3altSvcServer.addHandler(new HttpHeadOrGetHandler(RESPONSE_BODY), + "/AuthFilterCacheTest/h3-alt-svc/direct/head/"); + h3altSvcURI = new URI("https://" + h3altSvcServer.serverAuthority() + + "/AuthFilterCacheTest/h3-alt-svc/"); + h3altSvcServer.start(); + proxy = DigestEchoServer.createHttpsProxyTunnel( DigestEchoServer.HttpAuthSchemeType.NONE); proxyAddress = proxy.getProxyAddress(); proxySelector = new HttpProxySelector(proxyAddress); client = newHttpClient(proxySelector, auth); + HttpRequest headRequest = HttpRequest.newBuilder(h3altSvcURI.resolve("direct/head/h2")) + .HEAD() + .version(HTTP_2).build(); + System.out.println("Sending head request: " + headRequest); + var headResponse = client.send(headRequest, BodyHandlers.ofString()); + assertEquals(headResponse.statusCode(), 200); + assertEquals(headResponse.version(), HTTP_2); + System.out.println("Setup: done"); } catch (Exception x) { tearDown(); @@ -177,6 +214,8 @@ public class AuthFilterCacheTest implements HttpServerAdapters { https1Server = stop(https1Server, HttpTestServer::stop); http2Server = stop(http2Server, HttpTestServer::stop); https2Server = stop(https2Server, HttpTestServer::stop); + h3onlyServer = stop(h3onlyServer, HttpTestServer::stop); + h3altSvcServer = stop(h3altSvcServer, HttpTestServer::stop); client.close(); virtualExecutor.close(); @@ -229,6 +268,7 @@ public class AuthFilterCacheTest implements HttpServerAdapters { @Override public void handle(HttpTestExchange t) throws IOException { var count = respCounter.incrementAndGet(); + System.out.println("Server got request: " + t.getRequestURI()); System.out.println("Responses handled: " + count); t.getRequestBody().readAllBytes(); @@ -237,15 +277,19 @@ public class AuthFilterCacheTest implements HttpServerAdapters { t.getResponseHeaders() .addHeader("WWW-Authenticate", "Basic realm=\"Earth\""); t.sendResponseHeaders(401, 0); + System.out.println("Server sent 401 for " + t.getRequestURI()); } else { byte[] resp = RESPONSE_BODY.getBytes(StandardCharsets.UTF_8); t.sendResponseHeaders(200, resp.length); + System.out.println("Server sent 200 for " + t.getRequestURI() + "; awaiting barrier"); try { barrier.await(); } catch (Exception e) { + e.printStackTrace(); throw new IOException(e); } t.getResponseBody().write(resp); + System.out.println("Server sent body for " + t.getRequestURI()); } } t.close(); @@ -255,23 +299,54 @@ public class AuthFilterCacheTest implements HttpServerAdapters { void doClient(List uris) { assert uris.size() == URI_COUNT; barrier.reset(); - System.out.println("Client opening connection to: " + uris.toString()); + System.out.println("Client will connect " + REQUEST_COUNT + " times to: " + + uris.stream().map(URI::toString) + .collect(Collectors.joining("\n\t", "\n\t", "\n"))); List>> cfs = new ArrayList<>(); + int count = 0; for (int i = 0; i < REQUEST_COUNT; i++) { for (URI uri : uris) { - HttpRequest req = HttpRequest.newBuilder() - .uri(uri) - .build(); - cfs.add(client.sendAsync(req, HttpResponse.BodyHandlers.ofString())); + String uriStr = uri.toString() + (++count); + var builder = HttpRequest.newBuilder() + .uri(URI.create(uriStr)); + var config = uriStr.contains("h3-only") ? HTTP_3_URI_ONLY + : uriStr.contains("h3-alt-svc") ? ALT_SVC + : null; + if (config != null) { + builder = builder.setOption(H3_DISCOVERY, config).version(HTTP_3); + } else { + builder = builder.version(HTTP_2); + } + HttpRequest req = builder.build(); + System.out.printf("Sending request %s (version=%s, config=%s)%n", + req, req.version(), config); + cfs.add(client.sendAsync(req, HttpResponse.BodyHandlers.ofString()) + .handleAsync((r, t) -> logResponse(req, r, t)) + .thenCompose(Function.identity())); } } CompletableFuture.allOf(cfs.toArray(new CompletableFuture[0])).join(); } + CompletableFuture> logResponse(HttpRequest req, + HttpResponse resp, + Throwable t) { + if (t != null) { + System.out.printf("Request failed: %s (version=%s, config=%s): %s%n", + req, req.version(), req.getOption(H3_DISCOVERY).orElse(null), t); + t.printStackTrace(System.out); + return CompletableFuture.failedFuture(t); + } else { + System.out.printf("Request succeeded: %s (version=%s, config=%s): %s%n", + req, req.version(), req.getOption(H3_DISCOVERY).orElse(null), resp); + return CompletableFuture.completedFuture(resp); + } + } + static final class MyAuthenticator extends Authenticator { - private int count = 0; + private final AtomicInteger count = new AtomicInteger(); MyAuthenticator() { super(); @@ -294,7 +369,7 @@ public class AuthFilterCacheTest implements HttpServerAdapters { PasswordAuthentication passwordAuthentication; int count; synchronized (this) { - count = ++this.count; + count = this.count.incrementAndGet(); passwordAuthentication = super.requestPasswordAuthenticationInstance( host, addr, port, protocol, prompt, scheme, url, reqType); } @@ -304,13 +379,17 @@ public class AuthFilterCacheTest implements HttpServerAdapters { } public int getCount() { - return count; + return count.get(); } } @Test(dataProvider = "uris") public void test(List uris) throws Exception { - System.out.println("Server listening at " + uris.toString()); + System.out.println("Servers listening at " + + uris.stream().map(URI::toString) + .collect(Collectors.joining("\n\t", "\n\t", "\n"))); + System.out.println("h3-alt-svc server listening for h3 at: " + + h3altSvcServer.getH3AltService().map(s -> s.getAddress()).orElse(null)); doClient(uris); } } diff --git a/test/jdk/java/net/httpclient/BasicAuthTest.java b/test/jdk/java/net/httpclient/BasicAuthTest.java index e0aec3b19b7..35d5a7803d6 100644 --- a/test/jdk/java/net/httpclient/BasicAuthTest.java +++ b/test/jdk/java/net/httpclient/BasicAuthTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -49,6 +49,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import javax.net.ssl.SSLContext; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.US_ASCII; public class BasicAuthTest implements HttpServerAdapters { @@ -62,13 +63,15 @@ public class BasicAuthTest implements HttpServerAdapters { test(Version.HTTP_2, false); test(Version.HTTP_1_1, true); test(Version.HTTP_2, true); + test(Version.HTTP_3, true); } public static void test(Version version, boolean secure) throws Exception { ExecutorService e = Executors.newCachedThreadPool(); Handler h = new Handler(); - SSLContext sslContext = secure ? new SimpleSSLContext().get() : null; + SSLContext sslContext = secure || version == Version.HTTP_3 + ? new SimpleSSLContext().get() : null; HttpTestServer server = HttpTestServer.create(version, sslContext, e); HttpTestContext serverContext = server.addHandler(h,"/test/"); ServerAuth sa = new ServerAuth("foo realm"); @@ -77,7 +80,7 @@ public class BasicAuthTest implements HttpServerAdapters { System.out.println("Server auth = " + server.serverAuthority()); ClientAuth ca = new ClientAuth(); - var clientBuilder = HttpClient.newBuilder(); + var clientBuilder = HttpServerAdapters.createClientBuilderForH3(); if (sslContext != null) clientBuilder.sslContext(sslContext); HttpClient client = clientBuilder.authenticator(ca).build(); @@ -85,6 +88,11 @@ public class BasicAuthTest implements HttpServerAdapters { String scheme = sslContext == null ? "http" : "https"; URI uri = new URI(scheme + "://" + server.serverAuthority() + "/test/foo/"+version); var builder = HttpRequest.newBuilder(uri); + if (version == Version.HTTP_3) { + builder.version(version); + var config = server.h3DiscoveryConfig(); + builder.setOption(H3_DISCOVERY, server.h3DiscoveryConfig()); + } HttpRequest req = builder.copy().GET().build(); System.out.println("\n\nSending request: " + req); diff --git a/test/jdk/java/net/httpclient/BasicHTTP2Test.java b/test/jdk/java/net/httpclient/BasicHTTP2Test.java new file mode 100644 index 00000000000..65bcf7631ea --- /dev/null +++ b/test/jdk/java/net/httpclient/BasicHTTP2Test.java @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * ReferenceTracker + * @run testng/othervm -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors + * BasicHTTP2Test + * @summary Basic HTTP/2 test when HTTP/3 is requested + */ + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.DatagramSocket; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.Builder; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.time.Duration; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicLong; +import javax.net.ssl.SSLContext; + +import jdk.test.lib.net.SimpleSSLContext; +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import org.testng.ITestContext; +import org.testng.SkipException; +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import static java.lang.System.out; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class BasicHTTP2Test implements HttpServerAdapters { + + SSLContext sslContext; + HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + String https2URI; + DatagramSocket udp; + + // a shared executor helps reduce the amount of threads created by the test + static final Executor executor = new TestExecutor(Executors.newCachedThreadPool()); + static final ConcurrentMap FAILURES = new ConcurrentHashMap<>(); + static volatile boolean tasksFailed; + static final AtomicLong serverCount = new AtomicLong(); + static final AtomicLong clientCount = new AtomicLong(); + static final long start = System.nanoTime(); + public static String now() { + long now = System.nanoTime() - start; + long secs = now / 1000_000_000; + long mill = (now % 1000_000_000) / 1000_000; + long nan = now % 1000_000; + return String.format("[%d s, %d ms, %d ns] ", secs, mill, nan); + } + + final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; + private volatile HttpClient sharedClient; + + static class TestExecutor implements Executor { + final AtomicLong tasks = new AtomicLong(); + Executor executor; + TestExecutor(Executor executor) { + this.executor = executor; + } + + @Override + public void execute(Runnable command) { + long id = tasks.incrementAndGet(); + executor.execute(() -> { + try { + command.run(); + } catch (Throwable t) { + tasksFailed = true; + System.out.printf(now() + "Task %s failed: %s%n", id, t); + System.err.printf(now() + "Task %s failed: %s%n", id, t); + FAILURES.putIfAbsent("Task " + id, t); + throw t; + } + }); + } + } + + protected boolean stopAfterFirstFailure() { + return Boolean.getBoolean("jdk.internal.httpclient.debug"); + } + + @BeforeMethod + void beforeMethod(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + var x = new SkipException("Skipping: some test failed"); + x.setStackTrace(new StackTraceElement[0]); + throw x; + } + } + + @AfterClass + static final void printFailedTests() { + out.println("\n========================="); + try { + out.printf("%n%sCreated %d servers and %d clients%n", + now(), serverCount.get(), clientCount.get()); + if (FAILURES.isEmpty()) return; + out.println("Failed tests: "); + FAILURES.entrySet().forEach((e) -> { + out.printf("\t%s: %s%n", e.getKey(), e.getValue()); + e.getValue().printStackTrace(out); + e.getValue().printStackTrace(); + }); + if (tasksFailed) { + System.out.println("WARNING: Some tasks failed"); + } + } finally { + out.println("\n=========================\n"); + } + } + + private String[] uris() { + return new String[] { + https2URI, + }; + } + + static AtomicLong URICOUNT = new AtomicLong(); + + private HttpClient makeNewClient() { + clientCount.incrementAndGet(); + HttpClient client = HttpClient.newBuilder() + .proxy(HttpClient.Builder.NO_PROXY) + .executor(executor) + .sslContext(sslContext) + .connectTimeout(Duration.ofSeconds(10)) + .build(); + return TRACKER.track(client); + } + + HttpClient newHttpClient(boolean share) { + if (!share) return makeNewClient(); + HttpClient shared = sharedClient; + if (shared != null) return shared; + synchronized (this) { + shared = sharedClient; + if (shared == null) { + shared = sharedClient = makeNewClient(); + } + return shared; + } + } + + + static void checkStatus(int expected, int found) throws Exception { + if (expected != found) { + System.err.printf ("Test failed: wrong status code %d/%d\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static void checkStrings(String expected, String found) throws Exception { + if (!expected.equals(found)) { + System.err.printf ("Test failed: wrong string %s/%s\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + + @Test + public void testH2() throws Exception { + + System.err.println("XXXXX ====== xxxxx first xxxxx ====== XXXXX"); + HttpClient client = makeNewClient(); + URI uri = URI.create(https2URI); + Builder builder = HttpRequest.newBuilder(uri) + .version(Version.HTTP_3) + .GET(); + if (udp == null) { + out.println("Using config " + ALT_SVC); + builder.setOption(H3_DISCOVERY, ALT_SVC); + } + HttpRequest request = builder.build(); + + // send first request: that will go through regular HTTP/2. + HttpResponse response = client.send(request, BodyHandlers.ofString()); + out.println("Response #1: " + response); + out.println("Version #1: " + response.version()); + assertEquals(response.statusCode(), 200, "first response status"); + assertEquals(response.version(), HTTP_2, "first response version"); + + Thread.sleep(1000); + + // send second request: we still will not find an endpoint in the + // AltServicesRegistry and will send everything to an + // HTTP/2 connection + System.err.println("XXXXX ====== xxxxx second xxxxx ====== XXXXX"); + response = client.send(request, BodyHandlers.ofString()); + out.println("Response #2: " + response); + out.println("Version #2: " + response.version()); + assertEquals(response.statusCode(), 200, "second response status"); + assertEquals(response.version(), HTTP_2, "second response version"); + + } + + @BeforeTest + public void setup() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) + throw new AssertionError("Unexpected null sslContext"); + + // HTTP/2 + HttpTestHandler handler = new Handler(); + HttpTestHandler h3Handler = new Handler(); + + https2TestServer = HttpTestServer.create(HTTP_2, sslContext); + https2TestServer.addHandler(handler, "/https2/test204/"); + https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/test204/x"; + try { + // attempt to prevent any other server/process to use this port + // for UDP + udp = new DatagramSocket(https2TestServer.getAddress()); + } catch (Exception x) { + System.out.println("Failed to allocate UDP socket at: " + https2TestServer.getAddress()); + udp = null; + } + + serverCount.addAndGet(1); + https2TestServer.start(); + } + + @AfterTest + public void teardown() throws Exception { + String sharedClientName = + sharedClient == null ? null : sharedClient.toString(); + sharedClient = null; + Thread.sleep(100); + AssertionError fail = TRACKER.check(500); + try { + if (udp != null) udp.close(); + https2TestServer.stop(); + } finally { + if (fail != null) { + if (sharedClientName != null) { + System.err.println("Shared client name is: " + sharedClientName); + } + throw fail; + } + } + } + + static class Handler implements HttpTestHandler { + + public Handler() {} + + volatile int invocation = 0; + + @Override + public void handle(HttpTestExchange t) + throws IOException { + try { + URI uri = t.getRequestURI(); + System.err.printf("Handler received request for %s\n", uri); + String type = uri.getScheme().toLowerCase(); + InputStream is = t.getRequestBody(); + while (is.read() != -1); + is.close(); + + + if ((invocation++ % 2) == 1) { + System.err.printf("Server sending %d - chunked\n", 200); + t.sendResponseHeaders(200, -1); + OutputStream os = t.getResponseBody(); + os.close(); + } else { + System.err.printf("Server sending %d - 0 length\n", 200); + t.sendResponseHeaders(200, 0); + } + } catch (Throwable e) { + e.printStackTrace(System.err); + throw new IOException(e); + } + } + } +} diff --git a/test/jdk/java/net/httpclient/BasicHTTP3Test.java b/test/jdk/java/net/httpclient/BasicHTTP3Test.java new file mode 100644 index 00000000000..bd31b932b80 --- /dev/null +++ b/test/jdk/java/net/httpclient/BasicHTTP3Test.java @@ -0,0 +1,482 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.Builder; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicLong; +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http2.Http2TestServer; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.ITestContext; +import org.testng.SkipException; +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.testng.Assert.*; + +import static java.lang.System.out; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; + + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * ReferenceTracker + * jdk.httpclient.test.lib.quic.QuicStandaloneServer + * @run testng/othervm -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors + * -Djavax.net.debug=all + * BasicHTTP3Test + * @summary Basic HTTP/3 test + */ +public class BasicHTTP3Test implements HttpServerAdapters { + + SSLContext sslContext; + HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + String https2URI; + HttpTestServer h3TestServer; // HTTP/2 ( h2 + h3) + String h3URI; + HttpTestServer h3qv2TestServer; // HTTP/2 ( h2 + h3 on Quic v2, incompatible nego) + String h3URIQv2; + HttpTestServer h3qv2CTestServer; // HTTP/2 ( h2 + h3 on Quic v2, compatible nego) + String h3URIQv2C; + HttpTestServer h3mtlsTestServer; // HTTP/2 ( h2 + h3), h3 requires client cert + String h3mtlsURI; + HttpTestServer h3TestServerWithRetry; // h3 + String h3URIRetry; + HttpTestServer h3TestServerWithTLSHelloRetry; // h3 + String h3URITLSHelloRetry; + + static final int ITERATION_COUNT = 4; + // a shared executor helps reduce the amount of threads created by the test + static final Executor executor = new TestExecutor(Executors.newCachedThreadPool()); + static final ConcurrentMap FAILURES = new ConcurrentHashMap<>(); + static volatile boolean tasksFailed; + static final AtomicLong clientCount = new AtomicLong(); + static final long start = System.nanoTime(); + public static String now() { + long now = System.nanoTime() - start; + long secs = now / 1000_000_000; + long mill = (now % 1000_000_000) / 1000_000; + long nan = now % 1000_000; + return String.format("[%d s, %d ms, %d ns] ", secs, mill, nan); + } + + final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; + private volatile HttpClient sharedClient; + + static class TestExecutor implements Executor { + final AtomicLong tasks = new AtomicLong(); + Executor executor; + TestExecutor(Executor executor) { + this.executor = executor; + } + + @java.lang.Override + public void execute(Runnable command) { + long id = tasks.incrementAndGet(); + executor.execute(() -> { + try { + command.run(); + } catch (Throwable t) { + tasksFailed = true; + System.out.printf(now() + "Task %s failed: %s%n", id, t); + System.err.printf(now() + "Task %s failed: %s%n", id, t); + FAILURES.putIfAbsent("Task " + id, t); + throw t; + } + }); + } + } + + protected boolean stopAfterFirstFailure() { + return Boolean.getBoolean("jdk.internal.httpclient.debug"); + } + + @BeforeMethod + void beforeMethod(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + var x = new SkipException("Skipping: some test failed"); + x.setStackTrace(new StackTraceElement[0]); + throw x; + } + } + + @AfterClass + static final void printFailedTests() { + out.println("\n========================="); + try { + out.printf("%n%sCreated %d clients%n", + now(), clientCount.get()); + if (FAILURES.isEmpty()) return; + out.println("Failed tests: "); + FAILURES.entrySet().forEach((e) -> { + out.printf("\t%s: %s%n", e.getKey(), e.getValue()); + e.getValue().printStackTrace(out); + e.getValue().printStackTrace(); + }); + if (tasksFailed) { + System.out.println("WARNING: Some tasks failed"); + } + } finally { + out.println("\n=========================\n"); + } + } + + private String[] uris() { + return new String[] { + https2URI, + h3URI + }; + } + + @DataProvider(name = "variants") + public Object[][] variants(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + return new Object[0][]; + } + String[] uris = uris(); + Object[][] result = new Object[uris.length * 2 * 2][]; + int i = 0; + for (var version : List.of(Optional.empty(), Optional.of(Version.HTTP_3))) { + for (boolean sameClient : List.of(false, true)) { + for (String uri : uris()) { + result[i++] = new Object[]{uri, sameClient, version}; + } + } + } + assert i == uris.length * 2 * 2; + return result; + } + + @DataProvider(name = "h3URIs") + public Object[][] versions(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + return new Object[0][]; + } + Object[][] result = { + {h3URI}, {h3URIRetry}, + {h3URIQv2}, {h3URIQv2C}, + {h3mtlsURI}, {h3URITLSHelloRetry}, + }; + return result; + } + + private HttpClient makeNewClient() { + clientCount.incrementAndGet(); + HttpClient client = newClientBuilderForH3() + .version(Version.HTTP_3) + .proxy(HttpClient.Builder.NO_PROXY) + .executor(executor) + .sslContext(sslContext) + .connectTimeout(Duration.ofSeconds(10)) + .build(); + return TRACKER.track(client); + } + + HttpClient newHttpClient(boolean share) { + if (!share) return makeNewClient(); + HttpClient shared = sharedClient; + if (shared != null) return shared; + synchronized (this) { + shared = sharedClient; + if (shared == null) { + shared = sharedClient = makeNewClient(); + } + return shared; + } + } + + @Test(dataProvider = "variants") + public void test(String uri, boolean sameClient, Optional version) throws Exception { + System.out.println("Request to " + uri); + + HttpClient client = newHttpClient(sameClient); + + Builder builder = HttpRequest.newBuilder(URI.create(uri)) + .GET(); + version.ifPresent(builder::version); + for (int i = 0; i < ITERATION_COUNT; i++) { + // don't want to attempt direct connection as there could be another + // HTTP/3 endpoint listening at the URI port. + // sameClient should be fine because version.empty() should + // have come first and populated alt-services. + builder.setOption(H3_DISCOVERY, ALT_SVC); + HttpRequest request = builder.build(); + System.out.println("Iteration: " + i); + HttpResponse response = client.send(request, BodyHandlers.ofString()); + out.println("Response: " + response); + out.println("Version: " + response.version()); + int expectedResponse = 200; + if (response.statusCode() != expectedResponse) + throw new RuntimeException("wrong response code " + response.statusCode()); + } + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + AssertionError error = TRACKER.check(tracker, 1000); + if (error != null) throw error; + } + System.out.println("test: DONE"); + } + + @Test(dataProvider = "h3URIs") + public void testH3(final String h3URI) throws Exception { + HttpClient client = makeNewClient(); + URI uri = URI.create(h3URI); + Builder builder = HttpRequest.newBuilder(uri) + .version(HTTP_2) + .GET(); + HttpRequest request = builder.build(); + HttpResponse response = client.send(request, BodyHandlers.ofString()); + out.println("Response #1: " + response); + out.println("Version #1: " + response.version()); + assertEquals(response.statusCode(), 200, "first response status"); + assertEquals(response.version(), HTTP_2, "first response version"); + + request = builder.version(Version.HTTP_3).build(); + response = client.send(request, BodyHandlers.ofString()); + out.println("Response #2: " + response); + out.println("Version #2: " + response.version()); + assertEquals(response.statusCode(), 200, "second response status"); + assertEquals(response.version(), Version.HTTP_3, "second response version"); + + if (h3URI == h3mtlsURI) { + assertNotNull(response.sslSession().get().getLocalCertificates()); + } else { + assertNull(response.sslSession().get().getLocalCertificates()); + } + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + AssertionError error = TRACKER.check(tracker, 1000); + if (error != null) throw error; + } + + // verify that the client handles HTTP/3 reset stream correctly + @Test + public void testH3Reset() throws Exception { + HttpClient client = makeNewClient(); + URI uri = URI.create(h3URI); + Builder builder = HttpRequest.newBuilder(uri) + .version(HTTP_2) + .GET(); + HttpRequest request = builder.build(); + HttpResponse response = client.send(request, BodyHandlers.ofString()); + out.println("Response #1: " + response); + out.println("Version #1: " + response.version()); + assertEquals(response.statusCode(), 200, "first response status"); + assertEquals(response.version(), HTTP_2, "first response version"); + + // instruct the server side handler to throw an exception + // that then causes the test server to reset the stream + final String resetCausingURI = h3URI + "?handlerShouldThrow=true"; + builder = HttpRequest.newBuilder(URI.create(resetCausingURI)) + .GET(); + request = builder.version(Version.HTTP_3) + .setOption(H3_DISCOVERY, ALT_SVC) + .build(); + try { + response = client.send(request, BodyHandlers.ofString()); + throw new RuntimeException("Unexpectedly received a response instead of an exception," + + " response: " + response); + } catch (IOException e) { + final String msg = e.getMessage(); + if (msg == null || !msg.contains("reset by peer")) { + // unexpected message in the exception, propagate the exception + throw e; + } + } + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + AssertionError error = TRACKER.check(tracker, 1000); + if (error != null) throw error; + } + + @BeforeTest + public void setup() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + https2TestServer = HttpTestServer.create(HTTP_2, sslContext); + https2TestServer.addHandler(new Handler(), "/https2/test/"); + https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/test/x"; + + // A HTTP2 server with H3 enabled on a different host:port than the HTTP2 server + h3TestServer = HttpTestServer.create(HTTP_3, sslContext); + final HttpTestHandler h3Handler = new Handler(); + h3TestServer.addHandler(h3Handler, "/h3/testH3/"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3/testH3/h3"; + assertTrue(h3TestServer.canHandle(HTTP_2, Version.HTTP_3), "Server was expected" + + " to handle both HTTP2 and HTTP3, but doesn't"); + + // A HTTP2 server with H3 QUICv2 enabled on a different host:port than the HTTP2 server + final Http2TestServer h2q2Server = new Http2TestServer("localhost", true, sslContext) + .enableH3AltServiceOnEphemeralPortWithVersion(QuicVersion.QUIC_V2, false); + h3qv2TestServer = HttpTestServer.of(h2q2Server); + h3qv2TestServer.addHandler(h3Handler, "/h3/testH3/"); + h3URIQv2 = "https://" + h3qv2TestServer.serverAuthority() + "/h3/testH3/h3qv2";; + assertTrue(h3qv2TestServer.canHandle(HTTP_2, Version.HTTP_3), "Server was expected" + + " to handle both HTTP2 and HTTP3, but doesn't"); + + // A HTTP2 server with H3 QUICv2 compatible negotiation enabled on a different host:port than the HTTP2 server + final Http2TestServer h2q2CServer = new Http2TestServer("localhost", true, sslContext) + .enableH3AltServiceOnEphemeralPortWithVersion(QuicVersion.QUIC_V2, true); + h3qv2CTestServer = HttpTestServer.of(h2q2CServer); + h3qv2CTestServer.addHandler(h3Handler, "/h3/testH3/"); + h3URIQv2C = "https://" + h3qv2CTestServer.serverAuthority() + "/h3/testH3/h3qv2c";; + assertTrue(h3qv2CTestServer.canHandle(HTTP_2, Version.HTTP_3), "Server was expected" + + " to handle both HTTP2 and HTTP3, but doesn't"); + + // A HTTP2 server with H3 enabled on a different host:port than the HTTP2 server + // H3 server requires the client to authenticate with a certificate + h3mtlsTestServer = HttpTestServer.create(HTTP_3, sslContext); + h3mtlsTestServer.addHandler(h3Handler, "/h3/testH3/"); + h3mtlsTestServer.getH3AltService().get().getQuicServer().setNeedClientAuth(true); + h3mtlsURI = "https://" + h3mtlsTestServer.serverAuthority() + "/h3/testH3/h3mtls"; + assertTrue(h3mtlsTestServer.canHandle(HTTP_2, Version.HTTP_3), "Server was expected" + + " to handle both HTTP2 and HTTP3, but doesn't"); + + // A HTTP2 test server with H3 alt service listening on different host:port + // and the underlying quic server for H3 is configured to send a RETRY packet + final Http2TestServer h2Server = new Http2TestServer("localhost", true, sslContext) + .enableH3AltServiceOnEphemeralPort(); + // configure send retry on QUIC server + h2Server.getH3AltService().get().getQuicServer().sendRetry(true); + h3TestServerWithRetry = HttpTestServer.of(h2Server); + h3TestServerWithRetry.addHandler(h3Handler, "/h3/testH3Retry/"); + h3URIRetry = "https://" + h3TestServerWithRetry.serverAuthority() + "/h3/testH3Retry/x"; + + // A HTTP2 server with H3 enabled on a different host:port than the HTTP2 server + // TLS server rejects X25519 and secp256r1 key shares, + // which forces a hello retry at the moment of writing this test. + h3TestServerWithTLSHelloRetry = HttpTestServer.create(HTTP_3, sslContext); + h3TestServerWithTLSHelloRetry.addHandler(h3Handler, "/h3/testH3tlsretry/"); + h3TestServerWithTLSHelloRetry.getH3AltService().get().getQuicServer().setRejectKeyAgreement(Set.of("x25519", "secp256r1")); + h3URITLSHelloRetry = "https://" + h3TestServerWithTLSHelloRetry.serverAuthority() + "/h3/testH3tlsretry/x"; + assertTrue(h3TestServerWithTLSHelloRetry.canHandle(HTTP_2, Version.HTTP_3), "Server was expected" + + " to handle both HTTP2 and HTTP3, but doesn't"); + + https2TestServer.start(); + h3TestServer.start(); + h3qv2TestServer.start(); + h3qv2CTestServer.start(); + h3mtlsTestServer.start(); + h3TestServerWithRetry.start(); + h3TestServerWithTLSHelloRetry.start(); + } + + @AfterTest + public void teardown() throws Exception { + System.err.println("======================================================="); + System.err.println(" Tearing down test"); + System.err.println("======================================================="); + String sharedClientName = + sharedClient == null ? null : sharedClient.toString(); + sharedClient = null; + Thread.sleep(100); + AssertionError fail = TRACKER.check(500); + try { + https2TestServer.stop(); + h3TestServer.stop(); + h3qv2CTestServer.stop(); + h3qv2TestServer.stop(); + h3mtlsTestServer.stop(); + h3TestServerWithRetry.stop(); + h3TestServerWithTLSHelloRetry.stop(); + } finally { + if (fail != null) { + if (sharedClientName != null) { + System.err.println("Shared client name is: " + sharedClientName); + } + throw fail; + } + } + } + + static class Handler implements HttpTestHandler { + + public Handler() {} + + volatile int invocation = 0; + + @java.lang.Override + public void handle(HttpTestExchange t) + throws IOException { + try { + URI uri = t.getRequestURI(); + System.err.printf("Handler received request for %s\n", uri); + final String query = uri.getQuery(); + if (query != null && query.contains("handlerShouldThrow=true")) { + System.err.printf("intentionally throwing an exception for request %s\n", uri); + throw new RuntimeException("intentionally thrown by handler for request " + uri); + } + try (InputStream is = t.getRequestBody()) { + is.readAllBytes(); + } + if ((invocation++ % 2) == 1) { + System.err.printf("Server sending %d - chunked\n", 200); + t.sendResponseHeaders(200, -1); + OutputStream os = t.getResponseBody(); + os.close(); + } else { + System.err.printf("Server sending %d - 0 length\n", 200); + t.sendResponseHeaders(200, 0); + } + } catch (Throwable e) { + e.printStackTrace(System.err); + throw new IOException(e); + } + } + } +} diff --git a/test/jdk/java/net/httpclient/BasicRedirectTest.java b/test/jdk/java/net/httpclient/BasicRedirectTest.java index 8ea1653b4d0..a19b1444ac6 100644 --- a/test/jdk/java/net/httpclient/BasicRedirectTest.java +++ b/test/jdk/java/net/httpclient/BasicRedirectTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -28,25 +28,25 @@ * @build jdk.httpclient.test.lib.common.HttpServerAdapters jdk.test.lib.net.SimpleSSLContext * @run testng/othervm * -Djdk.httpclient.HttpClient.log=trace,headers,requests + * -Djdk.internal.httpclient.debug=true * BasicRedirectTest */ -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Redirect; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import java.net.http.HttpResponse.BodyHandlers; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; import javax.net.ssl.SSLContext; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; @@ -56,11 +56,16 @@ import org.testng.annotations.Test; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; + public class BasicRedirectTest implements HttpServerAdapters { SSLContext sslContext; @@ -68,14 +73,21 @@ public class BasicRedirectTest implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; - String httpURIToMoreSecure; // redirects HTTP to HTTPS + String httpURIToMoreSecure; // redirects HTTP to HTTPS + String httpURIToH3MoreSecure; // redirects HTTP to HTTPS/3 String httpsURI; String httpsURIToLessSecure; // redirects HTTPS to HTTP String http2URI; String http2URIToMoreSecure; // redirects HTTP to HTTPS + String http2URIToH3MoreSecure; // redirects HTTP to HTTPS/3 String https2URI; String https2URIToLessSecure; // redirects HTTPS to HTTP + String https3URI; + String https3HeadURI; + String http3URIToLessSecure; // redirects HTTP3 to HTTP + String http3URIToH2cLessSecure; // redirects HTTP3 to h2c static final String MESSAGE = "Is fearr Gaeilge briste, na Bearla cliste"; static final int ITERATIONS = 3; @@ -83,31 +95,67 @@ public class BasicRedirectTest implements HttpServerAdapters { @DataProvider(name = "positive") public Object[][] positive() { return new Object[][] { - { httpURI, Redirect.ALWAYS }, - { httpsURI, Redirect.ALWAYS }, - { http2URI, Redirect.ALWAYS }, - { https2URI, Redirect.ALWAYS }, - { httpURIToMoreSecure, Redirect.ALWAYS }, - { http2URIToMoreSecure, Redirect.ALWAYS }, - { httpsURIToLessSecure, Redirect.ALWAYS }, - { https2URIToLessSecure, Redirect.ALWAYS }, + { httpURI, Redirect.ALWAYS, Optional.empty() }, + { httpsURI, Redirect.ALWAYS, Optional.empty() }, + { http2URI, Redirect.ALWAYS, Optional.empty() }, + { https2URI, Redirect.ALWAYS, Optional.empty() }, + { https3URI, Redirect.ALWAYS, Optional.of(HTTP_3) }, + { httpURIToMoreSecure, Redirect.ALWAYS, Optional.empty() }, + { httpURIToH3MoreSecure, Redirect.ALWAYS, Optional.of(HTTP_3) }, + { http2URIToMoreSecure, Redirect.ALWAYS, Optional.empty() }, + { http2URIToH3MoreSecure, Redirect.ALWAYS, Optional.of(HTTP_3) }, + { httpsURIToLessSecure, Redirect.ALWAYS, Optional.empty() }, + { https2URIToLessSecure, Redirect.ALWAYS, Optional.empty() }, + { http3URIToLessSecure, Redirect.ALWAYS, Optional.of(HTTP_3) }, + { http3URIToH2cLessSecure, Redirect.ALWAYS, Optional.of(HTTP_3) }, - { httpURI, Redirect.NORMAL }, - { httpsURI, Redirect.NORMAL }, - { http2URI, Redirect.NORMAL }, - { https2URI, Redirect.NORMAL }, - { httpURIToMoreSecure, Redirect.NORMAL }, - { http2URIToMoreSecure, Redirect.NORMAL }, + { httpURI, Redirect.NORMAL, Optional.empty() }, + { httpsURI, Redirect.NORMAL, Optional.empty() }, + { http2URI, Redirect.NORMAL, Optional.empty() }, + { https2URI, Redirect.NORMAL, Optional.empty() }, + { https3URI, Redirect.NORMAL, Optional.of(HTTP_3) }, + { httpURIToMoreSecure, Redirect.NORMAL, Optional.empty() }, + { http2URIToMoreSecure, Redirect.NORMAL, Optional.empty() }, + { httpURIToH3MoreSecure, Redirect.NORMAL, Optional.of(HTTP_3) }, + { http2URIToH3MoreSecure, Redirect.NORMAL, Optional.of(HTTP_3) }, }; } - @Test(dataProvider = "positive") - void test(String uriString, Redirect redirectPolicy) throws Exception { - out.printf("%n---- starting positive (%s, %s) ----%n", uriString, redirectPolicy); - HttpClient client = HttpClient.newBuilder() + HttpClient createClient(Redirect redirectPolicy, Optional version) throws Exception { + var clientBuilder = newClientBuilderForH3() .followRedirects(redirectPolicy) - .sslContext(sslContext) + .sslContext(sslContext); + HttpClient client = version.map(clientBuilder::version) + .orElse(clientBuilder) .build(); + if (version.stream().anyMatch(HTTP_3::equals)) { + var builder = HttpRequest.newBuilder(URI.create(https3HeadURI)) + .setOption(H3_DISCOVERY, ALT_SVC); + var head = builder.copy().HEAD().version(HTTP_2).build(); + var get = builder.copy().GET().build(); + out.printf("%n---- sending initial head request (%s) -----%n", head.uri()); + var resp = client.send(head, BodyHandlers.ofString()); + assertEquals(resp.statusCode(), 200); + assertEquals(resp.version(), HTTP_2); + out.println("HEADERS: " + resp.headers()); + var length = resp.headers().firstValueAsLong("Content-Length") + .orElseThrow(AssertionError::new); + if (length < 0) throw new AssertionError("negative length " + length); + out.printf("%n---- sending initial HTTP/3 GET request (%s) -----%n", get.uri()); + resp = client.send(get, BodyHandlers.ofString()); + assertEquals(resp.statusCode(), 200); + assertEquals(resp.version(), HTTP_3); + assertEquals(resp.body().getBytes(UTF_8).length, length, + "body \"" + resp.body() + "\": "); + } + return client; + } + + @Test(dataProvider = "positive") + void test(String uriString, Redirect redirectPolicy, Optional clientVersion) throws Exception { + out.printf("%n---- starting positive (%s, %s, %s) ----%n", uriString, redirectPolicy, + clientVersion.map(Version::name).orElse("empty")); + HttpClient client = createClient(redirectPolicy, clientVersion); URI uri = URI.create(uriString); HttpRequest request = HttpRequest.newBuilder(uri).build(); @@ -125,20 +173,24 @@ public class BasicRedirectTest implements HttpServerAdapters { assertEquals(response.body(), MESSAGE); // asserts redirected URI in response.request().uri() assertTrue(response.uri().getPath().endsWith("message")); - assertPreviousRedirectResponses(request, response); + assertPreviousRedirectResponses(request, response, clientVersion); } } static void assertPreviousRedirectResponses(HttpRequest initialRequest, - HttpResponse finalResponse) { + HttpResponse finalResponse, + Optional clientVersion) { // there must be at least one previous response finalResponse.previousResponse() .orElseThrow(() -> new RuntimeException("no previous response")); HttpResponse response = finalResponse; + List versions = new ArrayList<>(); + versions.add(response.version()); do { URI uri = response.uri(); response = response.previousResponse().get(); + versions.add(response.version()); assertTrue(300 <= response.statusCode() && response.statusCode() <= 309, "Expected 300 <= code <= 309, got:" + response.statusCode()); assertEquals(response.body(), null, "Unexpected body: " + response.body()); @@ -153,6 +205,11 @@ public class BasicRedirectTest implements HttpServerAdapters { assertEquals(initialRequest, response.request(), String.format("Expected initial request [%s] to equal last prev req [%s]", initialRequest, response.request())); + if (clientVersion.stream().anyMatch(HTTP_3::equals)) { + out.println(versions.stream().map(Version::name) + .collect(Collectors.joining(" <-- ", "Redirects: ", ";"))); + assertTrue(versions.stream().anyMatch(HTTP_3::equals), "at least one version should be HTTP/3"); + } } // -- negatives @@ -160,27 +217,33 @@ public class BasicRedirectTest implements HttpServerAdapters { @DataProvider(name = "negative") public Object[][] negative() { return new Object[][] { - { httpURI, Redirect.NEVER }, - { httpsURI, Redirect.NEVER }, - { http2URI, Redirect.NEVER }, - { https2URI, Redirect.NEVER }, - { httpURIToMoreSecure, Redirect.NEVER }, - { http2URIToMoreSecure, Redirect.NEVER }, - { httpsURIToLessSecure, Redirect.NEVER }, - { https2URIToLessSecure, Redirect.NEVER }, + { httpURI, Redirect.NEVER, Optional.empty() }, + { httpsURI, Redirect.NEVER, Optional.empty() }, + { http2URI, Redirect.NEVER, Optional.empty() }, + { https2URI, Redirect.NEVER, Optional.empty() }, + { https3URI, Redirect.NEVER, Optional.of(HTTP_3) }, + { httpURIToMoreSecure, Redirect.NEVER, Optional.empty() }, + { http2URIToMoreSecure, Redirect.NEVER, Optional.empty() }, + { httpURIToH3MoreSecure, Redirect.NEVER, Optional.of(HTTP_3) }, + { http2URIToH3MoreSecure, Redirect.NEVER, Optional.of(HTTP_3) }, + { httpsURIToLessSecure, Redirect.NEVER, Optional.empty() }, + { https2URIToLessSecure, Redirect.NEVER, Optional.empty() }, + { http3URIToLessSecure, Redirect.NEVER, Optional.of(HTTP_3) }, + { http3URIToH2cLessSecure, Redirect.NEVER, Optional.of(HTTP_3) }, - { httpsURIToLessSecure, Redirect.NORMAL }, - { https2URIToLessSecure, Redirect.NORMAL }, + { httpsURIToLessSecure, Redirect.NORMAL, Optional.empty() }, + { https2URIToLessSecure, Redirect.NORMAL, Optional.empty() }, + { http3URIToLessSecure, Redirect.NORMAL, Optional.of(HTTP_3) }, + { http3URIToH2cLessSecure, Redirect.NORMAL, Optional.of(HTTP_3) }, }; } @Test(dataProvider = "negative") - void testNegatives(String uriString,Redirect redirectPolicy) throws Exception { - out.printf("%n---- starting negative (%s, %s) ----%n", uriString, redirectPolicy); - HttpClient client = HttpClient.newBuilder() - .followRedirects(redirectPolicy) - .sslContext(sslContext) - .build(); + void testNegatives(String uriString, Redirect redirectPolicy, Optional clientVersion) + throws Exception { + out.printf("%n---- starting negative (%s, %s, %s) ----%n", uriString, redirectPolicy, + clientVersion.map(Version::name).orElse("empty")); + HttpClient client = createClient(redirectPolicy, clientVersion); URI uri = URI.create(uriString); HttpRequest request = HttpRequest.newBuilder(uri).build(); @@ -225,13 +288,25 @@ public class BasicRedirectTest implements HttpServerAdapters { https2TestServer.addHandler(new BasicHttpRedirectHandler(), "/https2/same/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/same/redirect"; + http3TestServer = HttpTestServer.create(ANY, sslContext); + http3TestServer.addHandler(new BasicHttpRedirectHandler(), "/http3/same/"); + https3URI = "https://" + http3TestServer.serverAuthority() + "/http3/same/redirect"; + http3TestServer.addHandler(new HttpHeadOrGetHandler(), "/http3/head"); + https3HeadURI = "https://" + http3TestServer.serverAuthority() + "/http3/head"; + // HTTP to HTTPS redirect handler httpTestServer.addHandler(new ToSecureHttpRedirectHandler(httpsURI), "/http1/toSecure/"); httpURIToMoreSecure = "http://" + httpTestServer.serverAuthority()+ "/http1/toSecure/redirect"; + // HTTP to HTTP/3 redirect handler + httpTestServer.addHandler(new ToSecureHttpRedirectHandler(https3URI), "/http1/toSecureH3/"); + httpURIToH3MoreSecure = "http://" + httpTestServer.serverAuthority()+ "/http1/toSecureH3/redirect"; // HTTP2 to HTTP2S redirect handler http2TestServer.addHandler(new ToSecureHttpRedirectHandler(https2URI), "/http2/toSecure/"); http2URIToMoreSecure = "http://" + http2TestServer.serverAuthority() + "/http2/toSecure/redirect"; + // HTTP2 to HTTP2S redirect handler + http2TestServer.addHandler(new ToSecureHttpRedirectHandler(https3URI), "/http2/toSecureH3/"); + http2URIToH3MoreSecure = "http://" + http2TestServer.serverAuthority() + "/http2/toSecureH3/redirect"; // HTTPS to HTTP redirect handler httpsTestServer.addHandler(new ToLessSecureRedirectHandler(httpURI), "/https1/toLessSecure/"); @@ -239,11 +314,19 @@ public class BasicRedirectTest implements HttpServerAdapters { // HTTPS2 to HTTP2 redirect handler https2TestServer.addHandler(new ToLessSecureRedirectHandler(http2URI), "/https2/toLessSecure/"); https2URIToLessSecure = "https://" + https2TestServer.serverAuthority() + "/https2/toLessSecure/redirect"; + // HTTP3 to HTTP redirect handler + http3TestServer.addHandler(new ToLessSecureRedirectHandler(httpURI), "/http3/toLessSecure/"); + http3URIToLessSecure = "https://" + http3TestServer.serverAuthority() + "/http3/toLessSecure/redirect"; + // HTTP3 to HTTP2 redirect handler + http3TestServer.addHandler(new ToLessSecureRedirectHandler(http2URI), "/http3/toLessSecureH2/"); + http3URIToH2cLessSecure = "https://" + http3TestServer.serverAuthority() + "/http3/toLessSecureH2/redirect"; httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); + createClient(Redirect.NEVER, Optional.of(HTTP_3)); } @AfterTest diff --git a/test/jdk/java/net/httpclient/CancelRequestTest.java b/test/jdk/java/net/httpclient/CancelRequestTest.java index 7851b112498..bfc1eff9cf9 100644 --- a/test/jdk/java/net/httpclient/CancelRequestTest.java +++ b/test/jdk/java/net/httpclient/CancelRequestTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -49,6 +49,7 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLHandshakeException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -57,6 +58,7 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpConnectTimeoutException; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.net.http.HttpResponse.BodyHandlers; @@ -79,8 +81,9 @@ import jdk.httpclient.test.lib.common.HttpServerAdapters; import static java.lang.System.out; import static java.lang.System.err; -import static java.net.http.HttpClient.Version.HTTP_1_1; -import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.*; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -95,10 +98,15 @@ public class CancelRequestTest implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer h2h3TestServer; // HTTP/3 ( h2 + h3 ) + HttpTestServer h3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String h2h3URI; + String h2h3Head; + String h3URI; static final long SERVER_LATENCY = 75; static final int MAX_CLIENT_DELAY = 75; @@ -200,6 +208,8 @@ public class CancelRequestTest implements HttpServerAdapters { httpsURI, http2URI, https2URI, + h2h3URI, + h3URI, }; } @@ -245,7 +255,7 @@ public class CancelRequestTest implements HttpServerAdapters { private HttpClient makeNewClient() { clientCount.incrementAndGet(); - return TRACKER.track(HttpClient.newBuilder() + return TRACKER.track(newClientBuilderForH3() .proxy(HttpClient.Builder.NO_PROXY) .executor(executor) .sslContext(sslContext) @@ -265,6 +275,17 @@ public class CancelRequestTest implements HttpServerAdapters { } } + // set HTTP/3 version on the request when targeting + // an HTTP/3 server + private HttpRequest.Builder requestBuilder(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)); + if (uri.contains("h3")) { + builder.version(HTTP_3); + } + return builder; + } + + final static String BODY = "Some string | that ? can | be split ? several | ways."; // should accept SSLHandshakeException because of the connectionAborter @@ -273,8 +294,12 @@ public class CancelRequestTest implements HttpServerAdapters { // rewrap in "Request Cancelled" when the multi exchange was aborted... private static boolean isCancelled(Throwable t) { while (t instanceof ExecutionException) t = t.getCause(); - if (t instanceof CancellationException) return true; - if (t instanceof IOException) return String.valueOf(t).contains("Request cancelled"); + Throwable cause = t; + while (cause != null) { + if (cause instanceof CancellationException) return true; + if (cause instanceof IOException && String.valueOf(cause).contains("Request cancelled")) return true; + cause = cause.getCause(); + } out.println("Not a cancellation exception: " + t); t.printStackTrace(out); return false; @@ -290,6 +315,15 @@ public class CancelRequestTest implements HttpServerAdapters { } } + void headRequest(HttpClient client) throws Exception { + HttpRequest request = HttpRequest.newBuilder(URI.create(h2h3Head)) + .version(HTTP_2) + .HEAD() + .build(); + var resp = client.send(request, BodyHandlers.discarding()); + assertEquals(resp.statusCode(), 200); + } + @Test(dataProvider = "asyncurls") public void testGetSendAsync(String uri, boolean sameClient, boolean mayInterruptIfRunning) throws Exception { @@ -302,14 +336,19 @@ public class CancelRequestTest implements HttpServerAdapters { client = newHttpClient(sameClient); Tracker tracker = TRACKER.getTracker(client); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + // Populate alt-svc registry with h3 service + if (uri.contains("h2h3")) headRequest(client); + Http3DiscoveryMode config = uri.contains("h3-only") ? h3TestServer.h3DiscoveryConfig() : null; + HttpRequest req = requestBuilder(uri) .GET() + .setOption(H3_DISCOVERY, config) .build(); BodyHandler handler = BodyHandlers.ofString(); CountDownLatch latch = new CountDownLatch(1); CompletableFuture> response = client.sendAsync(req, handler); var cf1 = response.whenComplete((r,t) -> System.out.println(t)); CompletableFuture> cf2 = cf1.whenComplete((r,t) -> latch.countDown()); + out.println("iteration: " + i + ", req: " + req.uri()); out.println("response: " + response); out.println("cf1: " + cf1); out.println("cf2: " + cf2); @@ -352,10 +391,12 @@ public class CancelRequestTest implements HttpServerAdapters { Throwable wrapped = x.getCause(); Throwable cause = wrapped; if (mayInterruptIfRunning) { - assertTrue(CancellationException.class.isAssignableFrom(wrapped.getClass()), - "Unexpected exception: " + wrapped); - cause = wrapped.getCause(); - out.println("CancellationException cause: " + x); + if (CancellationException.class.isAssignableFrom(wrapped.getClass())) { + cause = wrapped.getCause(); + out.println("CancellationException cause: " + x); + } else if (!isCancelled(cause)) { + throw new RuntimeException("Unexpected cause: " + cause); + } if (cause instanceof HttpConnectTimeoutException) { cause.printStackTrace(out); throw new RuntimeException("Unexpected timeout exception", cause); @@ -426,8 +467,12 @@ public class CancelRequestTest implements HttpServerAdapters { } }; - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + // Populate alt-svc registry with h3 service + if (uri.contains("h2h3")) headRequest(client); + Http3DiscoveryMode config = uri.contains("h3-only") ? h3TestServer.h3DiscoveryConfig() : null; + HttpRequest req = requestBuilder(uri) .POST(HttpRequest.BodyPublishers.ofByteArrays(iterable)) + .setOption(H3_DISCOVERY, config) .build(); BodyHandler handler = BodyHandlers.ofString(); CountDownLatch latch = new CountDownLatch(1); @@ -473,8 +518,13 @@ public class CancelRequestTest implements HttpServerAdapters { } catch (ExecutionException x) { assertTrue(response.isDone()); Throwable wrapped = x.getCause(); - assertTrue(CancellationException.class.isAssignableFrom(wrapped.getClass())); - Throwable cause = wrapped.getCause(); + Throwable cause = wrapped; + if (CancellationException.class.isAssignableFrom(wrapped.getClass())) { + cause = wrapped.getCause(); + out.println("CancellationException cause: " + x); + } else if (!isCancelled(cause)) { + throw new RuntimeException("Unexpected cause: " + cause); + } assertTrue(IOException.class.isAssignableFrom(cause.getClass())); if (cause instanceof HttpConnectTimeoutException) { cause.printStackTrace(out); @@ -536,8 +586,12 @@ public class CancelRequestTest implements HttpServerAdapters { return List.of(BODY.getBytes(UTF_8)).iterator(); }; - HttpRequest req = HttpRequest.newBuilder(URI.create(uriStr)) + // Populate alt-svc registry with h3 service + if (uri.contains("h2h3")) headRequest(client); + Http3DiscoveryMode config = uri.contains("h3-only") ? h3TestServer.h3DiscoveryConfig() : null; + HttpRequest req = requestBuilder(uriStr) .POST(HttpRequest.BodyPublishers.ofByteArrays(iterable)) + .setOption(H3_DISCOVERY, config) .build(); String body = null; Exception failed = null; @@ -613,11 +667,24 @@ public class CancelRequestTest implements HttpServerAdapters { https2TestServer.addHandler(h2_chunkedHandler, "/https2/x/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/x/"; - serverCount.addAndGet(4); + HttpTestHandler h3_chunkedHandler = new HTTPSlowHandler(); + h2h3TestServer = HttpTestServer.create(HTTP_3, sslContext); + h2h3TestServer.addHandler(h3_chunkedHandler, "/h2h3/exec/"); + h2h3URI = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/exec/retry"; + h2h3TestServer.addHandler(new HttpHeadOrGetHandler(), "/h2h3/head/"); + h2h3Head = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/head/"; + + h3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3TestServer.addHandler(h3_chunkedHandler, "/h3-only/exec/"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3-only/exec/retry"; + + serverCount.addAndGet(6); httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + h2h3TestServer.start(); + h3TestServer.start(); } @AfterTest @@ -632,6 +699,8 @@ public class CancelRequestTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + h2h3TestServer.stop(); + h3TestServer.stop(); } finally { if (fail != null) { if (sharedClientName != null) { diff --git a/test/jdk/java/net/httpclient/CancelStreamedBodyTest.java b/test/jdk/java/net/httpclient/CancelStreamedBodyTest.java index e0199b18e68..e304168ba33 100644 --- a/test/jdk/java/net/httpclient/CancelStreamedBodyTest.java +++ b/test/jdk/java/net/httpclient/CancelStreamedBodyTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -32,11 +32,6 @@ * @run testng/othervm -Djdk.internal.httpclient.debug=true * CancelStreamedBodyTest */ -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; -import jdk.internal.net.http.common.OperationTrackers.Tracker; -import jdk.test.lib.RandomFactory; import jdk.test.lib.net.SimpleSSLContext; import org.testng.ITestContext; import org.testng.ITestResult; @@ -53,25 +48,15 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.lang.ref.Reference; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; -import java.net.http.HttpConnectTimeoutException; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.net.http.HttpResponse.BodyHandler; import java.net.http.HttpResponse.BodyHandlers; import java.util.Arrays; -import java.util.Iterator; import java.util.List; -import java.util.Random; -import java.util.concurrent.CancellationException; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicLong; @@ -79,12 +64,13 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.Stream; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; -import static java.lang.System.arraycopy; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -97,10 +83,12 @@ public class CancelStreamedBodyTest implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String https3URI; static final long SERVER_LATENCY = 75; static final int ITERATION_COUNT = 3; @@ -198,6 +186,7 @@ public class CancelStreamedBodyTest implements HttpServerAdapters { private String[] uris() { return new String[] { + https3URI, httpURI, httpsURI, http2URI, @@ -221,9 +210,9 @@ public class CancelStreamedBodyTest implements HttpServerAdapters { return result; } - private HttpClient makeNewClient() { + private HttpClient makeNewClient(HttpClient.Builder builder) { clientCount.incrementAndGet(); - var client = HttpClient.newBuilder() + var client = builder .proxy(HttpClient.Builder.NO_PROXY) .executor(executor) .sslContext(sslContext) @@ -236,21 +225,45 @@ public class CancelStreamedBodyTest implements HttpServerAdapters { return TRACKER.track(client); } - HttpClient newHttpClient(boolean share) { - if (!share) return makeNewClient(); + private Version version(String uri) { + if (uri == null) return null; + if (uri.contains("/http3/")) return HTTP_3; + if (uri.contains("/http2/")) return HTTP_2; + if (uri.contains("/https2/")) return HTTP_2; + if (uri.contains("/http1/")) return HTTP_1_1; + if (uri.contains("/https1/")) return HTTP_1_1; + return null; + } + + HttpClient makeNewClient(Version version) { + var builder = (version == HTTP_3) + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + return makeNewClient(builder); + } + + HttpClient newHttpClient(boolean share, String uri) { + if (!share) return makeNewClient(version(uri)); HttpClient shared = sharedClient; if (shared != null) return shared; synchronized (this) { shared = sharedClient; if (shared == null) { - shared = sharedClient = makeNewClient(); + shared = sharedClient = makeNewClient(HTTP_3); } return shared; } } - final static String BODY = "Some string |\n that ?\n can |\n be split ?\n several |\n ways."; + HttpRequest.Builder requestBuilder(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)); + var version = version(uri); + return version == HTTP_3 + ? builder.version(HTTP_3).setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + : builder; + } + final static String BODY = "Some string |\n that ?\n can |\n be split ?\n several |\n ways."; @Test(dataProvider = "urls") public void testAsLines(String uri, boolean sameClient) @@ -261,10 +274,10 @@ public class CancelStreamedBodyTest implements HttpServerAdapters { out.printf("%n%s testAsLines(%s, %b)%n", now(), uri, sameClient); for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(sameClient); + client = newHttpClient(sameClient, uri); var tracker = TRACKER.getTracker(client); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = requestBuilder(uri) .GET() .build(); List lines; @@ -302,10 +315,10 @@ public class CancelStreamedBodyTest implements HttpServerAdapters { out.printf("%n%s testInputStream(%s, %b)%n", now(), uri, sameClient); for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(sameClient); + client = newHttpClient(sameClient, uri); var tracker = TRACKER.getTracker(client); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = requestBuilder(uri) .GET() .build(); int read = -1; @@ -318,7 +331,7 @@ public class CancelStreamedBodyTest implements HttpServerAdapters { } // Only check our still alive client for outstanding operations // and outstanding subscribers here: it should have none. - var error = TRACKER.check(tracker, 1, + var error = TRACKER.check(tracker, 500, (t) -> t.getOutstandingOperations() > 0 || t.getOutstandingSubscribers() > 0, "subscribers for testInputStream(%s)\n\t step [%s,%s]".formatted(req.uri(), i,j), false); @@ -364,11 +377,16 @@ public class CancelStreamedBodyTest implements HttpServerAdapters { https2TestServer.addHandler(h2_chunkedHandler, "/https2/x/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/x/"; - serverCount.addAndGet(4); + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(h2_chunkedHandler, "/http3/x/"); + https3URI = "https://" + http3TestServer.serverAuthority() + "/http3/x/"; + + serverCount.addAndGet(5); httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -386,6 +404,7 @@ public class CancelStreamedBodyTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } finally { if (fail != null) { if (sharedClientName != null) { diff --git a/test/jdk/java/net/httpclient/http2/ExpectContinueResetTest.java b/test/jdk/java/net/httpclient/CancelledPartialResponseTest.java similarity index 52% rename from test/jdk/java/net/httpclient/http2/ExpectContinueResetTest.java rename to test/jdk/java/net/httpclient/CancelledPartialResponseTest.java index 5d09c01b96f..41493bbdaff 100644 --- a/test/jdk/java/net/httpclient/http2/ExpectContinueResetTest.java +++ b/test/jdk/java/net/httpclient/CancelledPartialResponseTest.java @@ -23,16 +23,19 @@ /* * @test - * @summary Verifies that the client reacts correctly to receiving RST_STREAM at various stages of - * a Partial Response. + * @summary Verifies that the client reacts correctly to receiving RST_STREAM or StopSendingFrame at various stages of + * a Partial/Expect-continue type Response for HTTP/2 and HTTP/3. * @bug 8309118 * @library /test/lib /test/jdk/java/net/httpclient/lib * @build jdk.httpclient.test.lib.common.HttpServerAdapters - * @run testng/othervm/timeout=40 -Djdk.internal.httpclient.debug=true -Djdk.httpclient.HttpClient.log=trace,errors,headers - * ExpectContinueResetTest + * @run testng/othervm/timeout=40 -Djdk.internal.httpclient.debug=false -Djdk.httpclient.HttpClient.log=trace,errors,headers + * CancelledPartialResponseTest */ import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestExchange; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestHandler; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; import jdk.httpclient.test.lib.http2.BodyOutputStream; import jdk.httpclient.test.lib.http2.Http2Handler; import jdk.httpclient.test.lib.http2.Http2TestExchange; @@ -42,12 +45,15 @@ import jdk.httpclient.test.lib.http2.Http2TestServerConnection; import jdk.internal.net.http.common.HttpHeadersBuilder; import jdk.internal.net.http.frame.ResetFrame; +import jdk.internal.net.http.http3.Http3Error; +import jdk.test.lib.net.SimpleSSLContext; import org.testng.TestException; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSession; import java.io.IOException; import java.io.InputStream; @@ -55,47 +61,58 @@ import java.io.OutputStream; import java.io.PrintStream; import java.net.URI; import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; import java.net.http.HttpHeaders; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.util.Iterator; import java.util.concurrent.ExecutionException; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; -import static org.testng.Assert.*; -public class ExpectContinueResetTest { +public class CancelledPartialResponseTest { Http2TestServer http2TestServer; + + HttpTestServer http3TestServer; + // "NoError" urls complete with an exception. "NoError" or "Error" here refers to the error code in the RST_STREAM frame // and not the outcome of the test. - URI warmup, partialResponseResetNoError, partialResponseResetError, fullResponseResetNoError, fullResponseResetError; + URI warmup, h2PartialResponseResetNoError, h2PartialResponseResetError, h2FullResponseResetNoError, h2FullResponseResetError; + URI h3PartialResponseStopSending, h3FullResponseStopSending; + + SSLContext sslContext; static PrintStream err = new PrintStream(System.err); static PrintStream out = System.out; + // TODO: Investigate further if checking against HTTP/3 Full Response is necessary @DataProvider(name = "testData") public Object[][] testData() { - // Not consuming the InputStream in the server's handler results in different handling of RST_STREAM client-side return new Object[][] { - { partialResponseResetNoError }, - { partialResponseResetError }, // Checks RST_STREAM is processed if client sees no END_STREAM - { fullResponseResetNoError }, - { fullResponseResetError } + { HTTP_2, h2PartialResponseResetNoError }, + { HTTP_2, h2PartialResponseResetError }, // Checks RST_STREAM is processed if client sees no END_STREAM + { HTTP_2, h2FullResponseResetNoError }, + { HTTP_2, h2FullResponseResetError }, + { HTTP_3, h3PartialResponseStopSending }, // All StopSending frames received by client throw exception regardless of code + { HTTP_3, h3FullResponseStopSending } }; } @Test(dataProvider = "testData") - public void test(URI uri) { - out.printf("\nTesting with Version: %s, URI: %s\n", HTTP_2, uri.toASCIIString()); - err.printf("\nTesting with Version: %s, URI: %s\n", HTTP_2, uri.toASCIIString()); + public void test(Version version, URI uri) { + out.printf("\nTesting with Version: %s, URI: %s\n", version, uri.toASCIIString()); + err.printf("\nTesting with Version: %s, URI: %s\n", version, uri.toASCIIString()); Iterable iterable = EndlessDataChunks::new; HttpRequest.BodyPublisher testPub = HttpRequest.BodyPublishers.ofByteArrays(iterable); Exception expectedException = null; try { - performRequest(testPub, uri); + performRequest(version, testPub, uri); throw new AssertionError("Expected exception not raised for " + uri); } catch (Exception e) { expectedException = e; @@ -105,20 +122,37 @@ public class ExpectContinueResetTest { throw new AssertionError("Unexpected null cause for " + expectedException, expectedException); } - assertEquals(testThrowable.getClass(), IOException.class, - "Test should have closed with an IOException"); - testThrowable.printStackTrace(); + if (!(testThrowable instanceof IOException)) { + throw new AssertionError( + "Test should have closed with an IOException, got: " + testThrowable, + testThrowable); + } + if (version == HTTP_3) { + if (testThrowable.getMessage().contains(Http3Error.H3_EXCESSIVE_LOAD.name())) { + System.out.println("Got expected message: " + testThrowable.getMessage()); + } else { + throw new AssertionError("Expected " + Http3Error.H3_EXCESSIVE_LOAD.name() + + ", got " + testThrowable, testThrowable); + } + } } static public class EndlessDataChunks implements Iterator { - byte[] data = new byte[16]; + byte[] data = new byte[32]; + boolean hasNext = true; @Override public boolean hasNext() { - return true; + return hasNext; } @Override public byte[] next() { + try { + Thread.sleep(2500); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + hasNext = false; return data; } @Override @@ -129,20 +163,32 @@ public class ExpectContinueResetTest { @BeforeTest public void setup() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) + throw new AssertionError("Unexpected null sslContext"); + http2TestServer = new Http2TestServer(false, 0); + http3TestServer = HttpTestServer.create(Http3DiscoveryMode.HTTP_3_URI_ONLY, sslContext); + http2TestServer.setExchangeSupplier(ExpectContinueResetTestExchangeImpl::new); http2TestServer.addHandler(new GetHandler().toHttp2Handler(), "/warmup"); - http2TestServer.addHandler(new NoEndStreamOnPartialResponse(), "/partialResponse/codeNoError"); - http2TestServer.addHandler(new NoEndStreamOnPartialResponse(), "/partialResponse/codeError"); - http2TestServer.addHandler(new NoEndStreamOnFullResponse(), "/fullResponse/codeNoError"); - http2TestServer.addHandler(new NoEndStreamOnFullResponse(), "/fullResponse/codeError"); + http2TestServer.addHandler(new PartialResponseResetStreamH2(), "/partialResponse/codeNoError"); + http2TestServer.addHandler(new PartialResponseResetStreamH2(), "/partialResponse/codeError"); + http2TestServer.addHandler(new FullResponseResetStreamH2(), "/fullResponse/codeNoError"); + http2TestServer.addHandler(new FullResponseResetStreamH2(), "/fullResponse/codeError"); + http3TestServer.addHandler(new PartialResponseStopSendingH3(), "/partialResponse/codeNoError"); + http3TestServer.addHandler(new FullResponseStopSendingH3(), "/fullResponse/codeNoError"); warmup = URI.create("http://" + http2TestServer.serverAuthority() + "/warmup"); - partialResponseResetNoError = URI.create("http://" + http2TestServer.serverAuthority() + "/partialResponse/codeNoError"); - partialResponseResetError = URI.create("http://" + http2TestServer.serverAuthority() + "/partialResponse/codeError"); - fullResponseResetNoError = URI.create("http://" + http2TestServer.serverAuthority() + "/fullResponse/codeNoError"); - fullResponseResetError = URI.create("http://" + http2TestServer.serverAuthority() + "/fullResponse/codeError"); + h2PartialResponseResetNoError = URI.create("http://" + http2TestServer.serverAuthority() + "/partialResponse/codeNoError"); + h2PartialResponseResetError = URI.create("http://" + http2TestServer.serverAuthority() + "/partialResponse/codeError"); + h2FullResponseResetNoError = URI.create("http://" + http2TestServer.serverAuthority() + "/fullResponse/codeNoError"); + h2FullResponseResetError = URI.create("http://" + http2TestServer.serverAuthority() + "/fullResponse/codeError"); + h3PartialResponseStopSending = URI.create("https://" + http3TestServer.serverAuthority() + "/partialResponse/codeNoError"); + h3FullResponseStopSending = URI.create("https://" + http3TestServer.serverAuthority() + "/fullResponse/codeNoError"); + http2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -150,26 +196,38 @@ public class ExpectContinueResetTest { http2TestServer.stop(); } - private void performRequest(HttpRequest.BodyPublisher bodyPublisher, URI uri) + private void performRequest(Version version, HttpRequest.BodyPublisher bodyPublisher, URI uri) throws IOException, InterruptedException, ExecutionException { - try (HttpClient client = HttpClient.newBuilder().proxy(HttpClient.Builder.NO_PROXY).version(HTTP_2).build()) { + + HttpClient.Builder builder = HttpServerAdapters.createClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .version(version) + .sslContext(sslContext); + Http3DiscoveryMode requestConfig = null; + if (version == HTTP_3) + requestConfig = Http3DiscoveryMode.HTTP_3_URI_ONLY; + + try (HttpClient client = builder.build()) { err.printf("Performing warmup request to %s", warmup); - client.send(HttpRequest.newBuilder(warmup).GET().version(HTTP_2).build(), HttpResponse.BodyHandlers.discarding()); + if (version == HTTP_2) + client.send(HttpRequest.newBuilder(warmup).GET().version(HTTP_2).build(), + HttpResponse.BodyHandlers.discarding()); + HttpRequest postRequest = HttpRequest.newBuilder(uri) - .version(HTTP_2) + .version(version) .POST(bodyPublisher) + .setOption(H3_DISCOVERY, requestConfig) .expectContinue(true) .build(); - err.printf("Sending request (%s): %s%n", HTTP_2, postRequest); - // TODO: when test is stable and complete, see then if fromSubscriber makes our subscriber non null + err.printf("Sending request (%s): %s%n", version, postRequest); client.sendAsync(postRequest, HttpResponse.BodyHandlers.ofString()).get(); } } - static class GetHandler implements HttpServerAdapters.HttpTestHandler { + static class GetHandler implements HttpTestHandler { @Override - public void handle(HttpServerAdapters.HttpTestExchange exchange) throws IOException { + public void handle(HttpTestExchange exchange) throws IOException { try (OutputStream os = exchange.getResponseBody()) { byte[] bytes = "Response Body".getBytes(UTF_8); err.printf("Server sending 200 (length=%s)", bytes.length); @@ -180,7 +238,7 @@ public class ExpectContinueResetTest { } } - static class NoEndStreamOnPartialResponse implements Http2Handler { + static class PartialResponseResetStreamH2 implements Http2Handler { @Override public void handle(Http2TestExchange exchange) throws IOException { @@ -200,14 +258,14 @@ public class ExpectContinueResetTest { } } - static class NoEndStreamOnFullResponse implements Http2Handler { + static class FullResponseResetStreamH2 implements Http2Handler { @Override public void handle(Http2TestExchange exchange) throws IOException { err.println("Sending 100"); exchange.sendResponseHeaders(100, -1); - err.println("Sending 200"); + err.println("Sending 200"); exchange.sendResponseHeaders(200, 0); if (exchange instanceof ExpectContinueResetTestExchangeImpl testExchange) { err.println("Sending Reset"); @@ -222,6 +280,38 @@ public class ExpectContinueResetTest { } } + static class PartialResponseStopSendingH3 implements HttpTestHandler { + + @Override + public void handle(HttpTestExchange exchange) throws IOException { + err.println("Sending 100"); + exchange.sendResponseHeaders(100, 0); + // sending StopSending(NO_ERROR) before or after sending 100 with no data + // should not fail. + System.err.println("Sending StopSendingFrame"); + exchange.requestStopSending(Http3Error.H3_REQUEST_REJECTED.code()); + // Not resetting the stream would cause the client to wait forever + exchange.resetStream(Http3Error.H3_EXCESSIVE_LOAD.code()); + } + } + + static class FullResponseStopSendingH3 implements HttpTestHandler { + + @Override + public void handle(HttpTestExchange exchange) throws IOException { + err.println("Sending 100"); + exchange.sendResponseHeaders(100, 0); + err.println("Sending 200"); + + // sending StopSending before or after sending 200 with no data + // should not fail. + err.println("Sending StopSendingFrame"); + exchange.requestStopSending(Http3Error.H3_REQUEST_REJECTED.code()); + exchange.sendResponseHeaders(200, 10); + exchange.resetStream(Http3Error.H3_EXCESSIVE_LOAD.code()); + } + } + static class ExpectContinueResetTestExchangeImpl extends Http2TestExchangeImpl { public ExpectContinueResetTestExchangeImpl(int streamid, String method, HttpHeaders reqheaders, HttpHeadersBuilder rspheadersBuilder, URI uri, InputStream is, SSLSession sslSession, BodyOutputStream os, Http2TestServerConnection conn, boolean pushAllowed) { diff --git a/test/jdk/java/net/httpclient/CancelledResponse.java b/test/jdk/java/net/httpclient/CancelledResponse.java index 44be38317dc..1f28ce87221 100644 --- a/test/jdk/java/net/httpclient/CancelledResponse.java +++ b/test/jdk/java/net/httpclient/CancelledResponse.java @@ -23,7 +23,6 @@ import java.net.http.HttpClient; import java.net.http.HttpClient.Version; -import java.net.http.HttpHeaders; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import jdk.test.lib.net.SimpleSSLContext; @@ -50,7 +49,10 @@ import java.net.http.HttpResponse.BodySubscriber; import static java.lang.String.format; import static java.lang.System.out; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.ISO_8859_1; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; /** * @test @@ -164,7 +166,10 @@ public class CancelledResponse { server.start(); HttpClient client = newHttpClient(); - HttpRequest request = HttpRequest.newBuilder(uri).version(version).build(); + HttpRequest request = HttpRequest.newBuilder(uri) + .setOption(H3_DISCOVERY, version == HTTP_3 ? ALT_SVC : null) + .version(version) + .build(); try { for (int i = 0; i < responses.length; i++) { HttpResponse r = null; diff --git a/test/jdk/java/net/httpclient/CancelledResponse2.java b/test/jdk/java/net/httpclient/CancelledResponse2.java index 263620acf81..dfcdb03f06e 100644 --- a/test/jdk/java/net/httpclient/CancelledResponse2.java +++ b/test/jdk/java/net/httpclient/CancelledResponse2.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -21,8 +21,8 @@ * questions. */ -import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestExchange; -import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.internal.net.http.common.OperationTrackers.Tracker; import jdk.test.lib.RandomFactory; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; @@ -38,6 +38,7 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.nio.ByteBuffer; @@ -51,8 +52,11 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static java.lang.System.out; -import static java.net.http.HttpClient.Version.*; -import static jdk.httpclient.test.lib.common.HttpServerAdapters.*; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -60,27 +64,36 @@ import static org.testng.Assert.assertTrue; * @test * @library /test/lib /test/jdk/java/net/httpclient/lib * @build jdk.test.lib.net.SimpleSSLContext + * @compile ReferenceTracker.java * @run testng/othervm -Djdk.internal.httpclient.debug=true CancelledResponse2 */ +// -Djdk.internal.httpclient.debug=true +public class CancelledResponse2 implements HttpServerAdapters { -public class CancelledResponse2 { - + private static final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; + private static final Random RANDOM = RandomFactory.getRandom(); + private static final int MAX_CLIENT_DELAY = 160; HttpTestServer h2TestServer; URI h2TestServerURI; - private SSLContext sslContext; - private static final Random random = RandomFactory.getRandom(); - private static final int MAX_CLIENT_DELAY = 160; + URI h2h3TestServerURI; + URI h2h3HeadTestServerURI; + URI h3TestServerURI; + HttpTestServer h2h3TestServer; + HttpTestServer h3TestServer; + SSLContext sslContext; @DataProvider(name = "versions") public Object[][] positive() { return new Object[][]{ - { HTTP_2, h2TestServerURI }, + { HTTP_2, null, h2TestServerURI }, + { HTTP_3, null, h2h3TestServerURI }, + { HTTP_3, HTTP_3_URI_ONLY, h3TestServerURI }, }; } private static void delay() { - int delay = random.nextInt(MAX_CLIENT_DELAY); + int delay = RANDOM.nextInt(MAX_CLIENT_DELAY); try { System.out.println("client delay: " + delay); Thread.sleep(delay); @@ -88,13 +101,25 @@ public class CancelledResponse2 { out.println("Unexpected exception: " + x); } } - @Test(dataProvider = "versions") - public void test(Version version, URI uri) throws Exception { + public void test(Version version, Http3DiscoveryMode config, URI uri) throws Exception { for (int i = 0; i < 5; i++) { - HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).version(version).build(); + HttpClient httpClient = newClientBuilderForH3().sslContext(sslContext).version(version).build(); + Http3DiscoveryMode reqConfig = null; + if (version.equals(HTTP_3)) { + if (config != null) { + reqConfig = (config.equals(HTTP_3_URI_ONLY)) ? HTTP_3_URI_ONLY : ALT_SVC; + } + // if config is null, we are talking to the H2H3 server, which may + // not support direct connection, in which case we should send a headRequest + if ((config == null && !h2h3TestServer.supportsH3DirectConnection()) + || (reqConfig != null && reqConfig.equals(ALT_SVC))) { + headRequest(httpClient); + } + } HttpRequest httpRequest = HttpRequest.newBuilder(uri) .version(version) + .setOption(H3_DISCOVERY, reqConfig) .GET() .build(); AtomicBoolean cancelled = new AtomicBoolean(); @@ -104,26 +129,56 @@ public class CancelledResponse2 { cf.get(); } catch (Exception e) { e.printStackTrace(); - assertTrue(e.getCause() instanceof IOException, "HTTP/2 should cancel with an IOException when the Subscription is cancelled."); + assertTrue(e.getCause() instanceof IOException, "HTTP/2 & HTTP/3 should cancel with an IOException when the Subscription is cancelled."); } assertTrue(cf.isCompletedExceptionally()); assertTrue(cancelled.get()); + + Tracker tracker = TRACKER.getTracker(httpClient); + httpClient = null; + var error = TRACKER.check(tracker, 5000); + if (error != null) throw error; } } + void headRequest(HttpClient client) throws Exception { + HttpRequest request = HttpRequest.newBuilder(h2h3HeadTestServerURI) + .version(HTTP_2) + .HEAD() + .build(); + var resp = client.send(request, HttpResponse.BodyHandlers.discarding()); + assertEquals(resp.statusCode(), 200); + } + @BeforeTest public void setup() throws IOException { sslContext = new SimpleSSLContext().get(); + h2TestServer = HttpTestServer.create(HTTP_2, sslContext); h2TestServer.addHandler(new CancelledResponseHandler(), "/h2"); h2TestServerURI = URI.create("https://" + h2TestServer.serverAuthority() + "/h2"); + h2h3TestServer = HttpTestServer.create(HTTP_3, sslContext); + h2h3TestServer.addHandler(new CancelledResponseHandler(), "/h2h3"); + h2h3TestServerURI = URI.create("https://" + h2h3TestServer.serverAuthority() + "/h2h3"); + h2h3TestServer.addHandler(new HttpHeadOrGetHandler(), "/h2h3/head"); + h2h3HeadTestServerURI = URI.create("https://" + h2h3TestServer.serverAuthority() + "/h2h3/head"); + + + h3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3TestServer.addHandler(new CancelledResponseHandler(), "/h3"); + h3TestServerURI = URI.create("https://" + h3TestServer.serverAuthority() + "/h3"); + h2TestServer.start(); + h2h3TestServer.start(); + h3TestServer.start(); } @AfterTest public void teardown() { h2TestServer.stop(); + h2h3TestServer.stop(); + h3TestServer.stop(); } BodyHandler ofString(String expected, AtomicBoolean cancelled) { @@ -234,6 +289,7 @@ public class CancelledResponse2 { } } + @Override public void onError(Throwable throwable) { result.completeExceptionally(throwable); diff --git a/test/jdk/java/net/httpclient/ConcurrentResponses.java b/test/jdk/java/net/httpclient/ConcurrentResponses.java index c68ebd0975c..1cabe461e80 100644 --- a/test/jdk/java/net/httpclient/ConcurrentResponses.java +++ b/test/jdk/java/net/httpclient/ConcurrentResponses.java @@ -30,16 +30,19 @@ * @build jdk.httpclient.test.lib.http2.Http2TestServer jdk.test.lib.net.SimpleSSLContext * jdk.httpclient.test.lib.common.TestServerConfigurator * @run testng/othervm - * -Djdk.httpclient.HttpClient.log=headers,errors,channel + * -Djdk.internal.httpclient.debug=true * ConcurrentResponses */ +//* -Djdk.internal.httpclient.HttpClient.log=all import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.URI; +import java.net.http.HttpClient.Version; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.List; @@ -64,6 +67,9 @@ import java.net.http.HttpResponse.BodyHandlers; import java.net.http.HttpResponse.BodySubscriber; import java.net.http.HttpResponse.BodySubscribers; +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestHandler; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; import jdk.httpclient.test.lib.common.TestServerConfigurator; import jdk.httpclient.test.lib.http2.Http2TestServer; import jdk.httpclient.test.lib.http2.Http2TestExchange; @@ -73,6 +79,8 @@ import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; + +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static java.net.http.HttpResponse.BodyHandlers.discarding; import static org.testng.Assert.assertEquals; @@ -86,8 +94,10 @@ public class ConcurrentResponses { HttpsServer httpsTestServer; // HTTPS/1.1 Http2TestServer http2TestServer; // HTTP/2 ( h2c ) Http2TestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer https3TestServer; String httpFixedURI, httpsFixedURI, httpChunkedURI, httpsChunkedURI; String http2FixedURI, https2FixedURI, http2VariableURI, https2VariableURI; + String https3FixedURI, https3VariableURI; static final int CONCURRENT_REQUESTS = 13; static final AtomicInteger IDS = new AtomicInteger(); @@ -145,7 +155,9 @@ public class ConcurrentResponses { { http2FixedURI }, { https2FixedURI }, { http2VariableURI }, - { https2VariableURI } + { https2VariableURI }, + { https3FixedURI }, + { https3VariableURI } }; } @@ -157,20 +169,25 @@ public class ConcurrentResponses { int id = IDS.getAndIncrement(); ExecutorService virtualExecutor = Executors.newThreadPerTaskExecutor(Thread.ofVirtual() .name("HttpClient-" + id + "-Worker", 0).factory()); - HttpClient client = HttpClient.newBuilder() - .sslContext(sslContext) + var http3 = uri.contains("/https3/"); + Http3DiscoveryMode config = http3 ? Http3DiscoveryMode.HTTP_3_URI_ONLY : null; + var builder = http3 ? HttpServerAdapters.createClientBuilderForH3() : HttpClient.newBuilder(); + if (http3) builder.version(Version.HTTP_3); + HttpClient client = builder .executor(virtualExecutor) - .build(); + .sslContext(sslContext).build(); try { Map requests = new HashMap<>(); for (int i = 0; i < CONCURRENT_REQUESTS; i++) { HttpRequest request = HttpRequest.newBuilder(URI.create(uri + "?" + i)) + .setOption(H3_DISCOVERY, config) .build(); requests.put(request, BODIES[i]); } // initial connection to seed the cache so next parallel connections reuse it - client.sendAsync(HttpRequest.newBuilder(URI.create(uri)).build(), discarding()).join(); + client.sendAsync(HttpRequest.newBuilder(URI.create(uri)) + .setOption(H3_DISCOVERY, config).build(), discarding()).join(); // will reuse connection cached from the previous request ( when HTTP/2 ) CompletableFuture.allOf(requests.keySet().parallelStream() @@ -192,19 +209,25 @@ public class ConcurrentResponses { int id = IDS.getAndIncrement(); ExecutorService virtualExecutor = Executors.newThreadPerTaskExecutor(Thread.ofVirtual() .name("HttpClient-" + id + "-Worker", 0).factory()); - HttpClient client = HttpClient.newBuilder() + var http3 = uri.contains("/https3/"); + Http3DiscoveryMode config = http3 ? Http3DiscoveryMode.HTTP_3_URI_ONLY : null; + var builder = http3 ? HttpServerAdapters.createClientBuilderForH3() : HttpClient.newBuilder(); + if (http3) builder.version(Version.HTTP_3); + HttpClient client = builder .executor(virtualExecutor) .sslContext(sslContext).build(); try { Map requests = new HashMap<>(); for (int i = 0; i < CONCURRENT_REQUESTS; i++) { HttpRequest request = HttpRequest.newBuilder(URI.create(uri + "?" + i)) + .setOption(H3_DISCOVERY, config) .build(); requests.put(request, BODIES[i]); } // initial connection to seed the cache so next parallel connections reuse it - client.sendAsync(HttpRequest.newBuilder(URI.create(uri)).build(), discarding()).join(); + client.sendAsync(HttpRequest.newBuilder(URI.create(uri)) + .setOption(H3_DISCOVERY, config).build(), discarding()).join(); // will reuse connection cached from the previous request ( when HTTP/2 ) CompletableFuture.allOf(requests.keySet().parallelStream() @@ -310,10 +333,17 @@ public class ConcurrentResponses { https2TestServer.addHandler(new Http2VariableHandler(), "/https2/variable"); https2VariableURI = "https://" + https2TestServer.serverAuthority() + "/https2/variable"; + https3TestServer = HttpTestServer.create(Http3DiscoveryMode.HTTP_3_URI_ONLY, sslContext); + https3TestServer.addHandler(new Http3FixedHandler(), "/https3/fixed"); + https3FixedURI = "https://" + https3TestServer.serverAuthority() + "/https3/fixed"; + https3TestServer.addHandler(new Http3VariableHandler(), "/https3/variable"); + https3VariableURI = "https://" + https3TestServer.serverAuthority() + "/https3/variable"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + https3TestServer.start(); } @AfterTest @@ -322,6 +352,7 @@ public class ConcurrentResponses { httpsTestServer.stop(0); http2TestServer.stop(); https2TestServer.stop(); + https3TestServer.stop(); } interface SendResponseHeadersFunction { @@ -407,4 +438,26 @@ public class ConcurrentResponses { (rcode, ignored) -> t.sendResponseHeaders(rcode, 0 /* no Content-Length */)); } } + + static class Http3FixedHandler implements HttpTestHandler { + + @Override + public void handle(HttpServerAdapters.HttpTestExchange t) throws IOException { + serverHandlerImpl(t.getRequestBody(), + t.getResponseBody(), + t.getRequestURI(), + (rcode, length) -> t.sendResponseHeaders(rcode, length)); + } + } + + static class Http3VariableHandler implements HttpTestHandler { + + @Override + public void handle(HttpServerAdapters.HttpTestExchange t) throws IOException { + serverHandlerImpl(t.getRequestBody(), + t.getResponseBody(), + t.getRequestURI(), + (rcode, ignored) -> t.sendResponseHeaders(rcode, -1/* no Content-Length */)); + } + } } diff --git a/test/jdk/java/net/httpclient/ContentLengthHeaderTest.java b/test/jdk/java/net/httpclient/ContentLengthHeaderTest.java index 0ff21c23a8f..f302de4ee48 100644 --- a/test/jdk/java/net/httpclient/ContentLengthHeaderTest.java +++ b/test/jdk/java/net/httpclient/ContentLengthHeaderTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -30,7 +30,9 @@ * @build jdk.test.lib.net.SimpleSSLContext * jdk.httpclient.test.lib.common.HttpServerAdapters * @bug 8283544 - * @run testng/othervm ContentLengthHeaderTest + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * ContentLengthHeaderTest */ @@ -51,6 +53,7 @@ import java.net.http.HttpClient; import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; import java.util.Optional; import javax.net.ssl.SSLContext; import jdk.test.lib.net.URIBuilder; @@ -58,6 +61,8 @@ import jdk.test.lib.net.URIBuilder; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; @@ -69,24 +74,29 @@ public class ContentLengthHeaderTest implements HttpServerAdapters { static HttpTestServer testContentLengthServerH1; static HttpTestServer testContentLengthServerH2; + static HttpTestServer testContentLengthServerH3; static PrintStream testLog = System.err; static SSLContext sslContext; HttpClient hc; URI testContentLengthURIH1; URI testContentLengthURIH2; + URI testContentLengthURIH3; @BeforeTest - public void setup() throws IOException, URISyntaxException { + public void setup() throws IOException, URISyntaxException, InterruptedException { sslContext = new SimpleSSLContext().get(); testContentLengthServerH1 = HttpTestServer.create(HTTP_1_1); testContentLengthServerH2 = HttpTestServer.create(HTTP_2, sslContext); + testContentLengthServerH3 = HttpTestServer.create(HTTP_3, sslContext); // Create handlers for tests that check for the presence of a Content-length header testContentLengthServerH1.addHandler(new NoContentLengthHandler(), NO_BODY_PATH); testContentLengthServerH2.addHandler(new NoContentLengthHandler(), NO_BODY_PATH); + testContentLengthServerH3.addHandler(new NoContentLengthHandler(), NO_BODY_PATH); testContentLengthServerH1.addHandler(new ContentLengthHandler(), BODY_PATH); testContentLengthServerH2.addHandler(new OptionalContentLengthHandler(), BODY_PATH); + testContentLengthServerH3.addHandler(new OptionalContentLengthHandler(), BODY_PATH); testContentLengthURIH1 = URIBuilder.newBuilder() .scheme("http") .loopback() @@ -97,6 +107,11 @@ public class ContentLengthHeaderTest implements HttpServerAdapters { .loopback() .port(testContentLengthServerH2.getAddress().getPort()) .build(); + testContentLengthURIH3 = URIBuilder.newBuilder() + .scheme("https") + .loopback() + .port(testContentLengthServerH3.getAddress().getPort()) + .build(); testContentLengthServerH1.start(); testLog.println("HTTP/1.1 Server up at address: " + testContentLengthServerH1.getAddress()); @@ -106,25 +121,45 @@ public class ContentLengthHeaderTest implements HttpServerAdapters { testLog.println("HTTP/2 Server up at address: " + testContentLengthServerH2.getAddress()); testLog.println("Request URI for Client: " + testContentLengthURIH2); - hc = HttpClient.newBuilder() + testContentLengthServerH3.start(); + testLog.println("HTTP/3 Server up at address: " + + testContentLengthServerH3.getAddress()); + testLog.println("HTTP/3 Quic Endpoint up at address: " + + testContentLengthServerH3.getH3AltService().get().getAddress()); + testLog.println("Request URI for Client: " + testContentLengthURIH3); + + hc = newClientBuilderForH3() .proxy(HttpClient.Builder.NO_PROXY) .sslContext(sslContext) .build(); + var firstReq = HttpRequest.newBuilder(URI.create(testContentLengthURIH3 + NO_BODY_PATH)) + .setOption(H3_DISCOVERY, testContentLengthServerH3.h3DiscoveryConfig()) + .HEAD() + .version(HTTP_2) + .build(); + // populate alt-service registry + var resp = hc.send(firstReq, BodyHandlers.ofString()); + assertEquals(resp.statusCode(), 200); + testLog.println("**** setup done ****"); } @AfterTest public void teardown() { + testLog.println("**** tearing down ****"); if (testContentLengthServerH1 != null) testContentLengthServerH1.stop(); if (testContentLengthServerH2 != null) testContentLengthServerH2.stop(); + if (testContentLengthServerH3 != null) + testContentLengthServerH3.stop(); } @DataProvider(name = "bodies") Object[][] bodies() { return new Object[][]{ {HTTP_1_1, URI.create(testContentLengthURIH1 + BODY_PATH)}, - {HTTP_2, URI.create(testContentLengthURIH2 + BODY_PATH)} + {HTTP_2, URI.create(testContentLengthURIH2 + BODY_PATH)}, + {HTTP_3, URI.create(testContentLengthURIH3 + BODY_PATH)} }; } @@ -132,7 +167,8 @@ public class ContentLengthHeaderTest implements HttpServerAdapters { Object[][] nobodies() { return new Object[][]{ {HTTP_1_1, URI.create(testContentLengthURIH1 + NO_BODY_PATH)}, - {HTTP_2, URI.create(testContentLengthURIH2 + NO_BODY_PATH)} + {HTTP_2, URI.create(testContentLengthURIH2 + NO_BODY_PATH)}, + {HTTP_3, URI.create(testContentLengthURIH3 + NO_BODY_PATH)} }; } diff --git a/test/jdk/java/net/httpclient/CookieHeaderTest.java b/test/jdk/java/net/httpclient/CookieHeaderTest.java index b39a23371ab..455b1048b06 100644 --- a/test/jdk/java/net/httpclient/CookieHeaderTest.java +++ b/test/jdk/java/net/httpclient/CookieHeaderTest.java @@ -75,6 +75,8 @@ import jdk.httpclient.test.lib.common.HttpServerAdapters; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -86,12 +88,14 @@ public class CookieHeaderTest implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) DummyServer httpDummyServer; DummyServer httpsDummyServer; String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; String httpDummy; String httpsDummy; @@ -115,6 +119,7 @@ public class CookieHeaderTest implements HttpServerAdapters { { httpsDummy, HTTP_1_1 }, { httpURI, HttpClient.Version.HTTP_2 }, { httpsURI, HttpClient.Version.HTTP_2 }, + { http3URI, HttpClient.Version.HTTP_3 }, { httpDummy, HttpClient.Version.HTTP_2 }, { httpsDummy, HttpClient.Version.HTTP_2 }, { http2URI, null }, @@ -130,7 +135,7 @@ public class CookieHeaderTest implements HttpServerAdapters { ConcurrentHashMap> cookieHeaders = new ConcurrentHashMap<>(); CookieHandler cookieManager = new TestCookieHandler(cookieHeaders); - HttpClient client = HttpClient.newBuilder() + HttpClient client = newClientBuilderForH3() .followRedirects(Redirect.ALWAYS) .cookieHandler(cookieManager) .sslContext(sslContext) @@ -150,6 +155,9 @@ public class CookieHeaderTest implements HttpServerAdapters { if (version != null) { requestBuilder.version(version); } + if (version == HTTP_3) { + requestBuilder.setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + } HttpRequest request = requestBuilder.build(); out.println("Initial request: " + request.uri()); @@ -157,7 +165,8 @@ public class CookieHeaderTest implements HttpServerAdapters { out.println("iteration: " + i); HttpResponse response = client.send(request, BodyHandlers.ofString()); - out.println(" Got response: " + response); + out.println(" Got response: " + response + ", config=" + request.getOption(H3_DISCOVERY) + + ", version=" + response.version()); out.println(" Got body Path: " + response.body()); assertEquals(response.statusCode(), 200); @@ -166,11 +175,17 @@ public class CookieHeaderTest implements HttpServerAdapters { cookies.stream() .filter(s -> !s.startsWith("LOC")) .collect(Collectors.toList())); + if (version == HTTP_3 && i > 0) { + assertEquals(response.version(), HTTP_3); + } requestBuilder = HttpRequest.newBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()); if (version != null) { requestBuilder.version(version); } + if (version == HTTP_3) { + requestBuilder.setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + } request = requestBuilder.build(); } } @@ -196,6 +211,9 @@ public class CookieHeaderTest implements HttpServerAdapters { https2TestServer = HttpTestServer.create(HTTP_2, sslContext); https2TestServer.addHandler(new CookieValidationHandler(), "/https2/cookie/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/cookie/retry"; + http3TestServer = HttpTestServer.create(HTTP_3, sslContext); + http3TestServer.addHandler(new CookieValidationHandler(), "/http3/cookie/"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/cookie/retry"; // DummyServer @@ -209,6 +227,7 @@ public class CookieHeaderTest implements HttpServerAdapters { httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); httpDummyServer.start(); httpsDummyServer.start(); } diff --git a/test/jdk/java/net/httpclient/CustomRequestPublisher.java b/test/jdk/java/net/httpclient/CustomRequestPublisher.java index 140a3a98ec4..c8d33030a06 100644 --- a/test/jdk/java/net/httpclient/CustomRequestPublisher.java +++ b/test/jdk/java/net/httpclient/CustomRequestPublisher.java @@ -25,13 +25,14 @@ * @test * @summary Checks correct handling of Publishers that call onComplete without demand * @library /test/lib /test/jdk/java/net/httpclient/lib - * @build jdk.httpclient.test.lib.common.HttpServerAdapters jdk.test.lib.net.SimpleSSLContext + * @build jdk.httpclient.test.lib.http2.Http2TestServer jdk.test.lib.net.SimpleSSLContext * @run testng/othervm CustomRequestPublisher */ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.URI; +import java.net.http.HttpClient.Builder; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Optional; @@ -57,9 +58,12 @@ import org.testng.annotations.Test; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.US_ASCII; import static java.net.http.HttpResponse.BodyHandlers.ofString; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @@ -70,10 +74,12 @@ public class CustomRequestPublisher implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; @DataProvider(name = "variants") public Object[][] variants() { @@ -98,19 +104,26 @@ public class CustomRequestPublisher implements HttpServerAdapters { { http2URI, unknownSupplier, true }, { https2URI, fixedSupplier, true,}, { https2URI, unknownSupplier, true }, + { http3URI, fixedSupplier, true,}, // always use same client with h3 + { http3URI, unknownSupplier, true }, // always use same client with h3 }; } static final int ITERATION_COUNT = 10; /** Asserts HTTP Version, and SSLSession presence when applicable. */ - static void assertVersionAndSession(HttpResponse response, String uri) { - if (uri.contains("http2") || uri.contains("https2")) + static void assertVersionAndSession(int step, HttpResponse response, String uri) { + if (uri.contains("http2") || uri.contains("https2")) { assertEquals(response.version(), HTTP_2); - else if (uri.contains("http1") || uri.contains("https1")) + } else if (uri.contains("http1") || uri.contains("https1")) { assertEquals(response.version(), HTTP_1_1); - else + } else if (uri.contains("http3")) { + if (step == 0) assertNotEquals(response.version(), HTTP_1_1); + else assertEquals(response.version(), HTTP_3, + "unexpected response version on step " + step); + } else { fail("Unknown HTTP version in test for: " + uri); + } Optional ssl = response.sslSession(); if (uri.contains("https")) { @@ -125,6 +138,28 @@ public class CustomRequestPublisher implements HttpServerAdapters { } } + HttpClient.Builder newHttpClientBuilder(String uri) { + HttpClient.Builder builder; + if (uri.contains("/http3/")) { + builder = newClientBuilderForH3(); + // ensure that the preferred version for the client + // is HTTP/3 + builder.version(HTTP_3); + } else builder = HttpClient.newBuilder(); + return builder.proxy(Builder.NO_PROXY); + } + + HttpRequest.Builder newHttpRequestBuilder(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)); + if (uri.contains("/http3/") && !http3TestServer.supportsH3DirectConnection()) { + // Ensure we don't attempt to connect to a + // potentially different server if HTTP/3 endpoint and + // HTTP/2 endpoint are not on the same port + builder.setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + } + return builder; + } + @Test(dataProvider = "variants") void test(String uri, Supplier bpSupplier, boolean sameClient) throws Exception @@ -132,10 +167,10 @@ public class CustomRequestPublisher implements HttpServerAdapters { HttpClient client = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = HttpClient.newBuilder().sslContext(sslContext).build(); + client = newHttpClientBuilder(uri).sslContext(sslContext).build(); BodyPublisher bodyPublisher = bpSupplier.get(); - HttpRequest request = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest request = newHttpRequestBuilder(uri) .POST(bodyPublisher) .build(); @@ -147,7 +182,7 @@ public class CustomRequestPublisher implements HttpServerAdapters { "Expected 200, got:" + resp.statusCode()); assertEquals(resp.body(), bodyPublisher.bodyAsString()); - assertVersionAndSession(resp, uri); + assertVersionAndSession(i, resp, uri); } } @@ -158,10 +193,10 @@ public class CustomRequestPublisher implements HttpServerAdapters { HttpClient client = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = HttpClient.newBuilder().sslContext(sslContext).build(); + client = newHttpClientBuilder(uri).sslContext(sslContext).build(); BodyPublisher bodyPublisher = bpSupplier.get(); - HttpRequest request = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest request = newHttpRequestBuilder(uri) .POST(bodyPublisher) .build(); @@ -174,7 +209,7 @@ public class CustomRequestPublisher implements HttpServerAdapters { "Expected 200, got:" + resp.statusCode()); assertEquals(resp.body(), bodyPublisher.bodyAsString()); - assertVersionAndSession(resp, uri); + assertVersionAndSession(0, resp, uri); } } @@ -325,10 +360,15 @@ public class CustomRequestPublisher implements HttpServerAdapters { https2TestServer.addHandler(new HttpTestEchoHandler(), "/https2/echo"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/echo"; + http3TestServer = HttpTestServer.create(HTTP_3, sslContext); + http3TestServer.addHandler(new HttpTestEchoHandler(), "/http3/echo"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/echo"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -337,6 +377,7 @@ public class CustomRequestPublisher implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } } diff --git a/test/jdk/java/net/httpclient/CustomResponseSubscriber.java b/test/jdk/java/net/httpclient/CustomResponseSubscriber.java index be71278a450..e8b5073d9a2 100644 --- a/test/jdk/java/net/httpclient/CustomResponseSubscriber.java +++ b/test/jdk/java/net/httpclient/CustomResponseSubscriber.java @@ -46,7 +46,6 @@ import com.sun.net.httpserver.HttpHandler; import com.sun.net.httpserver.HttpServer; import com.sun.net.httpserver.HttpsServer; import java.net.http.HttpClient; -import java.net.http.HttpHeaders; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; @@ -58,6 +57,7 @@ import jdk.httpclient.test.lib.http2.Http2TestServer; import jdk.httpclient.test.lib.http2.Http2TestExchange; import jdk.httpclient.test.lib.http2.Http2Handler; import javax.net.ssl.SSLContext; + import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; diff --git a/test/jdk/java/net/httpclient/DependentActionsTest.java b/test/jdk/java/net/httpclient/DependentActionsTest.java index c1e9025fe4b..7d903157de8 100644 --- a/test/jdk/java/net/httpclient/DependentActionsTest.java +++ b/test/jdk/java/net/httpclient/DependentActionsTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -29,16 +29,17 @@ * @library /test/lib /test/jdk/java/net/httpclient/lib * @build jdk.httpclient.test.lib.common.HttpServerAdapters jdk.test.lib.net.SimpleSSLContext * DependentActionsTest - * @run testng/othervm -Djdk.internal.httpclient.debug=true DependentActionsTest - */ + * @run testng/othervm -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.quic.maxPtoBackoff=9 + * DependentActionsTest + */ import java.io.BufferedReader; import java.io.InputStreamReader; import java.lang.StackWalker.StackFrame; -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; +import jdk.httpclient.test.lib.http3.Http3TestServer; import jdk.test.lib.net.SimpleSSLContext; +import org.testng.SkipException; import org.testng.annotations.AfterTest; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeTest; @@ -49,12 +50,12 @@ import javax.net.ssl.SSLContext; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; -import java.net.http.HttpHeaders; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.net.http.HttpResponse.BodyHandlers; @@ -73,21 +74,20 @@ import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.Flow; import java.util.concurrent.Semaphore; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Predicate; import java.util.function.Supplier; -import java.util.stream.Collectors; import java.util.stream.Stream; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import static java.lang.System.out; import static java.lang.String.format; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.util.stream.Collectors.toList; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -99,6 +99,7 @@ public class DependentActionsTest implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI_fixed; String httpURI_chunk; String httpsURI_fixed; @@ -107,6 +108,9 @@ public class DependentActionsTest implements HttpServerAdapters { String http2URI_chunk; String https2URI_fixed; String https2URI_chunk; + String http3URI_fixed; + String http3URI_chunk; + String http3URI_head; static final StackWalker WALKER = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE); @@ -118,6 +122,7 @@ public class DependentActionsTest implements HttpServerAdapters { static volatile boolean tasksFailed; static final AtomicLong serverCount = new AtomicLong(); static final AtomicLong clientCount = new AtomicLong(); + static final AtomicReference errorRef = new AtomicReference<>(); static final long start = System.nanoTime(); public static String now() { long now = System.nanoTime() - start; @@ -184,6 +189,8 @@ public class DependentActionsTest implements HttpServerAdapters { http2URI_chunk, https2URI_fixed, https2URI_chunk, + http3URI_fixed, + http3URI_chunk }; } @@ -232,7 +239,8 @@ public class DependentActionsTest implements HttpServerAdapters { private HttpClient makeNewClient() { clientCount.incrementAndGet(); - return HttpClient.newBuilder() + return newClientBuilderForH3() + .proxy(Builder.NO_PROXY) .executor(executor) .sslContext(sslContext) .build(); @@ -257,10 +265,14 @@ public class DependentActionsTest implements HttpServerAdapters { HttpClient client = null; out.printf("%ntestNoStalls(%s, %b)%n", uri, sameClient); for (int i=0; i< ITERATION_COUNT; i++) { - if (!sameClient || client == null) + if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } + } - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(uri) .build(); BodyHandler handler = new StallingBodyHandler((w) -> {}, @@ -317,10 +329,16 @@ public class DependentActionsTest implements HttpServerAdapters { Staller staller) throws Exception { + if (errorRef.get() != null) { + SkipException sk = new SkipException("skipping due to previous failure: " + name); + sk.setStackTrace(new StackTraceElement[0]); + throw sk; + } out.printf("%n%s%s%n", now(), name); try { testDependent(uri, sameClient, handlers, finisher, extractor, staller); } catch (Error | Exception x) { + errorRef.compareAndSet(null, x); FAILURES.putIfAbsent(name, x); throw x; } @@ -335,20 +353,36 @@ public class DependentActionsTest implements HttpServerAdapters { { HttpClient client = null; for (Where where : EnumSet.of(Where.BODY_HANDLER)) { - if (!sameClient || client == null) + if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } + } - HttpRequest req = HttpRequest. - newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(uri) .build(); BodyHandler handler = new StallingBodyHandler(where.select(staller), handlers.get()); - System.out.println("try stalling in " + where); - staller.acquire(); - assert staller.willStall(); - CompletableFuture> responseCF = client.sendAsync(req, handler); - assert !responseCF.isDone(); - finisher.finish(where, responseCF, staller, extractor); + for (int i = 0; i < 2; i++) { + System.out.println("try stalling in " + where); + staller.acquire(); + assert staller.willStall(); + CompletableFuture> responseCF = client.sendAsync(req, handler); + assert !responseCF.isDone(); + var resp = finisher.finish(where, responseCF, staller, extractor); + if (version(uri) == HTTP_3 && resp.version() != HTTP_3) { + if (i == 0) continue; + // it's possible that the first request still went through HTTP/2 + // if the config was HTTP3_ANY. Retry it - the next time we should + // have HTTP/3 + assertEquals(resp.version(), HTTP_3, + "expected second request to go through HTTP/3 (serverConfig=" + + http3TestServer.h3DiscoveryConfig() + ")"); + } + break; + } + } } @@ -388,7 +422,7 @@ public class DependentActionsTest implements HttpServerAdapters { } interface Finisher { - public void finish(Where w, + public HttpResponse finish(Where w, CompletableFuture> cf, Staller staller, Extractor extractor); @@ -424,7 +458,7 @@ public class DependentActionsTest implements HttpServerAdapters { } } - void finish(Where w, CompletableFuture> cf, + HttpResponse finish(Where w, CompletableFuture> cf, Staller staller, Extractor extractor) { Thread thread = Thread.currentThread(); @@ -446,6 +480,11 @@ public class DependentActionsTest implements HttpServerAdapters { + w + ": " + response, error); } assertEquals(result, List.of(response.request().uri().getPath())); + var uriStr = response.request().uri().toString(); + if (HTTP_3 != version(uriStr) || http3TestServer.h3DiscoveryConfig() != Http3DiscoveryMode.ANY) { + assertEquals(response.version(), version(uriStr), uriStr); + } + return response; } finally { staller.reset(); } @@ -570,6 +609,37 @@ public class DependentActionsTest implements HttpServerAdapters { } } + static Version version(String uri) { + if (uri.contains("/http1/") || uri.contains("/https1/")) + return HTTP_1_1; + if (uri.contains("/http2/") || uri.contains("/https2/")) + return HTTP_2; + if (uri.contains("/http3/")) + return HTTP_3; + return null; + } + + HttpRequest.Builder newRequestBuilder(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)); + if (version(uri) == HTTP_3) { + builder.version(HTTP_3); + builder.setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + } + return builder; + } + + HttpResponse headRequest(HttpClient client) + throws IOException, InterruptedException + { + var request = newRequestBuilder(http3URI_head) + .HEAD().version(HTTP_2).build(); + var response = client.send(request, BodyHandlers.ofString()); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), HTTP_2); + System.out.println("\n--- HEAD request succeeded ----\n"); + System.err.println("\n--- HEAD request succeeded ----\n"); + return response; + } @BeforeTest public void setup() throws Exception { @@ -608,11 +678,33 @@ public class DependentActionsTest implements HttpServerAdapters { https2URI_fixed = "https://" + https2TestServer.serverAuthority() + "/https2/fixed/x"; https2URI_chunk = "https://" + https2TestServer.serverAuthority() + "/https2/chunk/x"; - serverCount.addAndGet(4); + // HTTP/3 + HttpTestHandler h3_fixedLengthHandler = new HTTP_FixedLengthHandler(); + HttpTestHandler h3_chunkedHandler = new HTTP_ChunkedHandler(); + http3TestServer = HttpTestServer.create(HTTP_3, sslContext); + http3TestServer.addHandler(h3_fixedLengthHandler, "/http3/fixed"); + http3TestServer.addHandler(h3_chunkedHandler, "/http3/chunk"); + http3TestServer.addHandler(new HttpHeadOrGetHandler(), "/http3/head"); + http3URI_fixed = "https://" + http3TestServer.serverAuthority() + "/http3/fixed/x"; + http3URI_chunk = "https://" + http3TestServer.serverAuthority() + "/http3/chunk/x"; + http3URI_head = "https://" + http3TestServer.serverAuthority() + "/http3/head/x"; + + serverCount.addAndGet(5); httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); + + out.println("HTTP/1.1 server (http) listening at: " + httpTestServer.serverAuthority()); + out.println("HTTP/1.1 server (TLS) listening at: " + httpsTestServer.serverAuthority()); + out.println("HTTP/2 server (h2c) listening at: " + http2TestServer.serverAuthority()); + out.println("HTTP/2 server (h2) listening at: " + https2TestServer.serverAuthority()); + out.println("HTTP/3 server (h2) listening at: " + http3TestServer.serverAuthority()); + out.println(" + alt endpoint (h3) listening at: " + http3TestServer.getH3AltService() + .map(Http3TestServer::getAddress)); + + headRequest(newHttpClient(true)); } @AfterTest @@ -622,6 +714,7 @@ public class DependentActionsTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } static class HTTP_FixedLengthHandler implements HttpTestHandler { diff --git a/test/jdk/java/net/httpclient/DependentPromiseActionsTest.java b/test/jdk/java/net/httpclient/DependentPromiseActionsTest.java index e47f4be46f7..12903294315 100644 --- a/test/jdk/java/net/httpclient/DependentPromiseActionsTest.java +++ b/test/jdk/java/net/httpclient/DependentPromiseActionsTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -33,6 +33,7 @@ */ import java.io.BufferedReader; +import java.io.ByteArrayInputStream; import java.io.InputStreamReader; import java.lang.StackWalker.StackFrame; import jdk.test.lib.net.SimpleSSLContext; @@ -49,8 +50,10 @@ import java.io.OutputStream; import java.net.URI; import java.net.URISyntaxException; import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; import java.net.http.HttpHeaders; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.net.http.HttpResponse.BodyHandlers; @@ -81,12 +84,14 @@ import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import static java.lang.System.err; import static java.lang.System.out; import static java.lang.String.format; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -96,10 +101,13 @@ public class DependentPromiseActionsTest implements HttpServerAdapters { SSLContext sslContext; HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String http2URI_fixed; String http2URI_chunk; String https2URI_fixed; String https2URI_chunk; + String http3URI_fixed; + String http3URI_chunk; static final StackWalker WALKER = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE); @@ -170,6 +178,8 @@ public class DependentPromiseActionsTest implements HttpServerAdapters { private String[] uris() { return new String[] { + http3URI_fixed, + http3URI_chunk, http2URI_fixed, http2URI_chunk, https2URI_fixed, @@ -224,9 +234,11 @@ public class DependentPromiseActionsTest implements HttpServerAdapters { private HttpClient makeNewClient() { clientCount.incrementAndGet(); - return HttpClient.newBuilder() + return newClientBuilderForH3() .executor(executor) .sslContext(sslContext) + .proxy(HttpClient.Builder.NO_PROXY) + .version(HTTP_3) .build(); } @@ -243,47 +255,69 @@ public class DependentPromiseActionsTest implements HttpServerAdapters { } } + Http3DiscoveryMode config(String uri) { + return uri.contains("/http3/") ? HTTP_3_URI_ONLY : null; + } + + Version version(String uri) { + return uri.contains("/http3/") ? HTTP_3 : HTTP_2; + } + + HttpRequest request(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)) + .version(version(uri)); + var config = config(uri); + if (config != null) builder.setOption(H3_DISCOVERY, config); + return builder.build(); + } + @Test(dataProvider = "noStalls") public void testNoStalls(String rootUri, boolean sameClient) throws Exception { if (!FAILURES.isEmpty()) return; HttpClient client = null; out.printf("%ntestNoStalls(%s, %b)%n", rootUri, sameClient); - for (int i=0; i< ITERATION_COUNT; i++) { - if (!sameClient || client == null) - client = newHttpClient(sameClient); + try { + for (int i = 0; i < ITERATION_COUNT; i++) { + if (!sameClient || client == null) + client = newHttpClient(sameClient); - String uri = rootUri + "/" + requestCount.incrementAndGet(); - out.printf("\tsending request %s%n", uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) - .build(); - BodyHandler> handler = - new StallingBodyHandler((w) -> {}, - BodyHandlers.ofLines()); - Map>>> pushPromises = - new ConcurrentHashMap<>(); - PushPromiseHandler> pushHandler = new PushPromiseHandler<>() { - @Override - public void applyPushPromise(HttpRequest initiatingRequest, - HttpRequest pushPromiseRequest, - Function>, - CompletableFuture>>> - acceptor) { - pushPromises.putIfAbsent(pushPromiseRequest, acceptor.apply(handler)); + String uri = rootUri + "/" + requestCount.incrementAndGet(); + out.printf("\tsending request %s%n", uri); + HttpRequest req = request(uri); + + BodyHandler> handler = + new StallingBodyHandler((w) -> {}, + BodyHandlers.ofLines()); + Map>>> pushPromises = + new ConcurrentHashMap<>(); + PushPromiseHandler> pushHandler = new PushPromiseHandler<>() { + @Override + public void applyPushPromise(HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + Function>, + CompletableFuture>>> + acceptor) { + pushPromises.putIfAbsent(pushPromiseRequest, acceptor.apply(handler)); + } + }; + HttpResponse> response = + client.sendAsync(req, BodyHandlers.ofLines(), pushHandler).get(); + String body = response.body().collect(Collectors.joining("|")); + assertEquals(URI.create(body).getPath(), URI.create(uri).getPath()); + for (HttpRequest promised : pushPromises.keySet()) { + out.printf("%s Received promise: %s%n\tresponse: %s%n", + now(), promised, pushPromises.get(promised).get()); + String promisedBody = pushPromises.get(promised).get().body() + .collect(Collectors.joining("|")); + assertEquals(promisedBody, promised.uri().toASCIIString()); } - }; - HttpResponse> response = - client.sendAsync(req, BodyHandlers.ofLines(), pushHandler).get(); - String body = response.body().collect(Collectors.joining("|")); - assertEquals(URI.create(body).getPath(), URI.create(uri).getPath()); - for (HttpRequest promised : pushPromises.keySet()) { - out.printf("%s Received promise: %s%n\tresponse: %s%n", - now(), promised, pushPromises.get(promised).get()); - String promisedBody = pushPromises.get(promised).get().body() - .collect(Collectors.joining("|")); - assertEquals(promisedBody, promised.uri().toASCIIString()); + assertEquals(3, pushPromises.size()); + } + } finally { + if (!sameClient && client != null) { + client.close(); } - assertEquals(3, pushPromises.size()); } } @@ -357,29 +391,33 @@ public class DependentPromiseActionsTest implements HttpServerAdapters { throws Exception { HttpClient client = null; - for (Where where : EnumSet.of(Where.BODY_HANDLER)) { - if (!sameClient || client == null) - client = newHttpClient(sameClient); + try { + for (Where where : EnumSet.of(Where.BODY_HANDLER)) { + if (!sameClient || client == null) + client = newHttpClient(sameClient); + String uri = rootUri + "/" + requestCount.incrementAndGet(); + out.printf("\tsending request %s%n", uri); + HttpRequest req = request(uri); - String uri = rootUri + "/" + requestCount.incrementAndGet(); - out.printf("\tsending request %s%n", uri); - HttpRequest req = HttpRequest. - newBuilder(URI.create(uri)) - .build(); - StallingPushPromiseHandler promiseHandler = - new StallingPushPromiseHandler<>(where, handlers, stallers); - BodyHandler handler = handlers.get(); - System.out.println("try stalling in " + where); - CompletableFuture> responseCF = - client.sendAsync(req, handler, promiseHandler); - // The body of the main response can be received before the body - // of the push promise handlers are received. - // The body of the main response doesn't stall, so the cf of - // the main response may be done here even for EAGER subscribers. - // We cannot make any assumption on the state of the main response - // cf here, so the only thing we can do is to call the finisher - // which will wait for them all. - finisher.finish(where, responseCF, promiseHandler, extractor); + StallingPushPromiseHandler promiseHandler = + new StallingPushPromiseHandler<>(where, handlers, stallers); + BodyHandler handler = handlers.get(); + System.out.println("try stalling in " + where); + CompletableFuture> responseCF = + client.sendAsync(req, handler, promiseHandler); + // The body of the main response can be received before the body + // of the push promise handlers are received. + // The body of the main response doesn't stall, so the cf of + // the main response may be done here even for EAGER subscribers. + // We cannot make any assumption on the state of the main response + // cf here, so the only thing we can do is to call the finisher + // which will wait for them all. + finisher.finish(where, responseCF, promiseHandler, extractor); + } + } finally { + if (client != null && !sameClient) { + client.close(); + } } } @@ -499,9 +537,7 @@ public class DependentPromiseActionsTest implements HttpServerAdapters { httpStack.forEach(f -> System.out.printf("\t%s%n", f)); failed.set(new RuntimeException("Dependant action has unexpected frame in " + Thread.currentThread() + ": " + httpStack.get(0))); - } - return; } else { List httpStack = WALKER.walk(s -> s.filter(f -> f.getDeclaringClass() .getModule().equals(HttpClient.class.getModule())) @@ -670,31 +706,42 @@ public class DependentPromiseActionsTest implements HttpServerAdapters { throw new AssertionError("Unexpected null sslContext"); // HTTP/2 - HttpTestHandler h2_fixedLengthHandler = new HTTP_FixedLengthHandler(); - HttpTestHandler h2_chunkedHandler = new HTTP_ChunkedHandler(); + HttpTestHandler fixedLengthHandler = new HTTP_FixedLengthHandler(); + HttpTestHandler chunkedHandler = new HTTP_ChunkedHandler(); http2TestServer = HttpTestServer.create(HTTP_2); - http2TestServer.addHandler(h2_fixedLengthHandler, "/http2/fixed"); - http2TestServer.addHandler(h2_chunkedHandler, "/http2/chunk"); + http2TestServer.addHandler(fixedLengthHandler, "/http2/fixed"); + http2TestServer.addHandler(chunkedHandler, "/http2/chunk"); http2URI_fixed = "http://" + http2TestServer.serverAuthority() + "/http2/fixed/y"; http2URI_chunk = "http://" + http2TestServer.serverAuthority() + "/http2/chunk/y"; https2TestServer = HttpTestServer.create(HTTP_2, sslContext); - https2TestServer.addHandler(h2_fixedLengthHandler, "/https2/fixed"); - https2TestServer.addHandler(h2_chunkedHandler, "/https2/chunk"); + https2TestServer.addHandler(fixedLengthHandler, "/https2/fixed"); + https2TestServer.addHandler(chunkedHandler, "/https2/chunk"); https2URI_fixed = "https://" + https2TestServer.serverAuthority() + "/https2/fixed/y"; https2URI_chunk = "https://" + https2TestServer.serverAuthority() + "/https2/chunk/y"; - serverCount.addAndGet(4); + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(fixedLengthHandler, "/http3/fixed"); + http3TestServer.addHandler(chunkedHandler, "/http3/chunk"); + http3URI_fixed = "https://" + http3TestServer.serverAuthority() + "/http3/fixed/x"; + http3URI_chunk = "https://" + http3TestServer.serverAuthority() + "/http3/chunk/x"; + + serverCount.addAndGet(3); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest public void teardown() throws Exception { + if (sharedClient != null) { + sharedClient.close(); + } sharedClient = null; http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } static final BiPredicate ACCEPT_ALL = (x, y) -> true; @@ -712,15 +759,68 @@ public class DependentPromiseActionsTest implements HttpServerAdapters { byte[] promiseBytes = promise.toASCIIString().getBytes(UTF_8); out.printf("TestServer: %s Pushing promise: %s%n", now(), promise); err.printf("TestServer: %s Pushing promise: %s%n", now(), promise); - HttpHeaders headers; + HttpHeaders reqHeaders = HttpHeaders.of(Map.of(), ACCEPT_ALL); // empty + HttpHeaders rspHeaders; if (fixed) { String length = String.valueOf(promiseBytes.length); - headers = HttpHeaders.of(Map.of("Content-Length", List.of(length)), + rspHeaders = HttpHeaders.of(Map.of("Content-Length", List.of(length)), ACCEPT_ALL); } else { - headers = HttpHeaders.of(Map.of(), ACCEPT_ALL); // empty + rspHeaders = HttpHeaders.of(Map.of(), ACCEPT_ALL); // empty } - t.serverPush(promise, headers, promiseBytes); + t.serverPush(promise, reqHeaders, rspHeaders, promiseBytes); + } catch (URISyntaxException x) { + throw new IOException(x.getMessage(), x); + } + } + + private static long sendHttp3PushPromiseFrame(HttpTestExchange t, + URI requestURI, + String pushPath, + boolean fixed) + throws IOException + { + try { + URI promise = new URI(requestURI.getScheme(), + requestURI.getAuthority(), + pushPath, null, null); + byte[] promiseBytes = promise.toASCIIString().getBytes(UTF_8); + out.printf("TestServer: %s sending PushPromiseFrame: %s%n", now(), promise); + err.printf("TestServer: %s Pushing PushPromiseFrame: %s%n", now(), promise); + HttpHeaders reqHeaders = HttpHeaders.of(Map.of(), ACCEPT_ALL); + long pushId = t.sendHttp3PushPromiseFrame(-1, promise, reqHeaders); + out.printf("TestServer: %s PushPromiseFrame pushId=%s sent%n", now(), pushId); + err.printf("TestServer: %s PushPromiseFrame pushId=%s sent%n", now(), pushId); + return pushId; + } catch (URISyntaxException x) { + throw new IOException(x.getMessage(), x); + } + } + + private static void sendHttp3PushResponse(HttpTestExchange t, + long pushId, + URI requestURI, + String pushPath, + boolean fixed) + throws IOException + { + try { + URI promise = new URI(requestURI.getScheme(), + requestURI.getAuthority(), + pushPath, null, null); + byte[] promiseBytes = promise.toASCIIString().getBytes(UTF_8); + out.printf("TestServer: %s sending push response pushId=%s: %s%n", now(), pushId, promise); + err.printf("TestServer: %s Pushing push response pushId=%s: %s%n", now(), pushId, promise); + HttpHeaders reqHeaders = HttpHeaders.of(Map.of(), ACCEPT_ALL); // empty + HttpHeaders rspHeaders; + if (fixed) { + String length = String.valueOf(promiseBytes.length); + rspHeaders = HttpHeaders.of(Map.of("Content-Length", List.of(length)), + ACCEPT_ALL); + } else { + rspHeaders = HttpHeaders.of(Map.of(), ACCEPT_ALL); // empty + } + t.sendHttp3PushResponse(pushId, promise, reqHeaders, rspHeaders, new ByteArrayInputStream(promiseBytes)); } catch (URISyntaxException x) { throw new IOException(x.getMessage(), x); } @@ -739,14 +839,33 @@ public class DependentPromiseActionsTest implements HttpServerAdapters { pushPromiseFor(t, requestURI, path, true); } byte[] resp = t.getRequestURI().toString().getBytes(StandardCharsets.UTF_8); - t.sendResponseHeaders(200, resp.length); //fixed content length + t.sendResponseHeaders(200, resp.length); + + //fixed content length + // With HTTP/3 fixed length we send a single DataFrame, + // therefore we can't interleave a PushPromiseFrame in + // the middle of the DataFrame, so we're going to send + // the PushPromiseFrame before the DataFrame, and then + // fulfill the promise later while sending the response + // body. + long[] pushIds = new long[2]; + if (t.getExchangeVersion() == HTTP_3) { + for (int i = 0; i < 2; i++) { + String path = requestURI.getPath() + "/after/promise-" + (i + 2); + pushIds[i] = sendHttp3PushPromiseFrame(t, requestURI, path, true); + } + } try (OutputStream os = t.getResponseBody()) { int bytes = resp.length/3; for (int i = 0; i<2; i++) { String path = requestURI.getPath() + "/after/promise-" + (i + 2); os.write(resp, i * bytes, bytes); os.flush(); - pushPromiseFor(t, requestURI, path, true); + if (t.getExchangeVersion() == HTTP_2) { + pushPromiseFor(t, requestURI, path, true); + } else if (t.getExchangeVersion() == HTTP_3) { + sendHttp3PushResponse(t, pushIds[i], requestURI, path, true); + } } os.write(resp, 2*bytes, resp.length - 2*bytes); } diff --git a/test/jdk/java/net/httpclient/DigestEchoClient.java b/test/jdk/java/net/httpclient/DigestEchoClient.java index 0aa70a26586..10add19cb94 100644 --- a/test/jdk/java/net/httpclient/DigestEchoClient.java +++ b/test/jdk/java/net/httpclient/DigestEchoClient.java @@ -33,6 +33,7 @@ import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; import java.net.http.HttpRequest.BodyPublisher; import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import java.nio.charset.StandardCharsets; @@ -44,7 +45,6 @@ import java.util.List; import java.util.Optional; import java.util.Random; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; @@ -52,7 +52,9 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; import java.util.stream.Stream; import javax.net.ssl.SSLContext; + import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.test.lib.Utils; import jdk.test.lib.net.SimpleSSLContext; import sun.net.NetProperties; import sun.net.www.HeaderParser; @@ -60,8 +62,9 @@ import sun.net.www.HeaderParser; import static java.lang.System.out; import static java.lang.System.err; import static java.lang.String.format; +import static java.net.http.HttpOption.H3_DISCOVERY; -/** +/* * @test * @summary this test verifies that a client may provides authorization * headers directly when connecting with a server. @@ -189,7 +192,11 @@ public class DigestEchoClient { static final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; public HttpClient newHttpClient(DigestEchoServer server) { clientCount.incrementAndGet(); - HttpClient.Builder builder = HttpClient.newBuilder(); + HttpClient.Builder builder = switch (server.version()) { + case HTTP_3 -> HttpServerAdapters.createClientBuilderForH3() + .version(Version.HTTP_3); + default -> HttpClient.newBuilder(); + }; builder = builder.proxy(ProxySelector.of(null)); if (useSSL) { builder.sslContext(context); @@ -223,11 +230,11 @@ public class DigestEchoClient { } public static List serverVersions(Version clientVersion) { - if (clientVersion == Version.HTTP_1_1) { - return List.of(clientVersion); - } else { - return List.of(Version.values()); - } + return switch (clientVersion) { + case HTTP_1_1 -> List.of(clientVersion); + case HTTP_2 -> List.of(Version.HTTP_1_1, Version.HTTP_2); + case HTTP_3 -> List.of(Version.HTTP_2, Version.HTTP_3); + }; } public static List clientVersions() { @@ -273,6 +280,13 @@ public class DigestEchoClient { authType); for (Version clientVersion : clientVersions()) { for (Version serverVersion : serverVersions(clientVersion)) { + if (serverVersion == Version.HTTP_3) { + if (!useSSL) continue; + switch (authType) { + case PROXY, PROXY305: continue; + default: break; + } + } for (boolean expectContinue : expectContinue(serverVersion)) { for (boolean async : BOOLEANS) { for (boolean preemptive : BOOLEANS) { @@ -293,6 +307,13 @@ public class DigestEchoClient { authType); for (Version clientVersion : clientVersions()) { for (Version serverVersion : serverVersions(clientVersion)) { + if (serverVersion == Version.HTTP_3) { + if (!useSSL) continue; + switch (authType) { + case PROXY, PROXY305: continue; + default: break; + } + } for (boolean expectContinue : expectContinue(serverVersion)) { for (boolean async : BOOLEANS) { dec.testDigest(clientVersion, serverVersion, @@ -314,7 +335,7 @@ public class DigestEchoClient { throw t; } finally { Thread.sleep(100); - AssertionError trackFailed = TRACKER.check(500); + AssertionError trackFailed = TRACKER.check(Utils.adjustTimeout(1000)); EchoServers.stop(); System.out.println(" ---------------------------------------------------------- "); System.out.println(String.format("DigestEchoClient %s %s", useSSL ? "SSL" : "CLEAR", types)); @@ -370,6 +391,14 @@ public class DigestEchoClient { .isPresent(); } + static Http3DiscoveryMode serverConfig(int step, DigestEchoServer server) { + var config = server.serverConfig(); + return switch (config) { + case HTTP_3_URI_ONLY -> config; + default -> Http3DiscoveryMode.ALT_SVC; + }; + } + final static AtomicLong basics = new AtomicLong(); final static AtomicLong basicCount = new AtomicLong(); // @Test @@ -413,6 +442,7 @@ public class DigestEchoClient { + ",expectContinue=" + expectContinue + ",version=" + clientVersion); reqURI = URI.create(baseReq + ",basicCount=" + basicCount.get()); HttpRequest.Builder builder = HttpRequest.newBuilder(reqURI).version(clientVersion) + .setOption(H3_DISCOVERY, serverConfig(i, server)) .POST(reqBody).expectContinue(expectContinue); boolean isTunnel = isProxy(authType) && useSSL; if (addHeaders) { @@ -450,6 +480,16 @@ public class DigestEchoClient { out.printf("%s client.send(%s)%n", DigestEchoServer.now(), request); resp = client.send(request, BodyHandlers.ofLines()); } + if (serverVersion == Version.HTTP_3 && clientVersion == Version.HTTP_3) { + out.println("Response version [" + i + "]: " + resp.version()); + int required = isRedirecting(authType) ? 1 : 0; + if (i > required) { + if (resp.version() != serverVersion) { + throw new AssertionError("Expected HTTP/3, but got: " + + resp.version()); + } + } + } } catch (Throwable t) { long stop = System.nanoTime(); synchronized (basicCount) { @@ -471,6 +511,7 @@ public class DigestEchoClient { reqURI = URI.create(baseReq + ",withAuthorization=" + authType + ",basicCount=" + basicCount.get()); request = HttpRequest.newBuilder(reqURI).version(clientVersion) + .setOption(H3_DISCOVERY, server.serverConfig()) .POST(reqBody).header(authorizationKey(authType), auth).build(); if (async) { out.printf("%s client.sendAsync(%s)%n", DigestEchoServer.now(), request); @@ -479,6 +520,16 @@ public class DigestEchoClient { out.printf("%s client.send(%s)%n", DigestEchoServer.now(), request); resp = client.send(request, BodyHandlers.ofLines()); } + if (serverVersion == Version.HTTP_3 && clientVersion == Version.HTTP_3) { + out.println("Response version [" + i + "]: " + resp.version()); + int required = isRedirecting(authType) ? 1 : 0; + if (i > required) { + if (resp.version() != serverVersion) { + throw new AssertionError("Expected HTTP/3, but got: " + + resp.version()); + } + } + } } final List respLines; try { @@ -520,16 +571,19 @@ public class DigestEchoClient { failed = decorated; throw decorated; } finally { - client = null; - System.gc(); - while (!ref.refersTo(null)) { + if (client != null) { + var tracker = TRACKER.getTracker(client); + client = null; System.gc(); - if (queue.remove(100) == ref) break; - } - var error = TRACKER.checkShutdown(900); - if (error != null) { - if (failed != null) error.addSuppressed(failed); - throw error; + while (!ref.refersTo(null)) { + System.gc(); + if (queue.remove(100) == ref) break; + } + var error = TRACKER.checkShutdown(tracker, Utils.adjustTimeout(900), false); + if (error != null) { + if (failed != null) error.addSuppressed(failed); + throw error; + } } } System.out.println("OK"); @@ -563,13 +617,13 @@ public class DigestEchoClient { server.getServerAddress(), "/foo/"); HttpClient client = newHttpClient(server); + ReferenceQueue queue = new ReferenceQueue<>(); + WeakReference ref = new WeakReference<>(client, queue); HttpResponse r; CompletableFuture> cf1; byte[] cnonce = new byte[16]; String cnonceStr = null; DigestEchoServer.DigestResponse challenge = null; - ReferenceQueue queue = new ReferenceQueue<>(); - WeakReference ref = new WeakReference<>(client, queue); URI reqURI = null; Throwable failed = null; try { @@ -584,6 +638,7 @@ public class DigestEchoClient { reqURI = URI.create(baseReq + ",digestCount=" + digestCount.get()); HttpRequest.Builder reqBuilder = HttpRequest .newBuilder(reqURI).version(clientVersion).POST(reqBody) + .setOption(H3_DISCOVERY, serverConfig(i, server)) .expectContinue(expectContinue); boolean isTunnel = isProxy(authType) && useSSL; @@ -612,6 +667,16 @@ public class DigestEchoClient { out.printf("%s client.send(%s)%n", DigestEchoServer.now(), request); resp = client.send(request, BodyHandlers.ofLines()); } + if (serverVersion == Version.HTTP_3 && clientVersion == Version.HTTP_3) { + out.println("Response version [" + i + "]: " + resp.version()); + int required = isRedirecting(authType) ? 1 : 0; + if (i > required) { + if (resp.version() != serverVersion) { + throw new AssertionError("Expected HTTP/3, but got: " + + resp.version()); + } + } + } System.out.println(resp); assert challenge != null || resp.statusCode() == 401 || resp.statusCode() == 407 : "challenge=" + challenge + ", resp=" + resp + ", test=[" + test + "]"; @@ -642,6 +707,7 @@ public class DigestEchoClient { reqURI = URI.create(baseReq + ",withAuth=" + authType + ",digestCount=" + digestCount.get()); try { request = HttpRequest.newBuilder(reqURI).version(clientVersion) + .setOption(H3_DISCOVERY, serverConfig(i, server)) .POST(reqBody).header(authorizationKey(authType), auth).build(); } catch (IllegalArgumentException x) { throw x; @@ -654,6 +720,16 @@ public class DigestEchoClient { resp = client.send(request, BodyHandlers.ofLines()); } System.out.println(resp); + if (serverVersion == Version.HTTP_3 && clientVersion == Version.HTTP_3) { + out.println("Response version [" + i + "]: " + resp.version()); + int required = isRedirecting(authType) ? 1 : 0; + if (i > required) { + if (resp.version() != serverVersion) { + throw new AssertionError("Expected HTTP/3, but got: " + + resp.version()); + } + } + } } final List respLines; try { @@ -691,24 +767,27 @@ public class DigestEchoClient { failed = decorated; throw decorated; } finally { - client = null; - System.gc(); - while (!ref.refersTo(null)) { + if (client != null) { + var tracker = TRACKER.getTracker(client); + client = null; System.gc(); - if (queue.remove(100) == ref) break; - } - var error = TRACKER.checkShutdown(900); - if (error != null) { - if (failed != null) { - error.addSuppressed(failed); + while (!ref.refersTo(null)) { + System.gc(); + if (queue.remove(100) == ref) break; + } + var error = TRACKER.checkShutdown(tracker, Utils.adjustTimeout(900), false); + if (error != null) { + if (failed != null) { + error.addSuppressed(failed); + } + throw error; } - throw error; } } System.out.println("OK"); } - // WARNING: This is not a full fledged implementation of DIGEST. + // WARNING: This is not a full-fledged implementation of DIGEST. // It does contain bugs and inaccuracy. static String digestResponse(URI uri, String method, DigestEchoServer.DigestResponse challenge, String cnonce) throws NoSuchAlgorithmException { @@ -729,32 +808,34 @@ public class DigestEchoClient { } static String authenticateKey(DigestEchoServer.HttpAuthType authType) { - switch (authType) { - case SERVER: return "www-authenticate"; - case SERVER307: return "www-authenticate"; - case PROXY: return "proxy-authenticate"; - case PROXY305: return "proxy-authenticate"; - default: throw new InternalError("authType: " + authType); - } + return switch (authType) { + case SERVER -> "www-authenticate"; + case SERVER307 -> "www-authenticate"; + case PROXY -> "proxy-authenticate"; + case PROXY305 -> "proxy-authenticate"; + }; } static String authorizationKey(DigestEchoServer.HttpAuthType authType) { - switch (authType) { - case SERVER: return "authorization"; - case SERVER307: return "Authorization"; - case PROXY: return "Proxy-Authorization"; - case PROXY305: return "proxy-Authorization"; - default: throw new InternalError("authType: " + authType); - } + return switch (authType) { + case SERVER -> "authorization"; + case SERVER307 -> "Authorization"; + case PROXY -> "Proxy-Authorization"; + case PROXY305 -> "proxy-Authorization"; + }; } static boolean isProxy(DigestEchoServer.HttpAuthType authType) { - switch (authType) { - case SERVER: return false; - case SERVER307: return false; - case PROXY: return true; - case PROXY305: return true; - default: throw new InternalError("authType: " + authType); - } + return switch (authType) { + case SERVER, SERVER307 -> false; + case PROXY, PROXY305 -> true; + }; + } + + static boolean isRedirecting(DigestEchoServer.HttpAuthType authType) { + return switch (authType) { + case SERVER307, PROXY305 -> true; + case SERVER, PROXY -> false; + }; } } diff --git a/test/jdk/java/net/httpclient/DigestEchoClientSSL.java b/test/jdk/java/net/httpclient/DigestEchoClientSSL.java index 9a7f5acb88a..ecf043f58c8 100644 --- a/test/jdk/java/net/httpclient/DigestEchoClientSSL.java +++ b/test/jdk/java/net/httpclient/DigestEchoClientSSL.java @@ -31,10 +31,18 @@ * DigestEchoClient ReferenceTracker DigestEchoClientSSL * jdk.httpclient.test.lib.common.HttpServerAdapters * @run main/othervm/timeout=300 - * DigestEchoClientSSL SSL + * -Djdk.internal.httpclient.debug=err + * -Djdk.httpclient.HttpClient.log=headers + * DigestEchoClientSSL SSL SERVER307 + * @run main/othervm/timeout=300 + * -Djdk.httpclient.http3.maxDirectConnectionTimeout=100 + * -Djdk.httpclient.HttpClient.log=headers + * DigestEchoClientSSL SSL SERVER PROXY * @run main/othervm/timeout=300 * -Djdk.http.auth.proxying.disabledSchemes= * -Djdk.http.auth.tunneling.disabledSchemes= + * -Djdk.httpclient.http3.maxDirectConnectionTimeout=100 + * -Djdk.httpclient.HttpClient.log=headers * DigestEchoClientSSL SSL PROXY * */ diff --git a/test/jdk/java/net/httpclient/DigestEchoServer.java b/test/jdk/java/net/httpclient/DigestEchoServer.java index 187951405c2..f9a2ec1e017 100644 --- a/test/jdk/java/net/httpclient/DigestEchoServer.java +++ b/test/jdk/java/net/httpclient/DigestEchoServer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -47,6 +47,8 @@ import java.net.StandardSocketOptions; import java.net.URI; import java.net.URISyntaxException; import java.net.URL; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.UnsupportedProtocolVersionException; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -67,8 +69,10 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.Stream; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import jdk.httpclient.test.lib.common.HttpServerAdapters.AbstractHttpAuthFilter.HttpAuthMode; +import jdk.httpclient.test.lib.common.TestServerConfigurator; import sun.net.www.HeaderParser; import java.net.http.HttpClient.Version; import jdk.httpclient.test.lib.common.HttpServerAdapters; @@ -159,15 +163,21 @@ public abstract class DigestEchoServer implements HttpServerAdapters { final String key; DigestEchoServer(String key, - HttpTestServer server, - DigestEchoServer target, - HttpTestHandler delegate) { + HttpTestServer server, + DigestEchoServer target, + HttpTestHandler delegate) { this.key = key; this.serverImpl = server; this.redirect = target; this.delegate = delegate; } + public Version version() { + if (serverImpl.canHandle(Version.HTTP_3)) + return Version.HTTP_3; + return serverImpl.getVersion(); + } + public static void main(String[] args) throws IOException { @@ -186,6 +196,17 @@ public abstract class DigestEchoServer implements HttpServerAdapters { } } + public Http3DiscoveryMode serverConfig() { + // If the client request is HTTP_3, but the server + // doesn't support HTTP/3, we don't want the client + // to attempt a direct HTTP/3 connection - so use + // ALT_SVC to prevent that + var config = serverImpl.h3DiscoveryConfig(); + return config == null + ? Http3DiscoveryMode.ALT_SVC + : config; + } + private static String toString(HttpTestRequestHeaders headers) { return headers.entrySet().stream() .map((e) -> e.getKey() + ": " + e.getValue()) @@ -249,6 +270,13 @@ public abstract class DigestEchoServer implements HttpServerAdapters { } } + private static String authority(InetSocketAddress address) { + String addressStr = address.getAddress().getHostAddress(); + if (addressStr.indexOf(':') >= 0) { + addressStr = "[" + addressStr + "]"; + } + return addressStr + ":" + address.getPort(); + } /** * The SocketBindableFactory ensures that the local port used by an HttpServer @@ -267,7 +295,8 @@ public abstract class DigestEchoServer implements HttpServerAdapters { for (int i = 1; i <= max; i++) { B bindable = createBindable(); InetSocketAddress address = getAddress(bindable); - String key = "localhost:" + address.getPort(); + + String key = authority(address); if (addresses.addIfAbsent(key)) { System.out.println("Socket bound to: " + key + " after " + i + " attempt(s)"); @@ -461,6 +490,15 @@ public abstract class DigestEchoServer implements HttpServerAdapters { return HttpTestServer.of(createHttp1Server(protocol)); case HTTP_2: return HttpTestServer.of(createHttp2Server(protocol)); + case HTTP_3: + try { + if (!"https".equalsIgnoreCase(protocol)) { + throw new UnsupportedProtocolVersionException("HTTP/3 requires https"); + } + return HttpTestServer.create(Version.HTTP_3, SSLContext.getDefault()); + } catch (NoSuchAlgorithmException e) { + throw new IOException(e); + } default: throw new InternalError("Unexpected version: " + version); } @@ -481,7 +519,7 @@ public abstract class DigestEchoServer implements HttpServerAdapters { static HttpsServer configure(HttpsServer server) throws IOException { try { SSLContext ctx = SSLContext.getDefault(); - server.setHttpsConfigurator(new Configurator(ctx)); + server.setHttpsConfigurator(new Configurator(server.getAddress().getAddress(), ctx)); } catch (NoSuchAlgorithmException ex) { throw new IOException(ex); } @@ -516,6 +554,16 @@ public abstract class DigestEchoServer implements HttpServerAdapters { Objects.requireNonNull(authType); Objects.requireNonNull(auth); + if (version == Version.HTTP_3) { + if (!"https".equalsIgnoreCase(protocol)) { + throw new UnsupportedProtocolVersionException("HTTP/3 requires TLS"); + } + switch (authType) { + case PROXY, PROXY305 -> + throw new UnsupportedProtocolVersionException("proxying not supported for HTTP/3"); + case SERVER, SERVER307 -> {} + } + } HttpTestServer impl = createHttpServer(version, protocol); String key = String.format("DigestEchoServer[PID=%s,PORT=%s]:%s:%s:%s:%s", ProcessHandle.current().pid(), @@ -545,6 +593,10 @@ public abstract class DigestEchoServer implements HttpServerAdapters { System.out.println("WARNING: can't use HTTP/1.1 proxy with unsecure HTTP/2 server"); version = Version.HTTP_1_1; } + if (version == Version.HTTP_3) { + System.out.println("WARNING: can't use HTTP/1.1 proxy with HTTP/3 server"); + version = Version.HTTP_2; + } HttpTestServer impl = createHttpServer(version, protocol); String key = String.format("DigestEchoServer[PID=%s,PORT=%s]:%s:%s:%s:%s", ProcessHandle.current().pid(), @@ -932,15 +984,11 @@ public abstract class DigestEchoServer implements HttpServerAdapters { @Override protected void requestAuthentication(HttpTestExchange he) throws IOException { - String separator; Version v = he.getExchangeVersion(); - if (v == Version.HTTP_1_1) { - separator = "\r\n "; - } else if (v == Version.HTTP_2) { - separator = " "; - } else { - throw new InternalError(String.valueOf(v)); - } + String separator = switch (v) { + case HTTP_1_1 -> "\r\n "; + case HTTP_2, HTTP_3 -> " "; + }; String headerName = getAuthenticate(); String headerValue = "Digest realm=\"" + auth.getRealm() + "\"," + separator + "qop=\"auth\"," @@ -1136,13 +1184,18 @@ public abstract class DigestEchoServer implements HttpServerAdapters { } static class Configurator extends HttpsConfigurator { - public Configurator(SSLContext ctx) { + private final InetAddress serverAddr; + + public Configurator(InetAddress serverAddr, SSLContext ctx) { super(ctx); + this.serverAddr = serverAddr; } @Override public void configure (HttpsParameters params) { - params.setSSLParameters (getSSLContext().getSupportedSSLParameters()); + final SSLParameters parameters = getSSLContext().getSupportedSSLParameters(); + TestServerConfigurator.addSNIMatcher(this.serverAddr, parameters); + params.setSSLParameters(parameters); } } @@ -1553,8 +1606,8 @@ public abstract class DigestEchoServer implements HttpServerAdapters { port = port == -1 ? 443 : port; targetAddress = new InetSocketAddress(uri.getHost(), port); if (serverImpl != null) { - assert targetAddress.getHostString() - .equalsIgnoreCase(serverImpl.getAddress().getHostString()); + assert targetAddress.getAddress().getHostAddress() + .equalsIgnoreCase(serverImpl.getAddress().getAddress().getHostAddress()); assert targetAddress.getPort() == serverImpl.getAddress().getPort(); } } catch (Throwable x) { @@ -1716,15 +1769,16 @@ public abstract class DigestEchoServer implements HttpServerAdapters { public static URL url(String protocol, InetSocketAddress address, String path) throws MalformedURLException { - return new URL(protocol(protocol), - address.getHostString(), - address.getPort(), path); + try { + return uri(protocol, address, path).toURL(); + } catch (URISyntaxException e) { + throw new MalformedURLException(e.getMessage()); + } } public static URI uri(String protocol, InetSocketAddress address, String path) throws URISyntaxException { return new URI(protocol(protocol) + "://" + - address.getHostString() + ":" + - address.getPort() + path); + authority(address) + path); } } diff --git a/test/jdk/java/net/httpclient/EmptyAuthenticate.java b/test/jdk/java/net/httpclient/EmptyAuthenticate.java index ca5c41594e8..e445a974ad0 100644 --- a/test/jdk/java/net/httpclient/EmptyAuthenticate.java +++ b/test/jdk/java/net/httpclient/EmptyAuthenticate.java @@ -89,11 +89,13 @@ class EmptyAuthenticate { } static Stream args() { - return Stream - .of(Version.HTTP_1_1, Version.HTTP_2) - .flatMap(version -> Stream - .of(true, false) - .map(secure -> Arguments.of(version, secure))); + return Stream.concat( + Stream + .of(Version.HTTP_1_1, Version.HTTP_2) + .flatMap(version -> Stream + .of(true, false) + .map(secure -> Arguments.of(version, secure))), + Stream.of(Arguments.of(Version.HTTP_3, true))); } private static HttpTestServer createServer(Version version, boolean secure, String uriPath) @@ -128,7 +130,7 @@ class EmptyAuthenticate { } private static HttpClient createClient(Version version, boolean secure) { - HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(version).proxy(NO_PROXY); + HttpClient.Builder clientBuilder = HttpServerAdapters.createClientBuilderFor(version).proxy(NO_PROXY); if (secure) { clientBuilder.sslContext(SSL_CONTEXT); } diff --git a/test/jdk/java/net/httpclient/EncodedCharsInURI.java b/test/jdk/java/net/httpclient/EncodedCharsInURI.java index bcda1f32539..91bace42699 100644 --- a/test/jdk/java/net/httpclient/EncodedCharsInURI.java +++ b/test/jdk/java/net/httpclient/EncodedCharsInURI.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -36,6 +36,7 @@ */ //* -Djdk.internal.httpclient.debug=true +import jdk.httpclient.test.lib.http3.Http3TestServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterClass; import org.testng.annotations.AfterTest; @@ -78,6 +79,8 @@ import static java.lang.System.err; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static java.net.http.HttpClient.Builder.NO_PROXY; import static org.testng.Assert.assertEquals; @@ -92,6 +95,7 @@ public class EncodedCharsInURI implements HttpServerAdapters { HttpTestServer https2TestServer; // HTTP/2 ( h2 ) DummyServer httpDummyServer; // HTTP/1.1 [ 2 servers ] DummyServer httpsDummyServer; // HTTPS/1.1 + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI_fixed; String httpURI_chunk; String httpsURI_fixed; @@ -100,6 +104,9 @@ public class EncodedCharsInURI implements HttpServerAdapters { String http2URI_chunk; String https2URI_fixed; String https2URI_chunk; + String http3URI_fixed; + String http3URI_chunk; + String http3URI_head; String httpDummy; String httpsDummy; @@ -169,6 +176,8 @@ public class EncodedCharsInURI implements HttpServerAdapters { return new String[] { httpDummy, httpsDummy, + http3URI_fixed, + http3URI_chunk, httpURI_fixed, httpURI_chunk, httpsURI_fixed, @@ -202,12 +211,39 @@ public class EncodedCharsInURI implements HttpServerAdapters { return HTTP_1_1; if (uri.contains("/http2/") || uri.contains("/https2/")) return HTTP_2; + if (uri.contains("/http3/")) + return HTTP_3; return null; } + HttpRequest.Builder newRequestBuilder(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)); + if (version(uri) == HTTP_3) { + builder.version(HTTP_3); + builder.setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + } + return builder; + } + + HttpResponse headRequest(HttpClient client) + throws IOException, InterruptedException + { + out.println("\n" + now() + "--- Sending HEAD request ----\n"); + err.println("\n" + now() + "--- Sending HEAD request ----\n"); + + var request = newRequestBuilder(http3URI_head) + .HEAD().version(HTTP_2).build(); + var response = client.send(request, BodyHandlers.ofString()); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), HTTP_2); + out.println("\n" + now() + "--- HEAD request succeeded ----\n"); + err.println("\n" + now() + "--- HEAD request succeeded ----\n"); + return response; + } + private HttpClient makeNewClient() { clientCount.incrementAndGet(); - return HttpClient.newBuilder() + return newClientBuilderForH3() .executor(executor) .proxy(NO_PROXY) .sslContext(sslContext) @@ -246,11 +282,14 @@ public class EncodedCharsInURI implements HttpServerAdapters { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { BodyPublisher bodyPublisher = BodyPublishers.ofString(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(uri) .POST(bodyPublisher) .build(); BodyHandler handler = BodyHandlers.ofString(); @@ -314,15 +353,27 @@ public class EncodedCharsInURI implements HttpServerAdapters { httpDummy = "http://" + httpDummyServer.serverAuthority() + "/http1/dummy/x"; httpsDummy = "https://" + httpsDummyServer.serverAuthority() + "/https1/dummy/x"; + // HTTP/3 + HttpTestHandler h3_fixedLengthHandler = new HTTP_FixedLengthHandler(); + HttpTestHandler h3_chunkedHandler = new HTTP_ChunkedHandler(); + http3TestServer = HttpTestServer.create(HTTP_3, sslContext); + http3TestServer.addHandler(h3_fixedLengthHandler, "/http3/fixed"); + http3TestServer.addHandler(h3_chunkedHandler, "/http3/chunk"); + http3TestServer.addHandler(new HttpHeadOrGetHandler(), "/http3/head"); + http3URI_fixed = "https://" + http3TestServer.serverAuthority() + "/http3/fixed/x"; + http3URI_chunk = "https://" + http3TestServer.serverAuthority() + "/http3/chunk/x"; + http3URI_head = "https://" + http3TestServer.serverAuthority() + "/http3/head/x"; + err.println(now() + "Starting servers"); - serverCount.addAndGet(6); + serverCount.addAndGet(7); httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); httpDummyServer.start(); httpsDummyServer.start(); + http3TestServer.start(); out.println("HTTP/1.1 dummy server (http) listening at: " + httpDummyServer.serverAuthority()); out.println("HTTP/1.1 dummy server (TLS) listening at: " + httpsDummyServer.serverAuthority()); @@ -330,6 +381,11 @@ public class EncodedCharsInURI implements HttpServerAdapters { out.println("HTTP/1.1 server (TLS) listening at: " + httpsTestServer.serverAuthority()); out.println("HTTP/2 server (h2c) listening at: " + http2TestServer.serverAuthority()); out.println("HTTP/2 server (h2) listening at: " + https2TestServer.serverAuthority()); + out.println("HTTP/3 server (h2) listening at: " + http3TestServer.serverAuthority()); + out.println(" + alt endpoint (h3) listening at: " + http3TestServer.getH3AltService() + .map(Http3TestServer::getAddress)); + + headRequest(newHttpClient(true)); out.println(now() + "setup done"); err.println(now() + "setup done"); @@ -342,6 +398,7 @@ public class EncodedCharsInURI implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); httpDummyServer.stopServer(); httpsDummyServer.stopServer(); } diff --git a/test/jdk/java/net/httpclient/EscapedOctetsInURI.java b/test/jdk/java/net/httpclient/EscapedOctetsInURI.java index 8a17cea78c4..d9d8ba1ddd7 100644 --- a/test/jdk/java/net/httpclient/EscapedOctetsInURI.java +++ b/test/jdk/java/net/httpclient/EscapedOctetsInURI.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -50,6 +50,7 @@ import java.util.Arrays; import java.util.List; import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http3.Http3TestServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -60,6 +61,8 @@ import static java.lang.System.err; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.US_ASCII; import static java.net.http.HttpClient.Builder.NO_PROXY; import static org.testng.Assert.assertEquals; @@ -67,14 +70,17 @@ import static org.testng.Assert.assertEquals; public class EscapedOctetsInURI implements HttpServerAdapters { SSLContext sslContext; - HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] + HttpTestServer httpTestServer; // HTTP/1.1 [ 5 servers ] HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; + String http3URI_head; private volatile HttpClient sharedClient; @@ -106,6 +112,9 @@ public class EscapedOctetsInURI implements HttpServerAdapters { Arrays.asList(pathsAndQueryStrings).stream() .map(e -> new Object[] {https2URI + e[0] + e[1], sameClient}) .forEach(list::add); + Arrays.asList(pathsAndQueryStrings).stream() + .map(e -> new Object[] {http3URI + e[0] + e[1], sameClient}) + .forEach(list::add); } return list.stream().toArray(Object[][]::new); } @@ -126,12 +135,38 @@ public class EscapedOctetsInURI implements HttpServerAdapters { return HTTP_1_1; if (uri.contains("/http2/") || uri.contains("/https2/")) return HTTP_2; + if (uri.contains("/http3/")) + return HTTP_3; return null; } + HttpRequest.Builder newRequestBuilder(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)); + if (version(uri) == HTTP_3) { + builder.version(HTTP_3); + builder.setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + } + return builder; + } + + HttpResponse headRequest(HttpClient client) + throws IOException, InterruptedException + { + out.println("\n" + now() + "--- Sending HEAD request ----\n"); + err.println("\n" + now() + "--- Sending HEAD request ----\n"); + + var request = newRequestBuilder(http3URI_head) + .HEAD().version(HTTP_2).build(); + var response = client.send(request, BodyHandlers.ofString()); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), HTTP_2); + out.println("\n" + now() + "--- HEAD request succeeded ----\n"); + err.println("\n" + now() + "--- HEAD request succeeded ----\n"); + return response; + } private HttpClient makeNewClient() { - return HttpClient.newBuilder() + return newClientBuilderForH3() .proxy(NO_PROXY) .sslContext(sslContext) .build(); @@ -171,10 +206,13 @@ public class EscapedOctetsInURI implements HttpServerAdapters { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uriString) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { - HttpRequest request = HttpRequest.newBuilder(uri).build(); + HttpRequest request = newRequestBuilder(uriString).build(); HttpResponse resp = client.send(request, BodyHandlers.ofString()); out.println("Got response: " + resp); @@ -200,10 +238,13 @@ public class EscapedOctetsInURI implements HttpServerAdapters { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uriString) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { - HttpRequest request = HttpRequest.newBuilder(uri).build(); + HttpRequest request = newRequestBuilder(uriString).build(); client.sendAsync(request, BodyHandlers.ofString()) .thenApply(response -> { out.println("Got response: " + response); @@ -247,16 +288,28 @@ public class EscapedOctetsInURI implements HttpServerAdapters { https2TestServer.addHandler(new HttpASCIIUriStringHandler(), "/https2/get"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/get"; + http3TestServer = HttpTestServer.create(HTTP_3, sslContext); + http3TestServer.addHandler(new HttpASCIIUriStringHandler(), "/http3/get"); + http3TestServer.addHandler(new HttpHeadOrGetHandler(), "/http3/head"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/get"; + http3URI_head = "https://" + http3TestServer.serverAuthority() + "/http3/head/x"; + err.println(now() + "Starting servers"); httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); out.println("HTTP/1.1 server (http) listening at: " + httpTestServer.serverAuthority()); out.println("HTTP/1.1 server (TLS) listening at: " + httpsTestServer.serverAuthority()); out.println("HTTP/2 server (h2c) listening at: " + http2TestServer.serverAuthority()); out.println("HTTP/2 server (h2) listening at: " + https2TestServer.serverAuthority()); + out.println("HTTP/3 server (h2) listening at: " + http3TestServer.serverAuthority()); + out.println(" + alt endpoint (h3) listening at: " + http3TestServer.getH3AltService() + .map(Http3TestServer::getAddress)); + + headRequest(newHttpClient(true)); out.println(now() + "setup done"); err.println(now() + "setup done"); @@ -269,6 +322,7 @@ public class EscapedOctetsInURI implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } /** A handler that returns as its body the exact escaped request URI. */ diff --git a/test/jdk/java/net/httpclient/ExecutorShutdown.java b/test/jdk/java/net/httpclient/ExecutorShutdown.java index e829daa556b..7d4cbec6d92 100644 --- a/test/jdk/java/net/httpclient/ExecutorShutdown.java +++ b/test/jdk/java/net/httpclient/ExecutorShutdown.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -39,12 +39,12 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Redirect; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import java.nio.channels.ClosedChannelException; @@ -61,13 +61,8 @@ import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLHandshakeException; -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.RandomFactory; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; @@ -79,6 +74,9 @@ import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.fail; @@ -91,14 +89,19 @@ public class ExecutorShutdown implements HttpServerAdapters { static final Random RANDOM = RandomFactory.getRandom(); SSLContext sslContext; - HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] + HttpTestServer httpTestServer; // HTTP/1.1 [ 6 servers ] HttpTestServer httpsTestServer; // HTTPS/1.1 - HttpTestServer http2TestServer; // HTTP/2 ( h2c ) - HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http2TestServer; // HTTP/2 ( h2c ) + HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer h2h3TestServer; // HTTP/2 ( h2+h3 ) + HttpTestServer h3TestServer; // HTTP/2 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String h2h3URI; + String h3URI; + String h2h3Head; static final String MESSAGE = "ExecutorShutdown message body"; static final int ITERATIONS = 3; @@ -106,10 +109,12 @@ public class ExecutorShutdown implements HttpServerAdapters { @DataProvider(name = "positive") public Object[][] positive() { return new Object[][] { - { httpURI, }, - { httpsURI, }, - { http2URI, }, - { https2URI, }, + { h2h3URI, HTTP_3, h2h3TestServer.h3DiscoveryConfig() }, + { h3URI, HTTP_3, h3TestServer.h3DiscoveryConfig() }, + { httpURI, HTTP_1_1, null }, + { httpsURI, HTTP_1_1, null }, + { http2URI, HTTP_2, null }, + { https2URI, HTTP_2, null }, }; } @@ -134,6 +139,13 @@ public class ExecutorShutdown implements HttpServerAdapters { } else if (t instanceof ClosedChannelException) { out.println(what + ": Accepting ClosedChannelException as a valid cause: " + t); accepted = t; + } else if (t instanceof IOException io) { + var msg = io.getMessage(); + // Stream 0 cancelled should also be accepted + if (msg != null && msg.matches("Stream (0|([1-9][0-9]*)) cancelled")) { + out.println(what + ": Accepting Stream cancelled as a valid cause: " + io); + accepted = t; + } } t = t.getCause(); } @@ -147,12 +159,13 @@ public class ExecutorShutdown implements HttpServerAdapters { } @Test(dataProvider = "positive") - void testConcurrent(String uriString) throws Exception { + void testConcurrent(String uriString, Version version, Http3DiscoveryMode config) throws Exception { out.printf("%n---- starting (%s) ----%n", uriString); ExecutorService executorService = Executors.newCachedThreadPool(); - HttpClient client = HttpClient.newBuilder() + HttpClient client = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) + .version(version == HTTP_1_1 ? HTTP_2 : version) .executor(executorService) .sslContext(sslContext) .build(); @@ -160,11 +173,19 @@ public class ExecutorShutdown implements HttpServerAdapters { assert client.executor().isPresent(); int step = RANDOM.nextInt(ITERATIONS); + int head = Math.min(1, step); + List>> responses = new ArrayList<>(); try { - List>> responses = new ArrayList<>(); for (int i = 0; i < ITERATIONS; i++) { + if (i == head && version == HTTP_3 && config != HTTP_3_URI_ONLY) { + // let's the first request go through whatever version, + // but ensure that the second will find an AltService + // record + headRequest(client); + } URI uri = URI.create(uriString + "/concurrent/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) + .setOption(H3_DISCOVERY, config) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) .build(); out.printf("Iteration %d request: %s%n", i, request.uri()); @@ -189,6 +210,7 @@ public class ExecutorShutdown implements HttpServerAdapters { out.println(si + ": Got response: " + response); out.println(si + ": Got body Path: " + response.body()); assertEquals(response.statusCode(), 200); + if (si >= head) assertEquals(response.version(), version); assertEquals(response.body(), MESSAGE); return response; }).exceptionally((t) -> { @@ -207,12 +229,13 @@ public class ExecutorShutdown implements HttpServerAdapters { } @Test(dataProvider = "positive") - void testSequential(String uriString) throws Exception { - out.printf("%n---- starting (%s) ----%n", uriString); + void testSequential(String uriString, Version version, Http3DiscoveryMode config) throws Exception { + out.printf("%n---- starting (%s, %s, %s) ----%n%n", uriString, version, config); ExecutorService executorService = Executors.newCachedThreadPool(); - HttpClient client = HttpClient.newBuilder() + HttpClient client = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) + .version(version == HTTP_1_1 ? HTTP_2 : version) .executor(executorService) .sslContext(sslContext) .build(); @@ -225,8 +248,9 @@ public class ExecutorShutdown implements HttpServerAdapters { for (int i = 0; i < ITERATIONS; i++) { URI uri = URI.create(uriString + "/sequential/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) - .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) - .build(); + .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) + .setOption(H3_DISCOVERY, config) + .build(); out.printf("Iteration %d request: %s%n", i, request.uri()); CompletableFuture> responseCF; try { @@ -249,6 +273,7 @@ public class ExecutorShutdown implements HttpServerAdapters { out.println(si + ": Got response: " + response); out.println(si + ": Got body Path: " + response.body()); assertEquals(response.statusCode(), 200); + if (si > 0) assertEquals(response.version(), version); assertEquals(response.body(), MESSAGE); return response; }).handle((r,t) -> { @@ -274,6 +299,15 @@ public class ExecutorShutdown implements HttpServerAdapters { // -- Infrastructure + void headRequest(HttpClient client) throws Exception { + HttpRequest request = HttpRequest.newBuilder(URI.create(h2h3Head)) + .version(HTTP_2) + .HEAD() + .build(); + var resp = client.send(request, BodyHandlers.discarding()); + assertEquals(resp.statusCode(), 200); + } + @BeforeTest public void setup() throws Exception { out.println("\n**** Setup ****\n"); @@ -295,10 +329,21 @@ public class ExecutorShutdown implements HttpServerAdapters { https2TestServer.addHandler(new ServerRequestHandler(), "/https2/exec/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/exec/retry"; + h2h3TestServer = HttpTestServer.create(HTTP_3, sslContext); + h2h3TestServer.addHandler(new ServerRequestHandler(), "/h2h3/exec/"); + h2h3URI = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/exec/retry"; + h2h3TestServer.addHandler(new HttpHeadOrGetHandler(), "/h2h3/head/"); + h2h3Head = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/head/"; + h3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3TestServer.addHandler(new ServerRequestHandler(), "/h3-only/exec/"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3-only/exec/retry"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + h2h3TestServer.start(); + h3TestServer.start(); } @AfterTest @@ -310,6 +355,8 @@ public class ExecutorShutdown implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + h2h3TestServer.stop(); + h3TestServer.stop(); } finally { if (fail != null) throw fail; } diff --git a/test/jdk/java/net/httpclient/ExpectContinueTest.java b/test/jdk/java/net/httpclient/ExpectContinueTest.java index 50e43099255..a703fbfaacf 100644 --- a/test/jdk/java/net/httpclient/ExpectContinueTest.java +++ b/test/jdk/java/net/httpclient/ExpectContinueTest.java @@ -37,6 +37,7 @@ import jdk.httpclient.test.lib.http2.Http2TestExchange; import jdk.httpclient.test.lib.http2.Http2TestExchangeImpl; import jdk.httpclient.test.lib.http2.Http2TestServer; import jdk.httpclient.test.lib.http2.Http2TestServerConnection; +import jdk.httpclient.test.lib.http2.Http2TestServerConnection.ResponseHeaders; import jdk.internal.net.http.common.HttpHeadersBuilder; import jdk.internal.net.http.frame.HeaderFrame; import org.testng.TestException; @@ -252,16 +253,25 @@ public class ExpectContinueTest implements HttpServerAdapters { static class ExpectContinueTestExchangeImpl extends Http2TestExchangeImpl { - public ExpectContinueTestExchangeImpl(int streamid, String method, HttpHeaders reqheaders, HttpHeadersBuilder rspheadersBuilder, URI uri, InputStream is, SSLSession sslSession, BodyOutputStream os, Http2TestServerConnection conn, boolean pushAllowed) { + public ExpectContinueTestExchangeImpl(int streamid, + String method, + HttpHeaders reqheaders, + HttpHeadersBuilder rspheadersBuilder, + URI uri, InputStream is, + SSLSession sslSession, + BodyOutputStream os, + Http2TestServerConnection conn, + boolean pushAllowed) { super(streamid, method, reqheaders, rspheadersBuilder, uri, is, sslSession, os, conn, pushAllowed); } private void sendEndStreamHeaders() throws IOException { this.responseLength = 0; - rspheadersBuilder.setHeader(":status", Integer.toString(100)); + HttpHeadersBuilder pseudoHeadersBuilder = new HttpHeadersBuilder(); + pseudoHeadersBuilder.setHeader(":status", Integer.toString(100)); + HttpHeaders pseudoHeaders = pseudoHeadersBuilder.build(); HttpHeaders headers = rspheadersBuilder.build(); - Http2TestServerConnection.ResponseHeaders response - = new Http2TestServerConnection.ResponseHeaders(headers); + ResponseHeaders response = new ResponseHeaders(pseudoHeaders, headers); response.streamid(streamid); response.setFlag(HeaderFrame.END_HEADERS); response.setFlag(HeaderFrame.END_STREAM); diff --git a/test/jdk/java/net/httpclient/FlowAdapterPublisherTest.java b/test/jdk/java/net/httpclient/FlowAdapterPublisherTest.java index 5d0935d7216..ddefd2a9aa7 100644 --- a/test/jdk/java/net/httpclient/FlowAdapterPublisherTest.java +++ b/test/jdk/java/net/httpclient/FlowAdapterPublisherTest.java @@ -27,9 +27,11 @@ import java.io.OutputStream; import java.net.URI; import java.net.http.HttpClient.Builder; import java.net.http.HttpClient.Version; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.nio.ByteBuffer; import java.nio.MappedByteBuffer; import java.util.Arrays; +import java.util.Optional; import java.util.concurrent.Flow; import java.util.concurrent.Flow.Publisher; import java.util.concurrent.atomic.AtomicBoolean; @@ -46,6 +48,8 @@ import org.testng.annotations.BeforeTest; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import javax.net.ssl.SSLContext; + +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.util.stream.Collectors.joining; import static java.nio.charset.StandardCharsets.UTF_8; import static java.net.http.HttpRequest.BodyPublishers.fromPublisher; @@ -61,20 +65,22 @@ import static org.testng.Assert.fail; * @library /test/lib /test/jdk/java/net/httpclient/lib * @build jdk.httpclient.test.lib.common.HttpServerAdapters * jdk.test.lib.net.SimpleSSLContext - * @run testng/othervm FlowAdapterPublisherTest + * @run testng/othervm -Djdk.internal.httpclient.debug=err FlowAdapterPublisherTest */ public class FlowAdapterPublisherTest implements HttpServerAdapters { SSLContext sslContext; - HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] + HttpTestServer httpTestServer; // HTTP/1.1 [ 5 servers ] HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; @DataProvider(name = "uris") public Object[][] variants() { @@ -83,6 +89,7 @@ public class FlowAdapterPublisherTest implements HttpServerAdapters { { httpsURI }, { http2URI }, { https2URI }, + { http3URI }, }; } @@ -94,16 +101,26 @@ public class FlowAdapterPublisherTest implements HttpServerAdapters { if (uri.contains("/https1/")) return Version.HTTP_1_1; if (uri.contains("/http2/")) return Version.HTTP_2; if (uri.contains("/https2/")) return Version.HTTP_2; + if (uri.contains("/http3/")) return Version.HTTP_3; return null; } private HttpClient newHttpClient(String uri) { - var builder = HttpClient.newBuilder(); + var version = Optional.ofNullable(version(uri)); + var builder = version.isEmpty() || version.get() != Version.HTTP_3 + ? HttpClient.newBuilder() + : HttpServerAdapters.createClientBuilderForH3().version(Version.HTTP_3); return builder.sslContext(sslContext).proxy(Builder.NO_PROXY).build(); } private HttpRequest.Builder newRequestBuilder(String uri) { - return HttpRequest.newBuilder(URI.create(uri)); + var version = Optional.ofNullable(version(uri)); + var builder = version.isEmpty() || version.get() != Version.HTTP_3 + ? HttpRequest.newBuilder(URI.create(uri)) + : HttpRequest.newBuilder(URI.create(uri)) + .version(Version.HTTP_3) + .setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + return builder; } @Test @@ -378,10 +395,15 @@ public class FlowAdapterPublisherTest implements HttpServerAdapters { https2TestServer.addHandler(new HttpEchoHandler(), "/https2/echo"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/echo"; + http3TestServer = HttpTestServer.create(Http3DiscoveryMode.HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new HttpEchoHandler(), "/http3/echo"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/echo"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -390,6 +412,7 @@ public class FlowAdapterPublisherTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } static class HttpEchoHandler implements HttpTestHandler { diff --git a/test/jdk/java/net/httpclient/FlowAdapterSubscriberTest.java b/test/jdk/java/net/httpclient/FlowAdapterSubscriberTest.java index 8319204f5c2..5de943278eb 100644 --- a/test/jdk/java/net/httpclient/FlowAdapterSubscriberTest.java +++ b/test/jdk/java/net/httpclient/FlowAdapterSubscriberTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2017, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -30,9 +30,11 @@ import java.lang.StackWalker.StackFrame; import java.net.URI; import java.net.http.HttpClient.Builder; import java.net.http.HttpClient.Version; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.nio.ByteBuffer; import java.util.Collection; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; @@ -56,6 +58,8 @@ import org.testng.annotations.BeforeTest; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import javax.net.ssl.SSLContext; + +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertThrows; @@ -74,14 +78,16 @@ import static org.testng.Assert.assertTrue; public class FlowAdapterSubscriberTest implements HttpServerAdapters { SSLContext sslContext; - HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] + HttpTestServer httpTestServer; // HTTP/1.1 [ 5 servers ] HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; static final StackWalker WALKER = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE); @@ -102,6 +108,7 @@ public class FlowAdapterSubscriberTest implements HttpServerAdapters { { httpsURI }, { http2URI }, { https2URI }, + { http3URI }, }; } @@ -112,16 +119,26 @@ public class FlowAdapterSubscriberTest implements HttpServerAdapters { if (uri.contains("/https1/")) return Version.HTTP_1_1; if (uri.contains("/http2/")) return Version.HTTP_2; if (uri.contains("/https2/")) return Version.HTTP_2; + if (uri.contains("/http3/")) return Version.HTTP_3; return null; } private HttpClient newHttpClient(String uri) { - var builder = HttpClient.newBuilder(); + var version = Optional.ofNullable(version(uri)); + var builder = version.isEmpty() || version.get() != Version.HTTP_3 + ? HttpClient.newBuilder() + : HttpServerAdapters.createClientBuilderForH3().version(Version.HTTP_3); return builder.sslContext(sslContext).proxy(Builder.NO_PROXY).build(); } private HttpRequest.Builder newRequestBuilder(String uri) { - return HttpRequest.newBuilder(URI.create(uri)); + var version = Optional.ofNullable(version(uri)); + var builder = version.isEmpty() || version.get() != Version.HTTP_3 + ? HttpRequest.newBuilder(URI.create(uri)) + : HttpRequest.newBuilder(URI.create(uri)) + .version(Version.HTTP_3) + .setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + return builder; } @Test @@ -636,10 +653,15 @@ public class FlowAdapterSubscriberTest implements HttpServerAdapters { https2TestServer.addHandler(new HttpEchoHandler(), "/https2/echo"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/echo"; + http3TestServer = HttpTestServer.create(Http3DiscoveryMode.HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new HttpEchoHandler(), "/http3/echo"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/echo"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -648,6 +670,7 @@ public class FlowAdapterSubscriberTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } static class HttpEchoHandler implements HttpTestHandler { diff --git a/test/jdk/java/net/httpclient/ForbiddenHeadTest.java b/test/jdk/java/net/httpclient/ForbiddenHeadTest.java index 1498aa118b3..2a50d03b365 100644 --- a/test/jdk/java/net/httpclient/ForbiddenHeadTest.java +++ b/test/jdk/java/net/httpclient/ForbiddenHeadTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -35,9 +35,6 @@ * ForbiddenHeadTest */ -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.ITestContext; import org.testng.ITestResult; @@ -53,7 +50,6 @@ import javax.net.ssl.SSLContext; import java.io.IOException; import java.io.InputStream; import java.net.Authenticator; -import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.PasswordAuthentication; import java.net.Proxy; @@ -76,12 +72,14 @@ import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import static java.lang.System.err; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; @@ -93,12 +91,14 @@ public class ForbiddenHeadTest implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) DigestEchoServer.TunnelingProxy proxy; DigestEchoServer.TunnelingProxy authproxy; String httpURI; String httpsURI; String http2URI; String https2URI; + String https3URI; HttpClient authClient; HttpClient noAuthClient; @@ -218,6 +218,8 @@ public class ForbiddenHeadTest implements HttpServerAdapters { for (var uri : List.of(httpURI, httpsURI, http2URI, https2URI)) { result.add(new Object[]{uri + srv + auth, pcode, async, useAuth}); } + if (code == PROXY_UNAUTHORIZED) continue; // no HTTP/3 with proxy + result.add(new Object[] {https3URI + srv + auth, pcode, async, useAuth}); } } } @@ -275,6 +277,10 @@ public class ForbiddenHeadTest implements HttpServerAdapters { .newBuilder(uri) .method("HEAD", HttpRequest.BodyPublishers.noBody()); + if (uriString.contains("/http3/")) { + requestBuilder.version(HTTP_3).setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + HttpRequest request = requestBuilder.build(); out.println("Initial request: " + request.uri()); @@ -349,10 +355,14 @@ public class ForbiddenHeadTest implements HttpServerAdapters { https2TestServer.addHandler(new UnauthorizedHandler(), "/https2/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2"; + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new UnauthorizedHandler(), "/http3/"); + https3URI = "https://" + http3TestServer.serverAuthority() + "/http3"; + proxy = DigestEchoServer.createHttpsProxyTunnel(DigestEchoServer.HttpAuthSchemeType.NONE); authproxy = DigestEchoServer.createHttpsProxyTunnel(DigestEchoServer.HttpAuthSchemeType.BASIC); - authClient = TRACKER.track(HttpClient.newBuilder() + authClient = TRACKER.track(newClientBuilderForH3() .proxy(TestProxySelector.of(proxy, authproxy, httpTestServer)) .sslContext(sslContext) .executor(executor) @@ -360,7 +370,7 @@ public class ForbiddenHeadTest implements HttpServerAdapters { .build()); clientCount.incrementAndGet(); - noAuthClient = TRACKER.track(HttpClient.newBuilder() + noAuthClient = TRACKER.track(newClientBuilderForH3() .proxy(TestProxySelector.of(proxy, authproxy, httpTestServer)) .sslContext(sslContext) .executor(executor) @@ -375,6 +385,8 @@ public class ForbiddenHeadTest implements HttpServerAdapters { serverCount.incrementAndGet(); https2TestServer.start(); serverCount.incrementAndGet(); + http3TestServer.start(); + serverCount.incrementAndGet(); } @AfterTest @@ -389,6 +401,7 @@ public class ForbiddenHeadTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } finally { if (fail != null) throw fail; } diff --git a/test/jdk/java/net/httpclient/GZIPInputStreamTest.java b/test/jdk/java/net/httpclient/GZIPInputStreamTest.java index 1215fc6c0a0..a8ce8fb2de5 100644 --- a/test/jdk/java/net/httpclient/GZIPInputStreamTest.java +++ b/test/jdk/java/net/httpclient/GZIPInputStreamTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -31,8 +31,6 @@ */ import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -45,10 +43,10 @@ import java.io.InputStream; import java.io.OutputStream; import java.io.UncheckedIOException; import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.net.http.HttpResponse.BodyHandlers; @@ -62,11 +60,13 @@ import java.util.function.Supplier; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; @@ -77,10 +77,12 @@ public class GZIPInputStreamTest implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer https3TestServer; // HTTP/3 String httpURI; String httpsURI; String http2URI; String https2URI; + String https3URI; static final int ITERATION_COUNT = 3; // a shared executor helps reduce the amount of threads created by the test @@ -163,26 +165,28 @@ public class GZIPInputStreamTest implements HttpServerAdapters { { http2URI, true }, { https2URI, false }, { https2URI, true }, + { https3URI, false }, + { https3URI, true } }; } final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; HttpClient newHttpClient() { - return TRACKER.track(HttpClient.newBuilder() + return TRACKER.track(newClientBuilderForH3() .executor(executor) .sslContext(sslContext) .build()); } HttpClient newSingleThreadClient() { - return TRACKER.track(HttpClient.newBuilder() + return TRACKER.track(newClientBuilderForH3() .executor(singleThreadExecutor) .sslContext(sslContext) .build()); } HttpClient newInLineClient() { - return TRACKER.track(HttpClient.newBuilder() + return TRACKER.track(newClientBuilderForH3() .executor((r) -> r.run() ) .sslContext(sslContext) .build()); @@ -198,18 +202,10 @@ public class GZIPInputStreamTest implements HttpServerAdapters { if (!sameClient || client == null) client = newSingleThreadClient(); // should work with 1 single thread - HttpRequest req = HttpRequest.newBuilder(URI.create(uri +"/txt/LoremIpsum.txt")) - .build(); + HttpRequest req = buildRequest(URI.create(uri +"/txt/LoremIpsum.txt")); BodyHandler handler = BodyHandlers.ofString(UTF_8); HttpResponse response = client.send(req, handler); - String lorem = response.body(); - if (!LOREM_IPSUM.equals(lorem)) { - out.println("Response doesn't match"); - out.println("[" + LOREM_IPSUM + "] != [" + lorem + "]"); - assertEquals(LOREM_IPSUM, lorem); - } else { - out.println("Received expected response."); - } + verifyResponse(response.body()); } } @@ -222,18 +218,10 @@ public class GZIPInputStreamTest implements HttpServerAdapters { if (!sameClient || client == null) client = newSingleThreadClient(); // should work with 1 single thread - HttpRequest req = HttpRequest.newBuilder(URI.create(uri + "/txt/LoremIpsum.txt")) - .build(); + HttpRequest req = buildRequest(URI.create(uri + "/txt/LoremIpsum.txt")); BodyHandler handler = BodyHandlers.ofInputStream(); HttpResponse response = client.send(req, handler); - String lorem = new String(response.body().readAllBytes(), UTF_8); - if (!LOREM_IPSUM.equals(lorem)) { - out.println("Response doesn't match"); - out.println("[" + LOREM_IPSUM + "] != [" + lorem + "]"); - assertEquals(LOREM_IPSUM, lorem); - } else { - out.println("Received expected response."); - } + verifyResponse(new String(response.body().readAllBytes(), UTF_8)); } } @@ -247,19 +235,11 @@ public class GZIPInputStreamTest implements HttpServerAdapters { if (!sameClient || client == null) client = newSingleThreadClient(); // should work with 1 single thread - HttpRequest req = HttpRequest.newBuilder(URI.create(uri + "/gz/LoremIpsum.txt.gz")) - .build(); + HttpRequest req = buildRequest(URI.create(uri + "/gz/LoremIpsum.txt.gz")); BodyHandler handler = BodyHandlers.ofInputStream(); HttpResponse response = client.send(req, handler); GZIPInputStream gz = new GZIPInputStream(response.body()); - String lorem = new String(gz.readAllBytes(), UTF_8); - if (!LOREM_IPSUM.equals(lorem)) { - out.println("Response doesn't match"); - out.println("[" + LOREM_IPSUM + "] != [" + lorem + "]"); - assertEquals(LOREM_IPSUM, lorem); - } else { - out.println("Received expected response."); - } + verifyResponse(new String(gz.readAllBytes(), UTF_8)); } } @@ -273,20 +253,12 @@ public class GZIPInputStreamTest implements HttpServerAdapters { if (!sameClient || client == null) client = newHttpClient(); // needs at least 2 threads - HttpRequest req = HttpRequest.newBuilder(URI.create(uri + "/gz/LoremIpsum.txt.gz")) - .build(); + HttpRequest req = buildRequest(URI.create(uri + "/gz/LoremIpsum.txt.gz")); // This is dangerous, because the finisher will block. // We support this, but the executor must have enough threads. BodyHandler handler = new GZIPBodyHandler(); HttpResponse response = client.send(req, handler); - String lorem = new String(response.body().readAllBytes(), UTF_8); - if (!LOREM_IPSUM.equals(lorem)) { - out.println("Response doesn't match"); - out.println("[" + LOREM_IPSUM + "] != [" + lorem + "]"); - assertEquals(LOREM_IPSUM, lorem); - } else { - out.println("Received expected response."); - } + verifyResponse(new String(response.body().readAllBytes(), UTF_8)); } } @@ -301,8 +273,7 @@ public class GZIPInputStreamTest implements HttpServerAdapters { if (!sameClient || client == null) client = newSingleThreadClient(); // should work with 1 single thread - HttpRequest req = HttpRequest.newBuilder(URI.create(uri + "/gz/LoremIpsum.txt.gz")) - .build(); + HttpRequest req = buildRequest(URI.create(uri + "/gz/LoremIpsum.txt.gz")); // This is dangerous, because the finisher will block. // We support this, but the executor must have enough threads. BodyHandler> handler = new BodyHandler>() { @@ -329,14 +300,7 @@ public class GZIPInputStreamTest implements HttpServerAdapters { } }; HttpResponse> response = client.send(req, handler); - String lorem = new String(response.body().get().readAllBytes(), UTF_8); - if (!LOREM_IPSUM.equals(lorem)) { - out.println("Response doesn't match"); - out.println("[" + LOREM_IPSUM + "] != [" + lorem + "]"); - assertEquals(LOREM_IPSUM, lorem); - } else { - out.println("Received expected response."); - } + verifyResponse(new String(response.body().get().readAllBytes(), UTF_8)); } } @@ -350,27 +314,20 @@ public class GZIPInputStreamTest implements HttpServerAdapters { if (!sameClient || client == null) client = newHttpClient(); // needs at least 2 threads - HttpRequest req = HttpRequest.newBuilder(URI.create(uri + "/txt/LoremIpsum.txt")) - .build(); + HttpRequest req = buildRequest(URI.create(uri + "/txt/LoremIpsum.txt")); BodyHandler handler = BodyHandlers.ofInputStream(); CompletableFuture> responseCF = client.sendAsync(req, handler); // This is dangerous. Blocking in the mapping function can wedge the // response. We do support it provided that there enough threads in // the executor. - String lorem = responseCF.thenApply((r) -> { + String responseBody = responseCF.thenApply((r) -> { try { return new String(r.body().readAllBytes(), UTF_8); } catch (IOException io) { throw new UncheckedIOException(io); } }).join(); - if (!LOREM_IPSUM.equals(lorem)) { - out.println("Response doesn't match"); - out.println("[" + LOREM_IPSUM + "] != [" + lorem + "]"); - assertEquals(LOREM_IPSUM, lorem); - } else { - out.println("Received expected response."); - } + verifyResponse(responseBody); } } @@ -384,27 +341,20 @@ public class GZIPInputStreamTest implements HttpServerAdapters { if (!sameClient || client == null) client = newHttpClient(); // needs at least 2 threads - HttpRequest req = HttpRequest.newBuilder(URI.create(uri + "/gz/LoremIpsum.txt.gz")) - .build(); + HttpRequest req = buildRequest(URI.create(uri + "/gz/LoremIpsum.txt.gz")); BodyHandler handler = new GZIPBodyHandler(); CompletableFuture> responseCF = client.sendAsync(req, handler); // This is dangerous - we support this, but it can block // if there are not enough threads available. // Correct custom code should use thenApplyAsync instead. - String lorem = responseCF.thenApply((r) -> { + String responseBody = responseCF.thenApply((r) -> { try { return new String(r.body().readAllBytes(), UTF_8); } catch (IOException io) { throw new UncheckedIOException(io); } }).join(); - if (!LOREM_IPSUM.equals(lorem)) { - out.println("Response doesn't match"); - out.println("[" + LOREM_IPSUM + "] != [" + lorem + "]"); - assertEquals(LOREM_IPSUM, lorem); - } else { - out.println("Received expected response."); - } + verifyResponse(responseBody); } } @@ -419,8 +369,7 @@ public class GZIPInputStreamTest implements HttpServerAdapters { if (!sameClient || client == null) client = newHttpClient(); // needs at least 2 threads - HttpRequest req = HttpRequest.newBuilder(URI.create(uri + "/gz/LoremIpsum.txt.gz")) - .build(); + HttpRequest req = buildRequest(URI.create(uri + "/gz/LoremIpsum.txt.gz")); // This is dangerous. Blocking in the mapping function can wedge the // response. We do support it provided that there enough thread in // the executor. @@ -433,14 +382,7 @@ public class GZIPInputStreamTest implements HttpServerAdapters { } }); HttpResponse response = client.send(req, handler); - String lorem = response.body(); - if (!LOREM_IPSUM.equals(lorem)) { - out.println("Response doesn't match"); - out.println("[" + LOREM_IPSUM + "] != [" + lorem + "]"); - assertEquals(LOREM_IPSUM, lorem); - } else { - out.println("Received expected response."); - } + verifyResponse(response.body()); } } @@ -455,8 +397,7 @@ public class GZIPInputStreamTest implements HttpServerAdapters { if (!sameClient || client == null) client = newInLineClient(); // should even work with no threads - HttpRequest req = HttpRequest.newBuilder(URI.create(uri + "/gz/LoremIpsum.txt.gz")) - .build(); + HttpRequest req = buildRequest(URI.create(uri + "/gz/LoremIpsum.txt.gz")); // This is dangerous, because the finisher will block. // We support this, but the executor must have enough threads. BodyHandler> handler = new BodyHandler>() { @@ -483,17 +424,29 @@ public class GZIPInputStreamTest implements HttpServerAdapters { } }; HttpResponse> response = client.send(req, handler); - String lorem = new String(response.body().get().readAllBytes(), UTF_8); - if (!LOREM_IPSUM.equals(lorem)) { - out.println("Response doesn't match"); - out.println("[" + LOREM_IPSUM + "] != [" + lorem + "]"); - assertEquals(LOREM_IPSUM, lorem); - } else { - out.println("Received expected response."); - } + verifyResponse(new String(response.body().get().readAllBytes(), UTF_8)); } } + private void verifyResponse(String responseBody) { + if (!LOREM_IPSUM.equals(responseBody)) { + out.println("Response doesn't match"); + out.println("[" + LOREM_IPSUM + "] != [" + responseBody + "]"); + assertEquals(LOREM_IPSUM, responseBody); + } else { + out.println("Received expected response."); + } + } + + private HttpRequest buildRequest(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getPath().contains("/https3/")) { + builder.version(HTTP_3); + builder.setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder.build(); + } + static final class GZIPBodyHandler implements BodyHandler { @Override public HttpResponse.BodySubscriber apply(HttpResponse.ResponseInfo responseInfo) { @@ -565,10 +518,16 @@ public class GZIPInputStreamTest implements HttpServerAdapters { https2TestServer.addHandler(gzipHandler, "/https2/chunk/gz"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/chunk"; + https3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + https3TestServer.addHandler(plainHandler, "/https3/chunk/txt"); + https3TestServer.addHandler(gzipHandler, "/https3/chunk/gz"); + https3URI = "https://" + https3TestServer.serverAuthority() + "/https3/chunk"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + https3TestServer.start(); } @AfterTest @@ -580,6 +539,7 @@ public class GZIPInputStreamTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + https3TestServer.stop(); } finally { if (fail != null) { throw fail; diff --git a/test/jdk/java/net/httpclient/HandshakeFailureTest.java b/test/jdk/java/net/httpclient/HandshakeFailureTest.java index a6db6be817a..e07dfacbd85 100644 --- a/test/jdk/java/net/httpclient/HandshakeFailureTest.java +++ b/test/jdk/java/net/httpclient/HandshakeFailureTest.java @@ -142,6 +142,7 @@ public class HandshakeFailureTest { SSLParameters params = new SSLParameters(); params.setProtocols(new String[] { tlsProtocol }); return HttpClient.newBuilder() + .version(Version.HTTP_2) .sslParameters(params) .build(); } diff --git a/test/jdk/java/net/httpclient/HeadTest.java b/test/jdk/java/net/httpclient/HeadTest.java index 5b3b1671d43..e14339a1121 100644 --- a/test/jdk/java/net/httpclient/HeadTest.java +++ b/test/jdk/java/net/httpclient/HeadTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -24,17 +24,12 @@ /* * @test * @bug 8203433 8276559 - * @summary (httpclient) Add tests for HEAD and 304 responses. + * @summary Tests Client handles HEAD and 304 responses correctly. * @library /test/lib /test/jdk/java/net/httpclient/lib - * @build jdk.httpclient.test.lib.http2.Http2TestServer jdk.test.lib.net.SimpleSSLContext - * @run testng/othervm - * -Djdk.httpclient.HttpClient.log=trace,headers,requests - * HeadTest + * @build jdk.test.lib.net.SimpleSSLContext + * @run testng/othervm -Djdk.httpclient.HttpClient.log=trace,headers,requests HeadTest */ -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -44,20 +39,22 @@ import org.testng.annotations.Test; import javax.net.ssl.SSLContext; import java.io.IOException; import java.io.InputStream; -import java.net.InetAddress; -import java.net.InetSocketAddress; +import java.io.PrintStream; import java.net.URI; -import java.net.http.HttpClient; import java.net.http.HttpClient.Redirect; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; -import static java.lang.System.out; -import static java.net.http.HttpClient.Version.HTTP_1_1; +import static java.net.http.HttpClient.Version.HTTP_3; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_1_1; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static jdk.httpclient.test.lib.common.HttpServerAdapters.createClientBuilderForH3; import static org.testng.Assert.assertEquals; public class HeadTest implements HttpServerAdapters { @@ -67,10 +64,10 @@ public class HeadTest implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) - String httpURI; - String httpsURI; - String http2URI; - String https2URI; + HttpTestServer https3TestServer; // HTTP/3 + String httpURI, httpsURI; + String http2URI, https2URI; + String https3URI; static final String CONTENT_LEN = "300"; @@ -81,7 +78,7 @@ public class HeadTest implements HttpServerAdapters { */ static final int HTTP_NOT_MODIFIED = 304; static final int HTTP_OK = 200; - + static final PrintStream out = System.out; @DataProvider(name = "positive") public Object[][] positive() { @@ -96,28 +93,33 @@ public class HeadTest implements HttpServerAdapters { { httpURI + "transfer/", "HEAD", HTTP_OK, HTTP_1_1 }, { httpsURI + "transfer/", "HEAD", HTTP_OK, HTTP_1_1 }, // HTTP/2 - { http2URI, "GET", HTTP_NOT_MODIFIED, HttpClient.Version.HTTP_2 }, - { https2URI, "GET", HTTP_NOT_MODIFIED, HttpClient.Version.HTTP_2 }, - { http2URI, "HEAD", HTTP_OK, HttpClient.Version.HTTP_2 }, - { https2URI, "HEAD", HTTP_OK, HttpClient.Version.HTTP_2 }, - // HTTP2 forbids transfer-encoding + { http2URI, "GET", HTTP_NOT_MODIFIED, HTTP_2 }, + { https2URI, "GET", HTTP_NOT_MODIFIED, HTTP_2 }, + { http2URI, "HEAD", HTTP_OK, HTTP_2 }, + { https2URI, "HEAD", HTTP_OK, HTTP_2 }, + // HTTP/3 + { https3URI, "GET", HTTP_NOT_MODIFIED, HTTP_3 }, + { https3URI, "HEAD", HTTP_OK, HTTP_3 }, }; } @Test(dataProvider = "positive") void test(String uriString, String method, - int expResp, HttpClient.Version version) throws Exception { + int expResp, Version version) throws Exception { out.printf("%n---- starting (%s) ----%n", uriString); URI uri = URI.create(uriString); + Http3DiscoveryMode config = version.equals(HTTP_3) ? HTTP_3_URI_ONLY : null; HttpRequest.Builder requestBuilder = HttpRequest .newBuilder(uri) .version(version) + .setOption(H3_DISCOVERY, config) .method(method, HttpRequest.BodyPublishers.noBody()); doTest(requestBuilder.build(), expResp); // repeat the test this time by building the request using convenience // GET and HEAD methods requestBuilder = HttpRequest.newBuilder(uri) - .version(version); + .version(version) + .setOption(H3_DISCOVERY, config); switch (method) { case "GET" -> requestBuilder.GET(); case "HEAD" -> requestBuilder.HEAD(); @@ -128,32 +130,27 @@ public class HeadTest implements HttpServerAdapters { // issue a request with no body and verify the response code is the expected response code private void doTest(HttpRequest request, int expResp) throws Exception { - HttpClient client = HttpClient.newBuilder() - .followRedirects(Redirect.ALWAYS) - .sslContext(sslContext) - .build(); - out.println("Initial request: " + request.uri()); + try (var client = createClientBuilderForH3().followRedirects(Redirect.ALWAYS).sslContext(sslContext).build()) { + out.println("Initial request: " + request.uri()); + HttpResponse response = client.send(request, BodyHandlers.ofString()); - HttpResponse response = client.send(request, BodyHandlers.ofString()); + out.println(" Got response: " + response); - out.println(" Got response: " + response); - - assertEquals(response.statusCode(), expResp); - assertEquals(response.body(), ""); - assertEquals(response.headers().firstValue("Content-length").get(), CONTENT_LEN); - assertEquals(response.version(), request.version().get()); + assertEquals(response.statusCode(), expResp); + assertEquals(response.body(), ""); + assertEquals(response.headers().firstValue("Content-length").get(), CONTENT_LEN); + assertEquals(response.version(), request.version().get()); + } } // -- Infrastructure - + // TODO: See if test performs better with Vthreads, see H3SimplePost and H3SimpleGet @BeforeTest public void setup() throws Exception { sslContext = new SimpleSSLContext().get(); if (sslContext == null) throw new AssertionError("Unexpected null sslContext"); - InetSocketAddress sa = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0); - httpTestServer = HttpTestServer.create(HTTP_1_1); httpTestServer.addHandler(new HeadHandler(), "/"); httpURI = "http://" + httpTestServer.serverAuthority() + "/"; @@ -168,11 +165,16 @@ public class HeadTest implements HttpServerAdapters { https2TestServer.addHandler(new HeadHandler(), "/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/"; + https3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + https3TestServer.addHandler(new HeadHandler(), "/"); + https3URI = "https://" + https3TestServer.serverAuthority() + "/"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + https3TestServer.start(); } @AfterTest @@ -181,6 +183,7 @@ public class HeadTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + https3TestServer.stop(); } static class HeadHandler implements HttpTestHandler { diff --git a/test/jdk/java/net/httpclient/HeadersLowerCaseTest.java b/test/jdk/java/net/httpclient/HeadersLowerCaseTest.java new file mode 100644 index 00000000000..26c791eaca8 --- /dev/null +++ b/test/jdk/java/net/httpclient/HeadersLowerCaseTest.java @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.internal.net.http.common.Utils; +import jdk.test.lib.net.SimpleSSLContext; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/* + * @test + * @summary Verify that the request/response headers of HTTP/2 and HTTP/3 + * are sent and received in lower case + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @run junit/othervm -Djdk.internal.httpclient.debug=true HeadersLowerCaseTest + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class HeadersLowerCaseTest implements HttpServerAdapters { + + private static Set REQUEST_HEADERS; + + private HttpTestServer h2server; + private HttpTestServer h3server; + private String h2ReqURIBase; + private String h3ReqURIBase; + private SSLContext sslContext; + + @BeforeAll + public void beforeAll() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + h2server = HttpTestServer.create(HTTP_2, sslContext); + h2server.start(); + h2ReqURIBase = "https://" + h2server.serverAuthority(); + h2server.addHandler(new ReqHeadersVerifier(), "/h2verifyReqHeaders"); + System.out.println("HTTP/2 server listening on " + h2server.getAddress()); + + + h3server = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3server.start(); + h3ReqURIBase = "https://" + h3server.serverAuthority(); + h3server.addHandler(new ReqHeadersVerifier(), "/h3verifyReqHeaders"); + System.out.println("HTTP/3 server listening on " + h3server.getAddress()); + + REQUEST_HEADERS = new HashSet<>(); + REQUEST_HEADERS.add("AbCdeFgh"); + REQUEST_HEADERS.add("PQRSTU"); + REQUEST_HEADERS.add("xyz"); + REQUEST_HEADERS.add("A1243Bde2"); + REQUEST_HEADERS.add("123243"); + REQUEST_HEADERS.add("&1bacd*^d"); + REQUEST_HEADERS.add("~!#$%^&*_+"); + } + + @AfterAll + public void afterAll() throws Exception { + if (h2server != null) { + h2server.stop(); + } + if (h3server != null) { + h3server.stop(); + } + } + + /** + * Handler which verifies that the request header names are lowercase (as mandated by the spec) + */ + private static final class ReqHeadersVerifier implements HttpTestHandler { + + @Override + public void handle(final HttpTestExchange exchange) throws IOException { + System.out.println("Verifying request headers for " + exchange.getRequestURI()); + final Set missing = new HashSet<>(REQUEST_HEADERS); + final HttpTestRequestHeaders headers = exchange.getRequestHeaders(); + for (final Map.Entry> e : headers.entrySet()) { + final String header = e.getKey(); + // check validity of (non-pseudo) header names + if (!header.startsWith(":") && !Utils.isValidLowerCaseName(header)) { + System.err.println("Header name " + header + " is not valid"); + sendResponse(exchange, 500); + return; + } + final List headerVals = e.getValue(); + if (headerVals.isEmpty()) { + System.err.println("Header " + header + " is missing value"); + sendResponse(exchange, 500); + return; + } + // the header value represents the original form of the header key held in the + // REQUEST_HEADERS set + final String originalForm = headerVals.get(0); + missing.remove(originalForm); + } + if (!missing.isEmpty()) { + System.err.println("Missing headers in request: " + missing); + sendResponse(exchange, 500); + return; + } + System.out.println("All expected headers received in lower case for " + exchange.getRequestURI()); + sendResponse(exchange, 200); + } + + private static void sendResponse(final HttpTestExchange exchange, final int statusCode) throws IOException { + final HttpTestResponseHeaders respHeaders = exchange.getResponseHeaders(); + // we just send the pre-defined (request) headers back as the response headers + for (final String k : REQUEST_HEADERS) { + respHeaders.addHeader(k, k); + } + exchange.sendResponseHeaders(statusCode, 0); + } + } + + private Stream params() throws Exception { + return Stream.of( + Arguments.of(HTTP_2, new URI(h2ReqURIBase + "/h2verifyReqHeaders")), + Arguments.of(Version.HTTP_3, new URI(h3ReqURIBase + "/h3verifyReqHeaders"))); + } + + /** + * Issues a HTTP/2 or HTTP/3 request with header names of varying case (some in lower, + * some mixed, some upper case) and expects that the client internally converts them + * to lower case before encoding and sending to the server. The server side handler verifies + * that it receives the header names in lower case and if it doesn't then it returns a + * non-200 response + */ + @ParameterizedTest + @MethodSource("params") + public void testRequestHeaders(final Version version, final URI requestURI) throws Exception { + try (final HttpClient client = newClientBuilderForH3() + .version(version) + .sslContext(sslContext) + .proxy(HttpClient.Builder.NO_PROXY).build()) { + Http3DiscoveryMode config = switch (version) { + case HTTP_3 -> HTTP_3_URI_ONLY; + default -> ALT_SVC; + }; + final HttpRequest.Builder reqBuilder = HttpRequest.newBuilder(requestURI) + .setOption(H3_DISCOVERY, config) + .version(version); + for (final String k : REQUEST_HEADERS) { + reqBuilder.header(k, k); + } + final HttpRequest req = reqBuilder.build(); + System.out.println("Issuing " + version + " request to " + requestURI); + final HttpResponse resp = client.send(req, BodyHandlers.discarding()); + assertEquals(resp.version(), version, "Unexpected HTTP version in response"); + assertEquals(resp.statusCode(), 200, "Unexpected response code"); + // now try with async + System.out.println("Issuing (async) request to " + requestURI); + final CompletableFuture> futureResp = client.sendAsync(req, + BodyHandlers.discarding()); + final HttpResponse asyncResp = futureResp.get(); + assertEquals(asyncResp.version(), version, "Unexpected HTTP version in response"); + assertEquals(asyncResp.statusCode(), 200, "Unexpected response code"); + } + } + + /** + * Verifies that when a HTTP/2 or HTTP/3 request is being built using + * {@link HttpRequest.Builder}, only valid header names are allowed to be added to the request + */ + @ParameterizedTest + @MethodSource("params") + public void testInvalidHeaderName(final Version version, final URI requestURI) throws Exception { + Http3DiscoveryMode config = switch (version) { + case HTTP_3 -> HTTP_3_URI_ONLY; + default -> ALT_SVC; + }; + final HttpRequest.Builder reqBuilder = HttpRequest.newBuilder(requestURI) + .setOption(H3_DISCOVERY, config) + .version(version); + final String copyrightSign = new String(Character.toChars(0x00A9)); // copyright sign + final String invalidHeaderName = "abcd" + copyrightSign; + System.out.println("Adding header name " + invalidHeaderName + " to " + version + " request"); + // Field names are strings containing a subset of ASCII characters. + // This header name contains a unicode character, so it should fail + assertThrows(IllegalArgumentException.class, + () -> reqBuilder.header(invalidHeaderName, "something")); + } +} diff --git a/test/jdk/java/net/httpclient/HttpClientBuilderTest.java b/test/jdk/java/net/httpclient/HttpClientBuilderTest.java index 6074a3a855b..596320d49e6 100644 --- a/test/jdk/java/net/httpclient/HttpClientBuilderTest.java +++ b/test/jdk/java/net/httpclient/HttpClientBuilderTest.java @@ -357,6 +357,13 @@ public class HttpClientBuilderTest { @Test public void testVersion() { HttpClient.Builder builder = HttpClient.newBuilder(); + try (var closer = closeable(builder)) { + assertTrue(closer.build().version() == Version.HTTP_2); + } + builder.version(Version.HTTP_3); + try (var closer = closeable(builder)) { + assertTrue(closer.build().version() == Version.HTTP_3); + } builder.version(Version.HTTP_2); try (var closer = closeable(builder)) { assertTrue(closer.build().version() == Version.HTTP_2); @@ -366,6 +373,10 @@ public class HttpClientBuilderTest { assertTrue(closer.build().version() == Version.HTTP_1_1); } assertThrows(NPE, () -> builder.version(null)); + builder.version(Version.HTTP_3); + try (var closer = closeable(builder)) { + assertTrue(closer.build().version() == Version.HTTP_3); + } builder.version(Version.HTTP_2); try (var closer = closeable(builder)) { assertTrue(closer.build().version() == Version.HTTP_2); diff --git a/test/jdk/java/net/httpclient/HttpClientClose.java b/test/jdk/java/net/httpclient/HttpClientClose.java index 0425d27a25e..f761396487c 100644 --- a/test/jdk/java/net/httpclient/HttpClientClose.java +++ b/test/jdk/java/net/httpclient/HttpClientClose.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -46,7 +46,9 @@ import java.io.UncheckedIOException; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Redirect; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import java.nio.charset.StandardCharsets; @@ -57,13 +59,11 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.Flow; import java.util.concurrent.Flow.Publisher; import java.util.concurrent.Flow.Subscriber; import java.util.concurrent.Flow.Subscription; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; -import java.util.stream.Stream; import jdk.httpclient.test.lib.common.HttpServerAdapters; import javax.net.ssl.SSLContext; @@ -75,12 +75,14 @@ import org.testng.annotations.BeforeTest; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; -import static java.lang.System.err; -import static java.lang.System.in; import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; @@ -100,10 +102,15 @@ public class HttpClientClose implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer h2h3TestServer; // HTTP/3 ( h2 + h3 ) + HttpTestServer h3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String h2h3URI; + String h2h3Head; + String h3URI; static final String MESSAGE = "HttpClientClose message body"; static final int ITERATIONS = 3; @@ -111,10 +118,12 @@ public class HttpClientClose implements HttpServerAdapters { @DataProvider(name = "positive") public Object[][] positive() { return new Object[][] { - { httpURI, }, - { httpsURI, }, - { http2URI, }, - { https2URI, }, + { h2h3URI, HTTP_3, h2h3TestServer.h3DiscoveryConfig()}, + { h3URI, HTTP_3, h3TestServer.h3DiscoveryConfig()}, + { httpURI, HTTP_1_1, ALT_SVC}, // do not attempt HTTP/3 + { httpsURI, HTTP_1_1, ALT_SVC}, // do not attempt HTTP/3 + { http2URI, HTTP_2, ALT_SVC}, // do not attempt HTTP/3 + { https2URI, HTTP_2, ALT_SVC}, // do not attempt HTTP/3 }; } @@ -160,19 +169,39 @@ public class HttpClientClose implements HttpServerAdapters { } } - private record ExchangeResult(int step, HttpResponse response) { - public static ExchangeResult ofStep(int step) { - return new ExchangeResult(step, null); + record ExchangeResult(int step, + Version version, + Http3DiscoveryMode config, + HttpResponse response, + boolean firstVersionMayNotMatch) { + + static ExchangeResult afterHead(int step, Version version, Http3DiscoveryMode config) { + return new ExchangeResult(step, version, config, null, false); } + + static ExchangeResult ofSequential(int step, Version version, Http3DiscoveryMode config) { + return new ExchangeResult(step, version, config, null, true); + } + ExchangeResult withResponse(HttpResponse response) { - return new ExchangeResult(step, response); + return new ExchangeResult(step(), version(), config(), response, firstVersionMayNotMatch()); } + + // Ensures that the input stream gets closed in case of assertion ExchangeResult assertResponseState() { + out.println(step + ": Got response: " + response); try { - out.println(step + ": Got response: " + response); + out.printf("%s: expect status 200 and version %s (%s) for %s%n", step, version, config, + response.request().uri()); assertEquals(response.statusCode(), 200); + if (step == 0 && version == HTTP_3 && firstVersionMayNotMatch) { + out.printf("%s: version not checked%n", step); + } else { + assertEquals(response.version(), version); + out.printf("%s: got expected version %s%n", step, response.version()); + } } catch (AssertionError error) { - out.printf("%s: Closing body due to assertion - %s", error); + out.printf("%s: Closing body due to assertion - %s", step, error); ensureClosed(this); throw error; } @@ -180,35 +209,62 @@ public class HttpClientClose implements HttpServerAdapters { } } + static String readBody(int i, HttpResponse resp) { + try (var in = resp.body()) { + out.println(i + ": reading body for " + resp.request().uri()); + var body = new String(in.readAllBytes(), StandardCharsets.UTF_8); + out.println(i + ": got body " + body); + return body; + } catch (IOException io) { + out.println(i + ": failed to read body"); + throw new UncheckedIOException(io); + } + } + + void headRequest(HttpClient client) throws Exception { + HttpRequest request = HttpRequest.newBuilder(URI.create(h2h3Head)) + .version(HTTP_2) + .HEAD() + .build(); + var resp = client.send(request, BodyHandlers.discarding()); + assertEquals(resp.statusCode(), 200); + } + @Test(dataProvider = "positive") - void testConcurrent(String uriString) throws Exception { - out.printf("%n---- starting concurrent (%s) ----%n%n", uriString); + void testConcurrent(String uriString, Version version, Http3DiscoveryMode config) throws Exception { + out.printf("%n---- starting concurrent (%s, %s, %s) ----%n%n", uriString, version, config); Throwable failed = null; HttpClient toCheck = null; List> bodies = new ArrayList<>(); - try (HttpClient client = toCheck = HttpClient.newBuilder() + try (HttpClient client = toCheck = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) + .version(version == HTTP_1_1 ? HTTP_2 : version) .sslContext(sslContext) .build()) { TRACKER.track(client); + if (version == HTTP_3 && config != HTTP_3_URI_ONLY) { + headRequest(client); + } + for (int i = 0; i < ITERATIONS; i++) { URI uri = URI.create(uriString + "/concurrent/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) + .setOption(H3_DISCOVERY, config) .build(); out.printf("Iteration %d request: %s%n", i, request.uri()); CompletableFuture> responseCF; CompletableFuture bodyCF; final int si = i; - ExchangeResult result = ExchangeResult.ofStep(si); + ExchangeResult result = ExchangeResult.afterHead(i, version, config); responseCF = client.sendAsync(request, BodyHandlers.ofInputStream()) .thenApply(result::withResponse) .thenApplyAsync(ExchangeResult::assertResponseState, readerService) .thenApply(ExchangeResult::response); - bodyCF = responseCF.thenApplyAsync(HttpResponse::body, readerService) - .thenApply(HttpClientClose::readBody) + bodyCF = responseCF + .thenApplyAsync((resp) -> readBody(si, resp), readerService) .thenApply((s) -> { assertEquals(s, MESSAGE); return s; @@ -223,18 +279,25 @@ public class HttpClientClose implements HttpServerAdapters { } } assertTrue(toCheck.isTerminated()); - // assert all operations eventually terminate + + // Ensure all CF are eventually completed + out.printf("waiting for requests to complete%n"); CompletableFuture.allOf(bodies.toArray(new CompletableFuture[0])).get(); + out.printf("all requests completed%n"); + out.printf("%n---- end concurrent (%s, %s, %s): %s ----%n", + uriString, version, config, + failed == null ? "done" : failed.toString()); } @Test(dataProvider = "positive") - void testSequential(String uriString) throws Exception { - out.printf("%n---- starting sequential (%s) ----%n%n", uriString); + void testSequential(String uriString, Version version, Http3DiscoveryMode config) throws Exception { + out.printf("%n---- starting sequential (%s, %s, %s) ----%n%n", uriString, version, config); Throwable failed = null; HttpClient toCheck = null; - try (HttpClient client = toCheck = HttpClient.newBuilder() + try (HttpClient client = toCheck = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) + .version(version == HTTP_1_1 ? HTTP_2 : version) .sslContext(sslContext) .build()) { TRACKER.track(client); @@ -243,10 +306,11 @@ public class HttpClientClose implements HttpServerAdapters { URI uri = URI.create(uriString + "/sequential/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) + .setOption(H3_DISCOVERY, config) .build(); out.printf("Iteration %d request: %s%n", i, request.uri()); final int si = i; - ExchangeResult result = ExchangeResult.ofStep(si); + ExchangeResult result = ExchangeResult.ofSequential(si, version, config); CompletableFuture> responseCF; CompletableFuture bodyCF; responseCF = client.sendAsync(request, BodyHandlers.ofInputStream()) @@ -283,6 +347,7 @@ public class HttpClientClose implements HttpServerAdapters { if (sslContext == null) throw new AssertionError("Unexpected null sslContext"); readerService = Executors.newCachedThreadPool(); + httpTestServer = HttpTestServer.create(HTTP_1_1); httpTestServer.addHandler(new ServerRequestHandler(), "/http1/exec/"); httpURI = "http://" + httpTestServer.serverAuthority() + "/http1/exec/retry"; @@ -297,10 +362,21 @@ public class HttpClientClose implements HttpServerAdapters { https2TestServer.addHandler(new ServerRequestHandler(), "/https2/exec/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/exec/retry"; + h2h3TestServer = HttpTestServer.create(HTTP_3, sslContext); + h2h3TestServer.addHandler(new ServerRequestHandler(), "/h2h3/exec/"); + h2h3URI = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/exec/retry"; + h2h3TestServer.addHandler(new HttpHeadOrGetHandler(), "/h2h3/head/"); + h2h3Head = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/head/"; + h3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3TestServer.addHandler(new ServerRequestHandler(), "/h3-only/exec/"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3-only/exec/retry"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + h2h3TestServer.start(); + h3TestServer.start(); } @AfterTest @@ -313,6 +389,8 @@ public class HttpClientClose implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + h2h3TestServer.stop(); + h3TestServer.stop(); } finally { if (fail != null) throw fail; } diff --git a/test/jdk/java/net/httpclient/HttpClientShutdown.java b/test/jdk/java/net/httpclient/HttpClientShutdown.java index 3c77881c8aa..192abcba39d 100644 --- a/test/jdk/java/net/httpclient/HttpClientShutdown.java +++ b/test/jdk/java/net/httpclient/HttpClientShutdown.java @@ -47,7 +47,9 @@ import java.io.UncheckedIOException; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Redirect; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import java.nio.channels.ClosedChannelException; @@ -77,11 +79,14 @@ import org.testng.annotations.BeforeTest; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; -import static java.lang.System.err; import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; @@ -101,10 +106,15 @@ public class HttpClientShutdown implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer h2h3TestServer; // HTTP/3 ( h2 + h3 ) + HttpTestServer h3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String h2h3URI; + String h2h3Head; + String h3URI; static final String MESSAGE = "HttpClientShutdown message body"; static final int ITERATIONS = 3; @@ -112,10 +122,12 @@ public class HttpClientShutdown implements HttpServerAdapters { @DataProvider(name = "positive") public Object[][] positive() { return new Object[][] { - { httpURI, }, - { httpsURI, }, - { http2URI, }, - { https2URI, }, + { h2h3URI, HTTP_3, h2h3TestServer.h3DiscoveryConfig()}, + { h3URI, HTTP_3, h3TestServer.h3DiscoveryConfig()}, + { httpURI, HTTP_1_1, ALT_SVC}, // do not attempt HTTP/3 + { httpsURI, HTTP_1_1, ALT_SVC}, // do not attempt HTTP/3 + { http2URI, HTTP_2, ALT_SVC}, // do not attempt HTTP/3 + { https2URI, HTTP_2, ALT_SVC}, // do not attempt HTTP/3 }; } @@ -180,17 +192,37 @@ public class HttpClientShutdown implements HttpServerAdapters { } } - private record ExchangeResult(int step, HttpResponse response) { - public static ExchangeResult ofStep(int step) { - return new ExchangeResult(step, null); + record ExchangeResult(int step, + Version version, + Http3DiscoveryMode config, + HttpResponse response, + boolean firstVersionMayNotMatch) { + + static ExchangeResult afterHead(int step, Version version, Http3DiscoveryMode config) { + return new ExchangeResult(step, version, config, null, false); } + + static ExchangeResult ofSequential(int step, Version version, Http3DiscoveryMode config) { + return new ExchangeResult(step, version, config, null, true); + } + ExchangeResult withResponse(HttpResponse response) { - return new ExchangeResult(step, response); + return new ExchangeResult(step(), version(), config(), response, firstVersionMayNotMatch()); } + + // Ensures that the input stream gets closed in case of assertion ExchangeResult assertResponseState() { + out.println(now() + step + ": Got response: " + response); try { - out.println(now() + step + ": Got response: " + response); + out.printf(now() + "%s: expect status 200 and version %s (%s) for %s%n", step, version, config, + response.request().uri()); assertEquals(response.statusCode(), 200); + if (step == 0 && version == HTTP_3 && firstVersionMayNotMatch) { + out.printf(now() + "%s: version not checked%n", step); + } else { + assertEquals(response.version(), version); + out.printf(now() + "%s: got expected version %s%n", step, response.version()); + } } catch (AssertionError error) { out.printf(now() + "%s: Closing body due to assertion - %s", step, error); ensureClosed(this); @@ -200,6 +232,15 @@ public class HttpClientShutdown implements HttpServerAdapters { } } + void headRequest(HttpClient client) throws Exception { + HttpRequest request = HttpRequest.newBuilder(URI.create(h2h3Head)) + .version(HTTP_2) + .HEAD() + .build(); + var resp = client.send(request, BodyHandlers.discarding()); + assertEquals(resp.statusCode(), 200); + } + static boolean hasExpectedMessage(IOException io) { String message = io.getMessage(); if (message == null) return false; @@ -232,11 +273,13 @@ public class HttpClientShutdown implements HttpServerAdapters { } @Test(dataProvider = "positive") - void testConcurrent(String uriString) throws Exception { - out.printf("%n---- %sstarting concurrent (%s) ----%n%n", now(), uriString); - HttpClient client = HttpClient.newBuilder() + void testConcurrent(String uriString, Version version, Http3DiscoveryMode config) throws Exception { + out.printf("%n---- %sstarting concurrent (%s, %s, %s) ----%n%n", + now(), uriString, version, config); + HttpClient client = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) + .version(version == HTTP_1_1 ? HTTP_2 : version) .sslContext(sslContext) .build(); TRACKER.track(client); @@ -245,19 +288,24 @@ public class HttpClientShutdown implements HttpServerAdapters { Throwable failed = null; List> bodies = new ArrayList<>(); try { + if (version == HTTP_3 && config != HTTP_3_URI_ONLY) { + headRequest(client); + } + for (int i = 0; i < ITERATIONS; i++) { URI uri = URI.create(uriString + "/concurrent/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) + .setOption(H3_DISCOVERY, config) .build(); out.printf(now() + "Iteration %d request: %s%n", i, request.uri()); CompletableFuture> responseCF; CompletableFuture bodyCF; final int si = i; - ExchangeResult result = ExchangeResult.ofStep(si); + ExchangeResult result = ExchangeResult.afterHead(si, version, config); responseCF = client.sendAsync(request, BodyHandlers.ofInputStream()) .thenApply(result::withResponse) - .thenApplyAsync(ExchangeResult::assertResponseState, readerService) + .thenApplyAsync(ExchangeResult::assertResponseState) .thenApply(ExchangeResult::response); bodyCF = responseCF.thenApplyAsync(HttpResponse::body, readerService) .thenApply(HttpClientShutdown::readBody) @@ -328,11 +376,13 @@ public class HttpClientShutdown implements HttpServerAdapters { } @Test(dataProvider = "positive") - void testSequential(String uriString) throws Exception { - out.printf("%n---- %sstarting sequential (%s) ----%n%n", now(), uriString); - HttpClient client = HttpClient.newBuilder() + void testSequential(String uriString, Version version, Http3DiscoveryMode config) throws Exception { + out.printf("%n---- %sstarting sequential (%s, %s, %s) ----%n%n", + now(), uriString, version, config); + HttpClient client = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) + .version(version == HTTP_1_1 ? HTTP_2 : version) .sslContext(sslContext) .build(); TRACKER.track(client); @@ -345,12 +395,13 @@ public class HttpClientShutdown implements HttpServerAdapters { URI uri = URI.create(uriString + "/sequential/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) + .setOption(H3_DISCOVERY, config) .build(); out.printf(now() + "Iteration %d request: %s%n", i, request.uri()); final int si = i; CompletableFuture> responseCF; CompletableFuture bodyCF; - ExchangeResult result = ExchangeResult.ofStep(si); + ExchangeResult result = ExchangeResult.ofSequential(si, version, config); responseCF = client.sendAsync(request, BodyHandlers.ofInputStream()) .thenApply(result::withResponse) .thenApplyAsync(ExchangeResult::assertResponseState, readerService) @@ -432,10 +483,21 @@ public class HttpClientShutdown implements HttpServerAdapters { https2TestServer.addHandler(new ServerRequestHandler(), "/https2/exec/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/exec/retry"; + h2h3TestServer = HttpTestServer.create(HTTP_3, sslContext); + h2h3TestServer.addHandler(new ServerRequestHandler(), "/h2h3/exec/"); + h2h3URI = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/exec/retry"; + h2h3TestServer.addHandler(new HttpHeadOrGetHandler(), "/h2h3/head/"); + h2h3Head = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/head/"; + h3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3TestServer.addHandler(new ServerRequestHandler(), "/h3-only/exec/"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3-only/exec/retry"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + h2h3TestServer.start(); + h3TestServer.start(); start = System.nanoTime(); } @@ -449,6 +511,8 @@ public class HttpClientShutdown implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + h2h3TestServer.stop(); + h3TestServer.stop(); } finally { if (fail != null) throw fail; } diff --git a/test/jdk/java/net/httpclient/HttpGetInCancelledFuture.java b/test/jdk/java/net/httpclient/HttpGetInCancelledFuture.java index baa64356fa7..4ddc0bdfdc2 100644 --- a/test/jdk/java/net/httpclient/HttpGetInCancelledFuture.java +++ b/test/jdk/java/net/httpclient/HttpGetInCancelledFuture.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.io.UncheckedIOException; +import java.net.DatagramSocket; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ServerSocket; @@ -30,6 +31,8 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.HttpOption; import java.net.http.HttpResponse; import java.time.Duration; import java.util.List; @@ -47,20 +50,29 @@ import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; +import jdk.httpclient.test.lib.common.HttpServerAdapters; import jdk.internal.net.http.common.OperationTrackers.Tracker; import jdk.test.lib.net.SimpleSSLContext; import jdk.test.lib.net.URIBuilder; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; + +import static java.net.http.HttpClient.Version.HTTP_1_1; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; /* * @test * @bug 8316580 - * @library /test/lib + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build HttpGetInCancelledFuture ReferenceTracker * @run junit/othervm -DuseReferenceTracker=false * HttpGetInCancelledFuture * @run junit/othervm -DuseReferenceTracker=true @@ -82,7 +94,9 @@ public class HttpGetInCancelledFuture { static ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; HttpClient makeClient(URI uri, Version version, Executor executor) { - var builder = HttpClient.newBuilder(); + var builder = version == HTTP_3 + ? HttpServerAdapters.createClientBuilderForH3() + : HttpClient.newBuilder(); if (uri.getScheme().equalsIgnoreCase("https")) { try { builder.sslContext(new SimpleSSLContext().get()); @@ -96,9 +110,19 @@ public class HttpGetInCancelledFuture { .build(); } - record TestCase(String url, int reqCount, Version version) {} + record TestCase(String url, int reqCount, Version version, Http3DiscoveryMode config) { + TestCase(String url, int reqCount, Version version) { + this(url, reqCount, version, null); + } + TestCase(String url, int reqCount, Http3DiscoveryMode config) { + this(url, reqCount, HTTP_3, null); + } + } + + // A server that doesn't accept - static volatile ServerSocket NOT_ACCEPTING; + static volatile ServerSocket NOT_ACCEPTING; + static volatile DatagramSocket NOT_RESPONDING; static List parameters() { ServerSocket ss = NOT_ACCEPTING; @@ -116,6 +140,28 @@ public class HttpGetInCancelledFuture { } } } + + DatagramSocket ds = NOT_RESPONDING; + boolean sameport = false; + if (ds == null) { + synchronized (HttpGetInCancelledFuture.class) { + if ((ds = NOT_RESPONDING) == null) { + try { + var loopback = InetAddress.getLoopbackAddress(); + try { + ds = new DatagramSocket(new InetSocketAddress(loopback, ss.getLocalPort())); + sameport = true; + } catch (IOException io) { + ds = new DatagramSocket(new InetSocketAddress(loopback,0)); + } + NOT_RESPONDING = ds; + } catch (IOException io) { + throw new UncheckedIOException(io); + } + } + } + } + URI http = URIBuilder.newBuilder() .loopback() .scheme("http") @@ -128,13 +174,25 @@ public class HttpGetInCancelledFuture { .port(ss.getLocalPort()) .path("/not-accepting/") .buildUnchecked(); + URI https3 = URIBuilder.newBuilder() + .loopback() + .scheme("https") + .port(ds.getLocalPort()) + .path("/not-responding/") + .buildUnchecked(); // use all HTTP versions, without and with TLS - return List.of( - new TestCase(http.toString(), 200, Version.HTTP_2), - new TestCase(http.toString(), 200, Version.HTTP_1_1), - new TestCase(https.toString(), 200, Version.HTTP_2), - new TestCase(https.toString(), 200, Version.HTTP_1_1) + var def = Stream.of( + new TestCase(https3.toString(), 200, HTTP_3_URI_ONLY), + new TestCase(http.toString(), 200, HTTP_2), + new TestCase(http.toString(), 200, HTTP_1_1), + new TestCase(https.toString(), 200, HTTP_2), + new TestCase(https.toString(), 200, HTTP_1_1) ); + var first = sameport + ? Stream.of(new TestCase(https3.toString(), 200, ANY)) + : Stream.empty(); + var cases= Stream.concat(first, def); + return cases.toList(); } @ParameterizedTest @@ -251,7 +309,7 @@ public class HttpGetInCancelledFuture { Throwable failed = null; try { try (final var scope = new TestTaskScope.ShutdownOnFailure()) { - launchAndProcessRequests(scope, httpClient, reqCount, dest); + launchAndProcessRequests(scope, httpClient, reqCount, version, dest); } finally { System.out.printf("StructuredTaskScope closed: STARTED=%s, SUCCESS=%s, INTERRUPT=%s, FAILED=%s%n", STARTED.get(), SUCCESS.get(), INTERRUPT.get(), FAILED.get()); @@ -311,10 +369,11 @@ public class HttpGetInCancelledFuture { TestTaskScope.ShutdownOnFailure scope, HttpClient httpClient, int reqCount, + Version version, URI dest) { for (int counter = 0; counter < reqCount; counter++) { scope.fork(() -> - getAndCheck(httpClient, dest) + getAndCheck(httpClient, dest, version) ); } try { @@ -335,19 +394,21 @@ public class HttpGetInCancelledFuture { final AtomicLong FAILED = new AtomicLong(); final AtomicLong STARTED = new AtomicLong(); final CopyOnWriteArrayList EXCEPTIONS = new CopyOnWriteArrayList<>(); - private String getAndCheck(HttpClient httpClient, URI url) { + private String getAndCheck(HttpClient httpClient, URI url, Version version) { STARTED.incrementAndGet(); - final var response = sendRequest(httpClient, url); + final var response = sendRequest(httpClient, url, version); String res = response.body(); int statusCode = response.statusCode(); assertEquals(200, statusCode); return res; } - private HttpResponse sendRequest(HttpClient httpClient, URI url) { + private HttpResponse sendRequest(HttpClient httpClient, URI url, Version version) { var id = ID.incrementAndGet(); try { - var request = HttpRequest.newBuilder(url).GET().build(); + var builder = HttpRequest.newBuilder(url).version(version).GET(); + if (version == HTTP_3) builder.setOption(HttpOption.H3_DISCOVERY, HTTP_3_URI_ONLY); + var request = builder.build(); var response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); // System.out.println("Got response for " + id + ": " + response); SUCCESS.incrementAndGet(); @@ -372,16 +433,18 @@ public class HttpGetInCancelledFuture { if (error != null) throw error; } finally { ServerSocket ss; + DatagramSocket ds; synchronized (HttpGetInCancelledFuture.class) { ss = NOT_ACCEPTING; NOT_ACCEPTING = null; + ds = NOT_RESPONDING; + NOT_RESPONDING = null; } - if (ss != null) { - try { - ss.close(); - } catch (IOException io) { - throw new UncheckedIOException(io); - } + try (var ss1 = ss; var ds1 = ds;) { + System.out.printf("Cleaning up: ss=%s, ds=%s%n", + ss1.getLocalSocketAddress(), ds1.getLocalSocketAddress()); + } catch (IOException io) { + throw new UncheckedIOException(io); } } } diff --git a/test/jdk/java/net/httpclient/HttpRedirectTest.java b/test/jdk/java/net/httpclient/HttpRedirectTest.java index dedcc36dda7..358b908a03c 100644 --- a/test/jdk/java/net/httpclient/HttpRedirectTest.java +++ b/test/jdk/java/net/httpclient/HttpRedirectTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -29,6 +29,9 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static org.testng.Assert.*; import javax.net.ssl.SSLContext; @@ -93,12 +96,14 @@ public class HttpRedirectTest implements HttpServerAdapters { HttpTestServer http2Server; HttpTestServer https1Server; HttpTestServer https2Server; + HttpTestServer http3Server; DigestEchoServer.TunnelingProxy proxy; URI http1URI; URI https1URI; URI http2URI; URI https2URI; + URI http3URI; InetSocketAddress proxyAddress; ProxySelector proxySelector; HttpClient client; @@ -111,8 +116,7 @@ public class HttpRedirectTest implements HttpServerAdapters { TimeUnit.SECONDS, new LinkedBlockingQueue<>()); // Used by the client public HttpClient newHttpClient(ProxySelector ps) { - HttpClient.Builder builder = HttpClient - .newBuilder() + HttpClient.Builder builder = newClientBuilderForH3() .sslContext(context) .executor(clientexec) .followRedirects(HttpClient.Redirect.ALWAYS) @@ -123,6 +127,7 @@ public class HttpRedirectTest implements HttpServerAdapters { @DataProvider(name="uris") Object[][] testURIs() throws URISyntaxException { List uris = List.of( + http3URI.resolve("direct/orig/"), http1URI.resolve("direct/orig/"), https1URI.resolve("direct/orig/"), https1URI.resolve("proxy/orig/"), @@ -195,6 +200,13 @@ public class HttpRedirectTest implements HttpServerAdapters { https2Server.start(); https2URI = new URI("https://" + https2Server.serverAuthority() + "/HttpRedirectTest/https2/"); + // HTTPS/3 + http3Server = HttpTestServer.create(HTTP_3_URI_ONLY, SSLContext.getDefault()); + http3Server.addHandler(new HttpTestRedirectHandler("https", http3Server), + "/HttpRedirectTest/http3/"); + http3Server.start(); + http3URI = new URI("https://" + http3Server.serverAuthority() + "/HttpRedirectTest/http3/"); + proxy = DigestEchoServer.createHttpsProxyTunnel( DigestEchoServer.HttpAuthSchemeType.NONE); proxyAddress = proxy.getProxyAddress(); @@ -255,10 +267,18 @@ public class HttpRedirectTest implements HttpServerAdapters { } } + private HttpRequest.Builder newRequestBuilder(URI u) { + var builder = HttpRequest.newBuilder(u); + if (u.getRawPath().contains("/http3/")) { + builder.version(HTTP_3).setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + @Test(dataProvider = "uris") public void testPOST(URI uri, int code, String method) throws Exception { URI u = uri.resolve("foo?n=" + requestCounter.incrementAndGet()); - HttpRequest request = HttpRequest.newBuilder(u) + HttpRequest request = newRequestBuilder(u) .POST(HttpRequest.BodyPublishers.ofString(REQUEST_BODY)).build(); // POST is not considered idempotent. testNonIdempotent(u, request, code, method); @@ -268,7 +288,7 @@ public class HttpRedirectTest implements HttpServerAdapters { public void testPUT(URI uri, int code, String method) throws Exception { URI u = uri.resolve("foo?n=" + requestCounter.incrementAndGet()); System.out.println("Testing with " + u); - HttpRequest request = HttpRequest.newBuilder(u) + HttpRequest request = newRequestBuilder(u) .PUT(HttpRequest.BodyPublishers.ofString(REQUEST_BODY)).build(); // PUT is considered idempotent. testIdempotent(u, request, code, method); @@ -278,7 +298,7 @@ public class HttpRedirectTest implements HttpServerAdapters { public void testFoo(URI uri, int code, String method) throws Exception { URI u = uri.resolve("foo?n=" + requestCounter.incrementAndGet()); System.out.println("Testing with " + u); - HttpRequest request = HttpRequest.newBuilder(u) + HttpRequest request = newRequestBuilder(u) .method("FOO", HttpRequest.BodyPublishers.ofString(REQUEST_BODY)).build(); // FOO is considered idempotent. @@ -289,7 +309,7 @@ public class HttpRedirectTest implements HttpServerAdapters { public void testGet(URI uri, int code, String method) throws Exception { URI u = uri.resolve("foo?n=" + requestCounter.incrementAndGet()); System.out.println("Testing with " + u); - HttpRequest request = HttpRequest.newBuilder(u) + HttpRequest request = newRequestBuilder(u) .method("GET", HttpRequest.BodyPublishers.ofString(REQUEST_BODY)).build(); CompletableFuture> respCf = @@ -320,6 +340,7 @@ public class HttpRedirectTest implements HttpServerAdapters { https1Server = stop(https1Server, HttpTestServer::stop); http2Server = stop(http2Server, HttpTestServer::stop); https2Server = stop(https2Server, HttpTestServer::stop); + http3Server = stop(http3Server, HttpTestServer::stop); client = null; try { executor.awaitTermination(2000, TimeUnit.MILLISECONDS); diff --git a/test/jdk/java/net/httpclient/HttpRequestBuilderTest.java b/test/jdk/java/net/httpclient/HttpRequestBuilderTest.java index 9401c10cdf2..4544c85c5e8 100644 --- a/test/jdk/java/net/httpclient/HttpRequestBuilderTest.java +++ b/test/jdk/java/net/httpclient/HttpRequestBuilderTest.java @@ -23,11 +23,14 @@ import java.net.URI; import java.net.http.HttpClient; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.HttpOption; import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -36,6 +39,8 @@ import java.util.stream.Stream; import java.net.http.HttpRequest; import static java.net.http.HttpRequest.BodyPublishers.ofString; import static java.net.http.HttpRequest.BodyPublishers.noBody; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; /* * @test @@ -206,6 +211,15 @@ public class HttpRequestBuilderTest { builder = test2("method", builder, builder::method, null, ofString("foo"), NullPointerException.class); + + builder = test2("setOption", builder, builder::setOption, + (HttpOption)null, (Http3DiscoveryMode) null, + NullPointerException.class); + + builder = test2("setOption", builder, builder::setOption, + (HttpOption)null, HTTP_3_URI_ONLY, + NullPointerException.class); + // see JDK-8170093 // // builder = test2("method", builder, builder::method, "foo", @@ -268,6 +282,20 @@ public class HttpRequestBuilderTest { HttpRequest defaultHeadReq = new NotOverriddenHEADImpl().HEAD().uri(TEST_URI).build(); assertEquals("HEAD", defaultHeadReq.method(), "Method"); assertEquals(false, defaultHeadReq.bodyPublisher().isEmpty(), "Body publisher absence"); + HttpRequest defaultReqWithoutOption = new NotOverriddenHEADImpl().HEAD().uri(TEST_URI).build(); + assertEquals(Optional.empty(), defaultReqWithoutOption.getOption(H3_DISCOVERY), "default without options"); + HttpRequest defaultReqWithOption = new NotOverriddenHEADImpl().HEAD().uri(TEST_URI) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY).build(); + assertEquals(Optional.empty(), defaultReqWithOption.getOption(H3_DISCOVERY), "default with options"); + HttpRequest reqWithoutOption = HttpRequest.newBuilder().HEAD().uri(TEST_URI).build(); + assertEquals(Optional.empty(), reqWithoutOption.getOption(H3_DISCOVERY), "req without options"); + HttpRequest reqWithOption = HttpRequest.newBuilder().HEAD().uri(TEST_URI) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY).build(); + assertEquals(Optional.of(HTTP_3_URI_ONLY), reqWithOption.getOption(H3_DISCOVERY), "req with options"); + HttpRequest resetReqWithOption = HttpRequest.newBuilder().HEAD().uri(TEST_URI) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .setOption(H3_DISCOVERY, null).build(); + assertEquals(Optional.empty(), resetReqWithOption.getOption(H3_DISCOVERY), "req with option reset"); verifyCopy(); @@ -383,6 +411,7 @@ public class HttpRequestBuilderTest { .method("GET", noBody()) .expectContinue(true) .timeout(Duration.ofSeconds(0xBEEF)) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) .version(HttpClient.Version.HTTP_2); // Create the original and the _copy_ requests @@ -399,6 +428,8 @@ public class HttpRequestBuilderTest { assertEquals(request.expectContinue(), copiedRequest.expectContinue(), "Expect continue setting"); assertEquals(request.timeout(), copiedRequest.timeout(), "Timeout"); assertEquals(request.version(), copiedRequest.version(), "Version"); + assertEquals(request.getOption(H3_DISCOVERY), copiedRequest.getOption(H3_DISCOVERY), "H3_DISCOVERY option"); + assertEquals(Optional.of(HTTP_3_URI_ONLY), copiedRequest.getOption(H3_DISCOVERY), "copied H3_DISCOVERY option"); // Verify headers assertEquals(request.headers().map(), Map.of("X-Foo", List.of("1")), "Request headers"); @@ -415,66 +446,85 @@ public class HttpRequestBuilderTest { // doesn't override the default HEAD() method private static final class NotOverriddenHEADImpl implements HttpRequest.Builder { - private final HttpRequest.Builder underlying = HttpRequest.newBuilder(); + private final HttpRequest.Builder underlying; + + NotOverriddenHEADImpl() { + this(HttpRequest.newBuilder()); + } + + NotOverriddenHEADImpl(HttpRequest.Builder underlying) { + this.underlying = underlying; + } @Override public HttpRequest.Builder uri(URI uri) { - return this.underlying.uri(uri); + underlying.uri(uri); + return this; } @Override public HttpRequest.Builder expectContinue(boolean enable) { - return this.underlying.expectContinue(enable); + underlying.expectContinue(enable); return this; } @Override public HttpRequest.Builder version(HttpClient.Version version) { - return this.underlying.version(version); + this.underlying.version(version); + return this; } @Override public HttpRequest.Builder header(String name, String value) { - return this.underlying.header(name, value); + this.underlying.header(name, value); + return this; } @Override public HttpRequest.Builder headers(String... headers) { - return this.underlying.headers(headers); + underlying.headers(headers); + return this; } @Override public HttpRequest.Builder timeout(Duration duration) { - return this.underlying.timeout(duration); + this.underlying.timeout(duration); + return this; } @Override public HttpRequest.Builder setHeader(String name, String value) { - return this.underlying.setHeader(name, value); + underlying.setHeader(name, value); + return this; } @Override public HttpRequest.Builder GET() { - return this.underlying.GET(); + this.underlying.GET(); + return this; } @Override public HttpRequest.Builder POST(HttpRequest.BodyPublisher bodyPublisher) { - return this.underlying.POST(bodyPublisher); + this.underlying.POST(bodyPublisher); + return this; } @Override public HttpRequest.Builder PUT(HttpRequest.BodyPublisher bodyPublisher) { - return this.underlying.PUT(bodyPublisher); + this.underlying.PUT(bodyPublisher); + return this; } @Override public HttpRequest.Builder DELETE() { - return this.underlying.DELETE(); + this.underlying.DELETE(); + return this; } @Override public HttpRequest.Builder method(String method, HttpRequest.BodyPublisher bodyPublisher) { - return this.underlying.method(method, bodyPublisher); + this.underlying.method(method, bodyPublisher); + return this; } @Override @@ -484,7 +534,7 @@ public class HttpRequestBuilderTest { @Override public HttpRequest.Builder copy() { - return this.underlying.copy(); + return new NotOverriddenHEADImpl(underlying.copy()); } } } diff --git a/test/jdk/java/net/httpclient/HttpRequestNewBuilderTest.java b/test/jdk/java/net/httpclient/HttpRequestNewBuilderTest.java index eab9ecfc2c7..d7598ede3be 100644 --- a/test/jdk/java/net/httpclient/HttpRequestNewBuilderTest.java +++ b/test/jdk/java/net/httpclient/HttpRequestNewBuilderTest.java @@ -1,5 +1,5 @@ /* -* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. +* Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -39,6 +39,9 @@ import java.util.function.BiConsumer; import java.util.function.BiPredicate; import static java.net.http.HttpClient.Version.HTTP_2; import static java.net.http.HttpClient.Version.HTTP_1_1; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import org.testng.annotations.Test; @@ -66,6 +69,7 @@ public class HttpRequestNewBuilderTest { new NamedAssertion("timeout", (r1, r2) -> assertEquals(r1.timeout(), r2.timeout())), new NamedAssertion("version", (r1, r2) -> assertEquals(r1.version(), r2.version())), new NamedAssertion("headers", (r1, r2) -> assertEquals(r1.headers(), r2.headers())), + new NamedAssertion("options", (r1, r2) -> assertEquals(r1.getOption(H3_DISCOVERY), r2.getOption(H3_DISCOVERY))), new NamedAssertion("expectContinue", (r1, r2) -> assertEquals(r1.expectContinue(), r2.expectContinue())), new NamedAssertion("method", (r1, r2) -> { assertEquals(r1.method(), r2.method()); @@ -141,6 +145,9 @@ public class HttpRequestNewBuilderTest { { HttpRequest.newBuilder(URI.create("https://all-fields-1/")).GET().expectContinue(true).version(HTTP_2) .timeout(Duration.ofSeconds(1)).header("testName1", "testValue1").build() }, + { HttpRequest.newBuilder(URI.create("https://all-fields-2/")).GET().expectContinue(true).version(HTTP_2) + .timeout(Duration.ofSeconds(1)).header("testName1", "testValue1") + .setOption(H3_DISCOVERY, ANY).build() }, }; } @@ -313,6 +320,15 @@ public class HttpRequestNewBuilderTest { assertAllOtherElementsEqual(r, request, "headers"); } + @Test(dataProvider = "testRequests") + public void testSetOption(HttpRequest request) { + BiPredicate filter = (n, v) -> true; + + var r = HttpRequest.newBuilder(request, filter).setOption(H3_DISCOVERY, ALT_SVC).build(); + assertEquals(r.getOption(H3_DISCOVERY).get(), ALT_SVC); + assertAllOtherElementsEqual(r, request, "options"); + } + @Test(dataProvider = "testRequests") public void testRemoveHeader(HttpRequest request) { if(!request.headers().map().isEmpty()) { @@ -325,6 +341,18 @@ public class HttpRequestNewBuilderTest { assertEquals(r.headers().map(), HttpHeaders.of(request.headers().map(), filter).map()); } + @Test(dataProvider = "testRequests") + public void testRemoveOption(HttpRequest request) { + if(!request.getOption(H3_DISCOVERY).isEmpty()) { + assertEquals(request.getOption(H3_DISCOVERY).get(), ANY); + } + + var r = HttpRequest.newBuilder(request, (a, b) -> true) + .setOption(H3_DISCOVERY, null).build(); + assertTrue(r.getOption(H3_DISCOVERY).isEmpty()); + assertAllOtherElementsEqual(r, request, "options"); + } + @Test(dataProvider = "testRequests") public void testRemoveSingleHeaderValue(HttpRequest request) { if(!request.headers().map().isEmpty()) { diff --git a/test/jdk/java/net/httpclient/HttpResponseConnectionLabelTest.java b/test/jdk/java/net/httpclient/HttpResponseConnectionLabelTest.java index b6c13c51ee1..d1dc8b0a5f5 100644 --- a/test/jdk/java/net/httpclient/HttpResponseConnectionLabelTest.java +++ b/test/jdk/java/net/httpclient/HttpResponseConnectionLabelTest.java @@ -29,10 +29,13 @@ * @build jdk.httpclient.test.lib.common.HttpServerAdapters * jdk.test.lib.net.SimpleSSLContext * - * @comment Use a higher idle timeout to increase the chances of the same connection being used for sequential HTTP requests - * @run junit/othervm -Djdk.httpclient.keepalive.timeout=120 HttpResponseConnectionLabelTest + * @comment Use a higher idle timeout to increase the chances of + * the same connection being used for sequential HTTP requests + * @run junit/othervm -Djdk.httpclient.keepalive.timeout=120 + * HttpResponseConnectionLabelTest */ +import jdk.httpclient.test.lib.common.HttpServerAdapters; import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestHandler; import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; import jdk.internal.net.http.common.Logger; @@ -55,6 +58,7 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import java.nio.charset.Charset; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; @@ -64,6 +68,8 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Stream; import static java.net.http.HttpClient.Builder.NO_PROXY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.US_ASCII; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; @@ -83,8 +89,12 @@ class HttpResponseConnectionLabelTest { private static final SSLContext SSL_CONTEXT = createSslContext(); - // Start with a fresh client having no connections in the pool - private final HttpClient client = HttpClient.newBuilder().sslContext(SSL_CONTEXT).proxy(NO_PROXY).build(); + // For each test instance, start with a fresh client having no connections in the pool + private final HttpClient client = HttpServerAdapters + .createClientBuilderForH3() + .sslContext(SSL_CONTEXT) + .proxy(NO_PROXY) + .build(); // Primary server-client pairs @@ -96,6 +106,8 @@ class HttpResponseConnectionLabelTest { private static final ServerRequestPair PRI_HTTPS2 = ServerRequestPair.of(Version.HTTP_2, true); + private static final ServerRequestPair PRI_HTTP3 = ServerRequestPair.of(Version.HTTP_3, true); + // Secondary server-client pairs private static final ServerRequestPair SEC_HTTP1 = ServerRequestPair.of(Version.HTTP_1_1, false); @@ -106,6 +118,8 @@ class HttpResponseConnectionLabelTest { private static final ServerRequestPair SEC_HTTPS2 = ServerRequestPair.of(Version.HTTP_2, true); + private static final ServerRequestPair SEC_HTTP3 = ServerRequestPair.of(Version.HTTP_3, true); + private static SSLContext createSslContext() { try { return new SimpleSSLContext().get(); @@ -140,8 +154,8 @@ class HttpResponseConnectionLabelTest { AtomicReference serverResponseLatchRef = new AtomicReference<>(); server.addHandler(createServerHandler(serverId, serverResponseLatchRef), handlerPath); - // Create the client and the request - HttpRequest request = HttpRequest.newBuilder(requestUri).version(version).build(); + // Create the request + HttpRequest request = createRequest(version, requestUri); // Create the pair ServerRequestPair pair = new ServerRequestPair( @@ -168,13 +182,15 @@ class HttpResponseConnectionLabelTest { // - Only the HTTP/1.1 test server gets wedged when running // tests involving parallel request handling. // - // - The HTTP/2 test server creates its own sufficiently sized - // executor, and the thread names used there makes it easy to - // find which server they belong to. + // - The HTTP/2 and HTTP/3 test servers create their own + // sufficiently sized executor, and the thread names used + // there makes it easy to find which server they belong to. executorRef[0] = Version.HTTP_1_1.equals(version) ? createExecutor(version, secure, serverId) : null; - return HttpTestServer.create(version, sslContext, executorRef[0]); + return Version.HTTP_3.equals(version) + ? HttpTestServer.create(HTTP_3_URI_ONLY, sslContext, executorRef[0]) + : HttpTestServer.create(version, sslContext, executorRef[0]); } catch (IOException exception) { throw new UncheckedIOException(exception); } @@ -196,7 +212,7 @@ class HttpResponseConnectionLabelTest { return (exchange) -> { String responseBody = "" + SERVER_RESPONSE_COUNTER.getAndIncrement(); String connectionKey = exchange.getConnectionKey(); - LOGGER.log("Server[%d] has received request (connectionKey=%s)", serverId, connectionKey); + LOGGER.log("Server[%s] has received request (connectionKey=%s)", serverId, connectionKey); try (exchange) { // Participate in the latch count down @@ -232,6 +248,14 @@ class HttpResponseConnectionLabelTest { }; } + private static HttpRequest createRequest(Version version, URI requestUri) { + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(requestUri).version(version); + if (Version.HTTP_3.equals(version)) { + requestBuilder.setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return requestBuilder.build(); + } + @Override public String toString() { String version = server.getVersion().toString(); @@ -244,7 +268,9 @@ class HttpResponseConnectionLabelTest { static void closeServers() { Exception[] exceptionRef = {null}; Stream - .of(PRI_HTTP1, PRI_HTTPS1, PRI_HTTP2, PRI_HTTPS2, SEC_HTTP1, SEC_HTTPS1, SEC_HTTP2, SEC_HTTPS2) + .of( + PRI_HTTP1, PRI_HTTPS1, PRI_HTTP2, PRI_HTTPS2, PRI_HTTP3, + SEC_HTTP1, SEC_HTTPS1, SEC_HTTP2, SEC_HTTPS2, SEC_HTTP3) .flatMap(pair -> Stream.of( pair.server::stop, () -> { if (pair.executor != null) { pair.executor.shutdownNow(); } })) @@ -274,7 +300,8 @@ class HttpResponseConnectionLabelTest { PRI_HTTP1, PRI_HTTPS1, PRI_HTTP2, - PRI_HTTPS2 + PRI_HTTPS2, + PRI_HTTP3 }; } @@ -283,8 +310,9 @@ class HttpResponseConnectionLabelTest { void testParallelRequestsToSameServer(ServerRequestPair pair) throws Exception { // There is no implementation-agnostic reliable way to force admission - // of multiple connections targeting the same server to an HTTP/2 pool. - if (Version.HTTP_2.equals(pair.server.getVersion())) { + // of multiple connections targeting the same server to an HTTP/2 or + // HTTP/3 client connection pool. + if (Set.of(Version.HTTP_2, Version.HTTP_3).contains(pair.server.getVersion())) { return; } @@ -359,9 +387,9 @@ class HttpResponseConnectionLabelTest { static Stream testParallelRequestsToDifferentServers() { return Stream - .of(PRI_HTTP1, PRI_HTTPS1, PRI_HTTP2, PRI_HTTPS2) + .of(PRI_HTTP1, PRI_HTTPS1, PRI_HTTP2, PRI_HTTPS2, PRI_HTTP3) .flatMap(source -> Stream - .of(SEC_HTTP1, SEC_HTTPS1, SEC_HTTP2, SEC_HTTPS2) + .of(SEC_HTTP1, SEC_HTTPS1, SEC_HTTP2, SEC_HTTPS2, SEC_HTTP3) .map(target -> Arguments.of(source, target))); } @@ -440,7 +468,7 @@ class HttpResponseConnectionLabelTest { } static Stream testSerialRequestsToSameServer() { - return Stream.of(PRI_HTTP1, PRI_HTTPS1, PRI_HTTP2, PRI_HTTPS2); + return Stream.of(PRI_HTTP1, PRI_HTTPS1, PRI_HTTP2, PRI_HTTPS2, PRI_HTTP3); } @ParameterizedTest diff --git a/test/jdk/java/net/httpclient/HttpResponseLimitingTest.java b/test/jdk/java/net/httpclient/HttpResponseLimitingTest.java index b87e7ea8e49..e0bda0b0071 100644 --- a/test/jdk/java/net/httpclient/HttpResponseLimitingTest.java +++ b/test/jdk/java/net/httpclient/HttpResponseLimitingTest.java @@ -34,6 +34,7 @@ * @run junit HttpResponseLimitingTest */ +import jdk.httpclient.test.lib.common.HttpServerAdapters; import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; import jdk.test.lib.RandomFactory; import jdk.test.lib.net.SimpleSSLContext; @@ -66,6 +67,11 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import static java.net.http.HttpClient.Builder.NO_PROXY; +import static java.net.http.HttpClient.Version.HTTP_1_1; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Arrays.copyOfRange; import static org.junit.jupiter.api.Assertions.assertArrayEquals; @@ -89,15 +95,17 @@ class HttpResponseLimitingTest { */ private static final String RESPONSE_HEADER_VALUE = "!".repeat(RESPONSE_BODY.length + 1); - private static final ServerClientPair HTTP1 = ServerClientPair.of(HttpClient.Version.HTTP_1_1, false); + private static final ServerClientPair H1 = ServerClientPair.of(HTTP_1_1, false); - private static final ServerClientPair HTTPS1 = ServerClientPair.of(HttpClient.Version.HTTP_1_1, true); + private static final ServerClientPair H1S = ServerClientPair.of(HTTP_1_1, true); - private static final ServerClientPair HTTP2 = ServerClientPair.of(HttpClient.Version.HTTP_2, false); + private static final ServerClientPair H2 = ServerClientPair.of(HTTP_2, false); - private static final ServerClientPair HTTPS2 = ServerClientPair.of(HttpClient.Version.HTTP_2, true); + private static final ServerClientPair H2S = ServerClientPair.of(HTTP_2, true); - private record ServerClientPair(HttpTestServer server, HttpClient client, HttpRequest request) { + private static final ServerClientPair H3 = ServerClientPair.of(HTTP_3, true); + + private record ServerClientPair(HttpTestServer server, HttpClient client, HttpRequest request, boolean secure) { private static final SSLContext SSL_CONTEXT = createSslContext(); @@ -128,7 +136,7 @@ class HttpResponseLimitingTest { // Create the server and the request URI SSLContext sslContext = secure ? SSL_CONTEXT : null; HttpTestServer server = createServer(version, sslContext); - String handlerPath = "/"; + String handlerPath = "/" + /* salting the path: */ HttpResponseLimitingTest.class.getSimpleName(); String requestUriScheme = secure ? "https" : "http"; URI requestUri = URI.create(requestUriScheme + "://" + server.serverAuthority() + handlerPath); @@ -146,29 +154,40 @@ class HttpResponseLimitingTest { // Create the client and the request HttpClient client = createClient(version, sslContext); - HttpRequest request = HttpRequest.newBuilder(requestUri).version(version).build(); + HttpRequest request = createRequest(version, requestUri); // Create the pair - return new ServerClientPair(server, client, request); + return new ServerClientPair(server, client, request, secure); } private static HttpTestServer createServer(HttpClient.Version version, SSLContext sslContext) { try { - return HttpTestServer.create(version, sslContext); + return HTTP_3.equals(version) + ? HttpTestServer.create(HTTP_3_URI_ONLY, sslContext) + : HttpTestServer.create(version, sslContext); } catch (IOException exception) { throw new UncheckedIOException(exception); } } private static HttpClient createClient(HttpClient.Version version, SSLContext sslContext) { - HttpClient.Builder builder = HttpClient.newBuilder().version(version).proxy(NO_PROXY); + HttpClient.Builder builder = HttpServerAdapters.createClientBuilderFor(version) + .version(version).proxy(NO_PROXY); if (sslContext != null) { builder.sslContext(sslContext); } return builder.build(); } + private static HttpRequest createRequest(HttpClient.Version version, URI requestUri) { + HttpRequest.Builder builder = HttpRequest.newBuilder(requestUri).version(version); + if (HTTP_3.equals(version)) { + builder.setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder.build(); + } + private HttpResponse request(BodyHandler downstreamHandler, long capacity) throws Exception { var handler = BodyHandlers.limiting(downstreamHandler, capacity); return client.send(request, handler); @@ -176,8 +195,9 @@ class HttpResponseLimitingTest { @Override public String toString() { - String version = client.version().toString(); - return client.sslContext() != null ? version.replaceFirst("_", "S_") : version; + HttpClient.Version version = client.version(); + String versionString = version.toString(); + return secure && !HTTP_3.equals(version) ? versionString.replaceFirst("_", "S_") : versionString; } } @@ -186,7 +206,7 @@ class HttpResponseLimitingTest { static void closeServerClientPairs() { Exception[] exceptionRef = {null}; Stream - .of(HTTP1, HTTPS1, HTTP2, HTTPS2) + .of(H1, H1S, H2, H2S, H3) .flatMap(pair -> Stream.of( pair.client::close, pair.server::stop)) @@ -306,7 +326,7 @@ class HttpResponseLimitingTest { private static Arguments[] capacityArgs(long... capacities) { return Stream - .of(HTTP1, HTTPS1, HTTP2, HTTPS2) + .of(H1, H1S, H2, H2S, H3) .flatMap(pair -> Arrays .stream(capacities) .mapToObj(capacity -> Arguments.of(pair, capacity))) diff --git a/test/jdk/java/net/httpclient/HttpSlowServerTest.java b/test/jdk/java/net/httpclient/HttpSlowServerTest.java index a71eaf746ad..ca841cd8ae9 100644 --- a/test/jdk/java/net/httpclient/HttpSlowServerTest.java +++ b/test/jdk/java/net/httpclient/HttpSlowServerTest.java @@ -52,6 +52,9 @@ import java.util.concurrent.atomic.AtomicLong; import jdk.httpclient.test.lib.common.HttpServerAdapters; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; /** * @test @@ -62,8 +65,9 @@ import static java.net.http.HttpClient.Version.HTTP_2; * DigestEchoServer HttpSlowServerTest * jdk.httpclient.test.lib.common.TestServerConfigurator * @run main/othervm/timeout=480 -Dtest.requiresHost=true - * -Djdk.httpclient.HttpClient.log=headers + * -Djdk.httpclient.HttpClient.log=errors,headers,quic:hs * -Djdk.internal.httpclient.debug=false + * -Djdk.httpclient.quic.maxInitialTimeout=60 * HttpSlowServerTest * */ @@ -97,12 +101,14 @@ public class HttpSlowServerTest implements HttpServerAdapters { HttpTestServer http2Server; HttpTestServer https1Server; HttpTestServer https2Server; + HttpTestServer http3Server; DigestEchoServer.TunnelingProxy proxy; URI http1URI; URI https1URI; URI http2URI; URI https2URI; + URI http3URI; InetSocketAddress proxyAddress; ProxySelector proxySelector; HttpClient client; @@ -115,8 +121,7 @@ public class HttpSlowServerTest implements HttpServerAdapters { TimeUnit.SECONDS, new LinkedBlockingQueue<>()); // Used by the client public HttpClient newHttpClient(ProxySelector ps) { - HttpClient.Builder builder = HttpClient - .newBuilder() + HttpClient.Builder builder = newClientBuilderForH3() .sslContext(context) .executor(clientexec) .proxy(ps); @@ -155,6 +160,12 @@ public class HttpSlowServerTest implements HttpServerAdapters { https2Server.start(); https2URI = new URI("https://" + https2Server.serverAuthority() + "/HttpSlowServerTest/https2/"); + // HTTP/3 + http3Server = HttpTestServer.create(HTTP_3_URI_ONLY, SSLContext.getDefault()); + http3Server.addHandler(new HttpTestSlowHandler(), "/HttpSlowServerTest/http3/"); + http3Server.start(); + http3URI = new URI("https://" + http3Server.serverAuthority() + "/HttpSlowServerTest/http3/"); + proxy = DigestEchoServer.createHttpsProxyTunnel( DigestEchoServer.HttpAuthSchemeType.NONE); proxyAddress = proxy.getProxyAddress(); @@ -185,9 +196,10 @@ public class HttpSlowServerTest implements HttpServerAdapters { } public void run(String... args) throws Exception { - List serverURIs = List.of(http1URI, http2URI, https1URI, https2URI); + List serverURIs = List.of(http3URI, http1URI, http2URI, https1URI, https2URI); for (int i=0; i<20; i++) { for (URI base : serverURIs) { + if (base.getRawPath().contains("/http3/")) continue; // proxy not supported if (base.getScheme().equalsIgnoreCase("https")) { URI proxy = i % 1 == 0 ? base.resolve(URI.create("proxy/foo?n="+requestCounter.incrementAndGet())) : base.resolve(URI.create("direct/foo?n="+requestCounter.incrementAndGet())); @@ -205,7 +217,11 @@ public class HttpSlowServerTest implements HttpServerAdapters { public void test(URI uri) throws Exception { System.out.println("Testing with " + uri); pending.add(uri); - HttpRequest request = HttpRequest.newBuilder(uri).build(); + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder.version(HTTP_3).setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + HttpRequest request = builder.build(); CompletableFuture> resp = client.sendAsync(request, HttpResponse.BodyHandlers.ofString()) .whenComplete((r, t) -> this.requestCompleted(request, r, t)); @@ -228,6 +244,7 @@ public class HttpSlowServerTest implements HttpServerAdapters { https1Server = stop(https1Server, HttpTestServer::stop); http2Server = stop(http2Server, HttpTestServer::stop); https2Server = stop(https2Server, HttpTestServer::stop); + http3Server = stop(http3Server, HttpTestServer::stop); client = null; try { executor.awaitTermination(2000, TimeUnit.MILLISECONDS); diff --git a/test/jdk/java/net/httpclient/ISO_8859_1_Test.java b/test/jdk/java/net/httpclient/ISO_8859_1_Test.java index e81b0130418..a5465ad5103 100644 --- a/test/jdk/java/net/httpclient/ISO_8859_1_Test.java +++ b/test/jdk/java/net/httpclient/ISO_8859_1_Test.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -41,50 +41,31 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.net.URI; -import java.net.URL; import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; -import java.net.http.HttpRequest.BodyPublisher; -import java.net.http.HttpRequest.BodyPublishers; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; -import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; -import java.util.concurrent.Flow.Subscriber; -import java.util.concurrent.Flow.Subscription; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; -import java.util.function.Supplier; import java.util.stream.Collectors; -import java.util.stream.LongStream; -import java.util.stream.Stream; import javax.net.ssl.SSLContext; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; -import org.testng.Assert; import org.testng.ITestContext; import org.testng.ITestResult; import org.testng.SkipException; @@ -98,24 +79,28 @@ import org.testng.annotations.Test; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; -import static org.testng.Assert.expectThrows; public class ISO_8859_1_Test implements HttpServerAdapters { SSLContext sslContext; DummyServer http1DummyServer; - HttpServerAdapters.HttpTestServer http1TestServer; // HTTP/1.1 ( http ) - HttpServerAdapters.HttpTestServer https1TestServer; // HTTPS/1.1 ( https ) - HttpServerAdapters.HttpTestServer http2TestServer; // HTTP/2 ( h2c ) - HttpServerAdapters.HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http1TestServer; // HTTP/1.1 ( http ) + HttpTestServer https1TestServer; // HTTPS/1.1 ( https ) + HttpTestServer http2TestServer; // HTTP/2 ( h2c ) + HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String http1Dummy; String http1URI; String https1URI; String http2URI; String https2URI; + String http3URI; static final int RESPONSE_CODE = 200; static final int ITERATION_COUNT = 4; @@ -216,6 +201,7 @@ public class ISO_8859_1_Test implements HttpServerAdapters { private String[] uris() { return new String[] { + http3URI, http1Dummy, http1URI, https1URI, @@ -243,9 +229,12 @@ public class ISO_8859_1_Test implements HttpServerAdapters { return result; } - private HttpClient makeNewClient() { + private HttpClient makeNewClient(Version version) { clientCount.incrementAndGet(); - HttpClient client = HttpClient.newBuilder() + var builder = version == HTTP_3 + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + HttpClient client = builder .proxy(HttpClient.Builder.NO_PROXY) .executor(executor) .sslContext(sslContext) @@ -253,14 +242,24 @@ public class ISO_8859_1_Test implements HttpServerAdapters { return TRACKER.track(client); } - HttpClient newHttpClient(boolean share) { - if (!share) return makeNewClient(); + Version version(String uri) { + if (uri == null) return null; + if (uri.contains("/http1/")) return HTTP_1_1; + if (uri.contains("/https1/")) return HTTP_1_1; + if (uri.contains("/http2/")) return HTTP_2; + if (uri.contains("/https2/")) return HTTP_2; + if (uri.contains("/http3/")) return HTTP_3; + return null; + } + + HttpClient newHttpClient(String uri, boolean share) { + if (!share) return makeNewClient(version(uri)); HttpClient shared = sharedClient; if (shared != null) return shared; synchronized (this) { shared = sharedClient; if (shared == null) { - shared = sharedClient = makeNewClient(); + shared = sharedClient = makeNewClient(HTTP_3); } return shared; } @@ -277,16 +276,25 @@ public class ISO_8859_1_Test implements HttpServerAdapters { return (Exception)c; } + private static HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + @Test(dataProvider = "variants") public void test(String uri, boolean sameClient) throws Exception { checkSkip(); System.out.println("Request to " + uri); - HttpClient client = newHttpClient(sameClient); + HttpClient client = newHttpClient(uri, sameClient); List>> cfs = new ArrayList<>(); for (int i = 0; i < ITERATION_COUNT; i++) { - HttpRequest request = HttpRequest.newBuilder(URI.create(uri + "/" + i)) + HttpRequest request = newRequestBuilder(URI.create(uri + "/" + i)) .build(); cfs.add(client.sendAsync(request, BodyHandlers.ofString())); } @@ -431,12 +439,17 @@ public class ISO_8859_1_Test implements HttpServerAdapters { https2TestServer.addHandler(handler, "/https2/server/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/server/x"; - serverCount.addAndGet(5); + http3TestServer = HttpServerAdapters.HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(handler, "/http3/server/"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/server/x"; + + serverCount.addAndGet(6); http1TestServer.start(); https1TestServer.start(); http2TestServer.start(); https2TestServer.start(); http1DummyServer.start(); + http3TestServer.start(); } @AfterTest @@ -452,6 +465,7 @@ public class ISO_8859_1_Test implements HttpServerAdapters { http2TestServer.stop(); https2TestServer.stop(); http1DummyServer.close(); + http3TestServer.stop(); } finally { if (fail != null) { if (sharedClientName != null) { diff --git a/test/jdk/java/net/httpclient/IdleConnectionTimeoutTest.java b/test/jdk/java/net/httpclient/IdleConnectionTimeoutTest.java new file mode 100644 index 00000000000..07bb0be455c --- /dev/null +++ b/test/jdk/java/net/httpclient/IdleConnectionTimeoutTest.java @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http2.BodyOutputStream; +import jdk.httpclient.test.lib.http2.Http2TestExchangeImpl; +import jdk.httpclient.test.lib.http2.Http2TestServerConnection; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.internal.net.http.common.HttpHeadersBuilder; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintStream; +import java.net.SocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.concurrent.CompletableFuture; +import jdk.httpclient.test.lib.http2.Http2TestServer; +import jdk.httpclient.test.lib.http2.Http2TestExchange; +import jdk.httpclient.test.lib.http2.Http2Handler; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSession; + +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.net.http.HttpClient.Version.HTTP_2; +import static org.testng.Assert.assertEquals; + +/* + * @test + * @bug 8288717 + * @summary Tests that when the idle connection timeout is configured for a HTTP connection, + * then the connection is closed if it has been idle for that long + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.httpclient.test.lib.http2.Http2TestServer + * jdk.httpclient.test.lib.http3.Http3TestServer + * + * @run testng/othervm -Djdk.httpclient.HttpClient.log=all -Djdk.httpclient.keepalive.timeout=1 + * IdleConnectionTimeoutTest + * @run testng/othervm -Djdk.httpclient.HttpClient.log=all -Djdk.httpclient.keepalive.timeout=20 + * IdleConnectionTimeoutTest + * + * @run testng/othervm -Djdk.httpclient.HttpClient.log=all -Djdk.httpclient.keepalive.timeout.h2=1 + * IdleConnectionTimeoutTest + * @run testng/othervm -Djdk.httpclient.HttpClient.log=all -Djdk.httpclient.keepalive.timeout.h2=20 + * IdleConnectionTimeoutTest + * @run testng/othervm -Djdk.httpclient.HttpClient.log=all -Djdk.httpclient.keepalive.timeout.h2=abc + * IdleConnectionTimeoutTest + * @run testng/othervm -Djdk.httpclient.HttpClient.log=all -Djdk.httpclient.keepalive.timeout.h2=-1 + * IdleConnectionTimeoutTest + * + * @run testng/othervm -Djdk.httpclient.HttpClient.log=all -Djdk.httpclient.keepalive.timeout.h3=1 + * IdleConnectionTimeoutTest + * @run testng/othervm -Djdk.httpclient.HttpClient.log=all -Djdk.httpclient.keepalive.timeout.h3=20 + * IdleConnectionTimeoutTest + * @run testng/othervm -Djdk.httpclient.HttpClient.log=all -Djdk.httpclient.keepalive.timeout.h3=abc + * IdleConnectionTimeoutTest + * @run testng/othervm -Djdk.httpclient.HttpClient.log=all -Djdk.httpclient.keepalive.timeout.h3=-1 + * IdleConnectionTimeoutTest + */ +public class IdleConnectionTimeoutTest { + + URI timeoutUriH2, noTimeoutUriH2, timeoutUriH3, noTimeoutUriH3, getH3; + SSLContext sslContext; + static volatile QuicServerConnection latestServerConn; + final String KEEP_ALIVE_PROPERTY = "jdk.httpclient.keepalive.timeout"; + final String IDLE_CONN_PROPERTY_H2 = "jdk.httpclient.keepalive.timeout.h2"; + final String IDLE_CONN_PROPERTY_H3 = "jdk.httpclient.keepalive.timeout.h3"; + final String TIMEOUT_PATH = "/serverTimeoutHandler"; + final String NO_TIMEOUT_PATH = "/noServerTimeoutHandler"; + static Http2TestServer http2TestServer; + static Http3TestServer http3TestServer; + static final PrintStream testLog = System.err; + + @BeforeTest + public void setup() throws Exception { + http2TestServer = new Http2TestServer(false, 0); + http2TestServer.addHandler(new ServerTimeoutHandlerH2(), TIMEOUT_PATH); + http2TestServer.addHandler(new ServerNoTimeoutHandlerH2(), NO_TIMEOUT_PATH); + http2TestServer.setExchangeSupplier(TestExchange::new); + + sslContext = new SimpleSSLContext().get(); + http3TestServer = new Http3TestServer(sslContext) { + @Override + public boolean acceptIncoming(SocketAddress source, QuicServerConnection quicConn) { + final boolean accepted = super.acceptIncoming(source, quicConn); + if (accepted) { + // Quic Connection maps to Http3Connection, can use this to verify h3 timeouts + latestServerConn = quicConn; + } + return accepted; + } + }; + http3TestServer.addHandler(TIMEOUT_PATH, new ServerTimeoutHandlerH3()); + http3TestServer.addHandler(NO_TIMEOUT_PATH, new ServerNoTimeoutHandlerH3()); + + http2TestServer.start(); + http3TestServer.start(); + int port = http2TestServer.getAddress().getPort(); + timeoutUriH2 = URIBuilder.newBuilder() + .scheme("http") + .loopback() + .port(port) + .path(TIMEOUT_PATH) + .build(); + noTimeoutUriH2 = URIBuilder.newBuilder() + .scheme("http") + .loopback() + .port(port) + .path(NO_TIMEOUT_PATH) + .build(); + + port = http3TestServer.getAddress().getPort(); + getH3 = URIBuilder.newBuilder() + .scheme("https") + .loopback() + .port(port) + .path("/get") + .build(); + timeoutUriH3 = URIBuilder.newBuilder() + .scheme("https") + .loopback() + .port(port) + .path(TIMEOUT_PATH) + .build(); + noTimeoutUriH3 = URIBuilder.newBuilder() + .scheme("https") + .loopback() + .port(port) + .path(NO_TIMEOUT_PATH) + .build(); + } + + @Test + public void testRoot() { + String keepAliveVal = System.getProperty(KEEP_ALIVE_PROPERTY); + String idleConnectionH2Val = System.getProperty(IDLE_CONN_PROPERTY_H2); + String idleConnectionH3Val = System.getProperty(IDLE_CONN_PROPERTY_H3); + + if (keepAliveVal != null) { + try (HttpClient hc = HttpClient.newBuilder().version(HTTP_2).build()) { + // test H2 inherits value + testLog.println("Testing HTTP/2 connections set idleConnectionTimeout value to keep alive value"); + test(hc, keepAliveVal, HTTP_2, timeoutUriH2, noTimeoutUriH2); + } + try (HttpClient hc = HttpServerAdapters.createClientBuilderForH3().sslContext(sslContext).build()) { + // test H3 inherits value + testLog.println("Testing HTTP/3 connections set idleConnectionTimeout value to keep alive value"); + test(hc, keepAliveVal, HTTP_3, timeoutUriH3, noTimeoutUriH3); + } + } else if (idleConnectionH2Val != null) { + try (HttpClient hc = HttpClient.newBuilder().version(HTTP_2).build()) { + testLog.println("Testing HTTP/2 idleConnectionTimeout"); + test(hc, idleConnectionH2Val, HTTP_2, timeoutUriH2, noTimeoutUriH2); + } + } else if (idleConnectionH3Val != null) { + try (HttpClient hc = HttpServerAdapters.createClientBuilderForH3().sslContext(sslContext).build()) { + testLog.println("Testing HTTP/3 idleConnectionTimeout"); + test(hc, idleConnectionH3Val, HTTP_3, timeoutUriH3, noTimeoutUriH3); + } + } + + } + + private void test(HttpClient hc, String propVal, Version version, URI timeoutUri, URI noTimeoutUri) { + if (propVal.equals("1")) { + testTimeout(hc, timeoutUri, version); + } else if (propVal.equals("20")) { + testNoTimeout(hc, noTimeoutUri, version); + } else if (propVal.equals("abc") || propVal.equals("-1")) { + testNoTimeout(hc, noTimeoutUri, version); + } else { + throw new RuntimeException("Unexpected timeout value"); + } + } + + private void testTimeout(HttpClient hc, URI uri, Version version) { + // Timeout should occur + var config = version == HTTP_3 ? HTTP_3_URI_ONLY : null; + HttpRequest hreq = HttpRequest.newBuilder(uri).version(version).GET() + .setOption(H3_DISCOVERY, config).build(); + HttpResponse hresp = runRequest(hc, hreq, 2750); + assertEquals(hresp.statusCode(), 200, "idleConnectionTimeoutEvent was not expected but occurred"); + } + + private void testNoTimeout(HttpClient hc, URI uri, Version version) { + // Timeout should not occur + var config = version == HTTP_3 ? HTTP_3_URI_ONLY : null; + HttpRequest hreq = HttpRequest.newBuilder(uri).version(version).GET() + .setOption(H3_DISCOVERY, config).build(); + HttpResponse hresp = runRequest(hc, hreq, 0); + assertEquals(hresp.statusCode(), 200, "idleConnectionTimeoutEvent was not expected but occurred"); + } + + private HttpResponse runRequest(HttpClient hc, HttpRequest req, int sleepTime) { + CompletableFuture> request = hc.sendAsync(req, HttpResponse.BodyHandlers.ofString(UTF_8)); + HttpResponse hresp = request.join(); + assertEquals(hresp.statusCode(), 200); + try { + Thread.sleep(sleepTime); + } catch (InterruptedException e) { + e.printStackTrace(); + } + request = hc.sendAsync(req, HttpResponse.BodyHandlers.ofString(UTF_8)); + return request.join(); + } + + static class ServerTimeoutHandlerH2 implements Http2Handler { + + volatile Object firstConnection = null; + + @Override + public void handle(Http2TestExchange exchange) throws IOException { + if (exchange instanceof TestExchange exch) { + if (firstConnection == null) { + firstConnection = exch.getServerConnection(); + exch.sendResponseHeaders(200, 0); + } else { + var secondConnection = exch.getServerConnection(); + + if (firstConnection != secondConnection) { + testLog.println("ServerTimeoutHandlerH2: New Connection was used, idleConnectionTimeoutEvent fired." + + " First Connection: " + firstConnection + ", Second Connection Hash: " + secondConnection); + exch.sendResponseHeaders(200, 0); + } else { + testLog.println("ServerTimeoutHandlerH2: Same Connection was used, idleConnectionTimeoutEvent did not fire." + + " First Connection: " + firstConnection + ", Second Connection Hash: " + secondConnection); + exch.sendResponseHeaders(400, 0); + } + } + } + } + } + + static class ServerNoTimeoutHandlerH2 implements Http2Handler { + + volatile Object firstConnection; + + @Override + public void handle(Http2TestExchange exchange) throws IOException { + if (exchange instanceof TestExchange exch) { + if (firstConnection == null) { + firstConnection = exch.getServerConnection(); + exch.sendResponseHeaders(200, 0); + } else { + var secondConnection = exch.getServerConnection(); + + if (firstConnection == secondConnection) { + testLog.println("ServerTimeoutHandlerH2: Same Connection was used, idleConnectionTimeoutEvent did not fire." + + " First Connection: " + firstConnection + ", Second Connection Hash: " + secondConnection); + exch.sendResponseHeaders(200, 0); + } else { + testLog.println("ServerTimeoutHandlerH2: Different Connection was used, idleConnectionTimeoutEvent fired." + + " First Connection: " + firstConnection + ", Second Connection Hash: " + secondConnection); + exch.sendResponseHeaders(400, 0); + } + } + } + } + } + + static class ServerTimeoutHandlerH3 implements Http2Handler { + + volatile Object firstConnection; + + @Override + public void handle(Http2TestExchange exchange) throws IOException { + if (firstConnection == null) { + firstConnection = latestServerConn; + exchange.sendResponseHeaders(200, 0); + } else { + var secondConnection = latestServerConn; + if (firstConnection != secondConnection) { + testLog.println("ServerTimeoutHandlerH3: New Connection was used, idleConnectionTimeoutEvent fired." + + " First Connection: " + firstConnection + ", Second Connection Hash: " + secondConnection); + exchange.sendResponseHeaders(200, 0); + } else { + testLog.println("ServerTimeoutHandlerH3: Same Connection was used, idleConnectionTimeoutEvent did not fire." + + " First Connection: " + firstConnection + ", Second Connection Hash: " + secondConnection); + exchange.sendResponseHeaders(400, 0); + } + } + exchange.close(); + } + } + + static class ServerNoTimeoutHandlerH3 implements Http2Handler { + + volatile Object firstConnection = null; + + @Override + public void handle(Http2TestExchange exchange) throws IOException { + if (firstConnection == null) { + firstConnection = latestServerConn; + exchange.sendResponseHeaders(200, 0); + } else { + var secondConnection = latestServerConn; + + if (firstConnection == secondConnection) { + testLog.println("ServerTimeoutHandlerH3: Same Connection was used, idleConnectionTimeoutEvent did not fire." + + " First Connection: " + firstConnection + ", Second Connection Hash: " + secondConnection); + exchange.sendResponseHeaders(200, 0); + } else { + testLog.println("ServerTimeoutHandlerH3: New Connection was used, idleConnectionTimeoutEvent fired." + + " First Connection: " + firstConnection + ", Second Connection Hash: " + secondConnection); + exchange.sendResponseHeaders(400, 0); + } + } + exchange.close(); + } + } + + static class TestExchange extends Http2TestExchangeImpl { + + public TestExchange(int streamid, String method, + HttpHeaders reqheaders, HttpHeadersBuilder rspheadersBuilder, + URI uri, InputStream is, SSLSession sslSession, BodyOutputStream os, + Http2TestServerConnection conn, boolean pushAllowed) { + super(streamid, method, reqheaders, rspheadersBuilder, uri, is, sslSession, os, + conn, pushAllowed); + } + + public Http2TestServerConnection getServerConnection() { + return this.conn; + } + } +} diff --git a/test/jdk/java/net/httpclient/ImmutableFlowItems.java b/test/jdk/java/net/httpclient/ImmutableFlowItems.java index b440ad646e8..186d3bbd9f8 100644 --- a/test/jdk/java/net/httpclient/ImmutableFlowItems.java +++ b/test/jdk/java/net/httpclient/ImmutableFlowItems.java @@ -46,7 +46,6 @@ import com.sun.net.httpserver.HttpHandler; import com.sun.net.httpserver.HttpServer; import com.sun.net.httpserver.HttpsServer; import java.net.http.HttpClient; -import java.net.http.HttpHeaders; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; diff --git a/test/jdk/java/net/httpclient/ImmutableSSLSessionTest.java b/test/jdk/java/net/httpclient/ImmutableSSLSessionTest.java new file mode 100644 index 00000000000..179900479bc --- /dev/null +++ b/test/jdk/java/net/httpclient/ImmutableSSLSessionTest.java @@ -0,0 +1,381 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.security.Principal; +import java.security.cert.Certificate; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; +import javax.net.ssl.ExtendedSSLSession; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.internal.net.http.common.ImmutableExtendedSSLSession; +import jdk.internal.net.http.common.ImmutableSSLSession; +import jdk.internal.net.http.common.ImmutableSSLSessionAccess; +import jdk.test.lib.net.SimpleSSLContext; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import static java.net.http.HttpClient.Version.HTTP_1_1; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.nio.charset.StandardCharsets.US_ASCII; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/* + * @test + * @summary Verify that the request/response headers of HTTP/2 and HTTP/3 + * are sent and received in lower case + * @library /test/lib /test/jdk/java/net/httpclient/lib /test/jdk/java/net/httpclient/access + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * java.net.http/jdk.internal.net.http.common.ImmutableSSLSessionAccess + * @run junit/othervm -Djdk.httpclient.HttpClient.log=request,response,headers,errors + * ImmutableSSLSessionTest + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class ImmutableSSLSessionTest implements HttpServerAdapters { + + private HttpTestServer h1server; + private HttpTestServer h2server; + private HttpTestServer h3server; + private String h1ReqURIBase; + private String h2ReqURIBase; + private String h3ReqURIBase; + private static SSLContext sslContext; + private final AtomicInteger counter = new AtomicInteger(); + + @BeforeAll + public void beforeAll() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + h1server = HttpTestServer.create(HTTP_1_1, sslContext); + h1server.start(); + h1ReqURIBase = "https://" + h1server.serverAuthority() + "/h1ImmutableSSLSessionTest/"; + h1server.addHandler(new HttpHeadOrGetHandler(), "/h1ImmutableSSLSessionTest/"); + System.out.println("HTTP/1.1 server listening on " + h1server.getAddress()); + + h2server = HttpTestServer.create(HTTP_2, sslContext); + h2server.start(); + h2ReqURIBase = "https://" + h2server.serverAuthority() + "/h2ImmutableSSLSessionTest/"; + h2server.addHandler(new HttpHeadOrGetHandler(), "/h2ImmutableSSLSessionTest/"); + System.out.println("HTTP/2 server listening on " + h2server.getAddress()); + + + h3server = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3server.start(); + h3ReqURIBase = "https://" + h3server.serverAuthority() + "/h3ImmutableSSLSessionTest/"; + h3server.addHandler(new HttpHeadOrGetHandler(), "/h3ImmutableSSLSessionTest/"); + System.out.println("HTTP/3 server listening on " + h3server.getAddress()); + + } + + @AfterAll + public void afterAll() throws Exception { + if (h2server != null) { + h2server.stop(); + } + if (h3server != null) { + h3server.stop(); + } + } + + private Stream params() throws Exception { + return Stream.of( + Arguments.of(HTTP_1_1, new URI(h1ReqURIBase)), + Arguments.of(HTTP_2, new URI(h2ReqURIBase)), + Arguments.of(HTTP_3, new URI(h3ReqURIBase))); + } + + private Stream sessions() throws Exception { + return Stream.of( + Arguments.of(ImmutableSSLSessionAccess.immutableSSLSession(new DummySession())), + Arguments.of(ImmutableSSLSessionAccess.immutableExtendedSSLSession(new DummySession()))); + } + + /** + * Issues an HTTPS request and verifies that the SSLSession + * is immutable. + */ + @ParameterizedTest + @MethodSource("params") + public void testImmutableSSLSession(final Version version, final URI requestURI) throws Exception { + Http3DiscoveryMode config = switch (version) { + case HTTP_3 -> HTTP_3_URI_ONLY; + default -> null; + }; + + final HttpRequest.Builder reqBuilder = HttpRequest.newBuilder() + .setOption(H3_DISCOVERY, config) + .version(version); + final HttpClient.Builder clientBuilder = (version == HTTP_3 + ? newClientBuilderForH3() + : HttpClient.newBuilder()) + .version(version) + .sslContext(sslContext) + .proxy(HttpClient.Builder.NO_PROXY); + + try (HttpClient client = clientBuilder.build()) { + final URI uriSync = URI.create(requestURI.toString() + "?sync=true,req=" + counter.incrementAndGet()); + final HttpRequest reqSync = reqBuilder.uri(uriSync).build(); + System.out.println("Issuing " + version + " request to " + uriSync); + final HttpResponse resp = client.send(reqSync, BodyHandlers.discarding()); + final Optional syncSession = resp.sslSession(); + assertEquals(resp.version(), version, "Unexpected HTTP version in response"); + assertEquals(resp.statusCode(), 200, "Unexpected response code"); + checkImmutableSession(resp.sslSession()); + } + + // now try with async + try (HttpClient client = clientBuilder.build()) { + final URI uriAsync = URI.create(requestURI.toString() + "?sync=false,req=" + counter.incrementAndGet()); + final HttpRequest reqAsync = reqBuilder.copy().uri(uriAsync).build(); + System.out.println("Issuing (async) request to " + uriAsync); + final CompletableFuture> futureResp = client.sendAsync(reqAsync, + BodyHandlers.discarding()); + final HttpResponse asyncResp = futureResp.get(); + assertEquals(asyncResp.version(), version, "Unexpected HTTP version in response"); + assertEquals(asyncResp.statusCode(), 200, "Unexpected response code"); + checkImmutableSession(asyncResp.sslSession()); + } + } + + @ParameterizedTest + @MethodSource("sessions") + public void testImmutableSSLSessionClass(SSLSession session) throws Exception { + System.out.println("Checking session class: " + session.getClass()); + checkDummySession(session); + } + + + private void checkImmutableSession(Optional session) { + assertNotNull(session); + assertTrue(session.isPresent()); + SSLSession sess = session.get(); + assertNotNull(sess); + checkImmutableSession(sess); + } + + private void checkImmutableSession(SSLSession session) { + if (session instanceof ExtendedSSLSession) { + assertEquals(ImmutableExtendedSSLSession.class, session.getClass()); + } else { + assertEquals(ImmutableSSLSession.class, session.getClass()); + } + assertThrows(UnsupportedOperationException.class, session::invalidate); + assertThrows(UnsupportedOperationException.class, + () -> session.putValue("foo", "bar")); + for (String name : session.getValueNames()) { + assertThrows(UnsupportedOperationException.class, + () -> session.removeValue(name)); + } + } + + private void checkDummySession(SSLSession session) throws Exception { + checkImmutableSession(session); + assertEquals("abcd", new String(session.getId(), US_ASCII)); + assertEquals(sslContext.getClientSessionContext(), session.getSessionContext()); + assertEquals(42, session.getCreationTime()); + assertEquals(4242, session.getLastAccessedTime()); + assertFalse(session.isValid()); + assertEquals("bar", session.getValue("foo")); + assertEquals(List.of("foo"), Arrays.asList(session.getValueNames())); + assertEquals(0, session.getPeerCertificates().length); + assertEquals(0, session.getLocalCertificates().length); + assertNull(session.getPeerPrincipal()); + assertNull(session.getLocalPrincipal()); + assertEquals("MyCipherSuite", session.getCipherSuite()); + assertEquals("TLSv1.3", session.getProtocol()); + assertEquals("dummy", session.getPeerHost()); + assertEquals(42, session.getPeerPort()); + assertEquals(42, session.getPacketBufferSize()); + assertEquals(42, session.getApplicationBufferSize()); + if (session instanceof ExtendedSSLSession ext) { + assertEquals(List.of("bar", "foo"), + Arrays.asList(ext.getPeerSupportedSignatureAlgorithms())); + assertEquals(List.of("foo", "bar"), + Arrays.asList(ext.getLocalSupportedSignatureAlgorithms())); + assertEquals(List.of(new SNIHostName("localhost")), ((ExtendedSSLSession) session).getRequestedServerNames()); + List status = ext.getStatusResponses(); + assertEquals(1, status.size()); + assertEquals("42", new String(status.get(0), US_ASCII)); + assertThrows(UnsupportedOperationException.class, + () -> ext.exportKeyingMaterialData("foo", new byte[] {1,2,3,4}, 4)); + assertThrows(UnsupportedOperationException.class, + () -> ext.exportKeyingMaterialKey("foo", "foo", new byte[] {1,2,3,4}, 4)); + } + } + + static class DummySession extends ExtendedSSLSession { + + @Override + public byte[] getId() { + return new byte[] {'a', 'b', 'c', 'd'}; + } + + @Override + public SSLSessionContext getSessionContext() { + return sslContext.getClientSessionContext(); + } + + @Override + public long getCreationTime() { + return 42; + } + + @Override + public long getLastAccessedTime() { + return 4242; + } + + @Override + public void invalidate() {} + + @Override + public boolean isValid() { + return false; + } + + @Override + public void putValue(String name, Object value) { + + } + + @Override + public Object getValue(String name) { + if (name.equals("foo")) return "bar"; + return null; + } + + @Override + public void removeValue(String name) { + + } + + @Override + public String[] getValueNames() { + return new String[] {"foo"}; + } + + @Override + public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + return new Certificate[0]; + } + + @Override + public Certificate[] getLocalCertificates() { + return new Certificate[0]; + } + + @Override + public Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + return null; + } + + @Override + public Principal getLocalPrincipal() { + return null; + } + + @Override + public String getCipherSuite() { + return "MyCipherSuite"; + } + + @Override + public String getProtocol() { + return "TLSv1.3"; + } + + @Override + public String getPeerHost() { + return "dummy"; + } + + @Override + public int getPeerPort() { + return 42; + } + + @Override + public int getPacketBufferSize() { + return 42; + } + + @Override + public int getApplicationBufferSize() { + return 42; + } + + @Override + public String[] getLocalSupportedSignatureAlgorithms() { + return new String[] {"foo", "bar"}; + } + + @Override + public String[] getPeerSupportedSignatureAlgorithms() { + return new String[] {"bar", "foo"}; + } + + @Override + public List getRequestedServerNames() { + return List.of(new SNIHostName("localhost")); + } + + @Override + public List getStatusResponses() { + return List.of(new byte[] {'4', '2'}); + } + } + +} diff --git a/test/jdk/java/net/httpclient/InvalidInputStreamSubscriptionRequest.java b/test/jdk/java/net/httpclient/InvalidInputStreamSubscriptionRequest.java index 2f4f96e8b9b..d1e3cc026d3 100644 --- a/test/jdk/java/net/httpclient/InvalidInputStreamSubscriptionRequest.java +++ b/test/jdk/java/net/httpclient/InvalidInputStreamSubscriptionRequest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -22,19 +22,39 @@ */ /* - * @test + * @test id=http3 * @summary Tests an asynchronous BodySubscriber that completes * immediately with an InputStream which issues bad * requests * @library /test/lib /test/jdk/java/net/httpclient/lib * @build jdk.test.lib.net.SimpleSSLContext ReferenceTracker * jdk.httpclient.test.lib.common.HttpServerAdapters - * @run testng/othervm InvalidInputStreamSubscriptionRequest + * @run testng/othervm -Dtest.http.version=http3 + * -Djdk.internal.httpclient.debug=true + * InvalidInputStreamSubscriptionRequest + */ +/* + * @test id=http2 + * @summary Tests an asynchronous BodySubscriber that completes + * immediately with an InputStream which issues bad + * requests + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext ReferenceTracker + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @run testng/othervm -Dtest.http.version=http2 InvalidInputStreamSubscriptionRequest + */ +/* + * @test id=http1 + * @summary Tests an asynchronous BodySubscriber that completes + * immediately with an InputStream which issues bad + * requests + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext ReferenceTracker + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @run testng/othervm -Dtest.http.version=http1 InvalidInputStreamSubscriptionRequest */ import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterClass; import org.testng.annotations.AfterTest; @@ -47,7 +67,6 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -55,7 +74,6 @@ import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.net.http.HttpResponse.BodyHandlers; import java.net.http.HttpResponse.BodySubscriber; -import java.net.http.HttpResponse.BodySubscribers; import java.nio.ByteBuffer; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -67,15 +85,16 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.Flow; -import java.util.concurrent.Flow.Publisher; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; @@ -86,6 +105,7 @@ public class InvalidInputStreamSubscriptionRequest implements HttpServerAdapters HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI_fixed; String httpURI_chunk; String httpsURI_fixed; @@ -94,6 +114,8 @@ public class InvalidInputStreamSubscriptionRequest implements HttpServerAdapters String http2URI_chunk; String https2URI_fixed; String https2URI_chunk; + String http3URI_fixed; + String http3URI_chunk; static final int ITERATION_COUNT = 3; // a shared executor helps reduce the amount of threads created by the test @@ -180,46 +202,81 @@ public class InvalidInputStreamSubscriptionRequest implements HttpServerAdapters @DataProvider(name = "variants") public Object[][] variants() { - return new Object[][]{ - { httpURI_fixed, false, OF_INPUTSTREAM }, - { httpURI_chunk, false, OF_INPUTSTREAM }, - { httpsURI_fixed, false, OF_INPUTSTREAM }, - { httpsURI_chunk, false, OF_INPUTSTREAM }, + Object[][] http3 = new Object[][]{ + {http3URI_fixed, false, OF_INPUTSTREAM}, + {http3URI_chunk, false, OF_INPUTSTREAM}, + {http3URI_fixed, true, OF_INPUTSTREAM}, + {http3URI_chunk, true, OF_INPUTSTREAM}, + }; + Object[][] http1 = new Object[][] { + {httpURI_fixed, false, OF_INPUTSTREAM}, + {httpURI_chunk, false, OF_INPUTSTREAM}, + {httpsURI_fixed, false, OF_INPUTSTREAM}, + {httpsURI_chunk, false, OF_INPUTSTREAM}, + {httpURI_fixed, true, OF_INPUTSTREAM}, + {httpURI_chunk, true, OF_INPUTSTREAM}, + {httpsURI_fixed, true, OF_INPUTSTREAM}, + {httpsURI_chunk, true, OF_INPUTSTREAM}, + }; + Object[][] http2 = new Object[][] { { http2URI_fixed, false, OF_INPUTSTREAM }, { http2URI_chunk, false, OF_INPUTSTREAM }, { https2URI_fixed, false, OF_INPUTSTREAM }, { https2URI_chunk, false, OF_INPUTSTREAM }, - - { httpURI_fixed, true, OF_INPUTSTREAM }, - { httpURI_chunk, true, OF_INPUTSTREAM }, - { httpsURI_fixed, true, OF_INPUTSTREAM }, - { httpsURI_chunk, true, OF_INPUTSTREAM }, { http2URI_fixed, true, OF_INPUTSTREAM }, { http2URI_chunk, true, OF_INPUTSTREAM }, { https2URI_fixed, true, OF_INPUTSTREAM }, { https2URI_chunk, true, OF_INPUTSTREAM }, }; + String version = System.getProperty("test.http.version"); + if ("http3".equals(version)) { + return http3; + } + if ("http2".equals(version)) { + return http2; + } + if ("http1".equals(version)) { + return http1; + } + if (version == null) throw new AssertionError("test.http.version not set"); + throw new AssertionError("test.http.version should be set to http3|http2|http1. Found " + version); } + + final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; - HttpClient newHttpClient() { - return TRACKER.track(HttpClient.newBuilder() + HttpClient newHttpClient(String uri) { + HttpClient.Builder builder = uri.contains("/http3/") + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + return TRACKER.track(builder .proxy(HttpClient.Builder.NO_PROXY) .executor(executor) .sslContext(sslContext) .build()); } + HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + @Test(dataProvider = "variants") public void testNoBody(String uri, boolean sameClient, BHS handlers) throws Exception { HttpClient client = null; + Throwable failed = null; for (int i=0; i< ITERATION_COUNT; i++) { - if (!sameClient || client == null) - client = newHttpClient(); + if (!sameClient || client == null) { + client = newHttpClient(uri); + } - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(URI.create(uri)) .build(); BodyHandler handler = handlers.get(); BodyHandler badHandler = (rspinfo) -> @@ -246,7 +303,23 @@ public class InvalidInputStreamSubscriptionRequest implements HttpServerAdapters } if (cause instanceof IllegalArgumentException) { System.out.println("Got expected exception: " + cause); - } else throw x; + } else { + failed = x; + } + } finally { + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + var error = TRACKER.check(tracker, 1500); + if (error != null) { + if (failed != null) { + failed.addSuppressed(error); + } else throw error; + } + } + } + if (failed != null) { + throw new AssertionError("Unexpected exception: " + failed, failed); } } } @@ -256,11 +329,12 @@ public class InvalidInputStreamSubscriptionRequest implements HttpServerAdapters throws Exception { HttpClient client = null; + Throwable failed = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(URI.create(uri)) .build(); BodyHandler handler = handlers.get(); BodyHandler badHandler = (rspinfo) -> @@ -295,7 +369,23 @@ public class InvalidInputStreamSubscriptionRequest implements HttpServerAdapters } if (cause instanceof IllegalArgumentException) { System.out.println("Got expected exception: " + cause); - } else throw x; + } else { + failed = x; + } + } finally { + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + var error = TRACKER.check(tracker, 1500); + if (error != null) { + if (failed != null) { + failed.addSuppressed(error); + } else throw error; + } + } + } + if (failed != null) { + throw new AssertionError("Unexpected exception: " + failed, failed); } } } @@ -305,11 +395,12 @@ public class InvalidInputStreamSubscriptionRequest implements HttpServerAdapters throws Exception { HttpClient client = null; + Throwable failed = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri+"/withBody")) + HttpRequest req = newRequestBuilder(URI.create(uri+"/withBody")) .build(); BodyHandler handler = handlers.get(); BodyHandler badHandler = (rspinfo) -> @@ -331,7 +422,23 @@ public class InvalidInputStreamSubscriptionRequest implements HttpServerAdapters } if (cause instanceof IllegalArgumentException) { System.out.println("Got expected exception: " + cause); - } else throw x; + } else { + failed = x; + } + } finally { + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + var error = TRACKER.check(tracker, 1500); + if (error != null) { + if (failed != null) { + failed.addSuppressed(error); + } else throw error; + } + } + } + if (failed != null) { + throw new AssertionError("Unexpected exception: " + failed, failed); } } } @@ -341,11 +448,12 @@ public class InvalidInputStreamSubscriptionRequest implements HttpServerAdapters throws Exception { HttpClient client = null; + Throwable failed = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri+"/withBody")) + HttpRequest req = newRequestBuilder(URI.create(uri+"/withBody")) .build(); BodyHandler handler = handlers.get(); BodyHandler badHandler = (rspinfo) -> @@ -374,7 +482,23 @@ public class InvalidInputStreamSubscriptionRequest implements HttpServerAdapters } if (cause instanceof IllegalArgumentException) { System.out.println("Got expected exception: " + cause); - } else throw x; + } else { + failed = x; + } + } finally { + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + var error = TRACKER.check(tracker, 1500); + if (error != null) { + if (failed != null) { + failed.addSuppressed(error); + } else throw error; + } + } + } + if (failed != null) { + throw new AssertionError("Unexpected exception: " + failed, failed); } } } @@ -476,20 +600,32 @@ public class InvalidInputStreamSubscriptionRequest implements HttpServerAdapters https2URI_fixed = "https://" + https2TestServer.serverAuthority() + "/https2/fixed"; https2URI_chunk = "https://" + https2TestServer.serverAuthority() + "/https2/chunk"; + // HTTP/3 + HttpTestHandler h3_fixedLengthHandler = new HTTP_FixedLengthHandler(); + HttpTestHandler h3_chunkedHandler = new HTTP_VariableLengthHandler(); + + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(h3_fixedLengthHandler, "/http3/fixed"); + http3TestServer.addHandler(h3_chunkedHandler, "/http3/chunk"); + http3URI_fixed = "https://" + http3TestServer.serverAuthority() + "/http3/fixed"; + http3URI_chunk = "https://" + http3TestServer.serverAuthority() + "/http3/chunk"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest public void teardown() throws Exception { - AssertionError fail = TRACKER.check(500); + AssertionError fail = TRACKER.check(1500); try { httpTestServer.stop(); httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } finally { if (fail != null) { throw fail; diff --git a/test/jdk/java/net/httpclient/InvalidSubscriptionRequest.java b/test/jdk/java/net/httpclient/InvalidSubscriptionRequest.java index 3cd7cb629ea..327dc57d6df 100644 --- a/test/jdk/java/net/httpclient/InvalidSubscriptionRequest.java +++ b/test/jdk/java/net/httpclient/InvalidSubscriptionRequest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -34,8 +34,6 @@ */ import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -47,7 +45,6 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -69,11 +66,13 @@ import java.util.concurrent.Flow.Publisher; import java.util.function.Supplier; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; @@ -84,6 +83,7 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI_fixed; String httpURI_chunk; String httpsURI_fixed; @@ -92,6 +92,8 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { String http2URI_chunk; String https2URI_fixed; String https2URI_chunk; + String http3URI_fixed; + String http3URI_chunk; static final int ITERATION_COUNT = 3; // a shared executor helps reduce the amount of threads created by the test @@ -127,6 +129,11 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { @DataProvider(name = "variants") public Object[][] variants() { return new Object[][]{ + { http3URI_fixed, false, OF_PUBLISHER_API }, + { http3URI_chunk, false, OF_PUBLISHER_API }, + { http3URI_fixed, true, OF_PUBLISHER_API }, + { http3URI_chunk, true, OF_PUBLISHER_API }, + { httpURI_fixed, false, OF_PUBLISHER_API }, { httpURI_chunk, false, OF_PUBLISHER_API }, { httpsURI_fixed, false, OF_PUBLISHER_API }, @@ -148,22 +155,35 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { } final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; - HttpClient newHttpClient() { - return TRACKER.track(HttpClient.newBuilder() + HttpClient newHttpClient(String uri) { + HttpClient.Builder builder = uri.contains("/http3/") + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + return TRACKER.track(builder .proxy(HttpClient.Builder.NO_PROXY) .executor(executor) .sslContext(sslContext) .build()); } + HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + @Test(dataProvider = "variants") public void testNoBody(String uri, boolean sameClient, BHS handlers) throws Exception { HttpClient client = null; + Throwable failed = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(URI.create(uri)) .build(); BodyHandler>> handler = handlers.get(); HttpResponse>> response = client.send(req, handler); @@ -190,7 +210,23 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { } if (cause instanceof IllegalArgumentException) { System.out.println("Got expected exception: " + cause); - } else throw x; + } else { + failed = x; + } + } finally { + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + var error = TRACKER.check(tracker, 500); + if (error != null) { + if (failed != null) { + failed.addSuppressed(error); + } else throw error; + } + } + } + if (failed != null) { + throw new AssertionError("Unexpected exception: " + failed, failed); } } } @@ -198,11 +234,12 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { @Test(dataProvider = "variants") public void testNoBodyAsync(String uri, boolean sameClient, BHS handlers) throws Exception { HttpClient client = null; + Throwable failed = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(URI.create(uri)) .build(); BodyHandler>> handler = handlers.get(); // We can reuse our BodySubscribers implementations to subscribe to the @@ -234,7 +271,23 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { } if (cause instanceof IllegalArgumentException) { System.out.println("Got expected exception: " + cause); - } else throw x; + } else { + failed = x; + } + } finally { + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + var error = TRACKER.check(tracker, 500); + if (error != null) { + if (failed != null) { + failed.addSuppressed(error); + } else throw error; + } + } + } + if (failed != null) { + throw new AssertionError("Unexpected exception: " + failed, failed); } } } @@ -242,11 +295,12 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { @Test(dataProvider = "variants") public void testAsString(String uri, boolean sameClient, BHS handlers) throws Exception { HttpClient client = null; + Throwable failed = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri+"/withBody")) + HttpRequest req = newRequestBuilder(URI.create(uri+"/withBody")) .build(); BodyHandler>> handler = handlers.get(); HttpResponse>> response = client.send(req, handler); @@ -269,7 +323,23 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { } if (cause instanceof IllegalArgumentException) { System.out.println("Got expected exception: " + cause); - } else throw x; + } else { + failed = x; + } + } finally { + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + var error = TRACKER.check(tracker, 500); + if (error != null) { + if (failed != null) { + failed.addSuppressed(error); + } else throw error; + } + } + } + if (failed != null) { + throw new AssertionError("Unexpected exception: " + failed, failed); } } } @@ -277,11 +347,12 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { @Test(dataProvider = "variants") public void testAsStringAsync(String uri, boolean sameClient, BHS handlers) throws Exception { HttpClient client = null; + Throwable failed = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri+"/withBody")) + HttpRequest req = newRequestBuilder(URI.create(uri+"/withBody")) .build(); BodyHandler>> handler = handlers.get(); // We can reuse our BodySubscribers implementations to subscribe to the @@ -307,7 +378,23 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { } if (cause instanceof IllegalArgumentException) { System.out.println("Got expected exception: " + cause); - } else throw x; + } else { + failed = x; + } + } finally { + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + var error = TRACKER.check(tracker, 500); + if (error != null) { + if (failed != null) { + failed.addSuppressed(error); + } else throw error; + } + } + } + if (failed != null) { + throw new AssertionError("Unexpected exception: " + failed, failed); } } } @@ -409,10 +496,21 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { https2URI_fixed = "https://" + https2TestServer.serverAuthority() + "/https2/fixed"; https2URI_chunk = "https://" + https2TestServer.serverAuthority() + "/https2/chunk"; + // HTTP/3 + HttpTestHandler h3_fixedLengthHandler = new HTTP_FixedLengthHandler(); + HttpTestHandler h3_chunkedHandler = new HTTP_VariableLengthHandler(); + + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(h3_fixedLengthHandler, "/http3/fixed"); + http3TestServer.addHandler(h3_chunkedHandler, "/http3/chunk"); + http3URI_fixed = "https://" + http3TestServer.serverAuthority() + "/http3/fixed"; + http3URI_chunk = "https://" + http3TestServer.serverAuthority() + "/http3/chunk"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -423,6 +521,7 @@ public class InvalidSubscriptionRequest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } finally { if (fail != null) { throw fail; diff --git a/test/jdk/java/net/httpclient/LargeHandshakeTest.java b/test/jdk/java/net/httpclient/LargeHandshakeTest.java index 1f4c85b361e..ff3981e8b3d 100644 --- a/test/jdk/java/net/httpclient/LargeHandshakeTest.java +++ b/test/jdk/java/net/httpclient/LargeHandshakeTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -62,6 +62,9 @@ import jdk.httpclient.test.lib.common.HttpServerAdapters; import jdk.httpclient.test.lib.common.TestServerConfigurator; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; /** * @test @@ -964,12 +967,14 @@ public class LargeHandshakeTest implements HttpServerAdapters { HttpTestServer http2Server; HttpTestServer https1Server; HttpTestServer https2Server; + HttpTestServer http3Server; DigestEchoServer.TunnelingProxy proxy; URI http1URI; URI https1URI; URI http2URI; URI https2URI; + URI http3URI; InetSocketAddress proxyAddress; ProxySelector proxySelector; HttpClient client; @@ -988,8 +993,7 @@ public class LargeHandshakeTest implements HttpServerAdapters { } public HttpClient newHttpClient(ProxySelector ps) { - HttpClient.Builder builder = HttpClient - .newBuilder() + HttpClient.Builder builder = newClientBuilderForH3() .sslContext(context) .executor(clientexec) .proxy(ps); @@ -1028,6 +1032,12 @@ public class LargeHandshakeTest implements HttpServerAdapters { https2Server.start(); https2URI = new URI("https://" + https2Server.serverAuthority() + "/LargeHandshakeTest/https2/"); + // HTTP/3 + http3Server = HttpTestServer.create(HTTP_3_URI_ONLY, SSLContext.getDefault()); + http3Server.addHandler(new HttpTestLargeHandler(), "/LargeHandshakeTest/http3/"); + http3Server.start(); + http3URI = new URI("https://" + http3Server.serverAuthority() + "/LargeHandshakeTest/http3/"); + proxy = DigestEchoServer.createHttpsProxyTunnel( DigestEchoServer.HttpAuthSchemeType.NONE); proxyAddress = proxy.getProxyAddress(); @@ -1073,9 +1083,11 @@ public class LargeHandshakeTest implements HttpServerAdapters { } public void run(String... args) throws Exception { - List serverURIs = List.of(http1URI, http2URI, https1URI, https2URI); + List serverURIs = List.of(http3URI, http1URI, http2URI, https1URI, https2URI); for (int i = 0; i < 5; i++) { for (URI base : serverURIs) { + // skip HTTP/3 if proxy + if (base.getRawPath().contains("/http3/")) continue; if (base.getScheme().equalsIgnoreCase("https")) { URI proxy = i % 1 == 0 ? base.resolve(URI.create("proxy/foo?n=" + requestCounter.incrementAndGet())) : base.resolve(URI.create("direct/foo?n=" + requestCounter.incrementAndGet())); @@ -1090,10 +1102,19 @@ public class LargeHandshakeTest implements HttpServerAdapters { CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); } + HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + public void test(URI uri) throws Exception { System.out.println("Testing with " + uri); pending.add(uri); - HttpRequest request = HttpRequest.newBuilder(uri).build(); + HttpRequest request = newRequestBuilder(uri).build(); CompletableFuture> resp = client.sendAsync(request, HttpResponse.BodyHandlers.ofString()) .whenComplete((r, t) -> this.requestCompleted(request, r, t)); @@ -1111,11 +1132,13 @@ public class LargeHandshakeTest implements HttpServerAdapters { } public void tearDown() { + client.close(); proxy = stop(proxy, DigestEchoServer.TunnelingProxy::stop); http1Server = stop(http1Server, HttpTestServer::stop); https1Server = stop(https1Server, HttpTestServer::stop); http2Server = stop(http2Server, HttpTestServer::stop); https2Server = stop(https2Server, HttpTestServer::stop); + http3Server = stop(http3Server, HttpTestServer::stop); client = null; try { executor.awaitTermination(2000, TimeUnit.MILLISECONDS); diff --git a/test/jdk/java/net/httpclient/LargeResponseTest.java b/test/jdk/java/net/httpclient/LargeResponseTest.java index bfc7c9ca5db..02cf5ddcc8d 100644 --- a/test/jdk/java/net/httpclient/LargeResponseTest.java +++ b/test/jdk/java/net/httpclient/LargeResponseTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -51,6 +51,9 @@ import java.util.concurrent.atomic.AtomicLong; import jdk.httpclient.test.lib.common.HttpServerAdapters; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; /** * @test @@ -65,6 +68,7 @@ import static java.net.http.HttpClient.Version.HTTP_2; * @run main/othervm -Dtest.requiresHost=true * -Djdk.httpclient.HttpClient.log=headers * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.quic.maxInitialTimeout=60 * LargeResponseTest * */ @@ -94,12 +98,14 @@ public class LargeResponseTest implements HttpServerAdapters { HttpTestServer http2Server; HttpTestServer https1Server; HttpTestServer https2Server; + HttpTestServer http3Server; DigestEchoServer.TunnelingProxy proxy; URI http1URI; URI https1URI; URI http2URI; URI https2URI; + URI http3URI; InetSocketAddress proxyAddress; ProxySelector proxySelector; HttpClient client; @@ -112,8 +118,7 @@ public class LargeResponseTest implements HttpServerAdapters { TimeUnit.SECONDS, new LinkedBlockingQueue<>()); public HttpClient newHttpClient(ProxySelector ps) { - HttpClient.Builder builder = HttpClient - .newBuilder() + HttpClient.Builder builder = newClientBuilderForH3() .sslContext(context) .executor(clientexec) .proxy(ps); @@ -152,6 +157,12 @@ public class LargeResponseTest implements HttpServerAdapters { https2Server.start(); https2URI = new URI("https://" + https2Server.serverAuthority() + "/LargeResponseTest/https2/"); + // HTTP/3 + http3Server = HttpTestServer.create(HTTP_3_URI_ONLY, SSLContext.getDefault()); + http3Server.addHandler(new HttpTestLargeHandler(), "/LargeResponseTest/http3/"); + http3Server.start(); + http3URI = new URI("https://" + http3Server.serverAuthority() + "/LargeResponseTest/http3/"); + proxy = DigestEchoServer.createHttpsProxyTunnel( DigestEchoServer.HttpAuthSchemeType.NONE); proxyAddress = proxy.getProxyAddress(); @@ -182,9 +193,11 @@ public class LargeResponseTest implements HttpServerAdapters { } public void run(String... args) throws Exception { - List serverURIs = List.of(http1URI, http2URI, https1URI, https2URI); + List serverURIs = List.of(http3URI, http1URI, http2URI, https1URI, https2URI); for (int i=0; i<5; i++) { for (URI base : serverURIs) { + // no proxy with HTTP/3 + if (base.getRawPath().contains("/http3/")) continue; if (base.getScheme().equalsIgnoreCase("https")) { URI proxy = i % 1 == 0 ? base.resolve(URI.create("proxy/foo?n="+requestCounter.incrementAndGet())) : base.resolve(URI.create("direct/foo?n="+requestCounter.incrementAndGet())); @@ -199,10 +212,19 @@ public class LargeResponseTest implements HttpServerAdapters { CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); } + HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + public void test(URI uri) throws Exception { System.out.println("Testing with " + uri); pending.add(uri); - HttpRequest request = HttpRequest.newBuilder(uri).build(); + HttpRequest request = newRequestBuilder(uri).build(); CompletableFuture> resp = client.sendAsync(request, HttpResponse.BodyHandlers.ofString()) .whenComplete((r, t) -> this.requestCompleted(request, r, t)); @@ -220,11 +242,13 @@ public class LargeResponseTest implements HttpServerAdapters { } public void tearDown() { + client.close(); proxy = stop(proxy, DigestEchoServer.TunnelingProxy::stop); http1Server = stop(http1Server, HttpTestServer::stop); https1Server = stop(https1Server, HttpTestServer::stop); http2Server = stop(http2Server, HttpTestServer::stop); https2Server = stop(https2Server, HttpTestServer::stop); + http3Server = stop(http3Server, HttpTestServer::stop); client = null; try { executor.awaitTermination(2000, TimeUnit.MILLISECONDS); diff --git a/test/jdk/java/net/httpclient/LineBodyHandlerTest.java b/test/jdk/java/net/httpclient/LineBodyHandlerTest.java index 3271f21a130..62afcab9ee2 100644 --- a/test/jdk/java/net/httpclient/LineBodyHandlerTest.java +++ b/test/jdk/java/net/httpclient/LineBodyHandlerTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -28,8 +28,6 @@ import java.io.PrintStream; import java.io.StringReader; import java.io.UncheckedIOException; import java.math.BigInteger; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Builder; @@ -55,10 +53,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import javax.net.ssl.SSLContext; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -67,6 +61,9 @@ import org.testng.annotations.Test; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_16; import static java.nio.charset.StandardCharsets.UTF_8; import static java.net.http.HttpRequest.BodyPublishers.ofString; @@ -90,14 +87,16 @@ import static org.testng.Assert.assertTrue; public class LineBodyHandlerTest implements HttpServerAdapters { SSLContext sslContext; - HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] + HttpTestServer httpTestServer; // HTTP/1.1 [ 5 servers ] HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; final AtomicInteger clientCount = new AtomicInteger(); @@ -106,6 +105,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { @DataProvider(name = "uris") public Object[][] variants() { return new Object[][]{ + { http3URI }, { httpURI }, { httpsURI }, { http2URI }, @@ -195,17 +195,26 @@ public class LineBodyHandlerTest implements HttpServerAdapters { return sharedClient; } clientCount.incrementAndGet(); - return sharedClient = TRACKER.track(HttpClient.newBuilder() + return sharedClient = TRACKER.track(newClientBuilderForH3() .sslContext(sslContext) .proxy(Builder.NO_PROXY) .build()); } + HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + @Test(dataProvider = "uris") void testStringWithFinisher(String url) { String body = "May the luck of the Irish be with you!"; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(body)) .build(); @@ -226,7 +235,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testAsStream(String url) { String body = "May the luck of the Irish be with you!"; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(body)) .build(); @@ -249,7 +258,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { String body = "May the luck\r\n\r\n of the Irish be with you!"; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(body)) .build(); @@ -270,7 +279,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testAsStreamWithCRLF(String url) { String body = "May the luck\r\n\r\n of the Irish be with you!"; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(body)) .build(); @@ -294,7 +303,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testStringWithFinisherBlocking(String url) throws Exception { String body = "May the luck of the Irish be with you!"; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(body)).build(); StringSubscriber subscriber = new StringSubscriber(); @@ -311,7 +320,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testStringWithoutFinisherBlocking(String url) throws Exception { String body = "May the luck of the Irish be with you!"; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(body)).build(); StringSubscriber subscriber = new StringSubscriber(); @@ -330,7 +339,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testAsStreamWithMixedCRLF(String url) { String body = "May\r\n the wind\r\n always be\rat your back.\r\r"; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(body)) .build(); @@ -357,7 +366,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testAsStreamWithMixedCRLF_UTF8(String url) { String body = "May\r\n the wind\r\n always be\rat your back.\r\r"; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .header("Content-type", "text/text; charset=UTF-8") .POST(BodyPublishers.ofString(body, UTF_8)).build(); @@ -383,7 +392,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testAsStreamWithMixedCRLF_UTF16(String url) { String body = "May\r\n the wind\r\n always be\rat your back.\r\r"; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .header("Content-type", "text/text; charset=UTF-16") .POST(BodyPublishers.ofString(body, UTF_16)).build(); @@ -410,7 +419,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testObjectWithFinisher(String url) { String body = "May\r\n the wind\r\n always be\rat your back."; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(body)) .build(); @@ -435,7 +444,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testObjectWithFinisher_UTF16(String url) { String body = "May\r\n the wind\r\n always be\rat your back.\r\r"; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .header("Content-type", "text/text; charset=UTF-16") .POST(BodyPublishers.ofString(body, UTF_16)).build(); ObjectSubscriber subscriber = new ObjectSubscriber(); @@ -461,7 +470,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testObjectWithoutFinisher(String url) { String body = "May\r\n the wind\r\n always be\rat your back."; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(body)) .build(); @@ -487,7 +496,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testObjectWithFinisherBlocking(String url) throws Exception { String body = "May\r\n the wind\r\n always be\nat your back."; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(body)) .build(); @@ -511,7 +520,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testObjectWithoutFinisherBlocking(String url) throws Exception { String body = "May\r\n the wind\r\n always be\nat your back."; HttpClient client = newClient(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(body)) .build(); @@ -546,7 +555,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testBigTextFromLineSubscriber(String url) { HttpClient client = newClient(); String bigtext = bigtext(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(bigtext)) .build(); @@ -567,7 +576,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { void testBigTextAsStream(String url) { HttpClient client = newClient(); String bigtext = bigtext(); - HttpRequest request = HttpRequest.newBuilder(URI.create(url)) + HttpRequest request = newRequestBuilder(URI.create(url)) .POST(BodyPublishers.ofString(bigtext)) .build(); @@ -690,10 +699,15 @@ public class LineBodyHandlerTest implements HttpServerAdapters { https2TestServer.addHandler(new HttpTestEchoHandler(), "/https2/echo"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/echo"; + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new HttpTestEchoHandler(), "/http3/echo"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/echo"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -712,6 +726,7 @@ public class LineBodyHandlerTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); if (fail != null) throw fail; } diff --git a/test/jdk/java/net/httpclient/ManyRequests.java b/test/jdk/java/net/httpclient/ManyRequests.java index c196ef76466..190205a9ef5 100644 --- a/test/jdk/java/net/httpclient/ManyRequests.java +++ b/test/jdk/java/net/httpclient/ManyRequests.java @@ -25,11 +25,11 @@ * @test * @bug 8087112 8180044 8256459 * @key intermittent - * @modules java.net.http + * @modules java.net.http/jdk.internal.net.http.common * java.logging * jdk.httpserver - * @library /test/lib - * @build jdk.test.lib.net.SimpleSSLContext + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.TestServerConfigurator * @compile ../../../com/sun/net/httpserver/LogFilter.java * @compile ../../../com/sun/net/httpserver/EchoHandler.java * @compile ../../../com/sun/net/httpserver/FileServerHandler.java @@ -77,7 +77,9 @@ import java.util.logging.Level; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import jdk.httpclient.test.lib.common.TestServerConfigurator; import jdk.test.lib.Platform; import jdk.test.lib.RandomFactory; import jdk.test.lib.net.SimpleSSLContext; @@ -107,7 +109,7 @@ public class ManyRequests { InetSocketAddress addr = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0); HttpsServer server = HttpsServer.create(addr, 0); ExecutorService executor = executorFor("HTTPS/1.1 Server Thread"); - server.setHttpsConfigurator(new Configurator(ctx)); + server.setHttpsConfigurator(new Configurator(addr.getAddress(), ctx)); server.setExecutor(executor); ExecutorService virtualExecutor = Executors.newThreadPerTaskExecutor(Thread.ofVirtual() .name("HttpClient-Worker", 0).factory()); @@ -366,12 +368,17 @@ public class ManyRequests { } static class Configurator extends HttpsConfigurator { - public Configurator(SSLContext ctx) { + private final InetAddress serverAddr; + public Configurator(InetAddress serverAddr, SSLContext ctx) { super(ctx); + this.serverAddr = serverAddr; } + @Override public void configure(HttpsParameters params) { - params.setSSLParameters(getSSLContext().getSupportedSSLParameters()); + final SSLParameters parameters = getSSLContext().getSupportedSSLParameters(); + TestServerConfigurator.addSNIMatcher(this.serverAddr, parameters); + params.setSSLParameters(parameters); } } diff --git a/test/jdk/java/net/httpclient/ManyRequests2.java b/test/jdk/java/net/httpclient/ManyRequests2.java index e6aa5ffa83f..a1e2307b820 100644 --- a/test/jdk/java/net/httpclient/ManyRequests2.java +++ b/test/jdk/java/net/httpclient/ManyRequests2.java @@ -24,11 +24,11 @@ /* * @test * @bug 8087112 8180044 8256459 - * @modules java.net.http + * @modules java.net.http/jdk.internal.net.http.common * java.logging * jdk.httpserver - * @library /test/lib - * @build jdk.test.lib.net.SimpleSSLContext + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.TestServerConfigurator * @compile ../../../com/sun/net/httpserver/LogFilter.java * @compile ../../../com/sun/net/httpserver/EchoHandler.java * @compile ../../../com/sun/net/httpserver/FileServerHandler.java diff --git a/test/jdk/java/net/httpclient/ManyRequestsLegacy.java b/test/jdk/java/net/httpclient/ManyRequestsLegacy.java index 010c04cc51d..1d8091d3085 100644 --- a/test/jdk/java/net/httpclient/ManyRequestsLegacy.java +++ b/test/jdk/java/net/httpclient/ManyRequestsLegacy.java @@ -23,11 +23,11 @@ /* * @test - * @modules java.net.http + * @modules java.net.http/jdk.internal.net.http.common * java.logging * jdk.httpserver - * @library /test/lib - * @build jdk.test.lib.net.SimpleSSLContext + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.TestServerConfigurator * @compile ../../../com/sun/net/httpserver/LogFilter.java * @compile ../../../com/sun/net/httpserver/EchoHandler.java * @compile ../../../com/sun/net/httpserver/FileServerHandler.java @@ -65,6 +65,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; import java.net.http.HttpClient; import java.net.http.HttpClient.Version; @@ -81,6 +82,7 @@ import java.util.Random; import java.util.logging.Logger; import java.util.logging.Level; +import jdk.httpclient.test.lib.common.TestServerConfigurator; import jdk.test.lib.Platform; import jdk.test.lib.RandomFactory; import jdk.test.lib.net.SimpleSSLContext; @@ -112,7 +114,7 @@ public class ManyRequestsLegacy { }); InetSocketAddress addr = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0); HttpsServer server = HttpsServer.create(addr, 0); - server.setHttpsConfigurator(new Configurator(ctx)); + server.setHttpsConfigurator(new Configurator(addr.getAddress(), ctx)); LegacyHttpClient client = new LegacyHttpClient(); @@ -431,12 +433,18 @@ public class ManyRequestsLegacy { } static class Configurator extends HttpsConfigurator { - public Configurator(SSLContext ctx) { + private final InetAddress serverAddr; + + public Configurator(InetAddress serverAddr, SSLContext ctx) { super(ctx); + this.serverAddr = serverAddr; } + @Override public void configure(HttpsParameters params) { - params.setSSLParameters(getSSLContext().getSupportedSSLParameters()); + final SSLParameters parameters = getSSLContext().getSupportedSSLParameters(); + TestServerConfigurator.addSNIMatcher(this.serverAddr, parameters); + params.setSSLParameters(parameters); } } } diff --git a/test/jdk/java/net/httpclient/MappingResponseSubscriber.java b/test/jdk/java/net/httpclient/MappingResponseSubscriber.java index 6f5964970cc..5e5497cf0e5 100644 --- a/test/jdk/java/net/httpclient/MappingResponseSubscriber.java +++ b/test/jdk/java/net/httpclient/MappingResponseSubscriber.java @@ -49,11 +49,9 @@ import com.sun.net.httpserver.HttpHandler; import com.sun.net.httpserver.HttpServer; import com.sun.net.httpserver.HttpsServer; import java.net.http.HttpClient; -import java.net.http.HttpHeaders; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; -import java.net.http.HttpResponse.BodyHandlers; import java.net.http.HttpResponse.BodySubscribers; import java.net.http.HttpResponse.BodySubscriber; import java.util.function.Function; @@ -64,6 +62,7 @@ import jdk.internal.net.http.common.OperationTrackers.Tracker; import jdk.httpclient.test.lib.http2.Http2TestServer; import jdk.httpclient.test.lib.http2.Http2TestExchange; import jdk.httpclient.test.lib.http2.Http2Handler; +import jdk.test.lib.Utils; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -149,14 +148,14 @@ public class MappingResponseSubscriber { Tracker tracker = TRACKER.getTracker(client); client = null; System.gc(); - AssertionError error = TRACKER.check(tracker, 1500); + AssertionError error = TRACKER.check(tracker, Utils.adjustTimeout(1500)); if (error != null) throw error; // the client didn't shut down properly } if (sameClient) { Tracker tracker = TRACKER.getTracker(client); client = null; System.gc(); - AssertionError error = TRACKER.check(tracker,1500); + AssertionError error = TRACKER.check(tracker, Utils.adjustTimeout(1500)); if (error != null) throw error; // the client didn't shut down properly } } diff --git a/test/jdk/java/net/httpclient/NoBodyPartOne.java b/test/jdk/java/net/httpclient/NoBodyPartOne.java index d05d3e9c5e1..7c7a51c92e7 100644 --- a/test/jdk/java/net/httpclient/NoBodyPartOne.java +++ b/test/jdk/java/net/httpclient/NoBodyPartOne.java @@ -44,6 +44,7 @@ import java.net.http.HttpResponse.BodyHandler; import java.net.http.HttpResponse.BodyHandlers; import org.testng.annotations.Test; +import static java.net.http.HttpClient.Version.HTTP_3; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -57,6 +58,9 @@ public class NoBodyPartOne extends AbstractNoBody { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { HttpRequest req = newRequestBuilder(uri) @@ -78,6 +82,9 @@ public class NoBodyPartOne extends AbstractNoBody { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { @@ -101,6 +108,9 @@ public class NoBodyPartOne extends AbstractNoBody { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { diff --git a/test/jdk/java/net/httpclient/NoBodyPartThree.java b/test/jdk/java/net/httpclient/NoBodyPartThree.java index d0f71130421..d5e310d1914 100644 --- a/test/jdk/java/net/httpclient/NoBodyPartThree.java +++ b/test/jdk/java/net/httpclient/NoBodyPartThree.java @@ -28,6 +28,7 @@ * @library /test/lib /test/jdk/java/net/httpclient/lib * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.http2.Http2TestServer * @run testng/othervm + * -Djdk.httpclient.HttpClient.log=quic,errors * -Djdk.httpclient.HttpClient.log=all * NoBodyPartThree */ @@ -44,8 +45,10 @@ import java.net.http.HttpRequest.BodyPublishers; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; +import jdk.internal.net.http.common.Utils; import org.testng.annotations.Test; +import static java.net.http.HttpClient.Version.HTTP_3; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @@ -62,6 +65,9 @@ public class NoBodyPartThree extends AbstractNoBody { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { var u = uri + "/testAsByteArrayPublisher/first/" + REQID.getAndIncrement(); @@ -72,7 +78,7 @@ public class NoBodyPartThree extends AbstractNoBody { Consumer> consumer = oba -> { consumerHasBeenCalled = true; oba.ifPresent(ba -> fail("Unexpected non-empty optional:" - + asString(ByteBuffer.wrap(ba)))); + + Utils.asString(ByteBuffer.wrap(ba)))); }; consumerHasBeenCalled = false; var response = client.send(req, BodyHandlers.ofByteArrayConsumer(consumer)); @@ -99,6 +105,9 @@ public class NoBodyPartThree extends AbstractNoBody { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { var u = uri + "/testStringPublisher/" + REQID.getAndIncrement(); @@ -121,6 +130,9 @@ public class NoBodyPartThree extends AbstractNoBody { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { var u = uri + "/testInputStreamPublisherBuffering/" + REQID.getAndIncrement(); @@ -144,6 +156,9 @@ public class NoBodyPartThree extends AbstractNoBody { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { var u = uri + "/testEmptyArrayPublisher/" + REQID.getAndIncrement(); diff --git a/test/jdk/java/net/httpclient/NoBodyPartTwo.java b/test/jdk/java/net/httpclient/NoBodyPartTwo.java index 563516b11bf..f7d331cb526 100644 --- a/test/jdk/java/net/httpclient/NoBodyPartTwo.java +++ b/test/jdk/java/net/httpclient/NoBodyPartTwo.java @@ -43,8 +43,10 @@ import java.net.http.HttpRequest.BodyPublishers; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; +import jdk.internal.net.http.common.Utils; import org.testng.annotations.Test; +import static java.net.http.HttpClient.Version.HTTP_3; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @@ -59,6 +61,9 @@ public class NoBodyPartTwo extends AbstractNoBody { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { HttpRequest req = newRequestBuilder(uri) @@ -67,7 +72,7 @@ public class NoBodyPartTwo extends AbstractNoBody { Consumer> consumer = oba -> { consumerHasBeenCalled = true; oba.ifPresent(ba -> fail("Unexpected non-empty optional: " - + asString(ByteBuffer.wrap(ba)))); + + Utils.asString(ByteBuffer.wrap(ba)))); }; consumerHasBeenCalled = false; client.send(req, BodyHandlers.ofByteArrayConsumer(consumer)); @@ -83,6 +88,9 @@ public class NoBodyPartTwo extends AbstractNoBody { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { HttpRequest req = newRequestBuilder(uri) @@ -102,6 +110,9 @@ public class NoBodyPartTwo extends AbstractNoBody { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { HttpRequest req = newRequestBuilder(uri) @@ -122,6 +133,9 @@ public class NoBodyPartTwo extends AbstractNoBody { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uri) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { HttpRequest req = newRequestBuilder(uri) diff --git a/test/jdk/java/net/httpclient/NonAsciiCharsInURI.java b/test/jdk/java/net/httpclient/NonAsciiCharsInURI.java index 56a35119ae2..ba93645aef4 100644 --- a/test/jdk/java/net/httpclient/NonAsciiCharsInURI.java +++ b/test/jdk/java/net/httpclient/NonAsciiCharsInURI.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -30,7 +30,7 @@ * @build jdk.httpclient.test.lib.http2.Http2TestServer jdk.test.lib.net.SimpleSSLContext * @compile -encoding utf-8 NonAsciiCharsInURI.java * @run testng/othervm - * -Djdk.httpclient.HttpClient.log=reqeusts,headers + * -Djdk.httpclient.HttpClient.log=requests,headers,errors,quic * NonAsciiCharsInURI */ @@ -49,6 +49,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http3.Http3TestServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -58,6 +59,8 @@ import static java.lang.System.err; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.US_ASCII; import static java.net.http.HttpClient.Builder.NO_PROXY; import static org.testng.Assert.assertEquals; @@ -65,14 +68,17 @@ import static org.testng.Assert.assertEquals; public class NonAsciiCharsInURI implements HttpServerAdapters { SSLContext sslContext; - HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] + HttpTestServer httpTestServer; // HTTP/1.1 [ 5 servers ] HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; + String http3URI_head; private volatile HttpClient sharedClient; @@ -104,6 +110,9 @@ public class NonAsciiCharsInURI implements HttpServerAdapters { Arrays.asList(pathsAndQueryStrings).stream() .map(e -> new Object[] {https2URI + e[0], sameClient}) .forEach(list::add); + Arrays.asList(pathsAndQueryStrings).stream() + .map(e -> new Object[] {http3URI + e[0], sameClient}) + .forEach(list::add); } return list.stream().toArray(Object[][]::new); } @@ -122,11 +131,38 @@ public class NonAsciiCharsInURI implements HttpServerAdapters { return HTTP_1_1; if (uri.contains("/http2/") || uri.contains("/https2/")) return HTTP_2; + if (uri.contains("/http3/")) + return HTTP_3; return null; } + HttpRequest.Builder newRequestBuilder(String uri) { + var builder = HttpRequest.newBuilder(URI.create(uri)); + if (version(uri) == HTTP_3) { + builder.version(HTTP_3); + builder.setOption(H3_DISCOVERY, http3TestServer.h3DiscoveryConfig()); + } + return builder; + } + + HttpResponse headRequest(HttpClient client) + throws IOException, InterruptedException + { + out.println("\n" + now() + "--- Sending HEAD request ----\n"); + err.println("\n" + now() + "--- Sending HEAD request ----\n"); + + var request = newRequestBuilder(http3URI_head) + .HEAD().version(HTTP_2).build(); + var response = client.send(request, BodyHandlers.ofString()); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), HTTP_2); + out.println("\n" + now() + "--- HEAD request succeeded ----\n"); + err.println("\n" + now() + "--- HEAD request succeeded ----\n"); + return response; + } + private HttpClient makeNewClient() { - return HttpClient.newBuilder() + return newClientBuilderForH3() .proxy(NO_PROXY) .sslContext(sslContext) .build(); @@ -167,11 +203,14 @@ public class NonAsciiCharsInURI implements HttpServerAdapters { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uriString) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { - HttpRequest request = HttpRequest.newBuilder(uri).build(); + HttpRequest request = newRequestBuilder(uriString).build(); HttpResponse resp = client.send(request, BodyHandlers.ofString()); out.println("Got response: " + resp); @@ -203,10 +242,13 @@ public class NonAsciiCharsInURI implements HttpServerAdapters { for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { client = newHttpClient(sameClient); + if (!sameClient && version(uriString) == HTTP_3) { + headRequest(client); + } } try (var cl = new CloseableClient(client, sameClient)) { - HttpRequest request = HttpRequest.newBuilder(uri).build(); + HttpRequest request = newRequestBuilder(uriString).build(); client.sendAsync(request, BodyHandlers.ofString()) .thenApply(response -> { @@ -259,16 +301,28 @@ public class NonAsciiCharsInURI implements HttpServerAdapters { https2TestServer.addHandler(handler, "/https2/get"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/get"; + http3TestServer = HttpTestServer.create(HTTP_3, sslContext); + http3TestServer.addHandler(new HttpUriStringHandler(), "/http3/get"); + http3TestServer.addHandler(new HttpHeadOrGetHandler(), "/http3/head"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/get"; + http3URI_head = "https://" + http3TestServer.serverAuthority() + "/http3/head/x"; + err.println(now() + "Starting servers"); httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); out.println("HTTP/1.1 server (http) listening at: " + httpTestServer.serverAuthority()); out.println("HTTP/1.1 server (TLS) listening at: " + httpsTestServer.serverAuthority()); out.println("HTTP/2 server (h2c) listening at: " + http2TestServer.serverAuthority()); out.println("HTTP/2 server (h2) listening at: " + https2TestServer.serverAuthority()); + out.println("HTTP/3 server (h2) listening at: " + http3TestServer.serverAuthority()); + out.println(" + alt endpoint (h3) listening at: " + http3TestServer.getH3AltService() + .map(Http3TestServer::getAddress)); + + headRequest(newHttpClient(true)); out.println(now() + "setup done"); err.println(now() + "setup done"); @@ -281,6 +335,7 @@ public class NonAsciiCharsInURI implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } /** A handler that returns, as its body, the exact received request URI. */ diff --git a/test/jdk/java/net/httpclient/PathSubscriber/BodyHandlerOfFileDownloadTest.java b/test/jdk/java/net/httpclient/PathSubscriber/BodyHandlerOfFileDownloadTest.java index f3bf5521f4b..b5ea5da13e4 100644 --- a/test/jdk/java/net/httpclient/PathSubscriber/BodyHandlerOfFileDownloadTest.java +++ b/test/jdk/java/net/httpclient/PathSubscriber/BodyHandlerOfFileDownloadTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -40,9 +40,6 @@ * @run testng/othervm BodyHandlerOfFileDownloadTest */ -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import jdk.test.lib.util.FileUtils; import org.testng.annotations.AfterTest; @@ -54,8 +51,6 @@ import javax.net.ssl.SSLContext; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -68,16 +63,13 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Map; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; -import jdk.httpclient.test.lib.http2.Http2TestServerConnection; -import jdk.httpclient.test.lib.http2.Http2TestExchange; -import jdk.httpclient.test.lib.http2.Http2Handler; -import jdk.httpclient.test.lib.http2.OutgoingPushPromise; -import jdk.httpclient.test.lib.http2.Queue; import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.file.StandardOpenOption.CREATE; import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; import static java.nio.file.StandardOpenOption.WRITE; @@ -89,14 +81,16 @@ public class BodyHandlerOfFileDownloadTest implements HttpServerAdapters { static final String contentDispositionValue = "attachment; filename=example.html"; SSLContext sslContext; - HttpServerAdapters.HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] - HttpServerAdapters.HttpTestServer httpsTestServer; // HTTPS/1.1 - HttpServerAdapters.HttpTestServer http2TestServer; // HTTP/2 ( h2c ) - HttpServerAdapters.HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer httpTestServer; // HTTP/1.1 [ 5 servers ] + HttpTestServer httpsTestServer; // HTTPS/1.1 + HttpTestServer http2TestServer; // HTTP/2 ( h2c ) + HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; FileSystem zipFs; Path defaultFsPath; @@ -115,6 +109,9 @@ public class BodyHandlerOfFileDownloadTest implements HttpServerAdapters { @DataProvider(name = "defaultFsData") public Object[][] defaultFsData() { return new Object[][]{ + { http3URI, defaultFsPath, MSG, true }, + { http3URI, defaultFsPath, MSG, false }, + { httpURI, defaultFsPath, MSG, true }, { httpsURI, defaultFsPath, MSG, true }, { http2URI, defaultFsPath, MSG, true }, @@ -138,6 +135,24 @@ public class BodyHandlerOfFileDownloadTest implements HttpServerAdapters { private static final int ITERATION_COUNT = 3; + private HttpClient newHttpClient(String uri) { + var builder = uri.contains("/http3/") + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + return builder.proxy(NO_PROXY) + .sslContext(sslContext) + .build(); + } + + private HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + private void receive(String uriString, Path path, String expectedMsg, @@ -146,12 +161,9 @@ public class BodyHandlerOfFileDownloadTest implements HttpServerAdapters { for (int i = 0; i < ITERATION_COUNT; i++) { if (!sameClient || client == null) { - client = HttpClient.newBuilder() - .proxy(NO_PROXY) - .sslContext(sslContext) - .build(); + client = newHttpClient(uriString); } - var req = HttpRequest.newBuilder(URI.create(uriString)) + var req = newRequestBuilder(URI.create(uriString)) .POST(BodyPublishers.noBody()) .build(); var resp = client.send(req, BodyHandlers.ofFileDownload(path, CREATE, TRUNCATE_EXISTING, WRITE)); @@ -163,6 +175,12 @@ public class BodyHandlerOfFileDownloadTest implements HttpServerAdapters { assertEquals(msg, expectedMsg); assertTrue(resp.headers().firstValue("Content-Disposition").isPresent()); assertEquals(resp.headers().firstValue("Content-Disposition").get(), contentDispositionValue); + if (!sameClient) { + client.close(); + } + } + if (sameClient && client != null) { + client.close(); } } @@ -198,26 +216,31 @@ public class BodyHandlerOfFileDownloadTest implements HttpServerAdapters { zipFs = newZipFs(); zipFsPath = zipFsDir(zipFs); - httpTestServer = HttpServerAdapters.HttpTestServer.create(HTTP_1_1); + httpTestServer = HttpTestServer.create(HTTP_1_1); httpTestServer.addHandler(new HttpEchoHandler(), "/http1/echo"); httpURI = "http://" + httpTestServer.serverAuthority() + "/http1/echo"; - httpsTestServer = HttpServerAdapters.HttpTestServer.create(HTTP_1_1, sslContext); + httpsTestServer = HttpTestServer.create(HTTP_1_1, sslContext); httpsTestServer.addHandler(new HttpEchoHandler(), "/https1/echo"); httpsURI = "https://" + httpsTestServer.serverAuthority() + "/https1/echo"; - http2TestServer = HttpServerAdapters.HttpTestServer.create(HTTP_2); + http2TestServer = HttpTestServer.create(HTTP_2); http2TestServer.addHandler(new HttpEchoHandler(), "/http2/echo"); http2URI = "http://" + http2TestServer.serverAuthority() + "/http2/echo"; - https2TestServer = HttpServerAdapters.HttpTestServer.create(HTTP_2, sslContext); + https2TestServer = HttpTestServer.create(HTTP_2, sslContext); https2TestServer.addHandler(new HttpEchoHandler(), "/https2/echo"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/echo"; + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new HttpEchoHandler(), "/http3/echo"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/echo"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -231,6 +254,7 @@ public class BodyHandlerOfFileDownloadTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); zipFs.close(); } diff --git a/test/jdk/java/net/httpclient/PathSubscriber/BodyHandlerOfFileTest.java b/test/jdk/java/net/httpclient/PathSubscriber/BodyHandlerOfFileTest.java index ee8a9c7dcbb..42d5c2596c7 100644 --- a/test/jdk/java/net/httpclient/PathSubscriber/BodyHandlerOfFileTest.java +++ b/test/jdk/java/net/httpclient/PathSubscriber/BodyHandlerOfFileTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -39,9 +39,6 @@ * @run testng/othervm BodyHandlerOfFileTest */ -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import jdk.test.lib.util.FileUtils; import org.testng.annotations.AfterTest; @@ -53,8 +50,6 @@ import javax.net.ssl.SSLContext; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -64,30 +59,29 @@ import java.nio.charset.StandardCharsets; import java.nio.file.*; import java.util.Map; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; -import jdk.httpclient.test.lib.http2.Http2TestServerConnection; -import jdk.httpclient.test.lib.http2.Http2TestExchange; -import jdk.httpclient.test.lib.http2.Http2Handler; -import jdk.httpclient.test.lib.http2.OutgoingPushPromise; -import jdk.httpclient.test.lib.http2.Queue; import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static org.testng.Assert.assertEquals; public class BodyHandlerOfFileTest implements HttpServerAdapters { static final String MSG = "msg"; SSLContext sslContext; - HttpServerAdapters.HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] - HttpServerAdapters.HttpTestServer httpsTestServer; // HTTPS/1.1 - HttpServerAdapters.HttpTestServer http2TestServer; // HTTP/2 ( h2c ) - HttpServerAdapters.HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer httpTestServer; // HTTP/1.1 [ 5 servers ] + HttpTestServer httpsTestServer; // HTTPS/1.1 + HttpTestServer http2TestServer; // HTTP/2 ( h2c ) + HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; FileSystem zipFs; Path defaultFsPath; @@ -106,6 +100,9 @@ public class BodyHandlerOfFileTest implements HttpServerAdapters { @DataProvider(name = "defaultFsData") public Object[][] defaultFsData() { return new Object[][]{ + { http3URI, defaultFsPath, MSG, true }, + { http3URI, defaultFsPath, MSG, false }, + { httpURI, defaultFsPath, MSG, true }, { httpsURI, defaultFsPath, MSG, true }, { http2URI, defaultFsPath, MSG, true }, @@ -145,6 +142,9 @@ public class BodyHandlerOfFileTest implements HttpServerAdapters { @DataProvider(name = "zipFsData") public Object[][] zipFsData() { return new Object[][]{ + { http3URI, zipFsPath, MSG, true }, + { http3URI, zipFsPath, MSG, false }, + { httpURI, zipFsPath, MSG, true }, { httpsURI, zipFsPath, MSG, true }, { http2URI, zipFsPath, MSG, true }, @@ -168,6 +168,24 @@ public class BodyHandlerOfFileTest implements HttpServerAdapters { private static final int ITERATION_COUNT = 3; + private HttpClient newHttpClient(String uri) { + var builder = uri.contains("/http3/") + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + return builder.proxy(NO_PROXY) + .sslContext(sslContext) + .build(); + } + + private HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + private void receive(String uriString, Path path, String expectedMsg, @@ -176,12 +194,9 @@ public class BodyHandlerOfFileTest implements HttpServerAdapters { for (int i = 0; i < ITERATION_COUNT; i++) { if (!sameClient || client == null) { - client = HttpClient.newBuilder() - .proxy(NO_PROXY) - .sslContext(sslContext) - .build(); + client = newHttpClient(uriString); } - var req = HttpRequest.newBuilder(URI.create(uriString)) + var req = newRequestBuilder(URI.create(uriString)) .POST(BodyPublishers.noBody()) .build(); var resp = client.send(req, HttpResponse.BodyHandlers.ofFile(path)); @@ -190,6 +205,12 @@ public class BodyHandlerOfFileTest implements HttpServerAdapters { out.printf("Msg written to %s: %s\n", resp.body(), msg); assertEquals(resp.statusCode(), 200); assertEquals(msg, expectedMsg); + if (!sameClient) { + client.close(); + } + } + if (sameClient && client != null) { + client.close(); } } @@ -219,10 +240,15 @@ public class BodyHandlerOfFileTest implements HttpServerAdapters { https2TestServer.addHandler(new HttpEchoHandler(), "/https2/echo"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/echo"; + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new HttpEchoHandler(), "/http3/echo"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/echo"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -236,6 +262,7 @@ public class BodyHandlerOfFileTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); zipFs.close(); } diff --git a/test/jdk/java/net/httpclient/PathSubscriber/BodySubscriberOfFileTest.java b/test/jdk/java/net/httpclient/PathSubscriber/BodySubscriberOfFileTest.java index f2adc89ec17..2a7ba90c7b2 100644 --- a/test/jdk/java/net/httpclient/PathSubscriber/BodySubscriberOfFileTest.java +++ b/test/jdk/java/net/httpclient/PathSubscriber/BodySubscriberOfFileTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -38,9 +38,6 @@ * @run testng/othervm BodySubscriberOfFileTest */ -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import jdk.test.lib.util.FileUtils; import org.testng.annotations.AfterTest; @@ -52,14 +49,11 @@ import javax.net.ssl.SSLContext; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpRequest.BodyPublishers; import java.net.http.HttpResponse.BodyHandler; -import java.net.http.HttpResponse.BodySubscriber; import java.net.http.HttpResponse.BodySubscribers; import java.nio.Buffer; import java.nio.ByteBuffer; @@ -69,30 +63,29 @@ import java.util.Map; import java.util.concurrent.Flow; import java.util.stream.IntStream; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; -import jdk.httpclient.test.lib.http2.Http2TestServerConnection; -import jdk.httpclient.test.lib.http2.Http2TestExchange; -import jdk.httpclient.test.lib.http2.Http2Handler; -import jdk.httpclient.test.lib.http2.OutgoingPushPromise; -import jdk.httpclient.test.lib.http2.Queue; import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static org.testng.Assert.assertEquals; public class BodySubscriberOfFileTest implements HttpServerAdapters { static final String MSG = "msg"; SSLContext sslContext; - HttpServerAdapters.HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] - HttpServerAdapters.HttpTestServer httpsTestServer; // HTTPS/1.1 - HttpServerAdapters.HttpTestServer http2TestServer; // HTTP/2 ( h2c ) - HttpServerAdapters.HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer httpTestServer; // HTTP/1.1 [ 5 servers ] + HttpTestServer httpsTestServer; // HTTPS/1.1 + HttpTestServer http2TestServer; // HTTP/2 ( h2c ) + HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; FileSystem zipFs; Path defaultFsPath; @@ -111,6 +104,9 @@ public class BodySubscriberOfFileTest implements HttpServerAdapters { @DataProvider(name = "defaultFsData") public Object[][] defaultFsData() { return new Object[][]{ + { http3URI, defaultFsPath, MSG, true }, + { http3URI, defaultFsPath, MSG, false }, + { httpURI, defaultFsPath, MSG, true }, { httpsURI, defaultFsPath, MSG, true }, { http2URI, defaultFsPath, MSG, true }, @@ -150,6 +146,9 @@ public class BodySubscriberOfFileTest implements HttpServerAdapters { @DataProvider(name = "zipFsData") public Object[][] zipFsData() { return new Object[][]{ + { http3URI, zipFsPath, MSG, true }, + { http3URI, zipFsPath, MSG, false }, + { httpURI, zipFsPath, MSG, true }, { httpsURI, zipFsPath, MSG, true }, { http2URI, zipFsPath, MSG, true }, @@ -173,6 +172,24 @@ public class BodySubscriberOfFileTest implements HttpServerAdapters { private static final int ITERATION_COUNT = 3; + private HttpClient newHttpClient(String uri) { + var builder = uri.contains("/http3/") + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + return builder.proxy(NO_PROXY) + .sslContext(sslContext) + .build(); + } + + private HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + private void receive(String uriString, Path path, String expectedMsg, @@ -181,12 +198,9 @@ public class BodySubscriberOfFileTest implements HttpServerAdapters { for (int i = 0; i < ITERATION_COUNT; i++) { if (!sameClient || client == null) { - client = HttpClient.newBuilder() - .proxy(NO_PROXY) - .sslContext(sslContext) - .build(); + client = newHttpClient(uriString); } - var req = HttpRequest.newBuilder(URI.create(uriString)) + var req = newRequestBuilder(URI.create(uriString)) .POST(BodyPublishers.noBody()) .build(); @@ -197,6 +211,12 @@ public class BodySubscriberOfFileTest implements HttpServerAdapters { out.printf("Msg written to %s: %s\n", resp.body(), msg); assertEquals(resp.statusCode(), 200); assertEquals(msg, expectedMsg); + if (!sameClient) { + client.close(); + } + } + if (sameClient && client != null) { + client.close(); } } @@ -250,10 +270,15 @@ public class BodySubscriberOfFileTest implements HttpServerAdapters { https2TestServer.addHandler(new HttpEchoHandler(), "/https2/echo"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/echo"; + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new HttpEchoHandler(), "/http3/echo"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/echo"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -267,6 +292,7 @@ public class BodySubscriberOfFileTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); zipFs.close(); } diff --git a/test/jdk/java/net/httpclient/ProxyAuthDisabledSchemesSSL.java b/test/jdk/java/net/httpclient/ProxyAuthDisabledSchemesSSL.java index f0eb63f2a40..6d9d7e4f11b 100644 --- a/test/jdk/java/net/httpclient/ProxyAuthDisabledSchemesSSL.java +++ b/test/jdk/java/net/httpclient/ProxyAuthDisabledSchemesSSL.java @@ -35,16 +35,29 @@ * @run main/othervm/timeout=300 * -Djdk.http.auth.proxying.disabledSchemes=Basic,Digest * -Djdk.http.auth.tunneling.disabledSchemes=Digest,Basic - * ProxyAuthDisabledSchemesSSL SSL + * -Djdk.httpclient.http3.maxDirectConnectionTimeout=100 + * -Djdk.internal.httpclient.debug=err + * -Djdk.httpclient.HttpClient.log=headers + * ProxyAuthDisabledSchemesSSL SSL SERVER307 + * @run main/othervm/timeout=300 + * -Djdk.http.auth.proxying.disabledSchemes=Basic,Digest + * -Djdk.http.auth.tunneling.disabledSchemes=Digest,Basic + * -Djdk.httpclient.http3.maxDirectConnectionTimeout=100 + * -Djdk.httpclient.HttpClient.log=headers + * ProxyAuthDisabledSchemesSSL SSL SERVER PROXY * @run main/othervm/timeout=300 * -Djdk.http.auth.proxying.disabledSchemes=Basic * -Djdk.http.auth.tunneling.disabledSchemes=Basic * -Dtest.requiresHost=true + * -Djdk.httpclient.http3.maxDirectConnectionTimeout=100 + * -Djdk.httpclient.HttpClient.log=headers * ProxyAuthDisabledSchemesSSL SSL PROXY * @run main/othervm/timeout=300 * -Djdk.http.auth.proxying.disabledSchemes=Digest * -Djdk.http.auth.tunneling.disabledSchemes=Digest * -Dtest.requiresHost=true + * -Djdk.httpclient.http3.maxDirectConnectionTimeout=100 + * -Djdk.httpclient.HttpClient.log=headers * ProxyAuthDisabledSchemesSSL SSL PROXY */ diff --git a/test/jdk/java/net/httpclient/ProxyTest.java b/test/jdk/java/net/httpclient/ProxyTest.java index 8763e168a06..45c21e7ef4a 100644 --- a/test/jdk/java/net/httpclient/ProxyTest.java +++ b/test/jdk/java/net/httpclient/ProxyTest.java @@ -52,10 +52,13 @@ import java.util.concurrent.CopyOnWriteArrayList; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; + +import jdk.httpclient.test.lib.common.TestServerConfigurator; import jdk.test.lib.net.SimpleSSLContext; import static java.net.Proxy.NO_PROXY; @@ -67,9 +70,10 @@ import static java.net.Proxy.NO_PROXY; * Verifies that downgrading from HTTP/2 to HTTP/1.1 works through * an SSL Tunnel connection when the client is HTTP/2 and the server * and proxy are HTTP/1.1 - * @modules java.net.http - * @library /test/lib + * @modules java.net.http/jdk.internal.net.http.common + * @library /test/lib /test/jdk/java/net/httpclient/lib * @build jdk.test.lib.net.SimpleSSLContext ProxyTest + * jdk.httpclient.test.lib.common.TestServerConfigurator * @run main/othervm ProxyTest * @author danielfuchs */ @@ -103,9 +107,8 @@ public class ProxyTest { he.close(); } }); - - server.setHttpsConfigurator(new Configurator(SSLContext.getDefault())); InetSocketAddress addr = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0); + server.setHttpsConfigurator(new Configurator(addr.getAddress(), SSLContext.getDefault())); server.bind(addr, 0); return server; } @@ -407,13 +410,18 @@ public class ProxyTest { } static class Configurator extends HttpsConfigurator { - public Configurator(SSLContext ctx) { + private final InetAddress serverAddr; + + public Configurator(InetAddress serverAddr, SSLContext ctx) { super(ctx); + this.serverAddr = serverAddr; } @Override - public void configure (HttpsParameters params) { - params.setSSLParameters (getSSLContext().getSupportedSSLParameters()); + public void configure (final HttpsParameters params) { + final SSLParameters parameters = getSSLContext().getSupportedSSLParameters(); + TestServerConfigurator.addSNIMatcher(this.serverAddr, parameters); + params.setSSLParameters(parameters); } } diff --git a/test/jdk/java/net/httpclient/RedirectMethodChange.java b/test/jdk/java/net/httpclient/RedirectMethodChange.java index ed1afb384e1..f713c0dc5d1 100644 --- a/test/jdk/java/net/httpclient/RedirectMethodChange.java +++ b/test/jdk/java/net/httpclient/RedirectMethodChange.java @@ -33,19 +33,13 @@ import javax.net.ssl.SSLContext; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpRequest.BodyPublishers; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -53,6 +47,9 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.US_ASCII; import static org.testng.Assert.assertEquals; @@ -61,14 +58,16 @@ public class RedirectMethodChange implements HttpServerAdapters { SSLContext sslContext; HttpClient client; - HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] + HttpTestServer httpTestServer; // HTTP/1.1 [ 5 servers ] HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; static final String RESPONSE = "Hello world"; static final String POST_BODY = "This is the POST body 123909090909090"; @@ -90,6 +89,22 @@ public class RedirectMethodChange implements HttpServerAdapters { @DataProvider(name = "variants") public Object[][] variants() { return new Object[][] { + { http3URI, "GET", 301, "GET" }, + { http3URI, "GET", 302, "GET" }, + { http3URI, "GET", 303, "GET" }, + { http3URI, "GET", 307, "GET" }, + { http3URI, "GET", 308, "GET" }, + { http3URI, "POST", 301, "GET" }, + { http3URI, "POST", 302, "GET" }, + { http3URI, "POST", 303, "GET" }, + { http3URI, "POST", 307, "POST" }, + { http3URI, "POST", 308, "POST" }, + { http3URI, "PUT", 301, "PUT" }, + { http3URI, "PUT", 302, "PUT" }, + { http3URI, "PUT", 303, "GET" }, + { http3URI, "PUT", 307, "PUT" }, + { http3URI, "PUT", 308, "PUT" }, + { httpURI, "GET", 301, "GET" }, { httpURI, "GET", 302, "GET" }, { httpURI, "GET", 303, "GET" }, @@ -156,6 +171,15 @@ public class RedirectMethodChange implements HttpServerAdapters { }; } + private HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + @Test(dataProvider = "variants") public void test(String uriString, String method, @@ -163,7 +187,7 @@ public class RedirectMethodChange implements HttpServerAdapters { String expectedMethod) throws Exception { - HttpRequest req = HttpRequest.newBuilder(URI.create(uriString)) + HttpRequest req = newRequestBuilder(URI.create(uriString)) .method(method, getRequestBodyFor(method)) .header("X-Redirect-Code", Integer.toString(redirectCode)) .header("X-Expect-Method", expectedMethod) @@ -183,7 +207,7 @@ public class RedirectMethodChange implements HttpServerAdapters { if (sslContext == null) throw new AssertionError("Unexpected null sslContext"); - client = HttpClient.newBuilder() + client = newClientBuilderForH3() .followRedirects(HttpClient.Redirect.NORMAL) .sslContext(sslContext) .build(); @@ -212,10 +236,17 @@ public class RedirectMethodChange implements HttpServerAdapters { https2TestServer.addHandler(handler, "/https2/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/test/rmt"; + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + targetURI = "https://" + http3TestServer.serverAuthority() + "/http3/redirect/rmt"; + handler = new RedirMethodChgeHandler(targetURI); + http3TestServer.addHandler(handler, "/http3/"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/test/rmt"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -225,6 +256,7 @@ public class RedirectMethodChange implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } /** diff --git a/test/jdk/java/net/httpclient/RedirectTimeoutTest.java b/test/jdk/java/net/httpclient/RedirectTimeoutTest.java index 88b8fd964b0..be634398ba2 100644 --- a/test/jdk/java/net/httpclient/RedirectTimeoutTest.java +++ b/test/jdk/java/net/httpclient/RedirectTimeoutTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -32,10 +32,8 @@ * @run testng/othervm -Djdk.httpclient.HttpClient.log=errors,trace -Djdk.internal.httpclient.debug=false RedirectTimeoutTest */ -import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestExchange; -import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestHandler; -import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestResponseHeaders; -import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.test.lib.net.SimpleSSLContext; import org.testng.TestException; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -54,16 +52,24 @@ import java.net.http.HttpTimeoutException; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; +import java.util.Optional; + +import javax.net.ssl.SSLContext; import static java.net.http.HttpClient.Redirect.ALWAYS; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static jdk.test.lib.Utils.adjustTimeout; -public class RedirectTimeoutTest { +public class RedirectTimeoutTest implements HttpServerAdapters { - static HttpTestServer h1TestServer, h2TestServer; - static URI h1Uri, h1RedirectUri, h2Uri, h2RedirectUri, h2WarmupUri, testRedirectURI; + static SSLContext sslContext; + static HttpTestServer h1TestServer, h2TestServer, h3TestServer; + static URI h1Uri, h1RedirectUri, h2Uri, h2RedirectUri, + h3Uri, h3RedirectUri, h2WarmupUri, h3WarmupUri, testRedirectURI; private static final long TIMEOUT_MILLIS = 3000L; // 3s private static final long SLEEP_TIME = 1500L; // 1.5s public static final int ITERATIONS = 4; @@ -71,33 +77,47 @@ public class RedirectTimeoutTest { @BeforeTest public void setup() throws IOException { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) + throw new AssertionError("Unexpected null sslContext"); + h1TestServer = HttpTestServer.create(HTTP_1_1); h2TestServer = HttpTestServer.create(HTTP_2); + h3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); h1Uri = URI.create("http://" + h1TestServer.serverAuthority() + "/h1_test"); h1RedirectUri = URI.create("http://" + h1TestServer.serverAuthority() + "/h1_redirect"); h2Uri = URI.create("http://" + h2TestServer.serverAuthority() + "/h2_test"); h2RedirectUri = URI.create("http://" + h2TestServer.serverAuthority() + "/h2_redirect"); + h3Uri = URI.create("https://" + h3TestServer.serverAuthority() + "/h3_test"); + h3RedirectUri = URI.create("https://" + h3TestServer.serverAuthority() + "/h3_redirect"); h2WarmupUri = URI.create("http://" + h2TestServer.serverAuthority() + "/h2_warmup"); + h3WarmupUri = URI.create("https://" + h3TestServer.serverAuthority() + "/h3_warmup"); h1TestServer.addHandler(new GetHandler(), "/h1_test"); h1TestServer.addHandler(new RedirectHandler(), "/h1_redirect"); h2TestServer.addHandler(new GetHandler(), "/h2_test"); h2TestServer.addHandler(new RedirectHandler(), "/h2_redirect"); - h2TestServer.addHandler(new Http2Warmup(), "/h2_warmup"); + h3TestServer.addHandler(new GetHandler(), "/h3_test"); + h3TestServer.addHandler(new RedirectHandler(), "/h3_redirect"); + h2TestServer.addHandler(new HttpWarmup(), "/h2_warmup"); + h3TestServer.addHandler(new HttpWarmup(), "/h3_warmup"); h1TestServer.start(); h2TestServer.start(); + h3TestServer.start(); } @AfterTest public void teardown() { h1TestServer.stop(); h2TestServer.stop(); + h3TestServer.stop(); } @DataProvider(name = "testData") public Object[][] testData() { return new Object[][] { + { HTTP_3, h3Uri, h3RedirectUri }, { HTTP_1_1, h1Uri, h1RedirectUri }, - { HTTP_2, h2Uri, h2RedirectUri } + { HTTP_2, h2Uri, h2RedirectUri }, }; } @@ -105,16 +125,31 @@ public class RedirectTimeoutTest { public void test(Version version, URI uri, URI redirectURI) throws InterruptedException { out.println("Testing for " + version); testRedirectURI = redirectURI; - HttpClient.Builder clientBuilder = HttpClient.newBuilder().followRedirects(ALWAYS); - HttpRequest request = HttpRequest.newBuilder().uri(uri) + HttpClient.Builder clientBuilder = version == HTTP_3 + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + clientBuilder = clientBuilder.followRedirects(ALWAYS).sslContext(sslContext); + HttpRequest.Builder reqBuilder = HttpRequest.newBuilder().uri(uri); + if (version == HTTP_3) { + reqBuilder = reqBuilder.version(HTTP_3).setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + HttpRequest request = reqBuilder .GET() .version(version) .timeout(Duration.ofMillis(adjustTimeout(TIMEOUT_MILLIS))) .build(); try (HttpClient client = clientBuilder.build()) { - if (version.equals(HTTP_2)) - client.send(HttpRequest.newBuilder(h2WarmupUri).HEAD().build(), HttpResponse.BodyHandlers.discarding()); + Optional warmupUri = switch (version) { + case HTTP_1_1 -> Optional.empty(); + case HTTP_2 -> Optional.of(h2WarmupUri); + case HTTP_3 -> Optional.of(h3WarmupUri); + }; + if (warmupUri.isPresent()) { + HttpRequest head = reqBuilder.copy().uri(warmupUri.get()) + .version(version).HEAD().build(); + client.send(head, HttpResponse.BodyHandlers.discarding()); + } /* With TIMEOUT_MILLIS set to 1500ms and the server's RedirectHandler sleeping for 750ms before responding to each request, 4 iterations will take a guaranteed minimum time of 3000ms which will ensure that any @@ -135,7 +170,7 @@ public class RedirectTimeoutTest { } } - public static class Http2Warmup implements HttpTestHandler { + public static class HttpWarmup implements HttpTestHandler { @Override public void handle(HttpTestExchange t) throws IOException { diff --git a/test/jdk/java/net/httpclient/RedirectWithCookie.java b/test/jdk/java/net/httpclient/RedirectWithCookie.java index 13dfe766647..afc8618835b 100644 --- a/test/jdk/java/net/httpclient/RedirectWithCookie.java +++ b/test/jdk/java/net/httpclient/RedirectWithCookie.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -31,15 +31,10 @@ * RedirectWithCookie */ -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.CookieManager; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Redirect; @@ -49,7 +44,6 @@ import java.net.http.HttpResponse.BodyHandlers; import java.util.List; import javax.net.ssl.SSLContext; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -58,6 +52,9 @@ import org.testng.annotations.Test; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -69,10 +66,12 @@ public class RedirectWithCookie implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; static final String MESSAGE = "BasicRedirectTest message body"; static final int ITERATIONS = 3; @@ -80,6 +79,7 @@ public class RedirectWithCookie implements HttpServerAdapters { @DataProvider(name = "positive") public Object[][] positive() { return new Object[][] { + { http3URI, }, { httpURI, }, { httpsURI, }, { http2URI, }, @@ -87,10 +87,22 @@ public class RedirectWithCookie implements HttpServerAdapters { }; } + private HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + @Test(dataProvider = "positive") void test(String uriString) throws Exception { out.printf("%n---- starting (%s) ----%n", uriString); - HttpClient client = HttpClient.newBuilder() + var builder = uriString.contains("/http3/") + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + HttpClient client = builder .followRedirects(Redirect.ALWAYS) .cookieHandler(new CookieManager()) .sslContext(sslContext) @@ -98,7 +110,7 @@ public class RedirectWithCookie implements HttpServerAdapters { assert client.cookieHandler().isPresent(); URI uri = URI.create(uriString); - HttpRequest request = HttpRequest.newBuilder(uri).build(); + HttpRequest request = newRequestBuilder(uri).build(); out.println("Initial request: " + request.uri()); for (int i=0; i< ITERATIONS; i++) { @@ -114,6 +126,8 @@ public class RedirectWithCookie implements HttpServerAdapters { assertTrue(response.uri().getPath().endsWith("message")); assertPreviousRedirectResponses(request, response); } + + client.close(); } static void assertPreviousRedirectResponses(HttpRequest initialRequest, @@ -164,10 +178,15 @@ public class RedirectWithCookie implements HttpServerAdapters { https2TestServer.addHandler(new CookieRedirectHandler(), "/https2/cookie/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/cookie/redirect"; + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new CookieRedirectHandler(), "/http3/cookie/"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/cookie/redirect"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -176,6 +195,7 @@ public class RedirectWithCookie implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } static class CookieRedirectHandler implements HttpTestHandler { diff --git a/test/jdk/java/net/httpclient/ReferenceTracker.java b/test/jdk/java/net/httpclient/ReferenceTracker.java index d7e16d01201..6ea97ab6ac2 100644 --- a/test/jdk/java/net/httpclient/ReferenceTracker.java +++ b/test/jdk/java/net/httpclient/ReferenceTracker.java @@ -133,6 +133,15 @@ public class ReferenceTracker { "outstanding operations or unreleased resources", true); } + public AssertionError checkClosed(long graceDelayMs) { + Predicate hasOperations = (t) -> t.getOutstandingOperations() > 0; + Predicate hasSubscribers = (t) -> t.getOutstandingSubscribers() > 0; + return check(graceDelayMs, + hasOperations.or(hasSubscribers) + .or(Tracker::isSelectorAlive), + "outstanding operations or unreleased resources", true); + } + // This method is copied from ThreadInfo::toString, but removes the // limit on the stack trace depth (8 frames max) that ThreadInfo::toString // forcefully implement. We want to print all frames for better diagnosis. @@ -369,6 +378,7 @@ public class ReferenceTracker { warning.append("\n\tPending HTTP Requests: " + tracker.getOutstandingHttpRequests()); warning.append("\n\tPending HTTP/1.1 operations: " + tracker.getOutstandingHttpOperations()); warning.append("\n\tPending HTTP/2 streams: " + tracker.getOutstandingHttp2Streams()); + warning.append("\n\tPending HTTP/3 streams: " + tracker.getOutstandingHttp3Streams()); warning.append("\n\tPending WebSocket operations: " + tracker.getOutstandingWebSocketOperations()); warning.append("\n\tPending TCP connections: " + tracker.getOutstandingTcpConnections()); warning.append("\n\tPending Subscribers: " + tracker.getOutstandingSubscribers()); @@ -404,4 +414,23 @@ public class ReferenceTracker { "outstanding unclosed resources", true); return failed; } + + // This is a slightly more permissive check than the default checks, + // it only verifies that all CFs returned by send/sendAsync have been + // completed, and that all opened channels have been closed, and that + // the selector manager thread has exited. + // It doesn't check that all refcounts have reached 0. + // This is typically useful to only check that resources have been released. + public AssertionError checkShutdown(Tracker tracker, long graceDelayMs, boolean dumpThreads) { + Predicate isAlive = Tracker::isSelectorAlive; + Predicate hasPendingRequests = (t) -> t.getOutstandingHttpRequests() > 0; + Predicate hasPendingConnections = (t) -> t.getOutstandingTcpConnections() > 0; + Predicate hasPendingSubscribers = (t) -> t.getOutstandingSubscribers() > 0; + AssertionError failed = check(tracker, graceDelayMs, + isAlive.or(hasPendingRequests) + .or(hasPendingConnections) + .or(hasPendingSubscribers), + "outstanding unclosed resources", dumpThreads); + return failed; + } } diff --git a/test/jdk/java/net/httpclient/RequestBuilderTest.java b/test/jdk/java/net/httpclient/RequestBuilderTest.java index 991b6f9d2d9..97f4793d63c 100644 --- a/test/jdk/java/net/httpclient/RequestBuilderTest.java +++ b/test/jdk/java/net/httpclient/RequestBuilderTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2017, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -30,19 +30,25 @@ import java.net.URI; import java.net.URISyntaxException; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.HttpOption; import java.util.List; import java.util.Map; import java.util.Set; import java.net.http.HttpRequest; import java.net.http.HttpRequest.BodyPublishers; + +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.net.http.HttpClient.Version.HTTP_1_1; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpRequest.newBuilder; import static java.time.Duration.ofNanos; import static java.time.Duration.ofMinutes; import static java.time.Duration.ofSeconds; import static java.time.Duration.ZERO; -import static java.net.http.HttpClient.Version.HTTP_1_1; -import static java.net.http.HttpClient.Version.HTTP_2; -import static java.net.http.HttpRequest.newBuilder; import static org.testng.Assert.*; import org.testng.annotations.Test; @@ -96,6 +102,8 @@ public class RequestBuilderTest { assertThrows(NPE, () -> builder.setHeader(null, null)); assertThrows(NPE, () -> builder.setHeader("name", null)); assertThrows(NPE, () -> builder.setHeader(null, "value")); + assertThrows(NPE, () -> builder.setOption(null, null)); + assertThrows(NPE, () -> builder.setOption((HttpOption) null, ANY)); assertThrows(NPE, () -> builder.timeout(null)); assertThrows(NPE, () -> builder.POST(null)); assertThrows(NPE, () -> builder.PUT(null)); @@ -402,6 +410,7 @@ public class RequestBuilderTest { .header("A", "B") .POST(BodyPublishers.ofString("")) .timeout(ofSeconds(30)) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) .version(HTTP_1_1); HttpRequest.Builder copy = builder.copy(); assertTrue(builder != copy); @@ -418,6 +427,8 @@ public class RequestBuilderTest { assertEquals(copyRequest.timeout().get(), ofSeconds(30)); assertTrue(copyRequest.version().isPresent()); assertEquals(copyRequest.version().get(), HTTP_1_1); + assertTrue(copyRequest.getOption(H3_DISCOVERY).isPresent()); + assertEquals(copyRequest.getOption(H3_DISCOVERY).get(), HTTP_3_URI_ONLY); // lazy set URI ( maybe builder as a template ) copyRequest = newBuilder().copy().uri(uri).build(); diff --git a/test/jdk/java/net/httpclient/Response1xxTest.java b/test/jdk/java/net/httpclient/Response1xxTest.java index 57aed9407c2..d0148f0d684 100644 --- a/test/jdk/java/net/httpclient/Response1xxTest.java +++ b/test/jdk/java/net/httpclient/Response1xxTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -30,22 +30,29 @@ import java.net.ServerSocket; import java.net.Socket; import java.net.URI; import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; import java.net.http.HttpTimeoutException; import java.nio.charset.StandardCharsets; import java.time.Duration; import javax.net.ssl.SSLContext; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import jdk.test.lib.net.SimpleSSLContext; import jdk.test.lib.net.URIBuilder; import org.testng.Assert; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; + +import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpClient.Version.valueOf; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; /** * @test @@ -55,7 +62,7 @@ import static java.net.http.HttpClient.Version.HTTP_2; * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.HttpServerAdapters * jdk.httpclient.test.lib.http2.Http2TestServer * @run testng/othervm -Djdk.internal.httpclient.debug=true - * * -Djdk.httpclient.HttpClient.log=headers,requests,responses,errors Response1xxTest + * -Djdk.httpclient.HttpClient.log=headers,requests,responses,errors Response1xxTest */ public class Response1xxTest implements HttpServerAdapters { private static final String EXPECTED_RSP_BODY = "Hello World"; @@ -73,6 +80,9 @@ public class Response1xxTest implements HttpServerAdapters { private HttpTestServer https2Server; // h2 private String https2RequestURIBase; + private HttpTestServer http3Server; // h3 + private String http3RequestURIBase; + private final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; @BeforeClass @@ -81,7 +91,7 @@ public class Response1xxTest implements HttpServerAdapters { server = new Http11Server(serverSocket); new Thread(server).start(); http1RequestURIBase = URIBuilder.newBuilder().scheme("http").loopback() - .port(serverSocket.getLocalPort()).build().toString(); + .port(serverSocket.getLocalPort()).path("/http1").build().toString(); http2Server = HttpTestServer.create(HTTP_2); http2Server.addHandler(new Http2Handler(), "/http2/102"); @@ -109,6 +119,19 @@ public class Response1xxTest implements HttpServerAdapters { https2Server.start(); System.out.println("Started (https) HTTP2 server at " + https2Server.getAddress()); + http3Server = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3Server.addHandler(new Http3Handler(), "/http3/102"); + http3Server.addHandler(new Http3Handler(), "/http3/103"); + http3Server.addHandler(new Http3Handler(), "/http3/100"); + http3Server.addHandler(new Http3Handler(), "/http3/101"); + http3Server.addHandler(new OKHandler(), "/http3/200"); + http3Server.addHandler(new OnlyInformationalHandler(), "/http3/only-informational"); + http3RequestURIBase = URIBuilder.newBuilder().scheme("https").loopback() + .port(http3Server.getAddress().getPort()) + .path("/http3").build().toString(); + http3Server.start(); + System.out.println("Started (https) HTTP3 server at " + http3Server.getAddress()); + } @AfterClass @@ -132,6 +155,10 @@ public class Response1xxTest implements HttpServerAdapters { https2Server.stop(); System.out.println("Stopped (https) HTTP2 server"); } + if (http3Server != null) { + http3Server.stop(); + System.out.println("Stopped (https) HTTP3 server"); + } } } @@ -142,10 +169,11 @@ public class Response1xxTest implements HttpServerAdapters { "Content-Length: " + CONTENT_LENGTH + "\r\n\r\n" + EXPECTED_RSP_BODY; - private static final String REQ_LINE_FOO = "GET /test/foo HTTP/1.1\r\n"; - private static final String REQ_LINE_BAR = "GET /test/bar HTTP/1.1\r\n"; - private static final String REQ_LINE_HELLO = "GET /test/hello HTTP/1.1\r\n"; - private static final String REQ_LINE_BYE = "GET /test/bye HTTP/1.1\r\n"; + private static final String REQ_LINE_102 = "GET /http1/102 HTTP/1.1\r\n"; + private static final String REQ_LINE_103 = "GET /http1/103 HTTP/1.1\r\n"; + private static final String REQ_LINE_100 = "GET /http1/100 HTTP/1.1\r\n"; + private static final String REQ_LINE_101 = "GET /http1/101 HTTP/1.1\r\n"; + private static final String REQ_LINE_ONLY_INFO = "GET /http1/only-informational HTTP/1.1\r\n"; private final ServerSocket serverSocket; @@ -160,6 +188,7 @@ public class Response1xxTest implements HttpServerAdapters { System.out.println("Server running at " + serverSocket); while (!stop) { Socket socket = null; + boolean onlyInfo = false; try { // accept a connection socket = serverSocket.accept(); @@ -180,18 +209,22 @@ public class Response1xxTest implements HttpServerAdapters { System.out.println("Received following request line from client " + socket + " :\n" + requestLine); final int informationalResponseCode; - if (requestLine.startsWith(REQ_LINE_FOO)) { + if (requestLine.startsWith(REQ_LINE_102)) { // we will send intermediate/informational 102 response informationalResponseCode = 102; - } else if (requestLine.startsWith(REQ_LINE_BAR)) { + } else if (requestLine.startsWith(REQ_LINE_103)) { // we will send intermediate/informational 103 response informationalResponseCode = 103; - } else if (requestLine.startsWith(REQ_LINE_HELLO)) { + } else if (requestLine.startsWith(REQ_LINE_100)) { // we will send intermediate/informational 100 response informationalResponseCode = 100; - } else if (requestLine.startsWith(REQ_LINE_BYE)) { + } else if (requestLine.startsWith(REQ_LINE_101)) { // we will send intermediate/informational 101 response informationalResponseCode = 101; + } else if (requestLine.startsWith(REQ_LINE_ONLY_INFO)) { + // we will send intermediate/informational 102 response + informationalResponseCode = 102; + onlyInfo = true; } else { // unexpected client. ignore and close the client System.err.println("Ignoring unexpected request from client " + socket); @@ -215,6 +248,10 @@ public class Response1xxTest implements HttpServerAdapters { os.flush(); System.out.println("Sent response code " + informationalResponseCode + " to client " + socket); + if (onlyInfo) { + Thread.sleep(2000); + i = 1; + } } // now send a final response System.out.println("Now sending 200 response code to client " + socket); @@ -226,8 +263,10 @@ public class Response1xxTest implements HttpServerAdapters { // close the client connection safeClose(socket); // continue accepting any other client connections until we are asked to stop - System.err.println("Ignoring exception in server:"); - t.printStackTrace(); + if (!onlyInfo) { + System.err.println("Ignoring exception in server:"); + t.printStackTrace(); + } } } } @@ -261,6 +300,8 @@ public class Response1xxTest implements HttpServerAdapters { @Override public void handle(final HttpTestExchange exchange) throws IOException { final URI requestURI = exchange.getRequestURI(); + final Version version = exchange.getServerVersion(); + final int informationResponseCode; if (requestURI.getPath().endsWith("/102")) { informationResponseCode = 102; @@ -281,25 +322,29 @@ public class Response1xxTest implements HttpServerAdapters { // be sent multiple times) for (int i = 0; i < 3; i++) { exchange.sendResponseHeaders(informationResponseCode, -1); - System.out.println("Sent " + informationResponseCode + " response code from H2 server"); + System.out.println("Sent " + informationResponseCode + " response code from " + version + " server"); } // now send 200 response try { final byte[] body = EXPECTED_RSP_BODY.getBytes(StandardCharsets.UTF_8); exchange.sendResponseHeaders(200, body.length); - System.out.println("Sent 200 response from H2 server"); + System.out.println("Sent 200 response from " + version + " server"); try (OutputStream os = exchange.getResponseBody()) { os.write(body); } - System.out.println("Sent response body from H2 server"); + System.out.println("Sent response body from " + version + " server"); } catch (Throwable e) { - System.err.println("Failed to send response from HTTP2 handler:"); + System.err.println("Failed to send response from " + version + " handler:"); e.printStackTrace(); throw e; } } } + private static final class Http3Handler extends Http2Handler { + + } + private static class OnlyInformationalHandler implements HttpTestHandler { @Override @@ -338,20 +383,7 @@ public class Response1xxTest implements HttpServerAdapters { .version(HttpClient.Version.HTTP_1_1) .proxy(HttpClient.Builder.NO_PROXY).build(); TRACKER.track(client); - final URI[] requestURIs = new URI[]{ - new URI(http1RequestURIBase + "/test/foo"), - new URI(http1RequestURIBase + "/test/bar"), - new URI(http1RequestURIBase + "/test/hello")}; - for (final URI requestURI : requestURIs) { - final HttpRequest request = HttpRequest.newBuilder(requestURI).build(); - System.out.println("Issuing request to " + requestURI); - final HttpResponse response = client.send(request, - HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); - Assert.assertEquals(response.version(), HttpClient.Version.HTTP_1_1, - "Unexpected HTTP version in response"); - Assert.assertEquals(response.statusCode(), 200, "Unexpected response code"); - Assert.assertEquals(response.body(), EXPECTED_RSP_BODY, "Unexpected response body"); - } + test1xxFor(client, HTTP_1_1, http1RequestURIBase); } /** @@ -364,23 +396,58 @@ public class Response1xxTest implements HttpServerAdapters { final HttpClient client = HttpClient.newBuilder() .version(HTTP_2) .proxy(HttpClient.Builder.NO_PROXY).build(); + test1xxFor(client, HTTP_2, http2RequestURIBase); + } + + /** + * Tests that when a HTTP3 server sends intermediate 1xx response codes and then the final + * response, the client (internally) will ignore those intermediate informational response codes + * and only return the final response to the application + */ + @Test + public void test1xxForHTTP3() throws Exception { + final HttpClient client = newClientBuilderForH3() + .sslContext(sslContext) + .version(HTTP_3) + .proxy(HttpClient.Builder.NO_PROXY).build(); + test1xxFor(client, HTTP_3, http3RequestURIBase); + } + + private void test1xxFor(HttpClient client, Version version, String baseURI) throws Exception { TRACKER.track(client); final URI[] requestURIs = new URI[]{ - new URI(http2RequestURIBase + "/102"), - new URI(http2RequestURIBase + "/103"), - new URI(http2RequestURIBase + "/100")}; + new URI(baseURI + "/102"), + new URI(baseURI + "/103"), + new URI(baseURI + "/100")}; + var requestBuilder = HttpRequest.newBuilder(); + if (version == HTTP_3) { + requestBuilder.setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } for (final URI requestURI : requestURIs) { - final HttpRequest request = HttpRequest.newBuilder(requestURI).build(); + final HttpRequest request = requestBuilder.copy().uri(requestURI).build(); System.out.println("Issuing request to " + requestURI); final HttpResponse response = client.send(request, - HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); - Assert.assertEquals(response.version(), HTTP_2, + BodyHandlers.ofString(StandardCharsets.UTF_8)); + Assert.assertEquals(response.version(), version, "Unexpected HTTP version in response"); Assert.assertEquals(response.statusCode(), 200, "Unexpected response code"); Assert.assertEquals(response.body(), EXPECTED_RSP_BODY, "Unexpected response body"); } } + /** + * Tests that when a request is issued with a specific request timeout and the server + * responds with intermediate 1xx response code but doesn't respond with a final response within + * the timeout duration, then the application fails with a request timeout + */ + @Test + public void test1xxRequestTimeoutH1() throws Exception { + final HttpClient client = HttpClient.newBuilder() + .version(HTTP_1_1) + .proxy(HttpClient.Builder.NO_PROXY).build(); + test1xxRequestTimeout(client, HTTP_1_1, http1RequestURIBase); + } + /** * Tests that when a request is issued with a specific request timeout and the server @@ -388,22 +455,45 @@ public class Response1xxTest implements HttpServerAdapters { * the timeout duration, then the application fails with a request timeout */ @Test - public void test1xxRequestTimeout() throws Exception { + public void test1xxRequestTimeoutH2() throws Exception { final HttpClient client = HttpClient.newBuilder() .version(HTTP_2) .proxy(HttpClient.Builder.NO_PROXY).build(); + test1xxRequestTimeout(client, HTTP_2, http2RequestURIBase); + } + + /** + * Tests that when a request is issued with a specific request timeout and the server + * responds with intermediate 1xx response code but doesn't respond with a final response within + * the timeout duration, then the application fails with a request timeout + */ + @Test + public void test1xxRequestTimeoutH3() throws Exception { + final HttpClient client = newClientBuilderForH3() + .version(HTTP_3) + .sslContext(sslContext) + .proxy(HttpClient.Builder.NO_PROXY).build(); + test1xxRequestTimeout(client, HTTP_3, http3RequestURIBase); + } + + private void test1xxRequestTimeout(HttpClient client, Version version, String uriBase) throws Exception { TRACKER.track(client); - final URI requestURI = new URI(http2RequestURIBase + "/only-informational"); + final URI requestURI = new URI(uriBase + "/only-informational"); final Duration requestTimeout = Duration.ofSeconds(2); - final HttpRequest request = HttpRequest.newBuilder(requestURI).timeout(requestTimeout) + var requestBuilder = HttpRequest.newBuilder(requestURI); + if (version == HTTP_3) { + requestBuilder.setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + final HttpRequest request = requestBuilder.timeout(requestTimeout) .build(); System.out.println("Issuing request to " + requestURI); // we expect the request to timeout Assert.assertThrows(HttpTimeoutException.class, () -> { - client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + client.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); }); } + /** * Tests that when the HTTP/1.1 server sends a 101 response when the request hasn't asked * for an "Upgrade" then the request fails. @@ -413,13 +503,7 @@ public class Response1xxTest implements HttpServerAdapters { final HttpClient client = HttpClient.newBuilder() .version(HttpClient.Version.HTTP_1_1) .proxy(HttpClient.Builder.NO_PROXY).build(); - TRACKER.track(client); - final URI requestURI = new URI(http1RequestURIBase + "/test/bye"); - final HttpRequest request = HttpRequest.newBuilder(requestURI).build(); - System.out.println("Issuing request to " + requestURI); - // we expect the request to fail because the server sent an unexpected 101 - Assert.assertThrows(ProtocolException.class, - () -> client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8))); + testUnexpected101(client, HTTP_1_1, http1RequestURIBase); } @@ -433,13 +517,20 @@ public class Response1xxTest implements HttpServerAdapters { .version(HTTP_2) .sslContext(sslContext) .proxy(HttpClient.Builder.NO_PROXY).build(); - TRACKER.track(client); - final URI requestURI = new URI(https2RequestURIBase + "/101"); - final HttpRequest request = HttpRequest.newBuilder(requestURI).build(); - System.out.println("Issuing request to " + requestURI); - // we expect the request to fail because the server sent an unexpected 101 - Assert.assertThrows(ProtocolException.class, - () -> client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8))); + testUnexpected101(client, HTTP_2, https2RequestURIBase); + } + + /** + * Tests that when the HTTP2 server (over HTTPS) sends a 101 response when the request + * hasn't asked for an "Upgrade" then the request fails. + */ + @Test + public void testHTT3Unexpected101() throws Exception { + final HttpClient client = newClientBuilderForH3() + .version(HTTP_3) + .sslContext(sslContext) + .proxy(HttpClient.Builder.NO_PROXY).build(); + testUnexpected101(client, HTTP_3, http3RequestURIBase); } /** @@ -451,7 +542,6 @@ public class Response1xxTest implements HttpServerAdapters { final HttpClient client = HttpClient.newBuilder() .version(HTTP_2) .proxy(HttpClient.Builder.NO_PROXY).build(); - TRACKER.track(client); // when using HTTP2 version against a "http://" (non-secure) URI // the HTTP client (implementation) internally initiates a HTTP/1.1 connection // and then does an "Upgrade:" to "h2c". This it does when there isn't already a @@ -462,12 +552,21 @@ public class Response1xxTest implements HttpServerAdapters { // start our testing warmupH2Client(client); // start the actual testing - final URI requestURI = new URI(http2RequestURIBase + "/101"); - final HttpRequest request = HttpRequest.newBuilder(requestURI).build(); + testUnexpected101(client, HTTP_2, http2RequestURIBase); + } + + private void testUnexpected101(HttpClient client, Version version, String baseUri) throws Exception { + TRACKER.track(client); + final URI requestURI = new URI(baseUri + "/101"); + var requestBuilder = HttpRequest.newBuilder(requestURI); + if (version == HTTP_3) { + requestBuilder.setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + final HttpRequest request = requestBuilder.build(); System.out.println("Issuing request to " + requestURI); // we expect the request to fail because the server sent an unexpected 101 Assert.assertThrows(ProtocolException.class, - () -> client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8))); + () -> client.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8))); } // sends a request and expects a 200 response back diff --git a/test/jdk/java/net/httpclient/Response204V2Test.java b/test/jdk/java/net/httpclient/Response204V2Test.java index 610c312b667..9f867d0fb17 100644 --- a/test/jdk/java/net/httpclient/Response204V2Test.java +++ b/test/jdk/java/net/httpclient/Response204V2Test.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -54,7 +54,6 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.ITestContext; @@ -71,14 +70,19 @@ import javax.net.ssl.SSLContext; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; public class Response204V2Test implements HttpServerAdapters { SSLContext sslContext; HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String http2URI; String https2URI; + String http3URI; static final int RESPONSE_CODE = 204; static final int ITERATION_COUNT = 4; @@ -177,6 +181,7 @@ public class Response204V2Test implements HttpServerAdapters { private String[] uris() { return new String[] { + http3URI, http2URI, https2URI, }; @@ -201,9 +206,9 @@ public class Response204V2Test implements HttpServerAdapters { return result; } - private HttpClient makeNewClient() { + private HttpClient makeNewClient(HttpClient.Builder builder) { clientCount.incrementAndGet(); - HttpClient client = HttpClient.newBuilder() + HttpClient client = builder .proxy(HttpClient.Builder.NO_PROXY) .executor(executor) .sslContext(sslContext) @@ -211,14 +216,17 @@ public class Response204V2Test implements HttpServerAdapters { return TRACKER.track(client); } - HttpClient newHttpClient(boolean share) { - if (!share) return makeNewClient(); + HttpClient newHttpClient(String uri, boolean share) { + if (!share) return makeNewClient(newClientBuilderForH3()); HttpClient shared = sharedClient; if (shared != null) return shared; synchronized (this) { shared = sharedClient; if (shared == null) { - shared = sharedClient = makeNewClient(); + var builder = uri.contains("/http3/") + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + shared = sharedClient = makeNewClient(builder); } return shared; } @@ -241,25 +249,37 @@ public class Response204V2Test implements HttpServerAdapters { } } + private HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } @Test(dataProvider = "variants") public void test(String uri, boolean sameClient) throws Exception { checkSkip(); - System.out.println("Request to " + uri); + out.println("Request to " + uri); - HttpClient client = newHttpClient(sameClient); + HttpClient client = newHttpClient(uri, sameClient); - HttpRequest request = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest request = newRequestBuilder(URI.create(uri)) .GET() .build(); for (int i = 0; i < ITERATION_COUNT; i++) { - System.out.println("Iteration: " + i); + out.println("Iteration: " + i); HttpResponse response = client.send(request, BodyHandlers.ofString()); int expectedResponse = RESPONSE_CODE; if (response.statusCode() != expectedResponse) - throw new RuntimeException("wrong response code " + Integer.toString(response.statusCode())); + throw new RuntimeException("wrong response code " + response.statusCode()); } - System.out.println("test: DONE"); + if (!sameClient) { + out.println("test: closing test client"); + client.close(); + } + out.println("test: DONE"); } @BeforeTest @@ -279,9 +299,14 @@ public class Response204V2Test implements HttpServerAdapters { https2TestServer.addHandler(handler204, "/https2/test204/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/test204/x"; - serverCount.addAndGet(4); + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(handler204, "/http3/test204/"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/test204/x"; + + serverCount.addAndGet(3); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -294,6 +319,7 @@ public class Response204V2Test implements HttpServerAdapters { try { http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } finally { if (fail != null) { if (sharedClientName != null) { diff --git a/test/jdk/java/net/httpclient/ResponseBodyBeforeError.java b/test/jdk/java/net/httpclient/ResponseBodyBeforeError.java index beed8f86fa6..04ce3d531af 100644 --- a/test/jdk/java/net/httpclient/ResponseBodyBeforeError.java +++ b/test/jdk/java/net/httpclient/ResponseBodyBeforeError.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it diff --git a/test/jdk/java/net/httpclient/ResponsePublisher.java b/test/jdk/java/net/httpclient/ResponsePublisher.java index 3d6cf601b2f..5d90e20c626 100644 --- a/test/jdk/java/net/httpclient/ResponsePublisher.java +++ b/test/jdk/java/net/httpclient/ResponsePublisher.java @@ -31,11 +31,7 @@ * @run testng/othervm/timeout=480 ResponsePublisher */ -import com.sun.net.httpserver.HttpExchange; -import com.sun.net.httpserver.HttpHandler; import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.internal.net.http.common.OperationTrackers; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; @@ -48,10 +44,8 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; -import java.net.http.HttpHeaders; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; @@ -71,11 +65,13 @@ import java.util.concurrent.Flow.Publisher; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; @@ -84,10 +80,11 @@ import static org.testng.Assert.assertTrue; public class ResponsePublisher implements HttpServerAdapters { SSLContext sslContext; - HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] + HttpTestServer httpTestServer; // HTTP/1.1 [ 5 servers ] HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI_fixed; String httpURI_chunk; String httpsURI_fixed; @@ -96,6 +93,8 @@ public class ResponsePublisher implements HttpServerAdapters { String http2URI_chunk; String https2URI_fixed; String https2URI_chunk; + String http3URI_fixed; + String http3URI_chunk; static final int ITERATION_COUNT = 3; // a shared executor helps reduce the amount of threads created by the test @@ -143,6 +142,16 @@ public class ResponsePublisher implements HttpServerAdapters { @DataProvider(name = "variants") public Object[][] variants() { return new Object[][]{ + { http3URI_fixed, false, OF_PUBLISHER_API }, + { http3URI_chunk, false, OF_PUBLISHER_API }, + { http3URI_fixed, true, OF_PUBLISHER_API }, + { http3URI_chunk, true, OF_PUBLISHER_API }, + + { http3URI_fixed, false, OF_PUBLISHER_TEST }, + { http3URI_chunk, false, OF_PUBLISHER_TEST }, + { http3URI_fixed, true, OF_PUBLISHER_TEST }, + { http3URI_chunk, true, OF_PUBLISHER_TEST }, + { httpURI_fixed, false, OF_PUBLISHER_API }, { httpURI_chunk, false, OF_PUBLISHER_API }, { httpsURI_fixed, false, OF_PUBLISHER_API }, @@ -182,21 +191,33 @@ public class ResponsePublisher implements HttpServerAdapters { } final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; - HttpClient newHttpClient() { - return TRACKER.track(HttpClient.newBuilder() + HttpClient newHttpClient(String uri) { + var builder = uri.contains("/http3/") + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + return TRACKER.track(builder .executor(executor) .sslContext(sslContext) .build()); } + private HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + @Test(dataProvider = "variants") public void testExceptions(String uri, boolean sameClient, BHS handlers) throws Exception { HttpClient client = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(URI.create(uri)) .build(); BodyHandler>> handler = handlers.get(); HttpResponse>> response = client.send(req, handler); @@ -241,9 +262,9 @@ public class ResponsePublisher implements HttpServerAdapters { HttpClient client = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(URI.create(uri)) .build(); BodyHandler>> handler = handlers.get(); HttpResponse>> response = client.send(req, handler); @@ -270,9 +291,9 @@ public class ResponsePublisher implements HttpServerAdapters { HttpClient client = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + HttpRequest req = newRequestBuilder(URI.create(uri)) .build(); BodyHandler>> handler = handlers.get(); // We can reuse our BodySubscribers implementations to subscribe to the @@ -302,9 +323,9 @@ public class ResponsePublisher implements HttpServerAdapters { HttpClient client = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri+"/withBody")) + HttpRequest req = newRequestBuilder(URI.create(uri+"/withBody")) .build(); BodyHandler>> handler = handlers.get(); HttpResponse>> response = client.send(req, handler); @@ -331,9 +352,9 @@ public class ResponsePublisher implements HttpServerAdapters { HttpClient client = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) - client = newHttpClient(); + client = newHttpClient(uri); - HttpRequest req = HttpRequest.newBuilder(URI.create(uri+"/withBody")) + HttpRequest req = newRequestBuilder(URI.create(uri+"/withBody")) .build(); BodyHandler>> handler = handlers.get(); // We can reuse our BodySubscribers implementations to subscribe to the @@ -470,10 +491,21 @@ public class ResponsePublisher implements HttpServerAdapters { https2URI_fixed = "https://" + https2TestServer.serverAuthority() + "/https2/fixed"; https2URI_chunk = "https://" + https2TestServer.serverAuthority() + "/https2/chunk"; + // HTTP/3 + HttpTestHandler h3_fixedLengthHandler = new HTTP_FixedLengthHandler(); + HttpTestHandler h3_chunkedHandler = new HTTP_VariableLengthHandler(); + + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(h3_fixedLengthHandler, "/http3/fixed"); + http3TestServer.addHandler(h3_chunkedHandler, "/http3/chunk"); + http3URI_fixed = "https://" + http3TestServer.serverAuthority() + "/http3/fixed"; + http3URI_chunk = "https://" + http3TestServer.serverAuthority() + "/http3/chunk"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -485,6 +517,7 @@ public class ResponsePublisher implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } finally { if (fail != null) { throw fail; diff --git a/test/jdk/java/net/httpclient/RestrictedHeadersTest.java b/test/jdk/java/net/httpclient/RestrictedHeadersTest.java index 7e5881d3cee..f2ddf85b3f5 100644 --- a/test/jdk/java/net/httpclient/RestrictedHeadersTest.java +++ b/test/jdk/java/net/httpclient/RestrictedHeadersTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -47,7 +47,7 @@ public class RestrictedHeadersTest { // This list must be same as impl static Set defaultRestrictedHeaders = - Set.of("connection", "content-length", "expect", "host", "upgrade"); + Set.of("connection", "content-length", "expect", "host", "upgrade", "alt-used"); private static void runDefaultTest() { System.out.println("DEFAULT TEST: no property set"); diff --git a/test/jdk/java/net/httpclient/RetryWithCookie.java b/test/jdk/java/net/httpclient/RetryWithCookie.java index ec83fa5df89..dc046031bd3 100644 --- a/test/jdk/java/net/httpclient/RetryWithCookie.java +++ b/test/jdk/java/net/httpclient/RetryWithCookie.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -34,9 +34,6 @@ * RetryWithCookie */ -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -48,8 +45,6 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.CookieManager; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Redirect; @@ -63,12 +58,14 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -76,14 +73,16 @@ import static org.testng.Assert.assertTrue; public class RetryWithCookie implements HttpServerAdapters { SSLContext sslContext; - HttpTestServer httpTestServer; // HTTP/1.1 [ 4 servers ] + HttpTestServer httpTestServer; // HTTP/1.1 [ 5 servers ] HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; static final String MESSAGE = "BasicRedirectTest message body"; static final int ITERATIONS = 3; @@ -91,6 +90,7 @@ public class RetryWithCookie implements HttpServerAdapters { @DataProvider(name = "positive") public Object[][] positive() { return new Object[][] { + { http3URI, }, { httpURI, }, { httpsURI, }, { http2URI, }, @@ -101,11 +101,23 @@ public class RetryWithCookie implements HttpServerAdapters { static final AtomicLong requestCounter = new AtomicLong(); final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; + private HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + @Test(dataProvider = "positive") void test(String uriString) throws Exception { out.printf("%n---- starting (%s) ----%n", uriString); CookieManager cookieManager = new CookieManager(); - HttpClient client = HttpClient.newBuilder() + var builder = uriString.contains("/http3/") + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + HttpClient client = builder .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) .cookieHandler(cookieManager) @@ -121,7 +133,7 @@ public class RetryWithCookie implements HttpServerAdapters { cookieHeaders.put("Set-Cookie", cookies); cookieManager.put(uri, cookieHeaders); - HttpRequest request = HttpRequest.newBuilder(uri) + HttpRequest request = newRequestBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) .build(); out.println("Initial request: " + request.uri()); @@ -136,7 +148,7 @@ public class RetryWithCookie implements HttpServerAdapters { assertEquals(response.statusCode(), 200); assertEquals(response.body(), MESSAGE); assertEquals(response.headers().allValues("X-Request-Cookie"), cookies); - request = HttpRequest.newBuilder(uri) + request = newRequestBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) .build(); } @@ -164,10 +176,15 @@ public class RetryWithCookie implements HttpServerAdapters { https2TestServer.addHandler(new CookieRetryHandler(), "/https2/cookie/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/cookie/retry"; + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new CookieRetryHandler(), "/http3/cookie/"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/cookie/retry"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -179,6 +196,7 @@ public class RetryWithCookie implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } finally { if (fail != null) throw fail; } diff --git a/test/jdk/java/net/httpclient/ShutdownNow.java b/test/jdk/java/net/httpclient/ShutdownNow.java index 045876597a2..a5850b4af61 100644 --- a/test/jdk/java/net/httpclient/ShutdownNow.java +++ b/test/jdk/java/net/httpclient/ShutdownNow.java @@ -45,7 +45,9 @@ import java.io.OutputStream; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpClient.Redirect; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import java.nio.channels.ClosedChannelException; @@ -73,6 +75,10 @@ import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -90,10 +96,15 @@ public class ShutdownNow implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer h2h3TestServer; // HTTP/3 ( h2 + h3 ) + HttpTestServer h3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String h2h3URI; + String h2h3Head; + String h3URI; static final String MESSAGE = "ShutdownNow message body"; static final int ITERATIONS = 3; @@ -101,10 +112,12 @@ public class ShutdownNow implements HttpServerAdapters { @DataProvider(name = "positive") public Object[][] positive() { return new Object[][] { - { httpURI, }, - { httpsURI, }, - { http2URI, }, - { https2URI, }, + { h2h3URI, HTTP_3, h2h3TestServer.h3DiscoveryConfig()}, + { h3URI, HTTP_3, h3TestServer.h3DiscoveryConfig()}, + { httpURI, HTTP_1_1, ALT_SVC}, // do not attempt HTTP/3 + { httpsURI, HTTP_1_1, ALT_SVC}, // do not attempt HTTP/3 + { http2URI, HTTP_2, ALT_SVC}, // do not attempt HTTP/3 + { https2URI, HTTP_2, ALT_SVC}, // do not attempt HTTP/3 }; } @@ -118,6 +131,15 @@ public class ShutdownNow implements HttpServerAdapters { return t; } + void headRequest(HttpClient client) throws Exception { + HttpRequest request = HttpRequest.newBuilder(URI.create(h2h3Head)) + .version(HTTP_2) + .HEAD() + .build(); + var resp = client.send(request, BodyHandlers.discarding()); + assertEquals(resp.statusCode(), 200); + } + static boolean hasExpectedMessage(IOException io) { String message = io.getMessage(); if (message == null) return false; @@ -155,22 +177,28 @@ public class ShutdownNow implements HttpServerAdapters { } @Test(dataProvider = "positive") - void testConcurrent(String uriString) throws Exception { - out.printf("%n---- starting (%s) ----%n", uriString); - HttpClient client = HttpClient.newBuilder() + void testConcurrent(String uriString, Version version, Http3DiscoveryMode config) throws Exception { + out.printf("%n---- starting concurrent (%s, %s, %s) ----%n%n", uriString, version, config); + HttpClient client = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) + .version(version == HTTP_1_1 ? HTTP_2 : version) .sslContext(sslContext) .build(); TRACKER.track(client); int step = RANDOM.nextInt(ITERATIONS); try { + if (version == HTTP_3 && config != HTTP_3_URI_ONLY) { + headRequest(client); + } + List>> responses = new ArrayList<>(); for (int i = 0; i < ITERATIONS; i++) { URI uri = URI.create(uriString + "/concurrent/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) + .setOption(H3_DISCOVERY, config) .build(); out.printf("Iteration %d request: %s%n", i, request.uri()); CompletableFuture> responseCF; @@ -216,11 +244,13 @@ public class ShutdownNow implements HttpServerAdapters { } @Test(dataProvider = "positive") - void testSequential(String uriString) throws Exception { - out.printf("%n---- starting (%s) ----%n", uriString); - HttpClient client = HttpClient.newBuilder() + void testSequential(String uriString, Version version, Http3DiscoveryMode config) throws Exception { + out.printf("%n---- starting sequential (%s, %s, %s) ----%n%n", + uriString, version, config); + HttpClient client = newClientBuilderForH3() .proxy(NO_PROXY) .followRedirects(Redirect.ALWAYS) + .version(version == HTTP_1_1 ? HTTP_2 : version) .sslContext(sslContext) .build(); TRACKER.track(client); @@ -228,10 +258,15 @@ public class ShutdownNow implements HttpServerAdapters { int step = RANDOM.nextInt(ITERATIONS); out.printf("will shutdown client in step %d%n", step); try { + if (version == HTTP_3 && config != HTTP_3_URI_ONLY) { + headRequest(client); + } + for (int i = 0; i < ITERATIONS; i++) { URI uri = URI.create(uriString + "/sequential/iteration-" + i); HttpRequest request = HttpRequest.newBuilder(uri) .header("X-uuid", "uuid-" + requestCounter.incrementAndGet()) + .setOption(H3_DISCOVERY, config) .build(); out.printf("Iteration %d request: %s%n", i, request.uri()); CompletableFuture> responseCF; @@ -304,10 +339,21 @@ public class ShutdownNow implements HttpServerAdapters { https2TestServer.addHandler(new ServerRequestHandler(), "/https2/exec/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/exec/retry"; + h2h3TestServer = HttpTestServer.create(HTTP_3, sslContext); + h2h3TestServer.addHandler(new ServerRequestHandler(), "/h2h3/exec/"); + h2h3URI = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/exec/retry"; + h2h3TestServer.addHandler(new HttpHeadOrGetHandler(), "/h2h3/head/"); + h2h3Head = "https://" + h2h3TestServer.serverAuthority() + "/h2h3/head/"; + h3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3TestServer.addHandler(new ServerRequestHandler(), "/h3-only/exec/"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3-only/exec/retry"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + h2h3TestServer.start(); + h3TestServer.start(); } @AfterTest @@ -319,6 +365,8 @@ public class ShutdownNow implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + h2h3TestServer.stop(); + h3TestServer.stop(); } finally { if (fail != null) throw fail; } diff --git a/test/jdk/java/net/httpclient/SmokeTest.java b/test/jdk/java/net/httpclient/SmokeTest.java index 56294e8f02f..a10afef562f 100644 --- a/test/jdk/java/net/httpclient/SmokeTest.java +++ b/test/jdk/java/net/httpclient/SmokeTest.java @@ -24,11 +24,12 @@ /* * @test * @bug 8087112 8178699 8338569 - * @modules java.net.http + * @modules java.net.http/jdk.internal.net.http.common * java.logging * jdk.httpserver - * @library /test/lib / + * @library /test/lib /test/jdk/java/net/httpclient/lib / * @build jdk.test.lib.net.SimpleSSLContext ProxyServer + * jdk.httpclient.test.lib.common.TestServerConfigurator * @compile ../../../com/sun/net/httpserver/LogFilter.java * @compile ../../../com/sun/net/httpserver/FileServerHandler.java * @run main/othervm @@ -92,6 +93,8 @@ import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Random; + +import jdk.httpclient.test.lib.common.TestServerConfigurator; import jdk.test.lib.net.SimpleSSLContext; import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; import static java.nio.file.StandardOpenOption.WRITE; @@ -808,7 +811,7 @@ public class SmokeTest { ctx = new SimpleSSLContext().get(); sslparams = ctx.getDefaultSSLParameters(); //sslparams.setProtocols(new String[]{"TLSv1.2"}); - s2.setHttpsConfigurator(new Configurator(ctx)); + s2.setHttpsConfigurator(new Configurator(addr.getAddress(), ctx)); s1.start(); s2.start(); @@ -935,14 +938,19 @@ public class SmokeTest { } static class Configurator extends HttpsConfigurator { - public Configurator(SSLContext ctx) { + private final InetAddress serverAddr; + + public Configurator(InetAddress serverAddr, SSLContext ctx) { super(ctx); + this.serverAddr = serverAddr; } - public void configure (HttpsParameters params) { - SSLParameters p = getSSLContext().getDefaultSSLParameters(); + @Override + public void configure(final HttpsParameters params) { + final SSLParameters p = getSSLContext().getDefaultSSLParameters(); + TestServerConfigurator.addSNIMatcher(this.serverAddr, p); //p.setProtocols(new String[]{"TLSv1.2"}); - params.setSSLParameters (p); + params.setSSLParameters(p); } } diff --git a/test/jdk/java/net/httpclient/SpecialHeadersTest.java b/test/jdk/java/net/httpclient/SpecialHeadersTest.java index d3e1f37b592..9ad3114da2a 100644 --- a/test/jdk/java/net/httpclient/SpecialHeadersTest.java +++ b/test/jdk/java/net/httpclient/SpecialHeadersTest.java @@ -92,6 +92,7 @@ import static java.lang.System.out; import static java.net.http.HttpClient.Builder.NO_PROXY; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; import static java.nio.charset.StandardCharsets.US_ASCII; import org.testng.Assert; import static org.testng.Assert.assertEquals; @@ -347,7 +348,7 @@ public class SpecialHeadersTest implements HttpServerAdapters { boolean isInitialRequest = i == 0; boolean isSecure = uri.getScheme().equalsIgnoreCase("https"); - boolean isHTTP2 = resp.version() == HTTP_2; + boolean isHTTP1 = resp.version() == HTTP_1_1; boolean isNotH2CUpgrade = isSecure || (sameClient == true && !isInitialRequest); boolean isDefaultHostHeader = name.equalsIgnoreCase("host") && useDefault; @@ -356,13 +357,13 @@ public class SpecialHeadersTest implements HttpServerAdapters { // header in the response, except the response to the h2c Upgrade // request which will have been sent through HTTP/1.1. - if (isDefaultHostHeader && isHTTP2 && isNotH2CUpgrade) { + if (isDefaultHostHeader && !isHTTP1 && isNotH2CUpgrade) { assertTrue(resp.headers().firstValue("X-" + key).isEmpty()); assertTrue(resp.headers().allValues("X-" + key).isEmpty()); out.println("No X-" + key + " header received, as expected"); } else { String receivedHeaderString = value == null ? null - : resp.headers().firstValue("X-" + key).orElse(null); + : resp.headers().firstValue("X-" + key).get(); out.println("Got X-" + key + ": " + resp.headers().allValues("X-" + key)); if (value != null) { assertEquals(receivedHeaderString, value); @@ -512,7 +513,7 @@ public class SpecialHeadersTest implements HttpServerAdapters { // header in the response, except the response to the h2c Upgrade // request which will have been sent through HTTP/1.1. - if (isDefaultHostHeader && resp.version() == HTTP_2 && isNotH2CUpgrade) { + if (isDefaultHostHeader && resp.version() != HTTP_1_1 && isNotH2CUpgrade) { assertTrue(resp.headers().firstValue("X-" + key).isEmpty()); assertTrue(resp.headers().allValues("X-" + key).isEmpty()); out.println("No X-" + key + " header received, as expected"); diff --git a/test/jdk/java/net/httpclient/SplitResponse.java b/test/jdk/java/net/httpclient/SplitResponse.java index ad04e8a1627..9213901df4b 100644 --- a/test/jdk/java/net/httpclient/SplitResponse.java +++ b/test/jdk/java/net/httpclient/SplitResponse.java @@ -24,11 +24,8 @@ import java.io.IOException; import java.net.SocketException; import java.net.URI; -import java.util.ArrayList; import java.util.EnumSet; -import java.util.HashSet; import java.util.LinkedHashSet; -import java.util.List; import java.util.concurrent.CompletableFuture; import javax.net.ssl.SSLContext; import javax.net.ServerSocketFactory; @@ -43,6 +40,8 @@ import java.util.stream.Stream; import jdk.test.lib.net.SimpleSSLContext; import static java.lang.System.out; import static java.lang.String.format; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.ISO_8859_1; import static java.net.http.HttpResponse.BodyHandlers.ofString; @@ -212,7 +211,10 @@ public class SplitResponse { HttpClient client = newHttpClient(); - HttpRequest request = HttpRequest.newBuilder(uri).version(version).build(); + HttpRequest request = HttpRequest.newBuilder(uri) + .version(version) + .setOption(H3_DISCOVERY, ALT_SVC) + .build(); HttpResponse r; CompletableFuture> cf1; diff --git a/test/jdk/java/net/httpclient/StreamCloseTest.java b/test/jdk/java/net/httpclient/StreamCloseTest.java index 2d00a539e80..fcb00188581 100644 --- a/test/jdk/java/net/httpclient/StreamCloseTest.java +++ b/test/jdk/java/net/httpclient/StreamCloseTest.java @@ -27,13 +27,12 @@ * @test * @bug 8257736 * @library /test/jdk/java/net/httpclient/lib + * @library /test/lib * @build jdk.httpclient.test.lib.common.HttpServerAdapters * jdk.httpclient.test.lib.http2.Http2TestServer * @run testng/othervm StreamCloseTest */ -import com.sun.net.httpserver.HttpServer; - import java.io.InputStream; import java.io.IOException; import java.net.http.HttpClient; @@ -42,11 +41,8 @@ import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; import java.net.http.HttpRequest.BodyPublishers; import java.net.http.HttpResponse.BodyHandlers; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; diff --git a/test/jdk/java/net/httpclient/StreamingBody.java b/test/jdk/java/net/httpclient/StreamingBody.java index 7943968d239..bff6a2b39c4 100644 --- a/test/jdk/java/net/httpclient/StreamingBody.java +++ b/test/jdk/java/net/httpclient/StreamingBody.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -32,14 +32,9 @@ * StreamingBody */ -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -47,7 +42,6 @@ import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import javax.net.ssl.SSLContext; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -56,6 +50,9 @@ import org.testng.annotations.Test; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static java.net.http.HttpClient.Builder.NO_PROXY; import static org.testng.Assert.assertEquals; @@ -67,10 +64,12 @@ public class StreamingBody implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; static final String MESSAGE = "StreamingBody message body"; static final int ITERATIONS = 100; @@ -78,6 +77,7 @@ public class StreamingBody implements HttpServerAdapters { @DataProvider(name = "positive") public Object[][] positive() { return new Object[][] { + { http3URI, }, { httpURI, }, { httpsURI, }, { http2URI, }, @@ -85,15 +85,27 @@ public class StreamingBody implements HttpServerAdapters { }; } + private HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + @Test(dataProvider = "positive") void test(String uriString) throws Exception { out.printf("%n---- starting (%s) ----%n", uriString); URI uri = URI.create(uriString); - HttpRequest request = HttpRequest.newBuilder(uri).build(); + HttpRequest request = newRequestBuilder(uri).build(); for (int i=0; i< ITERATIONS; i++) { out.println("iteration: " + i); - HttpResponse response = HttpClient.newBuilder() + var builder = uriString.contains("/http3/") + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + HttpResponse response = builder .sslContext(sslContext) .proxy(NO_PROXY) .build() @@ -129,14 +141,20 @@ public class StreamingBody implements HttpServerAdapters { http2TestServer = HttpTestServer.create(HTTP_2); http2TestServer.addHandler(new MessageHandler(), "/http2/streamingbody/"); http2URI = "http://" + http2TestServer.serverAuthority() + "/http2/streamingbody/y"; + https2TestServer = HttpTestServer.create(HTTP_2, sslContext); https2TestServer.addHandler(new MessageHandler(), "/https2/streamingbody/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/streamingbody/z"; + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new MessageHandler(), "/http3/streamingbody/"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/streamingbody/z"; + httpTestServer.start(); httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -145,6 +163,7 @@ public class StreamingBody implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); } static class MessageHandler implements HttpTestHandler { diff --git a/test/jdk/java/net/httpclient/TEST.properties b/test/jdk/java/net/httpclient/TEST.properties index b4ca833b5e5..b79f7113952 100644 --- a/test/jdk/java/net/httpclient/TEST.properties +++ b/test/jdk/java/net/httpclient/TEST.properties @@ -1,9 +1,22 @@ -modules=java.base/sun.net.www.http \ +modules=java.base/jdk.internal.util \ + java.base/sun.net.www.http \ java.base/sun.net.www \ java.base/sun.net \ + java.net.http/jdk.internal.net.http \ java.net.http/jdk.internal.net.http.common \ java.net.http/jdk.internal.net.http.frame \ java.net.http/jdk.internal.net.http.hpack \ + java.base/jdk.internal.net.quic \ + java.net.http/jdk.internal.net.http.quic \ + java.net.http/jdk.internal.net.http.quic.packets \ + java.net.http/jdk.internal.net.http.quic.frames \ + java.net.http/jdk.internal.net.http.quic.streams \ + java.net.http/jdk.internal.net.http.http3.streams \ + java.net.http/jdk.internal.net.http.http3.frames \ + java.net.http/jdk.internal.net.http.http3 \ + java.net.http/jdk.internal.net.http.qpack \ + java.net.http/jdk.internal.net.http.qpack.readers \ + java.net.http/jdk.internal.net.http.qpack.writers \ + java.security.jgss \ java.logging \ jdk.httpserver -maxOutputSize = 2500000 diff --git a/test/jdk/java/net/httpclient/TimeoutBasic.java b/test/jdk/java/net/httpclient/TimeoutBasic.java index cdaedf06219..47210209408 100644 --- a/test/jdk/java/net/httpclient/TimeoutBasic.java +++ b/test/jdk/java/net/httpclient/TimeoutBasic.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -32,10 +32,13 @@ import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import java.net.http.HttpTimeoutException; import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; import javax.net.ServerSocketFactory; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLServerSocketFactory; +import java.nio.channels.DatagramChannel; import java.time.Duration; import java.util.Arrays; import java.util.List; @@ -43,6 +46,13 @@ import java.util.concurrent.CompletionException; import java.util.function.Function; import static java.lang.System.out; +import static java.net.StandardSocketOptions.SO_REUSEADDR; +import static java.net.StandardSocketOptions.SO_REUSEPORT; +import static java.net.http.HttpClient.Version.HTTP_1_1; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.H3_DISCOVERY; /** * @test @@ -67,7 +77,7 @@ public class TimeoutBasic { null); static final List VERSIONS = - Arrays.asList(HttpClient.Version.HTTP_2, HttpClient.Version.HTTP_1_1, null); + Arrays.asList(HTTP_2, HTTP_1_1, HTTP_3, null); static final List SCHEMES = List.of("https", "http"); @@ -81,7 +91,7 @@ public class TimeoutBasic { public static void main(String[] args) throws Exception { for (Function m : METHODS) { - for (HttpClient.Version version : List.of(HttpClient.Version.HTTP_1_1)) { + for (HttpClient.Version version : List.of(HTTP_1_1)) { for (HttpClient.Version reqVersion : VERSIONS) { for (String scheme : SCHEMES) { ServerSocketFactory ssf; @@ -141,11 +151,46 @@ public class TimeoutBasic { if (version != null) builder.version(version); HttpClient client = builder.build(); out.printf("%ntest(version=%s, reqVersion=%s, scheme=%s)%n", version, reqVersion, scheme); + DatagramChannel dc = null; try (ServerSocket ss = ssf.createServerSocket()) { ss.setReuseAddress(false); ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); int port = ss.getLocalPort(); - URI uri = new URI(scheme +"://localhost:" + port + "/"); + boolean useAltSvc = false; + if (reqVersion == HTTP_3 && "https".equalsIgnoreCase(scheme)) { + // Prevent the client to connecting to any random server + // opened by other tests on the machine, by opening a + // datagram channel on the same port than the server socket + dc = DatagramChannel.open(); + try { + if (dc.supportedOptions().contains(SO_REUSEADDR)) { + dc.setOption(SO_REUSEADDR, false); + } + if (dc.supportedOptions().contains(SO_REUSEPORT)) { + dc.setOption(SO_REUSEPORT, false); + } + dc.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), port)); + } catch (IOException io) { + // failed to bind - presumably the port was already taken + // we will configure the request to use ALT_SVC instead, which + // means no HTTP/3 connection will be attempted + useAltSvc = true; + // cleanup channel + dc.close(); + dc = null; + out.println("HTTP/3 direct connection cannot be tested: " + io); + } + } + + // can only reach here if dc port == port + assert dc == null || ((InetSocketAddress)dc.getLocalAddress()).getPort() == port; + + URI uri = URIBuilder.newBuilder() + .scheme(scheme) + .loopback() + .port(port) + .path("/") + .build(); out.println("--- TESTING Async"); int count = 0; @@ -153,6 +198,13 @@ public class TimeoutBasic { out.println(" with duration of " + duration); HttpRequest request = newRequest(uri, duration, reqVersion, method); if (request == null) continue; + if (useAltSvc) { + // make sure request will be downgraded to HTTP/2 if we + // have not been able to create `dc`. + request = HttpRequest.newBuilder(request, (n,v) -> true) + .setOption(H3_DISCOVERY, ALT_SVC) + .build(); + } count++; try { HttpResponse resp = client.sendAsync(request, BodyHandlers.discarding()).join(); @@ -163,11 +215,17 @@ public class TimeoutBasic { out.println("Body (should be null): " + resp.body()); throw new RuntimeException("Unexpected response: " + resp.statusCode()); } catch (CompletionException e) { - if (!(e.getCause() instanceof HttpTimeoutException)) { + Throwable x = e; + if (x.getCause() instanceof SSLHandshakeException s) { + if (s.getCause() instanceof HttpTimeoutException) { + x = s; + } + } + if (!(x.getCause() instanceof HttpTimeoutException)) { e.printStackTrace(out); throw new RuntimeException("Unexpected exception: " + e.getCause()); } else { - out.println("Caught expected timeout: " + e.getCause()); + out.println("Caught expected timeout: " + x.getCause()); } } } @@ -179,15 +237,25 @@ public class TimeoutBasic { out.println(" with duration of " + duration); HttpRequest request = newRequest(uri, duration, reqVersion, method); if (request == null) continue; + if (useAltSvc) { + // make sure request will be downgraded to HTTP/2 if we + // have not been able to create `dc`. + request = HttpRequest.newBuilder(request, (n,v) -> true) + .setOption(H3_DISCOVERY, ALT_SVC) + .build(); + } count++; try { - client.send(request, BodyHandlers.discarding()); + HttpResponse resp = client.send(request, BodyHandlers.discarding()); + throw new RuntimeException("Unexpected response: " + resp.statusCode()); } catch (HttpTimeoutException e) { out.println("Caught expected timeout: " + e); } } assert count >= TIMEOUTS.size() -1; + } finally { + if (dc != null) dc.close(); } } } diff --git a/test/jdk/java/net/httpclient/TlsContextTest.java b/test/jdk/java/net/httpclient/TlsContextTest.java index 707a2e2d7c7..f24007f6d90 100644 --- a/test/jdk/java/net/httpclient/TlsContextTest.java +++ b/test/jdk/java/net/httpclient/TlsContextTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -42,6 +42,7 @@ import static java.lang.System.out; import static java.net.http.HttpClient.Version; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; import static java.net.http.HttpResponse.BodyHandlers.ofString; import static org.testng.Assert.assertEquals; import jdk.test.lib.security.SecurityUtils; @@ -75,7 +76,8 @@ public class TlsContextTest implements HttpServerAdapters { server = SimpleSSLContext.getContext("TLS"); final ExecutorService executor = Executors.newCachedThreadPool(); https2Server = HttpTestServer.of( - new Http2TestServer("localhost", true, 0, executor, 50, null, server, true)); + new Http2TestServer("localhost", true, 0, executor, 50, null, server, true) + .enableH3AltServiceOnSamePort()); https2Server.addHandler(new TlsVersionTestHandler("https", https2Server), "/server/"); https2Server.start(); @@ -89,6 +91,9 @@ public class TlsContextTest implements HttpServerAdapters { { SimpleSSLContext.getContext("TLSv1.2"), HTTP_2, "TLSv1.2" }, { SimpleSSLContext.getContext("TLSv1.1"), HTTP_1_1, "TLSv1.1" }, { SimpleSSLContext.getContext("TLSv1.1"), HTTP_2, "TLSv1.1" }, + { SimpleSSLContext.getContext("TLSv1.3"), HTTP_3, "TLSv1.3" }, + { SimpleSSLContext.getContext("TLSv1.2"), HTTP_3, "TLSv1.2" }, + { SimpleSSLContext.getContext("TLSv1.1"), HTTP_3, "TLSv1.1" }, }; } @@ -99,21 +104,32 @@ public class TlsContextTest implements HttpServerAdapters { public void testVersionProtocols(SSLContext context, Version version, String expectedProtocol) throws Exception { - HttpClient client = HttpClient.newBuilder() - .sslContext(context) - .version(version) + // for HTTP/3 we won't accept to set the version to HTTP/3 on the + // client if we don't have TLSv1.3; We will set the version + // on the request instead in that case. + var builder = version == HTTP_3 ? newClientBuilderForH3() + : HttpClient.newBuilder().version(version); + var reqBuilder = HttpRequest.newBuilder(new URI(https2URI)); + + HttpClient client = builder.sslContext(context) .build(); - HttpRequest request = HttpRequest.newBuilder(new URI(https2URI)) - .GET() - .build(); + if (version == HTTP_3) { + // warmup to obtain AltService + client.send(reqBuilder.version(HTTP_2).GET().build(), ofString()); + reqBuilder = reqBuilder.version(HTTP_3); + } + + HttpRequest request = reqBuilder.GET().build(); for (int i = 0; i < ITERATIONS; i++) { HttpResponse response = client.send(request, ofString()); - testAllProtocols(response, expectedProtocol); + testAllProtocols(response, expectedProtocol, version); } + client.close(); } private void testAllProtocols(HttpResponse response, - String expectedProtocol) throws Exception { + String expectedProtocol, + Version clientVersion) throws Exception { String protocol = response.sslSession().get().getProtocol(); int statusCode = response.statusCode(); Version version = response.version(); @@ -121,7 +137,12 @@ public class TlsContextTest implements HttpServerAdapters { out.println("The protocol negotiated is :" + protocol); assertEquals(statusCode, 200); assertEquals(protocol, expectedProtocol); - assertEquals(version, expectedProtocol.equals("TLSv1.1") ? HTTP_1_1 : HTTP_2); + if (clientVersion == HTTP_3) { + assertEquals(version, expectedProtocol.equals("TLSv1.1") ? HTTP_1_1 : + expectedProtocol.equals("TLSv1.2") ? HTTP_2 : HTTP_3); + } else { + assertEquals(version, expectedProtocol.equals("TLSv1.1") ? HTTP_1_1 : HTTP_2); + } } @AfterTest @@ -143,8 +164,10 @@ public class TlsContextTest implements HttpServerAdapters { try (InputStream is = t.getRequestBody(); OutputStream os = t.getResponseBody()) { byte[] bytes = is.readAllBytes(); - t.sendResponseHeaders(200, 10); - os.write(bytes); + t.sendResponseHeaders(200, bytes.length); + if (bytes.length > 0) { + os.write(bytes); + } } } } diff --git a/test/jdk/java/net/httpclient/UnauthorizedTest.java b/test/jdk/java/net/httpclient/UnauthorizedTest.java index e2124f2a7ed..0c8c2b1d2cd 100644 --- a/test/jdk/java/net/httpclient/UnauthorizedTest.java +++ b/test/jdk/java/net/httpclient/UnauthorizedTest.java @@ -30,7 +30,8 @@ * for the client. If no authenticator is configured the client * should simply let the caller deal with the unauthorized response. * @library /test/lib /test/jdk/java/net/httpclient/lib - * @build jdk.httpclient.test.lib.common.HttpServerAdapters jdk.test.lib.net.SimpleSSLContext + * @build jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.test.lib.net.SimpleSSLContext ReferenceTracker * @run testng/othervm * -Djdk.httpclient.HttpClient.log=headers * UnauthorizedTest @@ -60,6 +61,9 @@ import jdk.httpclient.test.lib.common.HttpServerAdapters; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -71,10 +75,12 @@ public class UnauthorizedTest implements HttpServerAdapters { HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; HttpClient authClient; HttpClient noAuthClient; @@ -97,6 +103,12 @@ public class UnauthorizedTest implements HttpServerAdapters { @DataProvider(name = "all") public Object[][] positive() { return new Object[][] { + { http3URI + "/server", UNAUTHORIZED, true, ref(authClient)}, + { http3URI + "/server", UNAUTHORIZED, false, ref(authClient)}, + { http3URI + "/server", UNAUTHORIZED, true, ref(noAuthClient)}, + { http3URI + "/server", UNAUTHORIZED, false, ref(noAuthClient)}, + + { httpURI + "/server", UNAUTHORIZED, true, ref(authClient)}, { httpsURI + "/server", UNAUTHORIZED, true, ref(authClient)}, { http2URI + "/server", UNAUTHORIZED, true, ref(authClient)}, @@ -137,6 +149,15 @@ public class UnauthorizedTest implements HttpServerAdapters { static final Authenticator authenticator = new Authenticator() { }; + private HttpRequest.Builder newRequestBuilder(URI uri) { + var builder = HttpRequest.newBuilder(uri); + if (uri.getRawPath().contains("/http3/")) { + builder = builder.version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } + return builder; + } + @Test(dataProvider = "all") void test(String uriString, int code, boolean async, WeakReference clientRef) throws Throwable { HttpClient client = clientRef.get(); @@ -145,8 +166,7 @@ public class UnauthorizedTest implements HttpServerAdapters { client.authenticator().isPresent() ? "authClient" : "noAuthClient"); URI uri = URI.create(uriString); - HttpRequest.Builder requestBuilder = HttpRequest - .newBuilder(uri) + HttpRequest.Builder requestBuilder = newRequestBuilder(uri) .GET(); HttpRequest request = requestBuilder.build(); @@ -163,6 +183,7 @@ public class UnauthorizedTest implements HttpServerAdapters { try { response = client.sendAsync(request, BodyHandlers.ofString()).get(); } catch (ExecutionException ex) { + ex.printStackTrace(); throw ex.getCause(); } } @@ -204,13 +225,17 @@ public class UnauthorizedTest implements HttpServerAdapters { https2TestServer.addHandler(new UnauthorizedHandler(), "/https2/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2"; - authClient = HttpClient.newBuilder() + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new UnauthorizedHandler(), "/http3/"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3"; + + authClient = newClientBuilderForH3() .proxy(HttpClient.Builder.NO_PROXY) .sslContext(sslContext) .authenticator(authenticator) .build(); - noAuthClient = HttpClient.newBuilder() + noAuthClient = newClientBuilderForH3() .proxy(HttpClient.Builder.NO_PROXY) .sslContext(sslContext) .build(); @@ -219,6 +244,7 @@ public class UnauthorizedTest implements HttpServerAdapters { httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); } @AfterTest @@ -236,6 +262,7 @@ public class UnauthorizedTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); if (error != null) throw error; } diff --git a/test/jdk/java/net/httpclient/UserAuthWithAuthenticator.java b/test/jdk/java/net/httpclient/UserAuthWithAuthenticator.java index 97be90f8587..ee28980c035 100644 --- a/test/jdk/java/net/httpclient/UserAuthWithAuthenticator.java +++ b/test/jdk/java/net/httpclient/UserAuthWithAuthenticator.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -21,34 +21,12 @@ * questions. */ -/** - * @test - * @bug 8326949 - * @summary Authorization header is removed when a proxy Authenticator is set - * @library /test/lib /test/jdk/java/net/httpclient /test/jdk/java/net/httpclient/lib - * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.HttpServerAdapters - * jdk.httpclient.test.lib.http2.Http2TestServer - * jdk.test.lib.net.IPSupport - * - * @modules java.net.http/jdk.internal.net.http.common - * java.net.http/jdk.internal.net.http.frame - * java.net.http/jdk.internal.net.http.hpack - * java.logging - * java.base/sun.net.www.http - * java.base/sun.net.www - * java.base/sun.net - * - * @run main/othervm UserAuthWithAuthenticator - */ - import java.io.*; import java.net.*; import java.net.http.HttpClient; -import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; -import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.net.http.HttpResponse; -import java.net.http.HttpResponse.BodyHandlers; import javax.net.ssl.*; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -56,132 +34,165 @@ import java.util.regex.*; import java.util.*; import jdk.test.lib.net.SimpleSSLContext; import jdk.test.lib.net.URIBuilder; -import jdk.test.lib.net.IPSupport; import jdk.httpclient.test.lib.common.HttpServerAdapters; import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestHandler; import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestExchange; import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; import jdk.httpclient.test.lib.http2.Http2TestServer; import com.sun.net.httpserver.BasicAuthenticator; - -import jdk.test.lib.net.URIBuilder; +import org.junit.jupiter.api.Test; import static java.nio.charset.StandardCharsets.US_ASCII; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; -public class UserAuthWithAuthenticator { - private static final String AUTH_PREFIX = "Basic "; +/* + * @test + * @bug 8326949 + * @summary Authorization header is removed when a proxy Authenticator is set + * @library /test/lib /test/jdk/java/net/httpclient /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.httpclient.test.lib.http2.Http2TestServer + * jdk.test.lib.net.IPSupport + * @run junit UserAuthWithAuthenticator + */ +class UserAuthWithAuthenticator { - static class AuthTestHandler implements HttpTestHandler { - volatile String authValue; - final String response = "Hello world"; + private static final class AuthTestHandler implements HttpTestHandler { + private volatile String authHeaderValue; @Override public void handle(HttpTestExchange t) throws IOException { try (InputStream is = t.getRequestBody(); OutputStream os = t.getResponseBody()) { - byte[] bytes = is.readAllBytes(); - authValue = t.getRequestHeaders() + is.readAllBytes(); + authHeaderValue = t.getRequestHeaders() .firstValue("Authorization") - .orElse(AUTH_PREFIX) - .substring(AUTH_PREFIX.length()); + .orElse(""); + String response = "Hello world"; t.sendResponseHeaders(200, response.length()); os.write(response.getBytes(US_ASCII)); t.close(); } } - String authValue() {return authValue;} } - // if useHeader is true, we expect the Authenticator was not called - // and the user set header used. If false, Authenticator must - // be called and the user set header not used. + @Test + void h2Test() throws Exception { + h2Test(true, true); + h2Test(false, true); + h2Test(true, false); + } - // If rightPassword is true we expect the authentication to succeed and 200 OK - // If false, then an error should be returned. - - static void h2Test(final boolean useHeader, boolean rightPassword) throws Exception { - SSLContext ctx; - HttpTestServer h2s = null; - HttpClient client = null; - ExecutorService ex=null; - try { - ctx = new SimpleSSLContext().get(); - ex = Executors.newCachedThreadPool(); - InetAddress addr = InetAddress.getLoopbackAddress(); - - h2s = HttpTestServer.of(new Http2TestServer(addr, "::1", true, 0, ex, - 10, null, ctx, false)); - AuthTestHandler h = new AuthTestHandler(); - var context = h2s.addHandler(h, "/test1"); - context.setAuthenticator(new BasicAuthenticator("realm") { - public boolean checkCredentials(String username, String password) { - if (useHeader) { - return username.equals("user") && password.equals("pwd"); - } else { - return username.equals("serverUser") && password.equals("serverPwd"); - } - } - }); - h2s.start(); - - int port = h2s.getAddress().getPort(); - ServerAuth sa = new ServerAuth(); - var plainCreds = rightPassword? "user:pwd" : "user:wrongPwd"; - var encoded = java.util.Base64.getEncoder().encodeToString(plainCreds.getBytes(US_ASCII)); - - URI uri = URIBuilder.newBuilder() - .scheme("https") - .host(addr.getHostAddress()) - .port(port) - .path("/test1/foo.txt") - .build(); - - HttpClient.Builder builder = HttpClient.newBuilder() - .sslContext(ctx) - .executor(ex); - - builder.authenticator(sa); - client = builder.build(); - - HttpRequest req = HttpRequest.newBuilder(uri) - .version(HttpClient.Version.HTTP_2) - .header(useHeader ? "Authorization" : "X-Ignore", AUTH_PREFIX + encoded) - .GET() - .build(); - - HttpResponse resp = client.send(req, HttpResponse.BodyHandlers.ofString()); - if (!useHeader) { - assertTrue(resp.statusCode() == 200, "Expected 200 response"); - assertTrue(!h.authValue().equals(encoded), "Expected user set header to not be set"); - assertTrue(h.authValue().equals(sa.authValue()), "Expected auth value from Authenticator"); - assertTrue(sa.wasCalled(), "Expected authenticator to be called"); - System.out.println("h2Test: using authenticator OK"); - } else if (rightPassword) { - assertTrue(resp.statusCode() == 200, "Expected 200 response"); - assertTrue(h.authValue().equals(encoded), "Expected user set header to be set"); - assertTrue(!sa.wasCalled(), "Expected authenticator not to be called"); - System.out.println("h2Test: using user set header OK"); - } else { - assertTrue(resp.statusCode() == 401, "Expected 401 response"); - assertTrue(!sa.wasCalled(), "Expected authenticator not to be called"); - System.out.println("h2Test: using user set header with wrong password OK"); - } - } finally { - if (h2s != null) - h2s.stop(); - if (client != null) - client.close(); - if (ex != null) - ex.shutdown(); + private static void h2Test(final boolean useHeader, boolean rightPassword) throws Exception { + SSLContext sslContext = new SimpleSSLContext().get(); + try (ExecutorService executor = Executors.newCachedThreadPool(); + HttpTestServer server = HttpTestServer.of(new Http2TestServer( + InetAddress.getLoopbackAddress(), + "::1", + true, + 0, + executor, + 10, + null, + sslContext, + false)); + HttpClient client = HttpClient.newBuilder() + .sslContext(sslContext) + .executor(executor) + .authenticator(new ServerAuth()) + .build()) { + hXTest(useHeader, rightPassword, server, client, HttpClient.Version.HTTP_2); } } - static final String data = "0123456789"; + @Test + void h3Test() throws Exception { + h3Test(true, true); + h3Test(false, true); + h3Test(true, false); + } - static final String data1 = "ABCDEFGHIJKL"; + private static void h3Test(final boolean useHeader, boolean rightPassword) throws Exception { + SSLContext sslContext = new SimpleSSLContext().get(); + try (ExecutorService executor = Executors.newCachedThreadPool(); + HttpTestServer server = HttpTestServer.create(Http3DiscoveryMode.HTTP_3_URI_ONLY, sslContext, executor); + HttpClient client = HttpServerAdapters.createClientBuilderForH3() + .sslContext(sslContext) + .executor(executor) + .authenticator(new ServerAuth()) + .build()) { + hXTest(useHeader, rightPassword, server, client, HttpClient.Version.HTTP_3); + } + } - static final String[] proxyResponses = { + /** + * @param useHeader If {@code true}, we expect the authenticator was not called and the user set header used. + * If {@code false}, authenticator must be called and the user set header discarded. + * @param rightPassword If {@code true}, we expect the authentication to succeed with {@code 200 OK}. + * If {@code false}, then an error should be returned. + */ + private static void hXTest( + final boolean useHeader, + boolean rightPassword, + HttpTestServer server, + HttpClient client, + HttpClient.Version version) + throws Exception { + + AuthTestHandler handler = new AuthTestHandler(); + var context = server.addHandler(handler, "/test1"); + context.setAuthenticator(new BasicAuthenticator("realm") { + public boolean checkCredentials(String username, String password) { + if (useHeader) { + return username.equals("user") && password.equals("pwd"); + } else { + return username.equals("serverUser") && password.equals("serverPwd"); + } + } + }); + server.start(); + + URI uri = URIBuilder.newBuilder() + .scheme("https") + .host(server.getAddress().getAddress()) + .port(server.getAddress().getPort()) + .path("/test1/foo.txt") + .build(); + + var authHeaderValue = authHeaderValue("user", rightPassword ? "pwd" : "wrongPwd"); + HttpRequest req = HttpRequest.newBuilder(uri) + .version(version) + .header(useHeader ? "Authorization" : "X-Ignore", authHeaderValue) + .GET() + .build(); + + HttpResponse resp = client.send(req, HttpResponse.BodyHandlers.ofString()); + var sa = (ServerAuth) client.authenticator().orElseThrow(); + if (!useHeader) { + assertEquals(200, resp.statusCode(), "Expected 200 response"); + assertNotEquals(handler.authHeaderValue, authHeaderValue, "Expected user set header to not be set"); + assertEquals(handler.authHeaderValue, ServerAuth.AUTH_HEADER_VALUE, "Expected auth value from Authenticator"); + assertTrue(sa.called, "Expected authenticator to be called"); + } else if (rightPassword) { + assertEquals(200, resp.statusCode(), "Expected 200 response"); + assertEquals(authHeaderValue, handler.authHeaderValue, "Expected user set header to be set"); + assertFalse(sa.called, "Expected authenticator not to be called"); + } else { + assertEquals(401, resp.statusCode(), "Expected 401 response"); + assertFalse(sa.called, "Expected authenticator not to be called"); + } + + } + + private static final String data = "0123456789"; + + private static final String data1 = "ABCDEFGHIJKL"; + + private static final String[] proxyResponses = { "HTTP/1.1 407 Proxy Authentication Required\r\n"+ "Content-Length: 0\r\n" + "Proxy-Authenticate: Basic realm=\"Access to the proxy\"\r\n\r\n" @@ -192,7 +203,7 @@ public class UserAuthWithAuthenticator { "Content-Length: " + data.length() + "\r\n\r\n" + data }; - static final String[] proxyWithErrorResponses = { + private static final String[] proxyWithErrorResponses = { "HTTP/1.1 407 Proxy Authentication Required\r\n"+ "Content-Length: 0\r\n" + "Proxy-Authenticate: Basic realm=\"Access to the proxy\"\r\n\r\n" @@ -202,14 +213,14 @@ public class UserAuthWithAuthenticator { "Proxy-Authenticate: Basic realm=\"Access to the proxy\"\r\n\r\n" }; - static final String[] serverResponses = { + private static final String[] serverResponses = { "HTTP/1.1 200 OK\r\n"+ "Date: Mon, 15 Jan 2001 12:18:21 GMT\r\n" + "Server: Apache/1.3.14 (Unix)\r\n" + "Content-Length: " + data1.length() + "\r\n\r\n" + data1 }; - static final String[] authenticatorResponses = { + private static final String[] authenticatorResponses = { "HTTP/1.1 401 Authentication Required\r\n"+ "Content-Length: 0\r\n" + "WWW-Authenticate: Basic realm=\"Access to the server\"\r\n\r\n" @@ -220,119 +231,97 @@ public class UserAuthWithAuthenticator { "Content-Length: " + data1.length() + "\r\n\r\n" + data1 }; - public static void main(String[] args) throws Exception { - testServerOnly(); - testServerWithProxy(); - testServerWithProxyError(); - testServerOnlyAuthenticator(); - h2Test(true, true); - h2Test(false, true); - h2Test(true, false); - } - - static void testServerWithProxy() throws IOException, InterruptedException { - Mocker proxyMock = new Mocker(proxyResponses); - proxyMock.start(); + @Test + void h1TestServerWithProxy() throws IOException, InterruptedException { ProxyAuth p = new ProxyAuth(); - try (var client = HttpClient.newBuilder() + try (var proxyMock = new Mocker(proxyResponses); + var client = HttpClient.newBuilder() .version(java.net.http.HttpClient.Version.HTTP_1_1) .proxy(new ProxySel(proxyMock.getPort())) .authenticator(p) .build()) { - var plainCreds = "user:pwd"; - var encoded = java.util.Base64.getEncoder().encodeToString(plainCreds.getBytes(US_ASCII)); + var authHeaderValue = authHeaderValue("user", "pwd"); var request = HttpRequest.newBuilder().uri(URI.create("http://127.0.0.1/some_url")) .setHeader("User-Agent", "myUserAgent") - .setHeader("Authorization", AUTH_PREFIX + encoded) + .setHeader("Authorization", authHeaderValue) .build(); var response = client.send(request, HttpResponse.BodyHandlers.ofString()); assertEquals(200, response.statusCode()); - assertTrue(p.wasCalled(), "Proxy authenticator was not called"); + assertTrue(p.called, "Proxy authenticator was not called"); assertEquals(data, response.body()); - var proxyStr = proxyMock.getRequest(1); + var proxyStr = proxyMock.requests.get(1); assertContains(proxyStr, "/some_url"); - assertPattern(".*^Proxy-Authorization:.*Basic " + encoded + ".*", proxyStr); + assertPattern(".*^Proxy-Authorization:.*\\Q" + authHeaderValue + "\\E.*", proxyStr); assertPattern(".*^User-Agent:.*myUserAgent.*", proxyStr); assertPattern(".*^Authorization:.*Basic.*", proxyStr); - System.out.println("testServerWithProxy: OK"); - } finally { - proxyMock.stopMocker(); } } - static void testServerWithProxyError() throws IOException, InterruptedException { - Mocker proxyMock = new Mocker(proxyWithErrorResponses); - proxyMock.start(); + @Test + void h1TestServerWithProxyError() throws IOException, InterruptedException { ProxyAuth p = new ProxyAuth(); - try (var client = HttpClient.newBuilder() + try (var proxyMock = new Mocker(proxyWithErrorResponses); + var client = HttpClient.newBuilder() .version(java.net.http.HttpClient.Version.HTTP_1_1) .proxy(new ProxySel(proxyMock.getPort())) .authenticator(p) .build()) { - var badCreds = "user:wrong"; - var encoded1 = java.util.Base64.getEncoder().encodeToString(badCreds.getBytes(US_ASCII)); + var authHeaderValue = authHeaderValue("user", "wrong"); var request = HttpRequest.newBuilder().uri(URI.create("http://127.0.0.1/some_url")) .setHeader("User-Agent", "myUserAgent") - .setHeader("Proxy-Authorization", AUTH_PREFIX + encoded1) + .setHeader("Proxy-Authorization", authHeaderValue) .build(); var response = client.send(request, HttpResponse.BodyHandlers.ofString()); - var proxyStr = proxyMock.getRequest(0); + var proxyStr = proxyMock.requests.getFirst(); assertEquals(407, response.statusCode()); - assertPattern(".*^Proxy-Authorization:.*Basic " + encoded1 + ".*", proxyStr); - assertTrue(!p.wasCalled(), "Proxy Auth should not have been called"); - System.out.println("testServerWithProxyError: OK"); - } finally { - proxyMock.stopMocker(); + assertPattern(".*^Proxy-Authorization:.*\\Q" + authHeaderValue + "\\E.*", proxyStr); + assertFalse(p.called, "Proxy Auth should not have been called"); } } - static void testServerOnly() throws IOException, InterruptedException { - Mocker serverMock = new Mocker(serverResponses); - serverMock.start(); - try (var client = HttpClient.newBuilder() + @Test + void h1TestServerOnly() throws IOException, InterruptedException { + try (var serverMock = new Mocker(serverResponses); + var client = HttpClient.newBuilder() .version(java.net.http.HttpClient.Version.HTTP_1_1) .build()) { - var plainCreds = "user:pwd"; - var encoded = java.util.Base64.getEncoder().encodeToString(plainCreds.getBytes(US_ASCII)); + var authHeaderValue = authHeaderValue("user", "pwd"); var request = HttpRequest.newBuilder().uri(URI.create(serverMock.baseURL() + "/some_serv_url")) .setHeader("User-Agent", "myUserAgent") - .setHeader("Authorization", AUTH_PREFIX + encoded) + .setHeader("Authorization", authHeaderValue) .build(); var response = client.send(request, HttpResponse.BodyHandlers.ofString()); assertEquals(200, response.statusCode()); assertEquals(data1, response.body()); - var serverStr = serverMock.getRequest(0); + var serverStr = serverMock.requests.getFirst(); assertContains(serverStr, "/some_serv_url"); assertPattern(".*^User-Agent:.*myUserAgent.*", serverStr); - assertPattern(".*^Authorization:.*Basic " + encoded + ".*", serverStr); - System.out.println("testServerOnly: OK"); - } finally { - serverMock.stopMocker(); + assertPattern(".*^Authorization:.*\\Q" + authHeaderValue + "\\E.*", serverStr); } } - // This is effectively a regression test for existing behavior - static void testServerOnlyAuthenticator() throws IOException, InterruptedException { - Mocker serverMock = new Mocker(authenticatorResponses); - serverMock.start(); - try (var client = HttpClient.newBuilder() + /** + * A regression test for existing behavior. + */ + @Test + void h1TestServerOnlyAuthenticator() throws IOException, InterruptedException { + try (var serverMock = new Mocker(authenticatorResponses); + var client = HttpClient.newBuilder() .version(java.net.http.HttpClient.Version.HTTP_1_1) .authenticator(new ServerAuth()) .build()) { // credentials set in the server authenticator - var plainCreds = "serverUser:serverPwd"; - var encoded = java.util.Base64.getEncoder().encodeToString(plainCreds.getBytes(US_ASCII)); var request = HttpRequest.newBuilder().uri(URI.create(serverMock.baseURL() + "/some_serv_url")) .setHeader("User-Agent", "myUserAgent") .build(); @@ -341,42 +330,43 @@ public class UserAuthWithAuthenticator { assertEquals(200, response.statusCode()); assertEquals(data1, response.body()); - var serverStr = serverMock.getRequest(1); + var serverStr = serverMock.requests.get(1); assertContains(serverStr, "/some_serv_url"); assertPattern(".*^User-Agent:.*myUserAgent.*", serverStr); - assertPattern(".*^Authorization:.*Basic " + encoded + ".*", serverStr); - System.out.println("testServerOnlyAuthenticator: OK"); - } finally { - serverMock.stopMocker(); + assertPattern(".*^Authorization:.*\\Q" + authHeaderValue("serverUser", "serverPwd") + "\\E.*", serverStr); } } - static void close(Closeable... clarray) { - for (Closeable c : clarray) { - try { - c.close(); - } catch (Exception e) {} - } - } + private static final class Mocker extends Thread implements AutoCloseable { + private final ServerSocket ss; + private final String[] responses; + private final List requests; + private volatile InputStream in; + private volatile OutputStream out; + private volatile Socket s = null; - static class Mocker extends Thread { - final ServerSocket ss; - final String[] responses; - volatile List requests; - volatile InputStream in; - volatile OutputStream out; - volatile Socket s = null; - - public Mocker(String[] responses) throws IOException { + private Mocker(String[] responses) throws IOException { this.ss = new ServerSocket(0, 0, InetAddress.getLoopbackAddress()); this.responses = responses; this.requests = new LinkedList<>(); + start(); } - public void stopMocker() { + @Override + public void close() { close(ss, s, in, out); } + private static void close(Closeable... clarray) { + for (Closeable c : clarray) { + try { + c.close(); + } catch (Exception e) { + // Do nothing + } + } + } + public int getPort() { return ss.getLocalPort(); } @@ -404,15 +394,12 @@ public class UserAuthWithAuthenticator { in = s.getInputStream(); out = s.getOutputStream(); } - req += (char)x; + // noinspection StringConcatenationInLoop + req += (char) x; } return req; } - public String getRequest(int i) { - return requests.get(i); - } - public void run() { try { int index=0; @@ -424,15 +411,17 @@ public class UserAuthWithAuthenticator { out.write(responses[index++].getBytes(US_ASCII)); } } catch (Exception e) { + // noinspection CallToPrintStackTrace e.printStackTrace(); } } } - static class ProxySel extends ProxySelector { - final int port; + private static final class ProxySel extends ProxySelector { - ProxySel(int port) { + private final int port; + + private ProxySel(int port) { this.port = port; } @Override @@ -446,7 +435,8 @@ public class UserAuthWithAuthenticator { } - static class ProxyAuth extends Authenticator { + private static final class ProxyAuth extends Authenticator { + private volatile boolean called = false; @Override @@ -455,16 +445,17 @@ public class UserAuthWithAuthenticator { return new PasswordAuthentication("proxyUser", "proxyPwd".toCharArray()); } - boolean wasCalled() { - return called; - } } - static class ServerAuth extends Authenticator { + private static final class ServerAuth extends Authenticator { + private volatile boolean called = false; - private static String USER = "serverUser"; - private static String PASS = "serverPwd"; + private static final String USER = "serverUser"; + + private static final String PASS = "serverPwd"; + + private static final String AUTH_HEADER_VALUE = authHeaderValue(USER, PASS); @Override protected PasswordAuthentication getPasswordAuthentication() { @@ -476,49 +467,21 @@ public class UserAuthWithAuthenticator { return new PasswordAuthentication(USER, PASS.toCharArray()); } - String authValue() { - var plainCreds = USER + ":" + PASS; - return java.util.Base64.getEncoder().encodeToString(plainCreds.getBytes(US_ASCII)); - } - - boolean wasCalled() { - return called; - } } - static void assertTrue(boolean assertion, String failMsg) { - if (!assertion) { - throw new RuntimeException(failMsg); - } + private static String authHeaderValue(String username, String password) { + String credentials = username + ':' + password; + return "Basic " + java.util.Base64.getEncoder().encodeToString(credentials.getBytes(US_ASCII)); } - static void assertEquals(int a, int b) { - if (a != b) { - String msg = String.format("Error: expected %d Got %d", a, b); - throw new RuntimeException(msg); - } + private static void assertContains(String container, String containee) { + assertTrue(container.contains(containee), String.format("Error: expected %s Got %s", container, containee)); } - static void assertEquals(String s1, String s2) { - if (!s1.equals(s2)) { - String msg = String.format("Error: expected %s Got %s", s1, s2); - throw new RuntimeException(msg); - } - } - - static void assertContains(String container, String containee) { - if (!container.contains(containee)) { - String msg = String.format("Error: expected %s Got %s", container, containee); - throw new RuntimeException(msg); - } - } - - static void assertPattern(String pattern, String candidate) { + private static void assertPattern(String pattern, String candidate) { Pattern pat = Pattern.compile(pattern, Pattern.DOTALL | Pattern.MULTILINE); Matcher matcher = pat.matcher(candidate); - if (!matcher.matches()) { - String msg = String.format("Error: expected %s Got %s", pattern, candidate); - throw new RuntimeException(msg); - } + assertTrue(matcher.matches(), String.format("Error: expected %s Got %s", pattern, candidate)); } + } diff --git a/test/jdk/java/net/httpclient/UserCookieTest.java b/test/jdk/java/net/httpclient/UserCookieTest.java index b8de5f97955..50a13e48322 100644 --- a/test/jdk/java/net/httpclient/UserCookieTest.java +++ b/test/jdk/java/net/httpclient/UserCookieTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -66,11 +66,7 @@ import java.util.stream.Stream; import javax.net.ServerSocketFactory; import javax.net.ssl.SSLContext; import jdk.httpclient.test.lib.common.HttpServerAdapters; -import jdk.httpclient.test.lib.http2.Http2TestServer; -import com.sun.net.httpserver.HttpServer; -import com.sun.net.httpserver.HttpsConfigurator; -import com.sun.net.httpserver.HttpsServer; import jdk.test.lib.net.SimpleSSLContext; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; @@ -80,22 +76,27 @@ import org.testng.annotations.Test; import static java.lang.System.out; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; public class UserCookieTest implements HttpServerAdapters { SSLContext sslContext; - HttpTestServer httpTestServer; // HTTP/1.1 [ 6 servers ] + HttpTestServer httpTestServer; // HTTP/1.1 [ 7 servers ] HttpTestServer httpsTestServer; // HTTPS/1.1 HttpTestServer http2TestServer; // HTTP/2 ( h2c ) HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + HttpTestServer http3TestServer; // HTTP/3 ( h3 ) DummyServer httpDummyServer; DummyServer httpsDummyServer; String httpURI; String httpsURI; String http2URI; String https2URI; + String http3URI; String httpDummy; String httpsDummy; @@ -113,6 +114,7 @@ public class UserCookieTest implements HttpServerAdapters { @DataProvider(name = "positive") public Object[][] positive() { return new Object[][] { + { http3URI, HTTP_3 }, { httpURI, HTTP_1_1 }, { httpsURI, HTTP_1_1 }, { httpDummy, HTTP_1_1 }, @@ -134,7 +136,10 @@ public class UserCookieTest implements HttpServerAdapters { ConcurrentHashMap> cookieHeaders = new ConcurrentHashMap<>(); CookieHandler cookieManager = new TestCookieHandler(cookieHeaders); - HttpClient client = HttpClient.newBuilder() + var builder = version == HTTP_3 + ? newClientBuilderForH3() + : HttpClient.newBuilder(); + HttpClient client = builder .followRedirects(Redirect.ALWAYS) .cookieHandler(cookieManager) .sslContext(sslContext) @@ -159,6 +164,9 @@ public class UserCookieTest implements HttpServerAdapters { .header("Cookie", userCookie); if (version != null) { requestBuilder.version(version); + if (version == HTTP_3) { + requestBuilder.setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } } HttpRequest request = requestBuilder.build(); out.println("Initial request: " + request.uri()); @@ -181,9 +189,13 @@ public class UserCookieTest implements HttpServerAdapters { .header("Cookie", userCookie); if (version != null) { requestBuilder.version(version); + if (version == HTTP_3) { + requestBuilder.setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + } } request = requestBuilder.build(); } + client.close(); } // -- Infrastructure @@ -208,6 +220,10 @@ public class UserCookieTest implements HttpServerAdapters { https2TestServer.addHandler(new CookieValidationHandler(), "/https2/cookie/"); https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/cookie/retry"; + http3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new CookieValidationHandler(), "/http3/cookie/"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/cookie/retry"; + InetSocketAddress sa = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0); // DummyServer httpDummyServer = DummyServer.create(sa); @@ -219,6 +235,7 @@ public class UserCookieTest implements HttpServerAdapters { httpsTestServer.start(); http2TestServer.start(); https2TestServer.start(); + http3TestServer.start(); httpDummyServer.start(); httpsDummyServer.start(); } @@ -229,6 +246,7 @@ public class UserCookieTest implements HttpServerAdapters { httpsTestServer.stop(); http2TestServer.stop(); https2TestServer.stop(); + http3TestServer.stop(); httpsDummyServer.stopServer(); httpsDummyServer.stopServer(); } diff --git a/test/jdk/java/net/httpclient/VersionTest.java b/test/jdk/java/net/httpclient/VersionTest.java index d09ce0354d2..ff864202a9a 100644 --- a/test/jdk/java/net/httpclient/VersionTest.java +++ b/test/jdk/java/net/httpclient/VersionTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2017, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -49,6 +49,7 @@ import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; public class VersionTest { static HttpServer s1 ; @@ -80,6 +81,8 @@ public class VersionTest { test(HTTP_1_1, false); test(HTTP_2, false); test(HTTP_2, true); + test(HTTP_3, false); + test(HTTP_3, true); } finally { s1.stop(0); executor.shutdownNow(); diff --git a/test/jdk/java/net/httpclient/access/java.net.http/jdk/internal/net/http/Http3ConnectionAccess.java b/test/jdk/java/net/httpclient/access/java.net.http/jdk/internal/net/http/Http3ConnectionAccess.java new file mode 100644 index 00000000000..3e58990fb90 --- /dev/null +++ b/test/jdk/java/net/httpclient/access/java.net.http/jdk/internal/net/http/Http3ConnectionAccess.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http; + +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.NoSuchElementException; +import java.util.concurrent.CompletableFuture; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.http3.ConnectionSettings; + +public final class Http3ConnectionAccess { + + private Http3ConnectionAccess() { + throw new AssertionError(); + } + + static HttpClientImpl impl(HttpClient client) { + if (client instanceof HttpClientImpl impl) return impl; + if (client instanceof HttpClientFacade facade) return facade.impl; + return null; + } + + static HttpRequestImpl impl(HttpRequest request) { + if (request instanceof HttpRequestImpl impl) return impl; + return null; + } + + public static CompletableFuture peerSettings(HttpClient client, HttpResponse resp) { + try { + Http3Connection conn = impl(client) + .client3() + .get() + .findPooledConnectionFor(impl(resp.request()), null); + if (conn == null) { + return MinimalFuture.failedFuture(new NoSuchElementException("no connection found")); + } + return conn.peerSettingsCF(); + } catch (Exception ex) { + return MinimalFuture.failedFuture(ex); + } + } +} diff --git a/test/jdk/java/net/httpclient/access/java.net.http/jdk/internal/net/http/common/ImmutableSSLSessionAccess.java b/test/jdk/java/net/httpclient/access/java.net.http/jdk/internal/net/http/common/ImmutableSSLSessionAccess.java new file mode 100644 index 00000000000..ab915428dc9 --- /dev/null +++ b/test/jdk/java/net/httpclient/access/java.net.http/jdk/internal/net/http/common/ImmutableSSLSessionAccess.java @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.common; + +import javax.net.ssl.ExtendedSSLSession; +import javax.net.ssl.SSLSession; + +public final class ImmutableSSLSessionAccess { + + private ImmutableSSLSessionAccess() { + throw new AssertionError(); + } + + public static ImmutableSSLSession immutableSSLSession(SSLSession session) { + return new ImmutableSSLSession(session); + } + + public static ImmutableExtendedSSLSession immutableExtendedSSLSession(ExtendedSSLSession session) { + return new ImmutableExtendedSSLSession(session); + } + +} diff --git a/test/jdk/java/net/httpclient/altsvc/AltServiceReasonableAssurance.java b/test/jdk/java/net/httpclient/altsvc/AltServiceReasonableAssurance.java new file mode 100644 index 00000000000..0da1b238f60 --- /dev/null +++ b/test/jdk/java/net/httpclient/altsvc/AltServiceReasonableAssurance.java @@ -0,0 +1,688 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.SecureRandom; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; + +import jdk.httpclient.test.lib.common.DynamicKeyStoreUtil; +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.common.ServerNameMatcher; +import jdk.httpclient.test.lib.http2.Http2TestServer; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.httpclient.test.lib.quic.QuicServer; +import jdk.test.lib.net.URIBuilder; +import org.junit.jupiter.api.Test; +import static java.net.http.HttpClient.Builder.NO_PROXY; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static jdk.httpclient.test.lib.common.DynamicKeyStoreUtil.generateCert; +import static jdk.httpclient.test.lib.common.DynamicKeyStoreUtil.generateKeyStore; +import static jdk.httpclient.test.lib.common.DynamicKeyStoreUtil.generateRSAKeyPair; +import static jdk.httpclient.test.lib.http3.Http3TestServer.quicServerBuilder; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/* + * @test + * @summary verifies the HttpClient's usage of alternate services + * @comment The goal of this test class is to run various tests to verify that the HttpClient + * (and the underlying layers) use an alternate server for HTTP request(s) IF AND ONLY IF such an + * advertised alternate server satisfies "reasonable assurance" expectations as noted in the + * alternate service RFC-7838. Reasonable assurance can be summarized as: + * - The origin server which advertised the alternate service, MUST be running on TLS + * - The certificate presented by origin server during TLS handshake must be valid (and trusted) + * for the origin server + * - The certificate presented by alternate server (when subsequently a connection attempt is + * made to it) MUST be valid (and trusted) for the ORIGIN server + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.httpclient.test.lib.common.DynamicKeyStoreUtil + * jdk.test.lib.net.URIBuilder + * @modules java.base/sun.net.www.http + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.frame + * java.net.http/jdk.internal.net.http.hpack + * java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.packets + * java.net.http/jdk.internal.net.http.quic.frames + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * @modules java.base/sun.security.x509 + * java.base/jdk.internal.util + * @run junit/othervm -Djdk.net.hosts.file=${test.src}/altsvc-dns-hosts.txt + * -Djdk.internal.httpclient.debug=true -Djavax.net.debug=all + * -Djdk.httpclient.HttpClient.log=requests,responses,errors + * AltServiceReasonableAssurance + */ +public class AltServiceReasonableAssurance implements HttpServerAdapters { + + private static final String ORIGIN_SERVER_HOSTNAME = "origin.server"; + private static final String ALT_SERVER_HOSTNAME = "altservice.server"; + + private static final String ALT_SERVER_RESPONSE_MESSAGE = "Hello from an alt server"; + private static final String ORIGIN_SERVER_RESPONSE_MESSAGE = "Hello from an origin server"; + + private record TestInput(HttpTestServer originServer, HttpTestServer altServer, + URI requestURI, String expectedAltSvcHeader) { + } + + /** + * Creates and starts a origin server and an alternate server. The passed (same) SSLContext + * is used by both the origin server and the alternate server. + */ + private static TestInput startOriginAndAltServer(final SSLContext sslContext) + throws Exception { + Objects.requireNonNull(sslContext); + return startOriginAndAltServer(sslContext, sslContext); + } + + /** + * Creates and starts a origin server and an alternate server. The origin server will use + * the {@code originSrvSSLCtx} and the alternate server will use the {@code altSrvSSLCtx} + */ + private static TestInput startOriginAndAltServer(final SSLContext originSrvSSLCtx, + final SSLContext altSrvSSLCtx) + throws Exception { + Objects.requireNonNull(originSrvSSLCtx); + Objects.requireNonNull(altSrvSSLCtx); + final String requestPath = "/hello"; + final QuicServer quicServer = quicServerBuilder() + .sslContext(altSrvSSLCtx) + // the client sends a SNI for origin server. this alt server should be capable + // of matching/accepting that SNI name of the origin + .sniMatcher(new ServerNameMatcher(ORIGIN_SERVER_HOSTNAME)) + .build(); + // Alt server only supports H3 + final HttpTestServer altServer = HttpTestServer.of(new Http3TestServer(quicServer)); + altServer.addHandler(new Handler(ALT_SERVER_RESPONSE_MESSAGE), requestPath); + altServer.start(); + System.out.println("Alt server started at " + altServer.getAddress()); + + // H2 server which has a (application level) handler which advertises H3 alt service + final HttpTestServer originServer = HttpTestServer.of( + new Http2TestServer(ORIGIN_SERVER_HOSTNAME, true, originSrvSSLCtx)); + final int altServerPort = altServer.getAddress().getPort(); + final String altSvcHeaderVal = "h3=\"" + ALT_SERVER_HOSTNAME + ":" + altServerPort + "\""; + originServer.addHandler(new Handler(ORIGIN_SERVER_RESPONSE_MESSAGE, altSvcHeaderVal), + requestPath); + originServer.start(); + System.out.println("Origin server started at " + originServer.getAddress()); + // request URI should be directed to the origin server + final URI requestURI = URIBuilder.newBuilder() + .scheme("https") + .host(ORIGIN_SERVER_HOSTNAME) + .port(originServer.getAddress().getPort()) + .path(requestPath) + .build(); + return new TestInput(originServer, altServer, requestURI, altSvcHeaderVal); + } + + private TestInput startHttpOriginHttpsAltServer(final SSLContext altServerSSLCtx) + throws Exception { + Objects.requireNonNull(altServerSSLCtx); + final String requestPath = "/foo"; + // Alt server only supports H3 + final HttpTestServer altServer = HttpTestServer.create(HTTP_3_URI_ONLY, altServerSSLCtx); + altServer.addHandler(new Handler(ALT_SERVER_RESPONSE_MESSAGE), requestPath); + altServer.start(); + System.out.println("Alt server (HTTPS) started at " + altServer.getAddress()); + + // supports only HTTP server and uses a (application level) handler which advertises a H3 + // alternate service + final HttpTestServer originServer = HttpTestServer.create(HTTP_2); + final int altServerPort = altServer.getAddress().getPort(); + final String altSvcHeaderVal = "h3=\"" + ALT_SERVER_HOSTNAME + ":" + altServerPort + "\""; + originServer.addHandler(new Handler(ORIGIN_SERVER_RESPONSE_MESSAGE, altSvcHeaderVal), + requestPath); + originServer.start(); + System.out.println("Origin server (HTTP) started at " + originServer.getAddress()); + // request URI should be against (HTTP) origin server + final URI requestURI = URIBuilder.newBuilder() + .scheme("http") + .host(ORIGIN_SERVER_HOSTNAME) + .port(originServer.getAddress().getPort()) + .path(requestPath) + .build(); + return new TestInput(originServer, altServer, requestURI, altSvcHeaderVal); + } + + /** + * Stop the server (and ignore any exception) + */ + private static void safeStop(final HttpTestServer server) { + if (server == null) { + return; + } + final InetSocketAddress serverAddr = server.getAddress(); + try { + System.out.println("Stopping server " + serverAddr); + server.stop(); + } catch (Exception e) { + System.err.println("Ignoring exception: " + e.getMessage() + " that occurred " + + "during stop of server: " + serverAddr); + } + } + + /** + * Returns back a 200 HTTP response with a response body containing a response message + * that was used to construct the Handler instance. Additionally, if the Handler was constructed + * with a non-null {@code altSvcHeaderVal} then that value is sent back as a header value. in + * the response, for the {@code alt-svc} header + */ + private static final class Handler implements HttpTestHandler { + private final String responseMessage; + private final byte[] responseBytes; + private final String altSvcHeaderVal; + + private Handler(final String responseMessage) { + this(responseMessage, null); + } + + private Handler(final String responseMessage, final String altSvcHeaderVal) { + Objects.requireNonNull(responseMessage); + this.responseMessage = responseMessage; + this.responseBytes = responseMessage.getBytes(StandardCharsets.UTF_8); + this.altSvcHeaderVal = altSvcHeaderVal; + } + + @Override + public void handle(final HttpTestExchange exchange) throws IOException { + System.out.println("Handling request " + exchange.getRequestURI()); + if (this.altSvcHeaderVal != null) { + System.out.println("Responding with alt-svc header: " + this.altSvcHeaderVal); + exchange.getResponseHeaders().addHeader("alt-svc", this.altSvcHeaderVal); + } + System.out.println("Responding with body: " + this.responseMessage); + exchange.sendResponseHeaders(200, this.responseBytes.length); + try (final OutputStream os = exchange.getResponseBody()) { + os.write(this.responseBytes); + } + } + } + + /** + * - Keystore K1 is constructed with a certificate whose subject is origin server hostname and + * subject alternative name is alternate server hostname + * - K1 is used to construct a SSLContext and thus the SSLContext uses the keys and trusted + * certificate from this keystore + * - The constructed SSLContext instance is used by the HttpClient, the origin server and the + * alternate server + * - During TLS handshake with origin server, the origin server is expected to present the + * certificate from this K1 keystore. + * - During TLS handshake with alternate server, the alternate server is expected to present + * this same certificate from K1 keystore. + * - Since the certificate is valid (and trusted by the client) for both origin server + * and alternate server (because of the valid subject name and subject alternate name), + * the TLS handshake between the HttpClient and the origin and alternate server is expected + * to pass + *

    + * Once the servers are started, this test method does the following: + *

    + * - Client constructs a HTTP_3 request addressed to origin server + * - Origin server responds with a 200 response and also with alt-svc header pointing to + * an alternate server + * - Client verifies the response as well as presence of the alt-svc header value + * - Client issues the *same* request again + * - The request is expected to be handled by the alternate server + */ + @Test + public void testOriginAltSameCert() throws Exception { + // create a keystore which contains a PrivateKey entry and a certificate associated with + // that key. the certificate's subject will be origin server's hostname and will + // additionally have the alt server hostname as a subject alternate name. Thus, the + // certificate is valid for both origin server and alternate server + final KeyStore keyStore = generateKeyStore(ORIGIN_SERVER_HOSTNAME, ALT_SERVER_HOSTNAME); + System.out.println("Generated a keystore with certificate: " + + keyStore.getCertificate(DynamicKeyStoreUtil.DEFAULT_ALIAS)); + // create a SSLContext that will be used by the servers and the HttpClient and will be + // backed by the keystore we just created. Thus, the HttpClient will trust the certificate + // belonging to that keystore + final SSLContext sslContext = DynamicKeyStoreUtil.createSSLContext(keyStore); + // start the servers + final TestInput testInput = startOriginAndAltServer(sslContext); + try { + final HttpClient client = newClientBuilderForH3() + .proxy(NO_PROXY) + .sslContext(sslContext) + .version(HTTP_3) + .build(); + // send a HTTP3 request to a server which is expected to respond back + // with a 200 response and an alt-svc header pointing to another/different H3 server + final URI requestURI = testInput.requestURI; + final HttpRequest request = HttpRequest.newBuilder() + .GET().uri(requestURI) + .setOption(H3_DISCOVERY, ALT_SVC) + .build(); + System.out.println("Issuing request " + requestURI); + final HttpResponse response = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + assertEquals(200, response.statusCode(), "Unexpected response code"); + // the origin server is expected to respond + assertEquals(ORIGIN_SERVER_RESPONSE_MESSAGE, response.body(), "Unexpected response" + + " body"); + assertEquals(HTTP_2, response.version(), "Unexpected HTTP version in response"); + + // verify the origin server sent back a alt-svc header + final Optional altSvcHeader = response.headers().firstValue("alt-svc"); + assertTrue(altSvcHeader.isPresent(), "alt-svc header is missing in response"); + final String actualAltSvcHeader = altSvcHeader.get(); + System.out.println("Received alt-svc header value: " + actualAltSvcHeader); + assertTrue(actualAltSvcHeader.contains(testInput.expectedAltSvcHeader), + "Unexpected alt-svc header value: " + actualAltSvcHeader + + ", was expected to contain: " + testInput.expectedAltSvcHeader); + + // now issue the same request again and this time expect it to be handled + // by the alt-service + System.out.println("Again issuing request " + requestURI); + final HttpResponse secondResponse = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + assertEquals(200, secondResponse.statusCode(), "Unexpected response code"); + // expect the alt service to respond + assertEquals(ALT_SERVER_RESPONSE_MESSAGE, secondResponse.body(), "Unexpected response" + + " body"); + assertEquals(HTTP_3, secondResponse.version(), "Unexpected HTTP version in response"); + } finally { + safeStop(testInput.originServer); + safeStop(testInput.altServer); + } + } + + /** + * - Keystore K1 is constructed with a PrivateKey PK1 and certificate whose subject is + * origin server hostname + * - Keystore K2 is constructed with the same PrivateKey PK1 and certificate whose subject is + * alternate server hostname AND has a subject alternate name of origin server hostname + * - K1 is used to construct a SSLContext S1 and that S1 is used by origin server + * - K2 is used to construct a SSLContext S2 and that S2 is used by alternate server + * - SSLContext S3 is constructed with both the certificate of origin server and + * the certificate of alternate server as trusted certificates. HttpClient uses S3 + * - During TLS handshake with origin server, the origin server is expected to present the + * certificate from this K1 keystore, with subject as origin server hostname + * - During TLS handshake with alternate server, the alternate server is expected to present + * the certificate from K2 keystore, with subject as alternate server hostname AND a subject + * alternate name of origin server + * - HttpClient (through S3 SSLContext) trusts both these certs. The cert presented + * by alt server, is valid (even) for origin server (since its subject alternate name is + * origin server hostname). Thus, the client must consider the alternate service as valid and + * use it. + *

    + * Once the servers are started, this test method does the following: + *

    + * - Client constructs a HTTP_3 request addressed to origin server + * - Origin server responds with a 200 response and also with alt-svc header pointing to + * an alternate server + * - Client verifies the response as well as presence of the alt-svc header value + * - Client issues the *same* request again + * - The request is expected to be handled by the alternate server + */ + @Test + public void testOriginAltDifferentCert() throws Exception { + final SecureRandom secureRandom = new SecureRandom(); + final KeyPair keyPair = generateRSAKeyPair(secureRandom); + + // generate a certificate for origin server, with origin server hostname as the subject + final X509Certificate originServerCert = generateCert(keyPair, secureRandom, + ORIGIN_SERVER_HOSTNAME); + // create a keystore with the private key and the cert. this keystore will then be + // used by the SSLContext of origin server + final KeyStore originServerKeyStore = generateKeyStore(keyPair.getPrivate(), + new Certificate[]{originServerCert}); + System.out.println("Generated a keystore, for origin server, with certificate: " + + originServerKeyStore.getCertificate(DynamicKeyStoreUtil.DEFAULT_ALIAS)); + // create the SSLContext for the origin server + final SSLContext originServerSSLCtx = DynamicKeyStoreUtil.createSSLContext( + originServerKeyStore); + + // create a cert for the alternate server, with alternate server hostname as the subject + // AND origin server hostname as a subject alternate name + final X509Certificate altServerCert = generateCert(keyPair, secureRandom, + ALT_SERVER_HOSTNAME, ORIGIN_SERVER_HOSTNAME); + // create keystore with the private key and the alt server's cert. this keystore will then + // be used by the SSLContext of alternate server + final KeyStore altServerKeyStore = generateKeyStore(keyPair.getPrivate(), + new Certificate[]{altServerCert}); + System.out.println("Generated a keystore, for alt server, with certificate: " + + altServerKeyStore.getCertificate(DynamicKeyStoreUtil.DEFAULT_ALIAS)); + // create SSLContext of alternate server + final SSLContext altServerSSLCtx = DynamicKeyStoreUtil.createSSLContext(altServerKeyStore); + + // now create a SSLContext for the HttpClient. This SSLContext will contain no key manager + // and will have a trust manager which trusts origin server certificate and the alternate + // server certificate + final SSLContext clientSSLCtx = sslCtxWithTrustedCerts(List.of(originServerCert, + altServerCert)); + // start the servers + final TestInput testInput = startOriginAndAltServer(originServerSSLCtx, altServerSSLCtx); + try { + final HttpClient client = newClientBuilderForH3() + .proxy(NO_PROXY) + .sslContext(clientSSLCtx) + .version(HTTP_3) + .build(); + // send a HTTP3 request to a server which is expected to respond back + // with a 200 response and an alt-svc header pointing to another/different H3 server + final URI requestURI = testInput.requestURI; + final HttpRequest request = HttpRequest.newBuilder() + .GET().uri(requestURI) + .setOption(H3_DISCOVERY, ALT_SVC) + .build(); + System.out.println("Issuing request " + requestURI); + final HttpResponse response = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + assertEquals(200, response.statusCode(), "Unexpected response code"); + // the origin server is expected to respond + assertEquals(ORIGIN_SERVER_RESPONSE_MESSAGE, response.body(), "Unexpected response" + + " body"); + assertEquals(HTTP_2, response.version(), "Unexpected HTTP version in response"); + + // verify the origin server sent back a alt-svc header + final Optional altSvcHeader = response.headers().firstValue("alt-svc"); + assertTrue(altSvcHeader.isPresent(), "alt-svc header is missing in response"); + final String actualAltSvcHeader = altSvcHeader.get(); + System.out.println("Received alt-svc header value: " + actualAltSvcHeader); + assertTrue(actualAltSvcHeader.contains(testInput.expectedAltSvcHeader), + "Unexpected alt-svc header value: " + actualAltSvcHeader + + ", was expected to contain: " + testInput.expectedAltSvcHeader); + + // now issue the same request again and this time expect it to be handled + // by the alt-service + System.out.println("Again issuing request " + requestURI); + final HttpResponse secondResponse = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + assertEquals(200, secondResponse.statusCode(), "Unexpected response code"); + // expect the alt service to respond + assertEquals(ALT_SERVER_RESPONSE_MESSAGE, secondResponse.body(), "Unexpected response" + + " body"); + assertEquals(HTTP_3, secondResponse.version(), "Unexpected HTTP version in response"); + } finally { + safeStop(testInput.originServer); + safeStop(testInput.altServer); + } + } + + + /** + * - Keystore K1 is constructed with a PrivateKey PK1 and certificate whose subject is + * origin server hostname + * - Keystore K2 is constructed with the same PrivateKey PK1 and certificate whose subject is + * alternate server hostname + * - K1 is used to construct a SSLContext S1 and that S1 is used by origin server + * - K2 is used to construct a SSLContext S2 and that S2 is used by alternate server + * - SSLContext S3 is constructed with both the certificate of origin server and + * the certificate of alternate server as trusted certificates. HttpClient uses S3 + * - During TLS handshake with origin server, the origin server is expected to present the + * certificate from this K1 keystore, with subject as origin server hostname + * - During TLS handshake with alternate server, the alternate server is expected to present + * the certificate from K2 keystore, with subject as alternate server hostname + * - HttpClient (through S3 SSLContext) trusts both these certs, but the cert presented + * by alt server, although valid for the alt server, CANNOT/MUST NOT be valid for origin + * server (since it's subject nor subject alternate name is origin server hostname). + * Reasonable assurance expects that the alt server present a certificate that is valid for + * origin server host and since it doesn't, the alt server must not be used by the HttpClient. + *

    + * Once the servers are started, this test method does the following: + *

    + * - Client constructs a HTTP_3 request addressed to origin server + * - Origin server responds with a 200 response and also with alt-svc header pointing to + * an alternate server + * - Client verifies the response as well as presence of the alt-svc header value + * - Client issues the *same* request again + * - The request is expected to be handled by the origin server again and the advertised + * alternate service MUST NOT be used (due to reasons noted above) + */ + @Test + public void testAltServerWrongCert() throws Exception { + final SecureRandom secureRandom = new SecureRandom(); + final KeyPair keyPair = generateRSAKeyPair(secureRandom); + + // generate a certificate for origin server, with origin server hostname as the subject + final X509Certificate originServerCert = generateCert(keyPair, secureRandom, + ORIGIN_SERVER_HOSTNAME); + // create a keystore with the private key and the cert. this keystore will then be + // used by the SSLContext of origin server + final KeyStore originServerKeyStore = generateKeyStore(keyPair.getPrivate(), + new Certificate[]{originServerCert}); + System.out.println("Generated a keystore, for origin server, with certificate: " + + originServerKeyStore.getCertificate(DynamicKeyStoreUtil.DEFAULT_ALIAS)); + // create the SSLContext for the origin server + final SSLContext originServerSSLCtx = DynamicKeyStoreUtil.createSSLContext( + originServerKeyStore); + + // create a cert for the alternate server, with alternate server hostname as the subject + final X509Certificate altServerCert = generateCert(keyPair, secureRandom, + ALT_SERVER_HOSTNAME); + // create keystore with the private key and the alt server's cert. this keystore will then + // be used by the SSLContext of alternate server + final KeyStore altServerKeyStore = generateKeyStore(keyPair.getPrivate(), + new Certificate[]{altServerCert}); + System.out.println("Generated a keystore, for alt server, with certificate: " + + altServerKeyStore.getCertificate(DynamicKeyStoreUtil.DEFAULT_ALIAS)); + // create SSLContext of alternate server + final SSLContext altServerSSLCtx = DynamicKeyStoreUtil.createSSLContext(altServerKeyStore); + + // now create a SSLContext for the HttpClient. This SSLContext will contain no key manager + // and will have a trust manager which trusts origin server certificate and the alternate + // server certificate + final SSLContext clientSSLCtx = sslCtxWithTrustedCerts(List.of(originServerCert, + altServerCert)); + // start the servers + final TestInput testInput = startOriginAndAltServer(originServerSSLCtx, altServerSSLCtx); + try { + final HttpClient client = newClientBuilderForH3() + .proxy(NO_PROXY) + .sslContext(clientSSLCtx) + .version(HTTP_3) + .build(); + // send a HTTP3 request to a server which is expected to respond back + // with a 200 response and an alt-svc header pointing to another/different H3 server + final URI requestURI = testInput.requestURI; + final HttpRequest request = HttpRequest.newBuilder() + .GET().uri(requestURI) + .setOption(H3_DISCOVERY, ALT_SVC) + .build(); + System.out.println("Issuing request " + requestURI); + final HttpResponse response = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + assertEquals(200, response.statusCode(), "Unexpected response code"); + // the origin server is expected to respond + assertEquals(ORIGIN_SERVER_RESPONSE_MESSAGE, response.body(), "Unexpected response" + + " body"); + assertEquals(HTTP_2, response.version(), "Unexpected HTTP version in response"); + + // verify the origin server sent back a alt-svc header + final Optional altSvcHeader = response.headers().firstValue("alt-svc"); + assertTrue(altSvcHeader.isPresent(), "alt-svc header is missing in response"); + final String actualAltSvcHeader = altSvcHeader.get(); + System.out.println("Received alt-svc header value: " + actualAltSvcHeader); + assertTrue(actualAltSvcHeader.contains(testInput.expectedAltSvcHeader), + "Unexpected alt-svc header value: " + actualAltSvcHeader + + ", was expected to contain: " + testInput.expectedAltSvcHeader); + + // now issue the same request again (a few times). Expect each of these requests too, + // to be handled by the origin server (since the advertised alt server isn't expected + // to satisfy the "reasonable assurances" expectations + for (int i = 1; i <= 3; i++) { + System.out.println("Again(" + i + ") issuing request " + requestURI); + final HttpResponse rsp = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + assertEquals(200, rsp.statusCode(), "Unexpected response code"); + // expect the alt service to respond + assertEquals(ORIGIN_SERVER_RESPONSE_MESSAGE, rsp.body(), + "Unexpected response body"); + assertEquals(HTTP_2, rsp.version(), "Unexpected HTTP version in response"); + } + } finally { + safeStop(testInput.originServer); + safeStop(testInput.altServer); + } + } + + + /** + * - Keystore K1 is constructed with a PrivateKey PK1 and certificate whose subject is + * alternate server hostname + * - K1 is used to construct a SSLContext S1 and that S1 is used by alternate server + * - The same SSLContext S1 is used by the HttpClient (and thus will trust the alternate + * server's certificate) + * - Origin server runs only on HTTP + * - Any alt-svc advertised by origin server MUST NOT be used by the client, because origin + * server runs on HTTP and as a result the "reasonable assurance" for the origin server + * cannot be satisfied. + *

    + * Once the servers are started, this test method does the following: + *

    + * - Client constructs a HTTP2 request addressed to origin server + * - Origin server responds with a 200 response and also with alt-svc header pointing to + * an alternate server + * - Client verifies the response as well as presence of the alt-svc header value + * - Client issues the request again, to the origin server, this time with HTTP3 as the request + * version + * - The request is expected to be handled by the origin server again and the advertised + * alternate service MUST NOT be used (due to reasons noted above) + */ + @Test + public void testAltServiceAdvertisedByHTTPOrigin() throws Exception { + // create a keystore which contains a PrivateKey entry and a certificate associated with + // that key. the certificate's subject will be alternate server's hostname. Thus, the + // certificate is valid for alternate server + final KeyStore keyStore = generateKeyStore(ALT_SERVER_HOSTNAME); + System.out.println("Generated a keystore with certificate: " + + keyStore.getCertificate(DynamicKeyStoreUtil.DEFAULT_ALIAS)); + // create a SSLContext that will be used by the alternate server and the HttpClient and + // will be backed by the keystore we just created. Thus, the HttpClient will trust the + // certificate belonging to that keystore + final SSLContext sslContext = DynamicKeyStoreUtil.createSSLContext(keyStore); + + // start the servers + final TestInput testInput = startHttpOriginHttpsAltServer(sslContext); + try { + final HttpClient client = newClientBuilderForH3() + .proxy(NO_PROXY) + .sslContext(sslContext) + .version(HTTP_3) + .build(); + // send a HTTP2 request to a server which is expected to respond back + // with a 200 response and an alt-svc header pointing to another/different H3 server + final URI requestURI = testInput.requestURI; + final HttpRequest request = HttpRequest.newBuilder() + .version(HTTP_2).GET() + .uri(requestURI).build(); + System.out.println("Issuing " + request.version() + " request to " + requestURI); + final HttpResponse response = client.send(request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + assertEquals(200, response.statusCode(), "Unexpected response code"); + assertEquals(HTTP_2, response.version(), "Unexpected HTTP version in response"); + // the origin server is expected to respond + assertEquals(ORIGIN_SERVER_RESPONSE_MESSAGE, response.body(), "Unexpected response" + + " body"); + + // verify the origin server sent back a alt-svc header + final Optional altSvcHeader = response.headers().firstValue("alt-svc"); + assertTrue(altSvcHeader.isPresent(), "alt-svc header is missing in response"); + final String actualAltSvcHeader = altSvcHeader.get(); + System.out.println("Received alt-svc header value: " + actualAltSvcHeader); + assertTrue(actualAltSvcHeader.contains(testInput.expectedAltSvcHeader), + "Unexpected alt-svc header value: " + actualAltSvcHeader + + ", was expected to contain: " + testInput.expectedAltSvcHeader); + + // now issue few more requests to the same address, but as a HTTP3 version. Expect each + // of these requests too, to be handled by the origin server (since the previously + // advertised alt server isn't expected to satisfy the "reasonable assurances" + // expectations) + for (int i = 1; i <= 3; i++) { + final HttpRequest h3Request = HttpRequest.newBuilder() + .version(HTTP_3).GET().uri(requestURI) + .setOption(H3_DISCOVERY, ALT_SVC) + .build(); + System.out.println("Again(" + i + ") issuing " + h3Request.version() + + " request to " + requestURI); + final HttpResponse rsp = client.send(h3Request, + HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + assertEquals(200, rsp.statusCode(), "Unexpected response code"); + // even though the request is HTTP_3 version, the client will fall back to HTTP_2 + // since the origin server does run on TLS and thus cannot use HTTP_3 (which + // requires TLS) + assertEquals(HTTP_2, rsp.version(), "Unexpected HTTP version in response"); + // expect the alt service to respond + assertEquals(ORIGIN_SERVER_RESPONSE_MESSAGE, rsp.body(), + "Unexpected response body"); + } + } finally { + safeStop(testInput.originServer); + safeStop(testInput.altServer); + } + } + + private static SSLContext sslCtxWithTrustedCerts(final List trustedCerts) + throws Exception { + Objects.requireNonNull(trustedCerts); + // start with a blank keystore + final KeyStore keyStore = DynamicKeyStoreUtil.generateBlankKeyStore(); + final String aliasPrefix = "trusted-certs-alias-"; + int i = 1; + for (final Certificate cert : trustedCerts) { + // add the cert as a trusted certificate to the keystore + keyStore.setCertificateEntry(aliasPrefix + i, cert); + i++; + } + System.out.println("Generated a keystore with (only) trusted certs: "); + for (--i; i > 0; i--) { + System.out.println(keyStore.getCertificate(aliasPrefix + i)); + } + final TrustManagerFactory tmf = TrustManagerFactory.getInstance("PKIX"); + // use the generated keystore for this trust manager + tmf.init(keyStore); + + final String protocol = "TLS"; + final SSLContext ctx = SSLContext.getInstance(protocol); + // initialize the SSLContext with the trust manager which trusts the passed certificates + ctx.init(null, tmf.getTrustManagers(), null); + return ctx; + } +} diff --git a/test/jdk/java/net/httpclient/altsvc/altsvc-dns-hosts.txt b/test/jdk/java/net/httpclient/altsvc/altsvc-dns-hosts.txt new file mode 100644 index 00000000000..67445f3b278 --- /dev/null +++ b/test/jdk/java/net/httpclient/altsvc/altsvc-dns-hosts.txt @@ -0,0 +1,23 @@ +## Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. +## DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. +## +## This code is free software; you can redistribute it and/or modify it +## under the terms of the GNU General Public License version 2 only, as +## published by the Free Software Foundation. +## +## This code is distributed in the hope that it will be useful, but WITHOUT +## ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +## FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +## version 2 for more details (a copy is included in the LICENSE file that +## accompanied this code). +## +## You should have received a copy of the GNU General Public License version +## 2 along with this work; if not, write to the Free Software Foundation, +## Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. +## +## Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA +## or visit www.oracle.com if you need additional information or have any +## questions. +## +127.0.0.1 origin.server +127.0.0.1 altservice.server diff --git a/test/jdk/java/net/httpclient/debug/java.net.http/jdk/internal/net/http/common/TestLoggerUtil.java b/test/jdk/java/net/httpclient/debug/java.net.http/jdk/internal/net/http/common/TestLoggerUtil.java new file mode 100644 index 00000000000..a5617d85127 --- /dev/null +++ b/test/jdk/java/net/httpclient/debug/java.net.http/jdk/internal/net/http/common/TestLoggerUtil.java @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.net.http.common; + +import java.util.function.Supplier; + +import jdk.internal.net.http.common.DebugLogger.LoggerConfig; + +public class TestLoggerUtil { + + public static Logger getStdoutLogger(Supplier dbgTag) { + var logLevel = Utils.DEBUG_CONFIG.logLevel(); + LoggerConfig config = Utils.DEBUG + ? LoggerConfig.STDOUT.withLogLevel(logLevel) + : LoggerConfig.OFF; + return DebugLogger.createHttpLogger(dbgTag, config); + } + + public static Logger getErrOutLogger(Supplier dbgTag) { + var logLevel = Utils.DEBUG_CONFIG.logLevel(); + LoggerConfig config = Utils.DEBUG + ? LoggerConfig.ERROUT.withLogLevel(logLevel) + : LoggerConfig.OFF; + return DebugLogger.createHttpLogger(dbgTag, config); + } +} diff --git a/test/jdk/java/net/httpclient/http2/BadPushPromiseTest.java b/test/jdk/java/net/httpclient/http2/BadPushPromiseTest.java index 73cc12ce478..c7edfd5fa8f 100644 --- a/test/jdk/java/net/httpclient/http2/BadPushPromiseTest.java +++ b/test/jdk/java/net/httpclient/http2/BadPushPromiseTest.java @@ -168,7 +168,7 @@ public class BadPushPromiseTest { } } - private void pushPromise(HttpServerAdapters.HttpTestExchange exchange) { + private void pushPromise(HttpServerAdapters.HttpTestExchange exchange) throws IOException { URI requestURI = exchange.getRequestURI(); String query = exchange.getRequestURI().getQuery(); int badHeadersIndex = Integer.parseInt(query.substring(query.indexOf("=") + 1)); diff --git a/test/jdk/java/net/httpclient/http2/ContinuationFrameTest.java b/test/jdk/java/net/httpclient/http2/ContinuationFrameTest.java index 04ce8f4f4a4..f27fdb580dd 100644 --- a/test/jdk/java/net/httpclient/http2/ContinuationFrameTest.java +++ b/test/jdk/java/net/httpclient/http2/ContinuationFrameTest.java @@ -46,6 +46,7 @@ import java.net.http.HttpRequest; import java.net.http.HttpRequest.BodyPublishers; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandlers; + import jdk.internal.net.http.common.HttpHeadersBuilder; import jdk.internal.net.http.frame.ContinuationFrame; import jdk.internal.net.http.frame.HeaderFrame; @@ -154,6 +155,22 @@ public class ContinuationFrameTest { static final int ITERATION_COUNT = 20; + HttpClient sharedClient; + HttpClient httpClient(boolean shared) { + if (!shared || sharedClient == null) { + var client = HttpClient.newBuilder() + .proxy(HttpClient.Builder.NO_PROXY) + .sslContext(sslContext) + .build(); + if (sharedClient == null) { + sharedClient = client; + } + TRACKER.track(client); + return client; + } + return sharedClient; + } + @Test(dataProvider = "variants") void test(String uri, boolean sameClient, @@ -165,11 +182,7 @@ public class ContinuationFrameTest { HttpClient client = null; for (int i=0; i< ITERATION_COUNT; i++) { if (!sameClient || client == null) { - client = HttpClient.newBuilder() - .proxy(HttpClient.Builder.NO_PROXY) - .sslContext(sslContext) - .build(); - TRACKER.track(client); + client = httpClient(sameClient); } HttpRequest request = HttpRequest.newBuilder(URI.create(uri)) @@ -229,6 +242,7 @@ public class ContinuationFrameTest { @AfterTest public void teardown() throws Exception { + sharedClient = null; AssertionError fail = TRACKER.check(500); try { http2TestServer.stop(); diff --git a/test/jdk/java/net/httpclient/http2/ErrorTest.java b/test/jdk/java/net/httpclient/http2/ErrorTest.java index 061fd5cd350..e8613b9efa8 100644 --- a/test/jdk/java/net/httpclient/http2/ErrorTest.java +++ b/test/jdk/java/net/httpclient/http2/ErrorTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -25,12 +25,26 @@ * @test * @bug 8157105 * @library /test/lib /test/jdk/java/net/httpclient/lib - * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.http2.Http2TestServer + * @build jdk.httpclient.test.lib.http2.Http2TestServer + * jdk.test.lib.Asserts + * jdk.test.lib.net.SimpleSSLContext * @modules java.base/sun.net.www.http * java.net.http/jdk.internal.net.http.common * java.net.http/jdk.internal.net.http.frame * java.net.http/jdk.internal.net.http.hpack + * java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.packets + * java.net.http/jdk.internal.net.http.quic.frames + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers * java.security.jgss + * @modules java.base/jdk.internal.util * @run testng/othervm/timeout=60 -Djavax.net.debug=ssl -Djdk.httpclient.HttpClient.log=all ErrorTest * @summary check exception thrown when bad TLS parameters selected */ @@ -47,7 +61,6 @@ import javax.net.ssl.SSLParameters; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import jdk.httpclient.test.lib.http2.Http2TestServer; -import jdk.httpclient.test.lib.http2.Http2TestExchange; import jdk.httpclient.test.lib.http2.Http2EchoHandler; import jdk.test.lib.net.SimpleSSLContext; diff --git a/test/jdk/java/net/httpclient/http2/HpackBinaryTestDriver.java b/test/jdk/java/net/httpclient/http2/HpackBinaryTestDriver.java index 7d5424163ba..4c63e863fee 100644 --- a/test/jdk/java/net/httpclient/http2/HpackBinaryTestDriver.java +++ b/test/jdk/java/net/httpclient/http2/HpackBinaryTestDriver.java @@ -29,6 +29,6 @@ * @compile/module=java.net.http jdk/internal/net/http/hpack/SpecHelper.java * @compile/module=java.net.http jdk/internal/net/http/hpack/TestHelper.java * @compile/module=java.net.http jdk/internal/net/http/hpack/BuffersTestingKit.java - * @run testng/othervm java.net.http/jdk.internal.net.http.hpack.BinaryPrimitivesTest + * @run testng/othervm/timeout=240 java.net.http/jdk.internal.net.http.hpack.BinaryPrimitivesTest */ public class HpackBinaryTestDriver { } diff --git a/test/jdk/java/net/httpclient/http2/HpackHuffmanDriver.java b/test/jdk/java/net/httpclient/http2/HpackHuffmanDriver.java index d2b515c6d63..650ed706c51 100644 --- a/test/jdk/java/net/httpclient/http2/HpackHuffmanDriver.java +++ b/test/jdk/java/net/httpclient/http2/HpackHuffmanDriver.java @@ -29,6 +29,6 @@ * @compile/module=java.net.http jdk/internal/net/http/hpack/SpecHelper.java * @compile/module=java.net.http jdk/internal/net/http/hpack/TestHelper.java * @compile/module=java.net.http jdk/internal/net/http/hpack/BuffersTestingKit.java - * @run testng/othervm java.net.http/jdk.internal.net.http.hpack.HuffmanTest + * @run testng/othervm/timeout=300 java.net.http/jdk.internal.net.http.hpack.HuffmanTest */ public class HpackHuffmanDriver { } diff --git a/test/jdk/java/net/httpclient/http2/IdleConnectionTimeoutTest.java b/test/jdk/java/net/httpclient/http2/IdleConnectionTimeoutTest.java deleted file mode 100644 index ac060e8722b..00000000000 --- a/test/jdk/java/net/httpclient/http2/IdleConnectionTimeoutTest.java +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ - -/* - * @test - * @bug 8288717 - * @summary Tests that when the idleConnectionTimeoutEvent is configured in HTTP/2, - * an HTTP/2 connection will close within the specified interval if there - * are no active streams on the connection. - * @library /test/lib /test/jdk/java/net/httpclient/lib - * @build jdk.httpclient.test.lib.http2.Http2TestServer - * - * @run testng/othervm -Djdk.httpclient.HttpClient.log=errors -Djdk.httpclient.keepalive.timeout=1 - * IdleConnectionTimeoutTest - * @run testng/othervm -Djdk.httpclient.HttpClient.log=errors -Djdk.httpclient.keepalive.timeout=2 - * IdleConnectionTimeoutTest - * - * @run testng/othervm -Djdk.httpclient.HttpClient.log=errors -Djdk.httpclient.keepalive.timeout.h2=1 - * IdleConnectionTimeoutTest - * @run testng/othervm -Djdk.httpclient.HttpClient.log=errors -Djdk.httpclient.keepalive.timeout.h2=2 - * IdleConnectionTimeoutTest - * - * @run testng/othervm -Djdk.httpclient.HttpClient.log=errors -Djdk.httpclient.keepalive.timeout.h2=1 - * -Djdk.httpclient.keepalive.timeout=2 - * IdleConnectionTimeoutTest - * - * @run testng/othervm -Djdk.httpclient.HttpClient.log=errors IdleConnectionTimeoutTest - * @run testng/othervm -Djdk.httpclient.HttpClient.log=errors -Djdk.httpclient.keepalive.timeout.h2=-1 - * IdleConnectionTimeoutTest - * @run testng/othervm -Djdk.httpclient.HttpClient.log=errors,trace -Djdk.httpclient.keepalive.timeout.h2=abc - * IdleConnectionTimeoutTest - */ - -import jdk.httpclient.test.lib.http2.BodyOutputStream; -import jdk.httpclient.test.lib.http2.Http2TestExchangeImpl; -import jdk.httpclient.test.lib.http2.Http2TestServerConnection; -import jdk.internal.net.http.common.HttpHeadersBuilder; -import org.testng.annotations.BeforeTest; -import org.testng.annotations.Test; - -import java.io.IOException; -import java.io.InputStream; -import java.io.PrintStream; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpHeaders; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.util.concurrent.CompletableFuture; -import jdk.httpclient.test.lib.http2.Http2TestServer; -import jdk.httpclient.test.lib.http2.Http2TestExchange; -import jdk.httpclient.test.lib.http2.Http2Handler; - -import javax.net.ssl.SSLSession; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.net.http.HttpClient.Version.HTTP_2; -import static org.testng.Assert.assertEquals; - -public class IdleConnectionTimeoutTest { - - static Http2TestServer http2TestServer; - URI timeoutUri; - URI noTimeoutUri; - final String IDLE_CONN_PROPERTY = "jdk.httpclient.keepalive.timeout.h2"; - final String KEEP_ALIVE_PROPERTY = "jdk.httpclient.keepalive.timeout"; - final String TIMEOUT_PATH = "/serverTimeoutHandler"; - final String NO_TIMEOUT_PATH = "/noServerTimeoutHandler"; - static final PrintStream testLog = System.err; - - @BeforeTest - public void setup() throws Exception { - http2TestServer = new Http2TestServer(false, 0); - http2TestServer.addHandler(new ServerTimeoutHandler(), TIMEOUT_PATH); - http2TestServer.addHandler(new ServerNoTimeoutHandler(), NO_TIMEOUT_PATH); - http2TestServer.setExchangeSupplier(TestExchangeSupplier::new); - - http2TestServer.start(); - int port = http2TestServer.getAddress().getPort(); - timeoutUri = new URI("http://localhost:" + port + TIMEOUT_PATH); - noTimeoutUri = new URI("http://localhost:" + port + NO_TIMEOUT_PATH); - } - - /* - If the InetSocketAddress of the first remote connection is not equal to the address of the - second remote connection, then the idleConnectionTimeoutEvent has occurred and a new connection - was made to carry out the second request by the client. - */ - @Test - public void test() throws InterruptedException { - String timeoutVal = System.getProperty(IDLE_CONN_PROPERTY); - String keepAliveVal = System.getProperty(KEEP_ALIVE_PROPERTY); - testLog.println("Test run for " + IDLE_CONN_PROPERTY + "=" + timeoutVal); - - int sleepTime = 0; - HttpClient hc = HttpClient.newBuilder().version(HTTP_2).build(); - HttpRequest hreq; - HttpResponse hresp; - if (timeoutVal != null) { - if (keepAliveVal != null) { - // In this case, specified h2 timeout should override keep alive timeout. - // Timeout should occur - hreq = HttpRequest.newBuilder(timeoutUri).version(HTTP_2).GET().build(); - sleepTime = 2000; - hresp = runRequest(hc, hreq, sleepTime); - assertEquals(hresp.statusCode(), 200, "idleConnectionTimeoutEvent was expected but did not occur"); - } else if (timeoutVal.equals("1")) { - // Timeout should occur - hreq = HttpRequest.newBuilder(timeoutUri).version(HTTP_2).GET().build(); - sleepTime = 2000; - hresp = runRequest(hc, hreq, sleepTime); - assertEquals(hresp.statusCode(), 200, "idleConnectionTimeoutEvent was expected but did not occur"); - } else if (timeoutVal.equals("2")) { - // Timeout should not occur - hreq = HttpRequest.newBuilder(noTimeoutUri).version(HTTP_2).GET().build(); - sleepTime = 1000; - hresp = runRequest(hc, hreq, sleepTime); - assertEquals(hresp.statusCode(), 200, "idleConnectionTimeoutEvent was not expected but occurred"); - } else if (timeoutVal.equals("abc") || timeoutVal.equals("-1")) { - // Timeout should not occur - hreq = HttpRequest.newBuilder(noTimeoutUri).version(HTTP_2).GET().build(); - hresp = runRequest(hc, hreq, sleepTime); - assertEquals(hresp.statusCode(), 200, "idleConnectionTimeoutEvent was not expected but occurred"); - } - } else { - // When no value is specified then no timeout should occur (default keep alive value of 600 used) - hreq = HttpRequest.newBuilder(noTimeoutUri).version(HTTP_2).GET().build(); - hresp = runRequest(hc, hreq, sleepTime); - assertEquals(hresp.statusCode(), 200, "idleConnectionTimeoutEvent should not occur, no value was specified for this property"); - } - } - - private HttpResponse runRequest(HttpClient hc, HttpRequest req, int sleepTime) throws InterruptedException { - CompletableFuture> request = hc.sendAsync(req, HttpResponse.BodyHandlers.ofString(UTF_8)); - HttpResponse hresp = request.join(); - assertEquals(hresp.statusCode(), 200); - - Thread.sleep(sleepTime); - request = hc.sendAsync(req, HttpResponse.BodyHandlers.ofString(UTF_8)); - return request.join(); - } - - static class ServerTimeoutHandler implements Http2Handler { - - volatile Object firstConnection = null; - - @Override - public void handle(Http2TestExchange exchange) throws IOException { - if (exchange instanceof TestExchangeSupplier exch) { - if (firstConnection == null) { - firstConnection = exch.getTestConnection(); - exch.sendResponseHeaders(200, 0); - } else { - var secondConnection = exch.getTestConnection(); - - if (firstConnection != secondConnection) { - testLog.println("ServerTimeoutHandler: New Connection was used, idleConnectionTimeoutEvent fired." - + " First Connection Hash: " + firstConnection + ", Second Connection Hash: " + secondConnection); - exch.sendResponseHeaders(200, 0); - } else { - testLog.println("ServerTimeoutHandler: Same Connection was used, idleConnectionTimeoutEvent did not fire." - + " First Connection Hash: " + firstConnection + ", Second Connection Hash: " + secondConnection); - exch.sendResponseHeaders(400, 0); - } - } - } - } - } - - static class ServerNoTimeoutHandler implements Http2Handler { - - volatile Object firstConnection = null; - - @Override - public void handle(Http2TestExchange exchange) throws IOException { - if (exchange instanceof TestExchangeSupplier exch) { - if (firstConnection == null) { - firstConnection = exch.getTestConnection(); - exch.sendResponseHeaders(200, 0); - } else { - var secondConnection = exch.getTestConnection(); - - if (firstConnection == secondConnection) { - testLog.println("ServerTimeoutHandler: Same Connection was used, idleConnectionTimeoutEvent did not fire." - + " First Connection Hash: " + firstConnection + ", Second Connection Hash: " + secondConnection); - exch.sendResponseHeaders(200, 0); - } else { - testLog.println("ServerTimeoutHandler: Different Connection was used, idleConnectionTimeoutEvent fired." - + " First Connection Hash: " + firstConnection + ", Second Connection Hash: " + secondConnection); - exch.sendResponseHeaders(400, 0); - } - } - } - } - } - - static class TestExchangeSupplier extends Http2TestExchangeImpl { - - public TestExchangeSupplier(int streamid, String method, HttpHeaders reqheaders, HttpHeadersBuilder rspheadersBuilder, URI uri, InputStream is, SSLSession sslSession, BodyOutputStream os, Http2TestServerConnection conn, boolean pushAllowed) { - super(streamid, method, reqheaders, rspheadersBuilder, uri, is, sslSession, os, conn, pushAllowed); - } - - public Http2TestServerConnection getTestConnection() { - return this.conn; - } - } -} \ No newline at end of file diff --git a/test/jdk/java/net/httpclient/http2/IdlePooledConnectionTest.java b/test/jdk/java/net/httpclient/http2/IdlePooledConnectionTest.java index 0f4b204fda8..907afc28fa6 100644 --- a/test/jdk/java/net/httpclient/http2/IdlePooledConnectionTest.java +++ b/test/jdk/java/net/httpclient/http2/IdlePooledConnectionTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -56,8 +56,9 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; * @summary verify that the HttpClient's HTTP2 idle connection management doesn't close a connection * when that connection has been handed out from the pool to a caller * @library /test/jdk/java/net/httpclient/lib + * /test/lib * @build jdk.httpclient.test.lib.common.HttpServerAdapters - * + * jdk.test.lib.Asserts * @run junit/othervm -Djdk.internal.httpclient.debug=true * -Djdk.httpclient.keepalive.timeout.h2=3 * IdlePooledConnectionTest diff --git a/test/jdk/java/net/httpclient/http2/ProxyTest2.java b/test/jdk/java/net/httpclient/http2/ProxyTest2.java index 733c21ffe68..2fdb1b360e1 100644 --- a/test/jdk/java/net/httpclient/http2/ProxyTest2.java +++ b/test/jdk/java/net/httpclient/http2/ProxyTest2.java @@ -21,20 +21,15 @@ * questions. */ -import com.sun.net.httpserver.HttpContext; -import com.sun.net.httpserver.HttpExchange; -import com.sun.net.httpserver.HttpHandler; -import com.sun.net.httpserver.HttpServer; import com.sun.net.httpserver.HttpsConfigurator; import com.sun.net.httpserver.HttpsParameters; -import com.sun.net.httpserver.HttpsServer; + import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.io.Writer; -import java.net.HttpURLConnection; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Proxy; @@ -42,9 +37,7 @@ import java.net.ProxySelector; import java.net.ServerSocket; import java.net.Socket; import java.net.URI; -import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; -import java.security.NoSuchAlgorithmException; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; @@ -72,11 +65,6 @@ public class ProxyTest2 { static { try { - HttpsURLConnection.setDefaultHostnameVerifier(new HostnameVerifier() { - public boolean verify(String hostname, SSLSession session) { - return true; - } - }); SSLContext.setDefault(new SimpleSSLContext().get()); } catch (IOException ex) { throw new ExceptionInInitializerError(ex); @@ -336,15 +324,4 @@ public class ProxyTest2 { } - static class Configurator extends HttpsConfigurator { - public Configurator(SSLContext ctx) { - super(ctx); - } - - @Override - public void configure (HttpsParameters params) { - params.setSSLParameters (getSSLContext().getSupportedSSLParameters()); - } - } - } diff --git a/test/jdk/java/net/httpclient/http2/PushPromiseContinuation.java b/test/jdk/java/net/httpclient/http2/PushPromiseContinuation.java index 1f9340c5c08..e9c6447e600 100644 --- a/test/jdk/java/net/httpclient/http2/PushPromiseContinuation.java +++ b/test/jdk/java/net/httpclient/http2/PushPromiseContinuation.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -233,18 +233,18 @@ public class PushPromiseContinuation { @Override - public void serverPush(URI uri, HttpHeaders headers, InputStream content) { + public void serverPush(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream content) { HttpHeadersBuilder headersBuilder = new HttpHeadersBuilder(); headersBuilder.setHeader(":method", "GET"); headersBuilder.setHeader(":scheme", uri.getScheme()); headersBuilder.setHeader(":authority", uri.getAuthority()); headersBuilder.setHeader(":path", uri.getPath()); - for (Map.Entry> entry : headers.map().entrySet()) { + for (Map.Entry> entry : reqHeaders.map().entrySet()) { for (String value : entry.getValue()) headersBuilder.addHeader(entry.getKey(), value); } HttpHeaders combinedHeaders = headersBuilder.build(); - OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, combinedHeaders, content); + OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, combinedHeaders, rspHeaders, content); // Indicates to the client that a continuation should be expected pp.setFlag(0x0); try { @@ -292,7 +292,7 @@ public class PushPromiseContinuation { } @Override - public void serverPush(URI uri, HttpHeaders headers, InputStream content) { + public void serverPush(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream content) { pushPromiseHeadersBuilder = new HttpHeadersBuilder(); testHeadersBuilder = new HttpHeadersBuilder(); cfs = new ArrayList<>(); @@ -301,7 +301,7 @@ public class PushPromiseContinuation { setPushHeaders(":scheme", uri.getScheme()); setPushHeaders(":authority", uri.getAuthority()); setPushHeaders(":path", uri.getPath()); - for (Map.Entry> entry : headers.map().entrySet()) { + for (Map.Entry> entry : reqHeaders.map().entrySet()) { for (String value : entry.getValue()) { setPushHeaders(entry.getKey(), value); } @@ -318,7 +318,7 @@ public class PushPromiseContinuation { HttpHeaders pushPromiseHeaders = pushPromiseHeadersBuilder.build(); testHeaders = testHeadersBuilder.build(); // Create the Push Promise Frame - OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, pushPromiseHeaders, content, cfs); + OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, pushPromiseHeaders, rspHeaders, content, cfs); // Indicates to the client that a continuation should be expected pp.setFlag(0x0); diff --git a/test/jdk/java/net/httpclient/http2/RedirectTest.java b/test/jdk/java/net/httpclient/http2/RedirectTest.java index 2d7701c5e37..1d7b894bc40 100644 --- a/test/jdk/java/net/httpclient/http2/RedirectTest.java +++ b/test/jdk/java/net/httpclient/http2/RedirectTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -25,9 +25,12 @@ * @test * @bug 8156514 * @library /test/lib /test/jdk/java/net/httpclient/lib - * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.http2.Http2TestServer + * @build jdk.httpclient.test.lib.http2.Http2TestExchange + * jdk.httpclient.test.lib.http2.Http2TestServer * jdk.httpclient.test.lib.http2.Http2EchoHandler * jdk.httpclient.test.lib.http2.Http2RedirectHandler + * jdk.test.lib.Asserts + * jdk.test.lib.net.SimpleSSLContext * @run testng/othervm * -Djdk.httpclient.HttpClient.log=frames,ssl,requests,responses,errors * -Djdk.internal.httpclient.debug=true @@ -47,7 +50,6 @@ import java.util.Arrays; import java.util.Iterator; import jdk.httpclient.test.lib.http2.Http2TestServer; import jdk.httpclient.test.lib.http2.Http2TestExchange; -import jdk.httpclient.test.lib.http2.Http2Handler; import jdk.httpclient.test.lib.http2.Http2EchoHandler; import jdk.httpclient.test.lib.http2.Http2RedirectHandler; import org.testng.annotations.Test; diff --git a/test/jdk/java/net/httpclient/http2/SimpleGet.java b/test/jdk/java/net/httpclient/http2/SimpleGet.java new file mode 100644 index 00000000000..8a65292a65d --- /dev/null +++ b/test/jdk/java/net/httpclient/http2/SimpleGet.java @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8087112 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.TestUtil + * jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm -XX:+CrashOnOutOfMemoryError SimpleGet + * @run testng/othervm -XX:+CrashOnOutOfMemoryError + * -Dsimpleget.repeat=1 -Dsimpleget.chunks=1 -Dsimpleget.requests=1000 + * SimpleGet + * @run testng/othervm -Dsimpleget.requests=150 + * -Dsimpleget.chunks=16384 + * -Djdk.httpclient.redirects.retrylimit=5 + * -Djdk.httpclient.HttpClient.log=errors + * -XX:+CrashOnOutOfMemoryError + * -XX:+HeapDumpOnOutOfMemoryError + * SimpleGet + */ + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.Assert; +import org.testng.annotations.Test; + +import static java.net.http.HttpClient.Version.HTTP_2; + +public class SimpleGet implements HttpServerAdapters { + static HttpTestServer httpsServer; + static HttpClient client = null; + static SSLContext sslContext; + static String httpsURIString; + static ExecutorService serverExec = Executors.newVirtualThreadPerTaskExecutor(); + + static void initialize() throws Exception { + try { + SimpleSSLContext sslct = new SimpleSSLContext(); + sslContext = sslct.get(); + client = getClient(); + + httpsServer = HttpTestServer.create(HTTP_2, sslContext, serverExec); + httpsServer.addHandler(new TestHandler(), "/"); + httpsURIString = "https://" + httpsServer.serverAuthority() + "/bar/"; + + httpsServer.start(); + warmup(); + } catch (Throwable e) { + System.err.println("Throwing now"); + e.printStackTrace(); + throw e; + } + } + + private static void warmup() throws Exception { + SimpleSSLContext sslct = new SimpleSSLContext(); + var sslContext = sslct.get(); + + // warmup server + try (var client2 = createClient(sslContext)) { + HttpRequest request = HttpRequest.newBuilder(URI.create(httpsURIString)) + .version(HTTP_2) + .HEAD().build(); + client2.send(request, BodyHandlers.discarding()); + } + + // warmup client + var httpsServer2 = HttpTestServer.create(HTTP_2, sslContext, + Executors.newVirtualThreadPerTaskExecutor()); + httpsServer2.addHandler(new TestHandler(), "/"); + var httpsURIString2 = "https://" + httpsServer2.serverAuthority() + "/bar/"; + httpsServer2.start(); + try { + HttpRequest request = HttpRequest.newBuilder(URI.create(httpsURIString2)) + .version(HTTP_2) + .HEAD().build(); + client.send(request, BodyHandlers.discarding()); + } finally { + httpsServer2.stop(); + } + } + + public static void main(String[] args) throws Exception { + test(); + } + + @Test + public static void test() throws Exception { + try { + long prestart = System.nanoTime(); + initialize(); + long done = System.nanoTime(); + System.out.println("Stat: Initialization and warmup took " + TimeUnit.NANOSECONDS.toMillis(done - prestart) + " millis"); + HttpRequest request = HttpRequest.newBuilder(URI.create(httpsURIString)) + .version(HTTP_2) + .GET().build(); + long start = System.nanoTime(); + var resp = client.send(request, BodyHandlers.ofByteArrayConsumer(b -> {})); + Assert.assertEquals(resp.statusCode(), 200); + long elapsed = System.nanoTime() - start; + System.out.println("Stat: First request took: " + elapsed + " nanos (" + TimeUnit.NANOSECONDS.toMillis(elapsed) + " ms)"); + final int max = property("simpleget.requests", 50); + ; + List>> list = new ArrayList<>(max); + Set connections = new ConcurrentSkipListSet<>(); + long start2 = System.nanoTime(); + for (int i = 0; i < max; i++) { + var cf = client.sendAsync(request, BodyHandlers.ofByteArrayConsumer(b -> {})) + .whenComplete((r, t) -> Optional.ofNullable(r) + .flatMap(HttpResponse::connectionLabel) + .ifPresent(connections::add)); + list.add(cf); + //cf.get(); // uncomment to test with serial instead of concurrent requests + } + try { + CompletableFuture.allOf(list.toArray(new CompletableFuture[0])).join(); + } finally { + long elapsed2 = System.nanoTime() - start2; + long completed = list.stream().filter(CompletableFuture::isDone) + .filter(Predicate.not(CompletableFuture::isCompletedExceptionally)).count(); + connections.forEach(System.out::println); + if (completed > 0) { + System.out.println("Stat: Next " + completed + " requests took: " + elapsed2 + " nanos (" + + TimeUnit.NANOSECONDS.toMillis(elapsed2) + "ms for " + completed + " requests): " + + elapsed2 / completed + " nanos per request (" + + TimeUnit.NANOSECONDS.toMillis(elapsed2) / completed + " ms) on " + + connections.size() + " connections"); + } + } + list.forEach((cf) -> Assert.assertEquals(cf.join().statusCode(), 200)); + } catch (Throwable tt) { + System.err.println("tt caught"); + tt.printStackTrace(); + throw tt; + } finally { + httpsServer.stop(); + } + } + + static HttpClient createClient(SSLContext sslContext) { + return HttpClient.newBuilder() + .sslContext(sslContext) + .version(HTTP_2) + .proxy(Builder.NO_PROXY) + .executor(Executors.newVirtualThreadPerTaskExecutor()) + .build(); + } + + static HttpClient getClient() { + if (client == null) { + client = createClient(sslContext); + } + return client; + } + + static int property(String name, int defaultValue) { + return Integer.parseInt(System.getProperty(name, String.valueOf(defaultValue))); + } + + // 32 * 32 * 1024 * 10 chars = 10Mb responses + // 50 requests => 500Mb + // 100 requests => 1Gb + // 1000 requests => 10Gb + private final static int REPEAT = property("simpleget.repeat", 32); + private final static String RESPONSE = "abcdefghij".repeat(property("simpleget.chunks", 1024*32)); + private final static byte[] RESPONSE_BYTES = RESPONSE.getBytes(StandardCharsets.UTF_8); + + private static class TestHandler implements HttpTestHandler { + @Override + public void handle(HttpTestExchange t) throws IOException { + try (var in = t.getRequestBody()) { + byte[] input = in.readAllBytes(); + t.sendResponseHeaders(200, RESPONSE_BYTES.length * REPEAT); + try (var out = t.getResponseBody()) { + if (t.getRequestMethod().equals("HEAD")) return; + for (int i=0; i> entry : headers.map().entrySet()) { + for (Map.Entry> entry : reqHeaders.map().entrySet()) { for (String value : entry.getValue()) headersBuilder.addHeader(entry.getKey(), value); } HttpHeaders combinedHeaders = headersBuilder.build(); - OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, combinedHeaders, content); + OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, combinedHeaders, rspHeaders, content); pp.setFlag(HeaderFrame.END_HEADERS); try { @@ -311,7 +309,7 @@ public class TrailingHeadersTest { static final BiPredicate ACCEPT_ALL = (x, y) -> true; - private void pushPromise(Http2TestExchange exchange) { + private void pushPromise(Http2TestExchange exchange) throws IOException { URI requestURI = exchange.getRequestURI(); URI uri = requestURI.resolve("/promise"); InputStream is = new ByteArrayInputStream("Sample_Push_Data".getBytes(UTF_8)); diff --git a/test/jdk/java/net/httpclient/http2/UserInfoTest.java b/test/jdk/java/net/httpclient/http2/UserInfoTest.java index 7dafda0c1f8..69cded67297 100644 --- a/test/jdk/java/net/httpclient/http2/UserInfoTest.java +++ b/test/jdk/java/net/httpclient/http2/UserInfoTest.java @@ -32,6 +32,7 @@ import javax.net.ssl.SSLContext; import java.io.IOException; import java.net.URI; import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import jdk.httpclient.test.lib.http2.Http2TestServer; @@ -46,13 +47,17 @@ import static org.junit.jupiter.api.Assertions.assertEquals; * @test * @bug 8292876 * @library /test/lib /test/jdk/java/net/httpclient/lib - * @build jdk.httpclient.test.lib.http2.Http2TestServer jdk.test.lib.net.SimpleSSLContext + * @build jdk.httpclient.test.lib.http2.Http2TestServer + * jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.http2.Http2TestExchange + * @compile ../ReferenceTracker.java * @run junit UserInfoTest */ @TestInstance(TestInstance.Lifecycle.PER_CLASS) public class UserInfoTest { + static final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; Http2TestServer server; int port; SSLContext sslContext; @@ -96,6 +101,7 @@ public class UserInfoTest { .proxy(HttpClient.Builder.NO_PROXY) .sslContext(sslContext) .build(); + TRACKER.track(client); URI uri = URIBuilder.newBuilder() .scheme("https") @@ -112,5 +118,10 @@ public class UserInfoTest { HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); assertEquals(200, response.statusCode(), "Test Failed : " + response.uri().getAuthority()); + + client = null; + System.gc(); + var error = TRACKER.check(500); + if (error != null) throw error; } } diff --git a/test/jdk/java/net/httpclient/http3/BadCipherSuiteErrorTest.java b/test/jdk/java/net/httpclient/http3/BadCipherSuiteErrorTest.java new file mode 100644 index 00000000000..abe6889e497 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/BadCipherSuiteErrorTest.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8157105 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.HttpServerAdapters + * @run testng/othervm/timeout=60 -Djavax.net.debug=ssl -Djdk.httpclient.HttpClient.log=all BadCipherSuiteErrorTest + * @summary check exception thrown when bad TLS parameters selected + */ + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import java.util.concurrent.Executors; +import java.util.concurrent.ExecutorService; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; + +import jdk.test.lib.Asserts; +import jdk.test.lib.net.SimpleSSLContext; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; + +import org.testng.annotations.Test; + +/** + * When selecting an unacceptable cipher suite the TLS handshake will fail. + * But, the exception that was thrown was not being returned up to application + * causing hang problems + */ +public class BadCipherSuiteErrorTest implements HttpServerAdapters { + + static final String[] CIPHER_SUITES = new String[]{ "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" }; + + static final String SIMPLE_STRING = "Hello world Goodbye world"; + + //@Test(timeOut=5000) + @Test + public void test() throws Exception { + SSLContext sslContext = (new SimpleSSLContext()).get(); + ExecutorService exec = Executors.newCachedThreadPool(); + var builder = newClientBuilderForH3() + .executor(exec) + .sslContext(sslContext) + .version(HTTP_3); + var goodclient = builder.build(); + var badclient = builder + .sslParameters(new SSLParameters(CIPHER_SUITES)) + .build(); + + + + HttpTestServer httpsServer = null; + try { + SSLContext serverContext = (new SimpleSSLContext()).get(); + SSLParameters p = serverContext.getSupportedSSLParameters(); + p.setApplicationProtocols(new String[]{"h3"}); + httpsServer = HttpTestServer.create(HTTP_3_URI_ONLY, serverContext); + httpsServer.addHandler(new HttpTestEchoHandler(), "/"); + String httpsURIString = "https://" + httpsServer.serverAuthority() + "/bar/"; + System.out.println("HTTP/3 Server started on: " + httpsServer.serverAuthority()); + + httpsServer.start(); + URI uri = URI.create(httpsURIString); + + HttpRequest req = HttpRequest.newBuilder(uri) + .POST(BodyPublishers.ofString(SIMPLE_STRING)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .build(); + + System.out.println("Sending request with good client to " + uri); + HttpResponse response = goodclient.send(req, BodyHandlers.ofString()); + Asserts.assertEquals(response.statusCode(), 200); + Asserts.assertEquals(response.version(), HTTP_3); + Asserts.assertEquals(response.body(), SIMPLE_STRING); + System.out.println("Expected response successfully received"); + try { + System.out.println("Sending request with bad client to " + uri); + response = badclient.send(req, BodyHandlers.discarding()); + throw new RuntimeException("Unexpected response: " + response); + } catch (IOException e) { + System.out.println("Caught Expected IOException: " + e); + } + System.out.println("DONE"); + } finally { + if (httpsServer != null ) { httpsServer.stop(); } + exec.close(); + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/FramesDecoderTest.java b/test/jdk/java/net/httpclient/http3/FramesDecoderTest.java new file mode 100644 index 00000000000..d406ee491cb --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/FramesDecoderTest.java @@ -0,0 +1,227 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.FramesDecoder; +import jdk.internal.net.http.http3.frames.Http3Frame; +import jdk.internal.net.http.http3.frames.MalformedFrame; +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.nio.ByteBuffer; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + + +/* + * @test + * @library /test/lib + * @modules java.net.http/jdk.internal.net.http.http3 + * @modules java.net.http/jdk.internal.net.http.http3.frames + * @modules java.net.http/jdk.internal.net.http.quic.streams + * @run junit/othervm FramesDecoderTest + * @summary Tests to check HTTP3 methods decode frames correctly + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class FramesDecoderTest { + // frames with arbitrary content, interpreted as PartialFrames + byte[][] vlframes() { + return new byte[][]{ + {0, 2, 0, 0}, // DATA frame, 2 bytes = 0,0 + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 0, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 2, 0, 0}, // DATA frame, 8-byte VL encoding, 2 bytes = 0,0 + {1, 2, 0, 0}, // HEADERS frame, 2 bytes = 0,0 + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 1, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 2, 0, 0}, // HEADERS frame, 8-byte VL encoding, 2 bytes = 0,0 + {5, 3, 0, 0, 0}, // PUSH_PROMISE frame, Push ID = 0, 2 bytes = 0,0 + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 5, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 10, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // PUSH_PROMISE frame, 8-byte VL encoding, Push ID = 0, 2 bytes = 0,0 + {33, 2, 0, 0}, // RESERVED frame, 2 bytes = 0,0 + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 33, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 2, 0, 0}, // RESERVED frame, 8-byte VL encoding, 2 bytes = 0,0 + {32, 2, 0, 0}, // UNKNOWN frame, 2 bytes = 0,0 + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 32, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 2, 0, 0}, // UNKNOWN frame, 8-byte VL encoding, 2 bytes = 0,0 + }; + } + + // frames with predefined content, correct + byte[][] fixedframes() { + return new byte[][]{ + {3, 1, 0}, // CANCEL_PUSH frame, Push ID = 0 + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 3, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 8, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 0}, // CANCEL_PUSH frame, 8-byte VL encoding, Push ID = 0 + {7, 1, 0}, // GOAWAY frame, Push ID = 0 + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 7, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 8, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 0}, // GOAWAY frame, 8-byte VL encoding, Push ID = 0 + {13, 1, 0}, // MAX_PUSH_ID frame, Push ID = 0 + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 13, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 8, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 0}, // MAX_PUSH_ID frame, 8-byte VL encoding, Push ID = 0 + {4, 0}, // SETTINGS frame, empty + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 4, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 0}, // SETTINGS frame, 8-byte VL encoding, empty + {4, 2, 31, 0}, // SETTINGS frame, 31(reserved)->0 + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 4, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 16, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 33, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 0}, // SETTINGS frame, 8-byte VL encoding, 33(reserved)->0 + {4, 3, 0x40, 33, 0}, // SETTINGS frame, 33(reserved)->0 + }; + } + + // incorrect frames + byte[][] badframes() { + return new byte[][]{ + {3, 2, 0, 0}, // CANCEL_PUSH frame, Push ID = 0, extra byte + {3, 2, (byte) 0xC0, 0}, // CANCEL_PUSH frame, Push ID = truncated VL + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 3, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 5, (byte) 0x80, 0, 0, 0, 0}, // CANCEL_PUSH frame, 8-byte VL encoding, Push ID = 0, extra byte + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 3, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 7, (byte) 0xC0, 0, 0, 0, 0, 0, 0}, // CANCEL_PUSH frame, 8-byte VL encoding, Push ID = truncated VL + {7, 2, 0, 0}, // GOAWAY frame, Push ID = 0, extra byte + {7, 2, (byte) 0xC0, 0}, // GOAWAY frame, Push ID = truncated VL + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 7, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 5, (byte) 0x80, 0, 0, 0, 0}, // GOAWAY frame, 8-byte VL encoding, Push ID = 0, extra byte + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 7, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 7, (byte) 0xC0, 0, 0, 0, 0, 0, 0}, // GOAWAY frame, 8-byte VL encoding, Push ID = truncated VL + {13, 2, 0, 0}, // MAX_PUSH_ID frame, Push ID = 0, extra byte + {13, 2, (byte) 0xC0, 0}, // MAX_PUSH_ID frame, Push ID = truncated VL + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 13, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 5, (byte) 0x80, 0, 0, 0, 0}, // MAX_PUSH_ID frame, 8-byte VL encoding, Push ID = 0, extra byte + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 13, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 7, (byte) 0xC0, 0, 0, 0, 0, 0, 0}, // MAX_PUSH_ID frame, 8-byte VL encoding, Push ID = truncated VL + {(byte) 0xC0, 0, 0, 0, 0, 0, 0, 5, (byte) 0xC0, 0, 0, 0, 0, 0, 0, 1, (byte) 0xC0 }, // PUSH_PROMISE frame, 8-byte VL encoding, Push ID = truncated VL + {4, 5, 33, 0, 0x41, 0, 0x40}, // SETTINGS frame, 33(reserved)->0, 64 -> truncated VL + {4, 4, 33, 0, 0x41, 0}, // SETTINGS frame, 33(reserved)->0, 64 -> ? + {4, 3, 33, 0, 0x41}, // SETTINGS frame, 33(reserved)->0, truncated VL + {4, 2, 33, 0x40}, // SETTINGS frame, 33(reserved)-> truncated VL + {4, 1, 33}, // SETTINGS frame, 33(reserved)->? + }; + } + + private static int bufLength(List bufs) { + return bufs.stream().mapToInt(ByteBuffer::remaining).sum(); + } + + @ParameterizedTest + @MethodSource("vlframes") + public void testFullVLFrames(byte[] frame) { + // offer the entire frame at once + FramesDecoder fd = new FramesDecoder("test"); + fd.submit(ByteBuffer.wrap(frame)); + fd.submit(QuicStreamReader.EOF); + Http3Frame h3frame = fd.poll(); + assertEquals(2, h3frame.streamingLength()); + assertEquals(h3frame, fd.poll()); + assertFalse(fd.eof()); + List bufs = fd.readPayloadBytes(); + assertEquals(2, bufLength(bufs)); + assertNull(fd.poll()); + assertTrue(fd.eof()); + } + + @ParameterizedTest + @MethodSource("vlframes") + public void testSplitVLFrames(byte[] frame) { + // offer the frame one byte at a time + FramesDecoder fd = new FramesDecoder("test"); + ByteBuffer buffer = ByteBuffer.wrap(frame); + for (int i = 1; i <= frame.length; i++) { + buffer.position(i-1); + buffer.limit(i); + fd.submit(buffer.asReadOnlyBuffer()); + if (i < frame.length - 2) { + assertNull(fd.poll()); + } else { + Http3Frame h3frame = fd.poll(); + assertEquals(2, h3frame.streamingLength()); + } + } + Http3Frame h3frame = fd.poll(); + assertEquals(2, h3frame.streamingLength()); + assertEquals(h3frame, fd.poll()); + assertFalse(fd.eof()); + List bufs = fd.readPayloadBytes(); + assertEquals(2, bufLength(bufs)); + assertNull(fd.poll()); + assertFalse(fd.eof()); + fd.submit(QuicStreamReader.EOF); + assertTrue(fd.eof()); + } + + @ParameterizedTest + @MethodSource("fixedframes") + public void testFullGoodFrames(byte[] frame) { + // offer the entire frame at once + FramesDecoder fd = new FramesDecoder("test"); + fd.submit(ByteBuffer.wrap(frame)); + fd.submit(QuicStreamReader.EOF); + Http3Frame h3frame = fd.poll(); + assertEquals(0, h3frame.streamingLength()); + assertTrue(fd.eof()); + List bufs = fd.readPayloadBytes(); + assertNull(bufs); + assertNull(fd.poll()); + assertTrue(fd.eof()); + } + + @ParameterizedTest + @MethodSource("fixedframes") + public void testSplitGoodFrames(byte[] frame) { + // offer the frame one byte at a time + FramesDecoder fd = new FramesDecoder("test"); + ByteBuffer buffer = ByteBuffer.wrap(frame); + for (int i = 1; i <= frame.length; i++) { + buffer.position(i-1); + buffer.limit(i); + fd.submit(buffer.asReadOnlyBuffer()); + if (i < frame.length) { + assertNull(fd.poll()); + } else { + Http3Frame h3frame = fd.poll(); + assertEquals(0, h3frame.streamingLength()); + } + } + assertNull(fd.poll()); + assertFalse(fd.eof()); + fd.submit(QuicStreamReader.EOF); + assertTrue(fd.eof()); + } + + @ParameterizedTest + @MethodSource("badframes") + public void testFullBadFrames(byte[] frame) { + // offer the entire frame at once + FramesDecoder fd = new FramesDecoder("test"); + fd.submit(ByteBuffer.wrap(frame)); + fd.submit(QuicStreamReader.EOF); + Http3Frame h3frame = fd.poll(); + assertInstanceOf(MalformedFrame.class, h3frame); + assertEquals(Http3Error.H3_FRAME_ERROR.code(), ((MalformedFrame)h3frame).getErrorCode()); + } + + @ParameterizedTest + @MethodSource("badframes") + public void testSplitBadFrames(byte[] frame) { + // offer the frame one byte at a time + FramesDecoder fd = new FramesDecoder("test"); + ByteBuffer buffer = ByteBuffer.wrap(frame); + for (int i = 1; i <= frame.length; i++) { + buffer.position(i-1); + buffer.limit(i); + fd.submit(buffer.asReadOnlyBuffer()); + if (i < frame.length) { + assertNull(fd.poll()); + } else { + Http3Frame h3frame = fd.poll(); + assertInstanceOf(MalformedFrame.class, h3frame); + assertEquals(Http3Error.H3_FRAME_ERROR.code(), ((MalformedFrame)h3frame).getErrorCode()); + } + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/GetHTTP3Test.java b/test/jdk/java/net/httpclient/http3/GetHTTP3Test.java new file mode 100644 index 00000000000..a67710c485f --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/GetHTTP3Test.java @@ -0,0 +1,476 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.Builder; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicLong; +import javax.net.ssl.SSLContext; + +import jdk.test.lib.net.SimpleSSLContext; +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http2.Http2TestServer; +import org.testng.ITestContext; +import org.testng.SkipException; +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.testng.Assert.*; + +import static java.lang.System.out; + + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @compile ../ReferenceTracker.java + * @run testng/othervm/timeout=60 -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors + * GetHTTP3Test + * @summary Basic HTTP/3 GET test + */ +// -Djdk.httpclient.http3.maxDirectConnectionTimeout=2500 +public class GetHTTP3Test implements HttpServerAdapters { + + // The response body + static final String BODY = """ + May the road rise up to meet you. + May the wind be always at your back. + May the sun shine warm upon your face; + """; + + SSLContext sslContext; + HttpTestServer h3TestServer; // HTTP/2 ( h2 + h3) + String h3URI; + + static final int ITERATION_COUNT = 4; + // a shared executor helps reduce the amount of threads created by the test + static final Executor executor = new TestExecutor(Executors.newCachedThreadPool()); + static final ConcurrentMap FAILURES = new ConcurrentHashMap<>(); + static volatile boolean tasksFailed; + static final AtomicLong serverCount = new AtomicLong(); + static final AtomicLong clientCount = new AtomicLong(); + static final long start = System.nanoTime(); + public static String now() { + long now = System.nanoTime() - start; + long secs = now / 1000_000_000; + long mill = (now % 1000_000_000) / 1000_000; + long nan = now % 1000_000; + return String.format("[%d s, %d ms, %d ns] ", secs, mill, nan); + } + + final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; + final Set sharedClientHasH3 = ConcurrentHashMap.newKeySet(); + private volatile HttpClient sharedClient; + private boolean directQuicConnectionSupported; + + static class TestExecutor implements Executor { + final AtomicLong tasks = new AtomicLong(); + Executor executor; + TestExecutor(Executor executor) { + this.executor = executor; + } + + @java.lang.Override + public void execute(Runnable command) { + long id = tasks.incrementAndGet(); + executor.execute(() -> { + try { + command.run(); + } catch (Throwable t) { + tasksFailed = true; + System.out.printf(now() + "Task %s failed: %s%n", id, t); + System.err.printf(now() + "Task %s failed: %s%n", id, t); + FAILURES.putIfAbsent("Task " + id, t); + throw t; + } + }); + } + } + + protected boolean stopAfterFirstFailure() { + return Boolean.getBoolean("jdk.internal.httpclient.debug"); + } + + @BeforeMethod + void beforeMethod(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + var x = new SkipException("Skipping: some test failed"); + x.setStackTrace(new StackTraceElement[0]); + throw x; + } + } + + @AfterClass + final void printFailedTests() { + out.println("\n========================="); + try { + out.printf("%n%sCreated %d servers and %d clients%n", + now(), serverCount.get(), clientCount.get()); + if (FAILURES.isEmpty()) return; + out.println("Failed tests: "); + FAILURES.forEach((key, value) -> { + out.printf("\t%s: %s%n", key, value); + value.printStackTrace(out); + value.printStackTrace(); + }); + if (tasksFailed) { + System.out.println("WARNING: Some tasks failed"); + } + } finally { + out.println("\n=========================\n"); + } + } + + private String[] uris() { + return new String[] { + h3URI, + }; + } + + @DataProvider(name = "variants") + public Object[][] variants(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + return new Object[0][]; + } + String[] uris = uris(); + Object[][] result = new Object[uris.length * 2 * 2 * 2][]; + int i = 0; + for (var version : List.of(Optional.empty(), Optional.of(HTTP_3))) { + for (Version firstRequestVersion : List.of(HTTP_2, HTTP_3)) { + for (boolean sameClient : List.of(false, true)) { + for (String uri : uris()) { + result[i++] = new Object[]{uri, firstRequestVersion, sameClient, version}; + } + } + } + } + assert i == result.length; + return result; + } + + @DataProvider(name = "uris") + public Object[][] uris(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + return new Object[0][]; + } + Object[][] result = {{h3URI}}; + return result; + } + + private HttpClient makeNewClient() { + clientCount.incrementAndGet(); + HttpClient client = newClientBuilderForH3() + .version(HTTP_3) + .proxy(HttpClient.Builder.NO_PROXY) + .executor(executor) + .sslContext(sslContext) + .connectTimeout(Duration.ofSeconds(10)) + .build(); + return TRACKER.track(client); + } + + HttpClient newHttpClient(boolean share) { + if (!share) return makeNewClient(); + HttpClient shared = sharedClient; + if (shared != null) return shared; + synchronized (this) { + shared = sharedClient; + if (shared == null) { + shared = sharedClient = makeNewClient(); + } + return shared; + } + } + + + @Test(dataProvider = "variants") + public void testAsync(String uri, Version firstRequestVersion, boolean sameClient, Optional version) throws Exception { + System.out.println("Request to " + uri +"/Async/*" + + ", firstRequestVersion=" + firstRequestVersion + + ", sameclient=" + sameClient + ", version=" + version); + + HttpClient client = newHttpClient(sameClient); + final URI headURI = URI.create(uri + "/Async/First/HEAD"); + final Builder headBuilder = HttpRequest.newBuilder(headURI) + .version(firstRequestVersion) + .HEAD(); + Http3DiscoveryMode config = null; + if (firstRequestVersion == HTTP_3 && !directQuicConnectionSupported) { + // if the server doesn't listen for HTTP/3 on the same port than TCP, then + // do not attempt to connect to the URI host:port through UDP - as we might + // be connecting to some other server. Once the first request has gone + // through, there should be an AltService record for the server, so + // we should be able to safely use any default config (except + // HTTP_3_URI_ONLY) + config = ALT_SVC; + } + if (config != null) { + out.println("first request will use " + config); + headBuilder.setOption(H3_DISCOVERY, config); + config = null; + } + + HttpResponse response1 = client.send(headBuilder.build(), BodyHandlers.ofString()); + assertEquals(response1.statusCode(), 200, "Unexpected first response code"); + assertEquals(response1.body(), "", "Unexpected first response body"); + boolean expectH3 = sameClient && sharedClientHasH3.contains(headURI.getRawAuthority()); + if (firstRequestVersion == HTTP_3) { + if (expectH3) { + out.println("Expecting HEAD response over HTTP_3"); + assertEquals(response1.version(), HTTP_3, "Unexpected first response version"); + } + } else { + out.println("Expecting HEAD response over HTTP_2"); + assertEquals(response1.version(), HTTP_2, "Unexpected first response version"); + } + out.println("HEAD response version: " + response1.version()); + if (response1.version() == HTTP_2) { + if (sameClient) { + sharedClientHasH3.add(headURI.getRawAuthority()); + } + expectH3 = version.isEmpty() && client.version() == HTTP_3; + if (version.orElse(null) == HTTP_3 && !directQuicConnectionSupported) { + config = ALT_SVC; + expectH3 = true; + } + // we can expect H3 only if the (default) config is not ANY + if (expectH3) { + out.println("first response came over HTTP/2, so we should expect all responses over HTTP/3"); + } + } else if (response1.version() == HTTP_3) { + expectH3 = directQuicConnectionSupported && version.orElse(null) == HTTP_3; + if (expectH3) { + out.println("first response came over HTTP/3, direct connection supported: expect HTTP/3"); + } else if (firstRequestVersion == HTTP_3 && version.isEmpty() + && config == null && directQuicConnectionSupported) { + config = ANY; + expectH3 = true; + } + } + out.printf("request version: %s, directConnectionSupported: %s, first response: %s," + + " config: %s, expectH3: %s%n", + version, directQuicConnectionSupported, response1.version(), config, expectH3); + if (expectH3) { + out.println("All responses should now come through HTTP/3"); + } + + Builder builder = HttpRequest.newBuilder() + .GET(); + version.ifPresent(builder::version); + if (config != null) { + builder.setOption(H3_DISCOVERY, config); + } + Map>> responses = new HashMap<>(); + for (int i = 0; i < ITERATION_COUNT; i++) { + HttpRequest request = builder.uri(URI.create(uri+"/Async/GET/"+i)).build(); + System.out.println("Iteration: " + request.uri()); + responses.put(request.uri(), client.sendAsync(request, BodyHandlers.ofString())); + } + int h3Count = 0; + while (!responses.isEmpty()) { + CompletableFuture.anyOf(responses.values().toArray(CompletableFuture[]::new)).join(); + var done = responses.entrySet().stream() + .filter((e) -> e.getValue().isDone()).toList(); + for (var e : done) { + URI u = e.getKey(); + responses.remove(u); + out.println("Checking response: " + u); + var response = e.getValue().get(); + out.println("Response is: " + response + ", [version: " + response.version() + "]"); + assertEquals(response.statusCode(), 200,"status for " + u); + assertEquals(response.body(), BODY,"body for " + u); + if (expectH3) { + assertEquals(response.version(), HTTP_3, "version for " + u); + } + if (response.version() == HTTP_3) { + h3Count++; + } + } + } + if (client.version() == HTTP_3 || version.orElse(null) == HTTP_3) { + if (h3Count == 0) { + throw new AssertionError("No request used HTTP/3"); + } + } + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + AssertionError error = TRACKER.check(tracker, 1000); + if (error != null) throw error; + } + System.out.println("test: DONE"); + } + + @Test(dataProvider = "uris") + public void testSync(String h3URI) throws Exception { + HttpClient client = makeNewClient(); + Builder builder = HttpRequest.newBuilder(URI.create(h3URI + "/Sync/GET/1")) + .version(HTTP_3) + .GET(); + if (!directQuicConnectionSupported) { + // if the server doesn't listen for HTTP/3 on the same port than TCP, then + // do not attempt to connect to the URI host:port through UDP - as we might + // be connecting to some other server. Once the first request has gone + // through, there should be an AltService record for the server, so + // we should be able to safely use any default config (except + // HTTP_3_URI_ONLY) + builder.setOption(H3_DISCOVERY, ALT_SVC); + } + + HttpRequest request = builder.build(); + HttpResponse response = client.send(request, BodyHandlers.ofString()); + out.println("Response #1: " + response); + out.println("Version #1: " + response.version()); + assertEquals(response.statusCode(), 200, "first response status"); + if (directQuicConnectionSupported) { + // TODO unreliable assertion + //assertEquals(response.version(), HTTP_3, "Unexpected first response version"); + } else { + assertEquals(response.version(), HTTP_2, "Unexpected first response version"); + } + assertEquals(response.body(), BODY, "first response body"); + + request = builder.uri(URI.create(h3URI + "/Sync/GET/2")).build(); + response = client.send(request, BodyHandlers.ofString()); + out.println("Response #2: " + response); + out.println("Version #2: " + response.version()); + assertEquals(response.statusCode(), 200, "second response status"); + assertEquals(response.version(), HTTP_3, "second response version"); + assertEquals(response.body(), BODY, "second response body"); + + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + AssertionError error = TRACKER.check(tracker, 1000); + if (error != null) throw error; + } + + @BeforeTest + public void setup() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) + throw new AssertionError("Unexpected null sslContext"); + + final Http2TestServer h2WithAltService = new Http2TestServer("localhost", true, + sslContext).enableH3AltServiceOnSamePort(); + h3TestServer = HttpTestServer.of(h2WithAltService); + h3TestServer.addHandler(new Handler(), "/h3/testH3/"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3/testH3/GET"; + serverCount.addAndGet(1); + h3TestServer.start(); + directQuicConnectionSupported = h2WithAltService.supportsH3DirectConnection(); + } + + @AfterTest + public void teardown() throws Exception { + System.err.println("======================================================="); + System.err.println(" Tearing down test"); + System.err.println("======================================================="); + String sharedClientName = + sharedClient == null ? null : sharedClient.toString(); + sharedClient = null; + Thread.sleep(100); + AssertionError fail = TRACKER.check(500); + try { + h3TestServer.stop(); + } finally { + if (fail != null) { + if (sharedClientName != null) { + System.err.println("Shared client name is: " + sharedClientName); + } + throw fail; + } + } + } + + static class Handler implements HttpTestHandler { + public Handler() {} + + volatile int invocation = 0; + + @java.lang.Override + public void handle(HttpTestExchange t) + throws IOException { + try { + URI uri = t.getRequestURI(); + System.err.printf("Handler received request for %s\n", uri); + + if ((invocation++ % 2) == 1) { + System.err.printf("Server sending %d - chunked\n", 200); + t.sendResponseHeaders(200, -1); + } else { + System.err.printf("Server sending %d - %s length\n", 200, BODY.length()); + t.sendResponseHeaders(200, BODY.length()); + } + try (InputStream is = t.getRequestBody(); + OutputStream os = t.getResponseBody()) { + assertEquals(is.readAllBytes().length, 0); + if (!"HEAD".equals(t.getRequestMethod())) { + String[] body = BODY.split("\n"); + for (String line : body) { + os.write(line.getBytes(StandardCharsets.UTF_8)); + os.write('\n'); + os.flush(); + } + } + } + } catch (Throwable e) { + e.printStackTrace(System.err); + throw new IOException(e); + } + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3BadHeadersTest.java b/test/jdk/java/net/httpclient/http3/H3BadHeadersTest.java new file mode 100644 index 00000000000..3bc59a4d67f --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3BadHeadersTest.java @@ -0,0 +1,330 @@ +/* + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.test.lib.net.SimpleSSLContext + * @compile ../ReferenceTracker.java + * @run testng/othervm -Djdk.internal.httpclient.debug=true H3BadHeadersTest + * @summary this test verifies the behaviour of the HttpClient when presented + * with bad headers + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.util.List; +import java.util.Map.Entry; +import java.util.concurrent.ExecutionException; + +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.util.List.of; +import static java.util.Map.entry; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.fail; + +public class H3BadHeadersTest implements HttpServerAdapters { + + private static final List>> BAD_HEADERS = of( + of(entry(":status", "200"), entry(":hello", "GET")), // Unknown pseudo-header + of(entry(":status", "200"), entry("hell o", "value")), // Space in the name + of(entry(":status", "200"), entry("hello", "line1\r\n line2\r\n")), // Multiline value + of(entry(":status", "200"), entry("hello", "DE" + ((char) 0x7F) + "L")) // Bad byte in value + // Not easily testable with H3, because we use a HttpHeadersBuilders which sorts headers... + // of(entry("hello", "world!"), entry(":status", "200")) // Pseudo header is not the first one + ); + + static final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; + + SSLContext sslContext; + HttpTestServer http3TestServer; // HTTP/3 ( h3 only ) + HttpTestServer https2TestServer; // HTTP/2 ( h2 + h3 ) + String http3URI; + String https2URI; + + + @DataProvider(name = "variants") + public Object[][] variants() { + return new Object[][] { + { http3URI, false}, + { https2URI, false}, + { http3URI, true}, + { https2URI, true}, + }; + } + + + @Test(dataProvider = "variants") + void test(String uri, + boolean sameClient) + throws Exception + { + System.out.printf("%ntest %s, %s, STARTING%n%n", uri, sameClient); + System.err.printf("%ntest %s, %s, STARTING%n%n", uri, sameClient); + var config = uri.startsWith(http3URI) + ? Http3DiscoveryMode.HTTP_3_URI_ONLY + : https2TestServer.supportsH3DirectConnection() + ? Http3DiscoveryMode.ANY + : Http3DiscoveryMode.ALT_SVC; + + boolean sendHeadRequest = config != Http3DiscoveryMode.HTTP_3_URI_ONLY; + + HttpClient client = null; + for (int i=0; i< BAD_HEADERS.size(); i++) { + boolean needsHeadRequest = false; + if (!sameClient || client == null) { + needsHeadRequest = sendHeadRequest; + client = newClientBuilderForH3() + .version(Version.HTTP_3) + .sslContext(sslContext) + .build(); + } + + if (needsHeadRequest) { + URI simpleURI = URI.create(uri); + HttpRequest head = HttpRequest.newBuilder(simpleURI) + .version(Version.HTTP_2) + .HEAD().setOption(H3_DISCOVERY, config).build(); + System.out.println("\nSending HEAD request: " + head); + var headResponse = client.send(head, BodyHandlers.ofString()); + assertEquals(headResponse.statusCode(), 200); + assertEquals(headResponse.version(), Version.HTTP_2); + } + + URI uriWithQuery = URI.create(uri + "?BAD_HEADERS=" + i); + HttpRequest request = HttpRequest.newBuilder(uriWithQuery) + .POST(BodyPublishers.ofString("Hello there!")) + .setOption(H3_DISCOVERY, config) + .version(Version.HTTP_3) + .build(); + System.out.println("\nSending request:" + uriWithQuery); + try { + HttpResponse response = client.send(request, BodyHandlers.ofString()); + fail("Expected exception, got :" + response + ", " + response.body()); + } catch (IOException ioe) { + System.out.println("Got EXPECTED: " + ioe); + assertDetailMessage(ioe, i); + } + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + var error = TRACKER.check(tracker, 1500); + if (error != null) throw error; + } + } + if (client != null) { + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + var error = TRACKER.check(tracker, 1500); + if (error != null) throw error; + } + System.out.printf("%ntest %s, %s, DONE%n%n", uri, sameClient); + System.err.printf("%ntest %s, %s, DONE%n%n", uri, sameClient); + } + + @Test(dataProvider = "variants") + void testAsync(String uri, + boolean sameClient) throws Exception + { + + System.out.printf("%ntestAsync %s, %s, STARTING%n%n", uri, sameClient); + System.err.printf("%ntestAsync %s, %s, STARTING%n%n", uri, sameClient); + var config = uri.startsWith(http3URI) + ? Http3DiscoveryMode.HTTP_3_URI_ONLY + : https2TestServer.supportsH3DirectConnection() + ? Http3DiscoveryMode.ANY + : Http3DiscoveryMode.ALT_SVC; + + boolean sendHeadRequest = config != Http3DiscoveryMode.HTTP_3_URI_ONLY; + + HttpClient client = null; + for (int i=0; i< BAD_HEADERS.size(); i++) { + boolean needsHeadRequest = false; + if (!sameClient || client == null) { + needsHeadRequest = sendHeadRequest; + client = newClientBuilderForH3() + .version(Version.HTTP_3) + .sslContext(sslContext) + .build(); + } + + if (needsHeadRequest) { + URI simpleURI = URI.create(uri); + HttpRequest head = HttpRequest.newBuilder(simpleURI) + .version(Version.HTTP_2) + .HEAD() + .setOption(H3_DISCOVERY, config) + .build(); + System.out.println("\nSending HEAD request: " + head); + + var headResponse = client.send(head, BodyHandlers.ofString()); + assertEquals(headResponse.statusCode(), 200); + assertEquals(headResponse.version(), Version.HTTP_2); + } + + URI uriWithQuery = URI.create(uri + "?BAD_HEADERS=" + i); + HttpRequest request = HttpRequest.newBuilder(uriWithQuery) + .POST(BodyPublishers.ofString("Hello there!")) + .setOption(H3_DISCOVERY, config) + .version(Version.HTTP_3) + .build(); + System.out.println("\nSending request:" + uriWithQuery); + + Throwable t = null; + try { + HttpResponse response = client.sendAsync(request, BodyHandlers.ofString()).get(); + fail("Expected exception, got :" + response + ", " + response.body()); + } catch (Throwable t0) { + System.out.println("Got EXPECTED: " + t0); + if (t0 instanceof ExecutionException) { + t0 = t0.getCause(); + } + t = t0; + } + assertDetailMessage(t, i); + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + var error = TRACKER.check(tracker, 1500); + if (error != null) throw error; + } + } + if (client != null) { + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + var error = TRACKER.check(tracker, 1500); + if (error != null) throw error; + } + System.out.printf("%ntestAsync %s, %s, DONE%n%n", uri, sameClient); + System.err.printf("%ntestAsync %s, %s, DONE%n%n", uri, sameClient); + } + + // Assertions based on implementation specific detail messages. Keep in + // sync with implementation. + static void assertDetailMessage(Throwable throwable, int iterationIndex) { + try { + assertTrue(throwable instanceof IOException, + "Expected IOException, got, " + throwable); + assertNotNull(throwable.getMessage(), "No message for " + throwable); + assertTrue(throwable.getMessage().contains("malformed response"), + "Expected \"malformed response\" in: " + throwable.getMessage()); + + if (iterationIndex == 0) { // unknown + assertTrue(throwable.getMessage().contains("Unknown pseudo-header"), + "Expected \"Unknown pseudo-header\" in: " + throwable.getMessage()); + } else if (iterationIndex == 4) { // unexpected + assertTrue(throwable.getMessage().contains(" Unexpected pseudo-header"), + "Expected \" Unexpected pseudo-header\" in: " + throwable.getMessage()); + } else { + assertTrue(throwable.getMessage().contains("Bad header"), + "Expected \"Bad header\" in: " + throwable.getMessage()); + } + } catch (AssertionError e) { + System.out.println("Exception does not match expectation: " + throwable); + throwable.printStackTrace(System.out); + throw e; + } + } + + @BeforeTest + public void setup() throws Exception { + System.out.println("creating servers"); + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) + throw new AssertionError("Unexpected null sslContext"); + + http3TestServer = HttpTestServer.create(Http3DiscoveryMode.HTTP_3_URI_ONLY, sslContext); + http3TestServer.addHandler(new BadHeadersHandler(), "/http3/echo"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/echo"; + + https2TestServer = HttpTestServer.create(Http3DiscoveryMode.ANY, sslContext); + https2TestServer.addHandler(new BadHeadersHandler(), "/https2/echo"); + https2URI = "https://" + https2TestServer.serverAuthority() + "/https2/echo"; + + http3TestServer.start(); + https2TestServer.start(); + System.out.println("server started"); + } + + @AfterTest + public void teardown() throws Exception { + System.err.println("\n\n**** stopping servers\n"); + System.out.println("stopping servers"); + http3TestServer.stop(); + https2TestServer.stop(); + System.out.println("servers stopped"); + } + + static class BadHeadersHandler implements HttpTestHandler { + + @Override + public void handle(HttpTestExchange t) throws IOException { + var uri = t.getRequestURI(); + String query = uri.getRawQuery(); + if (query != null && !query.isEmpty()) { + int badHeadersIndex = Integer.parseInt(query.substring(query.indexOf("=") + 1)); + assert badHeadersIndex >= 0 && badHeadersIndex < BAD_HEADERS.size() : + "Unexpected badHeadersIndex value: " + badHeadersIndex; + List> headers = BAD_HEADERS.get(badHeadersIndex); + var responseHeaders = t.getResponseHeaders(); + for (var e : headers) { + responseHeaders.addHeader(e.getKey(), e.getValue()); + } + } + try (InputStream is = t.getRequestBody(); + OutputStream os = t.getResponseBody()) { + byte[] bytes = is.readAllBytes(); + t.sendResponseHeaders(200, bytes.length); + if (t.getRequestMethod().equals("HEAD")) { + os.close(); + } else { + os.write(bytes); + } + } + } + } + +} diff --git a/test/jdk/java/net/httpclient/http3/H3BasicTest.java b/test/jdk/java/net/httpclient/http3/H3BasicTest.java new file mode 100644 index 00000000000..b99520ab830 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3BasicTest.java @@ -0,0 +1,405 @@ +/* + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @key randomness + * @bug 8087112 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @compile ../ReferenceTracker.java + * @build jdk.httpclient.test.lib.http2.Http2TestServer + * jdk.test.lib.Asserts + * jdk.test.lib.Utils + * jdk.test.lib.net.SimpleSSLContext + * @run testng/othervm -Djdk.httpclient.HttpClient.log=ssl,requests,responses,errors + * -Djdk.internal.httpclient.debug=true + * H3BasicTest + */ +// -Dseed=-163464189156654174 + +import java.io.IOException; +import java.net.*; +import javax.net.ssl.*; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.file.*; +import java.util.Random; +import java.util.concurrent.*; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http2.Http2TestServer; +import jdk.httpclient.test.lib.http2.Http2TestExchange; +import jdk.httpclient.test.lib.http2.Http2EchoHandler; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.test.lib.RandomFactory; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.Test; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static jdk.test.lib.Asserts.assertFileContentsEqual; +import static jdk.test.lib.Utils.createTempFile; +import static jdk.test.lib.Utils.createTempFileOfSize; + +public class H3BasicTest implements HttpServerAdapters { + + private static final Random RANDOM = RandomFactory.getRandom(); + + private static final String CLASS_NAME = H3BasicTest.class.getSimpleName(); + + static int http3Port, https2Port; + static Http3TestServer http3OnlyServer; + static Http2TestServer https2AltSvcServer; + static HttpClient client = null; + static ExecutorService clientExec; + static ExecutorService serverExec; + static SSLContext sslContext; + static volatile String http3URIString, pingURIString, https2URIString; + + static void initialize() throws Exception { + try { + SimpleSSLContext sslct = new SimpleSSLContext(); + sslContext = sslct.get(); + client = getClient(); + + // server that only supports HTTP/3 + http3OnlyServer = new Http3TestServer(sslContext, serverExec); + http3OnlyServer.addHandler("/", new Http2EchoHandler()); + http3OnlyServer.addHandler("/ping", new EchoWithPingHandler()); + http3Port = http3OnlyServer.getAddress().getPort(); + System.out.println("HTTP/3 server started at localhost:" + http3Port); + + // server that supports both HTTP/2 and HTTP/3, with HTTP/3 on an altSvc port. + https2AltSvcServer = new Http2TestServer(true, 0, serverExec, sslContext); + if (RANDOM.nextBoolean()) { + https2AltSvcServer.enableH3AltServiceOnEphemeralPort(); + } else { + https2AltSvcServer.enableH3AltServiceOnSamePort(); + } + https2AltSvcServer.addHandler(new Http2EchoHandler(), "/"); + https2Port = https2AltSvcServer.getAddress().getPort(); + if (https2AltSvcServer.supportsH3DirectConnection()) { + System.out.println("HTTP/2 server (same HTTP/3 origin) started at localhost:" + https2Port); + } else { + System.out.println("HTTP/2 server (different HTTP/3 origin) started at localhost:" + https2Port); + } + + http3URIString = "https://localhost:" + http3Port + "/foo/"; + pingURIString = "https://localhost:" + http3Port + "/ping/"; + https2URIString = "https://localhost:" + https2Port + "/bar/"; + + http3OnlyServer.start(); + https2AltSvcServer.start(); + } catch (Throwable e) { + System.err.println("Throwing now"); + e.printStackTrace(); + throw e; + } + } + + static final List> cfs = Collections + .synchronizedList( new LinkedList<>()); + + static CompletableFuture currentCF; + + static class EchoWithPingHandler extends Http2EchoHandler { + private final Object lock = new Object(); + + @Override + public void handle(Http2TestExchange exchange) throws IOException { + // for now only one ping active at a time. don't want to saturate + System.out.println("PING handler invoked for " + exchange.getRequestURI()); + synchronized(lock) { + CompletableFuture cf = currentCF; + if (cf == null || cf.isDone()) { + cf = exchange.sendPing(); + assert cf != null; + cfs.add(cf); + currentCF = cf; + } + } + super.handle(exchange); + } + } + + @Test + public static void test() throws Exception { + try { + initialize(); + System.out.println("servers initialized"); + warmup(false); + warmup(true); + System.out.println("warmup finished"); + simpleTest(false, false); + System.out.println("simpleTest(false, false): done"); + simpleTest(false, true); + System.out.println("simpleTest(false, true): done"); + simpleTest(true, false); + System.out.println("simpleTest(true, false): done"); + System.out.println("simple tests finished"); + streamTest(false); + streamTest(true); + System.out.println("stream tests finished"); + paramsTest(); + System.out.println("params test finished"); + CompletableFuture.allOf(cfs.toArray(new CompletableFuture[0])).join(); + synchronized (cfs) { + for (CompletableFuture cf : cfs) { + System.out.printf("Ping ack received in %d millisec\n", cf.get()); + } + } + System.out.println("closing client"); + if (client != null) { + var tracker = ReferenceTracker.INSTANCE; + tracker.track(client); + client = null; + System.gc(); + var error = tracker.check(1500); + clientExec.close(); + if (error != null) throw error; + } + } catch (Throwable tt) { + System.err.println("tt caught"); + tt.printStackTrace(); + throw tt; + } finally { + http3OnlyServer.stop(); + https2AltSvcServer.stop(); + serverExec.close(); + } + } + + static HttpClient getClient() { + if (client == null) { + serverExec = Executors.newCachedThreadPool(); + clientExec = Executors.newCachedThreadPool(); + client = HttpServerAdapters.createClientBuilderForH3() + .executor(clientExec) + .sslContext(sslContext) + .version(HTTP_3) + .build(); + } + return client; + } + + static URI getURI(boolean altSvc) { + return getURI(altSvc, false); + } + + static URI getURI(boolean altsvc, boolean ping) { + if (altsvc) + return URI.create(https2URIString); + else + return URI.create(ping ? pingURIString: http3URIString); + } + + static void checkStatus(int expected, int found) throws Exception { + if (expected != found) { + System.err.printf ("Test failed: wrong status code %d/%d\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static void checkStrings(String expected, String found) throws Exception { + if (!expected.equals(found)) { + System.err.printf ("Test failed: wrong string %s/%s\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static final AtomicInteger count = new AtomicInteger(); + static Http3DiscoveryMode config(boolean http3only) { + if (http3only) return HTTP_3_URI_ONLY; + // if the server supports H3 direct connection, we can + // additionally use HTTP_3_URI_ONLY; Otherwise we can + // only use ALT_SVC - or ANY (given that we should have + // preloaded an ALT_SVC in warmup) + int bound = https2AltSvcServer.supportsH3DirectConnection() ? 4 : 3; + int rand = RANDOM.nextInt(bound); + count.getAndIncrement(); + return switch (rand) { + case 1 -> ANY; + case 2 -> ALT_SVC; + case 3 -> HTTP_3_URI_ONLY; + default -> null; + }; + } + + static final String SIMPLE_STRING = "Hello world Goodbye world"; + + static final int LOOPS = 13; + static final int FILESIZE = 64 * 1024 + 200; + + static void streamTest(boolean altSvc) throws Exception { + URI uri = getURI(altSvc); + System.err.printf("streamTest %b to %s\n" , altSvc, uri); + System.out.printf("streamTest %b to %s\n" , altSvc, uri); + + HttpClient client = getClient(); + Path src = createTempFileOfSize(CLASS_NAME, ".dat", FILESIZE * 4); + var http3Only = altSvc == false; + var config = config(http3Only); + HttpRequest req = HttpRequest.newBuilder(uri) + .POST(BodyPublishers.ofFile(src)) + .setOption(H3_DISCOVERY, config) + .build(); + + Path dest = Paths.get("streamtest.txt"); + dest.toFile().delete(); + CompletableFuture response = client.sendAsync(req, BodyHandlers.ofFile(dest)) + .thenApply(resp -> { + if (resp.statusCode() != 200) + throw new RuntimeException(); + if (resp.version() != HTTP_3) { + throw new RuntimeException("wrong response version: " + resp.version()); + } + return resp.body(); + }); + response.join(); + assertFileContentsEqual(src, dest); + System.err.println("streamTest: DONE"); + } + + static void paramsTest() throws Exception { + URI u = new URI("https://localhost:"+https2Port+"/foo"); + System.out.println("paramsTest: Request to " + u); + https2AltSvcServer.addHandler((t -> { + SSLSession s = t.getSSLSession(); + String prot = s.getProtocol(); + if (prot.equals("TLSv1.3")) { + t.sendResponseHeaders(200, -1); + } else { + System.err.printf("Protocols =%s\n", prot); + t.sendResponseHeaders(500, -1); + } + }), "/"); + HttpClient client = getClient(); + HttpRequest req = HttpRequest.newBuilder(u).build(); + HttpResponse resp = client.send(req, BodyHandlers.ofString()); + int stat = resp.statusCode(); + if (stat != 200) { + throw new RuntimeException("paramsTest failed " + stat); + } + if (resp.version() != HTTP_3) { + throw new RuntimeException("wrong response version: " + resp.version()); + } + System.err.println("paramsTest: DONE"); + } + + static void warmup(boolean altSvc) throws Exception { + URI uri = getURI(altSvc); + System.out.println("Warmup: Request to " + uri); + System.err.println("Warmup: Request to " + uri); + + // Do a simple warmup request + + HttpClient client = getClient(); + var http3Only = altSvc == false; + var config = config(http3Only); + + // in the warmup phase, we want to make sure + // to preload the ALT_SVC, otherwise the first + // request that uses ALT_SVC might go through HTTP/2 + if (altSvc) config = ALT_SVC; + + HttpRequest req = HttpRequest.newBuilder(uri) + .POST(BodyPublishers.ofString(SIMPLE_STRING)) + .setOption(H3_DISCOVERY, config) + .build(); + HttpResponse response = client.send(req, BodyHandlers.ofString()); + checkStatus(200, response.statusCode()); + String responseBody = response.body(); + HttpHeaders h = response.headers(); + checkStrings(SIMPLE_STRING, responseBody); + checkStrings(h.firstValue("x-hello").get(), "world"); + checkStrings(h.firstValue("x-bye").get(), "universe"); + } + + static T logExceptionally(String desc, Throwable t) { + System.out.println(desc + " failed: " + t); + System.err.println(desc + " failed: " + t); + if (t instanceof RuntimeException r) throw r; + if (t instanceof Error e) throw e; + throw new CompletionException(t); + } + + static void simpleTest(boolean altSvc, boolean ping) throws Exception { + URI uri = getURI(altSvc, ping); + System.err.printf("simpleTest(altSvc:%s, ping:%s) Request to %s%n", + altSvc, ping, uri); + System.out.printf("simpleTest(altSvc:%s, ping:%s) Request to %s%n", + altSvc, ping, uri); + String type = altSvc ? "altSvc" : (ping ? "ping" : "http3"); + + // Do loops asynchronously + + CompletableFuture>[] responses = new CompletableFuture[LOOPS]; + final Path source = createTempFileOfSize(H3BasicTest.class.getSimpleName(), ".dat", FILESIZE); + var http3Only = altSvc == false; + for (int i = 0; i < LOOPS; i++) { + var config = config(http3Only); + HttpRequest request = HttpRequest.newBuilder(uri) + .header("X-Compare", source.toString()) + .POST(BodyPublishers.ofFile(source)) + .setOption(H3_DISCOVERY, config) + .build(); + String desc = type + ": Loop " + i; + System.out.printf("%s simpleTest(altSvc:%s, ping:%s) config(%s) Request to %s%n", + desc, altSvc, ping, config, uri); + System.err.printf("%s simpleTest(altSvc:%s, ping:%s) config(%s) Request to %s%n", + desc, altSvc, ping, config, uri); + Path requestBodyFile = createTempFile(CLASS_NAME, ".dat"); + responses[i] = client.sendAsync(request, BodyHandlers.ofFile(requestBodyFile)) + //.thenApply(resp -> assertFileContentsEqual(resp.body(), source)); + .exceptionally((t) -> logExceptionally(desc, t)) + .thenApply(resp -> { + System.out.printf("Resp %s status %d body size %d\n", + resp.version(), resp.statusCode(), + resp.body().toFile().length() + ); + assertFileContentsEqual(resp.body(), source); + if (resp.version() != HTTP_3) { + throw new RuntimeException("wrong response version: " + resp.version()); + } + return resp; + }); + Thread.sleep(100); + System.out.println(type + ": Loop " + i + " done"); + } + CompletableFuture.allOf(responses).join(); + System.err.println(type + " simpleTest: DONE"); + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3ConcurrentPush.java b/test/jdk/java/net/httpclient/http3/H3ConcurrentPush.java new file mode 100644 index 00000000000..f0a9a947d13 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3ConcurrentPush.java @@ -0,0 +1,471 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=errors,requests,responses,trace + * -Djdk.httpclient.http3.maxConcurrentPushStreams=45 + * H3ConcurrentPush + * @summary This test exercises some of the HTTP/3 specifities for PushPromises. + * It sends several concurrent requests, and the server sends a bunch of + * identical push promise frames to all of them. That is, there will be + * a push promise frame with the same push ID sent to each exchange. + * The one (and only one) of the handlers will open a push stream for + * that push id. The client checks that the expected HTTP/3 specific + * methods are invoked on the PushPromiseHandler. + */ + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.PrintStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpClient.Version; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodyHandlers; +import java.net.http.HttpResponse.PushPromiseHandler; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; +import java.util.function.Supplier; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.internal.net.http.common.Utils; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + +public class H3ConcurrentPush implements HttpServerAdapters { + + // dummy hack to prevent the IDE complaining that calling + // println will throw NPE + static final PrintStream err = System.err; + static final PrintStream out = System.out; + + static Map PUSH_PROMISES = Map.of( + "/x/y/z/1", "the first push promise body", + "/x/y/z/2", "the second push promise body", + "/x/y/z/3", "the third push promise body", + "/x/y/z/4", "the fourth push promise body", + "/x/y/z/5", "the fifth push promise body", + "/x/y/z/6", "the sixth push promise body", + "/x/y/z/7", "the seventh push promise body", + "/x/y/z/8", "the eight push promise body", + "/x/y/z/9", "the ninth push promise body" + ); + static final String MAIN_RESPONSE_BODY = "the main response body"; + + HttpTestServer server; + URI uri; + URI headURI; + ServerPushHandler pushHandler; + + @BeforeTest + public void setup() throws Exception { + server = HttpTestServer.create(ANY, new SimpleSSLContext().get()); + pushHandler = new ServerPushHandler(MAIN_RESPONSE_BODY, PUSH_PROMISES); + server.addHandler(pushHandler, "/push/"); + server.addHandler(new HttpHeadOrGetHandler(), "/head/"); + server.start(); + err.println("Server listening on port " + server.serverAuthority()); + uri = new URI("https://" + server.serverAuthority() + "/push/a/b/c"); + headURI = new URI("https://" + server.serverAuthority() + "/head/x"); + } + + @AfterTest + public void teardown() { + server.stop(); + } + + static HttpResponse assert200ResponseCode(HttpResponse response) { + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), Version.HTTP_3); + return response; + } + + private void sendHeadRequest(HttpClient client) throws IOException, InterruptedException { + HttpRequest headRequest = HttpRequest.newBuilder(headURI) + .HEAD().version(Version.HTTP_2).build(); + var headResponse = client.send(headRequest, BodyHandlers.ofString()); + assertEquals(headResponse.statusCode(), 200); + assertEquals(headResponse.version(), Version.HTTP_2); + } + + static final class TestPushPromiseHandler implements PushPromiseHandler { + record NotifiedPromise(PushId pushId, HttpRequest initiatingRequest) {} + final Map requestToPushId = new ConcurrentHashMap<>(); + final Map pushIdToRequest = new ConcurrentHashMap<>(); + final List errors = new CopyOnWriteArrayList<>(); + final List notified = new CopyOnWriteArrayList<>(); + final ConcurrentMap>> promises + = new ConcurrentHashMap<>(); + final Supplier> bodyHandlerSupplier; + final PushPromiseHandler pph; + TestPushPromiseHandler(Supplier> bodyHandlerSupplier) { + this.bodyHandlerSupplier = bodyHandlerSupplier; + this.pph = PushPromiseHandler.of((r) -> bodyHandlerSupplier.get(), promises); + } + + @Override + public void applyPushPromise(HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + Function, CompletableFuture>> acceptor) { + errors.add(new AssertionError("no pushID provided for: " + pushPromiseRequest)); + } + + @Override + public void notifyAdditionalPromise(HttpRequest initiatingRequest, PushId pushid) { + notified.add(new NotifiedPromise(pushid, initiatingRequest)); + out.println("notifyPushPromise: pushId=" + pushid); + pph.notifyAdditionalPromise(initiatingRequest, pushid); + } + + @Override + public void applyPushPromise(HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + PushId pushid, + Function, CompletableFuture>> acceptor) { + out.println("applyPushPromise: " + pushPromiseRequest + ", pushId=" + pushid); + requestToPushId.putIfAbsent(pushPromiseRequest, pushid); + if (pushIdToRequest.putIfAbsent(pushid, pushPromiseRequest) != null) { + errors.add(new AssertionError("pushId already used: " + pushid)); + } + pph.applyPushPromise(initiatingRequest, pushPromiseRequest, pushid, acceptor); + } + + } + + @Test + public void testConcurrentPushes() throws Exception { + int maxPushes = Utils.getIntegerProperty("jdk.httpclient.http3.maxConcurrentPushStreams", -1); + out.println("maxPushes: " + maxPushes); + assertTrue(maxPushes > 0); + try (HttpClient client = newClientBuilderForH3() + .proxy(Builder.NO_PROXY) + .version(Version.HTTP_3) + .sslContext(new SimpleSSLContext().get()) + .build()) { + + sendHeadRequest(client); + + // Send with promise handler + TestPushPromiseHandler custom = new TestPushPromiseHandler<>(BodyHandlers::ofString); + var promises = custom.promises; + + for (int j=0; j < 2; j++) { + if (j == 0) out.println("\ntestCancel: First time around"); + else out.println("\ntestCancel: Second time around: should be a new connection"); + + // now make sure there's an HTTP/3 connection + client.send(HttpRequest.newBuilder(headURI).version(Version.HTTP_3) + .setOption(H3_DISCOVERY, ALT_SVC).HEAD().build(), BodyHandlers.discarding()); + + int waitForPushId; + List>> responses = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + waitForPushId = Math.min(PUSH_PROMISES.size(), maxPushes) + 1; + CompletableFuture> main = client.sendAsync( + HttpRequest.newBuilder(uri.resolve("?i=%s,j=%s".formatted(i, j))) + .header("X-WaitForPushId", String.valueOf(waitForPushId)) + .build(), + BodyHandlers.ofString(), + custom); + responses.add(main); + } + CompletableFuture.allOf(responses.toArray(CompletableFuture[]::new)).join(); + responses.forEach(cf -> { + var main = cf.join(); + var old = promises.put(main.request(), CompletableFuture.completedFuture(main)); + assertNull(old, "unexpected mapping for: " + old); + }); + + promises.forEach((key, value) -> out.println(key + ":" + value.join().body())); + + promises.forEach((request, value) -> { + HttpResponse response = value.join(); + assertEquals(response.statusCode(), 200); + if (PUSH_PROMISES.containsKey(request.uri().getPath())) { + assertEquals(response.body(), PUSH_PROMISES.get(request.uri().getPath())); + } else { + assertEquals(response.body(), MAIN_RESPONSE_BODY); + } + }); + + int expectedPushes = Math.min(PUSH_PROMISES.size(), maxPushes) + 5; + assertEquals(promises.size(), expectedPushes); + + promises.clear(); + + // Send with no promise handler + try { + client.sendAsync(HttpRequest.newBuilder(uri).build(), BodyHandlers.ofString()) + .thenApply(H3ConcurrentPush::assert200ResponseCode) + .thenApply(HttpResponse::body) + .thenAccept(body -> assertEquals(body, MAIN_RESPONSE_BODY)) + .join(); + } catch (CompletionException c) { + throw new AssertionError(c.getCause()); + } + assertEquals(promises.size(), 0); + + // Send with no promise handler, but use pushId bigger than allowed. + // This should cause the connection to get closed + long usePushId = maxPushes * 3 + 10; + try { + HttpRequest bigger = HttpRequest.newBuilder(uri) + .header("X-UsePushId", String.valueOf(usePushId)) + .build(); + client.sendAsync(bigger, BodyHandlers.ofString()) + .thenApply(H3ConcurrentPush::assert200ResponseCode) + .thenApply(HttpResponse::body) + .thenAccept(body -> assertEquals(body, MAIN_RESPONSE_BODY)) + .join(); + throw new AssertionError("Expected IOException not thrown"); + } catch (CompletionException c) { + boolean success = false; + if (c.getCause() instanceof IOException io) { + if (io.getMessage() != null && + io.getMessage().contains("Max pushId exceeded (%s >= %s)" + .formatted(usePushId, maxPushes))) { + success = true; + } + if (success) { + out.println("Got expected IOException: " + io); + } else throw io; + } + if (!success) { + throw new AssertionError("Unexpected exception: " + c.getCause(), c.getCause()); + } + } + assertEquals(promises.size(), 0); + + // the next time around we should have a new connection, + // so we can restart from scratch + pushHandler.reset(); + } + var errors = custom.errors; + errors.forEach(t -> t.printStackTrace(System.out)); + var error = errors.stream().findFirst().orElse(null); + if (error != null) throw error; + var notified = custom.notified; + assertEquals(notified.size(), 9*4*2, "Unexpected notification: " + notified); + } + } + + + // --- server push handler --- + static class ServerPushHandler implements HttpTestHandler { + + private final String mainResponseBody; + private final Map promises; + private final ReentrantLock lock = new ReentrantLock(); + private final Map sentPromises = new ConcurrentHashMap<>(); + + public ServerPushHandler(String mainResponseBody, + Map promises) + throws Exception + { + Objects.requireNonNull(promises); + this.mainResponseBody = mainResponseBody; + this.promises = promises; + } + + // The assumption is that there will be several concurrent + // exchanges, but all on the same connection + // The first exchange that emits a PushPromise sends + // a push promise frame + open the push response stream. + // The other exchanges will simply send a push promise + // frame, with the pushId allocated by the previous exchange. + // The sentPromises map is used to store that pushId. + // This obviously only works if we have a single HTTP/3 connection. + final AtomicInteger count = new AtomicInteger(); + public void handle(HttpTestExchange exchange) throws IOException { + long count = -1; + try { + count = this.count.incrementAndGet(); + err.println("Server: handle " + exchange + + " on " + exchange.getConnectionKey()); + out.println("Server: handle " + exchange.getRequestURI() + + " on " + exchange.getConnectionKey()); + try (InputStream is = exchange.getRequestBody()) { + is.readAllBytes(); + } + + if (exchange.serverPushAllowed()) { + pushPromises(exchange); + } + + // response data for the main response + try (OutputStream os = exchange.getResponseBody()) { + byte[] bytes = mainResponseBody.getBytes(UTF_8); + exchange.sendResponseHeaders(200, bytes.length); + os.write(bytes); + } catch (ClosedChannelException ex) { + out.printf("handling exchange %s, %s: %s%n", count, + exchange.getRequestURI(), exchange.getRequestHeaders()); + out.printf("Got closed channel exception sending response after sent=%s allowed=%s%n", + sent, allowed); + } + } finally { + out.printf("handled exchange %s, %s: %s%n", count, + exchange.getRequestURI(), exchange.getRequestHeaders()); + } + } + + volatile long allowed = -1; + volatile int sent = 0; + volatile int nsent = 0; + void reset() { + lock.lock(); + try { + allowed = -1; + sent = 0; + nsent = 0; + sentPromises.clear(); + } finally { + lock.unlock(); + } + } + + private void pushPromises(HttpTestExchange exchange) throws IOException { + URI requestURI = exchange.getRequestURI(); + long waitForPushId = exchange.getRequestHeaders() + .firstValueAsLong("X-WaitForPushId").orElse(-1); + long usePushId = exchange.getRequestHeaders() + .firstValueAsLong("X-UsePushId").orElse(-1); + if (waitForPushId >= 0) { + while (allowed <= waitForPushId) { + try { + err.printf("Server: waiting for pushId sent=%s allowed=%s: %s%n", + sent, allowed, waitForPushId); + var allowed = exchange.waitForHttp3MaxPushId(waitForPushId); + err.println("Server: Got maxPushId: " + allowed); + out.println("Server: Got maxPushId: " + allowed); + lock.lock(); + if (allowed > this.allowed) this.allowed = allowed; + lock.unlock(); + } catch (InterruptedException ie) { + ie.printStackTrace(); + } + } + } + for (Map.Entry promise : promises.entrySet()) { + // if usePushId != -1 we send a single push promise, + // without checking that it's allowed. + // Otherwise, we stop sending promises when we have consumed + // the whole window + if (usePushId == -1 && allowed > 0 && sent >= allowed) { + err.println("Server: sent all allowed promises: " + sent); + break; + } + + if (waitForPushId >= 0) { + while (allowed <= waitForPushId) { + try { + err.printf("Server: waiting for pushId sent=%s allowed=%s: %s%n", + sent, allowed, waitForPushId); + var allowed = exchange.waitForHttp3MaxPushId(waitForPushId); + err.println("Server: Got maxPushId: " + allowed); + out.println("Server: Got maxPushId: " + allowed); + lock.lock(); + if (allowed > this.allowed) this.allowed = allowed; + lock.unlock(); + } catch (InterruptedException ie) { + ie.printStackTrace(); + } + } + } + URI uri = requestURI.resolve(promise.getKey()); + InputStream is = new ByteArrayInputStream(promise.getValue().getBytes(UTF_8)); + HttpHeaders headers = HttpHeaders.of(Collections.emptyMap(), (x, y) -> true); + if (usePushId == -1) { + long pushId; + boolean send = false; + lock.lock(); + try { + Long usedPushId = sentPromises.get(promise.getKey()); + if (usedPushId == null) { + pushId = exchange.sendHttp3PushPromiseFrame(-1, uri, headers); + waitForPushId = pushId + 1; + sentPromises.put(promise.getKey(), pushId); + sent += 1; + send = true; + } else { + pushId = usedPushId; + exchange.sendHttp3PushPromiseFrame(pushId, uri, headers); + } + } finally { + lock.unlock(); + } + if (send) { + exchange.sendHttp3PushResponse(pushId, uri, headers, headers, is); + err.println("Server: Sent push promise with response: " + pushId); + } else { + err.println("Server: Sent push promise frame: " + pushId); + } + if (pushId >= waitForPushId) waitForPushId = pushId + 1; + } else { + exchange.sendHttp3PushPromiseFrame(usePushId, uri, headers); + err.println("Server: Sent push promise frame: " + usePushId); + exchange.sendHttp3PushResponse(usePushId, uri, headers, headers, is); + err.println("Server: Sent push promise response: " + usePushId); + lock.lock(); + sent += 1; + lock.unlock(); + return; + } + } + err.println("Server: All pushes sent"); + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3ConnectionPoolTest.java b/test/jdk/java/net/httpclient/http3/H3ConnectionPoolTest.java new file mode 100644 index 00000000000..719067a7c3c --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3ConnectionPoolTest.java @@ -0,0 +1,581 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.http2.Http2TestServer + * jdk.test.lib.Asserts + * jdk.test.lib.Utils + * jdk.test.lib.net.SimpleSSLContext + * @run testng/othervm -Djdk.httpclient.HttpClient.log=ssl,requests,responses,errors,http3,quic:hs + * -Djdk.internal.httpclient.debug=false + * H3ConnectionPoolTest + */ + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.function.Supplier; + +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http2.Http2Handler; +import jdk.httpclient.test.lib.http2.Http2TestServer; +import jdk.httpclient.test.lib.http2.Http2EchoHandler; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.Test; + +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static jdk.test.lib.Asserts.assertEquals; +import static jdk.test.lib.Asserts.assertNotEquals; +import static jdk.test.lib.Asserts.assertTrue; + +public class H3ConnectionPoolTest implements HttpServerAdapters { + + private static final String CLASS_NAME = H3ConnectionPoolTest.class.getSimpleName(); + + static int altsvcPort, https2Port, http3Port; + static Http3TestServer http3OnlyServer; + static Http2TestServer https2AltSvcServer; + static volatile HttpClient client = null; + static SSLContext sslContext; + static volatile String http3OnlyURIString, https2URIString, http3AltSvcURIString, http3DirectURIString; + + static void initialize(boolean samePort) throws Exception { + initialize(samePort, Http2EchoHandler::new); + } + + static void initialize(boolean samePort, Supplier handlers) throws Exception { + System.out.println("\nConfiguring for advertised AltSvc on " + + (samePort ? "same port" : "ephemeral port")); + try { + SimpleSSLContext sslct = new SimpleSSLContext(); + sslContext = sslct.get(); + client = null; + client = getClient(); + + // server that supports both HTTP/2 and HTTP/3, with HTTP/3 on an altSvc port. + https2AltSvcServer = new Http2TestServer(true, sslContext); + if (samePort) { + System.out.println("Attempting to enable advertised HTTP/3 service on same port"); + https2AltSvcServer.enableH3AltServiceOnSamePort(); + System.out.println("Advertised AltSvc on same port " + + (https2AltSvcServer.supportsH3DirectConnection() ? "enabled" : " not enabled")); + } else { + System.out.println("Attempting to enable advertised HTTP/3 service on different port"); + https2AltSvcServer.enableH3AltServiceOnEphemeralPort(); + } + https2AltSvcServer.addHandler(handlers.get(), "/" + CLASS_NAME + "/https2/"); + https2AltSvcServer.addHandler(handlers.get(), "/" + CLASS_NAME + "/h2h3/"); + https2Port = https2AltSvcServer.getAddress().getPort(); + altsvcPort = https2AltSvcServer.getH3AltService() + .map(Http3TestServer::getAddress).stream() + .mapToInt(InetSocketAddress::getPort).findFirst() + .getAsInt(); + // server that only supports HTTP/3 - we attempt to use the same port + // as the HTTP/2 server so that we can pretend that the H2 server as two H3 endpoints: + // one advertised (the alt service endpoint og the HTTP/2 server) + // one non advertised (the direct endpoint, at the same authority as HTTP/2, but which + // is in fact our http3OnlyServer) + try { + http3OnlyServer = new Http3TestServer(sslContext, samePort ? 0 : https2Port); + System.out.println("Unadvertised service enabled on " + + (samePort ? "ephemeral port" : "same port")); + } catch (IOException ex) { + System.out.println("Can't create HTTP/3 server on same port: " + ex); + http3OnlyServer = new Http3TestServer(sslContext, 0); + } + http3OnlyServer.addHandler("/" + CLASS_NAME + "/http3/", handlers.get()); + http3OnlyServer.addHandler("/" + CLASS_NAME + "/h2h3/", handlers.get()); + http3OnlyServer.start(); + http3Port = http3OnlyServer.getQuicServer().getAddress().getPort(); + + if (http3Port == https2Port) { + System.out.println("HTTP/3 server enabled on same port than HTTP/2 server"); + if (samePort) { + System.out.println("WARNING: configuration could not be respected," + + " should have used ephemeral port for HTTP/3 server"); + } + } else { + System.out.println("HTTP/3 server enabled on a different port than HTTP/2 server"); + if (!samePort) { + System.out.println("WARNING: configuration could not be respected," + + " should have used same port for HTTP/3 server"); + } + } + if (altsvcPort == https2Port) { + if (!samePort) { + System.out.println("WARNING: configuration could not be respected," + + " should have used same port for advertised AltSvc"); + } + } else { + if (samePort) { + System.out.println("WARNING: configuration could not be respected," + + " should have used ephemeral port for advertised AltSvc"); + } + } + + http3OnlyURIString = "https://" + http3OnlyServer.serverAuthority() + "/" + CLASS_NAME + "/http3/foo/"; + https2URIString = "https://" + https2AltSvcServer.serverAuthority() + "/" + CLASS_NAME + "/https2/bar/"; + http3DirectURIString = "https://" + https2AltSvcServer.serverAuthority() + "/" + CLASS_NAME + "/h2h3/direct/"; + http3AltSvcURIString = https2URIString + .replace(":" + https2Port + "/", ":" + altsvcPort + "/") + .replace("/https2/bar/", "/h2h3/altsvc/"); + System.out.println("HTTP/2 server started at: " + https2AltSvcServer.serverAuthority()); + System.out.println(" with advertised HTTP/3 endpoint at: " + + URI.create(http3AltSvcURIString).getRawAuthority()); + System.out.println("HTTP/3 server started at:" + http3OnlyServer.serverAuthority()); + + https2AltSvcServer.start(); + } catch (Throwable e) { + System.out.println("Configuration failed: " + e); + System.err.println("Throwing now: " + e); + e.printStackTrace(); + throw e; + } + } + + @Test + public static void testH3Only() throws Exception { + System.out.println("\nTesting HTTP/3 only"); + initialize(true); + try (HttpClient client = getClient()) { + var reqBuilder = HttpRequest.newBuilder() + .uri(URI.create(http3OnlyURIString)) + .version(HTTP_3) + .GET(); + HttpRequest request1 = reqBuilder.copy() + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .build(); + HttpResponse response1 = client.send(request1, BodyHandlers.ofString()); + System.out.printf("First response: (%s): %s%n", response1.connectionLabel(), response1); + response1.headers().map().entrySet().forEach((e) -> { + System.out.printf(" %s: %s%n", e.getKey(), e.getValue()); + }); + // ANY should reuse the same connection + HttpRequest request2 = reqBuilder.copy() + .setOption(H3_DISCOVERY, ANY) + .build(); + HttpResponse response2 = client.send(request2, BodyHandlers.ofString()); + HttpRequest request3 = reqBuilder.copy() + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .build(); + HttpResponse response3 = client.send(request3, BodyHandlers.ofString()); + // ANY should reuse the same connection + HttpRequest request4 = reqBuilder.copy() + .setOption(H3_DISCOVERY, ANY) + .build(); + HttpResponse response4 = client.send(request4, BodyHandlers.ofString()); + assertEquals(response1.connectionLabel().get(), response2.connectionLabel().get()); + assertEquals(response2.connectionLabel().get(), response3.connectionLabel().get()); + assertEquals(response3.connectionLabel().get(), response4.connectionLabel().get()); + } finally { + http3OnlyServer.stop(); + https2AltSvcServer.stop(); + } + } + + @Test + public static void testH2H3WithTwoAltSVC() throws Exception { + testH2H3(false); + } + + @Test + public static void testH2H3WithAltSVCOnSamePort() throws Exception { + testH2H3(true); + } + + private static void testH2H3(boolean samePort) throws Exception { + System.out.println("\nTesting with advertised AltSvc on " + + (samePort ? "same port" : "ephemeral port")); + initialize(samePort); + try (HttpClient client = getClient()) { + var req1Builder = HttpRequest.newBuilder() + .uri(URI.create(http3DirectURIString)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .GET(); + var req2Builder = HttpRequest.newBuilder() + .uri(URI.create(http3DirectURIString)) + .setOption(H3_DISCOVERY, ALT_SVC) + .version(HTTP_3) + .GET(); + + if (altsvcPort == https2Port) { + System.out.println("Testing with alt service on same port"); + + // first request with HTTP3_URI_ONLY should create H3 connection + HttpRequest request1 = req1Builder.copy().build(); + HttpResponse response1 = client.send(request1, BodyHandlers.ofString()); + assertEquals(HTTP_3, response1.version()); + checkStatus(200, response1.statusCode()); + + HttpRequest request2 = req2Builder.copy().build(); + // first request with ALT_SVC is to get alt service, should be H2 + HttpResponse h2resp2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + // second request should have ALT_SVC and create new connection with H3 + // it should not reuse the non-advertised connection + HttpResponse response2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + assertNotEquals(response2.connectionLabel().get(), response1.connectionLabel().get()); + + // second request with HTTP3_URI_ONLY should reuse a created connection + // It should reuse the advertised connection (from response2) if same + // origin + HttpRequest request3 = req1Builder.copy().build(); + HttpResponse response3 = client.send(request3, BodyHandlers.ofString()); + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + assertEquals(response1.connectionLabel().get(), response3.connectionLabel().get()); + + // third request with ALT_SVC should reuse the same advertised + // connection (from response2), regardless of same origin... + HttpRequest request4 = req2Builder.copy().build(); + HttpResponse response4 = client.send(request4, BodyHandlers.ofString()); + assertEquals(HTTP_3, response4.version()); + checkStatus(200, response4.statusCode()); + assertEquals(response4.connectionLabel().get(), response2.connectionLabel().get()); + } else if (http3Port == https2Port) { + System.out.println("Testing with two alt services"); + // first - make a direct connection + HttpRequest request1 = req1Builder.copy().build(); + HttpResponse response1 = client.send(request1, BodyHandlers.ofString()); + assertEquals(HTTP_3, response1.version()); + checkStatus(200, response1.statusCode()); + + // second, get the alt service + HttpRequest request2 = req2Builder.copy().build(); + // first request with ALT_SVC is to get alt service, should be H2 + HttpResponse h2resp2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + + // second request should have ALT_SVC and create new connection with H3 + // it should not reuse the non-advertised connection + HttpResponse response2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + assertNotEquals(response2.connectionLabel().get(), h2resp2.connectionLabel().get()); + assertNotEquals(response2.connectionLabel().get(), response1.connectionLabel().get()); + + // third request with ALT_SVC should reuse the same advertised + // connection (from response2), regardless of same origin... + HttpRequest request3 = req2Builder.copy().build(); + HttpResponse response3 = client.send(request3, BodyHandlers.ofString()); + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + assertEquals(response3.connectionLabel().get(), response2.connectionLabel().get()); + assertNotEquals(response3.connectionLabel().get(), response1.connectionLabel().get()); + + // fourth request with HTTP_3_URI_ONLY should reuse the first connection, + // and not reuse the second. + HttpRequest request4 = req1Builder.copy().build(); + HttpResponse response4 = client.send(request1, BodyHandlers.ofString()); + assertEquals(HTTP_3, response4.version()); + assertEquals(response4.connectionLabel().get(), response1.connectionLabel().get()); + assertNotEquals(response4.connectionLabel().get(), response3.connectionLabel().get()); + checkStatus(200, response1.statusCode()); + } else { + System.out.println("WARNING: Couldn't create HTTP/3 server on same port! Can't test all..."); + // Get, get the alt service + HttpRequest request2 = req2Builder.copy().build(); + // first request with ALT_SVC is to get alt service, should be H2 + HttpResponse h2resp2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + + // second request should have ALT_SVC and create new connection with H3 + // it should not reuse the non-advertised connection + HttpResponse response2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + assertNotEquals(response2.connectionLabel().get(), h2resp2.connectionLabel().get()); + + // third request with ALT_SVC should reuse the same advertised + // connection (from response2), regardless of same origin... + HttpRequest request3 = req2Builder.copy().build(); + HttpResponse response3 = client.send(request3, BodyHandlers.ofString()); + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + assertEquals(response3.connectionLabel().get(), response2.connectionLabel().get()); + } + } finally { + http3OnlyServer.stop(); + https2AltSvcServer.stop(); + } + } + + @Test + public static void testParallelH2H3WithTwoAltSVC() throws Exception { + testH2H3Concurrent(false); + } + + @Test + public static void testParallelH2H3WithAltSVCOnSamePort() throws Exception { + testH2H3Concurrent(true); + } + + private static void testH2H3Concurrent(boolean samePort) throws Exception { + System.out.println("\nTesting concurrent connections with advertised AltSvc on " + + (samePort ? "same port" : "ephemeral port")); + initialize(samePort); + try (HttpClient client = getClient()) { + var req1Builder = HttpRequest.newBuilder() + .uri(URI.create(http3DirectURIString)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .GET(); + var req2Builder = HttpRequest.newBuilder() + .uri(URI.create(http3DirectURIString)) + .setOption(H3_DISCOVERY, ALT_SVC) + .version(HTTP_3) + .GET(); + + if (altsvcPort == https2Port) { + System.out.println("Testing with alt service on same port"); + + // first request with HTTP3_URI_ONLY should create H3 connection + HttpRequest request1 = req1Builder.copy().build(); + HttpRequest request2 = req2Builder.copy().build(); + List>> directResponses = new ArrayList<>(); + for (int i=0; i<3; i++) { + directResponses.add(client.sendAsync(request1, BodyHandlers.ofString())); + } + // can't send requests in parallel here because if any establishes + // a connection before the H3 direct are established, then the H3 + // direct might reuse the H3 alt since the service is with same origin + HttpResponse h2resp2 = client.send(request2, BodyHandlers.ofString()); + String c1Label = null; + for (int i = 0; i < directResponses.size(); i++) { + HttpResponse response1 = directResponses.get(i).get(); + System.out.printf("direct response [%s][%s]: %s%n", i, + response1.connectionLabel(), + response1); + assertEquals(HTTP_3, response1.version()); + checkStatus(200, response1.statusCode()); + if (i == 0) { + c1Label = response1.connectionLabel().get(); + } + assertEquals(c1Label, response1.connectionLabel().orElse(null)); + } + // first request with ALT_SVC is to get alt service, should be H2 + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + assertNotEquals(c1Label, h2resp2.connectionLabel().orElse(null)); + + // second request should have ALT_SVC and create new connection with H3 + // it should not reuse the non-advertised connection + List>> altResponses = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + altResponses.add(client.sendAsync(request2, BodyHandlers.ofString())); + } + String c2Label = null; + for (int i = 0; i < altResponses.size(); i++) { + HttpResponse response2 = altResponses.get(i).get(); + System.out.printf("alt response [%s][%s]: %s%n", i, + response2.connectionLabel(), + response2); + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + assertNotEquals(response2.connectionLabel().get(), c1Label); + if (i == 0) { + c2Label = response2.connectionLabel().get(); + } + assertEquals(c2Label, response2.connectionLabel().orElse(null)); + } + + // second set of requests should reuse a created connection + HttpRequest request3 = req1Builder.copy().build(); + List>> mixResponses = new ArrayList<>(); + for (int i=0; i < 3; i++) { + mixResponses.add(client.sendAsync(request3, BodyHandlers.ofString())); + mixResponses.add(client.sendAsync(request2, BodyHandlers.ofString())); + } + for (int i=0; i < mixResponses.size(); i++) { + HttpResponse response3 = mixResponses.get(i).get(); + System.out.printf("mixed response [%s][%s] %s: %s%n", i, + response3.connectionLabel(), + response3.request().getOption(H3_DISCOVERY), + response3); + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + if (response3.request().getOption(H3_DISCOVERY).orElse(null) == ALT_SVC) { + assertEquals(c2Label, response3.connectionLabel().get()); + } else { + assertEquals(c1Label, response3.connectionLabel().get()); + } + } + } else if (http3Port == https2Port) { + System.out.println("Testing with two alt services"); + // first - make a direct connection + HttpRequest request1 = req1Builder.copy().build(); + + // second, use the alt service + HttpRequest request2 = req2Builder.copy().build(); + HttpResponse h2resp2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + + // third, use ANY + HttpRequest request3 = req2Builder.copy().setOption(H3_DISCOVERY, ANY).build(); + + List>> directResponses = new ArrayList<>(); + List>> altResponses = new ArrayList<>(); + List>> anyResponses = new ArrayList<>(); + checkStatus(200, h2resp2.statusCode()); + for (int i=0; i<3; i++) { + anyResponses.add(client.sendAsync(request3, BodyHandlers.ofString())); + directResponses.add(client.sendAsync(request1, BodyHandlers.ofString())); + altResponses.add(client.sendAsync(request2, BodyHandlers.ofString())); + } + String c1Label = null; + for (int i = 0; i < directResponses.size(); i++) { + HttpResponse response1 = directResponses.get(i).get(); + System.out.printf("direct response [%s][%s] %s: %s%n", i, + response1.connectionLabel(), + response1.request().getOption(H3_DISCOVERY), + response1); + assertEquals(HTTP_3, response1.version()); + checkStatus(200, response1.statusCode()); + if (i == 0) { + c1Label = response1.connectionLabel().get(); + } + assertEquals(c1Label, response1.connectionLabel().orElse(null)); + } + String c2Label = null; + for (int i = 0; i < altResponses.size(); i++) { + HttpResponse response2 = altResponses.get(i).get(); + System.out.printf("alt response [%s][%s] %s: %s%n", i, + response2.connectionLabel(), + response2.request().getOption(H3_DISCOVERY), + response2); + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + if (i == 0) { + c2Label = response2.connectionLabel().get(); + } + assertNotEquals(response2.connectionLabel().get(), h2resp2.connectionLabel().get()); + assertNotEquals(response2.connectionLabel().get(), c1Label); + assertEquals(c2Label, response2.connectionLabel().orElse(null)); + } + var expectedLabels = Set.of(c1Label, c2Label); + for (int i = 0; i < anyResponses.size(); i++) { + HttpResponse response3 = anyResponses.get(i).get(); + System.out.printf("any response [%s][%s] %s: %s%n", i, + response3.connectionLabel(), + response3.request().getOption(H3_DISCOVERY), + response3); + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + assertNotEquals(response3.connectionLabel().get(), h2resp2.connectionLabel().get()); + var label = response3.connectionLabel().orElse(""); + assertTrue(expectedLabels.contains(label), "Unexpected label: %s not in %s" + .formatted(label, expectedLabels)); + } + } else { + System.out.println("WARNING: Couldn't create HTTP/3 server on same port! Can't test all..."); + // Get, get the alt service + HttpRequest request2 = req2Builder.copy().build(); + // first request with ALT_SVC is to get alt service, should be H2 + HttpResponse h2resp2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + + // second request should have ALT_SVC and create new connection with H3 + // it should not reuse the non-advertised connection + HttpResponse response2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + assertNotEquals(response2.connectionLabel().get(), h2resp2.connectionLabel().get()); + + // third request with ALT_SVC should reuse the same advertised + // connection (from response2), regardless of same origin... + HttpRequest request3 = req2Builder.copy().build(); + HttpResponse response3 = client.send(request3, BodyHandlers.ofString()); + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + assertEquals(response3.connectionLabel().get(), response2.connectionLabel().get()); + } + } finally { + http3OnlyServer.stop(); + https2AltSvcServer.stop(); + } + } + + static HttpClient getClient() { + if (client == null) { + client = HttpServerAdapters.createClientBuilderForH3() + .sslContext(sslContext) + .version(HTTP_3) + .build(); + } + return client; + } + + static void checkStatus(int expected, int found) throws Exception { + if (expected != found) { + System.err.printf("Test failed: wrong status code %d/%d\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static void checkStrings(String expected, String found) throws Exception { + if (!expected.equals(found)) { + System.err.printf("Test failed: wrong string %s/%s\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + + static T logExceptionally(String desc, Throwable t) { + System.out.println(desc + " failed: " + t); + System.err.println(desc + " failed: " + t); + if (t instanceof RuntimeException r) throw r; + if (t instanceof Error e) throw e; + throw new CompletionException(t); + } + +} diff --git a/test/jdk/java/net/httpclient/http3/H3DataLimitsTest.java b/test/jdk/java/net/httpclient/http3/H3DataLimitsTest.java new file mode 100644 index 00000000000..0bee94e6e8a --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3DataLimitsTest.java @@ -0,0 +1,270 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.ITestContext; +import org.testng.SkipException; +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.Builder; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.time.Duration; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicLong; + +import static java.lang.System.out; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.testng.Assert.assertEquals; + + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.httpclient.test.lib.quic.QuicStandaloneServer + * @run testng/othervm/timeout=480 -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors + * -Djavax.net.debug=all + * H3DataLimitsTest + * @summary Verify handling of MAX_DATA / MAX_STREAM_DATA frames + */ +public class H3DataLimitsTest implements HttpServerAdapters { + + SSLContext sslContext; + HttpTestServer h3TestServer; + String h3URI; + + static final Executor executor = new TestExecutor(Executors.newCachedThreadPool()); + static final ConcurrentMap FAILURES = new ConcurrentHashMap<>(); + static volatile boolean tasksFailed; + static final AtomicLong serverCount = new AtomicLong(); + static final AtomicLong clientCount = new AtomicLong(); + static final long start = System.nanoTime(); + public static String now() { + long now = System.nanoTime() - start; + long secs = now / 1000_000_000; + long mill = (now % 1000_000_000) / 1000_000; + long nan = now % 1000_000; + return String.format("[%d s, %d ms, %d ns] ", secs, mill, nan); + } + + static class TestExecutor implements Executor { + final AtomicLong tasks = new AtomicLong(); + Executor executor; + TestExecutor(Executor executor) { + this.executor = executor; + } + + @java.lang.Override + public void execute(Runnable command) { + long id = tasks.incrementAndGet(); + executor.execute(() -> { + try { + command.run(); + } catch (Throwable t) { + tasksFailed = true; + System.out.printf(now() + "Task %s failed: %s%n", id, t); + System.err.printf(now() + "Task %s failed: %s%n", id, t); + FAILURES.putIfAbsent("Task " + id, t); + throw t; + } + }); + } + } + + protected boolean stopAfterFirstFailure() { + return Boolean.getBoolean("jdk.internal.httpclient.debug"); + } + + @BeforeMethod + void beforeMethod(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + var x = new SkipException("Skipping: some test failed"); + x.setStackTrace(new StackTraceElement[0]); + throw x; + } + } + + @AfterClass + static void printFailedTests() { + out.println("\n========================="); + try { + out.printf("%n%sCreated %d servers and %d clients%n", + now(), serverCount.get(), clientCount.get()); + if (FAILURES.isEmpty()) return; + out.println("Failed tests: "); + FAILURES.forEach((key, value) -> { + out.printf("\t%s: %s%n", key, value); + value.printStackTrace(out); + value.printStackTrace(); + }); + if (tasksFailed) { + System.out.println("WARNING: Some tasks failed"); + } + } finally { + out.println("\n=========================\n"); + } + } + + @DataProvider(name = "h3URIs") + public Object[][] versions(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + return new Object[0][]; + } + Object[][] result = {{h3URI}}; + return result; + } + + private HttpClient makeNewClient() { + clientCount.incrementAndGet(); + HttpClient client = newClientBuilderForH3() + .version(Version.HTTP_3) + .proxy(HttpClient.Builder.NO_PROXY) + .executor(executor) + .sslContext(sslContext) + .connectTimeout(Duration.ofSeconds(10)) + .build(); + return client; + } + + @Test(dataProvider = "h3URIs") + public void testHugeResponse(final String h3URI) throws Exception { + HttpClient client = makeNewClient(); + URI uri = URI.create(h3URI + "?16000000"); + Builder builder = HttpRequest.newBuilder(uri) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .GET(); + HttpRequest request = builder.build(); + HttpResponse response = client.send(request, BodyHandlers.ofString()); + out.println("Response #1: " + response); + out.println("Version #1: " + response.version()); + assertEquals(response.statusCode(), 200, "first response status"); + assertEquals(response.version(), HTTP_3, "first response version"); + + response = client.send(request, BodyHandlers.ofString()); + out.println("Response #2: " + response); + out.println("Version #2: " + response.version()); + assertEquals(response.statusCode(), 200, "second response status"); + assertEquals(response.version(), HTTP_3, "second response version"); + } + + @Test(dataProvider = "h3URIs") + public void testManySmallResponses(final String h3URI) throws Exception { + HttpClient client = makeNewClient(); + URI uri = URI.create(h3URI + "?160000"); + Builder builder = HttpRequest.newBuilder(uri) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .GET(); + HttpRequest request = builder.build(); + for (int i=0; i<102; i++) { // more than 100 to exercise MAX_STREAMS + HttpResponse response = client.send(request, BodyHandlers.ofString()); + out.println("Response #" + i + ": " + response); + out.println("Version #" + i + ": " + response.version()); + assertEquals(response.statusCode(), 200, "response status"); + assertEquals(response.version(), HTTP_3, "response version"); + } + } + + @BeforeTest + public void setup() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + + // An HTTP/3 server that only supports HTTP/3 + h3TestServer = HttpTestServer.of(new Http3TestServer(sslContext)); + final HttpTestHandler h3Handler = new Handler(); + h3TestServer.addHandler(h3Handler, "/h3/testH3/"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3/testH3/x"; + + serverCount.addAndGet(1); + h3TestServer.start(); + } + + @AfterTest + public void teardown() throws Exception { + System.err.println("======================================================="); + System.err.println(" Tearing down test"); + System.err.println("======================================================="); + h3TestServer.stop(); + } + + static class Handler implements HttpTestHandler { + + public Handler() {} + + volatile int invocation = 0; + + @java.lang.Override + public void handle(HttpTestExchange t) + throws IOException { + try { + URI uri = t.getRequestURI(); + System.err.printf("Handler received request for %s\n", uri); + try (InputStream is = t.getRequestBody()) { + is.readAllBytes(); + } + System.out.println("Query: "+uri.getQuery()); + int bytesToProduce = Integer.parseInt(uri.getQuery()); + if ((invocation++ % 2) == 1) { + System.err.printf("Server sending %d - chunked\n", 200); + t.sendResponseHeaders(200, -1); + } else { + System.err.printf("Server sending %d - 0 length\n", 200); + t.sendResponseHeaders(200, bytesToProduce); + } + try (OutputStream os = t.getResponseBody()) { + os.write(new byte[bytesToProduce]); + } + } catch (Throwable e) { + e.printStackTrace(System.err); + throw new IOException(e); + } + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3ErrorHandlingTest.java b/test/jdk/java/net/httpclient/http3/H3ErrorHandlingTest.java new file mode 100644 index 00000000000..10a3a3af8f2 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3ErrorHandlingTest.java @@ -0,0 +1,1063 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.httpclient.test.lib.quic.QuicStandaloneServer; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.QuicConnectionId; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.http.quic.streams.QuicSenderStream; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; +import org.testng.IRetryAnalyzer; +import org.testng.ITestResult; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Ignore; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.io.OutputStream; +import java.net.ProtocolException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.ByteBuffer; +import java.nio.channels.DatagramChannel; +import java.time.Duration; +import java.util.Arrays; +import java.util.HexFormat; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.testng.Assert.*; + +/* + * @test + * @key intermittent + * @comment testResetControlStream may fail if the client doesn't read the stream type + * before the stream is reset, + * testConnectionCloseXXX may fail because connection_close frame is not retransmitted + * @summary Verifies that the HTTP client responds with the right error codes and types + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @library ../access + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @build java.net.http/jdk.internal.net.http.Http3ConnectionAccess + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors H3ErrorHandlingTest + */ +public class H3ErrorHandlingTest implements HttpServerAdapters { + + private SSLContext sslContext; + private QuicStandaloneServer server; + private String requestURIBase; + + @DataProvider + public static Object[][] controlStreams() { + // control / encoder / decoder + return new Object[][] {{(byte)0}, {(byte)2}, {(byte)3}}; + } + + static final byte[] data = new byte[]{(byte)0,(byte)0}; + static final byte[] headers = new byte[]{(byte)1,(byte)0}; + static final byte[] reserved1 = new byte[]{(byte)2,(byte)0}; + static final byte[] cancel_push = new byte[]{(byte)3,(byte)1,(byte)0}; + static final byte[] settings = new byte[]{(byte)4,(byte)0}; + static final byte[] push_promise = new byte[]{(byte)5,(byte)1,(byte)0}; + // 48 bytes, ID 0, 47 byte headers + static final byte[] valid_push_promise = HexFormat.of().parseHex( + "0530000000"+ // push promise, length 48, id 0, section prefix + "508b089d5c0b8170dc702fbce7"+ // :authority + "d1"+ // :method:get + "51856272d141ff"+ // :path + "d7"+ // :scheme:https + "5f5094ca3ee35a74a6b589418b5258132b1aa496ca8747"); //user-agent + static final byte[] reserved2 = new byte[]{(byte)6,(byte)0}; + static final byte[] goaway = new byte[]{(byte)7,(byte)1,(byte)4}; + static final byte[] reserved3 = new byte[]{(byte)8,(byte)0}; + static final byte[] reserved4 = new byte[]{(byte)9,(byte)0}; + static final byte[] max_push_id = new byte[]{(byte)13,(byte)1,(byte)0}; + static final byte[] huge_id_push_promise = new byte[]{(byte)5,(byte)10, + (byte)255,(byte)255,(byte)255,(byte)255,(byte)255,(byte)255,(byte)255,(byte)255, + (byte)0, (byte)0}; + + /* + Truncates or expands the frame to the specified length + */ + private static Object[][] chopFrame(byte[] frame, int... lengths) { + var result = new Object[lengths.length][]; + for (int i = 0; i< lengths.length; i++) { + int length = lengths[i]; + byte[] choppedFrame = Arrays.copyOf(frame, length + 2); + choppedFrame[1] = (byte)length; + result[i] = new Object[] {choppedFrame, lengths[i]}; + } + return result; + } + + /* + Truncates or expands the byte array to the specified length + */ + private static Object[][] chopBytes(byte[] bytes, int... lengths) { + var result = new Object[lengths.length][]; + for (int i = 0; i< lengths.length; i++) { + int length = lengths[i]; + byte[] choppedBytes = Arrays.copyOf(bytes, length); + result[i] = new Object[] {choppedBytes, lengths[i]}; + } + return result; + } + + @DataProvider + public static Object[][] malformedSettingsFrames() { + // 2-byte ID, 2-byte value + byte[] settingsFrame = new byte[]{(byte)4,(byte)4,(byte)0x40, (byte)6, (byte)0x40, (byte)6}; + return chopFrame(settingsFrame, 1, 2, 3); + } + + @DataProvider + public static Object[][] malformedCancelPushFrames() { + byte[] cancelPush = new byte[]{(byte)3,(byte)2, (byte)0x40, (byte)0}; + return chopFrame(cancelPush, 0, 1, 3, 9); + } + + @DataProvider + public static Object[][] malformedGoawayFrames() { + byte[] goaway = new byte[]{(byte)7,(byte)2, (byte)0x40, (byte)0}; + return chopFrame(goaway, 0, 1, 3, 9); + } + + @DataProvider + public static Object[][] malformedResponseHeadersFrames() { + byte[] responseHeaders = HexFormat.of().parseHex( + "011a0000"+ // headers, length 26, section prefix + "d9"+ // :status:200 + "5f5094ca3ee35a74a6b589418b5258132b1aa496ca8747"); //user-agent + return chopFrame(responseHeaders, 0, 1, 4, 5, 6, 7); + } + + @DataProvider + public static Object[][] truncatedResponseFrames() { + byte[] response = HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "000100"+ // data, 1 byte + "210100" // reserved, 1 byte + ); + return chopBytes(response, 1, 2, 3, 4, 6, 7, 9, 10); + } + + @DataProvider + public static Object[][] truncatedControlFrames() { + byte[] response = HexFormat.of().parseHex( + "00"+ // stream type: control + "04022100"+ //settings, reserved + "070104"+ //goaway, 4 + "210100" // reserved, 1 byte + ); + return chopBytes(response, 2, 3, 4, 6, 7, 9, 10); + } + + @DataProvider + public static Object[][] malformedPushPromiseFrames() { + return chopFrame(valid_push_promise, 0, 1, 2, 4, 5, 6); + } + + @DataProvider + public static Object[][] invalidControlFrames() { + // frames not valid on the server control stream (after settings) + // all except cancel_push / goaway (max_push_id is client-only) + return new Object[][] {{data}, {headers}, {settings}, {push_promise}, {max_push_id}, + {reserved1}, {reserved2}, {reserved3}, {reserved4}}; + } + + @DataProvider + public static Object[][] invalidResponseFrames() { + // frames not valid on the response stream + // all except headers / push_promise + // data is not valid as the first frame + return new Object[][] {{data}, {cancel_push}, {settings}, {goaway}, {max_push_id}, + {reserved1}, {reserved2}, {reserved3}, {reserved4}}; + } + + @DataProvider + public static Object[][] invalidPushFrames() { + // frames not valid on the push promise stream + // all except headers + // data is not valid as the first frame + return new Object[][] {{data}, {cancel_push}, {settings}, {push_promise}, {goaway}, {max_push_id}, + {reserved1}, {reserved2}, {reserved3}, {reserved4}}; + } + + @BeforeClass + public void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + server = QuicStandaloneServer.newBuilder() + .availableVersions(new QuicVersion[]{QuicVersion.QUIC_V1}) + .sslContext(sslContext) + .alpn("h3") + .build(); + server.start(); + System.out.println("Server started at " + server.getAddress()); + requestURIBase = URIBuilder.newBuilder().scheme("https").loopback() + .port(server.getAddress().getPort()).build().toString(); + } + + @AfterClass + public void afterClass() throws Exception { + if (server != null) { + System.out.println("Stopping server " + server.getAddress()); + server.close(); + } + } + + /** + * Server sends a non-settings frame on the control stream + */ + @Test + public void testNonSettingsFrame() throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = controlStream.connectWriter(scheduler); + // control stream, reserved frame, length 0 + byte[] bytesToWrite = new byte[] { 0, 0x21, 0 }; + writer.scheduleForWriting(ByteBuffer.wrap(bytesToWrite), false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_MISSING_SETTINGS); + } + + /** + * Server opens 2 control streams + */ + @Test(dataProvider = "controlStreams") + public void testTwoControlStreams(byte type) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + + QuicSenderStream controlStream, controlStream2; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + controlStream2 = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = controlStream.connectWriter(scheduler); + var writer2 = controlStream2.connectWriter(scheduler); + // control stream + byte[] bytesToWrite = new byte[] { type }; + writer.scheduleForWriting(ByteBuffer.wrap(bytesToWrite), false); + writer2.scheduleForWriting(ByteBuffer.wrap(bytesToWrite), false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_STREAM_CREATION_ERROR); + } + + /** + * Server closes control stream + */ + @Test(dataProvider = "controlStreams") + public void testCloseControlStream(byte type) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var controlscheduler = SequentialScheduler.lockingScheduler(() -> {}); + var writer = controlStream.connectWriter(controlscheduler); + + byte[] bytesToWrite = new byte[] { type }; + writer.scheduleForWriting(ByteBuffer.wrap(bytesToWrite), true); + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_CLOSED_CRITICAL_STREAM); + } + + public static class RetryOnce implements IRetryAnalyzer { + boolean retried; + + @Override + public boolean retry(ITestResult iTestResult) { + if (!retried) { + retried = true; + return true; + } + return false; + } + } + + /** + * Server resets control stream + */ + @Test(dataProvider = "controlStreams", retryAnalyzer = RetryOnce.class) + public void testResetControlStream(byte type) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var controlscheduler = SequentialScheduler.lockingScheduler(() -> {}); + var writer = controlStream.connectWriter(controlscheduler); + + byte[] bytesToWrite = new byte[] { type }; + writer.scheduleForWriting(ByteBuffer.wrap(bytesToWrite), false); + // wait for the stream data to be sent before resetting + System.out.println("Server: sending first ping"); + c.requestSendPing().join(); + // sometimes the first ping succeeds before the stream frame is delivered. + // Send another one just in case. + System.out.println("Server: sending second ping"); + c.requestSendPing().join(); + System.out.println("Server: resetting control stream " + writer.stream().streamId()); + // the test may fail if the stream type byte is not processed by HTTP3 + // before the reset is received. + writer.reset(0); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_CLOSED_CRITICAL_STREAM); + } + + /** + * Server sends unexpected frame on control stream + */ + @Test(dataProvider = "invalidControlFrames") + public void testUnexpectedControlFrame(byte[] frame) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = controlStream.connectWriter(scheduler); + // control stream, settings frame, length 0 + byte[] bytesToWrite = new byte[] { 0, 4, 0 }; + ByteBuffer buf = ByteBuffer.allocate(3 + frame.length); + buf.put(bytesToWrite); + buf.put(frame); + buf.flip(); + writer.scheduleForWriting(buf, false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_FRAME_UNEXPECTED); + } + + /** + * Server sends malformed settings frame + */ + @Test(dataProvider = "malformedSettingsFrames") + public void testMalformedSettingsFrame(byte[] frame, int bytes) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = controlStream.connectWriter(scheduler); + // control stream + byte[] bytesToWrite = new byte[] { 0 }; + ByteBuffer buf = ByteBuffer.allocate(3 + frame.length); + buf.put(bytesToWrite); + buf.put(frame); + buf.flip(); + writer.scheduleForWriting(buf, false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_FRAME_ERROR); + } + + /** + * Server sends malformed goaway frame + */ + @Test(dataProvider = "malformedGoawayFrames") + public void testMalformedGoawayFrame(byte[] frame, int bytes) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = controlStream.connectWriter(scheduler); + // control stream, settings frame, length 0 + byte[] bytesToWrite = new byte[] { 0, 4, 0 }; + ByteBuffer buf = ByteBuffer.allocate(3 + frame.length); + buf.put(bytesToWrite); + buf.put(frame); + buf.flip(); + writer.scheduleForWriting(buf, false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_FRAME_ERROR); + } + + /** + * Server sends malformed cancel push frame + */ + @Test(dataProvider = "malformedCancelPushFrames") + public void testMalformedCancelPushFrame(byte[] frame, int bytes) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = controlStream.connectWriter(scheduler); + // control stream, settings frame, length 0 + byte[] bytesToWrite = new byte[] { 0, 4, 0 }; + ByteBuffer buf = ByteBuffer.allocate(3 + frame.length); + buf.put(bytesToWrite); + buf.put(frame); + buf.flip(); + writer.scheduleForWriting(buf, false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerPushError(errorCF, Http3Error.H3_FRAME_ERROR); + } + + /** + * Server sends invalid GOAWAY frame sequence + */ + @Test + public void testInvalidGoAwaySequence() throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = controlStream.connectWriter(scheduler); + // control stream, settings frame, length 0, GOAWAY, id = 4, GOAWAY, id = 8 + byte[] bytesToWrite = new byte[] { 0, 4, 0, 7, 1, 4, 7, 1, 8}; + ByteBuffer buf = ByteBuffer.wrap(bytesToWrite); + writer.scheduleForWriting(buf, false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_ID_ERROR); + } + + /** + * Server sends invalid GOAWAY stream ID + */ + @Test + public void testInvalidGoAwayId() throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = controlStream.connectWriter(scheduler); + // control stream, settings frame, length 0, GOAWAY, id = 7 + byte[] bytesToWrite = new byte[] { 0, 4, 0, 7, 1, 7}; + ByteBuffer buf = ByteBuffer.wrap(bytesToWrite); + writer.scheduleForWriting(buf, false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_ID_ERROR); + } + + /** + * Server sends invalid CANCEL_PUSH stream ID + */ + @Test + public void testInvalidCancelPushId() throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = controlStream.connectWriter(scheduler); + // control stream, settings frame, length 0, CANCEL_PUSH, id = MAX_VL_INTEGER + byte[] bytesToWrite = new byte[] { 0, 4, 0, 3, 8, (byte)255, (byte)255, + (byte)255,(byte)255,(byte)255,(byte)255,(byte)255,(byte)255}; + ByteBuffer buf = ByteBuffer.wrap(bytesToWrite); + writer.scheduleForWriting(buf, false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_ID_ERROR); + } + + /** + * Server sends unexpected frame on push stream + */ + @Test(dataProvider = "invalidPushFrames") + public void testUnexpectedPushFrame(byte[] frame) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream pushStream; + pushStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + // write PUSH_PROMISE frame + s.outputStream().write(valid_push_promise); + var writer = pushStream.connectWriter(scheduler); + // push stream, id 0 + byte[] bytesToWrite = new byte[] { 1, 0 }; + ByteBuffer buf = ByteBuffer.allocate(2 + frame.length); + buf.put(bytesToWrite); + buf.put(frame); + buf.flip(); + writer.scheduleForWriting(buf, false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerPushError(errorCF, Http3Error.H3_FRAME_UNEXPECTED); + } + + /** + * Server sends malformed frame on push stream + */ + @Test(dataProvider = "malformedResponseHeadersFrames") + public void testMalformedPushStreamFrame(byte[] frame, int bytes) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream pushStream; + pushStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + // write PUSH_PROMISE frame + s.outputStream().write(valid_push_promise); + var writer = pushStream.connectWriter(scheduler); + // push stream, id 0 + byte[] bytesToWrite = new byte[] { 1, 0 }; + ByteBuffer buf = ByteBuffer.allocate(2 + frame.length); + buf.put(bytesToWrite); + buf.put(frame); + buf.flip(); + writer.scheduleForWriting(buf, false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerPushError(errorCF, frame.length == 2 + ? Http3Error.H3_FRAME_ERROR + : Http3Error.QPACK_DECOMPRESSION_FAILED); + } + + /** + * Server sends malformed frame on push stream + */ + @Test(dataProvider = "malformedPushPromiseFrames") + public void testMalformedPushPromiseFrame(byte[] frame, int bytes) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + // write PUSH_PROMISE frame + s.outputStream().write(frame); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerPushError(errorCF, frame.length <= 3 + ? Http3Error.H3_FRAME_ERROR + : Http3Error.QPACK_DECOMPRESSION_FAILED); + } + + /** + * Server reuses push stream ID + */ + @Test + public void testDuplicatePushStream() throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream pushStream, pushStream2; + pushStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + pushStream2 = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + // write PUSH_PROMISE frame + s.outputStream().write(valid_push_promise); + + var writer = pushStream.connectWriter(scheduler); + // push stream, id 0 + byte[] bytesToWrite = new byte[] { 1, 0 }; + writer.scheduleForWriting(ByteBuffer.wrap(bytesToWrite), false); + + writer = pushStream2.connectWriter(scheduler); + // push stream, id 0 + writer.scheduleForWriting(ByteBuffer.wrap(bytesToWrite), false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerPushError(errorCF, Http3Error.H3_ID_ERROR); + } + + /** + * Server sends push promise with ID > MAX_PUSH_ID + */ + @Test + public void testInvalidPushPromiseId() throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + // write PUSH_PROMISE frame + s.outputStream().write(huge_id_push_promise); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_ID_ERROR); + } + + /** + * Server opens a push stream ID > MAX_PUSH_ID + */ + @Test + public void testInvalidPushStreamId() throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream pushStream; + pushStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = pushStream.connectWriter(scheduler); + // push stream, id MAX_VL_INTEGER + byte[] bytesToWrite = new byte[] { 1, + (byte)255, (byte)255, (byte)255, (byte)255, (byte)255, (byte)255, (byte)255, (byte)255 }; + ByteBuffer buf = ByteBuffer.wrap(bytesToWrite); + writer.scheduleForWriting(buf, false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_ID_ERROR); + } + + /** + * Server sends unexpected frame on response stream + */ + @Test(dataProvider = "invalidResponseFrames") + public void testUnexpectedResponseFrame(byte[] frame) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + s.outputStream().write(frame); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_FRAME_UNEXPECTED); + } + + /** + * Server sends malformed headers frame on response stream + */ + @Test(dataProvider = "malformedResponseHeadersFrames") + public void testMalformedResponseFrame(byte[] frame, int bytes) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + s.outputStream().write(frame); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, frame.length == 2 + ? Http3Error.H3_FRAME_ERROR + : Http3Error.QPACK_DECOMPRESSION_FAILED); + } + + /** + * Server truncates a frame on the response stream + */ + @Test(dataProvider = "truncatedResponseFrames") + public void testTruncatedResponseFrame(byte[] frame, int bytes) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + try (OutputStream outputStream = s.outputStream()) { + outputStream.write(frame); + } + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_FRAME_ERROR); + } + + /** + * Server truncates a frame on the control stream + */ + @Test(dataProvider = "truncatedControlFrames") + public void testTruncatedControlFrame(byte[] frame, int bytes) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var controlscheduler = SequentialScheduler.lockingScheduler(() -> {}); + var writer = controlStream.connectWriter(controlscheduler); + + writer.scheduleForWriting(ByteBuffer.wrap(frame), true); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + // H3_CLOSED_CRITICAL_STREAM is also acceptable here + triggerError(errorCF, Http3Error.H3_FRAME_ERROR, Http3Error.H3_CLOSED_CRITICAL_STREAM); + } + + /** + * Server truncates a frame on the push stream + */ + @Test(dataProvider = "truncatedResponseFrames") + public void testTruncatedPushStreamFrame(byte[] frame, int bytes) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream pushStream; + pushStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + // write PUSH_PROMISE frame + s.outputStream().write(valid_push_promise); + var writer = pushStream.connectWriter(scheduler); + // push stream, id 0 + byte[] bytesToWrite = new byte[] { 1, 0 }; + ByteBuffer buf = ByteBuffer.allocate(2 + frame.length); + buf.put(bytesToWrite); + buf.put(frame); + buf.flip(); + writer.scheduleForWriting(buf, true); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerPushError(errorCF, Http3Error.H3_FRAME_ERROR); + } + + /** + * Server sends a settings frame with reserved HTTP2 settings + */ + @Test + public void testReservedSettingsFrames() throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream controlStream; + controlStream = c.openNewLocalUniStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = controlStream.connectWriter(scheduler); + // control stream, settings frame, length 2, setting 4 = 0 + byte[] bytesToWrite = new byte[] { 0, 4, 2, 4, 0 }; + writer.scheduleForWriting(ByteBuffer.wrap(bytesToWrite), false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_SETTINGS_ERROR); + } + + /** + * Server sends a stateless reset + */ + @Test + public void testStatelessReset() throws Exception { + server.addHandler((c,s)-> { + // stateless reset + QuicConnectionId localConnId = c.localConnectionId(); + ByteBuffer resetDatagram = c.endpoint().idFactory().statelessReset(localConnId.asReadOnlyBuffer(), 43); + ((DatagramChannel)c.channel()).send(resetDatagram, c.peerAddress()); + // ignore the request stream; we're expecting the client to close the connection. + // The server won't receive any notification from the client here. + // The connection will leak. + }); + HttpClient client = getHttpClient(); + try { + HttpRequest request = getRequest(); + final HttpResponse response = client.sendAsync( + request, + BodyHandlers.discarding()) + .get(10, TimeUnit.SECONDS); + fail("Expected the request to fail, got " + response); + } catch (Exception e) { + final String expectedMsg = "stateless reset from peer"; + if (e.getMessage() != null && e.getMessage().contains(expectedMsg)) { + // got the expected exception + return; + } + // unexpected exception, throw it back + throw e; + } finally { + client.shutdownNow(); + } + } + + /** + * Server opens a bidi stream + */ + @Test + @Ignore("BiDi streams are rejected by H3 client at QUIC level") + public void testBidiStream() throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + QuicSenderStream bidiStream; + bidiStream = c.openNewLocalBidiStream(Duration.ZERO).resultNow(); + var scheduler = SequentialScheduler.lockingScheduler(() -> { + }); + var writer = bidiStream.connectWriter(scheduler); + // some data + byte[] bytesToWrite = new byte[] { 0, 4, 2, 4, 0 }; + writer.scheduleForWriting(ByteBuffer.wrap(bytesToWrite), false); + // ignore the request stream; we're expecting the client to close the connection + completeUponTermination(c, errorCF); + }); + triggerError(errorCF, Http3Error.H3_STREAM_CREATION_ERROR); + } + + /** + * Server closes the connection with a known QUIC error + */ + @Test + public void testConnectionCloseQUIC() throws Exception { + server.addHandler((c,s)-> { + TerminationCause tc = TerminationCause.forException( + new QuicTransportException("ignored", null, 0, + QuicTransportErrors.INTERNAL_ERROR) + ); + tc.peerVisibleReason("testtest"); + c.connectionTerminator().terminate(tc); + + }); + triggerClose("INTERNAL_ERROR", "testtest"); + } + + /** + * Server closes the connection with a known crypto error + */ + @Test + public void testConnectionCloseCryptoQUIC() throws Exception { + server.addHandler((c,s)-> { + TerminationCause tc = TerminationCause.forException( + new QuicTransportException("ignored", null, 0, + QuicTransportErrors.CRYPTO_ERROR.from() + 80 /*Alert.INTERNAL_ERROR.id*/, null) + ); + tc.peerVisibleReason("testtest"); + c.connectionTerminator().terminate(tc); + + }); + triggerClose("CRYPTO_ERROR", "internal_error", "testtest"); + } + + /** + * Server closes the connection with an unknown crypto error + */ + @Test + public void testConnectionCloseUnknownCryptoQUIC() throws Exception { + server.addHandler((c,s)-> { + TerminationCause tc = TerminationCause.forException( + new QuicTransportException("ignored", null, 0, + QuicTransportErrors.CRYPTO_ERROR.from() + 5, null) + ); + tc.peerVisibleReason("testtest"); + c.connectionTerminator().terminate(tc); + + }); + triggerClose("CRYPTO_ERROR", "5", "testtest"); + } + + /** + * Server closes the connection with an unknown QUIC error + */ + @Test + public void testConnectionCloseUnknownQUIC() throws Exception { + server.addHandler((c,s)-> { + TerminationCause tc = TerminationCause.forException( + new QuicTransportException("ignored", null, 0, + QuicTransportErrors.CRYPTO_ERROR.to() + 1 /*0x200*/, null) + ); + tc.peerVisibleReason("testtest"); + c.connectionTerminator().terminate(tc); + + }); + triggerClose("200", "testtest"); + } + + /** + * Server closes the connection with a known H3 error + */ + @Test + public void testConnectionCloseH3() throws Exception { + server.addHandler((c,s)-> { + TerminationCause tc = TerminationCause.appLayerClose(Http3Error.H3_EXCESSIVE_LOAD.code()); + tc.peerVisibleReason("testtest"); + c.connectionTerminator().terminate(tc); + + }); + triggerClose("H3_EXCESSIVE_LOAD", "testtest"); + } + + /** + * Server closes the connection with an unknown H3 error + */ + @Test + public void testConnectionCloseH3Unknown() throws Exception { + server.addHandler((c,s)-> { + TerminationCause tc = TerminationCause.appLayerClose(0x1f21); + tc.peerVisibleReason("testtest"); + c.connectionTerminator().terminate(tc); + + }); + triggerClose("1F21", "testtest"); + } + + + private void triggerClose(String... reasons) throws Exception { + HttpClient client = getHttpClient(); + try { + HttpRequest request = getRequest(); + final HttpResponse response = client.sendAsync( + request, + BodyHandlers.discarding()) + .get(10, TimeUnit.SECONDS); + fail("Expected the request to fail, got " + response); + } catch (ExecutionException e) { + System.out.println("Client exception [expected]: " + e); + var cause = e.getCause(); + assertTrue(cause instanceof IOException, "Expected IOException"); + for (String reason : reasons) { + assertTrue(cause.getMessage().contains(reason), + cause.getMessage() + " does not contain " + reason); + } + } finally { + client.shutdownNow(); + } + } + + + private void triggerError(CompletableFuture errorCF, Http3Error expected) throws Exception { + HttpClient client = getHttpClient(); + try { + HttpRequest request = getRequest(); + final HttpResponse response = client.sendAsync( + request, + BodyHandlers.discarding()) + .get(10, TimeUnit.SECONDS); + fail("Expected the request to fail, got " + response); + } catch (ExecutionException e) { + System.out.println("Client exception [expected]: " + e); + var cause = e.getCause(); + assertTrue(cause instanceof ProtocolException, "Expected ProtocolException"); + TerminationCause terminationCause = errorCF.get(10, TimeUnit.SECONDS); + System.out.println("Server reason: \"" + terminationCause.getPeerVisibleReason()+'"'); + final long actual = terminationCause.getCloseCode(); + // expected + assertEquals(actual, expected.code(), "Expected " + toHexString(expected) + " got 0x" + Long.toHexString(actual)); + } finally { + client.shutdownNow(); + } + } + + private void triggerError(CompletableFuture errorCF, Http3Error... expected) throws Exception { + HttpClient client = getHttpClient(); + try { + HttpRequest request = getRequest(); + final HttpResponse response = client.sendAsync( + request, + BodyHandlers.discarding()) + .get(10, TimeUnit.SECONDS); + fail("Expected the request to fail, got " + response); + } catch (ExecutionException e) { + System.out.println("Client exception [expected]: " + e); + var cause = e.getCause(); + assertTrue(cause instanceof ProtocolException, "Expected ProtocolException"); + TerminationCause terminationCause = errorCF.get(10, TimeUnit.SECONDS); + System.out.println("Server reason: \"" + terminationCause.getPeerVisibleReason()+'"'); + final long actual = terminationCause.getCloseCode(); + // expected + Optional h3Actual = Http3Error.fromCode(actual); + assertTrue(h3Actual.isPresent(), "Expected HTTP3 error, got 0x" + Long.toHexString(actual)); + Set expectedErrors = Set.of(expected); + assertTrue(expectedErrors.contains(h3Actual.get()), "Expected "+expectedErrors+ + ", got: "+h3Actual); + } finally { + client.shutdownNow(); + } + } + + private void triggerPushError(CompletableFuture errorCF, Http3Error http3Error) throws Exception { + HttpClient client = getHttpClient(); + // close might block; use shutdownNow instead + try { + HttpRequest request = getRequest(); + final HttpResponse response = client.sendAsync( + request, + BodyHandlers.discarding(), + (initiatingRequest, pushPromiseRequest, acceptor) -> + acceptor.apply(BodyHandlers.discarding()) + ).get(10, TimeUnit.SECONDS); + fail("Expected the request to fail, got " + response); + } catch (ExecutionException e) { + System.out.println("Client exception [expected]: " + e); + var cause = e.getCause(); + assertTrue(cause instanceof ProtocolException, "Expected ProtocolException"); + TerminationCause terminationCause = errorCF.get(10, TimeUnit.SECONDS); + System.out.println("Server reason: \"" + terminationCause.getPeerVisibleReason()+'"'); + final long actual = terminationCause.getCloseCode(); + // expected + assertEquals(actual, http3Error.code(), "Expected " + toHexString(http3Error) + " got 0x" + Long.toHexString(actual)); + } finally { + client.shutdownNow(); + } + } + + private HttpRequest getRequest() throws URISyntaxException { + final URI reqURI = new URI(requestURIBase + "/hello"); + final HttpRequest.Builder reqBuilder = HttpRequest.newBuilder(reqURI) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + return reqBuilder.build(); + } + + private HttpClient getHttpClient() { + final HttpClient client = newClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .version(HTTP_3) + .sslContext(sslContext).build(); + return client; + } + + private static String toHexString(final Http3Error error) { + return error.name() + "(0x" + Long.toHexString(error.code()) + ")"; + } + + private static void completeUponTermination(final QuicServerConnection serverConnection, + final CompletableFuture cf) { + serverConnection.futureTerminationCause().handle( + (r,t) -> t != null ? cf.completeExceptionally(t) : cf.complete(r)); + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3FixedThreadPoolTest.java b/test/jdk/java/net/httpclient/http3/H3FixedThreadPoolTest.java new file mode 100644 index 00000000000..9e801ac2ea4 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3FixedThreadPoolTest.java @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8087112 8177935 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.test.lib.Asserts + * jdk.test.lib.Utils + * jdk.test.lib.net.SimpleSSLContext + * @compile ../ReferenceTracker.java + * @run testng/othervm -Djdk.internal.httpclient.debug=err + * -Djdk.httpclient.HttpClient.log=ssl,headers,requests,responses,errors + * H3FixedThreadPoolTest + */ + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.file.Path; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.test.lib.Utils; +import jdk.test.lib.net.SimpleSSLContext; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static jdk.test.lib.Asserts.assertFileContentsEqual; +import static jdk.test.lib.Utils.createTempFileOfSize; + +import org.testng.annotations.Test; + +public class H3FixedThreadPoolTest implements HttpServerAdapters { + + private static final String CLASS_NAME = H3FixedThreadPoolTest.class.getSimpleName(); + + static int http3Port, https2Port; + static HttpTestServer http3Server, https2Server; + static volatile HttpClient client = null; + static ExecutorService exec; + static SSLContext sslContext; + + static String http3URIString, https2URIString; + + static void initialize() throws Exception { + try { + SimpleSSLContext sslct = new SimpleSSLContext(); + sslContext = sslct.get(); + client = getClient(); + exec = Executors.newCachedThreadPool(); + http3Server = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext, exec); + http3Server.addHandler(new HttpTestFileEchoHandler(), "/H3FixedThreadPoolTest/http3-only/"); + http3Port = http3Server.getAddress().getPort(); + + https2Server = HttpTestServer.create(ALT_SVC, sslContext, exec); + https2Server.addHandler(new HttpTestFileEchoHandler(), "/H3FixedThreadPoolTest/http3-alt-svc/"); + https2Server.addHandler((t) -> { + t.getRequestBody().readAllBytes(); + t.sendResponseHeaders(200, 0); + }, "/H3FixedThreadPoolTest/http3-alt-svc/bar/head"); + https2Port = https2Server.getAddress().getPort(); + http3URIString = "https://" + http3Server.serverAuthority() + "/H3FixedThreadPoolTest/http3-only/foo/"; + https2URIString = "https://" + https2Server.serverAuthority() + "/H3FixedThreadPoolTest/http3-alt-svc/bar/"; + + http3Server.start(); + https2Server.start(); + + // warmup client to populate AltServiceRegistry + var head = HttpRequest.newBuilder(URI.create(https2URIString + "head")) + .setOption(H3_DISCOVERY, ALT_SVC).build(); + var resp = client.send(head, BodyHandlers.ofString()); + assert resp.statusCode() == 200; + + } catch (Throwable e) { + System.err.println("Throwing now"); + e.printStackTrace(); + throw e; + } + } + + @Test + public static void test() throws Exception { + try { + initialize(); + simpleTest(false); + simpleTest(true); + streamTest(false); + streamTest(true); + paramsTest(); + if (client != null) { + ReferenceTracker.INSTANCE.track(client); + client = null; + System.gc(); + var error = ReferenceTracker.INSTANCE.check(4000); + if (error != null) throw error; + } + } catch (Exception | Error tt) { + tt.printStackTrace(); + throw tt; + } finally { + http3Server.stop(); + https2Server.stop(); + exec.shutdownNow(); + } + } + + static HttpClient getClient() { + if (client == null) { + // Executor e1 = Executors.newFixedThreadPool(1); + // Executor e = (Runnable r) -> e1.execute(() -> { + // System.out.println("[" + Thread.currentThread().getName() + // + "] Executing: " + // + r.getClass().getName()); + // r.run(); + // }); + client = HttpServerAdapters.createClientBuilderForH3() + .executor(Executors.newFixedThreadPool(2)) + .sslContext(sslContext) + .version(HTTP_3) + .build(); + } + return client; + } + + static URI getURI(boolean http3Only) { + if (http3Only) + return URI.create(http3URIString); + else + return URI.create(https2URIString); + } + + static void checkStatus(int expected, int found) throws Exception { + if (expected != found) { + System.err.printf ("Test failed: wrong status code %d/%d\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static void checkStrings(String expected, String found) throws Exception { + if (!expected.equals(found)) { + System.err.printf ("Test failed: wrong string \"%s\" != \"%s\"%n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static final String SIMPLE_STRING = "Hello world Goodbye world"; + + static final int LOOPS = 32; + static final int FILESIZE = 64 * 1024 + 200; + + static void streamTest(boolean http3only) throws Exception { + URI uri = getURI(http3only); + System.out.printf("%nstreamTest %b to %s%n" , http3only, uri); + System.err.printf("%nstreamTest %b to %s%n" , http3only, uri); + var config = http3only ? HTTP_3_URI_ONLY : ALT_SVC; + + HttpClient client = getClient(); + Path src = createTempFileOfSize(CLASS_NAME, ".dat", FILESIZE * 4); + HttpRequest req = HttpRequest.newBuilder(uri) + .setOption(H3_DISCOVERY, config) + .POST(BodyPublishers.ofFile(src)) + .build(); + + Path dest = Path.of("streamtest.txt"); + dest.toFile().delete(); + CompletableFuture response = client.sendAsync(req, BodyHandlers.ofFile(dest)) + .thenApply(resp -> { + if (resp.statusCode() != 200) + throw new RuntimeException(); + return resp.body(); + }); + response.join(); + assertFileContentsEqual(src, dest); + System.err.println("DONE"); + } + + // expect highest supported version we know about + static String expectedTLSVersion(SSLContext ctx) { + SSLParameters params = ctx.getSupportedSSLParameters(); + String[] protocols = params.getProtocols(); + for (String prot : protocols) { + if (prot.equals("TLSv1.3")) + return "TLSv1.3"; + } + return "TLSv1.2"; + } + + static void paramsTest() throws Exception { + System.out.println("\nparamsTest"); + System.err.println("\nparamsTest"); + HttpTestServer server = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + server.addHandler((t -> { + SSLSession s = t.getSSLSession(); + String prot = s.getProtocol(); + System.err.println("Server: paramsTest: " + prot); + if (prot.equals(expectedTLSVersion(sslContext))) { + t.sendResponseHeaders(200, 0); + } else { + System.err.printf("Protocols =%s%n", prot); + t.sendResponseHeaders(500, 0); + } + }), "/"); + server.start(); + try { + URI u = new URI("https://" + server.serverAuthority() + "/paramsTest"); + HttpClient client = getClient(); + HttpRequest req = HttpRequest.newBuilder(u).setOption(H3_DISCOVERY, HTTP_3_URI_ONLY).build(); + HttpResponse resp = client.sendAsync(req, BodyHandlers.ofString()).get(); + int stat = resp.statusCode(); + if (stat != 200) { + throw new RuntimeException("paramsTest failed " + stat); + } + } finally { + server.stop(); + } + } + + static void simpleTest(boolean http3only) throws Exception { + System.out.println("\nsimpleTest http3-only=" + http3only); + System.err.println("\nsimpleTest http3-only=" + http3only); + URI uri = getURI(http3only); + var config = http3only ? HTTP_3_URI_ONLY : ALT_SVC; + System.err.println("Request to " + uri); + + // Do a simple warmup request + + HttpClient client = getClient(); + HttpRequest req = HttpRequest.newBuilder(uri) + .POST(BodyPublishers.ofString(SIMPLE_STRING)) + .setOption(H3_DISCOVERY, config) + .build(); + HttpResponse response = client.sendAsync(req, BodyHandlers.ofString()).get(); + HttpHeaders h = response.headers(); + + checkStatus(200, response.statusCode()); + + String responseBody = response.body(); + checkStrings(SIMPLE_STRING, responseBody); + + checkStrings(h.firstValue("x-hello").get(), "world"); + checkStrings(h.firstValue("x-bye").get(), "universe"); + + // Do loops asynchronously + + CompletableFuture[] responses = new CompletableFuture[LOOPS]; + final Path source = createTempFileOfSize(CLASS_NAME, ".dat", FILESIZE); + HttpRequest request = HttpRequest.newBuilder(uri) + .setOption(H3_DISCOVERY, config) + .POST(BodyPublishers.ofFile(source)) + .build(); + for (int i = 0; i < LOOPS; i++) { + Path requestPayloadFile = Utils.createTempFile(CLASS_NAME, ".dat"); + responses[i] = client.sendAsync(request, BodyHandlers.ofFile(requestPayloadFile)) + //.thenApply(resp -> compareFiles(resp.body(), source)); + .thenApply(resp -> { + System.out.printf("Resp status %d body size %d\n", + resp.statusCode(), resp.body().toFile().length()); + assertFileContentsEqual(resp.body(), source); + return null; + }); + } + CompletableFuture.allOf(responses).join(); + System.err.println("DONE"); + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3GoAwayTest.java b/test/jdk/java/net/httpclient/http3/H3GoAwayTest.java new file mode 100644 index 00000000000..fc6166cb66c --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3GoAwayTest.java @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.OutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http2.Http2Handler; +import jdk.httpclient.test.lib.http2.Http2TestExchange; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/* + * @test + * @summary verifies that when the server sends a GOAWAY frame then + * the client correctly handles it and retries unprocessed requests + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * @run junit/othervm + * -Djdk.httpclient.HttpClient.log=errors,headers,quic:hs,http3 + * H3GoAwayTest + */ +public class H3GoAwayTest { + + private static String REQ_URI_BASE; + private static SSLContext sslCtx; + private static Http3TestServer server; + + @BeforeAll + static void beforeAll() throws Exception { + sslCtx = new SimpleSSLContext().get(); + assertNotNull(sslCtx, "SSLContext couldn't be created"); + server = new Http3TestServer(sslCtx); + final RequestApprover reqApprover = new RequestApprover(); + server.setRequestApprover(reqApprover::allowNewRequest); + server.addHandler("/test", new Handler()); + server.start(); + System.out.println("Server started at " + server.getAddress()); + REQ_URI_BASE = URIBuilder.newBuilder().scheme("https") + .loopback() + .port(server.getAddress().getPort()) + .path("/test") + .build().toString(); + } + + @AfterAll + static void afterAll() throws Exception { + if (server != null) { + System.out.println("stopping server at " + server.getAddress()); + server.close(); + } + } + + /** + * Verifies that when several requests are sent using send() and the server + * connection is configured to send a GOAWAY after processing only a few requests, then + * the remaining requests are retried on a different connection + */ + @Test + public void testSequential() throws Exception { + try (final HttpClient client = HttpServerAdapters + .createClientBuilderFor(HTTP_3) + .proxy(HttpClient.Builder.NO_PROXY) + .version(HTTP_3) + .sslContext(sslCtx).build()) { + final String[] reqMethods = {"HEAD", "GET", "POST"}; + for (final String reqMethod : reqMethods) { + final int numReqs = RequestApprover.MAX_REQS_PER_CONN + 3; + final Set connectionKeys = new LinkedHashSet<>(); + for (int i = 1; i <= numReqs; i++) { + final URI reqURI = new URI(REQ_URI_BASE + "?" + reqMethod + "=" + i); + final HttpRequest req = HttpRequest.newBuilder() + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .uri(reqURI) + .method(reqMethod, BodyPublishers.noBody()) + .build(); + System.out.println("initiating request " + req); + final HttpResponse resp = client.send(req, BodyHandlers.ofString()); + final String respBody = resp.body(); + System.out.println("received response: " + respBody); + assertEquals(200, resp.statusCode(), + "unexpected status code for request " + resp.request()); + // response body is the logical key of the connection on which the + // request was handled + connectionKeys.add(respBody); + } + System.out.println("connections involved in handling the requests: " + + connectionKeys); + // all requests have finished, we now just do a basic check that + // more than one connection was involved in processing these requests + assertEquals(2, connectionKeys.size(), + "unexpected number of connections " + connectionKeys); + } + } + } + + private static final class RequestApprover { + private static final int MAX_REQS_PER_CONN = 6; + private final Map numApproved = + new ConcurrentHashMap<>(); + private final Map numDisapproved = + new ConcurrentHashMap<>(); + + public boolean allowNewRequest(final String connKey) { + final AtomicInteger approved = numApproved.computeIfAbsent(connKey, + (k) -> new AtomicInteger()); + int curr = approved.get(); + while (curr < MAX_REQS_PER_CONN) { + if (approved.compareAndSet(curr, curr + 1)) { + return true; // new request allowed + } + curr = approved.get(); + } + final AtomicInteger disapproved = numDisapproved.computeIfAbsent(connKey, + (k) -> new AtomicInteger()); + final int numUnprocessed = disapproved.incrementAndGet(); + System.out.println(approved.get() + " processed, " + + numUnprocessed + " unprocessed requests so far," + + " sending GOAWAY on connection " + connKey); + return false; + } + } + + private static final class Handler implements Http2Handler { + @Override + public void handle(final Http2TestExchange exchange) throws IOException { + final String connectionKey = exchange.getConnectionKey(); + System.out.println(connectionKey + " responding to request: " + exchange.getRequestURI()); + final byte[] response = connectionKey.getBytes(UTF_8); + exchange.sendResponseHeaders(200, response.length); + try (final OutputStream os = exchange.getResponseBody()) { + os.write(response); + } + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3HeaderSizeLimitTest.java b/test/jdk/java/net/httpclient/http3/H3HeaderSizeLimitTest.java new file mode 100644 index 00000000000..7369bb9190e --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3HeaderSizeLimitTest.java @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.net.ProtocolException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.time.Duration; +import java.util.concurrent.ExecutionException; + +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.httpclient.test.lib.quic.QuicServer; +import jdk.internal.net.http.Http3ConnectionAccess; +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.test.lib.Utils; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; + +/* + * @test + * @summary Verifies that the HTTP client respects the SETTINGS_MAX_FIELD_SECTION_SIZE setting on HTTP3 connection + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @library ../access + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @build java.net.http/jdk.internal.net.http.Http3ConnectionAccess + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors H3HeaderSizeLimitTest + */ +public class H3HeaderSizeLimitTest implements HttpServerAdapters { + + private static final long HEADER_SIZE_LIMIT_BYTES = 1024; + private SSLContext sslContext; + private HttpTestServer h3Server; + private String requestURIBase; + + @BeforeClass + public void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + final QuicServer quicServer = Http3TestServer.quicServerBuilder() + .sslContext(sslContext) + .build(); + h3Server = HttpTestServer.of(new Http3TestServer(quicServer) + .setConnectionSettings(new ConnectionSettings(HEADER_SIZE_LIMIT_BYTES, 0, 0))); + h3Server.addHandler((exchange) -> exchange.sendResponseHeaders(200, 0), "/hello"); + h3Server.start(); + System.out.println("Server started at " + h3Server.getAddress()); + requestURIBase = URIBuilder.newBuilder().scheme("https").loopback() + .port(h3Server.getAddress().getPort()).build().toString(); + } + + @AfterClass + public void afterClass() throws Exception { + if (h3Server != null) { + System.out.println("Stopping server " + h3Server.getAddress()); + h3Server.stop(); + } + } + + /** + * Issues a HTTP3 request with combined request headers size exceeding the limit set by the + * test server. Verifies that such requests fail. + */ + @Test + public void testLargeHeaderSize() throws Exception { + final HttpClient client = newClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .version(Version.HTTP_3) + // the server drops 1 packet out of two! + .connectTimeout(Duration.ofSeconds(Utils.adjustTimeout(10))) + .sslContext(sslContext).build(); + final URI reqURI = new URI(requestURIBase + "/hello"); + final HttpRequest.Builder reqBuilder = HttpRequest.newBuilder(reqURI) + .version(Version.HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + // issue a few requests so that enough time has passed to allow the SETTINGS frame from the + // server to have reached the client. + for (int i = 1; i <= 3; i++) { + System.out.println("Issuing warmup request " + i + " to " + reqURI); + final HttpResponse response = client.send( + reqBuilder.build(), + BodyHandlers.discarding()); + Assert.assertEquals(response.statusCode(), 200, "Unexpected status code"); + if (i == 3) { + var cf = Http3ConnectionAccess.peerSettings(client, response); + if (!cf.isDone()) { + System.out.println("Waiting for peer settings"); + cf.join(); + } + System.out.println("Got peer settings: " + cf.get()); + } + } + // at this point the client should have processed the SETTINGS frame the server + // and we expect it to start honouring those settings. We start the real testing now + // create headers that are larger than the headers size limit that has been configured on + // the server by this test + final String headerValue = headerValueOfLargeSize(); + for (int i = 0; i < 10; i++) { + reqBuilder.setHeader("header-" + i, headerValue); + } + final HttpRequest request = reqBuilder.build(); + System.out.println("Issuing request to " + reqURI); + final IOException thrown = Assert.expectThrows(ProtocolException.class, + () -> client.send(request, BodyHandlers.discarding())); + if (!thrown.getMessage().equals("Request headers size exceeds limit set by peer")) { + throw thrown; + } + // test same with async + System.out.println("Issuing async request to " + reqURI); + final ExecutionException asyncThrown = Assert.expectThrows(ExecutionException.class, + () -> client.sendAsync(request, BodyHandlers.discarding()).get()); + if (!(asyncThrown.getCause() instanceof ProtocolException)) { + System.err.println("Received unexpected cause"); + throw asyncThrown; + } + if (!asyncThrown.getCause().getMessage().equals("Request headers size exceeds limit set by peer")) { + System.err.println("Received unexpected message in cause"); + throw asyncThrown; + } + } + + private static String headerValueOfLargeSize() { + return "abcdefgh".repeat(250); + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3HeadersEncoding.java b/test/jdk/java/net/httpclient/http3/H3HeadersEncoding.java new file mode 100644 index 00000000000..e63ae88ed32 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3HeadersEncoding.java @@ -0,0 +1,315 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.test.lib.net.SimpleSSLContext + * @compile ../ReferenceTracker.java + * @run testng/othervm -Djdk.httpclient.qpack.encoderTableCapacityLimit=4096 + * -Djdk.httpclient.qpack.decoderMaxTableCapacity=4096 + * -Dhttp3.test.server.encoderAllowedHeaders=* + * -Dhttp3.test.server.decoderMaxTableCapacity=4096 + * -Dhttp3.test.server.encoderTableCapacityLimit=4096 + * -Djdk.internal.httpclient.qpack.log.level=NORMAL + * H3HeadersEncoding + * @summary this test verifies that when QPACK dynamic table is enabled multiple + * random headers can be encoded/decoded correctly + */ + +import jdk.test.lib.RandomFactory; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.HttpResponse; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static jdk.httpclient.test.lib.common.HttpServerAdapters.*; + +public class H3HeadersEncoding { + + private static final int REQUESTS_COUNT = 500; + private static final int HEADERS_PER_REQUEST = 20; + SSLContext sslContext; + HttpTestServer http3TestServer; + HeadersHandler serverHeadersHandler; + String http3URI; + + @BeforeTest + public void setup() throws Exception { + System.out.println("Creating servers"); + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) + throw new AssertionError("Unexpected null sslContext"); + + http3TestServer = HttpTestServer.create(Http3DiscoveryMode.HTTP_3_URI_ONLY, sslContext); + serverHeadersHandler = new HeadersHandler(); + http3TestServer.addHandler(serverHeadersHandler, "/http3/headers"); + http3URI = "https://" + http3TestServer.serverAuthority() + "/http3/headers"; + + http3TestServer.start(); + } + + @AfterTest + public void tearDown() { + http3TestServer.stop(); + } + + @Test + public void serialRequests() throws Exception { + try (HttpClient client = newClient()) { + for (int i = 0; i < REQUESTS_COUNT; i++) { + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(URI.create(http3URI)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + List rndHeaders = TestHeader.randomHeaders(HEADERS_PER_REQUEST); + String[] requestHeaders = rndHeaders.stream() + .flatMap(th -> Stream.of(th.name(), th.value())) + .toArray(String[]::new); + requestBuilder.headers(requestHeaders); + + // Send client request headers in request body for further check + // on the server handler side + requestBuilder.POST(HttpRequest.BodyPublishers.ofString( + TestHeader.headersToBodyContent(rndHeaders))); + + HttpResponse response = + client.send(requestBuilder.build(), HttpResponse.BodyHandlers.ofString()); + // Headers received by the client + var serverHeadersClientSide = TestHeader.fromHttpHeaders(response.headers()); + // Headers sent by the server handler + var serverHeadersServerSide = TestHeader.bodyContentToHeaders(response.body()); + + // Check that all headers that server sent are received on client side + checkHeaders(serverHeadersServerSide, serverHeadersClientSide, "client"); + } + } + } + + @Test + public void asyncRequests() throws Exception { + try (HttpClient client = newClient()) { + ArrayList> requestsHeaders = new ArrayList<>(REQUESTS_COUNT); + ArrayList requests = new ArrayList<>(REQUESTS_COUNT); + CopyOnWriteArrayList>> futureReplies + = new CopyOnWriteArrayList<>(); + + // Prepare all requests first + for (int i = 0; i < REQUESTS_COUNT; i++) { + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(URI.create(http3URI)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + List rndHeaders = TestHeader.randomHeaders(HEADERS_PER_REQUEST); + String[] requestHeaders = rndHeaders.stream() + .flatMap(th -> Stream.of(th.name(), th.value())) + .toArray(String[]::new); + requestBuilder.headers(requestHeaders); + requestsHeaders.add(i, rndHeaders); + + // Send client request headers in request body for further check + // on the server handler side + requestBuilder.POST(HttpRequest.BodyPublishers.ofString( + TestHeader.headersToBodyContent(rndHeaders))); + requests.add(i, requestBuilder.build()); + } + + // Send async request + for (int i = 0; i < REQUESTS_COUNT; i++) { + CompletableFuture> cf = + client.sendAsync(requests.get(i), + HttpResponse.BodyHandlers.ofString()); + futureReplies.add(i, cf); + } + + // Join on all CF responses + for (int i = 0; i < REQUESTS_COUNT; i++) { + futureReplies.get(i).join(); + } + + // Check the responses + for (int i = 0; i < REQUESTS_COUNT; i++) { + HttpResponse reply = futureReplies.get(i).get(); + // Headers received by the client + var serverHeadersClientSide = TestHeader.fromHttpHeaders(reply.headers()); + // Headers sent by the server handler + var serverHeadersServerSide = TestHeader.bodyContentToHeaders(reply.body()); + // Check that all headers that server sent are received on client side + checkHeaders(serverHeadersServerSide, serverHeadersClientSide, "client"); + } + } + } + + + private static void checkHeaders(List mustPresent, + List allHeaders, + String sideDescription) { + List notFound = new ArrayList<>(); + for (var header : mustPresent) { + if (!allHeaders.contains(header)) { + notFound.add(header); + } + } + if (!notFound.isEmpty()) { + System.err.println("The following headers was not found on " + + sideDescription + " side: " + notFound); + throw new RuntimeException("Headers not found: " + notFound); + } + } + + HttpClient newClient() { + var builder = createClientBuilderForH3() + .sslContext(sslContext) + .version(HTTP_3) + .proxy(HttpClient.Builder.NO_PROXY); + return builder.build(); + } + + private static final Random RND = RandomFactory.getRandom(); + + record TestHeader(String name, String value) { + static TestHeader randomHeader() { + // It is better to have same id generated two or more + // times during the test run, therefore the range below + int headerId = RND.nextInt(10, 10 + HEADERS_PER_REQUEST * 3); + return new TestHeader("test_header" + headerId, "TestValue"); + } + + static List randomHeaders(int count) { + return IntStream + .range(0, count) + .boxed() + .map(ign -> TestHeader.randomHeader()) + .toList(); + } + + static List fromHttpHeaders(HttpHeaders httpHeaders) { + var headersMap = httpHeaders.map(); + return fromHeadersEntrySet(headersMap.entrySet()); + } + + static List fromTestHttpHeaders(HttpTestRequestHeaders requestHeaders) { + var headersSet = requestHeaders.entrySet(); + return fromHeadersEntrySet(headersSet); + } + + private static List fromHeadersEntrySet(Set>> entrySet) { + return entrySet.stream() + .flatMap(entry -> { + var name = entry.getKey(); + return entry.getValue() + .stream() + .map(value -> new TestHeader(name, value)); + }).toList(); + } + + public static String headersToBodyContent(List rndHeaders) { + return rndHeaders.stream() + .map(TestHeader::toString) + .collect(Collectors.joining(System.lineSeparator())); + } + + public static List bodyContentToHeaders(String bodyContent) { + return Arrays.stream(bodyContent.split(System.lineSeparator())) + .filter(Predicate.not(String::isBlank)) + .map(String::strip) + .map(TestHeader::fromBodyHeaderLine) + .toList(); + } + + public static TestHeader fromBodyHeaderLine(String headerLine) { + String[] parts = headerLine.split(":"); + if (parts.length != 2) { + throw new RuntimeException("Internal test error"); + } + return new TestHeader(parts[0], parts[1]); + } + + public String toString() { + return name + ":" + value; + } + } + + + private class HeadersHandler implements HttpTestHandler { + @Override + public void handle(HttpTestExchange t) throws IOException { + + var clientHeadersServerSide = TestHeader.fromTestHttpHeaders(t.getRequestHeaders()); + + String requestBody; + try (InputStream is = t.getRequestBody()) { + byte[] body = is.readAllBytes(); + requestBody = new String(body); + } + var clientHeadersClientSide = TestHeader.bodyContentToHeaders(requestBody); + + // Check that all headers that client sent are received on tne server side + checkHeaders(clientHeadersClientSide, clientHeadersServerSide, + "server handler"); + + // Response back with a set of random headers + var responseHeaders = t.getResponseHeaders(); + List serverResp = new ArrayList<>(); + for (TestHeader h : TestHeader.randomHeaders(HEADERS_PER_REQUEST)) { + serverResp.add(h); + responseHeaders.addHeader(h.name(), h.value()); + } + + String responseBody = TestHeader.headersToBodyContent(serverResp); + try (OutputStream os = t.getResponseBody()) { + byte[] responseBodyBytes = responseBody.getBytes(); + t.sendResponseHeaders(200, responseBodyBytes.length); + if (!t.getRequestMethod().equals("HEAD")) { + os.write(responseBodyBytes); + } + } + } + } + +} diff --git a/test/jdk/java/net/httpclient/http3/H3ImplicitPushCancel.java b/test/jdk/java/net/httpclient/http3/H3ImplicitPushCancel.java new file mode 100644 index 00000000000..735a68c88cf --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3ImplicitPushCancel.java @@ -0,0 +1,258 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=errors,requests,responses,trace + * H3ImplicitPushCancel + * @summary This is a clone of http2/ImplicitPushCancel but for HTTP/3 + */ + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpClient.Version; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.net.http.HttpResponse.PushPromiseHandler; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.internal.net.http.common.Utils; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; + +public class H3ImplicitPushCancel implements HttpServerAdapters { + + static Map PUSH_PROMISES = Map.of( + "/x/y/z/1", "the first push promise body", + "/x/y/z/2", "the second push promise body", + "/x/y/z/3", "the third push promise body", + "/x/y/z/4", "the fourth push promise body", + "/x/y/z/5", "the fifth push promise body", + "/x/y/z/6", "the sixth push promise body", + "/x/y/z/7", "the seventh push promise body", + "/x/y/z/8", "the eight push promise body", + "/x/y/z/9", "the ninth push promise body" + ); + static final String MAIN_RESPONSE_BODY = "the main response body"; + + HttpTestServer server; + URI uri; + URI headURI; + + @BeforeTest + public void setup() throws Exception { + server = HttpTestServer.create(ANY, new SimpleSSLContext().get()); + HttpTestHandler pushHandler = new ServerPushHandler(MAIN_RESPONSE_BODY, + PUSH_PROMISES); + server.addHandler(pushHandler, "/push/"); + server.addHandler(new HttpHeadOrGetHandler(), "/head/"); + server.start(); + System.err.println("Server listening on port " + server.serverAuthority()); + uri = new URI("https://" + server.serverAuthority() + "/push/a/b/c"); + headURI = new URI("https://" + server.serverAuthority() + "/head/x"); + } + + @AfterTest + public void teardown() { + server.stop(); + } + + static HttpResponse assert200ResponseCode(HttpResponse response) { + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), Version.HTTP_3); + return response; + } + + private void sendHeadRequest(HttpClient client) throws IOException, InterruptedException { + HttpRequest headRequest = HttpRequest.newBuilder(headURI) + .HEAD().version(Version.HTTP_2).build(); + var headResponse = client.send(headRequest, BodyHandlers.ofString()); + assertEquals(headResponse.statusCode(), 200); + assertEquals(headResponse.version(), Version.HTTP_2); + } + + /* + * With a handler not capable of accepting push promises, then all push + * promises should be rejected / cancelled, without interfering with the + * main response. + */ + @Test + public void test() throws Exception { + try (HttpClient client = newClientBuilderForH3() + .proxy(Builder.NO_PROXY) + .version(Version.HTTP_3) + .sslContext(new SimpleSSLContext().get()) + .build()) { + + sendHeadRequest(client); + + // Send with no promise handler + try { + client.sendAsync(HttpRequest.newBuilder(uri) + .build(), BodyHandlers.ofString()) + .thenApply(H3ImplicitPushCancel::assert200ResponseCode) + .thenApply(HttpResponse::body) + .thenAccept(body -> assertEquals(body, MAIN_RESPONSE_BODY)) + .join(); + System.out.println("Got result before error was raised"); + throw new AssertionError("should have failed"); + } catch (CompletionException c) { + Throwable cause = Utils.getCompletionCause(c); + if (cause.getMessage().contains("Max pushId exceeded")) { + System.out.println("Got expected exception: " + cause); + } else throw new AssertionError(cause); + } + + // Send with promise handler + ConcurrentMap>> promises + = new ConcurrentHashMap<>(); + PushPromiseHandler pph = PushPromiseHandler + .of((r) -> BodyHandlers.ofString(), promises); + + HttpResponse main; + try { + main = client.sendAsync( + HttpRequest.newBuilder(uri) + .header("X-WaitForPushId", String.valueOf(1)) + .build(), + BodyHandlers.ofString(), + pph) + .join(); + } catch (CompletionException c) { + throw new AssertionError(c.getCause()); + } + + promises.forEach((key, value) -> System.out.println(key + ":" + value.join().body())); + + promises.putIfAbsent(main.request(), CompletableFuture.completedFuture(main)); + promises.forEach((request, value) -> { + HttpResponse response = value.join(); + assertEquals(response.statusCode(), 200); + if (PUSH_PROMISES.containsKey(request.uri().getPath())) { + assertEquals(response.body(), PUSH_PROMISES.get(request.uri().getPath())); + } else { + assertEquals(response.body(), MAIN_RESPONSE_BODY); + } + }); + assertEquals(promises.size(), PUSH_PROMISES.size() + 1); + + promises.clear(); + + // Send with no promise handler + try { + client.sendAsync(HttpRequest.newBuilder(uri).build(), BodyHandlers.ofString()) + .thenApply(H3ImplicitPushCancel::assert200ResponseCode) + .thenApply(HttpResponse::body) + .thenAccept(body -> assertEquals(body, MAIN_RESPONSE_BODY)) + .join(); + } catch (CompletionException c) { + throw new AssertionError(c.getCause()); + } + + assertEquals(promises.size(), 0); + } + } + + + // --- server push handler --- + static class ServerPushHandler implements HttpTestHandler { + + private final String mainResponseBody; + private final Map promises; + + public ServerPushHandler(String mainResponseBody, + Map promises) + throws Exception + { + Objects.requireNonNull(promises); + this.mainResponseBody = mainResponseBody; + this.promises = promises; + } + + public void handle(HttpTestExchange exchange) throws IOException { + System.err.println("Server: handle " + exchange); + try (InputStream is = exchange.getRequestBody()) { + is.readAllBytes(); + } + + if (exchange.serverPushAllowed()) { + long waitForPushId = exchange.getRequestHeaders() + .firstValueAsLong("X-WaitForPushId").orElse(-1); + long allowed = -1; + if (waitForPushId >= 0) { + while (allowed <= waitForPushId) { + try { + allowed = exchange.waitForHttp3MaxPushId(waitForPushId); + System.err.println("Got maxPushId: " + allowed); + } catch (InterruptedException ie) { + ie.printStackTrace(); + } + } + } + pushPromises(exchange); + } + + // response data for the main response + try (OutputStream os = exchange.getResponseBody()) { + byte[] bytes = mainResponseBody.getBytes(UTF_8); + exchange.sendResponseHeaders(200, bytes.length); + os.write(bytes); + } + } + + private void pushPromises(HttpTestExchange exchange) throws IOException { + URI requestURI = exchange.getRequestURI(); + for (Map.Entry promise : promises.entrySet()) { + URI uri = requestURI.resolve(promise.getKey()); + InputStream is = new ByteArrayInputStream(promise.getValue().getBytes(UTF_8)); + HttpHeaders headers = HttpHeaders.of(Collections.emptyMap(), (x, y) -> true); + exchange.serverPush(uri, headers, is); + } + System.err.println("Server: All pushes sent"); + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3InsertionsLimitTest.java b/test/jdk/java/net/httpclient/http3/H3InsertionsLimitTest.java new file mode 100644 index 00000000000..7ca07a30ab9 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3InsertionsLimitTest.java @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.httpclient.test.lib.quic.QuicServer; +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.qpack.Encoder; +import jdk.internal.net.http.qpack.TableEntry; +import jdk.test.lib.Utils; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse.BodyHandlers; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; + +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; + +/* + * @test + * @summary Verifies that the HTTP client respects the maxLiteralWithIndexing + * system property value. + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @library ../access + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @build java.net.http/jdk.internal.net.http.Http3ConnectionAccess + * @run testng/othervm -Djdk.httpclient.qpack.encoderTableCapacityLimit=4096 + * -Djdk.internal.httpclient.qpack.allowBlockingEncoding=true + * -Djdk.httpclient.qpack.decoderMaxTableCapacity=4096 + * -Djdk.httpclient.qpack.decoderBlockedStreams=1024 + * -Dhttp3.test.server.encoderAllowedHeaders=* + * -Dhttp3.test.server.decoderMaxTableCapacity=4096 + * -Dhttp3.test.server.encoderTableCapacityLimit=4096 + * -Djdk.httpclient.maxLiteralWithIndexing=32 + * -Djdk.internal.httpclient.qpack.log.level=EXTRA + * H3InsertionsLimitTest + */ +public class H3InsertionsLimitTest implements HttpServerAdapters { + + private static final long HEADER_SIZE_LIMIT_BYTES = 8192; + private static final long MAX_SERVER_DT_CAPACITY = 4096; + private SSLContext sslContext; + private HttpTestServer h3Server; + private String requestURIBase; + public static final long MAX_LITERALS_WITH_INDEXING = 32L; + private static final CountDownLatch WAIT_FOR_FAILURE = new CountDownLatch(1); + + private static void handle(HttpTestExchange exchange) throws IOException { + String handlerMsg = "Server handler: " + exchange.getRequestURI(); + long unusedStreamID = 1111; + System.out.println(handlerMsg); + System.err.println(handlerMsg); + + try { + ConnectionSettings settings = exchange.clientHttp3Settings().get(); + System.err.println("Received client connection settings: " + settings); + } catch (Exception e) { + throw new RuntimeException("Test issue: failure awaiting for HTTP/3" + + " connection settings from the client", e); + } + + // Set encoder table capacity explicitly + // to avoid waiting client's settings frame + // that triggers the same DT configuration + Encoder encoder = exchange.qpackEncoder(); + encoder.setTableCapacity(MAX_SERVER_DT_CAPACITY); + // Mimic entry insertions on the server-side + try (Encoder.EncodingContext context = encoder.newEncodingContext( + unusedStreamID, 0, encoder.newHeaderFrameWriter())) { + for (int i = 0; i <= MAX_LITERALS_WITH_INDEXING; i++) { + var entry = new TableEntry("n" + i, "v" + i); + var insertedEntry = context.tryInsertEntry(entry); + if (insertedEntry.index() == -1L) { + throw new RuntimeException("Test issue: cannot insert" + + " entry to the encoder dynamic table"); + } + } + } + try { + WAIT_FOR_FAILURE.await(); + } catch (InterruptedException e) { + throw new RuntimeException("Test Issue: handler interrupted", e); + } + // Send a response + exchange.sendResponseHeaders(200, 0); + } + + @BeforeClass + public void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + final QuicServer quicServer = Http3TestServer.quicServerBuilder() + .bindAddress(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)) + .sslContext(sslContext) + .build(); + var http3TestServer = new Http3TestServer(quicServer) + .setConnectionSettings(new ConnectionSettings( + HEADER_SIZE_LIMIT_BYTES, MAX_SERVER_DT_CAPACITY, 0)); + + h3Server = HttpTestServer.of(http3TestServer); + h3Server.addHandler(H3InsertionsLimitTest::handle, "/insertions"); + h3Server.start(); + System.out.println("Server started at " + h3Server.getAddress()); + requestURIBase = URIBuilder.newBuilder().scheme("https").loopback() + .port(h3Server.getAddress().getPort()).build().toString(); + } + + @AfterClass + public void afterClass() throws Exception { + if (h3Server != null) { + System.out.println("Stopping server " + h3Server.getAddress()); + h3Server.stop(); + } + } + + @Test + public void multipleTableInsertions() throws Exception { + final HttpClient client = newClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .version(Version.HTTP_3) + // the server drops 1 packet out of two! + .connectTimeout(Duration.ofSeconds(Utils.adjustTimeout(10))) + .sslContext(sslContext).build(); + final URI reqURI = new URI(requestURIBase + "/insertions"); + final HttpRequest.Builder reqBuilder = HttpRequest.newBuilder(reqURI) + .version(Version.HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + final HttpRequest request = + reqBuilder.POST(HttpRequest.BodyPublishers.ofString("Hello")).build(); + System.out.println("Issuing request to " + reqURI); + try { + client.send(request, BodyHandlers.discarding()); + Assert.fail("IOException expected"); + } catch (IOException ioe) { + System.out.println("Got IOException: " + ioe); + Assert.assertTrue(ioe.getMessage() + .contains("Too many literal with indexing")); + } finally { + WAIT_FOR_FAILURE.countDown(); + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3MalformedResponseTest.java b/test/jdk/java/net/httpclient/http3/H3MalformedResponseTest.java new file mode 100644 index 00000000000..06b12b6a477 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3MalformedResponseTest.java @@ -0,0 +1,437 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.httpclient.test.lib.quic.QuicStandaloneServer; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLContext; +import java.io.OutputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.util.HexFormat; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.testng.Assert.*; + +/* + * @test + * @summary Verifies that the HTTP client correctly handles malformed responses + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @library ../access + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @build java.net.http/jdk.internal.net.http.Http3ConnectionAccess + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors H3MalformedResponseTest + */ +public class H3MalformedResponseTest implements HttpServerAdapters { + + private SSLContext sslContext; + private QuicStandaloneServer server; + private String requestURIBase; + + // These responses are malformed and should not be accepted by the client, + // but they should not cause connection closure + @DataProvider + public static Object[][] malformedResponse() { + return new Object[][]{ + new Object[] {"empty", HexFormat.of().parseHex( + "" + )}, + new Object[] {"non-final response", HexFormat.of().parseHex( + "01040000"+ // headers, length 4, section prefix + "ff00" // :status:100 + )}, + new Object[] {"uppercase header name", HexFormat.of().parseHex( + "01090000"+ // headers, length 9, section prefix + "d9"+ // :status:200 + "234147450130"+ // AGE:0 + "000100" // data, 1 byte + )}, + new Object[] {"content too long", HexFormat.of().parseHex( + "01040000"+ // headers, length 4, section prefix + "d9"+ // :status:200 + "c4"+ // content-length:0 + "000100" // data, 1 byte + )}, + new Object[] {"content too short", HexFormat.of().parseHex( + "01060000"+ // headers, length 6, section prefix + "d9"+ // :status:200 + "540132"+ // content-length:2 + "000100" // data, 1 byte + )}, + new Object[] {"text in content-length", HexFormat.of().parseHex( + "01060000"+ // headers, length 6, section prefix + "d9"+ // :status:200 + "540161"+ // content-length:a + "000100" // data, 1 byte + )}, + new Object[] {"connection: close", HexFormat.of().parseHex( + "01150000"+ // headers, length 21, section prefix + "d9"+ // :status:200 + "2703636F6E6E656374696F6E05636C6F7365"+ // connection:close + "000100" // data, 1 byte + )}, + // request pseudo-headers in response + new Object[] {":method in response", HexFormat.of().parseHex( + "01040000"+ // headers, length 4, section prefix + "d9"+ // :status:200 + "d1"+ // :method:get + "000100" // data, 1 byte + )}, + new Object[] {":authority in response", HexFormat.of().parseHex( + "01100000"+ // headers, length 16, section prefix + "d9"+ // :status:200 + "508b089d5c0b8170dc702fbce7"+ // :authority + "000100" // data, 1 byte + )}, + new Object[] {":path in response", HexFormat.of().parseHex( + "010a0000"+ // headers, length 10, section prefix + "d9"+ // :status:200 + "51856272d141ff"+ // :path + "000100" // data, 1 byte + )}, + new Object[] {":scheme in response", HexFormat.of().parseHex( + "01040000"+ // headers, length 4, section prefix + "d9"+ // :status:200 + "d7"+ // :scheme:https + "000100" // data, 1 byte + )}, + new Object[] {"undefined pseudo-header", HexFormat.of().parseHex( + "01080000"+ // headers, length 8, section prefix + "d9"+ // :status:200 + "223A6D0130"+ // :m:0 + "000100" // data, 1 byte + )}, + new Object[] {"pseudo-header after regular", HexFormat.of().parseHex( + "011a0000"+ // headers, length 26, section prefix + "5f5094ca3ee35a74a6b589418b5258132b1aa496ca8747"+ //user-agent + "d9"+ // :status:200 + "000100" // data, 1 byte + )}, + new Object[] {"trailer", HexFormat.of().parseHex( + "01020000" // headers, length 2, section prefix + )}, + new Object[] {"trailer+data", HexFormat.of().parseHex( + "01020000"+ // headers, length 2, section prefix + "000100" // data, 1 byte + )}, + // valid characters include \t, 0x20-0x7e, 0x80-0xff (RFC 9110, section 5.5) + new Object[] {"invalid character in field value 00", HexFormat.of().parseHex( + "01060000"+ // headers, length 6, section prefix + "d9"+ // :status:200 + "570100"+ // etag:\0 + "000100" // data, 1 byte + )}, + new Object[] {"invalid character in field value 0a", HexFormat.of().parseHex( + "01060000"+ // headers, length 6, section prefix + "d9"+ // :status:200 + "57010a"+ // etag:\n + "000100" // data, 1 byte + )}, + new Object[] {"invalid character in field value 0d", HexFormat.of().parseHex( + "01060000"+ // headers, length 6, section prefix + "d9"+ // :status:200 + "57010d"+ // etag:\r + "000100" // data, 1 byte + )}, + new Object[] {"invalid character in field value 7f", HexFormat.of().parseHex( + "01060000"+ // headers, length 6, section prefix + "d9"+ // :status:200 + "57017f"+ // etag: 0x7f + "000100" // data, 1 byte + )}, + }; + } + + // These responses are malformed and should not be accepted by the client. + // They might or might not cause connection closure (H3_FRAME_UNEXPECTED) + @DataProvider + public static Object[][] malformedResponse2() { + // data before headers is covered by H3ErrorHandlingTest + return new Object[][]{ + new Object[] {"100+data", HexFormat.of().parseHex( + "01040000"+ // headers, length 4, section prefix + "ff00"+ // :status:100 + "000100" // data, 1 byte + )}, + new Object[] {"100+data+200", HexFormat.of().parseHex( + "01040000"+ // headers, length 4, section prefix + "ff00"+ // :status:100 + "000100"+ // data, 1 byte + "01030000"+ // headers, length 3, section prefix + "d9" // :status:200 + )}, + new Object[] {"200+data+200", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "000100"+ // data, 1 byte + "01030000"+ // headers, length 3, section prefix + "d9" // :status:200 + )}, + new Object[] {"200+data+100", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "000100"+ // data, 1 byte + "01040000"+ // headers, length 4, section prefix + "ff00" // :status:100 + )}, + new Object[] {"200+data+trailers+data", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "000100"+ // data, 1 byte + "01020000"+ // trailers, length 2, section prefix + "000100" // data, 1 byte + )}, + new Object[] {"200+trailers+data", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "01020000"+ // trailers, length 2, section prefix + "000100" // data, 1 byte + )}, + new Object[] {"200+200", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "01030000"+ // headers, length 3, section prefix + "d9" // :status:200 + )}, + new Object[] {"200+100", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "01040000"+ // headers, length 4, section prefix + "ff00" // :status:100 + )}, + }; + } + + @DataProvider + public static Object[][] wellformedResponse() { + return new Object[][]{ + new Object[] {"100+200+data+reserved", HexFormat.of().parseHex( + "01040000"+ // headers, length 4, section prefix + "ff00"+ // :status:100 + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "000100"+ // data, 1 byte + "210100" // reserved, 1 byte + )}, + new Object[] {"200+data+reserved", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "000100"+ // data, 1 byte + "210100" // reserved, 1 byte + )}, + new Object[] {"200+data", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "000100" // data, 1 byte + )}, + new Object[] {"200+user-agent+data", HexFormat.of().parseHex( + "011a0000"+ // headers, length 26, section prefix + "d9"+ // :status:200 + "5f5094ca3ee35a74a6b589418b5258132b1aa496ca8747"+ //user-agent + "000100" // data, 1 byte + )}, + new Object[] {"200", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9" // :status:200 + )}, + new Object[] {"200+data+data", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "000100"+ // data, 1 byte + "000100" // data, 1 byte + )}, + new Object[] {"200+data+trailers", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "000100"+ // data, 1 byte + "01020000" // trailers, length 2, section prefix + )}, + new Object[] {"200+trailers", HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "01020000" // trailers, length 2, section prefix + )}, + }; + } + + @BeforeClass + public void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + server = QuicStandaloneServer.newBuilder() + .availableVersions(new QuicVersion[]{QuicVersion.QUIC_V1}) + .sslContext(sslContext) + .alpn("h3") + .build(); + server.start(); + System.out.println("Server started at " + server.getAddress()); + requestURIBase = URIBuilder.newBuilder().scheme("https").loopback() + .port(server.getAddress().getPort()).build().toString(); + } + + @AfterClass + public void afterClass() throws Exception { + if (server != null) { + System.out.println("Stopping server " + server.getAddress()); + server.close(); + } + } + + /** + * Server sends a well-formed response + */ + @Test(dataProvider = "wellformedResponse") + public void testWellFormedResponse(String desc, byte[] response) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + try (OutputStream outputStream = s.outputStream()) { + outputStream.write(response); + } + // verify that the connection stays open + completeUponTermination(c, errorCF); + }); + HttpClient client = getHttpClient(); + try { + HttpRequest request = getRequest(); + final HttpResponse response1 = client.sendAsync( + request, + BodyHandlers.discarding()) + .get(10, TimeUnit.SECONDS); + assertEquals(response1.statusCode(), 200); + assertFalse(errorCF.isDone(), "Expected the connection to be open"); + } finally { + client.shutdownNow(); + } + } + + + /** + * Server sends a malformed response that should not close connection + */ + @Test(dataProvider = "malformedResponse") + public void testMalformedResponse(String desc, byte[] response) throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + server.addHandler((c,s)-> { + try (OutputStream outputStream = s.outputStream()) { + outputStream.write(response); + } + // verify that the connection stays open + completeUponTermination(c, errorCF); + }); + HttpClient client = getHttpClient(); + try { + HttpRequest request = getRequest(); + final HttpResponse response1 = client.sendAsync( + request, + BodyHandlers.discarding()) + .get(10, TimeUnit.SECONDS); + fail("Expected the request to fail, got " + response1); + } catch (TimeoutException e) { + throw e; + } catch (Exception e) { + System.out.println("Got expected exception: " +e); + e.printStackTrace(); + assertFalse(errorCF.isDone(), "Expected the connection to be open"); + } finally { + client.shutdownNow(); + } + } + + /** + * Server sends a malformed response that might close connection + */ + @Test(dataProvider = "malformedResponse2") + public void testMalformedResponse2(String desc, byte[] response) throws Exception { + server.addHandler((c,s)-> { + try (OutputStream outputStream = s.outputStream()) { + outputStream.write(response); + } + }); + HttpClient client = getHttpClient(); + try { + HttpRequest request = getRequest(); + final HttpResponse response1 = client.sendAsync( + request, + BodyHandlers.discarding()) + .get(10, TimeUnit.SECONDS); + fail("Expected the request to fail, got " + response1); + } catch (TimeoutException e) { + throw e; + } catch (Exception e) { + System.out.println("Got expected exception: " +e); + e.printStackTrace(); + } finally { + client.shutdownNow(); + } + } + + private HttpRequest getRequest() throws URISyntaxException { + final URI reqURI = new URI(requestURIBase + "/hello"); + final HttpRequest.Builder reqBuilder = HttpRequest.newBuilder(reqURI) + .version(Version.HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + return reqBuilder.build(); + } + + private HttpClient getHttpClient() { + final HttpClient client = newClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .version(Version.HTTP_3) + .sslContext(sslContext).build(); + return client; + } + + private static String toHexString(final Http3Error error) { + return error.name() + "(0x" + Long.toHexString(error.code()) + ")"; + } + + private static void completeUponTermination(final QuicServerConnection serverConnection, + final CompletableFuture cf) { + serverConnection.futureTerminationCause().handle( + (r,t) -> t != null ? cf.completeExceptionally(t) : cf.complete(r)); + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3MaxInitialTimeoutTest.java b/test/jdk/java/net/httpclient/http3/H3MaxInitialTimeoutTest.java new file mode 100644 index 00000000000..30f48c56454 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3MaxInitialTimeoutTest.java @@ -0,0 +1,250 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.net.ConnectException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpConnectTimeoutException; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.Builder; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.channels.DatagramChannel; +import java.time.Duration; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicLong; +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; +import org.testng.ITestContext; +import org.testng.SkipException; +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import static java.lang.System.out; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.testng.Assert.assertEquals; + + +/* + * @test + * @bug 8342954 + * @summary Verify jdk.httpclient.quic.maxInitialTimeout is taken into account. + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.httpclient.test.lib.quic.QuicStandaloneServer + * @run testng/othervm -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors,quic:controls + * -Djdk.httpclient.quic.maxInitialTimeout=1 + * H3MaxInitialTimeoutTest + * @run testng/othervm -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors,quic:controls + * -Djdk.httpclient.quic.maxInitialTimeout=2 + * H3MaxInitialTimeoutTest + * @run testng/othervm -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors,quic:controls + * -Djdk.httpclient.quic.maxInitialTimeout=2147483647 + * H3MaxInitialTimeoutTest + */ +public class H3MaxInitialTimeoutTest implements HttpServerAdapters { + + SSLContext sslContext; + DatagramChannel receiver; + String h3URI; + + static final Executor executor = new TestExecutor(Executors.newVirtualThreadPerTaskExecutor()); + static final ConcurrentMap FAILURES = new ConcurrentHashMap<>(); + static volatile boolean tasksFailed; + static final AtomicLong serverCount = new AtomicLong(); + static final AtomicLong clientCount = new AtomicLong(); + static final long start = System.nanoTime(); + public static String now() { + long now = System.nanoTime() - start; + long secs = now / 1000_000_000; + long mill = (now % 1000_000_000) / 1000_000; + long nan = now % 1000_000; + return String.format("[%d s, %d ms, %d ns] ", secs, mill, nan); + } + + static class TestExecutor implements Executor { + final AtomicLong tasks = new AtomicLong(); + Executor executor; + TestExecutor(Executor executor) { + this.executor = executor; + } + + @Override + public void execute(Runnable command) { + long id = tasks.incrementAndGet(); + executor.execute(() -> { + try { + command.run(); + } catch (Throwable t) { + tasksFailed = true; + System.out.printf(now() + "Task %s failed: %s%n", id, t); + System.err.printf(now() + "Task %s failed: %s%n", id, t); + FAILURES.putIfAbsent("Task " + id, t); + throw t; + } + }); + } + } + + protected boolean stopAfterFirstFailure() { + return Boolean.getBoolean("jdk.internal.httpclient.debug"); + } + + @BeforeMethod + void beforeMethod(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + var x = new SkipException("Skipping: some test failed"); + x.setStackTrace(new StackTraceElement[0]); + throw x; + } + } + + @AfterClass + static void printFailedTests() { + out.println("\n========================="); + try { + out.printf("%n%sCreated %d servers and %d clients%n", + now(), serverCount.get(), clientCount.get()); + if (FAILURES.isEmpty()) return; + out.println("Failed tests: "); + FAILURES.forEach((key, value) -> { + out.printf("\t%s: %s%n", key, value); + value.printStackTrace(out); + value.printStackTrace(); + }); + if (tasksFailed) { + System.out.println("WARNING: Some tasks failed"); + } + } finally { + out.println("\n=========================\n"); + } + } + + @DataProvider(name = "h3URIs") + public Object[][] versions(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + return new Object[0][]; + } + Object[][] result = {{h3URI}}; + return result; + } + + private HttpClient makeNewClient(long connectionTimeout) { + clientCount.incrementAndGet(); + HttpClient client = newClientBuilderForH3() + .version(Version.HTTP_3) + .proxy(HttpClient.Builder.NO_PROXY) + .executor(executor) + .sslContext(sslContext) + .connectTimeout(Duration.ofSeconds(connectionTimeout)) + .build(); + return client; + } + + @Test(dataProvider = "h3URIs") + public void testTimeout(final String h3URI) throws Exception { + long timeout = Long.getLong("jdk.httpclient.quic.maxInitialTimeout", 30); + long connectionTimeout = timeout == Integer.MAX_VALUE ? 2 : 10 * timeout; + + try (HttpClient client = makeNewClient(connectionTimeout)) { + URI uri = URI.create(h3URI); + Builder builder = HttpRequest.newBuilder(uri) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .GET(); + HttpRequest request = builder.build(); + try { + HttpResponse response = client.send(request, BodyHandlers.ofString()); + out.println("Response #1: " + response); + out.println("Version #1: " + response.version()); + assertEquals(response.statusCode(), 200, "first response status"); + assertEquals(response.version(), HTTP_3, "first response version"); + throw new AssertionError("Expected ConnectException not thrown"); + } catch (ConnectException c) { + String msg = c.getMessage(); + if (timeout != Integer.MAX_VALUE) { + if (msg != null && msg.contains("No response from peer")) { + out.println("Got expected exception: " + c); + } else throw c; + } else throw c; + } catch (HttpConnectTimeoutException hc) { + String msg = hc.getMessage(); + if (timeout == Integer.MAX_VALUE) { + if (msg != null && msg.contains("No response from peer")) { + throw new AssertionError("Unexpected message: " + msg, hc); + } else { + out.println("Got expected exception: " + hc); + return; + } + } else throw hc; + } + } + + } + + + @BeforeTest + public void setup() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + receiver = DatagramChannel.open(); + receiver.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); + h3URI = URIBuilder.newBuilder() + .scheme("https") + .loopback() + .port(((InetSocketAddress)receiver.getLocalAddress()).getPort()) + .path("/") + .build() + .toString(); + } + + @AfterTest + public void teardown() throws Exception { + System.err.println("======================================================="); + System.err.println(" Tearing down test"); + System.err.println("======================================================="); + receiver.close(); + } + +} diff --git a/test/jdk/java/net/httpclient/http3/H3MemoryHandlingTest.java b/test/jdk/java/net/httpclient/http3/H3MemoryHandlingTest.java new file mode 100644 index 00000000000..f270297d85c --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3MemoryHandlingTest.java @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.httpclient.test.lib.quic.QuicStandaloneServer; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.util.HexFormat; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.testng.Assert.*; + +/* + * @test + * @summary Verifies that the HTTP client does not buffer excessive amounts of data + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @library ../access + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @build java.net.http/jdk.internal.net.http.Http3ConnectionAccess + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors + * -Djdk.httpclient.quic.maxStreamInitialData=16384 + * -Djdk.httpclient.quic.streamBufferSize=2048 H3MemoryHandlingTest + */ +public class H3MemoryHandlingTest implements HttpServerAdapters { + + private SSLContext sslContext; + private QuicStandaloneServer server; + private String requestURIBase; + + @BeforeClass + public void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + server = QuicStandaloneServer.newBuilder() + .availableVersions(new QuicVersion[]{QuicVersion.QUIC_V1}) + .sslContext(sslContext) + .alpn("h3") + .build(); + server.start(); + System.out.println("Server started at " + server.getAddress()); + requestURIBase = URIBuilder.newBuilder().scheme("https").loopback() + .port(server.getAddress().getPort()).build().toString(); + } + + @AfterClass + public void afterClass() throws Exception { + if (server != null) { + System.out.println("Stopping server " + server.getAddress()); + server.close(); + } + } + + /** + * Server sends a large response, and the user code does not read from the input stream. + * Writing on the server side should block once the buffers are full. + */ + @Test + public void testOfInputStreamBlocks() throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + byte[] response = HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "00ffffffffffffffff"); // data, 2^62 - 1 bytes + byte[] kilo = new byte[1024]; + final CompletableFuture serverAllWritesDone = new CompletableFuture<>(); + server.addHandler((c,s)-> { + // verify that the connection stays open + completeUponTermination(c, errorCF); + try (OutputStream outputStream = s.outputStream()) { + outputStream.write(response); + for (int i = 0; i < 20; i++) { + // 18 writes should succeed, 19th should block + outputStream.write(kilo); + System.out.println("Wrote "+(i+1)+"KB"); + } + // all 20 writes unexpectedly completed + serverAllWritesDone.complete(true); + } catch(IOException ex) { + System.out.println("Got expected exception: " + ex); + serverAllWritesDone.complete(false); + } + }); + HttpClient client = getHttpClient(); + try { + HttpRequest request = getRequest(); + final HttpResponse response1 = client.send( + request, BodyHandlers.ofInputStream()); + assertEquals(response1.statusCode(), 200); + assertFalse(errorCF.isDone(), "Expected the connection to be open"); + assertFalse(serverAllWritesDone.isDone()); + response1.body().close(); + final boolean done = serverAllWritesDone.get(10, TimeUnit.SECONDS); + assertFalse(done, "Too much data was buffered by the client"); + } finally { + client.close(); + } + } + + /** + * Server sends a large response, and the user code does not read from the input stream. + * Writing on the server side should unblock once the client starts receiving. + */ + @Test + public void testOfInputStreamUnblocks() throws Exception { + CompletableFuture errorCF = new CompletableFuture<>(); + CompletableFuture handlerCF = new CompletableFuture<>(); + byte[] response = HexFormat.of().parseHex( + "01030000"+ // headers, length 3, section prefix + "d9"+ // :status:200 + "0080008000"); // data, 32 KB + byte[] kilo = new byte[1024]; + CountDownLatch writerBlocked = new CountDownLatch(1); + + server.addHandler((c,s)-> { + // verify that the connection stays open + completeUponTermination(c, errorCF); + QuicBidiStream qs = s.underlyingBidiStream(); + + try (OutputStream outputStream = s.outputStream()) { + outputStream.write(response); + for (int i = 0;i < 32;i++) { + // 18 writes should succeed, 19th should block + if (i == 18) { + writerBlocked.countDown(); + } + outputStream.write(kilo); + System.out.println("Wrote "+(i+1)+"KB"); + } + handlerCF.complete(true); + } catch (IOException e) { + handlerCF.completeExceptionally(e); + } + }); + HttpClient client = getHttpClient(); + try { + HttpRequest request = getRequest(); + final HttpResponse response1 = client.send( + request, BodyHandlers.ofInputStream()); + assertEquals(response1.statusCode(), 200); + assertFalse(errorCF.isDone(), "Expected the connection to be open"); + assertFalse(handlerCF.isDone()); + assertTrue(writerBlocked.await(10, TimeUnit.SECONDS), + "write of 18 KB should have succeeded"); + System.out.println("Wait completed, receiving response"); + byte[] receivedResponse; + try (InputStream body = response1.body()) { + receivedResponse = body.readAllBytes(); + } + assertEquals(receivedResponse.length, 32768, + "Unexpected response length"); + } finally { + client.close(); + } + assertTrue(handlerCF.get(10, TimeUnit.SECONDS), + "Unexpected result"); + } + + private HttpRequest getRequest() throws URISyntaxException { + final URI reqURI = new URI(requestURIBase + "/hello"); + final HttpRequest.Builder reqBuilder = HttpRequest.newBuilder(reqURI) + .version(Version.HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + return reqBuilder.build(); + } + + private HttpClient getHttpClient() { + final HttpClient client = newClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .version(Version.HTTP_3) + .sslContext(sslContext).build(); + return client; + } + + private static String toHexString(final Http3Error error) { + return error.name() + "(0x" + Long.toHexString(error.code()) + ")"; + } + + private static void completeUponTermination(final QuicServerConnection serverConnection, + final CompletableFuture cf) { + serverConnection.futureTerminationCause().handle( + (r,t) -> t != null ? cf.completeExceptionally(t) : cf.complete(r)); + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3MultipleConnectionsToSameHost.java b/test/jdk/java/net/httpclient/http3/H3MultipleConnectionsToSameHost.java new file mode 100644 index 00000000000..45024c58e1f --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3MultipleConnectionsToSameHost.java @@ -0,0 +1,338 @@ +/* + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test id=with-continuations + * @bug 8087112 + * @requires os.family != "windows" | ( os.name != "Windows 10" & os.name != "Windows Server 2016" + * & os.name != "Windows Server 2019" ) + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm/timeout=360 -XX:+CrashOnOutOfMemoryError + * -Djdk.httpclient.quic.minPtoBackoffTime=60 + * -Djdk.httpclient.quic.maxPtoBackoffTime=90 + * -Djdk.httpclient.quic.maxPtoBackoff=10 + * -Djdk.internal.httpclient.quic.useNioSelector=false + * -Djdk.internal.httpclient.quic.poller.usePlatformThreads=false + * -Djdk.httpclient.quic.maxEndpoints=-1 + * -Djdk.httpclient.http3.maxStreamLimitTimeout=0 + * -Djdk.httpclient.quic.maxBidiStreams=2 + * -Djdk.httpclient.retryOnStreamlimit=50 + * -Djdk.httpclient.HttpClient.log=errors,http3,quic:retransmit + * -Dsimpleget.requests=100 + * H3MultipleConnectionsToSameHost + * @summary test multiple connections and concurrent requests with blocking IO and virtual threads + */ +/* + * @test id=without-continuations + * @bug 8087112 + * @requires os.family == "windows" & ( os.name == "Windows 10" | os.name == "Windows Server 2016" + * | os.name == "Windows Server 2019" ) + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm/timeout=360 -XX:+CrashOnOutOfMemoryError + * -Djdk.httpclient.quic.minPtoBackoffTime=45 + * -Djdk.httpclient.quic.maxPtoBackoffTime=60 + * -Djdk.httpclient.quic.maxPtoBackoff=9 + * -Djdk.internal.httpclient.quic.useNioSelector=false + * -Djdk.internal.httpclient.quic.poller.usePlatformThreads=false + * -XX:+UnlockExperimentalVMOptions -XX:-VMContinuations + * -Djdk.httpclient.quic.maxEndpoints=-1 + * -Djdk.httpclient.http3.maxStreamLimitTimeout=0 + * -Djdk.httpclient.quic.maxBidiStreams=2 + * -Djdk.httpclient.retryOnStreamlimit=50 + * -Djdk.httpclient.HttpClient.log=errors,http3,quic:retransmit + * -Dsimpleget.requests=100 + * H3MultipleConnectionsToSameHost + * @summary test multiple connections and concurrent requests with blocking IO and virtual threads + * on windows 10 and windows 2016 - but with -XX:-VMContinuations + */ +/* + * @test id=useNioSelector + * @bug 8087112 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm/timeout=360 -XX:+CrashOnOutOfMemoryError + * -Djdk.httpclient.quic.idleTimeout=120 + * -Djdk.httpclient.keepalive.timeout.h3=120 + * -Djdk.test.server.quic.idleTimeout=90 + * -Djdk.httpclient.quic.minPtoBackoffTime=60 + * -Djdk.httpclient.quic.maxPtoBackoffTime=120 + * -Djdk.httpclient.quic.maxPtoBackoff=9 + * -Djdk.internal.httpclient.quic.useNioSelector=true + * -Djdk.httpclient.http3.maxStreamLimitTimeout=0 + * -Djdk.httpclient.quic.maxEndpoints=1 + * -Djdk.httpclient.quic.maxBidiStreams=2 + * -Djdk.httpclient.retryOnStreamlimit=50 + * -Djdk.httpclient.HttpClient.log=errors,http3,quic:hs:retransmit + * -Dsimpleget.requests=100 + * H3MultipleConnectionsToSameHost + * @summary Send 100 large concurrent requests, with connections whose max stream + * limit is artificially low, in order to cause concurrent connections + * to the same host to be created, with non-blocking IO and selector + */ + +// Interesting additional settings for debugging and manual testing: +// ----------------------------------------------------------------- +// -XX:+UnlockExperimentalVMOptions -XX:-VMContinuations +// -Djdk.httpclient.HttpClient.log=errors,requests,http3,quic +// -Djdk.httpclient.HttpClient.log=requests,errors,quic:retransmit:control,http3 +// -Djdk.httpclient.HttpClient.log=errors,requests,quic:all +// -Djdk.httpclient.quic.defaultMTU=64000 +// -Djdk.httpclient.quic.defaultMTU=16384 +// -Djdk.httpclient.quic.defaultMTU=4096 +// -Djdk.httpclient.http3.maxStreamLimitTimeout=1375 +// -Xmx16g +// -Djdk.httpclient.quic.defaultMTU=16384 +// -Djdk.internal.httpclient.debug=err +// -XX:+HeapDumpOnOutOfMemoryError +// -Djdk.httpclient.HttpClient.log=errors,quic:cc +// -Djdk.httpclient.quic.sendBufferSize=16384 +// -Djdk.httpclient.quic.receiveBufferSize=16384 + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.internal.net.http.common.Utils; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.Assert; +import org.testng.annotations.Test; + +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static jdk.internal.net.http.Http3ClientProperties.MAX_STREAM_LIMIT_WAIT_TIMEOUT; + +public class H3MultipleConnectionsToSameHost implements HttpServerAdapters { + static HttpTestServer httpsServer; + static HttpClient client = null; + static SSLContext sslContext; + static String httpsURIString; + static ExecutorService serverExec = + Executors.newThreadPerTaskExecutor(Thread.ofVirtual() + .name("server-vt-worker-", 1).factory()); + + static void initialize() throws Exception { + try { + SimpleSSLContext sslct = new SimpleSSLContext(); + sslContext = sslct.get(); + client = getClient(); + + httpsServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext, serverExec); + httpsServer.addHandler(new TestHandler(), "/"); + httpsURIString = "https://" + httpsServer.serverAuthority() + "/bar/"; + + httpsServer.start(); + warmup(); + } catch (Throwable e) { + System.err.println("Throwing now"); + e.printStackTrace(); + throw e; + } + } + + private static void warmup() throws Exception { + SimpleSSLContext sslct = new SimpleSSLContext(); + var sslContext = sslct.get(); + + // warmup server + try (var client2 = createClient(sslContext, Executors.newVirtualThreadPerTaskExecutor())) { + HttpRequest request = HttpRequest.newBuilder(URI.create(httpsURIString)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .HEAD().build(); + client2.send(request, BodyHandlers.ofByteArrayConsumer(b-> {})); + } + + // warmup client + var httpsServer2 = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext, + Executors.newVirtualThreadPerTaskExecutor()); + httpsServer2.addHandler(new TestHandler(), "/"); + var httpsURIString2 = "https://" + httpsServer2.serverAuthority() + "/bar/"; + httpsServer2.start(); + try { + HttpRequest request = HttpRequest.newBuilder(URI.create(httpsURIString2)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .HEAD().build(); + client.send(request, BodyHandlers.ofByteArrayConsumer(b-> {})); + } finally { + httpsServer2.stop(); + } + } + + public static void main(String[] args) throws Exception { + test(); + } + + @Test + public static void test() throws Exception { + try { + long prestart = System.nanoTime(); + initialize(); + long done = System.nanoTime(); + System.out.println("Initialization and warmup took "+ TimeUnit.NANOSECONDS.toMillis(done-prestart)+" millis"); + // Thread.sleep(30000); + int maxBidiStreams = Utils.getIntegerNetProperty("jdk.httpclient.quic.maxBidiStreams", 100); + long timeout = MAX_STREAM_LIMIT_WAIT_TIMEOUT; + + Set connections = new ConcurrentSkipListSet<>(); + HttpRequest request = HttpRequest.newBuilder(URI.create(httpsURIString)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .GET().build(); + long start = System.nanoTime(); + var resp = client.send(request, BodyHandlers.ofByteArrayConsumer(b-> {})); + Assert.assertEquals(resp.statusCode(), 200); + long elapsed = System.nanoTime() - start; + System.out.println("First request took: " + elapsed + " nanos (" + TimeUnit.NANOSECONDS.toMillis(elapsed) + " ms)"); + final int max = property("simpleget.requests", 50); + List>> list = new ArrayList<>(max); + long start2 = System.nanoTime(); + for (int i = 0; i < max; i++) { + int rqNum = i; + var cf = client.sendAsync(request, BodyHandlers.ofByteArrayConsumer(b-> {})) + .whenComplete((r, t) -> { + Optional.ofNullable(r) + .flatMap(HttpResponse::connectionLabel) + .ifPresent(connections::add); + if (r != null) { + System.out.println(rqNum + " completed: " + r.connectionLabel()); + } else { + System.out.println(rqNum + " failed: " + t); + } + }); + list.add(cf); + //cf.get(); // uncomment to test with serial instead of concurrent requests + } + try { + CompletableFuture.allOf(list.toArray(new CompletableFuture[0])).join(); + } finally { + long elapsed2 = System.nanoTime() - start2; + long completed = list.stream().filter(CompletableFuture::isDone) + .filter(Predicate.not(CompletableFuture::isCompletedExceptionally)).count(); + if (completed > 0) { + System.out.println("Next " + completed + " requests took: " + elapsed2 + " nanos (" + + TimeUnit.NANOSECONDS.toMillis(elapsed2) + "ms for " + completed + " requests): " + + elapsed2 / completed + " nanos per request (" + TimeUnit.NANOSECONDS.toMillis(elapsed2) / completed + + " ms) on " + connections.size() + " connections"); + } + if (completed == list.size()) { + long msPerRequest = TimeUnit.NANOSECONDS.toMillis(elapsed2) / completed; + if (timeout == 0 || timeout < msPerRequest) { + int expectedCount = max / maxBidiStreams; + if (expectedCount > 2) { + if (connections.size() < expectedCount - Math.max(1, expectedCount/5)) { + throw new AssertionError( + "Too few connections: %s for %s requests with %s streams/connection (timeout %s ms)" + .formatted(connections.size(), max, maxBidiStreams, timeout)); + } + } + } + if (connections.size() > max - Math.max(1, max/5)) { + throw new AssertionError( + "Too few connections: %s for %s requests with %s streams/connection (timeout %s ms)" + .formatted(connections.size(), max, maxBidiStreams, timeout)); + } + + } + } + list.forEach((cf) -> Assert.assertEquals(cf.join().statusCode(), 200)); + client.close(); + } catch (Throwable tt) { + System.err.println("tt caught"); + tt.printStackTrace(); + throw tt; + } finally { + httpsServer.stop(); + } + } + + static HttpClient createClient(SSLContext sslContext, ExecutorService clientExec) { + var builder = HttpServerAdapters.createClientBuilderForH3() + .sslContext(sslContext) + .version(HTTP_3) + .proxy(Builder.NO_PROXY); + if (clientExec != null) { + builder = builder.executor(clientExec); + } + return builder.build(); + } + + static HttpClient getClient() { + if (client == null) { + client = createClient(sslContext, null); + } + return client; + } + + static int property(String name, int defaultValue) { + return Integer.parseInt(System.getProperty(name, String.valueOf(defaultValue))); + } + + // 32 * 32 * 1024 * 10 chars = 10Mb responses + // 50 requests => 500Mb + // 100 requests => 1Gb + // 1000 requests => 10Gb + private final static int REPEAT = property("simpleget.repeat", 32); + private final static String RESPONSE = "abcdefghij".repeat(property("simpleget.chunks", 1024*32)); + private final static byte[] RESPONSE_BYTES = RESPONSE.getBytes(StandardCharsets.UTF_8); + + private static class TestHandler implements HttpTestHandler { + @Override + public void handle(HttpTestExchange t) throws IOException { + try (var in = t.getRequestBody()) { + byte[] input = in.readAllBytes(); + t.sendResponseHeaders(200, RESPONSE_BYTES.length * REPEAT); + try (var out = t.getResponseBody()) { + if (t.getRequestMethod().equals("HEAD")) return; + for (int i=0; i { + he.getResponseHeaders().addHeader("encoding", "UTF-8"); + he.sendResponseHeaders(200, RESPONSE.length()); + if (!he.getRequestMethod().equals("HEAD")) { + he.getResponseBody().write(RESPONSE.getBytes(StandardCharsets.UTF_8)); + } + he.close(); + }, PATH); + + return server; + } + + static HttpTestServer createHttp3Server() throws Exception { + HttpTestServer server = HttpTestServer.create(HTTP_3_URI_ONLY, SSLContext.getDefault()); + server.addHandler(he -> { + he.getResponseHeaders().addHeader("encoding", "UTF-8"); + he.sendResponseHeaders(200, RESPONSE.length()); + if (!he.getRequestMethod().equals("HEAD")) { + he.getResponseBody().write(RESPONSE.getBytes(StandardCharsets.UTF_8)); + } + he.close(); + }, PATH); + + return server; + } + + public static void main(String[] args) + throws Exception + { + HttpTestServer server = createHttps2Server(); + HttpTestServer server3 = createHttp3Server(); + server.start(); + server3.start(); + try { + if (server.supportsH3DirectConnection()) { + test(server, ANY); + try { + test(server, HTTP_3_URI_ONLY); + throw new AssertionError("expected UnsupportedProtocolVersionException not raised"); + } catch (UnsupportedProtocolVersionException upve) { + System.out.printf("%nGot expected exception: %s%n%n", upve); + } + } + test(server, ALT_SVC); + try { + test(server3, HTTP_3_URI_ONLY); + throw new AssertionError("expected UnsupportedProtocolVersionException not raised"); + } catch (UnsupportedProtocolVersionException upve) { + System.out.printf("%nGot expected exception: %s%n%n", upve); + } + } finally { + server.stop(); + server3.stop(); + System.out.println("Server stopped"); + } + } + + public static void test(HttpTestServer server, + Http3DiscoveryMode config) + throws Exception + { + System.out.println(""" + + # -------------------------------------------------- + # Server is %s + # Config is %s + # -------------------------------------------------- + """.formatted(server.getAddress(), config)); + + URI uri = new URI("https://" + server.serverAuthority() + PATH + "x"); + TunnelingProxy proxy = new TunnelingProxy(server); + proxy.start(); + try { + System.out.println("Proxy started"); + System.out.println("\nSetting up request with HttpClient for version: " + + config.name() + " URI=" + uri); + ProxySelector ps = ProxySelector.of( + InetSocketAddress.createUnresolved("localhost", proxy.getAddress().getPort())); + HttpClient client = HttpServerAdapters.createClientBuilderForH3() + .version(Version.HTTP_3) + .sslContext(new SimpleSSLContext().get()) + .proxy(ps) + .build(); + try (client) { + if (config == ALT_SVC) { + System.out.println("\nSending HEAD request to preload AltServiceRegistry"); + HttpRequest head = HttpRequest.newBuilder() + .uri(uri) + .HEAD() + .version(Version.HTTP_2) + .build(); + var headResponse = client.send(head, BodyHandlers.ofString()); + System.out.println("Got head response: " + headResponse); + if (headResponse.statusCode() != 200) { + throw new AssertionError("bad status code: " + headResponse); + } + if (!headResponse.version().equals(Version.HTTP_2)) { + throw new AssertionError("bad protocol version: " + headResponse.version()); + } + + } + + HttpRequest request = HttpRequest.newBuilder() + .uri(uri) + .GET() + .version(Version.HTTP_3) + .setOption(H3_DISCOVERY, config) + .build(); + + System.out.println("\nSending request with HttpClient: " + config); + HttpResponse response + = client.send(request, HttpResponse.BodyHandlers.ofString()); + System.out.println("Got response"); + if (response.statusCode() != 200) { + throw new AssertionError("bad status code: " + response); + } + if (!response.version().equals(Version.HTTP_2)) { + throw new AssertionError("bad protocol version: " + response.version()); + } + + String resp = response.body(); + System.out.println("Received: " + resp); + if (!RESPONSE.equals(resp)) { + throw new AssertionError("Unexpected response"); + } + } + } catch (Throwable t) { + System.out.println("Error: " + t); + throw t; + } finally { + System.out.println("Stopping proxy"); + proxy.stop(); + System.out.println("Proxy stopped"); + } + } + + static class TunnelingProxy { + final Thread accept; + final ServerSocket ss; + final boolean DEBUG = false; + final HttpTestServer serverImpl; + final CopyOnWriteArrayList> connectionCFs + = new CopyOnWriteArrayList<>(); + private volatile boolean stopped; + TunnelingProxy(HttpTestServer serverImpl) throws IOException { + this.serverImpl = serverImpl; + ss = new ServerSocket(); + accept = new Thread(this::accept); + accept.setDaemon(true); + } + + void start() throws IOException { + ss.setReuseAddress(false); + ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); + accept.start(); + } + + // Pipe the input stream to the output stream. + private synchronized Thread pipe(InputStream is, OutputStream os, + char tag, CompletableFuture end) { + return new Thread("TunnelPipe("+tag+")") { + @Override + public void run() { + try { + try (os) { + int c; + while ((c = is.read()) != -1) { + os.write(c); + os.flush(); + // if DEBUG prints a + or a - for each transferred + // character. + if (DEBUG) System.out.print(tag); + } + is.close(); + } + } catch (IOException ex) { + if (DEBUG) ex.printStackTrace(System.out); + } finally { + end.complete(null); + } + } + }; + } + + public InetSocketAddress getAddress() { + return new InetSocketAddress( InetAddress.getLoopbackAddress(), ss.getLocalPort()); + } + + // This is a bit shaky. It doesn't handle continuation + // lines, but our client shouldn't send any. + // Read a line from the input stream, swallowing the final + // \r\n sequence. Stops at the first \n, doesn't complain + // if it wasn't preceded by '\r'. + // + String readLine(InputStream r) throws IOException { + StringBuilder b = new StringBuilder(); + int c; + while ((c = r.read()) != -1) { + if (c == '\n') break; + b.appendCodePoint(c); + } + if (b.codePointAt(b.length() -1) == '\r') { + b.delete(b.length() -1, b.length()); + } + return b.toString(); + } + + public void accept() { + Socket clientConnection; + try { + while (!stopped) { + System.out.println("Tunnel: Waiting for client"); + Socket toClose; + try { + toClose = clientConnection = ss.accept(); + } catch (IOException io) { + if (DEBUG) io.printStackTrace(System.out); + break; + } + System.out.println("Tunnel: Client accepted"); + Socket targetConnection; + InputStream ccis = clientConnection.getInputStream(); + OutputStream ccos = clientConnection.getOutputStream(); + Writer w = new OutputStreamWriter(ccos, StandardCharsets.UTF_8); + PrintWriter pw = new PrintWriter(w); + System.out.println("Tunnel: Reading request line"); + String requestLine = readLine(ccis); + System.out.println("Tunnel: Request status line: " + requestLine); + if (requestLine.startsWith("CONNECT ")) { + // We should probably check that the next word following + // CONNECT is the host:port of our HTTPS serverImpl. + // Some improvement for a followup! + + // Read all headers until we find the empty line that + // signals the end of all headers. + while(!requestLine.equals("")) { + System.out.println("Tunnel: Reading header: " + + (requestLine = readLine(ccis))); + } + + // Open target connection + targetConnection = new Socket( + InetAddress.getLoopbackAddress(), + serverImpl.getAddress().getPort()); + + // Then send the 200 OK response to the client + System.out.println("Tunnel: Sending " + + "HTTP/1.1 200 OK\r\n\r\n"); + pw.print("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"); + pw.flush(); + } else { + // This should not happen. If it does then just print an + // error - both on out and err, and close the accepted + // socket + System.out.println("WARNING: Tunnel: Unexpected status line: " + + requestLine + " received by " + + ss.getLocalSocketAddress() + + " from " + + toClose.getRemoteSocketAddress() + + " - closing accepted socket"); + // Print on err + System.err.println("WARNING: Tunnel: Unexpected status line: " + + requestLine + " received by " + + ss.getLocalSocketAddress() + + " from " + + toClose.getRemoteSocketAddress()); + // close accepted socket. + toClose.close(); + System.err.println("Tunnel: accepted socket closed."); + continue; + } + + // Pipe the input stream of the client connection to the + // output stream of the target connection and conversely. + // Now the client and target will just talk to each other. + System.out.println("Tunnel: Starting tunnel pipes"); + CompletableFuture end, end1, end2; + Thread t1 = pipe(ccis, targetConnection.getOutputStream(), '+', + end1 = new CompletableFuture<>()); + Thread t2 = pipe(targetConnection.getInputStream(), ccos, '-', + end2 = new CompletableFuture<>()); + end = CompletableFuture.allOf(end1, end2); + end.whenComplete( + (r,t) -> { + try { toClose.close(); } catch (IOException x) { } + finally {connectionCFs.remove(end);} + }); + connectionCFs.add(end); + t1.start(); + t2.start(); + } + } catch (Throwable ex) { + try { + ss.close(); + } catch (IOException ex1) { + ex.addSuppressed(ex1); + } + ex.printStackTrace(System.err); + } finally { + System.out.println("Tunnel: exiting (stopped=" + stopped + ")"); + connectionCFs.forEach(cf -> cf.complete(null)); + } + } + + public void stop() throws IOException { + stopped = true; + ss.close(); + try { + if (accept != Thread.currentThread()) accept.join(); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } + } + + } + +} diff --git a/test/jdk/java/net/httpclient/http3/H3PushCancel.java b/test/jdk/java/net/httpclient/http3/H3PushCancel.java new file mode 100644 index 00000000000..a5293444ade --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3PushCancel.java @@ -0,0 +1,508 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=errors,requests,responses,trace + * -Djdk.httpclient.http3.maxConcurrentPushStreams=5 + * H3PushCancel + * @summary This test checks that not accepting one of the push promise + * will cancel it. It also verifies that receiving a pushId bigger + * than the max push ID allowed on the connection will cause + * the exchange to fail and the connection to get closed. + */ + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpClient.Version; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodyHandlers; +import java.net.http.HttpResponse.PushPromiseHandler; +import java.net.http.HttpResponse.PushPromiseHandler.PushId; +import java.net.http.HttpResponse.PushPromiseHandler.PushId.Http3PushId; +import java.nio.channels.ClosedChannelException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.internal.net.http.common.Utils; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class H3PushCancel implements HttpServerAdapters { + + static Map PUSH_PROMISES = Map.of( + "/x/y/z/1", "the first push promise body", + "/x/y/z/2", "the second push promise body", + "/x/y/z/3", "the third push promise body", + "/x/y/z/4", "the fourth push promise body", + "/x/y/z/5", "the fifth push promise body", + "/x/y/z/6", "the sixth push promise body", + "/x/y/z/7", "the seventh push promise body", + "/x/y/z/8", "the eight push promise body", + "/x/y/z/9", "the ninth push promise body" + ); + static final String MAIN_RESPONSE_BODY = "the main response body"; + + HttpTestServer server; + URI uri; + URI headURI; + ServerPushHandler pushHandler; + + @BeforeTest + public void setup() throws Exception { + server = HttpTestServer.create(ANY, new SimpleSSLContext().get()); + pushHandler = new ServerPushHandler(MAIN_RESPONSE_BODY, PUSH_PROMISES); + server.addHandler(pushHandler, "/push/"); + server.addHandler(new HttpHeadOrGetHandler(), "/head/"); + server.start(); + System.err.println("Server listening on port " + server.serverAuthority()); + uri = new URI("https://" + server.serverAuthority() + "/push/a/b/c"); + headURI = new URI("https://" + server.serverAuthority() + "/head/x"); + } + + @AfterTest + public void teardown() { + server.stop(); + } + + static HttpResponse assert200ResponseCode(HttpResponse response) { + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), Version.HTTP_3); + return response; + } + + private void sendHeadRequest(HttpClient client) throws IOException, InterruptedException { + HttpRequest headRequest = HttpRequest.newBuilder(headURI) + .HEAD().version(Version.HTTP_2).build(); + var headResponse = client.send(headRequest, BodyHandlers.ofString()); + assertEquals(headResponse.statusCode(), 200); + assertEquals(headResponse.version(), Version.HTTP_2); + } + + @Test + public void testNoCancel() throws Exception { + int maxPushes = Utils.getIntegerProperty("jdk.httpclient.http3.maxConcurrentPushStreams", -1); + System.out.println("maxPushes: " + maxPushes); + assertTrue(maxPushes > 0); + try (HttpClient client = newClientBuilderForH3() + .proxy(Builder.NO_PROXY) + .version(Version.HTTP_3) + .sslContext(new SimpleSSLContext().get()) + .build()) { + + sendHeadRequest(client); + + // Send with promise handler + ConcurrentMap>> promises + = new ConcurrentHashMap<>(); + PushPromiseHandler pph = PushPromiseHandler + .of((r) -> BodyHandlers.ofString(), promises); + + for (int j=0; j < 2; j++) { + if (j == 0) System.out.println("\ntestNoCancel: First time around"); + else System.out.println("\ntestNoCancel: Second time around: should be a new connection"); + + int waitForPushId; + for (int i = 0; i < 3; i++) { + HttpResponse main; + waitForPushId = i * Math.min(PUSH_PROMISES.size(), maxPushes) + 1; + try { + main = client.sendAsync( + HttpRequest.newBuilder(uri) + .header("X-WaitForPushId", String.valueOf(waitForPushId)) + .build(), + BodyHandlers.ofString(), + pph) + .join(); + } catch (CompletionException c) { + throw new AssertionError(c.getCause()); + } + + promises.forEach((key, value1) -> System.out.println(key + ":" + value1.join().body())); + + promises.putIfAbsent(main.request(), CompletableFuture.completedFuture(main)); + promises.forEach((request, value) -> { + HttpResponse response = value.join(); + assertEquals(response.statusCode(), 200); + if (PUSH_PROMISES.containsKey(request.uri().getPath())) { + assertEquals(response.body(), PUSH_PROMISES.get(request.uri().getPath())); + } else { + assertEquals(response.body(), MAIN_RESPONSE_BODY); + } + }); + assertEquals(promises.size(), Math.min(PUSH_PROMISES.size(), maxPushes) + 1); + + promises.clear(); + } + + // Send with no promise handler + try { + client.sendAsync(HttpRequest.newBuilder(uri).build(), BodyHandlers.ofString()) + .thenApply(H3PushCancel::assert200ResponseCode) + .thenApply(HttpResponse::body) + .thenAccept(body -> assertEquals(body, MAIN_RESPONSE_BODY)) + .join(); + } catch (CompletionException c) { + throw new AssertionError(c.getCause()); + } + assertEquals(promises.size(), 0); + + // Send with no promise handler, but use pushId bigger than allowed. + // This should cause the connection to get closed + long usePushId = maxPushes * 3 + 10; + try { + HttpRequest bigger = HttpRequest.newBuilder(uri) + .header("X-UsePushId", String.valueOf(usePushId)) + .build(); + client.sendAsync(bigger, BodyHandlers.ofString()) + .thenApply(H3PushCancel::assert200ResponseCode) + .thenApply(HttpResponse::body) + .thenAccept(body -> assertEquals(body, MAIN_RESPONSE_BODY)) + .join(); + throw new AssertionError("Expected IOException not thrown"); + } catch (CompletionException c) { + boolean success = false; + if (c.getCause() instanceof IOException io) { + if (io.getMessage() != null && + io.getMessage().contains("Max pushId exceeded (%s >= %s)" + .formatted(usePushId, maxPushes * 3))) { + success = true; + } + if (success) { + System.out.println("Got expected IOException: " + io); + } else throw io; + } + if (!success) { + throw new AssertionError("Unexpected exception: " + c.getCause(), c.getCause()); + } + } + assertEquals(promises.size(), 0); + + // the next time around we should have a new connection + // so we can restart from scratch + pushHandler.reset(); + } + } + } + + @Test + public void testCancel() throws Exception { + int maxPushes = Utils.getIntegerProperty("jdk.httpclient.http3.maxConcurrentPushStreams", -1); + System.out.println("maxPushes: " + maxPushes); + assertTrue(maxPushes > 0); + try (HttpClient client = newClientBuilderForH3() + .proxy(Builder.NO_PROXY) + .version(Version.HTTP_3) + .sslContext(new SimpleSSLContext().get()) + .build()) { + + sendHeadRequest(client); + + // Send with promise handler + ConcurrentMap>> promises + = new ConcurrentHashMap<>(); + PushPromiseHandler pph = PushPromiseHandler + .of((r) -> BodyHandlers.ofString(), promises); + record NotifiedPromise(PushId pushId, HttpRequest initiatingRequest) {} + final Map requestToPushId = new ConcurrentHashMap<>(); + final Map pushIdToRequest = new ConcurrentHashMap<>(); + final List errors = new CopyOnWriteArrayList<>(); + final List notified = new CopyOnWriteArrayList<>(); + PushPromiseHandler custom = new PushPromiseHandler<>() { + @Override + public void applyPushPromise(HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + Function, CompletableFuture>> acceptor) { + pph.applyPushPromise(initiatingRequest, pushPromiseRequest, acceptor); + } + @Override + public void notifyAdditionalPromise(HttpRequest initiatingRequest, PushId pushid) { + notified.add(new NotifiedPromise(pushid, initiatingRequest)); + pph.notifyAdditionalPromise(initiatingRequest, pushid); + } + @Override + public void applyPushPromise(HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + PushId pushid, + Function, CompletableFuture>> acceptor) { + System.out.println("applyPushPromise: " + pushPromiseRequest + ", pushId=" + pushid); + requestToPushId.putIfAbsent(pushPromiseRequest, pushid); + if (pushIdToRequest.putIfAbsent(pushid, pushPromiseRequest) != null) { + errors.add(new AssertionError("pushId already used: " + pushid)); + } + if (pushid instanceof Http3PushId http3PushId) { + if (http3PushId.pushId() == 1) { + System.out.println("Cancelling: " + http3PushId); + return; // cancel pushId == 1 + } + } + pph.applyPushPromise(initiatingRequest, pushPromiseRequest, pushid, acceptor); + } + }; + + for (int j=0; j < 2; j++) { + if (j == 0) System.out.println("\ntestCancel: First time around"); + else System.out.println("\ntestCancel: Second time around: should be a new connection"); + + int waitForPushId; + for (int i = 0; i < 3; i++) { + HttpResponse main; + waitForPushId = i * Math.min(PUSH_PROMISES.size(), maxPushes) + 1; + try { + main = client.sendAsync( + HttpRequest.newBuilder(uri) + .header("X-WaitForPushId", String.valueOf(waitForPushId)) + .build(), + BodyHandlers.ofString(), + custom) + .join(); + } catch (CompletionException c) { + throw new AssertionError(c.getCause()); + } + + promises.forEach((key, value) -> System.out.println(key + ":" + value.join().body())); + + promises.putIfAbsent(main.request(), CompletableFuture.completedFuture(main)); + promises.forEach((request, value) -> { + HttpResponse response = value.join(); + assertEquals(response.statusCode(), 200); + if (PUSH_PROMISES.containsKey(request.uri().getPath())) { + assertEquals(response.body(), PUSH_PROMISES.get(request.uri().getPath())); + } else { + assertEquals(response.body(), MAIN_RESPONSE_BODY); + } + }); + int expectedPushes = Math.min(PUSH_PROMISES.size(), maxPushes) + 1; + if (i == 0) expectedPushes--; // pushId == 1 was cancelled + assertEquals(promises.size(), expectedPushes); + + promises.clear(); + } + + // Send with no promise handler + try { + client.sendAsync(HttpRequest.newBuilder(uri).build(), BodyHandlers.ofString()) + .thenApply(H3PushCancel::assert200ResponseCode) + .thenApply(HttpResponse::body) + .thenAccept(body -> assertEquals(body, MAIN_RESPONSE_BODY)) + .join(); + } catch (CompletionException c) { + throw new AssertionError(c.getCause()); + } + assertEquals(promises.size(), 0); + + // Send with no promise handler, but use pushId bigger than allowed. + // This should cause the connection to get closed + long usePushId = maxPushes * 3 + 10; + try { + HttpRequest bigger = HttpRequest.newBuilder(uri) + .header("X-UsePushId", String.valueOf(usePushId)) + .build(); + client.sendAsync(bigger, BodyHandlers.ofString()) + .thenApply(H3PushCancel::assert200ResponseCode) + .thenApply(HttpResponse::body) + .thenAccept(body ->assertEquals(body, MAIN_RESPONSE_BODY)) + .join(); + throw new AssertionError("Expected IOException not thrown"); + } catch (CompletionException c) { + boolean success = false; + if (c.getCause() instanceof IOException io) { + if (io.getMessage() != null && + io.getMessage().contains("Max pushId exceeded (%s >= %s)" + .formatted(usePushId, maxPushes * 3))) { + success = true; + } + if (success) { + System.out.println("Got expected IOException: " + io); + } else throw io; + } + if (!success) { + throw new AssertionError("Unexpected exception: " + c.getCause(), c.getCause()); + } + } + assertEquals(promises.size(), 0); + + // the next time around we should have a new connection + // so we can restart from scratch + pushHandler.reset(); + } + errors.forEach(t -> t.printStackTrace(System.out)); + var error = errors.stream().findFirst().orElse(null); + if (error != null) throw error; + assertEquals(notified.size(), 0, "Unexpected notification: " + notified); + } + } + + + // --- server push handler --- + static class ServerPushHandler implements HttpTestHandler { + + private final String mainResponseBody; + private final Map promises; + private final ReentrantLock lock = new ReentrantLock(); + + public ServerPushHandler(String mainResponseBody, + Map promises) + throws Exception + { + Objects.requireNonNull(promises); + this.mainResponseBody = mainResponseBody; + this.promises = promises; + } + + final AtomicInteger count = new AtomicInteger(); + public void handle(HttpTestExchange exchange) throws IOException { + long count = -1; + lock.lock(); + try { + count = this.count.incrementAndGet(); + System.err.println("Server: handle " + exchange); + System.out.println("Server: handle " + exchange.getRequestURI()); + try (InputStream is = exchange.getRequestBody()) { + is.readAllBytes(); + } + + if (exchange.serverPushAllowed()) { + pushPromises(exchange); + } + + // response data for the main response + try (OutputStream os = exchange.getResponseBody()) { + byte[] bytes = mainResponseBody.getBytes(UTF_8); + exchange.sendResponseHeaders(200, bytes.length); + os.write(bytes); + } catch (ClosedChannelException ex) { + System.out.printf("handling exchange %s, %s: %s%n", count, + exchange.getRequestURI(), exchange.getRequestHeaders()); + System.out.printf("Got closed channel exception sending response after sent=%s allowed=%s%n", + sent, allowed); + } + } finally { + lock.unlock(); + System.out.printf("handled exchange %s, %s: %s%n", count, + exchange.getRequestURI(), exchange.getRequestHeaders()); + } + } + + volatile long allowed = -1; + volatile int sent = 0; + void reset() { + lock.lock(); + try { + allowed = -1; + sent = 0; + } finally { + lock.unlock(); + } + } + + private void pushPromises(HttpTestExchange exchange) throws IOException { + URI requestURI = exchange.getRequestURI(); + long waitForPushId = exchange.getRequestHeaders() + .firstValueAsLong("X-WaitForPushId").orElse(-1); + long usePushId = exchange.getRequestHeaders() + .firstValueAsLong("X-UsePushId").orElse(-1); + if (waitForPushId >= 0) { + while (allowed <= waitForPushId) { + try { + System.err.printf("Server: waiting for pushId sent=%s allowed=%s: %s%n", + sent, allowed, waitForPushId); + allowed = exchange.waitForHttp3MaxPushId(waitForPushId); + System.err.println("Server: Got maxPushId: " + allowed); + } catch (InterruptedException ie) { + ie.printStackTrace(); + } + } + } + for (Map.Entry promise : promises.entrySet()) { + // if usePushId != -1 we send a single push promise, + // without checking that's it's allowed. + // Otherwise, we stop sending promises when we have consumed + // the whole window + if (usePushId == -1 && allowed > 0 && sent >= allowed) { + System.err.println("Server: sent all allowed promises: " + sent); + break; + } + if (waitForPushId >= 0) { + while (allowed <= waitForPushId) { + try { + System.err.printf("Server: waiting for pushId sent=%s allowed=%s: %s%n", + sent, allowed, waitForPushId); + allowed = exchange.waitForHttp3MaxPushId(waitForPushId); + System.err.println("Server: Got maxPushId: " + allowed); + } catch (InterruptedException ie) { + ie.printStackTrace(); + } + } + } + URI uri = requestURI.resolve(promise.getKey()); + InputStream is = new ByteArrayInputStream(promise.getValue().getBytes(UTF_8)); + HttpHeaders headers = HttpHeaders.of(Collections.emptyMap(), (x, y) -> true); + if (usePushId == -1) { + long pushId = exchange.http3ServerPush(uri, headers, headers, is); + System.err.println("Server: Sent push promise with response: " + pushId); + waitForPushId = pushId + 1; // assuming no concurrent requests... + sent += 1; + } else { + exchange.sendHttp3PushPromiseFrame(usePushId, uri, headers); + System.err.println("Server: Sent push promise frame: " + usePushId); + exchange.sendHttp3PushResponse(usePushId, uri, headers, headers, is); + System.err.println("Server: Sent push promise response: " + usePushId); + sent += 1; + return; + } + } + System.err.println("Server: All pushes sent"); + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3QuicTLSConnection.java b/test/jdk/java/net/httpclient/http3/H3QuicTLSConnection.java new file mode 100644 index 00000000000..cd699e7774d --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3QuicTLSConnection.java @@ -0,0 +1,363 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpResponse.BodyHandlers; +import java.net.http.UnsupportedProtocolVersionException; +import java.util.List; +import java.util.Optional; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestExchange; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestHandler; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; +import jdk.test.lib.net.SimpleSSLContext; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; + + +/* + * @test + * @summary verifies that the SSLParameters configured with specific cipher suites + * and TLS protocol versions gets used by the HttpClient for HTTP/3 + * @library /test/jdk/java/net/httpclient/lib /test/lib + * @build jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.test.lib.net.SimpleSSLContext + * @run main/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=all + * H3QuicTLSConnection + */ +public class H3QuicTLSConnection { + + private static final SSLParameters DEFAULT_SSL_PARAMETERS = new SSLParameters(); + + // expect highest supported version we know about + private static String expectedTLSVersion(SSLContext ctx) throws Exception { + if (ctx == null) { + ctx = SSLContext.getDefault(); + } + SSLParameters params = ctx.getSupportedSSLParameters(); + String[] protocols = params.getProtocols(); + for (String prot : protocols) { + if (prot.equals("TLSv1.3")) + return "TLSv1.3"; + } + return "TLSv1.2"; + } + + public static void main(String[] args) throws Exception { + // create and set the default SSLContext + SSLContext context = new SimpleSSLContext().get(); + SSLContext.setDefault(context); + + Handler handler = new Handler(); + + try (HttpTestServer server = HttpTestServer.create(HTTP_3_URI_ONLY, SSLContext.getDefault())) { + server.addHandler(handler, "/"); + server.start(); + + String uriString = "https://" + server.serverAuthority(); + + // run test cases + boolean success = true; + + SSLParameters parameters = null; + success &= expectFailure( + "---\nTest #1: SSL parameters is null, expect NPE", + () -> connect(uriString, parameters), + NullPointerException.class, + Optional.empty()); + + success &= expectSuccess( + "---\nTest #2: default SSL parameters, " + + "expect successful connection", + () -> connect(uriString, DEFAULT_SSL_PARAMETERS)); + success &= checkProtocol(handler.getSSLSession(), expectedTLSVersion(null)); + + success &= expectFailure( + "---\nTest #3: SSL parameters with " + + "TLS_AES_128_GCM_SHA256 cipher suite, but TLSv1.2 " + + "expect UnsupportedProtocolVersionException", + () -> connect(uriString, new SSLParameters( + new String[]{"TLS_AES_128_GCM_SHA256"}, + new String[]{"TLSv1.2"})), + UnsupportedProtocolVersionException.class, + Optional.empty()); + + // set TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 and expect it to fail since + // it's not supported with TLS v1.3 + success &= expectFailure( + "---\nTest #4: SSL parameters with " + + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 cipher suite, " + + "expect No appropriate protocol " + + "(protocol is disabled or cipher suites are inappropriate)", + () -> connect(uriString, new SSLParameters( + new String[]{"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"}, + new String[]{"TLSv1.3"})), + SSLHandshakeException.class, + Optional.of("protocol is disabled or cipher suites are inappropriate")); + + // set TLS_AES_128_CCM_8_SHA256 cipher suite + // which is not supported by the (default) SunJSSE provider + // and forbidden in QUIC in general + success &= expectFailure( + "---\nTest #5: SSL parameters with " + + "TLS_AES_128_CCM_8_SHA256 cipher suite, " + + "expect IllegalArgumentException: Unsupported CipherSuite", + () -> connect(uriString, new SSLParameters( + new String[]{"TLS_AES_128_CCM_8_SHA256"}, + new String[]{"TLSv1.3"})), + IllegalArgumentException.class, + Optional.of("Unsupported CipherSuite")); + + // set TLS_AES_128_GCM_SHA256 and TLS_AES_256_GCM_SHA384 cipher suite + var suites = List.of("TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384", + "TLS_CHACHA20_POLY1305_SHA256"); + success &= expectSuccess( + "---\nTest #6: SSL parameters with " + + "TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, " + + "and TLS_CHACHA20_POLY1305_SHA256 cipher suites," + + " expect successful connection", + () -> connect(uriString, new SSLParameters( + suites.toArray(new String[0]), + new String[]{"TLSv1.3"}))); + success &= checkProtocol(handler.getSSLSession(), "TLSv1.3"); + success &= checkCipherSuite(handler.getSSLSession(), suites); + + // set TLS_AES_128_GCM_SHA256 cipher suite + success &= expectSuccess( + "---\nTest #7: SSL parameters with " + + "TLS_AES_128_GCM_SHA256 cipher suites," + + " expect successful connection", + () -> connect(uriString, new SSLParameters( + new String[]{"TLS_AES_128_GCM_SHA256"}, + new String[]{"TLSv1.3"}))); + success &= checkProtocol(handler.getSSLSession(), "TLSv1.3"); + success &= checkCipherSuite(handler.getSSLSession(), + "TLS_AES_128_GCM_SHA256"); + + // set TLS_AES_256_GCM_SHA384 cipher suite + success &= expectSuccess( + "---\nTest #8: SSL parameters with " + + "TLS_AES_256_GCM_SHA384 cipher suites," + + " expect successful connection", + () -> connect(uriString, new SSLParameters( + new String[]{"TLS_AES_256_GCM_SHA384"}, + new String[]{"TLSv1.3"}))); + success &= checkProtocol(handler.getSSLSession(), "TLSv1.3"); + success &= checkCipherSuite(handler.getSSLSession(), + "TLS_AES_256_GCM_SHA384"); + + // set TLS_CHACHA20_POLY1305_SHA256 cipher suite + success &= expectSuccess( + "---\nTest #9: SSL parameters with " + + "TLS_CHACHA20_POLY1305_SHA256 cipher suites," + + " expect successful connection", + () -> connect(uriString, new SSLParameters( + new String[]{"TLS_CHACHA20_POLY1305_SHA256"}, + new String[]{"TLSv1.3"}))); + success &= checkProtocol(handler.getSSLSession(), "TLSv1.3"); + success &= checkCipherSuite(handler.getSSLSession(), + "TLS_CHACHA20_POLY1305_SHA256"); + + if (success) { + System.out.println("Test passed"); + } else { + throw new RuntimeException("At least one test case failed"); + } + } + } + + private interface Test { + void run() throws Exception; + } + + private static class Handler implements HttpTestHandler { + + private static final byte[] BODY = "Test response".getBytes(); + + private volatile SSLSession sslSession; + + @Override + public void handle(HttpTestExchange t) throws IOException { + System.out.println("Handler: received request to " + + t.getRequestURI()); + + try (InputStream is = t.getRequestBody()) { + byte[] body = is.readAllBytes(); + System.out.println("Handler: read " + body.length + + " bytes of body: "); + System.out.println(new String(body)); + } + + sslSession = t.getSSLSession(); + + try (OutputStream os = t.getResponseBody()) { + t.sendResponseHeaders(200, BODY.length); + os.write(BODY); + } + + } + + SSLSession getSSLSession() { + return sslSession; + } + } + + private static void connect(String uriString, SSLParameters sslParameters) + throws URISyntaxException, IOException, InterruptedException { + HttpClient.Builder builder = HttpServerAdapters.createClientBuilderForH3() + .proxy(Builder.NO_PROXY) + .version(HttpClient.Version.HTTP_3); + if (sslParameters != DEFAULT_SSL_PARAMETERS) + builder.sslParameters(sslParameters); + try (final HttpClient client = builder.build()) { + HttpRequest request = HttpRequest.newBuilder(new URI(uriString)) + .POST(BodyPublishers.ofString("body")) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .version(Version.HTTP_3) + .build(); + String body = client.send(request, BodyHandlers.ofString()).body(); + System.out.println("Response: " + body); + } catch (UncheckedIOException uio) { + throw uio.getCause(); + } + } + + private static boolean checkProtocol(SSLSession session, String protocol) { + if (session == null) { + System.out.println("Check protocol: no session provided"); + return false; + } + + System.out.println("Check protocol: negotiated protocol: " + + session.getProtocol()); + System.out.println("Check protocol: expected protocol: " + + protocol); + if (!protocol.equals(session.getProtocol())) { + System.out.println("Check protocol: unexpected negotiated protocol"); + return false; + } + + return true; + } + + private static boolean checkCipherSuite(SSLSession session, String ciphersuite) { + if (session == null) { + System.out.println("Check protocol: no session provided"); + return false; + } + + System.out.println("Check protocol: negotiated ciphersuite: " + + session.getCipherSuite()); + System.out.println("Check protocol: expected ciphersuite: " + + ciphersuite); + if (!ciphersuite.equals(session.getCipherSuite())) { + System.out.println("Check protocol: unexpected negotiated ciphersuite"); + return false; + } + + return true; + } + + private static boolean checkCipherSuite(SSLSession session, List ciphersuites) { + if (session == null) { + System.out.println("Check protocol: no session provided"); + return false; + } + + System.out.println("Check protocol: negotiated ciphersuite: " + + session.getCipherSuite()); + System.out.println("Check protocol: expected ciphersuite in: " + + ciphersuites); + if (!ciphersuites.contains(session.getCipherSuite())) { + System.out.println("Check protocol: unexpected negotiated ciphersuite"); + return false; + } + + return true; + } + + private static boolean expectSuccess(String message, Test test) { + System.out.println(message); + try { + test.run(); + System.out.println("Passed"); + return true; + } catch (Exception e) { + System.out.println("Failed: unexpected exception:"); + e.printStackTrace(System.out); + return false; + } + } + + private static boolean expectFailure(String message, Test test, + Class expectedException, + Optional exceptionMsg) { + + System.out.println(message); + try { + test.run(); + System.out.println("Failed: unexpected successful connection"); + return false; + } catch (Exception e) { + System.out.println("Got an exception:"); + e.printStackTrace(System.out); + if (expectedException != null + && !expectedException.isAssignableFrom(e.getClass())) { + System.out.printf("Failed: expected %s, but got %s%n", + expectedException.getName(), + e.getClass().getName()); + return false; + } + if (exceptionMsg.isPresent()) { + final String actualMsg = e.getMessage(); + if (actualMsg == null || !actualMsg.contains(exceptionMsg.get())) { + System.out.printf("Failed: exception message was expected" + + " to contain \"%s\", but got \"%s\"%n", + exceptionMsg.get(), actualMsg); + return false; + } + } + System.out.println("Passed: expected exception"); + return true; + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3RedirectTest.java b/test/jdk/java/net/httpclient/http3/H3RedirectTest.java new file mode 100644 index 00000000000..4f6fbaeef6f --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3RedirectTest.java @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8156514 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.TestUtil + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @compile ../ReferenceTracker.java + * @run testng/othervm + * -Djdk.httpclient.HttpClient.log=frames,ssl,requests,responses,errors + * -Djdk.internal.httpclient.debug=true + * H3RedirectTest + */ + +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.util.Arrays; +import java.util.Iterator; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; + +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.Test; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; + +public class H3RedirectTest implements HttpServerAdapters { + static int httpPort; + static SSLContext sslContext; + static HttpTestServer http3Server; + static HttpClient client; + + static String httpURIString, altURIString1, altURIString2; + static URI httpURI, altURI1, altURI2; + + static Supplier sup(String... args) { + Iterator i = Arrays.asList(args).iterator(); + // need to know when to stop calling it. + return i::next; + } + + static class Redirector extends HttpTestRedirectHandler { + private InetSocketAddress remoteAddr; + private boolean error = false; + + Redirector(Supplier supplier) { + super(supplier); + } + + @Override + protected synchronized void examineExchange(HttpTestExchange ex) { + InetSocketAddress addr = ex.getRemoteAddress(); + if (remoteAddr == null) { + remoteAddr = addr; + return; + } + // check that the client addr/port stays the same, proving + // that the connection didn't get dropped. + if (!remoteAddr.equals(addr)) { + System.err.printf("Error %s/%s\n", remoteAddr, + addr.toString()); + error = true; + } + } + + @Override + protected int redirectCode() { + return 308; // we need to use a code that preserves the body + } + + public synchronized boolean error() { + return error; + } + } + + static void initialize() throws Exception { + try { + SimpleSSLContext sslct = new SimpleSSLContext(); + sslContext = sslct.get(); + client = getClient(); + http3Server = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + httpPort = http3Server.getAddress().getPort(); + String serverAuth = http3Server.serverAuthority(); + + // urls are accessed in sequence below. The first two are on + // different servers. Third on same server as second. So, the + // client should use the same http connection. + httpURIString = "https://" + serverAuth + "/foo/"; + httpURI = URI.create(httpURIString); + altURIString1 = "https://" + serverAuth + "/redir"; + altURI1 = URI.create(altURIString1); + altURIString2 = "https://" + serverAuth + "/redir_again"; + altURI2 = URI.create(altURIString2); + + Redirector r = new Redirector(sup(altURIString1, altURIString2)); + http3Server.addHandler(r, "/foo"); + http3Server.addHandler(r, "/redir"); + http3Server.addHandler(new HttpTestEchoHandler(), "/redir_again"); + + http3Server.start(); + } catch (Throwable e) { + System.err.println("Throwing now"); + e.printStackTrace(); + throw e; + } + } + + @Test + public static void test() throws Exception { + try { + initialize(); + simpleTest(); + ReferenceTracker.INSTANCE.track(client); + client = null; + System.gc(); + var error = ReferenceTracker.INSTANCE.check(1500); + if (error != null) throw error; + } finally { + http3Server.stop(); + } + } + + static HttpClient getClient() { + if (client == null) { + client = HttpServerAdapters.createClientBuilderForH3() + .followRedirects(HttpClient.Redirect.ALWAYS) + .sslContext(sslContext) + .version(HTTP_3) + .build(); + } + return client; + } + + static URI getURI() { + return URI.create(httpURIString); + } + + static void checkStatus(int expected, int found) throws Exception { + if (expected != found) { + System.err.printf ("Test failed: wrong status code %d/%d\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static void checkURIs(URI expected, URI found) throws Exception { + System.out.printf ("Expected: %s, Found: %s\n", expected, found); + if (!expected.equals(found)) { + System.err.printf ("Test failed: wrong URI %s/%s\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static void checkStrings(String expected, String found) throws Exception { + if (!expected.equals(found)) { + System.err.printf ("Test failed: wrong string %s/%s\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static void check(boolean cond, Object... msg) { + if (cond) + return; + StringBuilder sb = new StringBuilder(); + for (Object o : msg) + sb.append(o); + throw new RuntimeException(sb.toString()); + } + + static final String SIMPLE_STRING = "Hello world Goodbye world"; + + static void simpleTest() throws Exception { + URI uri = getURI(); + System.err.println("Request to " + uri); + + HttpClient client = getClient(); + HttpRequest req = HttpRequest.newBuilder(uri) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .POST(BodyPublishers.ofString(SIMPLE_STRING)) + .build(); + CompletableFuture> cf = client.sendAsync(req, BodyHandlers.ofString()); + HttpResponse response = cf.join(); + + checkStatus(200, response.statusCode()); + String responseBody = response.body(); + checkStrings(SIMPLE_STRING, responseBody); + checkURIs(response.uri(), altURI2); + + // check two previous responses + HttpResponse prev = response.previousResponse() + .orElseThrow(() -> new RuntimeException("no previous response")); + checkURIs(prev.uri(), altURI1); + + prev = prev.previousResponse() + .orElseThrow(() -> new RuntimeException("no previous response")); + checkURIs(prev.uri(), httpURI); + + checkPreviousRedirectResponses(req, response); + + System.err.println("DONE"); + } + + static void checkPreviousRedirectResponses(HttpRequest initialRequest, + HttpResponse finalResponse) { + // there must be at least one previous response + finalResponse.previousResponse() + .orElseThrow(() -> new RuntimeException("no previous response")); + + HttpResponse response = finalResponse; + do { + URI uri = response.uri(); + response = response.previousResponse().get(); + check(300 <= response.statusCode() && response.statusCode() <= 309, + "Expected 300 <= code <= 309, got:" + response.statusCode()); + check(response.body() == null, "Unexpected body: " + response.body()); + String locationHeader = response.headers().firstValue("Location") + .orElseThrow(() -> new RuntimeException("no previous Location")); + check(uri.toString().endsWith(locationHeader), + "URI: " + uri + ", Location: " + locationHeader); + } while (response.previousResponse().isPresent()); + + // initial + check(initialRequest.equals(response.request()), + "Expected initial request [%s] to equal last prev req [%s]", + initialRequest, response.request()); + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3ServerPush.java b/test/jdk/java/net/httpclient/http3/H3ServerPush.java new file mode 100644 index 00000000000..4b271320b0d --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3ServerPush.java @@ -0,0 +1,396 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8087112 8159814 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.http2.Http2TestServer + * jdk.httpclient.test.lib.http2.PushHandler + * jdk.test.lib.Utils + * jdk.test.lib.net.SimpleSSLContext + * @run testng/othervm/timeout=960 + * -Djdk.httpclient.HttpClient.log=errors,requests,headers + * -Djdk.internal.httpclient.debug=false + * H3ServerPush + * @summary This is a clone of http2/ServerPush but for HTTP/3 + */ + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodyHandlers; +import java.net.http.HttpResponse.BodySubscribers; +import java.net.http.HttpResponse.PushPromiseHandler; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Consumer; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http2.Http2TestServer; +import jdk.httpclient.test.lib.http2.PushHandler; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static jdk.test.lib.Utils.createTempFileOfSize; +import static org.testng.Assert.assertEquals; + +public class H3ServerPush implements HttpServerAdapters { + + private static final String CLASS_NAME = H3ServerPush.class.getSimpleName(); + + static final int LOOPS = 13; + static final int FILE_SIZE = 512 * 1024 + 343; + + static Path tempFile; + + HttpTestServer server; + URI uri; + URI headURI; + + @BeforeTest + public void setup() throws Exception { + tempFile = createTempFileOfSize(CLASS_NAME, ".dat", FILE_SIZE); + var sslContext = new SimpleSSLContext().get(); + var h2Server = new Http2TestServer(true, sslContext); + h2Server.enableH3AltServiceOnSamePort(); + h2Server.addHandler(new PushHandler(tempFile, LOOPS), "/foo/"); + System.out.println("Using temp file:" + tempFile); + server = HttpTestServer.of(h2Server); + server.addHandler(new HttpHeadOrGetHandler(), "/head/"); + headURI = new URI("https://" + server.serverAuthority() + "/head/x"); + uri = new URI("https://" + server.serverAuthority() + "/foo/a/b/c"); + + System.err.println("Server listening at " + server.serverAuthority()); + server.start(); + + } + + private void sendHeadRequest(HttpClient client) throws IOException, InterruptedException { + HttpRequest headRequest = HttpRequest.newBuilder(headURI) + .HEAD().version(Version.HTTP_2).build(); + var headResponse = client.send(headRequest, BodyHandlers.ofString()); + assertEquals(headResponse.statusCode(), 200); + assertEquals(headResponse.version(), Version.HTTP_2); + } + + @AfterTest + public void teardown() { + server.stop(); + } + + // Test 1 - custom written push promise handler, everything as a String + @Test + public void testTypeString() throws Exception { + System.out.println("\n**** testTypeString\n"); + String tempFileAsString = Files.readString(tempFile); + ConcurrentMap>> + resultMap = new ConcurrentHashMap<>(); + + PushPromiseHandler pph = (initial, pushRequest, acceptor) -> { + BodyHandler s = BodyHandlers.ofString(UTF_8); + CompletableFuture> cf = acceptor.apply(s); + resultMap.put(pushRequest, cf); + }; + + try (HttpClient client = newClientBuilderForH3().proxy(Builder.NO_PROXY) + .sslContext(new SimpleSSLContext().get()) + .version(Version.HTTP_3).build()) { + sendHeadRequest(client); + + HttpRequest request = HttpRequest.newBuilder(uri).GET() + .build(); + CompletableFuture> cf = + client.sendAsync(request, BodyHandlers.ofString(UTF_8), pph); + resultMap.put(request, cf); + System.out.println("waiting for response"); + var resp = cf.join(); + assertEquals(resp.version(), Version.HTTP_3); + var seen = new HashSet<>(); + resultMap.forEach((k, v) -> { + if (seen.add(k)) { + System.out.println("Got " + v.join()); + } + }); + + // waiting for all promises to reach us + System.out.println("waiting for promises"); + System.out.println("results.size: " + resultMap.size()); + for (HttpRequest r : resultMap.keySet()) { + System.out.println("Checking " + r); + HttpResponse response = resultMap.get(r).join(); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), Version.HTTP_3); + assertEquals(response.body(), tempFileAsString); + } + resultMap.forEach((k, v) -> { + if (seen.add(k)) { + System.out.println("Got " + v.join()); + } + }); + assertEquals(resultMap.size(), LOOPS + 1); + } + } + + // Test 2 - of(...) populating the given Map, everything as a String + @Test + public void testTypeStringOfMap() throws Exception { + System.out.println("\n**** testTypeStringOfMap\n"); + String tempFileAsString = Files.readString(tempFile); + ConcurrentMap>> + resultMap = new ConcurrentHashMap<>(); + + PushPromiseHandler pph = + PushPromiseHandler.of(pushPromise -> BodyHandlers.ofString(UTF_8), resultMap); + + try (HttpClient client = newClientBuilderForH3().proxy(Builder.NO_PROXY) + .sslContext(new SimpleSSLContext().get()) + .version(Version.HTTP_3).build()) { + sendHeadRequest(client); + HttpRequest request = HttpRequest.newBuilder(uri).GET().build(); + CompletableFuture> cf = + client.sendAsync(request, BodyHandlers.ofString(UTF_8), pph); + cf.join(); + resultMap.put(request, cf); + System.err.println("results.size: " + resultMap.size()); + for (HttpRequest r : resultMap.keySet()) { + HttpResponse response = resultMap.get(r).join(); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), Version.HTTP_3); + assertEquals(response.body(), tempFileAsString); + } + assertEquals(resultMap.size(), LOOPS + 1); + } + } + + // --- Path --- + + static final Path dir = Paths.get(".", "serverPush"); + static BodyHandler requestToPath(HttpRequest req) { + URI u = req.uri(); + Path path = Paths.get(dir.toString(), u.getPath()); + try { + Files.createDirectories(path.getParent()); + } catch (IOException ee) { + throw new UncheckedIOException(ee); + } + return BodyHandlers.ofFile(path); + } + + // Test 3 - custom written push promise handler, everything as a Path + @Test + public void testTypePath() throws Exception { + System.out.println("\n**** testTypePath\n"); + String tempFileAsString = Files.readString(tempFile); + ConcurrentMap>> resultsMap + = new ConcurrentHashMap<>(); + + PushPromiseHandler pushPromiseHandler = (initial, pushRequest, acceptor) -> { + BodyHandler pp = requestToPath(pushRequest); + CompletableFuture> cf = acceptor.apply(pp); + resultsMap.put(pushRequest, cf); + }; + + try (HttpClient client = newClientBuilderForH3().proxy(Builder.NO_PROXY) + .sslContext(new SimpleSSLContext().get()) + .version(Version.HTTP_3).build()) { + sendHeadRequest(client); + + HttpRequest request = HttpRequest.newBuilder(uri).GET().build(); + CompletableFuture> cf = + client.sendAsync(request, requestToPath(request), pushPromiseHandler); + cf.join(); + resultsMap.put(request, cf); + for (HttpRequest r : resultsMap.keySet()) { + HttpResponse response = resultsMap.get(r).join(); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), Version.HTTP_3); + String fileAsString = Files.readString(response.body()); + assertEquals(fileAsString, tempFileAsString); + } + assertEquals(resultsMap.size(), LOOPS + 1); + } + } + + // Test 4 - of(...) populating the given Map, everything as a Path + @Test + public void testTypePathOfMap() throws Exception { + System.out.println("\n**** testTypePathOfMap\n"); + String tempFileAsString = Files.readString(tempFile); + ConcurrentMap>> resultsMap + = new ConcurrentHashMap<>(); + + PushPromiseHandler pushPromiseHandler = + PushPromiseHandler.of(H3ServerPush::requestToPath, resultsMap); + + try (HttpClient client = newClientBuilderForH3().proxy(Builder.NO_PROXY) + .sslContext(new SimpleSSLContext().get()) + .version(Version.HTTP_3).build()) { + sendHeadRequest(client); + + HttpRequest request = HttpRequest.newBuilder(uri).GET().build(); + CompletableFuture> cf = + client.sendAsync(request, requestToPath(request), pushPromiseHandler); + cf.join(); + resultsMap.put(request, cf); + for (HttpRequest r : resultsMap.keySet()) { + HttpResponse response = resultsMap.get(r).join(); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), Version.HTTP_3); + String fileAsString = Files.readString(response.body()); + assertEquals(fileAsString, tempFileAsString); + } + assertEquals(resultsMap.size(), LOOPS + 1); + } + } + + // --- Consumer --- + + static class ByteArrayConsumer implements Consumer> { + volatile List listByteArrays = new ArrayList<>(); + volatile byte[] accumulatedBytes; + + public byte[] getAccumulatedBytes() { return accumulatedBytes; } + + @Override + public void accept(Optional optionalBytes) { + assert accumulatedBytes == null; + if (optionalBytes.isEmpty()) { + int size = listByteArrays.stream().mapToInt(ba -> ba.length).sum(); + ByteBuffer bb = ByteBuffer.allocate(size); + listByteArrays.forEach(bb::put); + accumulatedBytes = bb.array(); + } else { + listByteArrays.add(optionalBytes.get()); + } + } + } + + // Test 5 - custom written handler, everything as a consumer of optional byte[] + @Test + public void testTypeByteArrayConsumer() throws Exception { + System.out.println("\n**** testTypeByteArrayConsumer\n"); + String tempFileAsString = Files.readString(tempFile); + ConcurrentMap>> resultsMap + = new ConcurrentHashMap<>(); + Map byteArrayConsumerMap + = new ConcurrentHashMap<>(); + + try (HttpClient client = newClientBuilderForH3().proxy(Builder.NO_PROXY) + .sslContext(new SimpleSSLContext().get()) + .version(Version.HTTP_3).build()) { + sendHeadRequest(client); + + HttpRequest request = HttpRequest.newBuilder(uri).GET().build(); + ByteArrayConsumer bac = new ByteArrayConsumer(); + byteArrayConsumerMap.put(request, bac); + + PushPromiseHandler pushPromiseHandler = (initial, pushRequest, acceptor) -> { + CompletableFuture> cf = acceptor.apply( + (info) -> { + ByteArrayConsumer bc = new ByteArrayConsumer(); + byteArrayConsumerMap.put(pushRequest, bc); + return BodySubscribers.ofByteArrayConsumer(bc); + }); + resultsMap.put(pushRequest, cf); + }; + + CompletableFuture> cf = + client.sendAsync(request, BodyHandlers.ofByteArrayConsumer(bac), pushPromiseHandler); + cf.join(); + resultsMap.put(request, cf); + for (HttpRequest r : resultsMap.keySet()) { + HttpResponse response = resultsMap.get(r).join(); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), Version.HTTP_3); + byte[] ba = byteArrayConsumerMap.get(r).getAccumulatedBytes(); + String result = new String(ba, UTF_8); + assertEquals(result, tempFileAsString); + } + assertEquals(resultsMap.size(), LOOPS + 1); + } + } + + // Test 6 - of(...) populating the given Map, everything as a consumer of optional byte[] + @Test + public void testTypeByteArrayConsumerOfMap() throws Exception { + System.out.println("\n**** testTypeByteArrayConsumerOfMap\n"); + String tempFileAsString = Files.readString(tempFile); + ConcurrentMap>> resultsMap + = new ConcurrentHashMap<>(); + Map byteArrayConsumerMap + = new ConcurrentHashMap<>(); + + try (HttpClient client = newClientBuilderForH3().proxy(Builder.NO_PROXY) + .sslContext(new SimpleSSLContext().get()) + .version(Version.HTTP_3).build()) { + sendHeadRequest(client); + + HttpRequest request = HttpRequest.newBuilder(uri).GET().build(); + ByteArrayConsumer bac = new ByteArrayConsumer(); + byteArrayConsumerMap.put(request, bac); + + PushPromiseHandler pushPromiseHandler = + PushPromiseHandler.of( + pushRequest -> { + ByteArrayConsumer bc = new ByteArrayConsumer(); + byteArrayConsumerMap.put(pushRequest, bc); + return BodyHandlers.ofByteArrayConsumer(bc); + }, + resultsMap); + + CompletableFuture> cf = + client.sendAsync(request, BodyHandlers.ofByteArrayConsumer(bac), pushPromiseHandler); + cf.join(); + resultsMap.put(request, cf); + for (HttpRequest r : resultsMap.keySet()) { + HttpResponse response = resultsMap.get(r).join(); + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), Version.HTTP_3); + byte[] ba = byteArrayConsumerMap.get(r).getAccumulatedBytes(); + String result = new String(ba, UTF_8); + assertEquals(result, tempFileAsString); + } + assertEquals(resultsMap.size(), LOOPS + 1); + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3ServerPushCancel.java b/test/jdk/java/net/httpclient/http3/H3ServerPushCancel.java new file mode 100644 index 00000000000..32ca87e20f5 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3ServerPushCancel.java @@ -0,0 +1,607 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=errors,requests,responses,trace + * -Djdk.httpclient.http3.maxConcurrentPushStreams=45 + * H3ServerPushCancel + * @summary This test checks that the client deals correctly with a + * CANCEL_PUSH frame sent by the server + */ + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InterruptedIOException; +import java.io.OutputStream; +import java.io.PrintStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodyHandlers; +import java.net.http.HttpResponse.PushPromiseHandler; +import java.net.http.HttpResponse.PushPromiseHandler.PushId; +import java.net.http.HttpResponse.PushPromiseHandler.PushId.Http3PushId; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; +import java.util.function.Supplier; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.internal.net.http.common.Utils; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + +public class H3ServerPushCancel implements HttpServerAdapters { + + // dummy hack to prevent the IDE complaining that calling + // println will throw NPE + static final PrintStream err = System.err; + static final PrintStream out = System.out; + + static Map PUSH_PROMISES = Map.of( + "/x/y/z/1", "the first push promise body", + "/x/y/z/2", "the second push promise body", + "/x/y/z/3", "the third push promise body", + "/x/y/z/4", "the fourth push promise body", + "/x/y/z/5", "the fifth push promise body", + "/x/y/z/6", "the sixth push promise body", + "/x/y/z/7", "the seventh push promise body", + "/x/y/z/8", "the eight push promise body", + "/x/y/z/9", "the ninth push promise body" + ); + static final String MAIN_RESPONSE_BODY = "the main response body"; + static final int REQUESTS = 5; + + HttpTestServer server; + URI uri; + URI headURI; + ServerPushHandler pushHandler; + + @BeforeTest + public void setup() throws Exception { + server = HttpTestServer.create(ANY, new SimpleSSLContext().get()); + pushHandler = new ServerPushHandler(MAIN_RESPONSE_BODY, PUSH_PROMISES); + server.addHandler(pushHandler, "/push/"); + server.addHandler(new HttpHeadOrGetHandler(), "/head/"); + server.start(); + err.println("Server listening on port " + server.serverAuthority()); + uri = new URI("https://" + server.serverAuthority() + "/push/a/b/c"); + headURI = new URI("https://" + server.serverAuthority() + "/head/x"); + } + + @AfterTest + public void teardown() { + server.stop(); + } + + static HttpResponse assert200ResponseCode(HttpResponse response) { + assertEquals(response.statusCode(), 200); + assertEquals(response.version(), HTTP_3); + return response; + } + + private void sendHeadRequest(HttpClient client) throws IOException, InterruptedException { + HttpRequest headRequest = HttpRequest.newBuilder(headURI) + .HEAD().version(HTTP_2).build(); + var headResponse = client.send(headRequest, BodyHandlers.ofString()); + assertEquals(headResponse.statusCode(), 200); + assertEquals(headResponse.version(), HTTP_2); + } + + static final class TestPushPromiseHandler implements PushPromiseHandler { + record NotifiedPromise(PushId pushId, HttpRequest initiatingRequest) {} + final Map requestToPushId = new ConcurrentHashMap<>(); + final Map pushIdToRequest = new ConcurrentHashMap<>(); + final List errors = new CopyOnWriteArrayList<>(); + final List notified = new CopyOnWriteArrayList<>(); + final ConcurrentMap>> promises + = new ConcurrentHashMap<>(); + final Supplier> bodyHandlerSupplier; + final PushPromiseHandler pph; + TestPushPromiseHandler(Supplier> bodyHandlerSupplier) { + this.bodyHandlerSupplier = bodyHandlerSupplier; + this.pph = PushPromiseHandler.of((r) -> bodyHandlerSupplier.get(), promises); + } + + @Override + public void applyPushPromise(HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + Function, CompletableFuture>> acceptor) { + errors.add(new AssertionError("no pushID provided for: " + pushPromiseRequest)); + } + + @Override + public void notifyAdditionalPromise(HttpRequest initiatingRequest, PushId pushid) { + notified.add(new NotifiedPromise(pushid, initiatingRequest)); + out.println("notifyPushPromise: pushId=" + pushid); + pph.notifyAdditionalPromise(initiatingRequest, pushid); + } + + @Override + public void applyPushPromise(HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + PushId pushid, + Function, CompletableFuture>> acceptor) { + out.println("applyPushPromise: " + pushPromiseRequest + ", pushId=" + pushid); + requestToPushId.putIfAbsent(pushPromiseRequest, pushid); + if (pushIdToRequest.putIfAbsent(pushid, pushPromiseRequest) != null) { + errors.add(new AssertionError("pushId already used: " + pushid)); + } + pph.applyPushPromise(initiatingRequest, pushPromiseRequest, pushid, acceptor); + } + + } + + T join(CompletableFuture cf) { + try { + return cf.join(); + } catch (CompletionException c) { + throw new AssertionError(c.getCause()); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + String describeBody(CompletableFuture> cf) { + return cf.thenApply(HttpResponse::body).exceptionally(Throwable::toString).join(); + } + + String describeKey(TestPushPromiseHandler pph, HttpRequest request) { + if (PUSH_PROMISES.containsKey(request.uri().getPath())) { + var pushId = pph.requestToPushId.get(request); + if (pushId instanceof Http3PushId h3id) { + return "[push=%s]".formatted(h3id.pushId()) + request; + } else return "[push=%s]".formatted(pushId) + request; + } else { + return "[main] " + request; + } + } + + @Test + public void testServerCancelPushes() throws Exception { + int maxPushes = Utils.getIntegerProperty("jdk.httpclient.http3.maxConcurrentPushStreams", -1); + out.println("maxPushes: " + maxPushes); + assertTrue(maxPushes > 0); + try (HttpClient client = newClientBuilderForH3() + .proxy(Builder.NO_PROXY) + .version(HTTP_3) + .sslContext(new SimpleSSLContext().get()) + .build()) { + + sendHeadRequest(client); + + // Send with promise handler + TestPushPromiseHandler custom = new TestPushPromiseHandler<>(BodyHandlers::ofString); + var promises = custom.promises; + Set expectedPushIds = new HashSet<>(); + + for (int j=0; j < 2; j++) { + if (j == 0) out.println("\ntestCancel: First time around"); + else out.println("\ntestCancel: Second time around: should be a new connection"); + + // now make sure there's an HTTP/3 connection + client.send(HttpRequest.newBuilder(headURI).version(HTTP_3) + .setOption(H3_DISCOVERY, ALT_SVC) + .HEAD().build(), BodyHandlers.discarding()); + + int waitForPushId; + List>> responses = new ArrayList<>(); + for (int i = 0; i < REQUESTS; i++) { + waitForPushId = Math.min(PUSH_PROMISES.size(), maxPushes) + 1; + CompletableFuture> main = client.sendAsync( + HttpRequest.newBuilder(uri.resolve("?i=%s,j=%s".formatted(i, j))) + .header("X-WaitForPushId", String.valueOf(waitForPushId)) + .build(), + BodyHandlers.ofString(), + custom); + responses.add(main); + } + + join(CompletableFuture.allOf(responses.toArray(CompletableFuture[]::new))); + + responses.forEach(cf -> { + var main = join(cf); + var old = promises.put(main.request(), CompletableFuture.completedFuture(main)); + assertNull(old, "unexpected mapping for: " + old); + }); + + promises.forEach((key, value) -> out.println(describeKey(custom, key) + ":" + describeBody(value))); + + promises.forEach((request, value) -> { + if (PUSH_PROMISES.containsKey(request.uri().getPath())) { + var pushId = custom.requestToPushId.get(request); + assertNotNull(pushId, "no pushId for " + request); + if (pushId instanceof Http3PushId h3id) { + long id = h3id.pushId(); + long mod = id % PUSH_PROMISES.size(); + if (mod > 0 && mod < 4) { + // should have been cancelled by server, so + // we should have an IO for these + Throwable ex = value.exceptionNow(); + var msg = ex.getMessage(); + var expected = "Push promise cancelled: %s".formatted(id); + if (!(ex instanceof IOException)) { + throw new AssertionError(ex); + } else if (!msg.contains(expected)) { + throw new AssertionError("Unexpected message: " + msg, ex); + } + } else { + assertEquals(join(value).body(), PUSH_PROMISES.get(request.uri().getPath())); + } + expectedPushIds.add(pushId); + } else assertEquals(pushId.getClass(), Http3PushId.class); + } else { + HttpResponse response = join(value); + assertEquals(response.statusCode(), 200); + assertEquals(response.body(), MAIN_RESPONSE_BODY); + } + }); + + int maxExpectedPushes = Math.min(PUSH_PROMISES.size(), maxPushes); + int minExpectedPushes = maxExpectedPushes - 3; + int countpushes = promises.size() - REQUESTS; + assert countpushes >= 0; + if (maxExpectedPushes < countpushes || minExpectedPushes > countpushes) { + throw new AssertionError("unexpected number of pushes %s should be in [%s,%s]" + .formatted(countpushes, minExpectedPushes, maxExpectedPushes)); + } + + promises.clear(); + custom.requestToPushId.clear(); + + // Send with no promise handler + try { + client.sendAsync(HttpRequest.newBuilder(uri).build(), BodyHandlers.ofString()) + .thenApply(H3ServerPushCancel::assert200ResponseCode) + .thenApply(HttpResponse::body) + .thenAccept(body -> assertEquals(body, MAIN_RESPONSE_BODY)) + .join(); + } catch (CompletionException c) { + throw new AssertionError(c.getCause()); + } + assertEquals(promises.size(), 0); + + // Send with no promise handler, but use pushId bigger than allowed. + // This should cause the connection to get closed + long usePushId = maxPushes * 3 + 10; + try { + HttpRequest bigger = HttpRequest.newBuilder(uri) + .header("X-UsePushId", String.valueOf(usePushId)) + .build(); + client.sendAsync(bigger, BodyHandlers.ofString()) + .thenApply(H3ServerPushCancel::assert200ResponseCode) + .thenApply(HttpResponse::body) + .thenAccept(body -> assertEquals(body, MAIN_RESPONSE_BODY)) + .join(); + throw new AssertionError("Expected IOException not thrown"); + } catch (CompletionException c) { + boolean success = false; + if (c.getCause() instanceof IOException io) { + if (io.getMessage() != null && + io.getMessage().contains("Max pushId exceeded (%s >= %s)" + .formatted(usePushId, maxPushes))) { + success = true; + } + if (success) { + out.println("Got expected IOException: " + io); + } else throw new AssertionError(io); + } + if (!success) { + throw new AssertionError("Unexpected exception: " + c.getCause(), c.getCause()); + } + } + assertEquals(promises.size(), 0); + + // the next time around we should have a new connection, + // so we can restart from scratch + pushHandler.reset(); + } + var errors = custom.errors; + errors.forEach(t -> t.printStackTrace(System.out)); + var error = errors.stream().findFirst().orElse(null); + if (error != null) throw error; + var notified = custom.notified; + int count = 0; + Set uniqueIds = new HashSet<>(); + Set cIds = new HashSet<>(); + for (var npp : notified) { + uniqueIds.add(npp.pushId); + if (npp.pushId instanceof Http3PushId h3id) { + long id = h3id.pushId(); + cIds.add(h3id.connectionLabel()); + long mod = id % PUSH_PROMISES.size(); + // we can't count the cancelled pushes as + // how many notifs we might get for those is racy. + if (mod == 0 || mod >= 4) { + // was not cancelled + count++; + } + } + } + + if (!uniqueIds.equals(expectedPushIds)) { + int problems = 0; + int missed = 0; + for (var id : uniqueIds) { + if (!expectedPushIds.contains(id)) { + problems++; + out.printf("%s was not expected%n", id); + } + } + for (var id : expectedPushIds) { + if (!uniqueIds.contains(id)) { + if (id instanceof Http3PushId h3id) { + long mod = h3id.pushId() % PUSH_PROMISES.size(); + if (mod > 0 && mod < 4) { + // this one was cancelled, so it might not + // have been notified + missed++; + continue; + } + } + problems++; + out.printf("%s was expected but not notified%n", id); + } + } + if (problems > 0) { + throw new AssertionError("%s unexpected problems with ids have been found" + .formatted(problems)); + } + } + // excluding those that got cancelled, + // we should have received REQUEST-1 notifications + // per push promise and per connection + assertEquals(count, (PUSH_PROMISES.size()-3)*2*(REQUESTS-1), + "Unexpected notification: " + notified); + } + } + + + // --- server push handler --- + static class ServerPushHandler implements HttpTestHandler { + + private final String mainResponseBody; + private final Map promises; + private final ReentrantLock lock = new ReentrantLock(); + private final Map sentPromises = new ConcurrentHashMap<>(); + record PendingPromise(long pushId, CountDownLatch latch) { + PendingPromise(long pushId) { + this(pushId, new CountDownLatch(REQUESTS)); + } + } + + public ServerPushHandler(String mainResponseBody, + Map promises) + throws Exception + { + Objects.requireNonNull(promises); + this.mainResponseBody = mainResponseBody; + this.promises = promises; + } + + // The assumption is that there will be several concurrent + // exchanges, but all on the same connection + // The first exchange that emits a PushPromise sends + // a push promise frame + open the push response stream. + // The other exchanges will simply send a push promise + // frame, with the pushId allocated by the previous exchange. + // The sentPromises map is used to store that pushId. + // This obviously only works if we have a single HTTP/3 connection. + final AtomicInteger count = new AtomicInteger(); + public void handle(HttpTestExchange exchange) throws IOException { + long count = -1; + try { + count = this.count.incrementAndGet(); + err.println("Server: handle " + exchange + + " on " + exchange.getConnectionKey()); + out.println("Server: handle " + exchange.getRequestURI() + + " on " + exchange.getConnectionKey()); + try (InputStream is = exchange.getRequestBody()) { + is.readAllBytes(); + } + + if (exchange.serverPushAllowed()) { + pushPromises(exchange); + } + + // response data for the main response + try (OutputStream os = exchange.getResponseBody()) { + byte[] bytes = mainResponseBody.getBytes(UTF_8); + exchange.sendResponseHeaders(200, bytes.length); + os.write(bytes); + } catch (ClosedChannelException ex) { + out.printf("handling exchange %s, %s: %s%n", count, + exchange.getRequestURI(), exchange.getRequestHeaders()); + out.printf("Got closed channel exception sending response after sent=%s allowed=%s%n", + sent, allowed); + } + } finally { + out.printf("handled exchange %s, %s: %s%n", count, + exchange.getRequestURI(), exchange.getRequestHeaders()); + } + } + + volatile long allowed = -1; + volatile int sent = 0; + volatile int nsent = 0; + void reset() { + lock.lock(); + try { + allowed = -1; + sent = 0; + nsent = 0; + sentPromises.clear(); + } finally { + lock.unlock(); + } + } + + private void pushPromises(HttpTestExchange exchange) throws IOException { + URI requestURI = exchange.getRequestURI(); + long waitForPushId = exchange.getRequestHeaders() + .firstValueAsLong("X-WaitForPushId").orElse(-1); + long usePushId = exchange.getRequestHeaders() + .firstValueAsLong("X-UsePushId").orElse(-1); + if (waitForPushId >= 0) { + while (allowed <= waitForPushId) { + try { + err.printf("Server: waiting for pushId sent=%s allowed=%s: %s%n", + sent, allowed, waitForPushId); + var allowed = exchange.waitForHttp3MaxPushId(waitForPushId); + err.println("Server: Got maxPushId: " + allowed); + out.println("Server: Got maxPushId: " + allowed); + lock.lock(); + if (allowed > this.allowed) this.allowed = allowed; + lock.unlock(); + } catch (InterruptedException ie) { + ie.printStackTrace(); + } + } + } + for (Map.Entry promise : promises.entrySet()) { + // if usePushId != -1 we send a single push promise, + // without checking that it's allowed. + // Otherwise, we stop sending promises when we have consumed + // the whole window + if (usePushId == -1 && allowed > 0 && sent >= allowed) { + err.println("Server: sent all allowed promises: " + sent); + break; + } + + if (waitForPushId >= 0) { + while (allowed <= waitForPushId) { + try { + err.printf("Server: waiting for pushId sent=%s allowed=%s: %s%n", + sent, allowed, waitForPushId); + var allowed = exchange.waitForHttp3MaxPushId(waitForPushId); + err.println("Server: Got maxPushId: " + allowed); + out.println("Server: Got maxPushId: " + allowed); + lock.lock(); + if (allowed > this.allowed) this.allowed = allowed; + lock.unlock(); + } catch (InterruptedException ie) { + ie.printStackTrace(); + } + } + } + URI uri = requestURI.resolve(promise.getKey()); + InputStream is = new ByteArrayInputStream(promise.getValue().getBytes(UTF_8)); + HttpHeaders headers = HttpHeaders.of(Collections.emptyMap(), (x, y) -> true); + if (usePushId == -1) { + long pushId; + boolean send = false; + lock.lock(); + PendingPromise pendingPromise; + try { + pendingPromise = sentPromises.get(promise.getKey()); + if (pendingPromise == null) { + pushId = exchange.sendHttp3PushPromiseFrame(-1, uri, headers); + waitForPushId = pushId + 1; + pendingPromise = new PendingPromise(pushId); + sentPromises.put(promise.getKey(), pendingPromise); + sent += 1; + send = true; + } else { + pushId = pendingPromise.pushId; + exchange.sendHttp3PushPromiseFrame(pushId, uri, headers); + } + } finally { + lock.unlock(); + } + pendingPromise.latch.countDown(); + long mod = pushId % promises.size(); + if (send) { + if (mod == 1) { + // var stream = exchange.openPushIdStream(pushId); + // err.println("Server: Opened push stream: " + pushId); + } + if (mod > 0 && mod < 4) { + try { + pendingPromise.latch.await(); + } catch (InterruptedException e) { + throw new InterruptedIOException("" + e); + } + exchange.sendHttp3CancelPushFrame(pendingPromise.pushId); + err.println("Server: Cancelled push promise: " + pushId); + out.println("Server: Cancelled push promise: " + pushId); + } else { + exchange.sendHttp3PushResponse(pushId, uri, headers, headers, is); + out.println("Server: Sent push promise with response: " + pushId); + err.println("Server: Sent push promise with response: " + pushId); + } + } else { + err.println("Server: Sent push promise frame: " + pushId); + } + if (pushId >= waitForPushId) waitForPushId = pushId + 1; + } else { + exchange.sendHttp3PushPromiseFrame(usePushId, uri, headers); + err.println("Server: Sent push promise frame: " + usePushId); + exchange.sendHttp3PushResponse(usePushId, uri, headers, headers, is); + err.println("Server: Sent push promise response: " + usePushId); + lock.lock(); + sent += 1; + lock.unlock(); + return; + } + } + err.println("Server: All pushes sent"); + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3ServerPushTest.java b/test/jdk/java/net/httpclient/http3/H3ServerPushTest.java new file mode 100644 index 00000000000..9b4858f50c9 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3ServerPushTest.java @@ -0,0 +1,1224 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestExchange; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestHandler; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; +import jdk.test.lib.net.SimpleSSLContext; +import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.TestMethodOrder; + +import javax.net.ssl.SSLContext; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.lang.reflect.Method; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.PushPromiseHandler.PushId.Http3PushId; +import java.time.LocalTime; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.Semaphore; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static java.net.http.HttpClient.Builder.NO_PROXY; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.nio.charset.StandardCharsets.US_ASCII; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/* + * @test + * @summary Verifies the HTTP/3 server push handling of the HTTP client + * @library /test/jdk/java/net/httpclient/lib + * /test/lib + * @build jdk.httpclient.test.lib.http2.Http2TestServer + * jdk.test.lib.net.SimpleSSLContext + * @run junit H3ServerPushTest + */ + +/** + * Verifies the HTTP/3 server push handling of {@link HttpClient}. + * + * @implNote + * Some tests deliberately corrupt the HTTP/3 stream state. Hence, instead of + * creating a single {@link HttpTestServer}-{@link HttpClient} pair attached + * to the class, and sharing it across tests, each test creates its own + * server/client pair. + */ +@TestMethodOrder(OrderAnnotation.class) +class H3ServerPushTest { + + private static final HttpHeaders EMPTY_HEADERS = HttpHeaders.of(Map.of(), (_, _) -> false); + + private static final SSLContext SSL_CONTEXT = createSslContext(); + + private static SSLContext createSslContext() { + try { + return new SimpleSSLContext().get(); + } catch (IOException exception) { + throw new RuntimeException(exception); + } + } + + @Test + @Order(1) + void testBasicRequestResponse(TestInfo testInfo) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + server.addHandler( + exchange -> { + try (exchange) { + exchange.sendResponseHeaders(200, 0); + } + }, + uri.getPath()); + + // Send the request and verify its response + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.discarding()); + assertEquals(200, response.statusCode()); + + } + } + + @Test + @Order(2) + void testTwoConsecutiveRequestsToSameServer(TestInfo testInfo) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + server.addHandler(new PushSender(), uri.getPath()); + + // Send the 1st request + PushReceiver pushReceiver = new PushReceiver(); + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + HttpResponse response1 = client + .sendAsync(request, HttpResponse.BodyHandlers.ofString(US_ASCII), pushReceiver) + .get(); + + // Verify the 1st response + assertEquals(200, response1.statusCode()); + assertEquals("response0", response1.body()); + String connectionLabel = response1.connectionLabel().orElseThrow(); + final long initialPushId; + { + ReceivedPush.Promise[] push1Ref = {null}; // 1. Push(initialPushId) promise + ReceivedPush.Response[] push2Ref = {null}; // 2. Push(initialPushId) response, since this is the very first request + ReceivedPush.Promise[] push3Ref = {null}; // 3. Push(initialPushId+1) promise, the orphan one + pushReceiver.consume(push1Ref, push2Ref, push3Ref); + initialPushId = push1Ref[0].pushId.pushId(); + assertEquals(connectionLabel, push1Ref[0].pushId.connectionLabel()); + assertEquals(initialPushId, push2Ref[0].pushId.pushId()); + assertEquals(connectionLabel, push2Ref[0].pushId.connectionLabel()); + assertEquals("pushResponse0", push2Ref[0].responseBody); + assertEquals(initialPushId + 1, push3Ref[0].pushId.pushId()); + assertEquals(connectionLabel, push3Ref[0].pushId.connectionLabel()); + } + + // Send the 2nd request + log("requesting `%s`...", request.uri()); + HttpResponse response2 = client + .sendAsync(request, HttpResponse.BodyHandlers.ofString(US_ASCII), pushReceiver) + .get(); + + // Verify the 2nd response + assertEquals(200, response2.statusCode()); + assertEquals("response1", response2.body()); + { + ReceivedPush.AdditionalPromise[] push1Ref = {null}; // 1. Push(initialPushId) additional promise + ReceivedPush.Promise[] push2Ref = {null}; // 2. Push(initialPushId+2) promise, the orphan one + pushReceiver.consume(push1Ref, push2Ref); + assertEquals(initialPushId, push1Ref[0].pushId.pushId()); + assertEquals(connectionLabel, push1Ref[0].pushId.connectionLabel()); + assertEquals(initialPushId + 2, push2Ref[0].pushId.pushId()); + assertEquals(connectionLabel, push2Ref[0].pushId.connectionLabel()); + } + + } + } + + @Test + @Order(3) + void testTwoConsecutiveRequestsToDifferentServers(TestInfo testInfo) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server1 = createServer(); + HttpTestServer server2 = createServer()) { + + // Configure the server handlers + URI uri1 = createUri(server1, testInfo); + server1.addHandler(new PushSender(), uri1.getPath()); + URI uri2 = createUri(server2, testInfo); + server2.addHandler(new PushSender(), uri2.getPath()); + + // Send a request to the 1st server + PushReceiver pushReceiver = new PushReceiver(); + HttpRequest request1 = createRequest(uri1); + log("requesting `%s`...", request1.uri()); + HttpResponse response1 = client + .sendAsync(request1, HttpResponse.BodyHandlers.ofString(US_ASCII), pushReceiver) + .get(); + + // Verify the response from the 1st server + assertEquals(200, response1.statusCode()); + assertEquals("response0", response1.body()); + String connectionLabel1 = response1.connectionLabel().orElseThrow(); + { + ReceivedPush.Promise[] push1Ref = {null}; // 1. Push(initialPushId) promise + ReceivedPush.Response[] push2Ref = {null}; // 2. Push(initialPushId) response, since this is the very first request + ReceivedPush.Promise[] push3Ref = {null}; // 3. Push(initialPushId+1) promise, the orphan one + pushReceiver.consume(push1Ref, push2Ref, push3Ref); + long initialPushId = push1Ref[0].pushId.pushId(); + assertEquals(connectionLabel1, push1Ref[0].pushId.connectionLabel()); + assertEquals(initialPushId, push2Ref[0].pushId.pushId()); + assertEquals(connectionLabel1, push2Ref[0].pushId.connectionLabel()); + assertEquals("pushResponse0", push2Ref[0].responseBody); + assertEquals(initialPushId + 1, push3Ref[0].pushId.pushId()); + assertEquals(connectionLabel1, push3Ref[0].pushId.connectionLabel()); + } + + // Send a request to the 2nd server + HttpRequest request2 = createRequest(uri2); + log("requesting `%s`...", request2.uri()); + HttpResponse response2 = client + .sendAsync(request2, HttpResponse.BodyHandlers.ofString(US_ASCII), pushReceiver) + .get(); + + // Verify the response from the 2nd server + assertEquals(200, response2.statusCode()); + assertEquals("response0", response2.body()); + String connectionLabel2 = response2.connectionLabel().orElseThrow(); + { + ReceivedPush.Promise[] push1Ref = {null}; // 1. Push(initialPushId) promise + ReceivedPush.Response[] push2Ref = {null}; // 2. Push(initialPushId) response, since this is the very first request + ReceivedPush.Promise[] push3Ref = {null}; // 3. Push(initialPushId+1) promise, the orphan one + pushReceiver.consume(push1Ref, push2Ref, push3Ref); + long initialPushId = push1Ref[0].pushId.pushId(); + assertEquals(connectionLabel2, push1Ref[0].pushId.connectionLabel()); + assertEquals(initialPushId, push2Ref[0].pushId.pushId()); + assertEquals(connectionLabel2, push2Ref[0].pushId.connectionLabel()); + assertEquals("pushResponse0", push2Ref[0].responseBody); + assertEquals(initialPushId + 1, push3Ref[0].pushId.pushId()); + assertEquals(connectionLabel2, push3Ref[0].pushId.connectionLabel()); + } + + // Verify that connection labels differ + assertNotEquals(connectionLabel1, connectionLabel2); + + } + } + + /** + * A server handler responding to all requests as follows: + *

      + *
    1. push(initialPushId) promise
    2. + *
    3. push(initialPushId) response: "pushResponse" + responseIndex,
      + * iff responseIndex == 0
    4. + *
    5. push(pushId++) promise (an orphan push promise)
    6. + *
    7. response: "response" + responseIndex
    8. + *
    + */ + private static final class PushSender implements HttpTestHandler { + + private int responseIndex = 0; + + private long initialPushId = -1; + + @Override + public synchronized void handle(HttpTestExchange exchange) throws IOException { + try (exchange) { + + // Start with the push promise + assertTrue(exchange.serverPushAllowed()); + log(">>> sending push promise (responseIndex=%d, pushId=%d)", responseIndex, initialPushId); + long newInitialPushId = exchange.sendHttp3PushPromiseFrame( + initialPushId, + exchange.getRequestURI(), + EMPTY_HEADERS); + if (initialPushId != newInitialPushId) { + log(">>> updated initial pushId=%d (responseIndex=%d)", newInitialPushId, responseIndex); + initialPushId = newInitialPushId; + } + + // Send the push response iff it is the very first request + if (responseIndex == 0) { + log(">>> sending push response (responseIndex=%d, pushId=%d)", responseIndex, initialPushId); + byte[] pushResponseBody = "pushResponse%d".formatted(responseIndex).getBytes(US_ASCII); + exchange.sendHttp3PushResponse( + initialPushId, + exchange.getRequestURI(), + EMPTY_HEADERS, + EMPTY_HEADERS, + new ByteArrayInputStream(pushResponseBody)); + } + + // Send the orphan push promise + log(">>> sending an orphan push promise (responseIndex=%d)", responseIndex); + long orphanPushId = exchange.sendHttp3PushPromiseFrame(-1, exchange.getRequestURI(), EMPTY_HEADERS); + log(">>> sent the orphan push promise (responseIndex=%d, pushId=%d)", responseIndex, orphanPushId); + + // Send the response + log(">>> sending response (responseIndex=%d)", responseIndex); + byte[] responseBody = "response%d".formatted(responseIndex).getBytes(US_ASCII); + exchange.sendResponseHeaders(200, responseBody.length); + exchange.getResponseBody().write(responseBody); + + } finally { + responseIndex++; + } + } + + } + + @Test + @Order(4) + void testTwoPushPromisesWithSameIdInOneResponse(TestInfo testInfo) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + HttpTestHandler pushSender = new HttpTestHandler() { + + private int responseIndex = 0; + + @Override + public synchronized void handle(HttpTestExchange exchange) throws IOException { + try (exchange) { + + // Send the 1st push promise and receive the push ID + log(">>> sending push promise (responseIndex=%d)", responseIndex); + long pushId = exchange.sendHttp3PushPromiseFrame(-1, uri, EMPTY_HEADERS); + + // Send the 2nd push promise using the same ID + log(">>> sending push response (responseIndex=%d, pushId=%d)", responseIndex, pushId); + exchange.sendHttp3PushPromiseFrame(pushId, uri, EMPTY_HEADERS); + + // Send the response + log(">>> sending response (responseIndex=%d)", responseIndex); + byte[] responseBody = "response%d".formatted(responseIndex).getBytes(US_ASCII); + exchange.sendResponseHeaders(200, responseBody.length); + exchange.getResponseBody().write(responseBody); + + } finally { + responseIndex++; + } + } + + }; + server.addHandler(pushSender, uri.getPath()); + + // Send the request + PushReceiver pushReceiver = new PushReceiver(); + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + HttpResponse response = client + .sendAsync(request, HttpResponse.BodyHandlers.ofString(US_ASCII), pushReceiver) + .get(); + + // Verify the response + assertEquals(200, response.statusCode()); + assertEquals("response0", response.body()); + ReceivedPush.Promise[] push1Ref = {null}; // 1. Push(initialPushId) promise + ReceivedPush.AdditionalPromise[] push2Ref = {null}; // 2. Push(initialPushId) promise, again + pushReceiver.consume(push1Ref, push2Ref); + long initialPushId = push1Ref[0].pushId.pushId(); + String connectionLabel = response.connectionLabel().orElseThrow(); + assertEquals(connectionLabel, push1Ref[0].pushId.connectionLabel()); + assertEquals(initialPushId, push2Ref[0].pushId.pushId()); + assertEquals(connectionLabel, push2Ref[0].pushId.connectionLabel()); + + } + } + + @Test + @Order(5) + void testTwoPushPromisesWithSameIdButDifferentHeadersInOneResponse(TestInfo testInfo) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + CountDownLatch responseBodyWriteLatch = new CountDownLatch(1); + var pushSender = new HttpTestHandler() { + + private long pushId = -1; + + private int responseIndex = 0; + + @Override + public synchronized void handle(HttpTestExchange exchange) throws IOException { + try (exchange) { + + // Send the 1st push promise and receive the push ID + log(">>> sending push promise (responseIndex=%d)", responseIndex); + pushId = exchange.sendHttp3PushPromiseFrame(pushId, uri, EMPTY_HEADERS); + + // Send the 2nd push promise using the same ID, but different headers + log(">>> sending push response (responseIndex=%d, pushId=%d)", responseIndex, pushId); + HttpHeaders nonEmptyHeaders = HttpHeaders.of(Map.of("Foo", List.of("Bar")), (_, _) -> true); + assertNotEquals(EMPTY_HEADERS, nonEmptyHeaders); + exchange.sendHttp3PushPromiseFrame(pushId, uri, nonEmptyHeaders); + + // Send the response + log(">>> sending response (responseIndex=%d)", responseIndex); + byte[] responseBody = "response%d".formatted(responseIndex).getBytes(US_ASCII); + exchange.sendResponseHeaders(200, responseBody.length); + // Block to ensure the bad push stream is received before the response + awaitLatch(responseBodyWriteLatch); + exchange.getResponseBody().write(responseBody); + + } finally { + responseIndex++; + } + } + + }; + server.addHandler(pushSender, uri.getPath()); + + // Send the request and verify the failure + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + ExecutionException exception = assertThrows(ExecutionException.class, () -> client + .sendAsync( + request, + HttpResponse.BodyHandlers.ofString(US_ASCII), + // Push receiver has no semantic purpose here, but provide logs to aid troubleshooting. + new PushReceiver()) + .get()); + responseBodyWriteLatch.countDown(); + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertInstanceOf(IOException.class, cause); + String actualMessage = cause.getMessage(); + String expectedMessage = "push headers do not match with previous promise for %d".formatted(pushSender.pushId); + assertEquals(expectedMessage, actualMessage); + + } + } + + @Test + @Order(6) + void testTwoPushPromisesWithSameIdInTwoResponsesOverOneConnection(TestInfo testInfo) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + HttpTestHandler pushSender = new HttpTestHandler() { + + private long pushId = -1; + + private int responseIndex = 0; + + @Override + public synchronized void handle(HttpTestExchange exchange) throws IOException { + try (exchange) { + + // Send the push promise, and receive the push ID, if necessary + log(">>> sending push promise (responseIndex=%d, pushId=%d)", responseIndex, pushId); + pushId = exchange.sendHttp3PushPromiseFrame(pushId, uri, EMPTY_HEADERS); + + // Send the response + log(">>> sending response (responseIndex=%d)", responseIndex); + byte[] responseBody = "response%d".formatted(responseIndex).getBytes(US_ASCII); + exchange.sendResponseHeaders(200, responseBody.length); + exchange.getResponseBody().write(responseBody); + + } finally { + responseIndex++; + } + } + + }; + server.addHandler(pushSender, uri.getPath()); + + // Send the 1st request + PushReceiver pushReceiver = new PushReceiver(); + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + HttpResponse response1 = client + .sendAsync(request, HttpResponse.BodyHandlers.ofString(US_ASCII), pushReceiver) + .get(); + + // Verify the 1st response + assertEquals(200, response1.statusCode()); + assertEquals("response0", response1.body()); + String connectionLabel = response1.connectionLabel().orElseThrow(); + final long initialPushId; + { + ReceivedPush.Promise[] pushRef = {null}; + pushReceiver.consume(pushRef); + initialPushId = pushRef[0].pushId.pushId(); + assertEquals(connectionLabel, pushRef[0].pushId.connectionLabel()); + } + + // Send the 2nd request + log("requesting `%s`...", request.uri()); + HttpResponse response2 = client + .sendAsync(request, HttpResponse.BodyHandlers.ofString(US_ASCII), pushReceiver) + .get(); + + // Verify the 2nd request + assertEquals(200, response2.statusCode()); + assertEquals("response1", response2.body()); + assertEquals(connectionLabel, response2.connectionLabel().orElseThrow()); + { + ReceivedPush.AdditionalPromise[] pushRef = {null}; + pushReceiver.consume(pushRef); + assertEquals(initialPushId, pushRef[0].pushId.pushId()); + assertEquals(connectionLabel, pushRef[0].pushId.connectionLabel()); + } + + } + } + + @Test + @Order(7) + void testTwoPushResponsesWithSameIdInOneResponse(TestInfo testInfo) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + long[] pushId = {-1}; + CountDownLatch responseBodyWriteLatch = new CountDownLatch(1); + HttpTestHandler pushSender = new HttpTestHandler() { + + private int responseIndex = 0; + + @Override + public synchronized void handle(HttpTestExchange exchange) throws IOException { + try (exchange) { + + // Send the 1st push promise and receive the push ID + log(">>> sending push promise (responseIndex=%d)", responseIndex); + pushId[0] = exchange.sendHttp3PushPromiseFrame(-1, uri, EMPTY_HEADERS); + + // Send two push responses + for (int trialIndex = 0; trialIndex < 2; trialIndex++) { + log( + ">>> sending push response (responseIndex=%d, pushId=%d, trialIndex=%d)", + responseIndex, pushId[0], trialIndex); + byte[] pushResponseBody = "pushResponse%d-%d" + .formatted(responseIndex, trialIndex) + .getBytes(US_ASCII); + exchange.sendHttp3PushResponse( + pushId[0], + uri, + EMPTY_HEADERS, + EMPTY_HEADERS, + new ByteArrayInputStream(pushResponseBody)); + } + + // Send the response + log(">>> sending response (responseIndex=%d)", responseIndex); + byte[] responseBody = "response%d".formatted(responseIndex).getBytes(US_ASCII); + exchange.sendResponseHeaders(200, responseBody.length); + // Block to ensure the bad push stream is received before the response + awaitLatch(responseBodyWriteLatch); + exchange.getResponseBody().write(responseBody); + } finally { + responseIndex++; + } + } + + }; + server.addHandler(pushSender, uri.getPath()); + + // Send the request and verify the failure + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + ExecutionException exception = assertThrows(ExecutionException.class, () -> client + .sendAsync( + request, + HttpResponse.BodyHandlers.ofString(US_ASCII), + // Push receiver has no semantic purpose here, but provide logs to aid troubleshooting. + new PushReceiver()) + .get()); + responseBodyWriteLatch.countDown(); + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertInstanceOf(IOException.class, cause); + String actualMessage = cause.getMessage(); + String expectedMessage = "HTTP/3 pushId %d already used on this connection".formatted(pushId[0]); + assertEquals(expectedMessage, actualMessage); + } + } + + private static void awaitLatch(CountDownLatch latch) { + try { + latch.await(); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); // Restore the interrupt + throw new RuntimeException(ie); + } + } + + @Test + @Order(8) + void testTwoPushResponsesWithSameIdInTwoResponsesOverOneConnection(TestInfo testInfo) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + long[] pushId = {-1}; + HttpTestHandler pushSender = new HttpTestHandler() { + + private int responseIndex = 0; + + private final Semaphore responseBodyWriteSem = new Semaphore(1); + + @Override + public synchronized void handle(HttpTestExchange exchange) throws IOException { + try (exchange) { + + // Send the push promise and receive the push ID, iff this is the very first response + if (responseIndex == 0) { + log(">>> sending push promise (responseIndex=%d)", responseIndex); + pushId[0] = exchange.sendHttp3PushPromiseFrame(-1, uri, EMPTY_HEADERS); + } + + // Send the push response + log(">>> sending push response (responseIndex=%d, pushId=%d)", responseIndex, pushId[0]); + byte[] pushResponseBody = "pushResponse%d".formatted(responseIndex).getBytes(US_ASCII); + exchange.sendHttp3PushResponse( + pushId[0], + uri, + EMPTY_HEADERS, + EMPTY_HEADERS, + new ByteArrayInputStream(pushResponseBody)); + + // Send the response + log(">>> sending response (responseIndex=%d)", responseIndex); + byte[] responseBody = "response%d".formatted(responseIndex).getBytes(US_ASCII); + exchange.sendResponseHeaders(200, responseBody.length); + try { + // The second request will block here, ensuring the + // bad push stream is received before the response. + responseBodyWriteSem.acquire(); + } catch (InterruptedException x) { + Thread.currentThread().interrupt(); + } + exchange.getResponseBody().write(responseBody); + + } finally { + responseIndex++; + } + } + + }; + server.addHandler(pushSender, uri.getPath()); + + // Send the 1st request + PushReceiver pushReceiver = new PushReceiver(); + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + HttpResponse response1 = client + .sendAsync(request, HttpResponse.BodyHandlers.ofString(US_ASCII), pushReceiver) + .get(); + + // Verify the 1st response + assertEquals(200, response1.statusCode()); + assertEquals("response0", response1.body()); + ReceivedPush.Promise[] push1Ref = {null}; // 1. Push(initialPushId) promise + ReceivedPush.Response[] push2Ref = {null}; // 2. Push(initialPushId) response + pushReceiver.consume(push1Ref, push2Ref); + long initialPushId = push1Ref[0].pushId.pushId(); + String connectionLabel = response1.connectionLabel().orElseThrow(); + assertEquals(connectionLabel, push1Ref[0].pushId.connectionLabel()); + assertEquals(initialPushId, push2Ref[0].pushId.pushId()); + assertEquals(connectionLabel, push2Ref[0].pushId.connectionLabel()); + assertEquals("pushResponse0", push2Ref[0].responseBody); + + // Send the 2nd request and verify the failure + log("requesting `%s`...", request.uri()); + ExecutionException exception = assertThrows(ExecutionException.class, () -> client + .sendAsync( + request, + HttpResponse.BodyHandlers.ofString(US_ASCII), + // Push receiver has no semantic purpose here, but provide logs to aid troubleshooting. + pushReceiver) + .get()); + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertInstanceOf(IOException.class, cause); + String actualMessage = cause.getMessage(); + String expectedMessage = "HTTP/3 pushId %d already used on this connection".formatted(pushId[0]); + assertEquals(expectedMessage, actualMessage); + + } + } + + @Test + @Order(9) + void testPushPromiseBeforeHeader(TestInfo testInfo) throws Exception { + testPositionalPushPromise(testInfo, new PositionalPushSender() { + @Override + void handle0(HttpTestExchange exchange) throws IOException { + sendPushPromise(exchange); + sendHeaders(exchange); + sendBody(exchange); + } + }); + } + + @Test + @Order(10) + void testPushPromiseAfterHeaderAndBeforeBody(TestInfo testInfo) throws Exception { + testPositionalPushPromise(testInfo, new PositionalPushSender() { + @Override + void handle0(HttpTestExchange exchange) throws IOException { + sendHeaders(exchange); + sendPushPromise(exchange); + sendBody(exchange); + } + }); + } + + @Test + @Order(11) + void testPushPromiseAfterBody(TestInfo testInfo) throws Exception { + testPositionalPushPromise(testInfo, new PositionalPushSender() { + @Override + void handle0(HttpTestExchange exchange) throws IOException { + sendHeaders(exchange); + sendBody(exchange); + sendPushPromise(exchange); + } + }); + } + + private static void testPositionalPushPromise(TestInfo testInfo, PositionalPushSender pushSender) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + server.addHandler(pushSender, uri.getPath()); + + // Send the request + PushReceiver pushReceiver = new PushReceiver(); + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + HttpResponse response = client + .sendAsync(request, HttpResponse.BodyHandlers.ofString(US_ASCII), pushReceiver) + .get(); + + // Verify the response + assertEquals(200, response.statusCode()); + assertEquals("response0", response.body()); + ReceivedPush.Promise[] pushRef = {null}; + pushReceiver.consume(pushRef); + assertEquals(pushSender.pushId, pushRef[0].pushId.pushId()); + + } + } + + @Test + @Order(12) + void testPushPromiseAndResponseBeforeHeader(TestInfo testInfo) throws Exception { + testPositionalPushPromiseAndResponse(testInfo, new PositionalPushSender() { + @Override + void handle0(HttpTestExchange exchange) throws IOException { + sendPushPromise(exchange); + sendPushResponse(exchange); + sendHeaders(exchange); + sendBody(exchange); + } + }); + } + + @Test + @Order(13) + void testPushPromiseAndResponseAfterHeaderAndBeforeBody(TestInfo testInfo) throws Exception { + testPositionalPushPromiseAndResponse(testInfo, new PositionalPushSender() { + @Override + void handle0(HttpTestExchange exchange) throws IOException { + sendHeaders(exchange); + sendPushPromise(exchange); + sendPushResponse(exchange); + sendBody(exchange); + } + }); + } + + @Test + @Order(14) + void testPushPromiseAndResponseAfterBody(TestInfo testInfo) throws Exception { + testPositionalPushPromiseAndResponse(testInfo, new PositionalPushSender() { + @Override + void handle0(HttpTestExchange exchange) throws IOException { + sendHeaders(exchange); + sendBody(exchange); + sendPushPromise(exchange); + sendPushResponse(exchange); + } + }); + } + + private static void testPositionalPushPromiseAndResponse(TestInfo testInfo, PositionalPushSender pushSender) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + server.addHandler(pushSender, uri.getPath()); + + // Send the request + PushReceiver pushReceiver = new PushReceiver(); + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + HttpResponse response = client + .sendAsync(request, HttpResponse.BodyHandlers.ofString(US_ASCII), pushReceiver) + .get(); + + // Verify the response + assertEquals(200, response.statusCode()); + assertEquals("response0", response.body()); + ReceivedPush.Promise[] push1Ref = {null}; + ReceivedPush.Response[] push2Ref = {null}; + pushReceiver.consume(push1Ref, push2Ref); + assertEquals(pushSender.pushId, push1Ref[0].pushId.pushId()); + String connectionLabel = response.connectionLabel().orElseThrow(); + assertEquals(connectionLabel, push1Ref[0].pushId.connectionLabel()); + assertEquals(pushSender.pushId, push2Ref[0].pushId.pushId()); + assertEquals(connectionLabel, push2Ref[0].pushId.connectionLabel()); + assertEquals("pushResponse0", push2Ref[0].responseBody); + + } + } + + /** + * A server providing helper methods to send header, body, push promise & response. + * Subclasses can use these methods to inject custom server push behaviour at certain positions of the response assembly. + */ + private static abstract class PositionalPushSender implements HttpTestHandler { + + private long pushId = -1; + + private int responseIndex = 0; + + @Override + public final synchronized void handle(HttpTestExchange exchange) throws IOException { + try (exchange) { + handle0(exchange); + } finally { + responseIndex++; + } + } + + abstract void handle0(HttpTestExchange exchange) throws IOException; + + void sendHeaders(HttpTestExchange exchange) throws IOException { + log(">>> sending headers (responseIndex=%d)", responseIndex); + exchange.sendResponseHeaders( + 200, + // Use `-1` to avoid generating a single DataFrame. + // Otherwise, server closes the stream after writing + // the response body, and this makes it impossible to test + // server pushes delivered after the response body. + -1); + } + + void sendBody(HttpTestExchange exchange) throws IOException { + log(">>> sending body (responseIndex=%d)", responseIndex); + byte[] responseBody = "response%d".formatted(responseIndex).getBytes(US_ASCII); + exchange.getResponseBody().write(responseBody); + } + + void sendPushResponse(HttpTestExchange exchange) throws IOException { + log(">>> sending push response (responseIndex=%d, pushId=%d)", responseIndex, pushId); + byte[] pushResponseBody = "pushResponse%d".formatted(responseIndex).getBytes(US_ASCII); + exchange.sendHttp3PushResponse( + pushId, + exchange.getRequestURI(), + EMPTY_HEADERS, + EMPTY_HEADERS, + new ByteArrayInputStream(pushResponseBody)); + } + + void sendPushPromise(HttpTestExchange exchange) throws IOException { + log(">>> sending push promise (responseIndex=%d, pushId=%d)", responseIndex, pushId); + pushId = exchange.sendHttp3PushPromiseFrame(pushId, exchange.getRequestURI(), EMPTY_HEADERS); + } + + } + + /** + * The maximum number of distinct push promise IDs allowed in a single response. + */ + private static final int MAX_ALLOWED_PUSH_ID_COUNT_PER_RESPONSE = 100; + + /** + * A value slightly more than {@link #MAX_ALLOWED_PUSH_ID_COUNT_PER_RESPONSE} to intentionally violate limits. + */ + private static final int EXCESSIVE_PUSH_ID_COUNT_PER_RESPONSE = + Math.addExact(10, MAX_ALLOWED_PUSH_ID_COUNT_PER_RESPONSE); + + @Test + @Order(15) + void testExcessivePushPromisesWithSameIdInOneResponse(TestInfo testInfo) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + int pushCount = EXCESSIVE_PUSH_ID_COUNT_PER_RESPONSE; + HttpTestHandler pushSender = new ManyPushSender() { + @Override + void handle0(HttpTestExchange exchange) throws IOException { + sendPushPromise(exchange, pushCount, () -> pushId); + } + }; + server.addHandler(pushSender, uri.getPath()); + + // Send the request + PushReceiver pushReceiver = new PushReceiver(); + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + HttpResponse response = client + .sendAsync(request, HttpResponse.BodyHandlers.ofString(US_ASCII), pushReceiver) + .get(); + + // Verify the response + assertEquals(200, response.statusCode()); + assertEquals("response0", response.body()); + ReceivedPush[][] pushRefs = new ReceivedPush[pushCount][1]; + for (int i = 0; i < pushCount; i++) { + pushRefs[i] = i == 0 ? new ReceivedPush.Promise[1] : new ReceivedPush.AdditionalPromise[1]; + } + pushReceiver.consume(pushRefs); + long initialPushId = ((ReceivedPush.Promise[]) pushRefs[0])[0].pushId.pushId(); + for (int i = 1; i < pushCount; i++) { + assertEquals( + initialPushId, ((ReceivedPush.AdditionalPromise[]) pushRefs[i])[0].pushId.pushId(), + "push ID mismatch for received server push at index %d".formatted(i)); + } + + } + } + + @Test + @Order(16) + void testExcessivePushPromisesWithDistinctIdsInOneResponse(TestInfo testInfo) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + HttpTestHandler pushSender = new ManyPushSender() { + @Override + void handle0(HttpTestExchange exchange) throws IOException { + sendPushPromise(exchange, EXCESSIVE_PUSH_ID_COUNT_PER_RESPONSE, () -> -1L); + } + }; + server.addHandler(pushSender, uri.getPath()); + + // Send the request and verify the failure + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + Exception exception = assertThrows(Exception.class, () -> client + .sendAsync( + request, + HttpResponse.BodyHandlers.ofString(US_ASCII), + // Push receiver has no semantic purpose here, but provide logs to aid troubleshooting. + new PushReceiver()) + .get()); + String exceptionMessage = exception.getMessage(); + assertTrue( + exceptionMessage.contains("Max pushId exceeded"), + "Unexpected exception message: `%s`".formatted(exceptionMessage)); + + } + } + + @Test + @Order(17) + void testExcessivePushResponsesWithDistinctIdsInOneResponse(TestInfo testInfo) throws Exception { + try (HttpClient client = createClient(); + HttpTestServer server = createServer()) { + + // Configure the server handler + URI uri = createUri(server, testInfo); + int pushCount = EXCESSIVE_PUSH_ID_COUNT_PER_RESPONSE; + HttpTestHandler pushSender = new ManyPushSender() { + @Override + void handle0(HttpTestExchange exchange) throws IOException { + long[] returnedPushIds = sendPushPromise(exchange, pushCount, () -> -1L); + Queue returnedPushIdQueue = Arrays + .stream(returnedPushIds) + .boxed() + .collect(Collectors.toCollection(LinkedList::new)); + sendPushResponse(exchange, pushCount, returnedPushIdQueue::poll); + } + }; + server.addHandler(pushSender, uri.getPath()); + + // Send the request and verify the failure + HttpRequest request = createRequest(uri); + log("requesting `%s`...", request.uri()); + Exception exception = assertThrows(Exception.class, () -> client + .sendAsync( + request, + HttpResponse.BodyHandlers.ofString(US_ASCII), + // Push receiver has no semantic purpose here, but provide logs to aid troubleshooting. + new PushReceiver()) + .get()); + String exceptionMessage = exception.getMessage(); + assertTrue( + exceptionMessage.contains("Max pushId exceeded"), + "Unexpected exception message: `%s`".formatted(exceptionMessage)); + + } + } + + /** + * A server providing helper methods subclasses can extend to send multiple push promises & responses. + */ + private static abstract class ManyPushSender implements HttpTestHandler { + + long pushId = -1; + + private int responseIndex = 0; + + @Override + public final synchronized void handle(HttpTestExchange exchange) throws IOException { + try (exchange) { + handle0(exchange); + sendHeaders(exchange); + sendBody(exchange); + } finally { + responseIndex++; + } + } + + abstract void handle0(HttpTestExchange exchange) throws IOException; + + private void sendHeaders(HttpTestExchange exchange) throws IOException { + log(">>> sending headers (responseIndex=%d)", responseIndex); + byte[] responseBody = responseBody(); + exchange.sendResponseHeaders(200, responseBody.length); + } + + private void sendBody(HttpTestExchange exchange) throws IOException { + log(">>> sending body (responseIndex=%d)", responseIndex); + byte[] responseBody = responseBody(); + exchange.getResponseBody().write(responseBody); + } + + private byte[] responseBody() { + return "response%d".formatted(responseIndex).getBytes(US_ASCII); + } + + long[] sendPushPromise(HttpTestExchange exchange, int count, Supplier pushIdProvider) throws IOException { + long[] returnedPushIds = new long[count]; + for (int i = 0; i < count; i++) { + long pushPromiseId = pushIdProvider.get(); + log( + ">>> sending push promise (responseIndex=%d, pushId=%d, i=%d/%d)", + responseIndex, pushPromiseId, i, count); + pushId = returnedPushIds[i] = + exchange.sendHttp3PushPromiseFrame(pushPromiseId, exchange.getRequestURI(), EMPTY_HEADERS); + } + return returnedPushIds; + } + + void sendPushResponse(HttpTestExchange exchange, int count, Supplier pushIdProvider) throws IOException { + for (int i = 0; i < count; i++) { + long pushResponseId = pushIdProvider.get(); + log( + ">>> sending push response (responseIndex=%d, pushId=%d, i=%d/%d)", + responseIndex, pushResponseId, i, count); + byte[] pushResponseBody = "pushResponse%d-%d".formatted(responseIndex, i).getBytes(US_ASCII); + exchange.sendHttp3PushResponse( + pushResponseId, + exchange.getRequestURI(), + EMPTY_HEADERS, + EMPTY_HEADERS, + new ByteArrayInputStream(pushResponseBody)); + } + } + + } + + private static URI createUri(HttpTestServer server, TestInfo testInfo) { + String uri = "https://%s/%s/%s".formatted( + server.serverAuthority(), + testInfo.getTestClass().map(Class::getSimpleName).orElse("UnknownClass"), + testInfo.getTestMethod().map(Method::getName).orElse("UnknownMethod")); + return URI.create(uri); + } + + private static HttpRequest createRequest(URI uri) { + return HttpRequest.newBuilder(uri).HEAD().setOption(H3_DISCOVERY, HTTP_3_URI_ONLY).build(); + } + + private static final class PushReceiver implements HttpResponse.PushPromiseHandler { + + private final BlockingQueue buffer = new LinkedBlockingQueue<>(); + + @Override + public void applyPushPromise( + HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + Function, CompletableFuture>> acceptor) { + fail("`applyPushPromise(...,PushId,...)` should have been called instead"); + } + + @Override + public void applyPushPromise( + HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + PushId pushId, + Function, CompletableFuture>> acceptor) { + Http3PushId http3PushId = (Http3PushId) pushId; + buffer(new ReceivedPush.Promise(http3PushId)); + acceptor.apply(HttpResponse.BodyHandlers.ofString(US_ASCII)).thenAccept(response -> { + assertEquals(200, response.statusCode()); + String responseBody = response.body(); + buffer(new ReceivedPush.Response(http3PushId, responseBody)); + }); + } + + @Override + public void notifyAdditionalPromise(HttpRequest initiatingRequest, PushId pushId) { + Http3PushId http3PushId = (Http3PushId) pushId; + buffer(new ReceivedPush.AdditionalPromise(http3PushId)); + } + + private void buffer(ReceivedPush push) { + log("<<< received push: `%s`", push); + try { + buffer.put(push); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); // Restore the interrupt + throw new RuntimeException(ie); + } + } + + @SuppressWarnings("rawtypes") + private void consume(ReceivedPush[]... pushRefs) { + int n = pushRefs.length; + Class[] pushTypes = Arrays + .stream(pushRefs) + .map(pushRef -> pushRef.getClass().componentType()) + .toArray(Class[]::new); + boolean[] foundIndices = new boolean[n]; + for (int i = 0; i < n; i++) { + ReceivedPush push; + try { + push = buffer.take(); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); // Restore the interrupt + throw new RuntimeException(ie); + } + boolean found = false; + for (int j = 0; j < n; j++) { + if (!foundIndices[j] && pushTypes[j].isInstance(push)) { + pushRefs[j][0] = push; + foundIndices[j] = true; + found = true; + break; + } + } + if (!found) { + log("pushRefs: %s", List.of(pushRefs)); + log("foundIndices: %s", List.of(foundIndices)); + log("n: %d", n); + log("i: %d", i); + log("push: %s", push); + fail("received push does not match with the expected types"); + } + } + } + + } + + private sealed interface ReceivedPush { + + record Promise(Http3PushId pushId) implements ReceivedPush {} + + record Response(Http3PushId pushId, String responseBody) implements ReceivedPush {} + + record AdditionalPromise(Http3PushId pushId) implements ReceivedPush {} + + } + + private static HttpTestServer createServer() throws IOException { + HttpTestServer server = HttpTestServer.create(HTTP_3_URI_ONLY, SSL_CONTEXT); + server.start(); + return server; + } + + private static HttpClient createClient() { + return HttpServerAdapters + .createClientBuilderFor(HTTP_3) + .proxy(NO_PROXY) + .version(HTTP_3) + .sslContext(SSL_CONTEXT) + .build(); + } + + private static void log(String format, Object... args) { + String text = format.formatted(args); + System.err.printf( + "%s [%25s] %s%n", + LocalTime.now(), + Thread.currentThread().getName(), + text); + } + +} diff --git a/test/jdk/java/net/httpclient/http3/H3ServerPushWithDiffTypes.java b/test/jdk/java/net/httpclient/http3/H3ServerPushWithDiffTypes.java new file mode 100644 index 00000000000..af22651f0a0 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3ServerPushWithDiffTypes.java @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=errors,requests,responses + * H3ServerPushWithDiffTypes + * @summary This is a clone of http2/ServerPushWithDiffTypes but for HTTP/3 + */ + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpClient.Version; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodyHandlers; +import java.net.http.HttpResponse.BodySubscriber; +import java.net.http.HttpResponse.BodySubscribers; +import java.net.http.HttpResponse.PushPromiseHandler; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Flow; +import java.util.function.BiPredicate; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.Test; + +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; + +public class H3ServerPushWithDiffTypes implements HttpServerAdapters { + + static Map PUSH_PROMISES = Map.of( + "/x/y/z/1", "the first push promise body", + "/x/y/z/2", "the second push promise body", + "/x/y/z/3", "the third push promise body", + "/x/y/z/4", "the fourth push promise body", + "/x/y/z/5", "the fifth push promise body", + "/x/y/z/6", "the sixth push promise body", + "/x/y/z/7", "the seventh push promise body", + "/x/y/z/8", "the eighth push promise body", + "/x/y/z/9", "the ninth push promise body" + ); + + private void sendHeadRequest(HttpClient client, URI headURI) throws IOException, InterruptedException { + HttpRequest headRequest = HttpRequest.newBuilder(headURI) + .HEAD().version(Version.HTTP_2).build(); + var headResponse = client.send(headRequest, BodyHandlers.ofString()); + assertEquals(headResponse.statusCode(), 200); + assertEquals(headResponse.version(), Version.HTTP_2); + } + + @Test + public void test() throws Exception { + var sslContext = new SimpleSSLContext().get(); + try (HttpTestServer server = HttpTestServer.create(ANY, sslContext)) { + HttpTestHandler pushHandler = + new ServerPushHandler("the main response body", + PUSH_PROMISES); + server.addHandler(pushHandler, "/push/"); + server.addHandler(new HttpHeadOrGetHandler(), "/head/"); + + server.start(); + System.err.println("Server listening on port " + server.serverAuthority()); + + // use multi-level path + URI uri = new URI("https://" + server.serverAuthority() + "/push/a/b/c"); + URI headURI = new URI("https://" + server.serverAuthority() + "/head/x"); + + try (HttpClient client = newClientBuilderForH3().proxy(Builder.NO_PROXY) + .sslContext(sslContext).version(Version.HTTP_3).build()) { + + sendHeadRequest(client, headURI); + + HttpRequest request = HttpRequest.newBuilder(uri).GET().build(); + + ConcurrentMap>>> + results = new ConcurrentHashMap<>(); + PushPromiseHandler> bh = PushPromiseHandler.of( + BodyAndTypeHandler::new, results); + + CompletableFuture>> cf = + client.sendAsync(request, new BodyAndTypeHandler(request), bh); + results.put(request, cf); + cf.join(); + + assertEquals(results.size(), PUSH_PROMISES.size() + 1); + + for (HttpRequest r : results.keySet()) { + URI u = r.uri(); + var resp = results.get(r).get(); + assertEquals(resp.statusCode(), 200); + assertEquals(resp.version(), Version.HTTP_3); + BodyAndType body = resp.body(); + String result; + // convert all body types to String for easier comparison + if (body.type() == String.class) { + result = (String) body.body(); + } else if (body.type() == byte[].class) { + byte[] bytes = (byte[]) body.body(); + result = new String(bytes, UTF_8); + } else if (Path.class.isAssignableFrom(body.type())) { + Path path = (Path) body.body(); + result = Files.readString(path); + } else { + throw new AssertionError("Unknown:" + body.type()); + } + + System.err.printf("%s -> %s\n", u.toString(), result); + String expected = PUSH_PROMISES.get(r.uri().getPath()); + if (expected == null) + expected = "the main response body"; + assertEquals(result, expected); + } + } + } + } + + interface BodyAndType { + Class type(); + T body(); + } + + static final Path WORK_DIR = Paths.get("."); + + static class BodyAndTypeHandler implements BodyHandler> { + int count; + final HttpRequest request; + + BodyAndTypeHandler(HttpRequest request) { + this.request = request; + } + + @Override + @SuppressWarnings("rawtypes,unchecked") + public BodySubscriber> apply(HttpResponse.ResponseInfo info) { + int whichType = count++ % 3; // real world may base this on the request metadata + switch (whichType) { + case 0: // String + return new BodyAndTypeSubscriber(BodySubscribers.ofString(UTF_8)); + case 1: // byte[] + return new BodyAndTypeSubscriber(BodySubscribers.ofByteArray()); + case 2: // Path + URI u = request.uri(); + Path path = Paths.get(WORK_DIR.toString(), u.getPath()); + try { + Files.createDirectories(path.getParent()); + } catch (IOException ee) { + throw new UncheckedIOException(ee); + } + return new BodyAndTypeSubscriber(BodySubscribers.ofFile(path)); + default: + throw new AssertionError("Unexpected " + whichType); + } + } + } + + static class BodyAndTypeSubscriber + implements BodySubscriber> + { + private record BodyAndTypeImpl(Class type, T body) implements BodyAndType { } + + private final BodySubscriber bodySubscriber; + private final CompletableFuture> cf; + + @SuppressWarnings("unchecked") + BodyAndTypeSubscriber(BodySubscriber bodySubscriber) { + this.bodySubscriber = bodySubscriber; + cf = new CompletableFuture<>(); + bodySubscriber.getBody().whenComplete( + (r,t) -> cf.complete(new BodyAndTypeImpl<>((Class) r.getClass(), r))); + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + bodySubscriber.onSubscribe(subscription); + } + + @Override + public void onNext(List item) { + bodySubscriber.onNext(item); + } + + @Override + public void onError(Throwable throwable) { + bodySubscriber.onError(throwable); + cf.completeExceptionally(throwable); + } + + @Override + public void onComplete() { + bodySubscriber.onComplete(); + } + + @Override + public CompletionStage> getBody() { + return cf; + } + } + + // --- server push handler --- + static class ServerPushHandler implements HttpTestHandler { + + private final String mainResponseBody; + private final Map promises; + + public ServerPushHandler(String mainResponseBody, + Map promises) + throws Exception + { + Objects.requireNonNull(promises); + this.mainResponseBody = mainResponseBody; + this.promises = promises; + } + + public void handle(HttpTestExchange exchange) throws IOException { + System.err.println("Server: handle " + exchange); + try (InputStream is = exchange.getRequestBody()) { + is.readAllBytes(); + } + + if (exchange.serverPushAllowed()) { + pushPromises(exchange); + } + + // response data for the main response + try (OutputStream os = exchange.getResponseBody()) { + byte[] bytes = mainResponseBody.getBytes(UTF_8); + exchange.sendResponseHeaders(200, bytes.length); + os.write(bytes); + } + } + + static final BiPredicate ACCEPT_ALL = (x, y) -> true; + + private void pushPromises(HttpTestExchange exchange) throws IOException { + URI requestURI = exchange.getRequestURI(); + for (Map.Entry promise : promises.entrySet()) { + URI uri = requestURI.resolve(promise.getKey()); + InputStream is = new ByteArrayInputStream(promise.getValue().getBytes(UTF_8)); + Map> map = Map.of("X-Promise", List.of(promise.getKey())); + HttpHeaders headers = HttpHeaders.of(map, ACCEPT_ALL); + // TODO: add some check on headers, maybe + exchange.serverPush(uri, headers, is); + } + System.err.println("Server: All pushes sent"); + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3SimpleGet.java b/test/jdk/java/net/httpclient/http3/H3SimpleGet.java new file mode 100644 index 00000000000..bccd77e7e1d --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3SimpleGet.java @@ -0,0 +1,315 @@ +/* + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test id=with-continuations + * @bug 8087112 + * @requires os.family != "windows" | ( os.name != "Windows 10" & os.name != "Windows Server 2016" + * & os.name != "Windows Server 2019" ) + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.TestUtil + * jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm/timeout=480 -XX:+HeapDumpOnOutOfMemoryError -XX:+CrashOnOutOfMemoryError + * H3SimpleGet + * @run testng/othervm/timeout=480 -XX:+HeapDumpOnOutOfMemoryError -XX:+CrashOnOutOfMemoryError + * -Djdk.httpclient.retryOnStreamlimit=20 + * -Djdk.httpclient.redirects.retrylimit=21 + * -Dsimpleget.repeat=1 -Dsimpleget.chunks=1 -Dsimpleget.requests=1000 + * H3SimpleGet + * @run testng/othervm/timeout=480 -XX:+HeapDumpOnOutOfMemoryError -XX:+CrashOnOutOfMemoryError + * -Dsimpleget.requests=150 + * -Dsimpleget.chunks=16384 + * -Djdk.httpclient.retryOnStreamlimit=5 + * -Djdk.httpclient.redirects.retrylimit=6 + * -Djdk.httpclient.quic.defaultMTU=16336 + * H3SimpleGet + */ + +/* + * @test id=without-continuation + * @bug 8087112 + * @requires os.family == "windows" & ( os.name == "Windows 10" | os.name == "Windows Server 2016" + * | os.name == "Windows Server 2019" ) + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.TestUtil + * jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm/timeout=480 -XX:+HeapDumpOnOutOfMemoryError -XX:+CrashOnOutOfMemoryError + * -XX:+UnlockExperimentalVMOptions -XX:-VMContinuations + * H3SimpleGet + * @run testng/othervm/timeout=480 -XX:+HeapDumpOnOutOfMemoryError -XX:+CrashOnOutOfMemoryError + * -XX:+UnlockExperimentalVMOptions -XX:-VMContinuations + * -Djdk.httpclient.retryOnStreamlimit=20 + * -Djdk.httpclient.redirects.retrylimit=21 + * -Dsimpleget.repeat=1 -Dsimpleget.chunks=1 -Dsimpleget.requests=1000 + * H3SimpleGet + * @run testng/othervm/timeout=480 -XX:+HeapDumpOnOutOfMemoryError -XX:+CrashOnOutOfMemoryError + * -XX:+UnlockExperimentalVMOptions -XX:-VMContinuations + * -Dsimpleget.requests=150 + * -Dsimpleget.chunks=16384 + * -Djdk.httpclient.retryOnStreamlimit=5 + * -Djdk.httpclient.redirects.retrylimit=6 + * -Djdk.httpclient.quic.defaultMTU=16336 + * H3SimpleGet + */ + +/* + * @test id=useNioSelector + * @bug 8087112 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.TestUtil + * jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm/timeout=480 -XX:+HeapDumpOnOutOfMemoryError -XX:+CrashOnOutOfMemoryError + * -Djdk.internal.httpclient.quic.useNioSelector=true + * H3SimpleGet + * @run testng/othervm/timeout=480 -XX:+HeapDumpOnOutOfMemoryError -XX:+CrashOnOutOfMemoryError + * -Djdk.internal.httpclient.quic.useNioSelector=true + * -Djdk.httpclient.retryOnStreamlimit=20 + * -Djdk.httpclient.redirects.retrylimit=21 + * -Dsimpleget.repeat=1 -Dsimpleget.chunks=1 -Dsimpleget.requests=1000 + * H3SimpleGet + * @run testng/othervm/timeout=480 -XX:+HeapDumpOnOutOfMemoryError -XX:+CrashOnOutOfMemoryError + * -Djdk.internal.httpclient.quic.useNioSelector=true + * -Dsimpleget.requests=150 + * -Dsimpleget.chunks=16384 + * -Djdk.httpclient.retryOnStreamlimit=5 + * -Djdk.httpclient.redirects.retrylimit=6 + * -Djdk.httpclient.quic.defaultMTU=16336 + * H3SimpleGet + */ + +// Interesting additional settings for debugging and manual testing: +// ----------------------------------------------------------------- +// -Djdk.httpclient.HttpClient.log=requests,errors,quic:retransmit:control,http3 +// -Djdk.httpclient.HttpClient.log=errors,requests,quic:all +// -Djdk.httpclient.quic.defaultMTU=64000 +// -Djdk.httpclient.quic.defaultMTU=16384 +// -Djdk.httpclient.quic.defaultMTU=4096 +// -Djdk.httpclient.http3.maxStreamLimitTimeout=1375 +// -Xmx16g +// -Djdk.httpclient.quic.defaultMTU=16384 +// -Djdk.internal.httpclient.debug=err +// -XX:+HeapDumpOnOutOfMemoryError +// -Xmx768m -XX:MaxRAMPercentage=12.5 +// -Djdk.httpclient.HttpClient.log=errors,requests,http3 +// -Djdk.httpclient.HttpClient.log=errors,http3,quic:retransmit:control + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.Assert; +import org.testng.annotations.Test; + +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; + +public class H3SimpleGet implements HttpServerAdapters { + static HttpTestServer httpsServer; + static HttpClient client = null; + static SSLContext sslContext; + static String httpsURIString; + static ExecutorService serverExec = + Executors.newThreadPerTaskExecutor(Thread.ofVirtual() + .name("server-vt-worker-", 1).factory()); + + static void initialize() throws Exception { + try { + SimpleSSLContext sslct = new SimpleSSLContext(); + sslContext = sslct.get(); + client = getClient(); + + httpsServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext, serverExec); + httpsServer.addHandler(new TestHandler(), "/"); + httpsURIString = "https://" + httpsServer.serverAuthority() + "/bar/"; + + httpsServer.start(); + warmup(); + } catch (Throwable e) { + System.err.println("Throwing now"); + e.printStackTrace(); + throw e; + } + } + + private static void warmup() throws Exception { + SimpleSSLContext sslct = new SimpleSSLContext(); + var sslContext = sslct.get(); + + // warmup server + try (var client2 = createClient(sslContext, Executors.newThreadPerTaskExecutor( + Thread.ofVirtual().name("client-2-vt-worker", 1).factory()))) { + HttpRequest request = HttpRequest.newBuilder(URI.create(httpsURIString)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .HEAD().build(); + client2.send(request, BodyHandlers.ofByteArrayConsumer(b-> {})); + } + + // warmup client + var httpsServer2 = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext, + Executors.newThreadPerTaskExecutor( + Thread.ofVirtual().name("server-2-vt-worker", 1).factory())); + httpsServer2.addHandler(new TestHandler(), "/"); + var httpsURIString2 = "https://" + httpsServer2.serverAuthority() + "/bar/"; + httpsServer2.start(); + try { + HttpRequest request = HttpRequest.newBuilder(URI.create(httpsURIString2)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .HEAD().build(); + client.send(request, BodyHandlers.ofByteArrayConsumer(b-> {})); + } finally { + httpsServer2.stop(); + } + } + + public static void main(String[] args) throws Exception { + test(); + } + + static volatile boolean waitBeforeTest = false; + + @Test + public static void test() throws Exception { + try { + if (waitBeforeTest) { + Thread.sleep(20000); + } + long prestart = System.nanoTime(); + initialize(); + long done = System.nanoTime(); + System.out.println("Stat: Initialization and warmup took " + + TimeUnit.NANOSECONDS.toMillis(done-prestart)+" millis"); + HttpRequest request = HttpRequest.newBuilder(URI.create(httpsURIString)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .GET().build(); + long start = System.nanoTime(); + var resp = client.send(request, BodyHandlers.ofByteArrayConsumer(b-> {})); + Assert.assertEquals(resp.statusCode(), 200); + long elapsed = System.nanoTime() - start; + System.out.println("Stat: First request took: " + elapsed + + " nanos (" + TimeUnit.NANOSECONDS.toMillis(elapsed) + " ms)"); + final int max = property("simpleget.requests", 50); + List>> list = new ArrayList<>(max); + Set connections = new ConcurrentSkipListSet<>(); + long start2 = System.nanoTime(); + for (int i = 0; i < max; i++) { + var cf = client.sendAsync(request, BodyHandlers.ofByteArrayConsumer(b-> {})) + .whenComplete((r, t) -> Optional.ofNullable(r) + .flatMap(HttpResponse::connectionLabel) + .ifPresent(connections::add)); + list.add(cf); + //cf.get(); // uncomment to test with serial instead of concurrent requests + } + try { + CompletableFuture.allOf(list.toArray(new CompletableFuture[0])).join(); + } finally { + long elapsed2 = System.nanoTime() - start2; + long completed = list.stream().filter(CompletableFuture::isDone) + .filter(Predicate.not(CompletableFuture::isCompletedExceptionally)).count(); + connections.forEach(System.out::println); + if (completed > 0) { + System.out.println("Stat: Next " + completed + " requests took: " + elapsed2 + " nanos (" + + TimeUnit.NANOSECONDS.toMillis(elapsed2) + "ms for " + completed + " requests): " + + elapsed2 / completed + " nanos per request (" + + TimeUnit.NANOSECONDS.toMillis(elapsed2) / completed + " ms) on " + + connections.size() + " connections"); + } + } + list.forEach((cf) -> Assert.assertEquals(cf.join().statusCode(), 200)); + } catch (Throwable tt) { + System.err.println("tt caught"); + tt.printStackTrace(); + throw tt; + } finally { + httpsServer.stop(); + } + } + + static HttpClient createClient(SSLContext sslContext, ExecutorService clientExec) { + var builder = HttpServerAdapters.createClientBuilderForH3() + .sslContext(sslContext) + .version(HTTP_3) + .proxy(Builder.NO_PROXY); + if (clientExec != null) { + builder = builder.executor(clientExec); + } + return builder.build(); + } + + static HttpClient getClient() { + if (client == null) { + client = createClient(sslContext, null); + } + return client; + } + + static int property(String name, int defaultValue) { + return Integer.parseInt(System.getProperty(name, String.valueOf(defaultValue))); + } + + // 32 * 32 * 1024 * 10 chars = 10Mb responses + // 50 requests => 500Mb + // 100 requests => 1Gb + // 1000 requests => 10Gb + private final static int REPEAT = property("simpleget.repeat", 32); + private final static String RESPONSE = "abcdefghij".repeat(property("simpleget.chunks", 1024*32)); + private final static byte[] RESPONSE_BYTES = RESPONSE.getBytes(StandardCharsets.UTF_8); + + private static class TestHandler implements HttpTestHandler { + @Override + public void handle(HttpTestExchange t) throws IOException { + try (var in = t.getRequestBody()) { + byte[] input = in.readAllBytes(); + t.sendResponseHeaders(200, RESPONSE_BYTES.length * REPEAT); + try (var out = t.getResponseBody()) { + if (t.getRequestMethod().equals("HEAD")) return; + for (int i=0; i {})); + Assert.assertEquals(resp.statusCode(), 200); + long elapsed = System.nanoTime() - start; + System.out.println("First GET request took: " + elapsed + " nanos (" + TimeUnit.NANOSECONDS.toMillis(elapsed) + " ms)"); + final int max = 50; + List>> list = new ArrayList<>(max); + long start2 = System.nanoTime(); + for (int i = 0; i < max; i++) { + list.add(client.sendAsync(postRequest, BodyHandlers.ofByteArrayConsumer(b -> { + }))); + } + CompletableFuture.allOf(list.toArray(new CompletableFuture[0])).join(); + long elapsed2 = System.nanoTime() - start2; + System.out.println("Next " + max + " POST requests took: " + elapsed2 + " nanos (" + + TimeUnit.NANOSECONDS.toMillis(elapsed2) + "ms for " + max + " requests): " + + elapsed2 / max + " nanos per request (" + TimeUnit.NANOSECONDS.toMillis(elapsed2) / max + " ms)"); + list.forEach((cf) -> Assert.assertEquals(cf.join().statusCode(), 200)); + } catch (Throwable tt) { + System.err.println("tt caught"); + tt.printStackTrace(); + throw tt; + } finally { + httpsServer.stop(); + } + } + + static HttpClient createClient(SSLContext sslContext, ExecutorService clientExec) { + var builder = HttpServerAdapters.createClientBuilderForH3() + .sslContext(sslContext) + .version(HTTP_3) + .proxy(Builder.NO_PROXY); + if (clientExec != null) { + builder = builder.executor(clientExec); + } + return builder.build(); + } + + static HttpClient getClient() { + if (client == null) { + client = createClient(sslContext, null); + } + return client; + } + + private final static int REPEAT = 32; + private final static String RESPONSE = "abcdefghij".repeat(1024*32); + private final static byte[] RESPONSE_BYTES = RESPONSE.getBytes(StandardCharsets.UTF_8); + + private static class TestHandler implements HttpTestHandler { + @Override + public void handle(HttpTestExchange t) throws IOException { + // consume all input bytes, + try (var in = t.getRequestBody()) { + in.skip(Integer.MAX_VALUE); + t.sendResponseHeaders(200, 0); + t.getResponseBody().close(); + } + } + } + +} diff --git a/test/jdk/java/net/httpclient/http3/H3SimpleTest.java b/test/jdk/java/net/httpclient/http3/H3SimpleTest.java new file mode 100644 index 00000000000..73c766f1eab --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3SimpleTest.java @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; + +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static java.net.http.HttpClient.Builder.NO_PROXY; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; + +/* + * @test + * @summary Basic test to verify that simple GET/POST/HEAD + * requests work as expected with HTTP/3, using IPv4 + * or IPv6 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors + * H3SimpleTest + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors + * -Djava.net.preferIPv6Addresses=true + * H3SimpleTest + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors + * -Djava.net.preferIPv4Stack=true + * H3SimpleTest + */ +// -Djava.security.debug=all +public class H3SimpleTest implements HttpServerAdapters { + + private SSLContext sslContext; + private HttpTestServer h3Server; + private String requestURI; + + @BeforeClass + public void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + // create an H3 only server + h3Server = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3Server.addHandler((exchange) -> exchange.sendResponseHeaders(200, 0), "/hello"); + h3Server.start(); + System.out.println("Server started at " + h3Server.getAddress()); + requestURI = "https://" + h3Server.serverAuthority() + "/hello"; + } + + @AfterClass + public void afterClass() throws Exception { + if (h3Server != null) { + System.out.println("Stopping server " + h3Server.getAddress()); + h3Server.stop(); + } + } + + /** + * Issues various HTTP3 requests and verifies the responses are received + */ + @Test + public void testBasicRequests() throws Exception { + final HttpClient client = newClientBuilderForH3() + .proxy(NO_PROXY) + .version(HTTP_3) + .sslContext(sslContext).build(); + final URI reqURI = new URI(requestURI); + final HttpRequest.Builder reqBuilder = HttpRequest.newBuilder(reqURI) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY); + + // GET + final HttpRequest req1 = reqBuilder.copy().GET().build(); + System.out.println("Issuing request: " + req1); + final HttpResponse resp1 = client.send(req1, BodyHandlers.discarding()); + Assert.assertEquals(resp1.statusCode(), 200, "unexpected response code for GET request"); + + // POST + final HttpRequest req2 = reqBuilder.copy().POST(BodyPublishers.ofString("foo")).build(); + System.out.println("Issuing request: " + req2); + final HttpResponse resp2 = client.send(req2, BodyHandlers.discarding()); + Assert.assertEquals(resp2.statusCode(), 200, "unexpected response code for POST request"); + + // HEAD + final HttpRequest req3 = reqBuilder.copy().HEAD().build(); + System.out.println("Issuing request: " + req3); + final HttpResponse resp3 = client.send(req3, BodyHandlers.discarding()); + Assert.assertEquals(resp3.statusCode(), 200, "unexpected response code for HEAD request"); + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3StopSendingTest.java b/test/jdk/java/net/httpclient/http3/H3StopSendingTest.java new file mode 100644 index 00000000000..1855cdba6ff --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3StopSendingTest.java @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @summary Verifies that the client reacts correctly to the receipt of a STOP_SENDING frame. + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.common.HttpServerAdapters + * @run testng/othervm/timeout=40 -Djdk.internal.httpclient.debug=true -Djdk.httpclient.HttpClient.log=trace,errors,headers + * H3StopSendingTest + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestHandler; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.internal.net.http.http3.Http3Error; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.io.PrintStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.concurrent.ExecutionException; + +import static java.net.http.HttpClient.Builder.NO_PROXY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.testng.Assert.*; + +public class H3StopSendingTest { + + HttpTestServer h3TestServer; + HttpRequest postRequestNoError, postRequestError; + HttpRequest postRequestNoErrorWithData, postRequestErrorWithData; + URI h3TestServerUriNoError, h3TestServerUriError; + SSLContext sslContext; + + static final String TEST_ROOT_PATH = "/h3_stop_sending_test"; + static final String NO_ERROR_PATH = TEST_ROOT_PATH + "/no_error_path"; + static final String ERROR_PATH = TEST_ROOT_PATH + "/error_path"; + static final String WITH_RESPONSE_BODY_QUERY = "?withbody"; + + static final PrintStream err = System.err; + + + @Test + public void test() throws ExecutionException, InterruptedException { + HttpClient.Builder clientBuilder = HttpServerAdapters.createClientBuilderForH3() + .proxy(NO_PROXY) + .sslContext(sslContext); + try (HttpClient client = clientBuilder.build()) { + HttpResponse resp = client.sendAsync(postRequestNoError, HttpResponse.BodyHandlers.ofString()).get(); + err.println(resp.headers()); + err.println(resp.body()); + err.println(resp.statusCode()); + assertEquals(resp.statusCode(), 200); + resp = client.sendAsync(postRequestNoErrorWithData, HttpResponse.BodyHandlers.ofString()).get(); + err.println(resp.headers()); + err.println(resp.body()); + err.println(resp.statusCode()); + assertEquals(resp.statusCode(), 200); + assertEquals(resp.body(), RESPONSE_MESSAGE.repeat(MESSAGE_REPEAT)); + } + } + + @Test + public void testError() { + Throwable caught = null; + HttpClient.Builder clientBuilder = HttpServerAdapters.createClientBuilderForH3() + .proxy(NO_PROXY) + .sslContext(sslContext); + try (HttpClient client = clientBuilder.build()) { + try { + client.sendAsync(postRequestError, HttpResponse.BodyHandlers.ofString()).get(); + } catch (Throwable throwable) { + caught = throwable; + } + assertRequestCancelled(caught); + try { + client.sendAsync(postRequestErrorWithData, HttpResponse.BodyHandlers.ofString()).get(); + } catch (Throwable throwable) { + caught = throwable; + } + assertRequestCancelled(caught); + } + } + + private static void assertRequestCancelled(Throwable caught) { + assertNotNull(caught); + if (!caught.getMessage().contains(Long.toString(Http3Error.H3_REQUEST_CANCELLED.code())) + && !caught.getMessage().contains(Http3Error.H3_REQUEST_CANCELLED.name())) { + throw new AssertionError(caught.getMessage() + " does not contain " + + Http3Error.H3_REQUEST_CANCELLED.code() + " or " + + Http3Error.H3_REQUEST_CANCELLED.name(), caught); + } else { + System.out.println("Got expected exception: " + caught); + } + } + + @BeforeTest + public void setup() throws IOException { + + sslContext = new SimpleSSLContext().get(); + h3TestServer = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3TestServer.addHandler(new ServerRequestStopSendingHandler(), TEST_ROOT_PATH); + h3TestServerUriError = URI.create("https://" + h3TestServer.serverAuthority() + ERROR_PATH); + h3TestServerUriNoError = URI.create("https://" + h3TestServer.serverAuthority() + NO_ERROR_PATH); + h3TestServer.start(); + + Iterable iterable = EndlessDataChunks::new; + HttpRequest.BodyPublisher testPub = HttpRequest.BodyPublishers.ofByteArrays(iterable); + postRequestNoError = HttpRequest.newBuilder() + .POST(testPub) + .uri(h3TestServerUriNoError) + .version(HttpClient.Version.HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .build(); + postRequestError = HttpRequest.newBuilder() + .POST(testPub) + .uri(h3TestServerUriError) + .version(HttpClient.Version.HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .build(); + postRequestNoErrorWithData = HttpRequest.newBuilder() + .POST(testPub) + .uri(URI.create(h3TestServerUriNoError.toString() + WITH_RESPONSE_BODY_QUERY)) + .version(HttpClient.Version.HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .build(); + postRequestErrorWithData = HttpRequest.newBuilder() + .POST(testPub) + .uri(URI.create(h3TestServerUriError.toString() + WITH_RESPONSE_BODY_QUERY)) + .version(HttpClient.Version.HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .build(); + } + + @AfterTest + public void afterTest() { + h3TestServer.stop(); + } + + static final String RESPONSE_MESSAGE = "May the road rise up to meet you "; + static final String ERROR_MESSAGE = "Forbidden: the data won't be sent "; + static final String NO_BODY = ""; + + static final int MESSAGE_REPEAT = 5; + + static class ServerRequestStopSendingHandler implements HttpTestHandler { + + @Override + public void handle(HttpServerAdapters.HttpTestExchange t) throws IOException { + String query = t.getRequestURI().getQuery(); + System.out.println("Query is: " + query); + boolean withData = WITH_RESPONSE_BODY_QUERY.substring(1).equals(query); + switch (t.getRequestURI().getPath()) { + case NO_ERROR_PATH -> { + final String RESP = withData ? RESPONSE_MESSAGE : NO_BODY; + byte[] resp = RESP.getBytes(StandardCharsets.UTF_8); + System.err.println("Replying with 200 to " + t.getRequestURI().getPath() + + " with data " + resp.length * MESSAGE_REPEAT); + t.sendResponseHeaders(200, resp.length * MESSAGE_REPEAT); + t.requestStopSending(Http3Error.H3_NO_ERROR.code()); + if (resp.length == 0) return; + try (var os = t.getResponseBody()) { + for (int i = 0; i { + final String RESP = withData ? ERROR_MESSAGE : NO_BODY; + byte[] resp = RESP.getBytes(StandardCharsets.UTF_8); + System.err.println("Replying with 403 to " + t.getRequestURI().getPath() + + " with data " + resp.length * MESSAGE_REPEAT); + if (resp.length == 0) { + t.requestStopSending(Http3Error.H3_EXCESSIVE_LOAD.code()); + sleep(100); + t.resetStream(Http3Error.H3_REQUEST_CANCELLED.code()); + } else { + t.sendResponseHeaders(403, resp.length * MESSAGE_REPEAT); + t.requestStopSending(Http3Error.H3_EXCESSIVE_LOAD.code()); + try (var os = t.getResponseBody()) { + for (int i = 0; i < MESSAGE_REPEAT; i++) { + os.write(resp); + os.flush(); + sleep(10); + if (i == MESSAGE_REPEAT - 1) { + t.resetStream(Http3Error.H3_REQUEST_CANCELLED.code()); + } + } + } + } + } + } + } + } + + private static void sleep(long ms) { + try { + Thread.sleep(ms); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } + } + + static public class EndlessDataChunks implements Iterator { + + byte[] data = new byte[32]; + boolean hasNext = true; + @Override + public boolean hasNext() { + return hasNext; + } + @Override + public byte[] next() { + try { + Thread.sleep(500); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return data; + } + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3StreamLimitReachedTest.java b/test/jdk/java/net/httpclient/http3/H3StreamLimitReachedTest.java new file mode 100644 index 00000000000..916a6a5221d --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3StreamLimitReachedTest.java @@ -0,0 +1,971 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test id=with-default-wait + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.http2.Http2TestServer + * jdk.test.lib.Asserts + * jdk.test.lib.Utils + * jdk.test.lib.net.SimpleSSLContext + * @run testng/othervm -Djdk.httpclient.HttpClient.log=ssl,requests,responses,errors,http3,quic:control + * -Djdk.internal.httpclient.debug=false + * -Djdk.httpclient.quic.maxBidiStreams=1 + * H3StreamLimitReachedTest + */ + +/* + * @test id=with-no-wait + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.http2.Http2TestServer + * jdk.test.lib.Asserts + * jdk.test.lib.Utils + * jdk.test.lib.net.SimpleSSLContext + * @run testng/othervm -Djdk.httpclient.HttpClient.log=ssl,requests,responses,errors,http3,quic:control + * -Djdk.internal.httpclient.debug=false + * -Djdk.httpclient.quic.maxBidiStreams=1 + * -Djdk.httpclient.http3.maxStreamLimitTimeout=0 + * -Djdk.httpclient.retryOnStreamlimit=9 + * H3StreamLimitReachedTest + */ + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpOption; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http2.Http2Handler; +import jdk.httpclient.test.lib.http2.Http2TestExchange; +import jdk.httpclient.test.lib.http2.Http2TestServer; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.Test; + +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static jdk.test.lib.Asserts.assertEquals; +import static jdk.test.lib.Asserts.assertNotEquals; +import static jdk.test.lib.Asserts.assertTrue; +import static org.testng.Assert.assertFalse; + +public class H3StreamLimitReachedTest implements HttpServerAdapters { + + private static final String CLASS_NAME = H3ConnectionPoolTest.class.getSimpleName(); + + static int altsvcPort, https2Port, http3Port; + static Http3TestServer http3OnlyServer; + static Http2TestServer https2AltSvcServer; + static volatile HttpClient client = null; + static SSLContext sslContext; + static volatile String http3OnlyURIString, https2URIString, http3AltSvcURIString, http3DirectURIString; + + static void initialize(boolean samePort) throws Exception { + BlockingHandler.GATE.drainPermits(); + BlockingHandler.IN_HANDLER.drainPermits(); + BlockingHandler.PERMITS.set(0); + initialize(samePort, BlockingHandler::new); + } + + static void initialize(boolean samePort, Supplier handlers) throws Exception { + System.out.println("\nConfiguring for advertised AltSvc on " + + (samePort ? "same port" : "ephemeral port")); + try { + SimpleSSLContext sslct = new SimpleSSLContext(); + sslContext = sslct.get(); + client = null; + client = getClient(); + + // server that supports both HTTP/2 and HTTP/3, with HTTP/3 on an altSvc port. + https2AltSvcServer = new Http2TestServer(true, sslContext); + if (samePort) { + System.out.println("Attempting to enable advertised HTTP/3 service on same port"); + https2AltSvcServer.enableH3AltServiceOnSamePort(); + System.out.println("Advertised AltSvc on same port " + + (https2AltSvcServer.supportsH3DirectConnection() ? "enabled" : " not enabled")); + } else { + System.out.println("Attempting to enable advertised HTTP/3 service on different port"); + https2AltSvcServer.enableH3AltServiceOnEphemeralPort(); + } + https2AltSvcServer.addHandler(handlers.get(), "/" + CLASS_NAME + "/https2/"); + https2AltSvcServer.addHandler(handlers.get(), "/" + CLASS_NAME + "/h2h3/"); + https2Port = https2AltSvcServer.getAddress().getPort(); + altsvcPort = https2AltSvcServer.getH3AltService() + .map(Http3TestServer::getAddress).stream() + .mapToInt(InetSocketAddress::getPort).findFirst() + .getAsInt(); + // server that only supports HTTP/3 - we attempt to use the same port + // as the HTTP/2 server so that we can pretend that the H2 server as two H3 endpoints: + // one advertised (the alt service endpoint og the HTTP/2 server) + // one non advertised (the direct endpoint, at the same authority as HTTP/2, but which + // is in fact our http3OnlyServer) + try { + http3OnlyServer = new Http3TestServer(sslContext, samePort ? 0 : https2Port); + System.out.println("Unadvertised service enabled on " + + (samePort ? "ephemeral port" : "same port")); + } catch (IOException ex) { + System.out.println("Can't create HTTP/3 server on same port: " + ex); + http3OnlyServer = new Http3TestServer(sslContext, 0); + } + http3OnlyServer.addHandler("/" + CLASS_NAME + "/http3/", handlers.get()); + http3OnlyServer.addHandler("/" + CLASS_NAME + "/h2h3/", handlers.get()); + http3OnlyServer.start(); + http3Port = http3OnlyServer.getQuicServer().getAddress().getPort(); + + if (http3Port == https2Port) { + System.out.println("HTTP/3 server enabled on same port than HTTP/2 server"); + if (samePort) { + System.out.println("WARNING: configuration could not be respected," + + " should have used ephemeral port for HTTP/3 server"); + } + } else { + System.out.println("HTTP/3 server enabled on a different port than HTTP/2 server"); + if (!samePort) { + System.out.println("WARNING: configuration could not be respected," + + " should have used same port for HTTP/3 server"); + } + } + if (altsvcPort == https2Port) { + if (!samePort) { + System.out.println("WARNING: configuration could not be respected," + + " should have used same port for advertised AltSvc"); + } + } else { + if (samePort) { + System.out.println("WARNING: configuration could not be respected," + + " should have used ephemeral port for advertised AltSvc"); + } + } + + http3OnlyURIString = "https://" + http3OnlyServer.serverAuthority() + "/" + CLASS_NAME + "/http3/foo/"; + https2URIString = "https://" + https2AltSvcServer.serverAuthority() + "/" + CLASS_NAME + "/https2/bar/"; + http3DirectURIString = "https://" + https2AltSvcServer.serverAuthority() + "/" + CLASS_NAME + "/h2h3/direct/"; + http3AltSvcURIString = https2URIString + .replace(":" + https2Port + "/", ":" + altsvcPort + "/") + .replace("/https2/bar/", "/h2h3/altsvc/"); + System.out.println("HTTP/2 server started at: " + https2AltSvcServer.serverAuthority()); + System.out.println(" with advertised HTTP/3 endpoint at: " + + URI.create(http3AltSvcURIString).getRawAuthority()); + System.out.println("HTTP/3 server started at:" + http3OnlyServer.serverAuthority()); + + https2AltSvcServer.start(); + } catch (Throwable e) { + System.out.println("Configuration failed: " + e); + System.err.println("Throwing now: " + e); + e.printStackTrace(); + throw e; + } + } + + static class BlockingHandler implements Http2Handler { + static final AtomicLong REQCOUNT = new AtomicLong(); + static final Semaphore IN_HANDLER = new Semaphore(0); + static final Semaphore GATE = new Semaphore(0); + static final AtomicInteger PERMITS = new AtomicInteger(); + + static void acquireGate() throws InterruptedException { + GATE.acquire(); + System.out.println("GATE acquired: remaining permits: " + + PERMITS.decrementAndGet() + + " (actual: " + GATE.availablePermits() +")"); + } + static void releaseGate() throws InterruptedException { + int permits = PERMITS.incrementAndGet(); + GATE.release(); + System.out.println("GATE released: remaining permits: " + + permits + " (actual: " + GATE.availablePermits() +")"); + } + + static void releaseGate(int permits) throws InterruptedException { + int npermits = PERMITS.addAndGet(permits); + GATE.release(npermits); + System.out.println("GATE released (" + permits + "): remaining permits: " + + npermits + " (actual: " + GATE.availablePermits() +")"); + } + + @Override + public void handle(Http2TestExchange t) throws IOException { + long count = REQCOUNT.incrementAndGet(); + byte[] resp; + int status; + try { + try { + IN_HANDLER.release(); + System.out.printf("*** Server [%s] waiting for GATE: %s%n", + count, t.getRequestURI()); + System.err.printf("*** Server [%s] waiting for GATE: %s%n", + count, t.getRequestURI()); + acquireGate(); + System.err.printf("*** Server [%s] GATE acquired: %s%n", + count, t.getRequestURI()); + status = 200; + resp = "Request %s OK".formatted(count) + .getBytes(StandardCharsets.UTF_8); + } catch (InterruptedException x) { + status = 500; + resp = "Request %s interrupted: %s" + .formatted(count, x) + .getBytes(StandardCharsets.UTF_8); + } + System.out.printf("*** Server [%s] headers for: %s%n\t%s%n", + count, t.getRequestURI(), t.getRequestHeaders()); + System.out.printf("*** Server [%s] reading body for: %s%n", + count, t.getRequestURI()); + t.getRequestBody().readAllBytes(); + System.out.printf("*** Server [%s] sending headers for: %s%n", + count, t.getRequestURI()); + t.sendResponseHeaders(status, resp.length); + System.out.printf("*** Server [%s] sending body %s (%s bytes) for: %s%n", + count, status, resp.length, t.getRequestURI()); + try (var body = t.getResponseBody()) { + body.write(resp); + } + System.out.printf("*** Server [%s] response %s sent for: %s%n", + count, status, t.getRequestURI()); + } catch (Throwable throwable) { + var msg = String.format("Server [%s] response failed for: %s", + count, t.getRequestURI()); + System.out.printf("*** %s%n\t%s%n", msg, throwable); + var error = new IOException(msg,throwable); + error.printStackTrace(System.out); + System.err.printf("*** %s%n\t%s%n", msg, throwable); + //GATE.release(); + throw error; + } + } + } + + private static void printResponse(String name, + HttpOption option, + HttpResponse response) { + printResponse("%s %s".formatted( + name, response.request().getOption(option)), + response); + } + + private static void printResponse(String name, HttpResponse response) { + System.out.printf("%s: (%s): %s%n", + name, response.connectionLabel(), response); + response.headers().map().entrySet().forEach((e) -> { + System.out.printf(" %s: %s%n", e.getKey(), e.getValue()); + }); + System.out.printf(" :body: \"%s\"%n%n", response.body()); + } + + @Test + public static void testH3Only() throws Exception { + System.out.println("\nTesting HTTP/3 only"); + initialize(true); + try (HttpClient client = getClient()) { + var reqBuilder = HttpRequest.newBuilder() + .uri(URI.create(http3OnlyURIString)) + .version(HTTP_3) + .GET(); + HttpRequest request1 = reqBuilder.copy() + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .build(); + var responseCF1 = client.sendAsync(request1, BodyHandlers.ofString()); + // wait until blocked in the handler + BlockingHandler.IN_HANDLER.acquire(); + + // ANY should reuse the same connection, get stream limit reached, + // and open a new connection + HttpRequest request2 = reqBuilder.copy() + .setOption(H3_DISCOVERY, ANY) + .build(); + var responseCF2 = client.sendAsync(request2, BodyHandlers.ofString()); + // wait until blocked in the handler + BlockingHandler.IN_HANDLER.acquire(); + + // release both + BlockingHandler.GATE.release(2); + + var response1 = responseCF1.get(); + printResponse("First response", response1); + var response2 = responseCF2.get(); + printResponse("Second response", response2); + + // set a timeout to make sure we wait long enough for + // the MAX_STREAMS update to reach us before attempting + // to create the HTTP/3 exchange. + HttpRequest request3 = reqBuilder.copy() + .timeout(Duration.ofSeconds(30)) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .build(); + var responseCF3 = client.sendAsync(request3, BodyHandlers.ofString()); + // request3 should now reuse the connection opened by request2 + // wait until blocked in the handler + BlockingHandler.IN_HANDLER.acquire(); + + // ANY should reuse the same connection as request3, + // get stream limit exception, and open a new connection + HttpRequest request4 = reqBuilder.copy() + .setOption(H3_DISCOVERY, ANY) + .build(); + var responseCF4 = client.sendAsync(request4, BodyHandlers.ofString()); + // wait until blocked in the handler + BlockingHandler.IN_HANDLER.acquire(); + + // release both response 3 and response 4 + BlockingHandler.GATE.release(2); + + var response3 = responseCF3.get(); + printResponse("Third response", response3); + var response4 = responseCF4.get(); + printResponse("Fourth response", response4); + + assertNotEquals(response1.connectionLabel().get(), response2.connectionLabel().get()); + assertEquals(response2.connectionLabel().get(), response3.connectionLabel().get()); + assertNotEquals(response1.connectionLabel().get(), response4.connectionLabel().get()); + assertNotEquals(response2.connectionLabel().get(), response4.connectionLabel().get()); + assertNotEquals(response3.connectionLabel().get(), response4.connectionLabel().get()); + } finally { + http3OnlyServer.stop(); + https2AltSvcServer.stop(); + } + } + + @Test + public static void testH2H3WithTwoAltSVC() throws Exception { + testH2H3(false); + } + + @Test + public static void testH2H3WithAltSVCOnSamePort() throws Exception { + testH2H3(true); + } + + private static void testH2H3(boolean samePort) throws Exception { + System.out.println("\nTesting with advertised AltSvc on " + + (samePort ? "same port" : "ephemeral port")); + initialize(samePort); + try (HttpClient client = getClient()) { + var req1Builder = HttpRequest.newBuilder() + .uri(URI.create(http3DirectURIString)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .GET(); + var req2Builder = HttpRequest.newBuilder() + .uri(URI.create(http3DirectURIString)) + .setOption(H3_DISCOVERY, ALT_SVC) + .version(HTTP_3) + .GET(); + + if (altsvcPort == https2Port) { + System.out.println("Testing with alt service on same port"); + + // first request with HTTP3_URI_ONLY should create H3 connection + HttpRequest request1 = req1Builder.copy().build(); + var responseCF1 = client.sendAsync(request1, BodyHandlers.ofString()); + BlockingHandler.IN_HANDLER.acquire(); + + HttpRequest request2 = req2Builder.copy().build(); + // first request with ALT_SVC is to get alt service, should be H2 + var h2resp2CF = client.sendAsync(request2, BodyHandlers.ofString()); + // Blocking until the handler is invoked should be enough + // to ensure we get the AltSvc frame before the next request. + BlockingHandler.IN_HANDLER.acquire(); + BlockingHandler.GATE.release(3); + + // second request should have ALT_SVC and create new connection with H3 + // it should not reuse the non-advertised connection + // There's still a potential race here - to avoid it + // we need h2resp2 = h2resp2CF.get(); before sending request2 + var h2resp2 = h2resp2CF.get(); + System.out.printf("Got expected h2 response: %s [%s]%n", + h2resp2, h2resp2.version()); + + var responseCF2 = client.sendAsync(request2, BodyHandlers.ofString()); + BlockingHandler.IN_HANDLER.acquire(); + + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + + var response1 = responseCF1.get(); + printResponse("response 1", H3_DISCOVERY, response1); + var response2 = responseCF2.get(); + printResponse("response 2", H3_DISCOVERY, response2); + assertEquals(HTTP_3, response1.version()); + checkStatus(200, response1.statusCode()); + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + assertNotEquals(response2.connectionLabel().get(), response1.connectionLabel().get()); + + // second request with HTTP3_URI_ONLY should reuse a created connection + // It should reuse the advertised connection (from response2) if same + // origin + // Specify a request timeout here to make sure we wait long enough + // for the MAX_STREAMS frame to arrive before creating a new + // connection + HttpRequest request3 = req1Builder.copy() + .timeout(Duration.ofSeconds(30)).build(); + var responseCF3 = client.sendAsync(request3, BodyHandlers.ofString()); + BlockingHandler.IN_HANDLER.acquire(); + + // third request with ALT_SVC should reuse the same advertised + // connection (from response2), regardless of same origin... + // It should not reuse the connection from response1, it should + // not invoke reconnection on the connection used by response 3 + // Specify a request timeout here to make sure we wait long enough + // for the MAX_STREAMS frame to arrive before creating a new + // connection + HttpRequest request4 = req2Builder.copy() + .timeout(Duration.ofSeconds(30)).build(); + var responseCF4 = client.sendAsync(request4, BodyHandlers.ofString()); + BlockingHandler.IN_HANDLER.acquire(); + + // release both + BlockingHandler.GATE.release(2); + + var response3 = responseCF3.get(); + printResponse("response 3", H3_DISCOVERY, response3); + var response4 = responseCF4.get(); + printResponse("response 4", H3_DISCOVERY, response4); + + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + assertEquals(response1.connectionLabel().get(), response3.connectionLabel().get()); + + assertEquals(HTTP_3, response4.version()); + checkStatus(200, response4.statusCode()); + assertEquals(response2.connectionLabel().get(), response4.connectionLabel().get()); + + } else if (http3Port == https2Port) { + System.out.println("Testing with two alt services"); + // first, get the alt service + HttpRequest request2 = req2Builder.copy().build(); + // first request with ALT_SVC is to get alt service, should be H2 + BlockingHandler.GATE.release(); + HttpResponse h2resp2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + BlockingHandler.IN_HANDLER.acquire(); + System.out.printf("Got expected h2 response: %s [%s]%n", + h2resp2, h2resp2.version()); + + // second - make a direct connection + HttpRequest request1 = req1Builder.copy() + .uri(URI.create(http3DirectURIString+"?request1")).build(); + var responseCF1 = client.sendAsync(request1, BodyHandlers.ofString()); + + // second request should have ALT_SVC and create new connection with H3 + // it should not reuse the non-advertised connection + var req2 = req2Builder.copy() + .uri(URI.create(http3DirectURIString + "?request2")) + .build(); + var responseCF2 = client.sendAsync(req2, BodyHandlers.ofString()); + + BlockingHandler.IN_HANDLER.acquire(2); + BlockingHandler.GATE.release(2); + + var response1 = responseCF1.get(); + printResponse("response 1", H3_DISCOVERY, response1); + var response2 = responseCF2.get(); + printResponse("response 2", H3_DISCOVERY, response2); + + assertEquals(HTTP_3, response1.version()); + checkStatus(200, response1.statusCode()); + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + assertNotEquals(response2.connectionLabel().get(), h2resp2.connectionLabel().get()); + assertNotEquals(response2.connectionLabel().get(), response1.connectionLabel().get()); + + // third request with ALT_SVC should reuse the same advertised + // connection (from response2), regardless of same origin... + // Specify a request timeout here to make sure we wait long enough + // for the MAX_STREAMS frame to arrive before creating a new + // connection + HttpRequest request3 = req2Builder.copy() + .uri(URI.create(http3DirectURIString + "?request3")) + .timeout(Duration.ofSeconds(30)).build(); + var responseCF3 = client.sendAsync(request3, BodyHandlers.ofString()); + + // fourth request with HTTP_3_URI_ONLY should reuse the first connection, + // and not reuse the second. + // Specify a request timeout here to make sure we wait long enough + // for the MAX_STREAMS frame to arrive before creating a new + // connection + HttpRequest request4 = req1Builder.copy() + .uri(URI.create(http3DirectURIString + "?request4")) + .timeout(Duration.ofSeconds(30)).build(); + var responseCF4 = client.sendAsync(request4, BodyHandlers.ofString()); + + BlockingHandler.IN_HANDLER.acquire(2); + BlockingHandler.GATE.release(2); + + var response3 = responseCF3.get(); + printResponse("response 3", H3_DISCOVERY, response3); + var response4 = responseCF4.get(); + printResponse("response 4", H3_DISCOVERY, response4); + + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + assertEquals(response2.connectionLabel().get(), response3.connectionLabel().get()); + assertNotEquals(response1.connectionLabel().get(), response3.connectionLabel().get()); + assertEquals(HTTP_3, response4.version()); + assertEquals(response1.connectionLabel().get(), response4.connectionLabel().get()); + assertNotEquals(response3.connectionLabel().get(), response4.connectionLabel().get()); + checkStatus(200, response1.statusCode()); + } else { + System.out.println("WARNING: Couldn't create HTTP/3 server on same port! Can't test all..."); + // Get, get the alt service + HttpRequest request2 = req2Builder.copy().build(); + // first request with ALT_SVC is to get alt service, should be H2 + BlockingHandler.GATE.release(1); + HttpResponse h2resp2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + BlockingHandler.IN_HANDLER.acquire(); + + // second request should have ALT_SVC and create new connection with H3 + // it should not reuse the non-advertised connection + var responseCF2 = client.sendAsync(request2, BodyHandlers.ofString()); + BlockingHandler.IN_HANDLER.acquire(); + + // third request with ALT_SVC should reuse the same advertised + // connection (from response2), regardless of same origin, get + // StreamLimitReached exception, and create a new connection + HttpRequest request3 = req2Builder.copy().build(); + var responseCF3 = client.sendAsync(request3, BodyHandlers.ofString()); + BlockingHandler.IN_HANDLER.acquire(); + + BlockingHandler.GATE.release(2); + + var response2 = responseCF2.get(); + printResponse("response 2", H3_DISCOVERY, response2); + + var response3 = responseCF3.get(); + printResponse("response 3", H3_DISCOVERY, response3); + + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + assertNotEquals(response2.connectionLabel().get(), h2resp2.connectionLabel().get()); + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + assertNotEquals(response3.connectionLabel().get(), response2.connectionLabel().get()); + } + } finally { + http3OnlyServer.stop(); + https2AltSvcServer.stop(); + } + } + + @Test + public static void testParallelH2H3WithTwoAltSVC() throws Exception { + testH2H3Concurrent(false); + } + + @Test + public static void testParallelH2H3WithAltSVCOnSamePort() throws Exception { + testH2H3Concurrent(true); + } + + + private static void testH2H3Concurrent(boolean samePort) throws Exception { + System.out.println("\nTesting concurrent reconnections with advertised AltSvc on " + + (samePort ? "same port" : "ephemeral port")); + initialize(samePort); + try (HttpClient client = getClient()) { + var req1Builder = HttpRequest.newBuilder() + .uri(URI.create(http3DirectURIString)) + .version(HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .GET(); + var req2Builder = HttpRequest.newBuilder() + .uri(URI.create(http3DirectURIString)) + .setOption(H3_DISCOVERY, ALT_SVC) + .version(HTTP_3) + .GET(); + + if (altsvcPort == https2Port) { + System.out.println("Testing reconnections with alt service on same port"); + + // first request with HTTP3_URI_ONLY should create H3 connection + HttpRequest request1 = req1Builder.copy().build(); + HttpRequest request2 = req2Builder.copy().build(); + List>> directResponses = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + HttpRequest req1 = req1Builder.copy() + .uri(URI.create(http3DirectURIString+"?dir="+i)).build(); + directResponses.add(client.sendAsync(req1, BodyHandlers.ofString())); + BlockingHandler.IN_HANDLER.acquire(); + } + // can't send requests in parallel here because if any establishes + // a connection before the H3 direct are established, then the H3 + // direct might reuse the H3 alt since the service is with same origin + BlockingHandler.releaseGate(directResponses.size() + 1); + HttpResponse h2resp2 = client.send(request2, BodyHandlers.ofString()); + BlockingHandler.IN_HANDLER.acquire(); + + CompletableFuture.allOf(directResponses.stream() + .toArray(CompletableFuture[]::new)).exceptionally((t) -> null) + .join(); + + Set c1Label = new HashSet<>(); + for (int i = 0; i < directResponses.size(); i++) { + HttpResponse response1 = directResponses.get(i).get(); + System.out.printf("direct response [%s][%s]: %s%n", i, + response1.connectionLabel(), + response1); + assertEquals(HTTP_3, response1.version()); + checkStatus(200, response1.statusCode()); + var cLabel = response1.connectionLabel().get(); + assertFalse(c1Label.contains(cLabel), + "%s contained in %s".formatted(cLabel, c1Label)); + c1Label.add(cLabel); + } + + // first request with ALT_SVC is to get alt service, should be H2 + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + assertFalse(c1Label.contains(h2resp2.connectionLabel().get()), + "%s contained in %s!".formatted(h2resp2.connectionLabel().get(), c1Label)); + + // second request should have ALT_SVC and create new connection with H3 + // it should not reuse the non-advertised connection + List>> altResponses = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + HttpRequest req2 = req2Builder.copy() + .uri(URI.create(http3DirectURIString+"?alt="+i)).build(); + altResponses.add(client.sendAsync(req2, BodyHandlers.ofString())); + } + + BlockingHandler.releaseGate(altResponses.size()); + BlockingHandler.IN_HANDLER.acquire(altResponses.size()); + + CompletableFuture.allOf(altResponses.stream().toArray(CompletableFuture[]::new)) + .exceptionally((t) -> null) + .join(); + + Set c2Label = new HashSet<>(); + for (int i = 0; i < altResponses.size(); i++) { + HttpResponse response2 = altResponses.get(i).get(); + System.out.printf("alt response [%s][%s]: %s%n", i, + response2.connectionLabel(), + response2); + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + var cLabel = response2.connectionLabel().get(); + if (c2Label.contains(cLabel)) { + System.out.printf("Connection %s reused%n", cLabel); + } + c2Label.add(cLabel); + assertFalse(c1Label.contains(cLabel), + "%s contained in %s".formatted(cLabel, c1Label)); + // cLabel could already be in c2Label, if the previous + // request finished before the next one. + } + + System.out.println("Sending mixed requests"); + + // second set of requests should reuse a created connection + HttpRequest request3 = req1Builder.copy().build(); + List>> mixResponses = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + var builder1 = req1Builder.copy() + .uri(URI.create(http3DirectURIString+"?mix1="+i)); + if (i == 0) { + // make sure to give time for MAX_STREAMS to arrive + builder1 = builder1.timeout(Duration.ofSeconds(30)); + } + HttpRequest req1 = builder1.build(); + mixResponses.add(client.sendAsync(req1, BodyHandlers.ofString())); + if (i == 0) BlockingHandler.IN_HANDLER.acquire(); + var builder2 = req2Builder.copy() + .uri(URI.create(http3DirectURIString+"?mix2="+i)); + if (i == 0) { + // make sure to give time for MAX_STREAMS to arrive + builder2 = builder2.timeout(Duration.ofSeconds(30)); + } + HttpRequest req2 = builder2.build(); + mixResponses.add(client.sendAsync(req2, BodyHandlers.ofString())); + if (i == 0) BlockingHandler.IN_HANDLER.acquire(); + } + + System.out.println("IN_HANDLER.acquire(" + + (mixResponses.size() - 2) + ") - available: " + + BlockingHandler.IN_HANDLER.availablePermits()); + BlockingHandler.IN_HANDLER.acquire(mixResponses.size() - 2); + BlockingHandler.releaseGate(mixResponses.size()); + + System.out.println("Getting mixed responses"); + + CompletableFuture.allOf(mixResponses.stream().toArray(CompletableFuture[]::new)) + .exceptionally((t) -> null) + .join(); + + System.out.println("All mixed responses received"); + + Set mixC1Label = new HashSet<>(); + Set mixC2Label = new HashSet<>(); + for (int i = 0; i < mixResponses.size(); i++) { + HttpResponse response3 = mixResponses.get(i).get(); + System.out.printf("mixed response [%s][%s] %s: %s%n", i, + response3.connectionLabel(), + response3.request().getOption(H3_DISCOVERY), + response3); + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + var cLabel = response3.connectionLabel().get(); + if (response3.request().getOption(H3_DISCOVERY).orElse(null) == ALT_SVC) { + if (i == 0 || i == 1) { + assertTrue(c2Label.contains(cLabel), + "%s not in %s".formatted(cLabel, c2Label)); + System.out.printf("first ALTSVC connection reused %s from %s%n", cLabel, c2Label); + } else { + assertFalse(c2Label.contains(cLabel), + "%s in %s".formatted(cLabel, c2Label)); + } + assertFalse(c1Label.contains(cLabel), + "%s in %s".formatted(cLabel, c1Label)); + assertFalse(mixC1Label.contains(cLabel), + "%s in %s".formatted(cLabel, mixC1Label)); + if (mixC2Label.contains(cLabel)) { + System.out.printf("ALTSVC connection reused %s from %s%n", cLabel, mixC2Label); + } + mixC2Label.add(cLabel); + } else { + if (i == 0 || i == 1) { + assertTrue(c1Label.contains(cLabel), + "%s not in %s".formatted(cLabel, c1Label)); + System.out.printf("first ALTSVC connection reused %s from %s%n", cLabel, c1Label); + } else { + assertFalse(c1Label.contains(cLabel), + "%s in %s".formatted(cLabel, c1Label)); + } + assertFalse(c2Label.contains(cLabel), + "%s in %s".formatted(cLabel, c2Label)); + assertFalse(mixC2Label.contains(cLabel), + "%s in %s".formatted(cLabel, mixC2Label)); + if (mixC1Label.contains(cLabel)) { + System.out.printf("ALTSVC connection reused %s from %s%n", cLabel, mixC1Label); + } + mixC1Label.add(cLabel); + } + } + System.out.println("All done"); + } else if (http3Port == https2Port) { + System.out.println("Testing with two alt services"); + // first - make a direct connection + HttpRequest request1 = req1Builder.copy().build(); + + // second, use the alt service + HttpRequest request2 = req2Builder.copy().build(); + BlockingHandler.GATE.release(); + HttpResponse h2resp2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + BlockingHandler.IN_HANDLER.acquire(); + + // third, use ANY + HttpRequest request3 = req2Builder.copy().setOption(H3_DISCOVERY, ANY).build(); + + List>> directResponses = new ArrayList<>(); + List>> altResponses = new ArrayList<>(); + List>> anyResponses = new ArrayList<>(); + checkStatus(200, h2resp2.statusCode()); + + // We're going to send nine requests here. We could get + // "No more stream available on connection" when we run with + // a stream limit timeout of 0, unless we raise the retry on + // stream limit to at least 6 + for (int i = 0; i < 3; i++) { + anyResponses.add(client.sendAsync(request3, BodyHandlers.ofString())); + directResponses.add(client.sendAsync(request1, BodyHandlers.ofString())); + altResponses.add(client.sendAsync(request2, BodyHandlers.ofString())); + } + + var all = new ArrayList<>(directResponses); + all.addAll(altResponses); + all.addAll(anyResponses); + int requestCount = all.size(); + + BlockingHandler.GATE.release(requestCount); + CompletableFuture.allOf(all.stream().toArray(CompletableFuture[]::new)) + .exceptionally((t) -> null) + .join(); + + Set c1Label = new HashSet<>(); + for (int i = 0; i < directResponses.size(); i++) { + HttpResponse response1 = directResponses.get(i).get(); + System.out.printf("direct response [%s][%s] %s: %s%n", i, + response1.connectionLabel(), + response1.request().getOption(H3_DISCOVERY), + response1); + assertEquals(HTTP_3, response1.version()); + checkStatus(200, response1.statusCode()); + + var cLabel = response1.connectionLabel().get(); + if (c1Label.contains(cLabel)) { + System.out.printf(" connection %s reused from %s%n", cLabel, c1Label); + } + c1Label.add(cLabel); + } + Set c2Label = new HashSet<>(); + for (int i = 0; i < altResponses.size(); i++) { + HttpResponse response2 = altResponses.get(i).get(); + System.out.printf("alt response [%s][%s] %s: %s%n", i, + response2.connectionLabel(), + response2.request().getOption(H3_DISCOVERY), + response2); + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + + var cLabel = response2.connectionLabel().get(); + if (c2Label.contains(cLabel)) { + System.out.printf(" connection %s reused from %s%n", cLabel, c2Label); + } + assertNotEquals(cLabel, h2resp2.connectionLabel().get()); + assertFalse(c1Label.contains(cLabel), + "%s found in %s".formatted(cLabel, c1Label)); + c2Label.add(cLabel); + } + + var diff = new HashSet<>(c2Label); + diff.retainAll(c1Label); + assertTrue(diff.isEmpty()); + + var anyLabels = new HashSet(Set.of(c1Label, c2Label)); + for (int i = 0; i < anyResponses.size(); i++) { + HttpResponse response3 = anyResponses.get(i).get(); + System.out.printf("any response [%s][%s] %s: %s%n", i, + response3.connectionLabel(), + response3.request().getOption(H3_DISCOVERY), + response3); + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + assertNotEquals(response3.connectionLabel().get(), h2resp2.connectionLabel().get()); + var label = response3.connectionLabel().orElse(""); + if (anyLabels.contains(label)) { + System.out.printf(" connection %s reused from %s%n", label, anyLabels); + } + } + BlockingHandler.IN_HANDLER.acquire(requestCount); + } else { + System.out.println("WARNING: Couldn't create HTTP/3 server on same port! Can't test all..."); + // Get, get the alt service + HttpRequest request2 = req2Builder.copy().build(); + // first request with ALT_SVC is to get alt service, should be H2 + BlockingHandler.GATE.release(); + HttpResponse h2resp2 = client.send(request2, BodyHandlers.ofString()); + assertEquals(HTTP_2, h2resp2.version()); + checkStatus(200, h2resp2.statusCode()); + BlockingHandler.IN_HANDLER.acquire(); + + // second request should have ALT_SVC and create new connection with H3 + var responseCF2 = client.sendAsync(request2, BodyHandlers.ofString()); + BlockingHandler.IN_HANDLER.acquire(); + + // third request with ALT_SVC should reuse the same advertised + // connection (from response2), regardless of same origin, get + // stream limit, and create a new connection + HttpRequest request3 = req2Builder.copy().build(); + var responseCF3 = client.sendAsync(request3, BodyHandlers.ofString()); + BlockingHandler.IN_HANDLER.acquire(); + + BlockingHandler.GATE.release(2); + + CompletableFuture.allOf(responseCF2, responseCF3) + .exceptionally((t) -> null) + .join(); + + var response2 = responseCF2.get(); + printResponse("first HTTP/3 request", H3_DISCOVERY, response2); + var response3 = responseCF3.get(); + printResponse("second HTTP/3 request", H3_DISCOVERY, response2); + + assertEquals(HTTP_3, response2.version()); + checkStatus(200, response2.statusCode()); + assertNotEquals(response2.connectionLabel().get(), h2resp2.connectionLabel().get()); + + assertEquals(HTTP_3, response3.version()); + checkStatus(200, response3.statusCode()); + assertNotEquals(response3.connectionLabel().get(), h2resp2.connectionLabel().get()); + assertNotEquals(response3.connectionLabel().get(), response2.connectionLabel().get()); + } + } catch (Throwable t) { + t.printStackTrace(System.out); + throw t; + } finally { + http3OnlyServer.stop(); + https2AltSvcServer.stop(); + } + } + static HttpClient getClient() { + if (client == null) { + client = HttpServerAdapters.createClientBuilderForH3() + .sslContext(sslContext) + .version(HTTP_3) + .build(); + } + return client; + } + + static void checkStatus(int expected, int found) throws Exception { + if (expected != found) { + System.err.printf("Test failed: wrong status code %d/%d\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static void checkStrings(String expected, String found) throws Exception { + if (!expected.equals(found)) { + System.err.printf("Test failed: wrong string %s/%s\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + + static T logExceptionally(String desc, Throwable t) { + System.out.println(desc + " failed: " + t); + System.err.println(desc + " failed: " + t); + if (t instanceof RuntimeException r) throw r; + if (t instanceof Error e) throw e; + throw new CompletionException(t); + } + +} diff --git a/test/jdk/java/net/httpclient/http3/H3Timeout.java b/test/jdk/java/net/httpclient/http3/H3Timeout.java new file mode 100644 index 00000000000..81a0f87467b --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3Timeout.java @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.net.DatagramSocket; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpConnectTimeoutException; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.net.http.HttpTimeoutException; +import java.time.Duration; +import java.util.concurrent.CompletionException; +import java.util.concurrent.CountDownLatch; +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.internal.net.http.common.OperationTrackers.Tracker; +import jdk.test.lib.net.SimpleSSLContext; + +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; + +/* + * @test + * @bug 8156710 + * @summary Check if HttpTimeoutException is thrown if a server doesn't reply + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.TestUtil + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @compile ../ReferenceTracker.java + * @run main/othervm -Djdk.httpclient.HttpClient.log=ssl,requests,responses,errors H3Timeout + */ +public class H3Timeout implements HttpServerAdapters { + + private static final int TIMEOUT = 2 * 1000; // in millis + private static final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; + + public static void main(String[] args) throws Exception { + SSLContext context = new SimpleSSLContext().get(); + testConnect(context, false); + testConnect(context, true); + testTimeout(context, false); + testTimeout(context, true); + } + + public static void testConnect(SSLContext context, boolean async) throws Exception { + + InetSocketAddress loopback = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0); + try (DatagramSocket socket = new DatagramSocket(loopback)) { + String address = socket.getLocalAddress().getHostAddress(); + if (address.indexOf(':') >= 0) { + if (!address.startsWith("[") || !address.endsWith("]")) { + address = "[" + address + "]"; + } + } + String serverAuth = address + ":" + socket.getLocalPort(); + String uri = "https://" + serverAuth + "/"; + HttpTimeoutException expected; + if (async) { + System.out.println(uri + ": Trying to connect asynchronously"); + expected = connectAsync(context, uri); + } else { + System.out.println(uri + ": Trying to connect synchronously"); + expected = connect(context, uri); + } + if (!(expected instanceof HttpConnectTimeoutException)) { + throw new AssertionError("expected HttpConnectTimeoutException, got: " + + expected, expected); + } + } + } + + public static void testTimeout(SSLContext context, boolean async) throws Exception { + + CountDownLatch latch = new CountDownLatch(1); + HttpTestServer server = HttpTestServer.create(HTTP_3_URI_ONLY, context); + server.addHandler((exch) -> { + try { + System.err.println("server reading request"); + byte[] req = exch.getRequestBody().readAllBytes(); + System.err.printf("server got request: %s bytes", req.length); + latch.await(); + exch.sendResponseHeaders(500, 0); + } catch (Exception e) { + System.err.println("server exception: " + e); + } + }, "/"); + server.start(); + try { + String serverAuth = server.serverAuthority(); + String uri = "https://" + serverAuth + "/"; + HttpTimeoutException expected; + if (async) { + System.out.println(uri + ": Trying to connect asynchronously"); + expected = connectAsync(context, uri); + } else { + System.out.println(uri + ": Trying to connect synchronously"); + expected = connect(context, uri); + } + assert expected instanceof HttpTimeoutException; + } finally { + latch.countDown(); + server.stop(); + } + } + + private static HttpTimeoutException connect(SSLContext context, String server) throws Exception { + HttpClient client = HttpServerAdapters.createClientBuilderForH3() + .version(HTTP_3) + .sslContext(context) + .build(); + try { + HttpRequest request = HttpRequest.newBuilder(new URI(server)) + .timeout(Duration.ofMillis(TIMEOUT)) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .POST(BodyPublishers.ofString("body")) + .build(); + HttpResponse response = client.send(request, BodyHandlers.ofString()); + System.out.println("Received unexpected reply: " + response.statusCode()); + throw new RuntimeException("unexpected successful connection"); + } catch (HttpTimeoutException e) { + System.out.println("expected exception: " + e); + return e; + } finally { + client.shutdown(); + if (!client.awaitTermination(Duration.ofSeconds(5))) { + Tracker tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + AssertionError error = TRACKER.check(tracker, 5000); + if (error != null) throw error; + } + } + } + + private static HttpTimeoutException connectAsync(SSLContext context, String server) throws Exception { + try (HttpClient client = HttpServerAdapters.createClientBuilderForH3() + .version(HTTP_3) + .sslContext(context) + .build()) { + try { + HttpRequest request = HttpRequest.newBuilder(new URI(server)) + .timeout(Duration.ofMillis(TIMEOUT)) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .POST(BodyPublishers.ofString("body")) + .build(); + HttpResponse response = client.sendAsync(request, BodyHandlers.ofString()).join(); + System.out.println("Received unexpected reply: " + response.statusCode()); + throw new RuntimeException("unexpected successful connection"); + } catch (CompletionException e) { + var cause = e.getCause(); + if (cause instanceof HttpTimeoutException timeout) { + System.out.println("expected exception: " + e.getCause()); + return timeout; + } else { + throw new RuntimeException("Unexpected exception received: " + e.getCause(), e); + } + } + } + } + +} diff --git a/test/jdk/java/net/httpclient/http3/H3UnsupportedSSLParametersTest.java b/test/jdk/java/net/httpclient/http3/H3UnsupportedSSLParametersTest.java new file mode 100644 index 00000000000..090dec2c189 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3UnsupportedSSLParametersTest.java @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +import java.io.UncheckedIOException; +import java.net.http.HttpClient; +import java.net.http.UnsupportedProtocolVersionException; + +import javax.net.ssl.SSLParameters; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/* + * @test + * @summary Tests that a HttpClient configured with SSLParameters that doesn't include TLSv1.3 + * cannot be used for HTTP3 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @run junit H3UnsupportedSSLParametersTest + */ +public class H3UnsupportedSSLParametersTest { + + /** + * Configures a HttpClient builder to use a SSLParameter which doesn't list TLSv1.3 + * as one of the supported protocols. The method then uses this builder + * to create a HttpClient for HTTP3 and expects that build() to fail with + * UnsupportedProtocolVersionException + */ + @Test + public void testNoTLSv13() throws Exception { + final SSLParameters params = new SSLParameters(); + params.setProtocols(new String[]{"TLSv1.2"}); + final UncheckedIOException uioe = assertThrows(UncheckedIOException.class, + () -> HttpServerAdapters.createClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .version(HttpClient.Version.HTTP_3) + .sslParameters(params) + .build()); + assertTrue(uioe.getCause() instanceof UnsupportedProtocolVersionException, + "Unexpected cause " + uioe.getCause() + " in HttpClient build failure"); + } + + /** + * Builds a HttpClient with SSLParameters which explicitly lists TLSv1.3 as one of the supported + * protocol versions and expects the build() to succeed and return a HttpClient instance + */ + @Test + public void testExplicitTLSv13() throws Exception { + final SSLParameters params = new SSLParameters(); + params.setProtocols(new String[]{"TLSv1.2", "TLSv1.3"}); + final HttpClient client = HttpServerAdapters.createClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .sslParameters(params) + .version(HttpClient.Version.HTTP_3).build(); + assertNotNull(client, "HttpClient is null"); + } +} diff --git a/test/jdk/java/net/httpclient/http3/H3UserInfoTest.java b/test/jdk/java/net/httpclient/http3/H3UserInfoTest.java new file mode 100644 index 00000000000..15e25b865cd --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/H3UserInfoTest.java @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.Arguments; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.util.stream.Stream; + + +import static java.net.http.HttpOption.H3_DISCOVERY; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static org.junit.jupiter.api.Assertions.fail; + +/* + * @test + * @bug 8292876 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.test.lib.net.SimpleSSLContext + * @compile ../ReferenceTracker.java + * @run junit/othervm -Djdk.httpclient.HttpClient.log=quic,errors + * -Djdk.httpclient.http3.maxDirectConnectionTimeout=4000 + * -Djdk.internal.httpclient.debug=true H3UserInfoTest + */ + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class H3UserInfoTest implements HttpServerAdapters { + + static final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; + static HttpTestServer server; + static HttpTestServer server3; + static String serverURI; + static String server3URI; + static SSLContext sslContext; + + @BeforeAll + static void before() throws Exception { + sslContext = new SimpleSSLContext().get(); + HttpTestHandler handler = new HttpHandler(); + + server = HttpTestServer.create(ANY, sslContext); + server.addHandler(handler, "/"); + serverURI = "https://" + server.serverAuthority() +"/http3-any/"; + server.start(); + + server3 = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + server3.addHandler(handler, "/"); + server3URI = "https://" + server3.serverAuthority() +"/http3-only/"; + server3.start(); + } + + @AfterAll + static void after() throws Exception { + server.stop(); + server3.stop(); + } + + static class HttpHandler implements HttpTestHandler { + @Override + public void handle(HttpTestExchange e) throws IOException { + String authorityHeader = e.getRequestHeaders() + .firstValue(":authority") + .orElse(null); + if (authorityHeader == null || authorityHeader.contains("user@")) { + e.sendResponseHeaders(500, 0); + } else { + e.sendResponseHeaders(200, 0); + } + } + } + + public static Stream servers() { + return Stream.of( + Arguments.arguments(serverURI, server), + Arguments.arguments(server3URI, server3) + ); + } + + @ParameterizedTest + @MethodSource("servers") + public void testAuthorityHeader(String serverURI, HttpTestServer server) throws Exception { + try (HttpClient client = newClientBuilderForH3() + .proxy(HttpClient.Builder.NO_PROXY) + .version(HTTP_3) + .sslContext(sslContext) + .build()) { + TRACKER.track(client); + + URI origURI = URI.create(serverURI); + URI uri = URIBuilder.newBuilder() + .scheme("https") + .userInfo("user") + .host(origURI.getHost()) + .port(origURI.getPort()) + .path(origURI.getRawPath()) + .build(); + var config = server.h3DiscoveryConfig(); + + int numRetries = 0; + while (true) { + if (config == ALT_SVC) { + // send head request + System.out.printf("Sending head request (%s) to %s%n", config, origURI); + System.err.printf("Sending head request (%s) to %s%n", config, origURI); + HttpRequest head = HttpRequest.newBuilder(origURI) + .HEAD() + .version(HTTP_2) + .build(); + var headResponse = client.send(head, BodyHandlers.ofString()); + assertEquals(200, headResponse.statusCode()); + assertEquals(HTTP_2, headResponse.version()); + assertEquals("", headResponse.body()); + } + + HttpRequest request = HttpRequest + .newBuilder(uri) + .setOption(H3_DISCOVERY, config) + .version(HTTP_3) + .GET() + .build(); + + System.out.printf("Sending GET request (%s) to %s%n", config, origURI); + System.err.printf("Sending GET request (%s) to %s%n", config, origURI); + HttpResponse response = client.send(request, BodyHandlers.ofString()); + + assertEquals(200, response.statusCode(), + "Test Failed : " + response.uri().getAuthority()); + assertEquals("", response.body()); + if (config != ANY) { + assertEquals(HTTP_3, response.version()); + } else if (response.version() != HTTP_3) { + // the request went through HTTP/2 - the next + // should go through HTTP/3 + if (numRetries++ < 3) { + System.out.printf("Received GET response (%s) to %s with version %s: " + + "repeating request once more%n", config, origURI, response.version()); + System.err.printf("Received GET response (%s) to %s with version %s: " + + "repeating request once more%n", config, origURI, response.version()); + assertEquals(HTTP_2, response.version()); + continue; + } else { + fail("Did not receive the expected HTTP3 response"); + } + } + break; + } + } + // the client should already be closed, but its facade ref might + // not have been cleared by GC yet. + System.gc(); + var error = TRACKER.checkClosed(1500); + if (error != null) throw error; + } +} diff --git a/test/jdk/java/net/httpclient/http3/HTTP3NoBodyTest.java b/test/jdk/java/net/httpclient/http3/HTTP3NoBodyTest.java new file mode 100644 index 00000000000..f04588e9362 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/HTTP3NoBodyTest.java @@ -0,0 +1,324 @@ +/* + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @key randomness + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.http3.Http3TestServer + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @compile ../ReferenceTracker.java + * @run testng/othervm -Djdk.httpclient.HttpClient.log=ssl,requests,responses,errors + * -Djdk.internal.httpclient.debug=true + * HTTP3NoBodyTest + * @summary this is a copy of http2/NoBodyTest over HTTP/3 + */ + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.*; +import javax.net.ssl.*; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpRequest.BodyPublishers; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.util.Random; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http2.Http2TestServer; +import jdk.httpclient.test.lib.http2.Http2TestExchange; +import jdk.httpclient.test.lib.http2.Http2Handler; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.RandomFactory; +import org.testng.annotations.Test; + +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.Http3DiscoveryMode.ALT_SVC; +import static java.net.http.HttpOption.Http3DiscoveryMode.ANY; +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; + +@Test +public class HTTP3NoBodyTest { + private static final Random RANDOM = RandomFactory.getRandom(); + + static int http3Port, https2Port; + static Http3TestServer http3OnlyServer; + static Http2TestServer https2AltSvcServer; + static HttpClient client = null; + static ExecutorService clientExec; + static ExecutorService serverExec; + static SSLContext sslContext; + static final String TEST_STRING = "The quick brown fox jumps over the lazy dog "; + + static volatile String http3URIString, https2URIString; + + static void initialize() throws Exception { + try { + SimpleSSLContext sslct = new SimpleSSLContext(); + sslContext = sslct.get(); + client = getClient(); + + // server that only supports HTTP/3 + http3OnlyServer = new Http3TestServer(sslContext, serverExec); + http3OnlyServer.addHandler("/", new Handler()); + http3Port = http3OnlyServer.getAddress().getPort(); + System.out.println("HTTP/3 server started at localhost:" + http3Port); + + // server that supports both HTTP/2 and HTTP/3, with HTTP/3 on an altSvc port. + https2AltSvcServer = new Http2TestServer(true, 0, serverExec, sslContext); + if (RANDOM.nextBoolean()) { + https2AltSvcServer.enableH3AltServiceOnEphemeralPort(); + } else { + https2AltSvcServer.enableH3AltServiceOnSamePort(); + } + https2AltSvcServer.addHandler(new Handler(), "/"); + https2Port = https2AltSvcServer.getAddress().getPort(); + if (https2AltSvcServer.supportsH3DirectConnection()) { + System.out.println("HTTP/2 server (same HTTP/3 origin) started at localhost:" + https2Port); + } else { + System.out.println("HTTP/2 server (different HTTP/3 origin) started at localhost:" + https2Port); + } + + http3URIString = "https://localhost:" + http3Port + "/foo/"; + https2URIString = "https://localhost:" + https2Port + "/bar/"; + + http3OnlyServer.start(); + https2AltSvcServer.start(); + } catch (Throwable e) { + System.err.println("Throwing now"); + e.printStackTrace(System.err); + throw e; + } + } + + @Test + public static void runtest() throws Exception { + try { + initialize(); + warmup(false); + warmup(true); + test(false); + test(true); + if (client != null) { + var tracker = ReferenceTracker.INSTANCE; + tracker.track(client); + client = null; + System.gc(); + var error = tracker.check(1500); + if (error != null) throw error; + } + } catch (Throwable tt) { + System.err.println("Unexpected Throwable caught"); + tt.printStackTrace(System.err); + throw tt; + } finally { + http3OnlyServer.stop(); + https2AltSvcServer.stop(); + serverExec.close(); + clientExec.close(); + } + } + + static HttpClient getClient() { + if (client == null) { + serverExec = Executors.newCachedThreadPool(); + clientExec = Executors.newCachedThreadPool(); + client = HttpServerAdapters.createClientBuilderForH3() + .executor(clientExec) + .sslContext(sslContext) + .version(HTTP_3) + .build(); + } + return client; + } + + static URI getURI(boolean altSvc) { + return getURI(altSvc, -1); + } + + static URI getURI(boolean altSvc, int step) { + return URI.create(getURIString(altSvc, step)); + } + + static String getURIString(boolean altSvc, int step) { + var uriStr = altSvc ? https2URIString : http3URIString; + return step >= 0 ? (uriStr + step) : uriStr; + } + + static void checkStatus(int expected, int found) throws Exception { + if (expected != found) { + System.err.printf ("Test failed: wrong status code %d/%d\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static void checkStrings(String expected, String found) throws Exception { + if (!expected.equals(found)) { + System.err.printf ("Test failed: wrong string %s/%s\n", + expected, found); + throw new RuntimeException("Test failed"); + } + } + + static final AtomicInteger count = new AtomicInteger(); + static Http3DiscoveryMode config(boolean http3only) { + if (http3only) return HTTP_3_URI_ONLY; + // if the server supports H3 direct connection, we can + // additionally use HTTP_3_URI_ONLY; Otherwise we can + // only use ALT_SVC - or ANY (given that we should have + // preloaded an ALT_SVC in warmup) + int bound = https2AltSvcServer.supportsH3DirectConnection() ? 4 : 3; + int rand = RANDOM.nextInt(bound); + count.getAndIncrement(); + return switch (rand) { + case 1 -> ANY; + case 2 -> ALT_SVC; + case 3 -> HTTP_3_URI_ONLY; + default -> null; + }; + } + + static final int LOOPS = 13; + + static void warmup(boolean altSvc) throws Exception { + URI uri = getURI(altSvc); + String type = altSvc ? "http2" : "http3"; + System.out.println("warmup: " + type); + System.err.println("Request to " + uri); + var http3Only = altSvc == false; + var config = config(http3Only); + + // in the warmup phase, we want to make sure + // to preload the ALT_SVC, otherwise the first + // request that uses ALT_SVC might go through HTTP/2 + if (altSvc) config = ALT_SVC; + + // Do a simple warmup request + + HttpClient client = getClient(); + var builder = HttpRequest.newBuilder(uri); + HttpRequest req = builder + .POST(BodyPublishers.ofString("Random text")) + .setOption(H3_DISCOVERY, config) + .build(); + HttpResponse response = client.send(req, BodyHandlers.ofString()); + checkStatus(200, response.statusCode()); + String responseBody = response.body(); + HttpHeaders h = response.headers(); + checkStrings(TEST_STRING + type, responseBody); + System.out.println("warmup: " + type + " done"); + System.err.println("warmup: " + type + " done"); + } + + static void test(boolean http2) throws Exception { + URI uri = getURI(http2); + String type = http2 ? "http2" : "http3"; + System.err.println("Request to " + uri); + var http3Only = http2 == false; + for (int i = 0; i < LOOPS; i++) { + var config = config(http3Only); + URI uri2 = getURI(http2, i); + HttpRequest request = HttpRequest.newBuilder(uri2) + .POST(BodyPublishers.ofString(TEST_STRING)) + .setOption(H3_DISCOVERY, config) + .build(); + System.out.println(type + ": Loop " + i + ", config: " + config + ", uri: " + uri2); + HttpResponse response = client.send(request, BodyHandlers.ofString()); + int expectedResponse = (i % 2) == 0 ? 200 : 204; + if (response.statusCode() != expectedResponse) + throw new RuntimeException("wrong response code " + response.statusCode()); + if (expectedResponse == 200 && !response.body().equals(TEST_STRING + type)) { + System.err.printf(type + " response received/expected %s/%s\n", response.body(), TEST_STRING + type); + throw new RuntimeException("wrong response body"); + } + if (response.version() != HTTP_3) { + throw new RuntimeException("wrong response version: " + response.version()); + } + System.out.println(type + ": Loop " + i + " done"); + } + System.err.println("test: " + type + " DONE"); + } + + static URI base(URI uri) { + var uriStr = uri.toString(); + if (uriStr.startsWith(http3URIString)) { + if (uriStr.equals(http3URIString)) return uri; + return URI.create(http3URIString); + } else if (uri.toString().startsWith(https2URIString)) { + if (uriStr.equals(https2URIString)) return uri; + return URI.create(https2URIString); + } else return uri; + } + + static class Handler implements Http2Handler { + + public Handler() {} + + volatile int invocation = 0; + + @Override + public void handle(Http2TestExchange t) + throws IOException { + try { + URI uri = t.getRequestURI(); + System.err.printf("Handler received request to %s from %s\n", + uri, t.getRemoteAddress()); + String type = uri.toString().startsWith(http3URIString) + ? "http3" : "http2"; + InputStream is = t.getRequestBody(); + while (is.read() != -1); + is.close(); + + // every second response is 204. + var base = base(uri); + int step = base == uri ? 0 : Integer.parseInt(base.relativize(uri).toString()); + invocation++; + + if ((step++ % 2) == 1) { + System.err.println("Server sending 204"); + t.sendResponseHeaders(204, -1); + } else { + System.err.println("Server sending 200"); + String body = TEST_STRING + type; + t.sendResponseHeaders(200, body.length()); + OutputStream os = t.getResponseBody(); + os.write(body.getBytes()); + os.close(); + } + } catch (Throwable e) { + e.printStackTrace(System.err); + throw new IOException(e); + } + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/Http3ExpectContinueTest.java b/test/jdk/java/net/httpclient/http3/Http3ExpectContinueTest.java new file mode 100644 index 00000000000..2052671b697 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/Http3ExpectContinueTest.java @@ -0,0 +1,250 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @summary Tests Http3 expect continue + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @compile ../ReferenceTracker.java + * @build jdk.httpclient.test.lib.common.HttpServerAdapters + * @run testng/othervm -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=errors,requests,headers + * Http3ExpectContinueTest + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.httpclient.test.lib.quic.QuicServer; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.DataProvider; +import org.testng.TestException; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.PrintStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; +import java.net.http.HttpRequest; +import java.net.http.HttpOption.Http3DiscoveryMode; +import java.net.http.HttpOption; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import static java.net.http.HttpClient.Version.HTTP_3; +import static org.testng.Assert.*; + +public class Http3ExpectContinueTest implements HttpServerAdapters { + + ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; + + Http3TestServer http3TestServer; + + URI h3postUri, h3forcePostUri, h3hangUri; + + static PrintStream err = new PrintStream(System.err); + static PrintStream out = new PrintStream(System.out); + static final String EXPECTATION_FAILED_417 = "417 Expectation Failed"; + static final String CONTINUE_100 = "100 Continue"; + static final String RESPONSE_BODY= "Verify response received"; + static final String BODY = "Post body"; + private SSLContext sslContext; + + @DataProvider(name = "uris") + public Object[][] urisData() { + return new Object[][]{ + // URI, Expected Status Code, Will finish with Exception + { h3postUri, 200, false }, + { h3forcePostUri, 200, false }, + { h3hangUri, 417, false }, + }; + } + + @Test(dataProvider = "uris") + public void test(URI uri, int expectedStatusCode, boolean exceptionally) + throws CancellationException, InterruptedException, ExecutionException, IOException { + + err.printf("\nTesting URI: %s, exceptionally: %b\n", uri, exceptionally); + out.printf("\nTesting URI: %s, exceptionally: %b\n", uri, exceptionally); + HttpClient client = newClientBuilderForH3(). + proxy(Builder.NO_PROXY) + .version(HTTP_3).sslContext(sslContext) + .build(); + AssertionError failed = null; + TRACKER.track(client); + try { + HttpResponse resp = null; + Throwable testThrowable = null; + + HttpRequest postRequest = HttpRequest.newBuilder(uri) + .version(HTTP_3) + .setOption(HttpOption.H3_DISCOVERY, + Http3DiscoveryMode.HTTP_3_URI_ONLY) + .POST(HttpRequest.BodyPublishers.ofString(BODY)) + .expectContinue(true) + .build(); + + err.printf("Sending request: %s%n", postRequest); + CompletableFuture> cf = client.sendAsync(postRequest, HttpResponse.BodyHandlers.ofString()); + try { + resp = cf.get(); + } catch (Exception e) { + testThrowable = e.getCause(); + } + verifyRequest(uri.getPath(), expectedStatusCode, resp, exceptionally, testThrowable); + } catch (Throwable x) { + failed = new AssertionError("Unexpected exception:" + x, x); + } finally { + client.shutdown(); + if (!client.awaitTermination(Duration.ofMillis(1000))) { + var tracker = TRACKER.getTracker(client); + client = null; + var error = TRACKER.check(tracker, 2000); + if (error != null || failed != null) { + var ex = failed == null ? error : failed; + err.printf("FAILED URI: %s, exceptionally: %b, error: %s\n", uri, exceptionally, ex); + out.printf("FAILED URI: %s, exceptionally: %b, error: %s\n", uri, exceptionally, ex); + } + if (error != null) { + if (failed != null) { + failed.addSuppressed(error); + throw failed; + } + throw error; + } + } + } + if (failed != null) { + err.printf("FAILED URI: %s, exceptionally: %b, error: %s\n", uri, exceptionally, failed); + out.printf("FAILED URI: %s, exceptionally: %b, error: %s\n", uri, exceptionally, failed); + throw failed; + } + } + + @BeforeTest + public void setup() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + final QuicServer quicServer = Http3TestServer.quicServerBuilder() + .sslContext(sslContext) + .build(); + http3TestServer = new Http3TestServer(quicServer); + http3TestServer.addHandler("/http3/post", new PostHandler().toHttp2Handler()); + http3TestServer.addHandler("/http3/forcePost", new ForcePostHandler().toHttp2Handler()); + http3TestServer.addHandler("/http3/hang", new PostHandlerCantContinue().toHttp2Handler()); + + h3postUri = new URI("https://" + http3TestServer.serverAuthority() + "/http3/post"); + h3forcePostUri = URI.create("https://" + http3TestServer.serverAuthority() + "/http3/forcePost"); + h3hangUri = URI.create("https://" + http3TestServer.serverAuthority() + "/http3/hang"); + out.printf("HTTP/3 server listening at: %s", http3TestServer.getAddress()); + + http3TestServer.start(); + } + + @AfterTest + public void teardown() throws IOException { + var error = TRACKER.check(500); + if (error != null) throw error; + http3TestServer.stop(); + } + + static class PostHandler implements HttpTestHandler { + + @java.lang.Override + public void handle(HttpTestExchange exchange) throws IOException { + System.out.printf("Server version %s and exchange version %s", exchange.getServerVersion(), exchange.getExchangeVersion()); + + if(exchange.getExchangeVersion().equals(HTTP_3)){ + // send 100 header + byte[] ContinueResponseBytes = CONTINUE_100.getBytes(); + err.println("Server send 100 (length="+ContinueResponseBytes.length+")"); + exchange.sendResponseHeaders(100, ContinueResponseBytes.length); + } + + // Read body from client and acknowledge with 200 + try (InputStream is = exchange.getRequestBody()) { + err.println("Server reading body"); + var bytes = is.readAllBytes(); + String responseBody = new String(bytes); + assert responseBody.equals(BODY); + byte[] responseBodyBytes = RESPONSE_BODY.getBytes(); + err.println("Server send 200 (length="+responseBodyBytes.length+")"); + exchange.sendResponseHeaders(200, responseBodyBytes.length); + exchange.getResponseBody().write(responseBodyBytes); + } + } + } + + static class ForcePostHandler implements HttpTestHandler { + @Override + public void handle(HttpTestExchange exchange) throws IOException { + try (InputStream is = exchange.getRequestBody()) { + err.println("Server reading body inside the force Post"); + is.readAllBytes(); + err.println("Server send 200 (length=0) in the force post"); + exchange.sendResponseHeaders(200, 0); + } + } + } + + static class PostHandlerCantContinue implements HttpTestHandler { + @java.lang.Override + public void handle(HttpTestExchange exchange) throws IOException { + //Send 417 Headers, tell client to not send body + try (OutputStream os = exchange.getResponseBody()) { + byte[] bytes = EXPECTATION_FAILED_417.getBytes(); + err.println("Server send 417 (length="+bytes.length+")"); + exchange.sendResponseHeaders(417, bytes.length); + err.println("Server sending Response Body"); + os.write(bytes); + } + } + } + + private void verifyRequest(String path, int expectedStatusCode, HttpResponse resp, boolean exceptionally, Throwable testThrowable) { + if (!exceptionally) { + err.printf("Response code %s received for path %s %n", resp.statusCode(), path); + } + if (exceptionally && testThrowable != null) { + err.println("Finished exceptionally Test throwable: " + testThrowable); + assertEquals(IOException.class, testThrowable.getClass()); + } else if (exceptionally) { + throw new TestException("Expected case to finish with an IOException but testException is null"); + } else if (resp != null) { + assertEquals(resp.statusCode(), expectedStatusCode); + err.println("Request completed successfully for path " + path); + err.println("Response Headers: " + resp.headers()); + err.println("Response Status Code: " + resp.statusCode()); + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/PeerUniStreamDispatcherTest.java b/test/jdk/java/net/httpclient/http3/PeerUniStreamDispatcherTest.java new file mode 100644 index 00000000000..be1b8304bf1 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/PeerUniStreamDispatcherTest.java @@ -0,0 +1,436 @@ +/* + * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @run testng/othervm + * -Djdk.internal.httpclient.debug=out + * PeerUniStreamDispatcherTest + * @summary Unit test for the PeerUniStreamDispatcher + */ + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CopyOnWriteArrayList; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.streams.Http3Streams; +import jdk.internal.net.http.http3.streams.PeerUniStreamDispatcher; +import jdk.internal.net.http.quic.VariableLengthEncoder; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.http.quic.streams.QuicStreams; + +import org.testng.annotations.Test; +import static org.testng.Assert.*; + +public class PeerUniStreamDispatcherTest { + + final Logger debug = Utils.getDebugLogger(() -> "PeerUniStreamDispatcherStub"); + + enum DISPATCHED_STREAM { + CONTROL, ENCODER, DECODER, PUSH, RESERVED, UNKNOWN + } + + sealed interface DispatchedStream { + record StandardStream(DISPATCHED_STREAM type, String description, QuicReceiverStream stream) + implements DispatchedStream { } + record PushStream(DISPATCHED_STREAM type, String description, QuicReceiverStream stream, long pushId) + implements DispatchedStream { } + record UnknownStream(DISPATCHED_STREAM type, long code, QuicReceiverStream stream) + implements DispatchedStream { } + record ReservedStream(DISPATCHED_STREAM type, long code, QuicReceiverStream stream) + implements DispatchedStream { } + + static DispatchedStream of(DISPATCHED_STREAM type, String description, QuicReceiverStream stream) { + return new StandardStream(type, description, stream); + } + static DispatchedStream of(DISPATCHED_STREAM type, String description, QuicReceiverStream stream, long pushId) { + return new PushStream(type, description, stream, pushId); + } + static DispatchedStream reserved(DISPATCHED_STREAM type, long code, QuicReceiverStream stream) { + return new ReservedStream(type, code, stream); + } + static DispatchedStream unknown(DISPATCHED_STREAM type, long code, QuicReceiverStream stream) { + return new UnknownStream(type, code, stream); + } + } + + class PeerUniStreamDispatcherStub extends PeerUniStreamDispatcher { + + final List dispatched = new CopyOnWriteArrayList<>(); + + PeerUniStreamDispatcherStub(QuicReceiverStream stream) { + super(stream); + } + + private void dispatched(DISPATCHED_STREAM type, String description, QuicReceiverStream stream) { + dispatched.add(DispatchedStream.of(type, description, stream)); + } + + private void dispatched(DISPATCHED_STREAM type, String description, QuicReceiverStream stream, long pushId) { + dispatched.add(DispatchedStream.of(type, description, stream, pushId)); + } + + private void dispatched(DISPATCHED_STREAM type, long code, QuicReceiverStream stream) { + dispatched.add(switch (type) { + case UNKNOWN -> DispatchedStream.unknown(type, code, stream); + case RESERVED -> DispatchedStream.reserved(type, code, stream); + default -> throw new IllegalArgumentException(String.valueOf(type)); + }); + } + + @Override + protected Logger debug() { + return debug; + } + + @Override + protected void onControlStreamCreated(String description, QuicReceiverStream stream) { + dispatched(DISPATCHED_STREAM.CONTROL, description, stream); + } + + @Override + protected void onEncoderStreamCreated(String description, QuicReceiverStream stream) { + dispatched(DISPATCHED_STREAM.ENCODER, description, stream); + } + + @Override + protected void onDecoderStreamCreated(String description, QuicReceiverStream stream) { + dispatched(DISPATCHED_STREAM.DECODER, description, stream); + } + + @Override + protected void onPushStreamCreated(String description, QuicReceiverStream stream, long pushId) { + dispatched(DISPATCHED_STREAM.PUSH, description, stream, pushId); + } + + @Override + protected void onReservedStreamType(long code, QuicReceiverStream stream) { + dispatched(DISPATCHED_STREAM.RESERVED, code, stream); + super.onReservedStreamType(code, stream); + } + + @Override + protected void onUnknownStreamType(long code, QuicReceiverStream stream) { + dispatched(DISPATCHED_STREAM.UNKNOWN, code, stream); + super.onUnknownStreamType(code, stream); + } + + @Override + public void start() { + super.start(); + } + } + + static class QuicReceiverStreamStub implements QuicReceiverStream { + + class QuicStreamReaderStub extends QuicStreamReader { + + volatile boolean connected, started; + QuicStreamReaderStub(SequentialScheduler scheduler) { + super(scheduler); + } + + @Override + public ReceivingStreamState receivingState() { + return QuicReceiverStreamStub.this.receivingState(); + } + + @Override + public ByteBuffer poll() throws IOException { + return buffers.poll(); + } + + @Override + public ByteBuffer peek() throws IOException { + return buffers.peek(); + } + + @Override + public QuicReceiverStream stream() { + return QuicReceiverStreamStub.this; + } + + @Override + public boolean connected() { + return connected; + } + + @Override + public boolean started() { + return started; + } + + @Override + public void start() { + started = true; + if (!buffers.isEmpty()) scheduler.runOrSchedule(); + } + } + + volatile QuicStreamReaderStub reader; + volatile SequentialScheduler scheduler; + volatile long errorCode; + final long streamId; + ConcurrentLinkedQueue buffers = new ConcurrentLinkedQueue<>(); + + QuicReceiverStreamStub(long streamId) { + this.streamId = streamId; + } + + + @Override + public ReceivingStreamState receivingState() { + return ReceivingStreamState.RECV; + } + + @Override + public QuicStreamReader connectReader(SequentialScheduler scheduler) { + this.scheduler = scheduler; + var reader = this.reader + = new QuicStreamReaderStub(scheduler); + reader.connected = true; + return reader; + } + + @Override + public void disconnectReader(QuicStreamReader reader) { + this.scheduler = null; + this.reader = null; + ((QuicStreamReaderStub) reader).connected = false; + } + + @Override + public void requestStopSending(long errorCode) { + this.errorCode = errorCode; + } + + @Override + public long dataReceived() { + return 0; + } + + @Override + public long maxStreamData() { + return 0; + } + + @Override + public long rcvErrorCode() { + return errorCode; + } + + @Override + public long streamId() { + return streamId; + } + + @Override + public StreamMode mode() { + return StreamMode.READ_ONLY; + } + + @Override + public boolean isClientInitiated() { + return QuicStreams.isClientInitiated(streamId); + } + + @Override + public boolean isServerInitiated() { + return QuicStreams.isServerInitiated(streamId); + } + + @Override + public boolean isBidirectional() { + return QuicStreams.isBidirectional(streamId); + } + + @Override + public boolean isLocalInitiated() { + return isClientInitiated(); + } + + @Override + public boolean isRemoteInitiated() { + return !isClientInitiated(); + } + + @Override + public int type() { + return QuicStreams.streamType(streamId); + } + + @Override + public StreamState state() { + return ReceivingStreamState.RECV; + } + + } + + private void simpleStreamType(DISPATCHED_STREAM type, long code) { + System.out.println("Testing " + type + " with " + code); + QuicReceiverStreamStub stream = new QuicReceiverStreamStub(QuicStreams.UNI_MASK + QuicStreams.SRV_MASK); + PeerUniStreamDispatcherStub dispatcher = new PeerUniStreamDispatcherStub(stream); + QuicStreamReader reader = stream.reader; + SequentialScheduler scheduler = stream.scheduler; + assertTrue(reader.connected()); + int size = VariableLengthEncoder.getEncodedSize(code); + ByteBuffer buffer = ByteBuffer.allocate(size); + assertEquals(buffer.remaining(), size); + VariableLengthEncoder.encode(buffer, code); + buffer.flip(); + stream.buffers.add(buffer); + scheduler.runOrSchedule(); + dispatcher.start(); + if (type == DISPATCHED_STREAM.PUSH) { + // we want to encode the pushId in multiple buffers, but call + // the scheduler only once to check that the dispatcher + // will loop correctly. + size = VariableLengthEncoder.getEncodedSize(1L << 62 - 5); + ByteBuffer buffer2 = ByteBuffer.allocate(size); + assertEquals(buffer2.remaining(), size); + VariableLengthEncoder.encode(buffer2, 1L << 62 - 5); + buffer2.flip(); + stream.buffers.add(ByteBuffer.wrap(new byte[] {buffer2.get()})); + scheduler.runOrSchedule(); // call runOrSchedule after supplying the first byte. + assertTrue(reader.connected()); + assert buffer2.remaining() > 1; // should always be true + while (buffer2.hasRemaining()) { + stream.buffers.add(ByteBuffer.wrap(new byte[] {buffer2.get()})); + } + } + scheduler.runOrSchedule(); + assertFalse(reader.connected()); + assertFalse(dispatcher.dispatched.isEmpty()); + assertTrue(stream.buffers.isEmpty()); + assertEquals(dispatcher.dispatched.size(), 1); + var dispatched = dispatcher.dispatched.get(0); + checkDispatched(type, code, stream, dispatched); + } + + private void checkDispatched(DISPATCHED_STREAM type, + long code, + QuicReceiverStream stream, + DispatchedStream dispatched) { + var streamClass = switch (type) { + case CONTROL, ENCODER, DECODER -> DispatchedStream.StandardStream.class; + case PUSH -> DispatchedStream.PushStream.class; + case RESERVED -> DispatchedStream.ReservedStream.class; + case UNKNOWN -> DispatchedStream.UnknownStream.class; + }; + assertEquals(dispatched.getClass(), streamClass, + "unexpected dispatched class " + dispatched + " for " + type); + if (dispatched instanceof DispatchedStream.StandardStream st) { + System.out.println("Got expected stream: " + st); + assertEquals(st.type(), type); + assertEquals(st.stream, stream); + } else if (dispatched instanceof DispatchedStream.ReservedStream res) { + System.out.println("Got expected stream: " + res); + assertEquals(res.type(), type); + assertEquals(res.stream, stream); + assertEquals(res.code(), code); + assertTrue(Http3Streams.isReserved(res.code())); + } else if (dispatched instanceof DispatchedStream.UnknownStream unk) { + System.out.println("Got expected stream: " + unk); + assertEquals(unk.type(), type); + assertEquals(unk.stream, stream); + assertEquals(unk.code(), code); + assertFalse(Http3Streams.isReserved(unk.code())); + } else if (dispatched instanceof DispatchedStream.PushStream push) { + System.out.println("Got expected stream: " + push); + assertEquals(push.type(), type); + assertEquals(push.stream, stream); + assertEquals(push.pushId, 1L << 62 - 5); + assertEquals(push.type(), DISPATCHED_STREAM.PUSH); + } + + } + @Test + public void simpleControl() { + simpleStreamType(DISPATCHED_STREAM.CONTROL, Http3Streams.CONTROL_STREAM_CODE); + } + @Test + public void simpleDecoder() { + simpleStreamType(DISPATCHED_STREAM.DECODER, Http3Streams.QPACK_DECODER_STREAM_CODE); + } + @Test + public void simpleEncoder() { + simpleStreamType(DISPATCHED_STREAM.ENCODER, Http3Streams.QPACK_ENCODER_STREAM_CODE); + } + @Test + public void simplePush() { + simpleStreamType(DISPATCHED_STREAM.PUSH, Http3Streams.PUSH_STREAM_CODE); + } + @Test + public void simpleUknown() { + simpleStreamType(DISPATCHED_STREAM.UNKNOWN, VariableLengthEncoder.MAX_ENCODED_INTEGER); + } + @Test + public void simpleReserved() { + simpleStreamType(DISPATCHED_STREAM.RESERVED, 31 * 256 + 2); + } + + @Test + public void multyBytes() { + DISPATCHED_STREAM type = DISPATCHED_STREAM.UNKNOWN; + long code = VariableLengthEncoder.MAX_ENCODED_INTEGER; + System.out.println("Testing multi byte " + type + " with " + code); + QuicReceiverStreamStub stream = new QuicReceiverStreamStub(QuicStreams.UNI_MASK + QuicStreams.SRV_MASK); + PeerUniStreamDispatcherStub dispatcher = new PeerUniStreamDispatcherStub(stream); + QuicStreamReader reader = stream.reader; + SequentialScheduler scheduler = stream.scheduler; + assertTrue(reader.connected()); + int size = VariableLengthEncoder.getEncodedSize(code); + assertEquals(size, 8); + ByteBuffer buffer = ByteBuffer.allocate(size); + assertEquals(buffer.remaining(), size); + VariableLengthEncoder.encode(buffer, code); + buffer.flip(); + dispatcher.start(); + for (int i=0; i FAILURES = new ConcurrentHashMap<>(); + static volatile boolean tasksFailed; + static final AtomicLong serverCount = new AtomicLong(); + static final AtomicLong clientCount = new AtomicLong(); + static final long start = System.nanoTime(); + public static String now() { + long now = System.nanoTime() - start; + long secs = now / 1000_000_000; + long mill = (now % 1000_000_000) / 1000_000; + long nan = now % 1000_000; + return String.format("[%d s, %d ms, %d ns] ", secs, mill, nan); + } + + final ReferenceTracker TRACKER = ReferenceTracker.INSTANCE; + final Set sharedClientHasH3 = ConcurrentHashMap.newKeySet(); + private volatile HttpClient sharedClient; + private boolean directQuicConnectionSupported; + + static class TestExecutor implements Executor { + final AtomicLong tasks = new AtomicLong(); + Executor executor; + TestExecutor(Executor executor) { + this.executor = executor; + } + + @java.lang.Override + public void execute(Runnable command) { + long id = tasks.incrementAndGet(); + executor.execute(() -> { + try { + command.run(); + } catch (Throwable t) { + tasksFailed = true; + System.out.printf(now() + "Task %s failed: %s%n", id, t); + System.err.printf(now() + "Task %s failed: %s%n", id, t); + FAILURES.putIfAbsent("Task " + id, t); + throw t; + } + }); + } + } + + protected boolean stopAfterFirstFailure() { + return Boolean.getBoolean("jdk.internal.httpclient.debug"); + } + + @BeforeMethod + void beforeMethod(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + var x = new SkipException("Skipping: some test failed"); + x.setStackTrace(new StackTraceElement[0]); + throw x; + } + } + + @AfterClass + static void printFailedTests() { + out.println("\n========================="); + try { + out.printf("%n%sCreated %d servers and %d clients%n", + now(), serverCount.get(), clientCount.get()); + if (FAILURES.isEmpty()) return; + out.println("Failed tests: "); + FAILURES.forEach((key, value) -> { + out.printf("\t%s: %s%n", key, value); + value.printStackTrace(out); + value.printStackTrace(); + }); + if (tasksFailed) { + System.out.println("WARNING: Some tasks failed"); + } + } finally { + out.println("\n=========================\n"); + } + } + + private String[] uris() { + return new String[] { + h3URI, + }; + } + + @DataProvider(name = "variants") + public Object[][] variants(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + return new Object[0][]; + } + String[] uris = uris(); + Object[][] result = new Object[uris.length * 2 * 2 * 2][]; + int i = 0; + for (var version : List.of(Optional.empty(), Optional.of(HTTP_3))) { + for (Version firstRequestVersion : List.of(HTTP_2, HTTP_3)) { + for (boolean sameClient : List.of(false, true)) { + for (String uri : uris()) { + result[i++] = new Object[]{uri, firstRequestVersion, sameClient, version}; + } + } + } + } + assert i == result.length; + return result; + } + + @DataProvider(name = "uris") + public Object[][] uris(ITestContext context) { + if (stopAfterFirstFailure() && context.getFailedTests().size() > 0) { + return new Object[0][]; + } + Object[][] result = {{h3URI}}; + return result; + } + + private HttpClient makeNewClient() { + clientCount.incrementAndGet(); + HttpClient client = newClientBuilderForH3() + .version(HTTP_3) + .proxy(HttpClient.Builder.NO_PROXY) + .executor(executor) + .sslContext(sslContext) + .connectTimeout(Duration.ofSeconds(10)) + .build(); + return TRACKER.track(client); + } + + HttpClient newHttpClient(boolean share) { + if (!share) return makeNewClient(); + HttpClient shared = sharedClient; + if (shared != null) return shared; + synchronized (this) { + shared = sharedClient; + if (shared == null) { + shared = sharedClient = makeNewClient(); + } + return shared; + } + } + + BodyPublisher oflines(boolean streaming, String ...lines) { + if (streaming) { + return BodyPublishers.fromPublisher(BodyPublishers.concat( + Stream.of(lines) + .map(s -> s + '\n') + .map(BodyPublishers::ofString) + .toArray(BodyPublisher[]::new))); + } else { + return BodyPublishers.fromPublisher(BodyPublishers.concat( + Stream.of(lines) + .map(s -> s + '\n') + .map(BodyPublishers::ofString) + .toArray(BodyPublisher[]::new)), + Stream.of(lines).mapToLong(String::length) + .map((l) -> l+1) + .sum()); + } + } + + + @Test(dataProvider = "variants") + public void testAsync(String uri, Version firstRequestVersion, boolean sameClient, Optional version) throws Exception { + System.out.println("Request to " + uri +"/Async/*" + + ", firstRequestVersion=" + firstRequestVersion + + ", sameclient=" + sameClient + ",version=" + version); + + HttpClient client = newHttpClient(sameClient); + final URI headURI = URI.create(uri + "/Async/First/HEAD"); + final Builder headBuilder = HttpRequest.newBuilder(headURI) + .version(firstRequestVersion) + .HEAD(); + Http3DiscoveryMode config = null; + if (firstRequestVersion == HTTP_3 && !directQuicConnectionSupported) { + // if the server doesn't listen for HTTP/3 on the same port than TCP, then + // do not attempt to connect to the URI host:port through UDP - as we might + // be connecting to some other server. Once the first request has gone + // through, there should be an AltService record for the server, so + // we should be able to safely use any default config (except + // HTTP_3_URI_ONLY) + config = ALT_SVC; + } + if (config != null) { + out.println("first request will use " + config); + headBuilder.setOption(H3_DISCOVERY, config); + config = null; + } + + HttpResponse response1 = client.send(headBuilder.build(), BodyHandlers.ofString()); + assertEquals(response1.statusCode(), 200, "Unexpected first response code"); + assertEquals(response1.body(), "", "Unexpected first response body"); + boolean expectH3 = sameClient && sharedClientHasH3.contains(headURI.getRawAuthority()); + if (firstRequestVersion == HTTP_3) { + if (expectH3) { + out.println("Expecting HEAD response over HTTP_3"); + assertEquals(response1.version(), HTTP_3, "Unexpected first response version"); + } + } else { + out.println("Expecting HEAD response over HTTP_2"); + assertEquals(response1.version(), HTTP_2, "Unexpected first response version"); + } + out.println("HEAD response version: " + response1.version()); + if (response1.version() == HTTP_2) { + if (sameClient) { + sharedClientHasH3.add(headURI.getRawAuthority()); + } + expectH3 = version.isEmpty() && client.version() == HTTP_3; + if (version.orElse(null) == HTTP_3 && !directQuicConnectionSupported) { + config = ALT_SVC; + expectH3 = true; + } + // we can expect H3 only if the (default) config is not ANY + if (expectH3) { + out.println("first response came over HTTP/2, so we should expect all responses over HTTP/3"); + } + } else if (response1.version() == HTTP_3) { + expectH3 = directQuicConnectionSupported && version.orElse(null) == HTTP_3; + if (expectH3) { + out.println("first response came over HTTP/3, direct connection supported: expect HTTP/3"); + } else if (firstRequestVersion == HTTP_3 && version.isEmpty() + && config == null && directQuicConnectionSupported) { + config = ANY; + expectH3 = true; + } + } + out.printf("request version: %s, directConnectionSupported: %s, first response: %s," + + " config: %s, expectH3: %s%n", + version, directQuicConnectionSupported, response1.version(), config, expectH3); + if (expectH3) { + out.println("All responses should now come through HTTP/3"); + } + + Builder builder = HttpRequest.newBuilder(); + version.ifPresent(builder::version); + if (config != null) { + builder.setOption(H3_DISCOVERY, config); + } + Map>> responses = new HashMap<>(); + boolean streaming = false; + int h3Count = 0; + for (int i = 0; i < ITERATION_COUNT; i++) { + streaming = !streaming; + HttpRequest request = builder.uri(URI.create(uri+"/Async/"+i)) + .POST(oflines(streaming, BODY.split("\n"))) + .build(); + System.out.println("Iteration: " + request.uri()); + responses.put(request.uri(), client.sendAsync(request, BodyHandlers.ofString())); + } + while (!responses.isEmpty()) { + CompletableFuture.anyOf(responses.values().toArray(CompletableFuture[]::new)).join(); + var done = responses.entrySet().stream() + .filter((e) -> e.getValue().isDone()).toList(); + for (var e : done) { + URI u = e.getKey(); + responses.remove(u); + out.println("Checking response: " + u); + var response = e.getValue().get(); + out.println("Response is: " + response + ", [version: " + response.version() + "]"); + assertEquals(response.statusCode(), 200,"status for " + u); + assertEquals(response.body(), BODY,"body for " + u); + if (expectH3) { + assertEquals(response.version(), HTTP_3, "version for " + u); + } + if (response.version() == HTTP_3) { + h3Count++; + } + } + } + if (client.version() == HTTP_3 || version.orElse(null) == HTTP_3) { + if (h3Count == 0) { + throw new AssertionError("No request used HTTP/3"); + } + } + if (!sameClient) { + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + AssertionError error = TRACKER.check(tracker, 1000); + if (error != null) throw error; + } + System.out.println("test: DONE"); + } + + @Test(dataProvider = "uris") + public void testSync(String h3URI) throws Exception { + HttpClient client = makeNewClient(); + Builder builder = HttpRequest.newBuilder(URI.create(h3URI + "/Sync/1")) + .version(HTTP_3); + if (!directQuicConnectionSupported) { + // if the server doesn't listen for HTTP/3 on the same port than TCP, then + // do not attempt to connect to the URI host:port through UDP - as we might + // be connecting to some other server. Once the first request has gone + // through, there should be an AltService record for the server, so + // we should be able to safely use any default config (except + // HTTP_3_URI_ONLY) + builder.setOption(H3_DISCOVERY, ALT_SVC); + } + + HttpRequest request = builder + .POST(oflines(true, BODY.split("\n"))) + .build(); + HttpResponse response = client.send(request, BodyHandlers.ofString()); + out.println("Response #1: " + response); + out.println("Version #1: " + response.version()); + out.println("Body #1:\n" + response.body().indent(4)); + assertEquals(response.statusCode(), 200, "first response status"); + if (directQuicConnectionSupported) { + // TODO unreliable assertion + //assertEquals(response.version(), HTTP_3, "Unexpected first response version"); + } else { + assertEquals(response.version(), HTTP_2, "Unexpected first response version"); + } + assertEquals(response.body(), BODY, "first response body"); + + request = builder.uri(URI.create(h3URI + "/Sync/2")) + .POST(oflines(true, BODY.split("\n"))) + .build(); + response = client.send(request, BodyHandlers.ofString()); + out.println("Response #2: " + response); + out.println("Version #2: " + response.version()); + out.println("Body #2:\n" + response.body().indent(4)); + assertEquals(response.statusCode(), 200, "second response status"); + assertEquals(response.version(), HTTP_3, "second response version"); + assertEquals(response.body(), BODY, "second response body"); + + request = builder.uri(URI.create(h3URI + "/Sync/3")) + .POST(oflines(true, BODY.split("\n"))) + .build(); + response = client.send(request, BodyHandlers.ofString()); + out.println("Response #3: " + response); + out.println("Version #3: " + response.version()); + out.println("Body #3:\n" + response.body().indent(4)); + assertEquals(response.statusCode(), 200, "third response status"); + assertEquals(response.version(), HTTP_3, "third response version"); + assertEquals(response.body(), BODY, "third response body"); + + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + AssertionError error = TRACKER.check(tracker, 1000); + if (error != null) throw error; + } + + @BeforeTest + public void setup() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) + throw new AssertionError("Unexpected null sslContext"); + final Http2TestServer h2WithAltService = new Http2TestServer("localhost", true, sslContext) + .enableH3AltServiceOnSamePort(); + h3TestServer = HttpTestServer.of(h2WithAltService); + h3TestServer.addHandler(new Handler(), "/h3/testH3/"); + h3URI = "https://" + h3TestServer.serverAuthority() + "/h3/testH3/POST"; + + serverCount.addAndGet(1); + h3TestServer.start(); + directQuicConnectionSupported = h2WithAltService.supportsH3DirectConnection(); + } + + @AfterTest + public void teardown() throws Exception { + System.err.println("======================================================="); + System.err.println(" Tearing down test"); + System.err.println("======================================================="); + String sharedClientName = + sharedClient == null ? null : sharedClient.toString(); + sharedClient = null; + Thread.sleep(100); + AssertionError fail = TRACKER.check(500); + try { + h3TestServer.stop(); + } finally { + if (fail != null) { + if (sharedClientName != null) { + System.err.println("Shared client name is: " + sharedClientName); + } + throw fail; + } + } + } + + static class Handler implements HttpTestHandler { + public Handler() {} + + volatile int invocation = 0; + + @java.lang.Override + public void handle(HttpTestExchange t) + throws IOException { + try { + URI uri = t.getRequestURI(); + System.err.printf("Handler received request for %s\n", uri); + + boolean head = "HEAD".equals(t.getRequestMethod()); + if ((invocation++ % 2) == 1 && !head) { + System.err.printf("Server sending %d - chunked\n", 200); + t.sendResponseHeaders(200, -1); + } else { + System.err.printf("Server sending %d - %s length\n", 200, BODY.length()); + t.sendResponseHeaders(200, BODY.length()); + } + try (InputStream is = t.getRequestBody(); + OutputStream os = t.getResponseBody()) { + BufferedReader reader = new BufferedReader(new InputStreamReader(is)); + reader.lines().forEach((line) -> { + try { + if (head) return; + os.write(line.getBytes(StandardCharsets.UTF_8)); + os.write('\n'); + os.flush(); + } catch (IOException io) { + throw new UncheckedIOException(io); + } + }); + } + } catch (Throwable e) { + e.printStackTrace(System.err); + throw new IOException(e); + } + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/StopSendingTest.java b/test/jdk/java/net/httpclient/http3/StopSendingTest.java new file mode 100644 index 00000000000..2a029e176a2 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/StopSendingTest.java @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.OutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.internal.net.http.ResponseSubscribers; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.test.lib.net.URIBuilder; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY; +import static java.net.http.HttpOption.H3_DISCOVERY; + +/* + * @test + * @summary Exercises the HTTP3 client to send a STOP_SENDING frame + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * @compile ../ReferenceTracker.java + * @run testng/othervm -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=requests,responses,errors StopSendingTest + */ +public class StopSendingTest implements HttpServerAdapters { + + private SSLContext sslContext; + private HttpTestServer h3Server; + private String requestURIBase; + + @BeforeClass + public void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + h3Server = HttpTestServer.create(HTTP_3_URI_ONLY, sslContext); + h3Server.addHandler(new Handler(), "/hello"); + h3Server.start(); + System.out.println("Server started at " + h3Server.getAddress()); + requestURIBase = URIBuilder.newBuilder().scheme("https").loopback() + .port(h3Server.getAddress().getPort()).build().toString(); + + } + + @AfterClass + public void afterClass() throws Exception { + if (h3Server != null) { + System.out.println("Stopping server " + h3Server.getAddress()); + h3Server.stop(); + } + } + + private static final class Handler implements HttpTestHandler { + private static final byte[] DUMMY_BODY = "foo bar hello world".getBytes(StandardCharsets.UTF_8); + private static volatile boolean stop; + private static final CountDownLatch stopped = new CountDownLatch(1); + + /** + * Keeps writing out response data (bytes) until asked to stop + */ + @Override + public void handle(final HttpTestExchange exchange) throws IOException { + System.out.println("Handling request: " + exchange.getRequestURI()); + exchange.sendResponseHeaders(200, -1); + try (final OutputStream os = exchange.getResponseBody()) { + while (!stop) { + os.write(DUMMY_BODY); + os.flush(); + System.out.println("Wrote response data of size " + DUMMY_BODY.length); + try { + Thread.sleep(5); + } catch (InterruptedException e) { + // ignore + } + } + System.out.println("Stopped writing response"); + } catch (IOException io) { + System.out.println("Got expected exception: " + io); + } finally { + stopped.countDown(); + } + } + } + + /** + * Issues a HTTP3 request to a server handler which keeps sending data. When some amount of + * data is received on the client side, the request is cancelled by the test method. This + * internally is expected to trigger a STOP_SENDING frame from the HTTP client to the server. + */ + @Test + public void testStopSending() throws Exception { + HttpClient client = newClientBuilderForH3() + .version(Version.HTTP_3) + .sslContext(sslContext).build(); + final URI reqURI = new URI(requestURIBase + "/hello"); + final HttpRequest req = HttpRequest.newBuilder(reqURI) + .version(Version.HTTP_3) + .setOption(H3_DISCOVERY, HTTP_3_URI_ONLY) + .build(); + // used to wait and trigger a request cancellation + final CountDownLatch cancellationTrigger = new CountDownLatch(1); + System.out.println("Issuing request to " + reqURI); + final CompletableFuture> futureResp = client.sendAsync(req, + BodyHandlers.fromSubscriber(new CustomBodySubscriber(cancellationTrigger))); + // wait for the subscriber to receive some amount of response data before we trigger + // the request cancellation + System.out.println("Awaiting some response data to arrive"); + cancellationTrigger.await(); + System.out.println("Cancelling request"); + // cancel the request which will internal trigger a STOP_SENDING frame from the HTTP + // client to the server + final boolean cancelled = futureResp.cancel(true); + System.out.println("Cancelled request: " + cancelled); + try { + // we expect a CancellationException for a cancelled request, + // but due to a bug (race condition) in the HttpClient's implementation + // of the Future instance, sometimes the Future.cancel(true) results + // in an ExecutionException which wraps the CancellationException. + // TODO: fix the actual race condition and then expect only CancellationException here + final Exception actualException = Assert.expectThrows(Exception.class, futureResp::get); + if (actualException instanceof CancellationException) { + // expected + System.out.println("Received the expected CancellationException"); + } else if (actualException instanceof ExecutionException + && actualException.getCause() instanceof CancellationException) { + System.out.println("Received CancellationException wrapped as ExecutionException"); + } else { + // unexpected + throw actualException; + } + } catch (Exception | Error e) { + Handler.stop = true; + System.err.println("Unexpected exception: " + e); + e.printStackTrace(); + throw e; + } finally { + // wait until the handler stops sending + Handler.stopped.await(10, TimeUnit.SECONDS); + } + var TRACKER = ReferenceTracker.INSTANCE; + var tracker = TRACKER.getTracker(client); + client = null; + System.gc(); + var error = TRACKER.check(tracker,1000); + if (error != null) throw error; + } + + /** + * A {@link java.net.http.HttpResponse.BodySubscriber} which informs any interested parties + * whenever it receives any data in {@link #onNext(List)} + */ + private static final class CustomBodySubscriber extends ResponseSubscribers.ByteArraySubscriber { + // the latch used to inform any interested parties about data being received + private final CountDownLatch latch; + + private CustomBodySubscriber(final CountDownLatch latch) { + // a finisher which just returns the bytes back + super((bytes) -> bytes); + this.latch = latch; + } + + @Override + public void onNext(final List items) { + super.onNext(items); + long totalSize = 0; + for (final ByteBuffer bb : items) { + totalSize += bb.remaining(); + } + System.out.println("Subscriber got response data of size " + totalSize); + // inform interested party that we received some data + latch.countDown(); + } + } +} diff --git a/test/jdk/java/net/httpclient/http3/StreamLimitTest.java b/test/jdk/java/net/httpclient/http3/StreamLimitTest.java new file mode 100644 index 00000000000..d8a920c8544 --- /dev/null +++ b/test/jdk/java/net/httpclient/http3/StreamLimitTest.java @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.OutputStream; +import java.net.SocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestExchange; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestHandler; +import jdk.httpclient.test.lib.common.HttpServerAdapters.HttpTestServer; +import jdk.httpclient.test.lib.http3.Http3TestServer; +import jdk.httpclient.test.lib.quic.QuicServer; +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.internal.net.http.quic.QuicTransportParameters; +import jdk.internal.net.http.quic.QuicTransportParameters.ParameterId; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; +import static java.net.http.HttpClient.Builder.NO_PROXY; +import static java.net.http.HttpClient.Version.HTTP_3; +import static java.net.http.HttpOption.H3_DISCOVERY; + +/* + * @test + * @summary verifies that when the Quic stream limit is reached + * then HTTP3 requests are retried on newer connection + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext + * jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.httpclient.test.lib.http3.Http3TestServer + * @run testng/othervm -Djdk.internal.httpclient.debug=true StreamLimitTest + */ +public class StreamLimitTest { + + private SSLContext sslContext; + private HttpTestServer server; + private QuicServer quicServer; + private URI requestURI; + private volatile QuicServerConnection latestServerConn; + + private final class Handler implements HttpTestHandler { + + @Override + public void handle(HttpTestExchange exchange) throws IOException { + final String handledBy = latestServerConn.logTag(); + System.out.println(handledBy + " handling request " + exchange.getRequestURI()); + final byte[] respBody; + if (handledBy == null) { + respBody = new byte[0]; + } else { + respBody = handledBy.getBytes(StandardCharsets.UTF_8); + } + exchange.sendResponseHeaders(200, respBody.length == 0 ? -1 : respBody.length); + // write out the server's connection id as a response + if (respBody.length > 0) { + try (OutputStream os = exchange.getResponseBody()) { + os.write(respBody); + } + } + exchange.close(); + } + } + + @BeforeClass + public void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + quicServer = Http3TestServer.quicServerBuilder().sslContext(sslContext).build(); + final Http3TestServer h3Server = new Http3TestServer(quicServer) { + @Override + public boolean acceptIncoming(SocketAddress source, QuicServerConnection quicConn) { + final boolean accepted = super.acceptIncoming(source, quicConn); + if (accepted) { + // keep track of the latest server connection + latestServerConn = quicConn; + } + return accepted; + } + }; + server = HttpTestServer.of(h3Server); + server.addHandler(new Handler(), "/foo"); + server.start(); + System.out.println("Server started at " + server.serverAuthority()); + requestURI = new URI("https://" + server.serverAuthority() + "/foo"); + } + + @AfterClass + public void afterClass() throws Exception { + latestServerConn = null; + if (server != null) { + server.stop(); + } + } + + /** + * Configures different limits for max bidi stream creation by HTTP client and then verifies + * the expected behaviour by sending HTTP3 requests + */ + @Test + public void testBidiMaxStreamLimit() throws Exception { + final QuicTransportParameters transportParameters = new QuicTransportParameters(); + final int intialMaxStreamLimit = 3; + final AtomicInteger maxStreamLimit = new AtomicInteger(intialMaxStreamLimit); + transportParameters.setIntParameter(ParameterId.initial_max_streams_bidi, + maxStreamLimit.get()); + // set the limit so that any new server connections created with advertise this limit + // to the peer (client) + quicServer.setTransportParameters(transportParameters); + // also set a MAX_STREAMS limit computer for this server so that the created + // connections use this computer for deciding MAX_STREAMS limit + quicServer.setMaxStreamLimitComputer((ignore) -> maxStreamLimit.longValue()); + final HttpClient client = HttpServerAdapters.createClientBuilderForH3() + .proxy(NO_PROXY) + .sslContext(sslContext) + .build(); + final HttpRequest req = HttpRequest.newBuilder().version(HTTP_3) + .GET().uri(requestURI) + .setOption(H3_DISCOVERY, server.h3DiscoveryConfig()) + .build(); + String requestHandledBy = null; + System.out.println("Server has been configured with a limit for max" + + " bidi streams: " + intialMaxStreamLimit); + // issue N number of requests where N == the max bidi stream creation limit that is + // advertised to the peer. All these N requests are expected to be handled by the same + // server connection + for (int i = 1; i <= intialMaxStreamLimit; i++) { + System.out.println("Sending request " + i + " to " + requestURI); + final HttpResponse resp = client.send(req, + HttpResponse.BodyHandlers.ofString()); + Assert.assertEquals(resp.version(), HTTP_3, "Unexpected response version"); + Assert.assertEquals(resp.statusCode(), 200, "Unexpected response code"); + final String respBody = resp.body(); + System.out.println("Request " + i + " was handled by server connection: " + respBody); + if (i == 1) { + // first request; keep track the server connection id which responded + // to this request + requestHandledBy = respBody; + } else { + Assert.assertEquals(respBody, requestHandledBy, "Request was handled by an" + + " unexpected server connection"); + } + } + // at this point the limit for bidi stream creation has reached on the client. + // now issue a request so that: + // - the HTTP client implementation notices that it has hit the limit + // - HTTP client sends a STREAMS_BLOCKED frame and waits for a while (timeout derived + // internally based on request timeout) for server to increase the limit. But this server + // connection will not send any MAX_STREAMS frame upon receipt of STREAMS_BLOCKED frame + // - client notices that server connection hasn't increased the stream limit, so internally + // retries the request, which should trigger a new server connection and thus this + // request should be handled by a different server connection than the last N requests + final HttpRequest reqWithTimeout = HttpRequest.newBuilder().version(HTTP_3) + .GET().uri(requestURI) + .setOption(H3_DISCOVERY, server.h3DiscoveryConfig()) + .timeout(Duration.of(10, ChronoUnit.SECONDS)) + .build(); + for (int i = 1; i <= intialMaxStreamLimit; i++) { + System.out.println("Sending request " + i + " (configured with timeout) to " + + requestURI); + final HttpResponse resp = client.send(reqWithTimeout, + HttpResponse.BodyHandlers.ofString()); + Assert.assertEquals(resp.version(), HTTP_3, "Unexpected response version"); + Assert.assertEquals(resp.statusCode(), 200, "Unexpected response code"); + final String respBody = resp.body(); + System.out.println("Request " + i + " was handled by server connection: " + respBody); + if (i == 1) { + // first request after the limit was hit. + // verify that it was handled by a new connection and not the one that handled + // the previous N requests + Assert.assertNotEquals(respBody, requestHandledBy, "Request was expected to be" + + " handled by a new server connection, but wasn't"); + // keep track this new server connection id which responded to this request + requestHandledBy = respBody; + } else { + Assert.assertEquals(respBody, requestHandledBy, "Request was handled by an" + + " unexpected server connection"); + } + } + // at this point the limit for bidi stream creation has reached on the client, for + // this new server connection. we now configure this current server connection to + // increment the limit to new higher limit + maxStreamLimit.set(intialMaxStreamLimit + 2); + System.out.println("Server connection " + latestServerConn + " has now been configured to" + + " increment the bidi stream creation limit to " + maxStreamLimit.get()); + // we now issue new requests, with timeout specified + // - the HTTP client implementation notices that it has hit the limit + // - HTTP client sends a STREAMS_BLOCKED frame and waits for a while (timeout derived + // internally based on request timeout) for server to increase the limit. This server + // connection, since it is configured to increment the stream limit, will send + // MAX_STREAMS frame upon receipt of STREAMS_BLOCKED frame + // - client receives the MAX_STREAMS frame (hopefully within the timeout) and + // notices that server connection has increased the stream limit, so opens a new bidi + // stream and lets the request move forward (on the same server connection) + final int numNewRequests = maxStreamLimit.get() - intialMaxStreamLimit; + for (int i = 1; i <= numNewRequests; i++) { + System.out.println("Sending request " + i + " (after stream limit has been increased)" + + " to " + requestURI); + final HttpResponse resp = client.send(reqWithTimeout, + HttpResponse.BodyHandlers.ofString()); + Assert.assertEquals(resp.version(), HTTP_3, "Unexpected response version"); + Assert.assertEquals(resp.statusCode(), 200, "Unexpected response code"); + final String respBody = resp.body(); + System.out.println("Request " + i + " was handled by server connection: " + respBody); + // all these requests should be handled by the same server connection which handled + // the previous requests + Assert.assertEquals(respBody, requestHandledBy, "Request was handled by an" + + " unexpected server connection"); + } + // at this point the newer limit for bidi stream creation has reached on the client. + // we now issue a new request without any timeout configured on the request, so that the + // client internally just immediately retries and uses a different connection on noticing + // that the stream creation limit for the current server connection has been reached. + System.out.println("Server connection " + latestServerConn + " has now been configured to" + + " not increase max stream limit for bidi streams created by the client"); + final HttpRequest finalReq = HttpRequest.newBuilder() + .version(HTTP_3) + .setOption(H3_DISCOVERY, server.h3DiscoveryConfig()) + .GET().uri(requestURI) + .build(); + System.out.println("Sending request, without timeout, to " + requestURI); + final HttpResponse finalResp = client.send(finalReq, + HttpResponse.BodyHandlers.ofString()); + Assert.assertEquals(finalResp.version(), HTTP_3, "Unexpected response version"); + Assert.assertEquals(finalResp.statusCode(), 200, "Unexpected response code"); + final String finalRespBody = finalResp.body(); + System.out.println("Request was handled by server connection: " + finalRespBody); + // this request should have been handled by a new server connection + Assert.assertNotEquals(finalRespBody, requestHandledBy, "Request was handled by an" + + " unexpected server connection"); + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/DynamicKeyStoreUtil.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/DynamicKeyStoreUtil.java new file mode 100644 index 00000000000..5b96d4c0b76 --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/DynamicKeyStoreUtil.java @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.httpclient.test.lib.common; + +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.KeyStore; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.SecureRandom; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; +import java.util.Date; +import java.util.Objects; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; + +import sun.security.x509.CertificateExtensions; +import sun.security.x509.CertificateSerialNumber; +import sun.security.x509.CertificateValidity; +import sun.security.x509.CertificateVersion; +import sun.security.x509.CertificateX509Key; +import sun.security.x509.DNSName; +import sun.security.x509.GeneralName; +import sun.security.x509.GeneralNames; +import sun.security.x509.SubjectAlternativeNameExtension; +import sun.security.x509.X500Name; +import sun.security.x509.X509CertImpl; +import sun.security.x509.X509CertInfo; + +/** + * A utility for generating dynamic {@link java.security.KeyStore}s in tests. The keystores + * generated by this utility are dynamic in the sense that they are generated as necessary and + * don't require keys/certificates to be present on the filesystem. The generated keystores too + * aren't saved to the filesystem, by this utility. + */ +public class DynamicKeyStoreUtil { + + private static final String PKCS12_KEYSTORE_TYPE = "PKCS12"; + + /** + * The default alias that will be used in the KeyStore generated by + * {@link #generateKeyStore(String, String...)} + */ + public static final String DEFAULT_ALIAS = "foobar-key-alias"; + + private static final String DEFAULT_SUBJECT_OU = "foobar-org-unit"; + private static final String DEFAULT_SUBJECT_ON = "foobar-org-name"; + private static final String DEFAULT_SUBJECT_COUNTRY = "US"; + + + /** + * Generates a PKCS12 type {@link KeyStore} which has one + * {@link KeyStore#getKey(String, char[]) key entry} which corresponds to a newly generated + * {@link java.security.PrivateKey} and accompanied by a newly generated self-signed certificate, + * certifying the corresponding public key. + *

    + * The newly generated {@link X509Certificate} certificate will use the {@code certSubject} + * as the certificate's subject. If the {@code certSubjectAltNames} is non-null then + * the certificate will be created with a + * {@link sun.security.x509.SubjectAlternativeNameExtension subject alternative name extension} + * which will include the {@code certSubject} and each of the {@code certSubjectAltNames} as + * subject alternative names (represented as DNS names) + *

    + * The generated KeyStore won't be password protected + * + * @param certSubject The subject to be used for the newly generated certificate + * @param certSubjectAltNames Optional subject alternative names to be used in the generated + * certificate + * @return The newly generated KeyStore + * @throws NullPointerException If {@code certSubject} is null + */ + public static KeyStore generateKeyStore(final String certSubject, final String... certSubjectAltNames) + throws Exception { + Objects.requireNonNull(certSubject); + final SecureRandom secureRandom = new SecureRandom(); + final KeyPair keyPair = generateRSAKeyPair(secureRandom); + final X509Certificate cert = generateCert(keyPair, secureRandom, certSubject, certSubjectAltNames); + final KeyStore.Builder keystoreBuilder = KeyStore.Builder.newInstance(PKCS12_KEYSTORE_TYPE, + null, new KeyStore.PasswordProtection(null)); + final KeyStore keyStore = keystoreBuilder.getKeyStore(); + // write a private key (with cert chain for the public key) + final char[] keyPassword = null; + keyStore.setKeyEntry(DEFAULT_ALIAS, keyPair.getPrivate(), keyPassword, new Certificate[]{cert}); + return keyStore; + } + + /** + * Generates a PKCS12 type {@link KeyStore} which has one + * {@link KeyStore#getKey(String, char[]) key entry} which corresponds to the {@code privateKey} + * and accompanied by the {@code certChain} certifying the corresponding public key. + *

    + * The generated KeyStore won't be password protected + * + * @param privateKey The private key to include in the keystore + * @param certChain The certificate chain + * @return The newly generated KeyStore + * @throws NullPointerException If {@code privateKey} or {@code certChain} is null + */ + public static KeyStore generateKeyStore(final PrivateKey privateKey, final Certificate[] certChain) + throws Exception { + Objects.requireNonNull(privateKey); + Objects.requireNonNull(certChain); + final KeyStore.Builder keystoreBuilder = KeyStore.Builder.newInstance(PKCS12_KEYSTORE_TYPE, + null, new KeyStore.PasswordProtection(null)); + final KeyStore keyStore = keystoreBuilder.getKeyStore(); + // write a private key (with cert chain for the public key) + final char[] keyPassword = null; + keyStore.setKeyEntry(DEFAULT_ALIAS, privateKey, keyPassword, certChain); + return keyStore; + } + + /** + * Creates and returns a new PKCS12 type keystore without any entries in the keystore + * + * @return The newly created keystore + * @throws Exception if any exception occurs during keystore generation + */ + public static KeyStore generateBlankKeyStore() throws Exception { + final KeyStore.Builder keystoreBuilder = KeyStore.Builder.newInstance(PKCS12_KEYSTORE_TYPE, + null, new KeyStore.PasswordProtection(null)); + return keystoreBuilder.getKeyStore(); + } + + /** + * Generates a {@link KeyPair} using the {@code secureRandom} + * + * @param secureRandom The SecureRandom + * @return The newly generated KeyPair + * @throws NoSuchAlgorithmException + */ + public static KeyPair generateRSAKeyPair(final SecureRandom secureRandom) + throws NoSuchAlgorithmException { + final String keyType = "RSA"; + final int defaultRSAKeySize = 3072; + final KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(keyType); + keyPairGenerator.initialize(defaultRSAKeySize, secureRandom); + //final String sigAlgName = "SHA384withRSA"; + final KeyPair pair = keyPairGenerator.generateKeyPair(); + final PublicKey publicKey = pair.getPublic(); + // publicKey's format must be X.509 otherwise + if (!"X.509".equalsIgnoreCase(publicKey.getFormat())) { + throw new IllegalArgumentException("Public key format " + publicKey.getFormat() + + ", isn't X.509"); + } + return pair; + } + + /** + * Generates a X509 certificate using the passed {@code keyPair} + * + * @param keyPair The KeyPair to use + * @param secureRandom The SecureRandom + * @param subjectName The subject to use in the certificate + * @param subjectAltNames Any subject alternate names to use in the certificate + * @return The newly generated certificate + * @throws Exception + */ + public static X509Certificate generateCert(final KeyPair keyPair, final SecureRandom secureRandom, + final String subjectName, final String... subjectAltNames) + throws Exception { + final X500Name subject = new X500Name(subjectName, DEFAULT_SUBJECT_OU, DEFAULT_SUBJECT_ON, + DEFAULT_SUBJECT_COUNTRY); + final X500Name issuer = subject; // self-signed cert + final GeneralNames generalNames; + if (subjectAltNames == null) { + generalNames = null; + } else { + generalNames = new GeneralNames(); + generalNames.add(new GeneralName(new DNSName(subjectName, true))); + for (final String san : subjectAltNames) { + if (san == null) { + continue; + } + final DNSName dnsName = new DNSName(san, true); + generalNames.add(new GeneralName(dnsName)); + } + } + return generateCert(keyPair, secureRandom, subject, issuer, generalNames); + } + + private static X509Certificate generateCert(final KeyPair keyPair, final SecureRandom secureRandom, + final X500Name subjectName, final X500Name issuerName, + final GeneralNames subjectAltNames) + throws Exception { + final X509CertInfo certInfo = new X509CertInfo(); + certInfo.setVersion(new CertificateVersion(CertificateVersion.V3)); + certInfo.setSerialNumber(CertificateSerialNumber.newRandom64bit(secureRandom)); + certInfo.setSubject(subjectName); + final PublicKey publicKey = keyPair.getPublic(); + certInfo.setKey(new CertificateX509Key(publicKey)); + certInfo.setValidity(certDuration()); + certInfo.setIssuer(issuerName); + if (subjectAltNames != null && !subjectAltNames.isEmpty()) { + final SubjectAlternativeNameExtension sanExtn = new SubjectAlternativeNameExtension( + true, subjectAltNames); + final CertificateExtensions certExtensions = new CertificateExtensions(); + certExtensions.setExtension(sanExtn.getId(), sanExtn); + certInfo.setExtensions(certExtensions); + } + final PrivateKey privateKey = keyPair.getPrivate(); + final String sigAlgName = "SHA384withRSA"; + return X509CertImpl.newSigned(certInfo, privateKey, sigAlgName); + } + + private static CertificateValidity certDuration() { + // create a cert with 1 day validity, starting from 1 minute back + final long currentTime = System.currentTimeMillis(); + final long oneMinuteInThePast = currentTime - (60 * 1000); + final long oneDayInTheFuture = currentTime + (24 * 60 * 60 * 1000); + final Date startDate = new Date(oneMinuteInThePast); + final Date expiryDate = new Date(oneDayInTheFuture); + return new CertificateValidity(startDate, expiryDate); + } + + /** + * Creates a {@link SSLContext} which is + * {@link SSLContext#init(KeyManager[], TrustManager[], SecureRandom) initialized} using the + * {@link javax.net.ssl.KeyManager}s and {@link javax.net.ssl.TrustManager}s available in the + * {@code keyStore}. The {@code keyStore} is expected to be password-less. + * + * @param keyStore The password-less keystore + * @return the SSLContext + * @throws Exception + */ + public static SSLContext createSSLContext(final KeyStore keyStore) throws Exception { + Objects.requireNonNull(keyStore); + final char[] password = null; + final KeyManagerFactory kmf = KeyManagerFactory.getInstance("PKIX"); + kmf.init(keyStore, password); + + final TrustManagerFactory tmf = TrustManagerFactory.getInstance("PKIX"); + tmf.init(keyStore); + + final String protocol = "TLS"; + final SSLContext ctx = SSLContext.getInstance(protocol); + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + return ctx; + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/HttpServerAdapters.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/HttpServerAdapters.java index 93a93ad25d2..10633340a66 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/HttpServerAdapters.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/HttpServerAdapters.java @@ -34,31 +34,46 @@ import com.sun.net.httpserver.HttpsServer; import jdk.httpclient.test.lib.http2.Http2Handler; import jdk.httpclient.test.lib.http2.Http2TestExchange; import jdk.httpclient.test.lib.http2.Http2TestServer; +import jdk.httpclient.test.lib.http3.Http3TestServer; import jdk.internal.net.http.common.HttpHeadersBuilder; +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.qpack.Encoder; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; import java.net.InetAddress; import java.io.ByteArrayInputStream; +import java.net.ProxySelector; +import java.net.http.HttpClient; +import java.net.http.HttpClient.Builder; import java.net.http.HttpClient.Version; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintStream; -import java.io.UncheckedIOException; -import java.math.BigInteger; import java.net.InetSocketAddress; import java.net.URI; import java.net.http.HttpHeaders; +import java.net.http.HttpOption.Http3DiscoveryMode; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.Base64; +import java.util.HashSet; +import java.util.HexFormat; import java.util.List; import java.util.ListIterator; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.OptionalLong; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutorService; +import java.util.function.Supplier; import java.util.function.Predicate; import java.util.logging.Level; import java.util.logging.Logger; @@ -66,9 +81,12 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSession; import static java.net.http.HttpClient.Version.HTTP_1_1; import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; +import static jdk.test.lib.Asserts.assertFileContentsEqual; /** * Defines an adaptation layers so that a test server handlers and filters @@ -93,24 +111,14 @@ import static java.net.http.HttpClient.Version.HTTP_2; */ public interface HttpServerAdapters { - static final boolean PRINTSTACK = - Boolean.getBoolean("jdk.internal.httpclient.debug"); - - static void uncheckedWrite(ByteArrayOutputStream baos, byte[] ba) { - try { - baos.write(ba); - } catch (IOException e) { - throw new UncheckedIOException(e); - } + static final boolean PRINTSTACK = getPrintStack(); + private static boolean getPrintStack() { + return Boolean.getBoolean("jdk.internal.httpclient.debug"); } + static final HexFormat HEX_FORMAT = HexFormat.ofDelimiter(":").withUpperCase(); static void printBytes(PrintStream out, String prefix, byte[] bytes) { - int padding = 4 + 4 - (bytes.length % 4); - padding = padding > 4 ? padding - 4 : 4; - byte[] bigbytes = new byte[bytes.length + padding]; - System.arraycopy(bytes, 0, bigbytes, padding, bytes.length); - out.println(prefix + bytes.length + " " - + new BigInteger(bigbytes).toString(16)); + out.println(prefix + bytes.length + " " + HEX_FORMAT.formatHex(bytes)); } /** @@ -122,6 +130,7 @@ public interface HttpServerAdapters { public abstract Set>> entrySet(); public abstract List get(String name); public abstract boolean containsKey(String name); + public abstract OptionalLong firstValueAsLong(String name); @Override public boolean equals(Object o) { if (this == o) return true; @@ -166,6 +175,11 @@ public interface HttpServerAdapters { return headers.containsKey(name); } @Override + public OptionalLong firstValueAsLong(String name) { + return Optional.ofNullable(headers.getFirst(name)) + .stream().mapToLong(Long::parseLong).findFirst(); + } + @Override public String toString() { return String.valueOf(headers); } @@ -192,6 +206,10 @@ public interface HttpServerAdapters { return headers.firstValue(name).isPresent(); } @Override + public OptionalLong firstValueAsLong(String name) { + return headers.firstValueAsLong(name); + } + @Override public String toString() { return String.valueOf(headers); } @@ -244,18 +262,169 @@ public interface HttpServerAdapters { public abstract String getRequestMethod(); public abstract void close(); public abstract InetSocketAddress getRemoteAddress(); - public abstract String getConnectionKey(); public abstract InetSocketAddress getLocalAddress(); - public void serverPush(URI uri, HttpHeaders headers, byte[] body) { + public abstract String getConnectionKey(); + public abstract SSLSession getSSLSession(); + public void serverPush(URI uri, HttpHeaders reqHeaders, byte[] body) throws IOException { ByteArrayInputStream bais = new ByteArrayInputStream(body); - serverPush(uri, headers, bais); + serverPush(uri, reqHeaders, bais); } - public void serverPush(URI uri, HttpHeaders headers, InputStream body) { + public void serverPush(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, byte[] body) throws IOException { + ByteArrayInputStream bais = new ByteArrayInputStream(body); + serverPush(uri, reqHeaders, rspHeaders, bais); + } + public void serverPush(URI uri, HttpHeaders reqHeaders, InputStream body) + throws IOException { + serverPush(uri, reqHeaders, HttpHeaders.of(Map.of(), (n,v) -> true), body); + } + + public void serverPush(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream body) + throws IOException { throw new UnsupportedOperationException("serverPush with " + getExchangeVersion()); } + + public void requestStopSending(long errorCode) { + throw new UnsupportedOperationException("sendHttp3ConnectionClose with " + getExchangeVersion()); + } + + /** + * Sends an HTTP/3 PUSH_PROMISE frame, for the given {@code uri}, + * with the given request {@code reqHeaders}, and opens a push promise + * stream to send the given response {@code rspHeaders} and {@code body}. + * + * @implSpec + * The default implementation of this method throws {@link + * UnsupportedOperationException} + * + * @param uri the push promise URI + * @param reqHeaders the push promise request headers + * @param rspHeaders the push promise request headers + * @param body the push response body + * + * @return the pushId used to push the promise + * + * @throws IOException if an error occurs + * @throws UnsupportedOperationException if the exchange is not {@link + * #getExchangeVersion() HTTP_3} + */ + public long http3ServerPush(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream body) + throws IOException { + throw new UnsupportedOperationException("serverPushWithId with " + getExchangeVersion()); + } + /** + * Sends an HTTP/3 PUSH_PROMISE frame, for the given {@code uri}, + * with the given request {@code headers}, and with the given + * {@code pushId}. This method only sends the PUSH_PROMISE frame + * and doesn't open any push stream. + * + * @apiNote + * This method can be used to send a PUSH_PROMISE whose body has + * already been promised by calling {@link + * #http3ServerPush(URI, HttpHeaders, HttpHeaders, InputStream)}. In that case + * the {@code pushId} returned by {@link + * #http3ServerPush(URI, HttpHeaders, HttpHeaders, InputStream)} should be passed + * as parameter. Otherwise, if {@code pushId=-1} is passed as parameter, + * a new pushId will be allocated. The push response headers and body + * can be later sent using {@link + * #sendHttp3PushResponse(long, URI, HttpHeaders, HttpHeaders, InputStream)}. + * + * @implSpec + * The default implementation of this method throws {@link + * UnsupportedOperationException} + * + * @param pushId the pushId to use, or {@code -1} if a new + * pushId should be allocated. + * @param uri the push promise URI + * @param headers the push promise request headers + * @return the given pushId, if positive, otherwise the new allocated pushId + * + * @throws IOException if an error occurs + * @throws UnsupportedOperationException if the exchange is not {@link + * #getExchangeVersion() HTTP_3} + */ + public long sendHttp3PushPromiseFrame(long pushId, URI uri, HttpHeaders headers) + throws IOException { + throw new UnsupportedOperationException("serverPushId with " + getExchangeVersion()); + } + /** + * Opens an HTTP/3 PUSH_STREAM to send a push promise response headers + * and body. + * + * @apiNote + * No check is performed on the provided pushId + * + * @param pushId a positive pushId obtained from {@link + * #sendHttp3PushPromiseFrame(long, URI, HttpHeaders)} + * @param uri the push request URI + * @param reqHeaders the push promise request headers + * @param rspHeaders the push promise response headers + * @param body the push response body + * + * @throws IOException if an error occurs + * @throws UnsupportedOperationException if the exchange is not {@link + */ + public void sendHttp3PushResponse(long pushId, URI uri, + HttpHeaders reqHeaders, + HttpHeaders rspHeaders, + InputStream body) + throws IOException { + throw new UnsupportedOperationException("serverPushWithId with " + getExchangeVersion()); + } + /** + * Sends an HTTP/3 CANCEL_PUSH frame to cancel a push that has been + * promised by either {@link #http3ServerPush(URI, HttpHeaders, HttpHeaders, InputStream)} + * or {@link #sendHttp3PushPromiseFrame(long, URI, HttpHeaders)}. + * + * This method doesn't cancel the push stream but just sends + * a CANCEL_PUSH frame. + * Note that if the push stream has already been opened this + * method may not have any effect. + * + * @apiNote + * No check is performed on the provided pushId + * + * @implSpec + * The default implementation of this method throws {@link + * UnsupportedOperationException} + * + * @param pushId the cancelled pushId + * + * @throws IOException if an error occurs + * @throws UnsupportedOperationException if the exchange is not {@link + * #getExchangeVersion() HTTP_3} + */ + public void sendHttp3CancelPushFrame(long pushId) + throws IOException { + throw new UnsupportedOperationException("cancelPushId with " + getExchangeVersion()); + } + /** + * Waits until the given {@code pushId} is allowed by the HTTP/3 peer + * + * @implSpec + * The default implementation of this method throws {@link + * UnsupportedOperationException} + * + * @param pushId a pushId + * + * @return the maximum pushId allowed (exclusive) + * + * @throws UnsupportedOperationException if the exchange is not {@link + * #getExchangeVersion() HTTP_3} + */ + public long waitForHttp3MaxPushId(long pushId) + throws InterruptedException { + throw new UnsupportedOperationException("waitForMaxPushId with " + getExchangeVersion()); + } public boolean serverPushAllowed() { return false; } + public Encoder qpackEncoder() { + throw new UnsupportedOperationException("qpackEncoder with " + getExchangeVersion()); + } + public CompletableFuture clientHttp3Settings() { + throw new UnsupportedOperationException("HTTP/3 client connection settings with " + + getExchangeVersion()); + } public static HttpTestExchange of(HttpExchange exchange) { return new Http1TestExchange(exchange); } @@ -265,6 +434,10 @@ public interface HttpServerAdapters { abstract void doFilter(Filter.Chain chain) throws IOException; + public void resetStream(long code) throws IOException { + throw new UnsupportedOperationException(String.valueOf(this.getServerVersion())); + } + // implementations... private static final class Http1TestExchange extends HttpTestExchange { private final HttpExchange exchange; @@ -303,7 +476,8 @@ public interface HttpServerAdapters { } @Override public void close() { exchange.close(); } - + @Override + public SSLSession getSSLSession() { return null; } @Override public InetSocketAddress getRemoteAddress() { return exchange.getRemoteAddress(); @@ -334,9 +508,9 @@ public interface HttpServerAdapters { this.exchange = exch; } @Override - public Version getServerVersion() { return HTTP_2; } + public Version getServerVersion() { return exchange.getServerVersion(); } @Override - public Version getExchangeVersion() { return HTTP_2; } + public Version getExchangeVersion() { return exchange.getServerVersion(); } @Override public InputStream getRequestBody() { return exchange.getRequestBody(); @@ -365,15 +539,61 @@ public interface HttpServerAdapters { return exchange.serverPushAllowed(); } @Override - public void serverPush(URI uri, HttpHeaders headers, InputStream body) { - exchange.serverPush(uri, headers, body); + public void serverPush(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream body) + throws IOException { + exchange.serverPush(uri, reqHeaders, rspHeaders, body); } + @Override + public void requestStopSending(long errorCode) { + exchange.requestStopSending(errorCode); + } + @Override + public void resetStream(long code) throws IOException { + exchange.resetStream(code); + } + + @Override + public long http3ServerPush(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream body) throws IOException { + return exchange.serverPushWithId(uri, reqHeaders, rspHeaders, body); + } + @Override + public long sendHttp3PushPromiseFrame(long pushId, URI uri, HttpHeaders reqHeaders) throws IOException { + return exchange.sendPushId(pushId, uri, reqHeaders); + } + @Override + public void sendHttp3CancelPushFrame(long pushId) throws IOException { + exchange.cancelPushId(pushId); + } + @Override + public void sendHttp3PushResponse(long pushId, + URI uri, + HttpHeaders reqHeaders, + HttpHeaders rspHeaders, + InputStream body) throws IOException { + exchange.sendPushResponse(pushId, uri, reqHeaders, rspHeaders, body); + } + @Override + public long waitForHttp3MaxPushId(long pushId) throws InterruptedException { + return exchange.waitForMaxPushId(pushId); + } + @Override + public Encoder qpackEncoder() { + return exchange.qpackEncoder(); + } + + @Override + public CompletableFuture clientHttp3Settings() { + return exchange.clientHttp3Settings(); + } + + @Override void doFilter(Filter.Chain filter) throws IOException { throw new IOException("cannot use HTTP/1.1 filter with HTTP/2 server"); } @Override public void close() { exchange.close();} - + @Override + public SSLSession getSSLSession() { return exchange.getSSLSession();} @Override public InetSocketAddress getRemoteAddress() { return exchange.getRemoteAddress(); @@ -400,31 +620,6 @@ public interface HttpServerAdapters { } - - /** - * A version agnostic adapter class for HTTP Server Handlers. - */ - public interface HttpTestHandler { - void handle(HttpTestExchange t) throws IOException; - - default HttpHandler toHttpHandler() { - return (t) -> doHandle(HttpTestExchange.of(t)); - } - default Http2Handler toHttp2Handler() { - return (t) -> doHandle(HttpTestExchange.of(t)); - } - private void doHandle(HttpTestExchange t) throws IOException { - try { - handle(t); - } catch (Throwable x) { - System.out.println("WARNING: exception caught in HttpTestHandler::handle " + x); - System.err.println("WARNING: exception caught in HttpTestHandler::handle " + x); - if (PRINTSTACK && !expectException(t)) x.printStackTrace(System.out); - throw x; - } - } - } - /** * An {@link HttpTestHandler} that handles only HEAD and GET * requests. If another method is used 405 is returned with @@ -473,24 +668,192 @@ public interface HttpServerAdapters { } + /** + * A version agnostic adapter class for HTTP Server Handlers. + */ + public interface HttpTestHandler { + void handle(HttpTestExchange t) throws IOException; + + default HttpHandler toHttpHandler() { + return (t) -> doHandle(HttpTestExchange.of(t)); + } + default Http2Handler toHttp2Handler() { + return (t) -> doHandle(HttpTestExchange.of(t)); + } + + default void handleFailure(final HttpTestExchange exchange, Throwable failure) { + System.out.println("WARNING: exception caught in HttpTestHandler::handle " + failure); + System.err.println("WARNING: exception caught in HttpTestHandler::handle " + failure); + if (PRINTSTACK && !expectException(exchange)) { + failure.printStackTrace(System.out); + } + } + + private void doHandle(HttpTestExchange exchange) throws IOException { + try { + handle(exchange); + } catch (Throwable failure) { + handleFailure(exchange, failure); + throw failure; + } + } + } + + /** + * An echo handler that can be used to transfer large amount of data, and + * uses file on the file system to download the input. + */ + // TODO: it would be good if we could merge this with the Http2EchoHandler, + // from which this code was copied and adapted. + public static class HttpTestFileEchoHandler implements HttpTestHandler { + static final Path CWD = Paths.get("."); + + @Override + public void handle(HttpTestExchange t) throws IOException { + try { + System.err.printf("EchoHandler received request to %s from %s (version %s)%n", + t.getRequestURI(), t.getRemoteAddress(), t.getExchangeVersion()); + InputStream is = t.getRequestBody(); + var requestHeaders = t.getRequestHeaders(); + var responseHeaders = t.getResponseHeaders(); + responseHeaders.addHeader("X-Hello", "world"); + responseHeaders.addHeader("X-Bye", "universe"); + String fixedrequest = requestHeaders.firstValue("XFixed").orElse(null); + File outfile = Files.createTempFile(CWD, "foo", "bar").toFile(); + //System.err.println ("QQQ = " + outfile.toString()); + FileOutputStream fos = new FileOutputStream(outfile); + long count = is.transferTo(fos); + System.err.printf("EchoHandler read %s bytes\n", count); + is.close(); + fos.close(); + InputStream is1 = new FileInputStream(outfile); + OutputStream os = null; + + Path check = requestHeaders.firstValue("X-Compare") + .map((String s) -> Path.of(s)).orElse(null); + if (check != null) { + System.err.println("EchoHandler checking file match: " + check); + try { + assertFileContentsEqual(check, outfile.toPath()); + } catch (Throwable x) { + System.err.println("Files do not match: " + x); + t.sendResponseHeaders(500, -1); + outfile.delete(); + os.close(); + return; + } + } + + // return the number of bytes received (no echo) + String summary = requestHeaders.firstValue("XSummary").orElse(null); + if (fixedrequest != null && summary == null) { + t.sendResponseHeaders(200, count); + os = t.getResponseBody(); + if (!t.getRequestMethod().equals("HEAD")) { + long count1 = is1.transferTo(os); + System.err.printf("EchoHandler wrote %s bytes%n", count1); + } else { + System.err.printf("EchoHandler HEAD received, no bytes sent%n"); + } + } else { + t.sendResponseHeaders(200, -1); + os = t.getResponseBody(); + if (!t.getRequestMethod().equals("HEAD")) { + long count1 = is1.transferTo(os); + System.err.printf("EchoHandler wrote %s bytes\n", count1); + + if (summary != null) { + String s = Long.toString(count); + os.write(s.getBytes()); + } + } else { + System.err.printf("EchoHandler HEAD received, no bytes sent%n"); + } + } + outfile.delete(); + os.close(); + is1.close(); + } catch (Throwable e) { + e.printStackTrace(); + throw new IOException(e); + } + } + } + public static class HttpTestEchoHandler implements HttpTestHandler { + + private final boolean printBytes; + public HttpTestEchoHandler() { + this(true); + } + + public HttpTestEchoHandler(boolean printBytes) { + this.printBytes = printBytes; + } + @Override public void handle(HttpTestExchange t) throws IOException { try (InputStream is = t.getRequestBody(); OutputStream os = t.getResponseBody()) { byte[] bytes = is.readAllBytes(); - printBytes(System.out,"Echo server got " - + t.getExchangeVersion() + " bytes: ", bytes); + if (printBytes) { + printBytes(System.out, "Echo server got " + + t.getExchangeVersion() + " bytes: ", bytes); + } if (t.getRequestHeaders().firstValue("Content-type").isPresent()) { t.getResponseHeaders().addHeader("Content-type", t.getRequestHeaders().firstValue("Content-type").get()); } t.sendResponseHeaders(200, bytes.length); - os.write(bytes); + if (!t.getRequestMethod().equals("HEAD")) { + os.write(bytes); + } } } } + public static class HttpTestRedirectHandler implements HttpTestHandler { + + final Supplier supplier; + + public HttpTestRedirectHandler(Supplier redirectSupplier) { + supplier = redirectSupplier; + } + + @Override + public void handle(HttpTestExchange t) throws IOException { + examineExchange(t); + try (InputStream is = t.getRequestBody()) { + is.readAllBytes(); + String location = supplier.get(); + System.err.printf("RedirectHandler request to %s from %s\n", + t.getRequestURI().toString(), t.getRemoteAddress().toString()); + System.err.println("Redirecting to: " + location); + var headersBuilder = t.getResponseHeaders(); + headersBuilder.addHeader("Location", location); + byte[] bb = getResponseBytes(); + t.sendResponseHeaders(redirectCode(), bb.length); + OutputStream os = t.getResponseBody(); + os.write(bb); + os.close(); + t.close(); + } + } + + protected byte[] getResponseBytes() { + return new byte[1024]; + } + + protected int redirectCode() { + return 301; + } + + // override in sub-class to examine the exchange, but don't + // alter transaction state by reading the request body etc. + protected void examineExchange(HttpTestExchange t) { + } + } + public static boolean expectException(HttpTestExchange e) { HttpTestRequestHeaders h = e.getRequestHeaders(); Optional expectException = h.firstValue("X-expect-exception"); @@ -780,6 +1143,41 @@ public interface HttpServerAdapters { public abstract HttpTestContext addHandler(HttpTestHandler handler, String root); public abstract InetSocketAddress getAddress(); public abstract Version getVersion(); + + /** + * {@return the HTTP3 test server which is acting as an alt-service for this server, + * if any} + */ + public Optional getH3AltService() { + return Optional.empty(); + } + + /** + * {@return true if any HTTP3 test server is acting as an alt-service for this server and the + * HTTP3 test server listens on the same host and port as this server. + * Returns false otherwise} + */ + public boolean supportsH3DirectConnection() { + return false; + } + + public Http3DiscoveryMode h3DiscoveryConfig() { + return null; + } + + @Override + public String toString() { + var conf = Optional.ofNullable(h3DiscoveryConfig()).orElse(getVersion()); + return "HttpTestServer(%s: %s)".formatted(conf, serverAuthority()); + } + + /** + * @param version the HTTP version + * @param more additional HTTP versions + * {@return true if the handlers registered with this server can be accessed (through + * request URIs) using all of the passed HTTP versions. Returns false otherwise} + */ + public abstract boolean canHandle(Version version, Version... more); public abstract void setRequestApprover(final Predicate approver); @Override @@ -799,18 +1197,26 @@ public interface HttpServerAdapters { return hostString + ":" + address.getPort(); } - public static HttpTestServer of(HttpServer server) { + public static HttpTestServer of(final HttpServer server) { + Objects.requireNonNull(server); return new Http1TestServer(server); } - public static HttpTestServer of(HttpServer server, ExecutorService executor) { + public static HttpTestServer of(final HttpServer server, ExecutorService executor) { + Objects.requireNonNull(server); return new Http1TestServer(server, executor); } - public static HttpTestServer of(Http2TestServer server) { + public static HttpTestServer of(final Http2TestServer server) { + Objects.requireNonNull(server); return new Http2TestServerImpl(server); } + public static HttpTestServer of(final Http3TestServer server) { + Objects.requireNonNull(server); + return new H3ServerAdapter(server); + } + /** * Creates a {@link HttpTestServer} which supports the {@code serverVersion}. The server * will only be available on {@code http} protocol. {@code https} will not be supported @@ -841,7 +1247,7 @@ public interface HttpServerAdapters { public static HttpTestServer create(Version serverVersion, SSLContext sslContext) throws IOException { Objects.requireNonNull(serverVersion); - return create(serverVersion, sslContext, null); + return create(serverVersion, sslContext, null, null); } /** @@ -860,7 +1266,130 @@ public interface HttpServerAdapters { public static HttpTestServer create(Version serverVersion, SSLContext sslContext, ExecutorService executor) throws IOException { Objects.requireNonNull(serverVersion); + return create(serverVersion, sslContext, null, executor); + } + + /** + * Creates a {@link HttpTestServer} which supports HTTP_3 version. + * + * @param h3DiscoveryCfg Discovery config for HTTP_3 connection creation. Can be null + * @param sslContext SSLContext. Cannot be null + * @return The newly created server + * @throws IOException if any exception occurs during the server creation + */ + public static HttpTestServer create(Http3DiscoveryMode h3DiscoveryCfg, + SSLContext sslContext) + throws IOException { + Objects.requireNonNull(sslContext, "SSLContext"); + return create(h3DiscoveryCfg, sslContext, null); + } + + /** + * Creates a {@link HttpTestServer} which supports HTTP_3 version. + * + * @param h3DiscoveryCfg Discovery config for HTTP_3 connection creation. Can be null + * @param sslContext SSLContext. Cannot be null + * @param executor The executor to be used by the server. Can be null + * @return The newly created server + * @throws IOException if any exception occurs during the server creation + */ + public static HttpTestServer create(Http3DiscoveryMode h3DiscoveryCfg, + SSLContext sslContext, ExecutorService executor) + throws IOException { + Objects.requireNonNull(sslContext, "SSLContext"); + return create(HTTP_3, sslContext, h3DiscoveryCfg, executor); + } + + + /** + * Creates a {@link HttpTestServer} which supports the {@code serverVersion}. If the + * {@code sslContext} is null, then only {@code http} protocol will be supported by the + * server. Else, the server will be configured with the {@code sslContext} and will support + * {@code https} protocol. + * + * If {@code serverVersion} is {@link Version#HTTP_3 HTTP_3}, then a {@code h3DiscoveryCfg} + * can be passed to decide how the HTTP_3 server will be created. The following table + * summarizes how {@code h3DiscoveryCfg} is used: + *
      + *
    • HTTP3_ONLY - A server which only supports HTTP_3 is created
    • + *
    • HTTP3_ALTSVC - A HTTP_2 server is created and a HTTP_3 server is created. + * The HTTP_2 server advertises the HTTP_3 server as an alternate service. When + * creating the HTTP_3 server, an ephemeral port is used and thus the alternate + * service will be advertised on a different port than the HTTP_2 server's port
    • + *
    • ANY - A HTTP_2 server is created and a HTTP_3 server is created. + * The HTTP_2 server advertises the HTTP_3 server as an alternate service. When + * creating the HTTP_3 server, the same port as that of the HTTP_2 server is used + * to bind the HTTP_3 server. If that bind attempt fails, then an ephemeral port + * is used to bind the HTTP_3 server
    • + *
    + * + * @param serverVersion The HTTP version of the server + * @param sslContext The SSLContext to use. Can be null + * @param h3DiscoveryCfg The Http3DiscoveryMode for HTTP_3 server. Can be null, + * in which case it defaults to {@code ALT_SVC} for HTTP_3 + * server + * @param executor The executor to be used by the server. Can be null + * @return The newly created server + * @throws IllegalArgumentException if {@code serverVersion} is not supported by this method + * @throws IllegalArgumentException if {@code h3DiscoveryCfg} is not null when + * {@code serverVersion} is not {@code HTTP_3} + * @throws IOException if any exception occurs during the server creation + */ + private static HttpTestServer create(final Version serverVersion, final SSLContext sslContext, + final Http3DiscoveryMode h3DiscoveryCfg, + final ExecutorService executor) throws IOException { + Objects.requireNonNull(serverVersion); + if (h3DiscoveryCfg != null && serverVersion != HTTP_3) { + // Http3DiscoveryMode is only supported when version of HTTP_3 + throw new IllegalArgumentException("Http3DiscoveryMode" + + " isn't allowed for " + serverVersion + " version"); + } switch (serverVersion) { + case HTTP_3 -> { + if (sslContext == null) { + throw new IllegalArgumentException("SSLContext cannot be null when" + + " constructing a HTTP_3 server"); + } + final Http3DiscoveryMode effectiveDiscoveryCfg = h3DiscoveryCfg == null + ? Http3DiscoveryMode.ALT_SVC + : h3DiscoveryCfg; + switch (effectiveDiscoveryCfg) { + case HTTP_3_URI_ONLY -> { + // create only a HTTP3 server + return HttpTestServer.of(new Http3TestServer(sslContext, executor)); + } + case ALT_SVC -> { + // create a HTTP2 server which advertises an HTTP3 alternate service. + // that alternate service will be using an ephemeral port for the server + final Http2TestServer h2WithAltService; + try { + h2WithAltService = new Http2TestServer( + "localhost", true, 0, executor, sslContext) + .enableH3AltServiceOnEphemeralPort(); + } catch (Exception e) { + throw new IOException(e); + } + return HttpTestServer.of(h2WithAltService); + } + case ANY -> { + // create a HTTP2 server which advertises an HTTP3 alternate service. + // that alternate service will first attempt to use the same port as the + // HTTP2 server and if binding to that port fails, then will attempt + // to use a ephemeral port. + final Http2TestServer h2WithAltService; + try { + h2WithAltService = new Http2TestServer( + "localhost", true, 0, executor, sslContext) + .enableH3AltServiceOnSamePort(); + } catch (Exception e) { + throw new IOException(e); + } + return HttpTestServer.of(h2WithAltService); + } + default -> throw new IllegalArgumentException("Unsupported" + + " Http3DiscoveryMode: " + effectiveDiscoveryCfg); + } + } case HTTP_2 -> { Http2TestServer underlying; try { @@ -874,7 +1403,7 @@ public interface HttpServerAdapters { } return HttpTestServer.of(underlying); } - case HTTP_1_1 -> { + case HTTP_1_1 -> { InetAddress loopback = InetAddress.getLoopbackAddress(); InetSocketAddress sa = new InetSocketAddress(loopback, 0); HttpServer underlying; @@ -933,8 +1462,22 @@ public interface HttpServerAdapters { return new InetSocketAddress(InetAddress.getLoopbackAddress(), impl.getAddress().getPort()); } + @Override public Version getVersion() { return HTTP_1_1; } + @Override + public boolean canHandle(final Version version, final Version... more) { + if (version != HTTP_1_1) { + return false; + } + for (var v : more) { + if (v != HTTP_1_1) { + return false; + } + } + return true; + } + @Override public void setRequestApprover(final Predicate approver) { throw new UnsupportedOperationException("not supported"); @@ -996,8 +1539,41 @@ public interface HttpServerAdapters { return new InetSocketAddress(InetAddress.getLoopbackAddress(), impl.getAddress().getPort()); } + + @Override + public Optional getH3AltService() { + return impl.getH3AltService(); + } + + @Override + public boolean supportsH3DirectConnection() { + return impl.supportsH3DirectConnection(); + } + + public Http3DiscoveryMode h3DiscoveryConfig() { + return supportsH3DirectConnection() + ? Http3DiscoveryMode.ANY + : Http3DiscoveryMode.ALT_SVC; + } + public Version getVersion() { return HTTP_2; } + @Override + public boolean canHandle(final Version version, final Version... more) { + final Set supported = new HashSet<>(); + supported.add(HTTP_2); + impl.getH3AltService().ifPresent((unused)-> supported.add(HTTP_3)); + if (!supported.contains(version)) { + return false; + } + for (var v : more) { + if (!supported.contains(v)) { + return false; + } + } + return true; + } + @Override public void setRequestApprover(final Predicate approver) { this.impl.setRequestApprover(approver); @@ -1036,6 +1612,119 @@ public interface HttpServerAdapters { } @Override public Version getVersion() { return HTTP_2; } } + + private static final class H3ServerAdapter extends HttpTestServer { + private final Http3TestServer underlyingH3Server; + + private H3ServerAdapter(final Http3TestServer server) { + this.underlyingH3Server = server; + } + + @Override + public void start() { + underlyingH3Server.start(); + } + + @Override + public void stop() { + underlyingH3Server.stop(); + } + + @Override + public HttpTestContext addHandler(final HttpTestHandler handler, final String path) { + Objects.requireNonNull(path); + Objects.requireNonNull(handler); + final H3RootCtx h3Ctx = new H3RootCtx(path, handler); + this.underlyingH3Server.addHandler(path, h3Ctx::doHandle); + return h3Ctx; + } + + @Override + public InetSocketAddress getAddress() { + return underlyingH3Server.getAddress(); + } + + @Override + public Version getVersion() { + return HTTP_3; + } + + @Override + public Http3DiscoveryMode h3DiscoveryConfig() { + return Http3DiscoveryMode.HTTP_3_URI_ONLY; + } + + @Override + public boolean canHandle(Version version, Version... more) { + if (version != HTTP_3) { + return false; + } + for (var v : more) { + if (v != HTTP_3) { + return false; + } + } + return true; + } + + @Override + public void setRequestApprover(final Predicate approver) { + underlyingH3Server.setRequestApprover(approver); + } + + } + + private static final class H3RootCtx extends HttpTestContext implements HttpTestHandler { + private final String path; + private final HttpTestHandler handler; + private final List filters = new CopyOnWriteArrayList<>(); + + private H3RootCtx(final String path, final HttpTestHandler handler) { + this.path = path; + this.handler = handler; + } + + @Override + public String getPath() { + return this.path; + } + + @Override + public void addFilter(final HttpTestFilter filter) { + Objects.requireNonNull(filter); + this.filters.add(filter); + } + + @Override + public Version getVersion() { + return HTTP_3; + } + + @Override + public void setAuthenticator(final Authenticator authenticator) { + if (authenticator instanceof BasicAuthenticator basicAuth) { + addFilter(new HttpBasicAuthFilter(basicAuth)); + } else { + throw new UnsupportedOperationException( + "Only BasicAuthenticator is supported on an H3 context"); + } + } + + @Override + public void handle(final HttpTestExchange exchange) throws IOException { + HttpChain.of(this.filters, this.handler).doFilter(exchange); + } + + private void doHandle(final Http2TestExchange exchange) throws IOException { + final HttpTestExchange adapted = HttpTestExchange.of(exchange); + try { + H3RootCtx.this.handle(adapted); + } catch (Throwable failure) { + handleFailure(adapted, failure); + throw failure; + } + } + } } public static void enableServerLogging() { @@ -1044,4 +1733,73 @@ public interface HttpServerAdapters { HttpTestServer.ServerLogging.enableLogging(); } + public default HttpClient.Builder newClientBuilderForH3() { + return createClientBuilderForH3(); + } + + /** + * {@return a client builder suitable for interacting with the specified + * version} + * The builder's {@linkplain HttpClient.Builder#version(Version) version}, + * {@linkplain HttpClient.Builder#proxy(ProxySelector) proxy selector} + * and {@linkplain HttpClient.Builder#sslContext(SSLContext) SSL context} + * are not set. + * @apiNote This method sets the {@linkplain HttpClient.Builder#localAddress(InetAddress) + * bind address} to the {@linkplain InetAddress#getLoopbackAddress() loopback address} + * if version is HTTP/3, the OS is Mac, and the OS version is 10.X, in order to + * avoid conflicting with system allocated ephemeral UDP ports. + * @param version the highest version the client is assumed to interact with. + */ + public static HttpClient.Builder createClientBuilderFor(Version version) { + var builder = HttpClient.newBuilder(); + return switch (version) { + case HTTP_3 -> configureForH3(builder); + default -> builder; + }; + } + + /** + * {@return a client builder suitable for interacting with HTTP/3} + * The builder's {@linkplain HttpClient.Builder#version(Version) version}, + * {@linkplain HttpClient.Builder#proxy(ProxySelector) proxy selector} + * and {@linkplain HttpClient.Builder#sslContext(SSLContext) SSL context} + * are not set. + * @apiNote This method sets the {@linkplain HttpClient.Builder#localAddress(InetAddress) + * bind address} to the {@linkplain InetAddress#getLoopbackAddress() loopback address} + * if version is HTTP/3, the OS is Mac, and the OS version is 10.X, in order to + * avoid conflicting with system allocated ephemeral UDP ports. + * @implSpec This is identical to calling {@link #createClientBuilderFor(Version) + * newClientBuilderFor(Version.HTTP_3)} or {@link #configureForH3(Builder) + * configureForH3(HttpClient.newBuilder())} + */ + public static HttpClient.Builder createClientBuilderForH3() { + return configureForH3(HttpClient.newBuilder()); + } + + /** + * Configure a builder to be suitable for a client that may send requests + * through HTTP/3. + * The builder's {@linkplain HttpClient.Builder#version(Version) version}, + * {@linkplain HttpClient.Builder#proxy(ProxySelector) proxy selector} + * and {@linkplain HttpClient.Builder#sslContext(SSLContext) SSL context} + * are not set. + * @apiNote This method sets the {@linkplain HttpClient.Builder#localAddress(InetAddress) + * bind address} to the {@linkplain InetAddress#getLoopbackAddress() loopback address} + * if the OS is Mac, and the OS version is 10.X, in order to + * avoid conflicting with system allocated ephemeral UDP ports. + * @return a client builder suitable for interacting with HTTP/3 + */ + public static HttpClient.Builder configureForH3(HttpClient.Builder builder) { + if (TestUtil.sysPortsMayConflict()) { + return builder.localAddress(InetAddress.getLoopbackAddress()); + } + return builder; + } + + public static InetAddress clientLocalBindAddress() { + if (TestUtil.sysPortsMayConflict()) { + return InetAddress.getLoopbackAddress(); + } + return new InetSocketAddress(0).getAddress(); + } } diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/RequestPathMatcherUtil.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/RequestPathMatcherUtil.java new file mode 100644 index 00000000000..e732f45efe4 --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/RequestPathMatcherUtil.java @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.common; + +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Utility which parses a request path and finds a best match registered handler + */ +public class RequestPathMatcherUtil { + + public record Resolved(String bestMatchedPath, T handler) { + } + + /** + * Matches the {@code path} against the registered {@code pathHandlers} and returns the best + * matched handler. + * + * @param path The request path + * @param pathHandlers The handlers for each of the registered paths + * @param + * @return The resolved result or an {@linkplain Optional#empty() empty Optional} if no + * handler could be found for the {@code path} + * @throws NullPointerException if {@code pathHandlers} is null + */ + public static Optional> findHandler(final String path, + final Map pathHandlers) { + Objects.requireNonNull(pathHandlers, "pathHandlers is null"); + final String fpath = (path == null || path.isEmpty()) ? "/" : path; + final AtomicReference bestMatch = new AtomicReference<>(""); + final AtomicReference result = new AtomicReference<>(); + pathHandlers.forEach((key, value) -> { + if (fpath.startsWith(key) && key.length() > bestMatch.get().length()) { + bestMatch.set(key); + result.set(value); + } + }); + final T handler = result.get(); + if (handler == null) { + System.err.println("No handler found for path: " + path); + return Optional.empty(); + } + return Optional.of(new Resolved(bestMatch.get(), handler)); + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/TestServerConfigurator.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/TestServerConfigurator.java index a471f3ce07f..0d525a4f2c8 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/TestServerConfigurator.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/TestServerConfigurator.java @@ -57,11 +57,22 @@ public final class TestServerConfigurator extends HttpsConfigurator { @Override public void configure(final HttpsParameters params) { final SSLParameters sslParams = getSSLContext().getDefaultSSLParameters(); - final String hostname = serverAddr.getHostName(); - - final List sniMatchers = List.of(new ServerNameMatcher(hostname)); - sslParams.setSNIMatchers(sniMatchers); + addSNIMatcher(serverAddr, sslParams); // configure the server with these custom SSLParameters params.setSSLParameters(sslParams); } + + public static void addSNIMatcher(final InetAddress serverAddr, final SSLParameters sslParams) { + final String hostname; + if (serverAddr.isLoopbackAddress()) { + // when it's loopback address, don't rely on InetAddress.getHostName() to get us the + // hostname, since it has been observed on Windows setups that InetAddress.getHostName() + // can return an IP address (127.0.0.1) instead of the hostname for loopback address + hostname = "localhost"; + } else { + hostname = serverAddr.getHostName(); + } + final List sniMatchers = List.of(new ServerNameMatcher(hostname)); + sslParams.setSNIMatchers(sniMatchers); + } } diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/TestUtil.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/TestUtil.java new file mode 100644 index 00000000000..37ce49cd901 --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/common/TestUtil.java @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2015, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.httpclient.test.lib.common; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Optional; + +import jdk.internal.util.OperatingSystem; + +public final class TestUtil { + + private TestUtil() {} + + public static boolean sysPortsMayConflict() { + if (OperatingSystem.isMacOS()) { + // syslogd udp_in module may be dynamically started and opens an udp4 port + // on the wildcard address. In addition, macOS will allow different processes + // to bind to the same port on the wildcard, if one uses udp4 and the other + // binds using udp46 (dual IPv4 IPv6 socket). + // Binding to the loopback (or a specific interface) instead of binding + // to the wildcard can prevent such conflicts. + return true; + } + return false; + } + + public static Optional chooseClientBindAddress() { + if (!TestUtil.sysPortsMayConflict()) { + return Optional.empty(); + } + final InetSocketAddress address = new InetSocketAddress( + InetAddress.getLoopbackAddress(), 0); + return Optional.of(address); + } + +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/BodyOutputStream.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/BodyOutputStream.java index c6eee5afabf..33cca60ff38 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/BodyOutputStream.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/BodyOutputStream.java @@ -26,6 +26,7 @@ package jdk.httpclient.test.lib.http2; import java.io.*; import java.nio.ByteBuffer; import java.util.Objects; +import java.util.concurrent.Semaphore; import jdk.internal.net.http.frame.DataFrame; import jdk.internal.net.http.frame.ResetFrame; @@ -39,7 +40,8 @@ public class BodyOutputStream extends OutputStream { final static byte[] EMPTY_BARRAY = new byte[0]; final int streamid; - int window; + // stream level send window, permits = bytes + final Semaphore window; volatile boolean closed; volatile BodyInputStream bis; volatile int resetErrorCode; @@ -48,7 +50,7 @@ public class BodyOutputStream extends OutputStream { final Queue outputQ; BodyOutputStream(int streamid, int initialWindow, Http2TestServerConnection conn) { - this.window = initialWindow; + this.window = new Semaphore(initialWindow); this.streamid = streamid; this.conn = conn; this.outputQ = conn.outputQ; @@ -57,9 +59,8 @@ public class BodyOutputStream extends OutputStream { // called from connection reader thread as all incoming window // updates are handled there. - synchronized void updateWindow(int update) { - window += update; - notifyAll(); + void updateWindow(int update) { + window.release(update); } void waitForWindow(int demand) throws InterruptedException { @@ -69,23 +70,8 @@ public class BodyOutputStream extends OutputStream { conn.obtainConnectionWindow(demand); } - public void waitForStreamWindow(int amount) throws InterruptedException { - int demand = amount; - try { - synchronized (this) { - while (amount > 0) { - int n = Math.min(amount, window); - amount -= n; - window -= n; - if (amount > 0) { - wait(); - } - } - } - } catch (Throwable t) { - window += (demand - amount); - throw t; - } + public void waitForStreamWindow(int demand) throws InterruptedException { + window.acquire(demand); } public void goodToGo() { @@ -176,7 +162,7 @@ public class BodyOutputStream extends OutputStream { sendEndStream(); if (bis!= null && bis.unconsumed()) { // Send a reset if there is still unconsumed data in the input stream - sendReset(EMPTY_BARRAY, 0, 0, ResetFrame.NO_ERROR); + sendReset(ResetFrame.NO_ERROR); } } catch (IOException ex) { ex.printStackTrace(); @@ -187,12 +173,18 @@ public class BodyOutputStream extends OutputStream { send(EMPTY_BARRAY, 0, 0, DataFrame.END_STREAM); } - public void sendReset(byte[] buf, int offset, int len, int flags) throws IOException { - ByteBuffer buffer = ByteBuffer.allocate(len); - buffer.put(buf, offset, len); - buffer.flip(); + public void sendReset(int resetErrorCode) throws IOException { assert streamid != 0; - ResetFrame rf = new ResetFrame(streamid, flags); + ResetFrame rf = new ResetFrame(streamid, resetErrorCode); outputQ.put(rf); } + + public void reset(int resetErrorCode) throws IOException { + if (closed) return; + synchronized (this) { + if (closed) return; + this.closed = true; + } + sendReset(resetErrorCode); + } } diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/EchoHandler.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/EchoHandler.java index 24ac59d96dc..047b52b3699 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/EchoHandler.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/EchoHandler.java @@ -24,11 +24,12 @@ package jdk.httpclient.test.lib.http2; import java.io.*; +import java.net.http.HttpHeaders; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import jdk.internal.net.http.common.HttpHeadersImpl; +import jdk.internal.net.http.common.HttpHeadersBuilder; public class EchoHandler implements Http2Handler { static final Path CWD = Paths.get("."); @@ -41,8 +42,8 @@ public class EchoHandler implements Http2Handler { try { System.err.println("EchoHandler received request to " + t.getRequestURI()); InputStream is = t.getRequestBody(); - HttpHeadersImpl map = t.getRequestHeaders(); - HttpHeadersImpl map1 = t.getResponseHeaders(); + HttpHeaders map = t.getRequestHeaders(); + HttpHeadersBuilder map1 = t.getResponseHeaders(); map1.addHeader("X-Hello", "world"); map1.addHeader("X-Bye", "universe"); String fixedrequest = map.firstValue("XFixed").orElse(null); diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2EchoHandler.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2EchoHandler.java index 7b13724ac51..fd0b03ac691 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2EchoHandler.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2EchoHandler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2005, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2005, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -27,8 +27,11 @@ import java.net.http.HttpHeaders; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; + import jdk.internal.net.http.common.HttpHeadersBuilder; +import static jdk.test.lib.Asserts.assertFileContentsEqual; + public class Http2EchoHandler implements Http2Handler { static final Path CWD = Paths.get("."); @@ -49,27 +52,42 @@ public class Http2EchoHandler implements Http2Handler { File outfile = Files.createTempFile(CWD, "foo", "bar").toFile(); //System.err.println ("QQQ = " + outfile.toString()); FileOutputStream fos = new FileOutputStream(outfile); - int count = (int) is.transferTo(fos); - System.err.printf("EchoHandler read %d bytes\n", count); + long count = is.transferTo(fos); + System.err.printf("EchoHandler read %s bytes\n", count); is.close(); fos.close(); InputStream is1 = new FileInputStream(outfile); OutputStream os = null; + + Path check = map.firstValue("X-Compare").map((String s) -> Path.of(s)).orElse(null); + if (check != null) { + System.err.println("EchoHandler checking file match: " + check); + try { + assertFileContentsEqual(check, outfile.toPath()); + } catch (Throwable x) { + System.err.println("Files do not match: " + x); + t.sendResponseHeaders(500, -1); + outfile.delete(); + os.close(); + return; + } + } + // return the number of bytes received (no echo) String summary = map.firstValue("XSummary").orElse(null); if (fixedrequest != null && summary == null) { t.sendResponseHeaders(200, count); os = t.getResponseBody(); - int count1 = (int)is1.transferTo(os); - System.err.printf("EchoHandler wrote %d bytes\n", count1); + long count1 = is1.transferTo(os); + System.err.printf("EchoHandler wrote %s bytes\n", count1); } else { t.sendResponseHeaders(200, 0); os = t.getResponseBody(); - int count1 = (int)is1.transferTo(os); - System.err.printf("EchoHandler wrote %d bytes\n", count1); + long count1 = is1.transferTo(os); + System.err.printf("EchoHandler wrote %s bytes\n", count1); if (summary != null) { - String s = Integer.toString(count); + String s = Long.toString(count); os.write(s.getBytes()); } } diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2Handler.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2Handler.java index 8871f0c2cd5..c1a5538b95d 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2Handler.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2Handler.java @@ -37,6 +37,5 @@ public interface Http2Handler { * client and used to send the response * @throws NullPointerException if exchange is null */ - void handle (Http2TestExchange exchange) throws IOException; + void handle(Http2TestExchange exchange) throws IOException; } - diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2RedirectHandler.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2RedirectHandler.java index 69e4344aff2..ab817d0c79d 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2RedirectHandler.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2RedirectHandler.java @@ -27,6 +27,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.function.Supplier; + import jdk.internal.net.http.common.HttpHeadersBuilder; public class Http2RedirectHandler implements Http2Handler { @@ -48,8 +49,8 @@ public class Http2RedirectHandler implements Http2Handler { System.err.println("Redirecting to: " + location); HttpHeadersBuilder headersBuilder = t.getResponseHeaders(); headersBuilder.addHeader("Location", location); - t.sendResponseHeaders(redirectCode(), 1024); - byte[] bb = new byte[1024]; + byte[] bb = getResponseBytes(); + t.sendResponseHeaders(redirectCode(), bb.length == 0 ? -1 : bb.length); OutputStream os = t.getResponseBody(); os.write(bb); os.close(); @@ -57,6 +58,10 @@ public class Http2RedirectHandler implements Http2Handler { } } + protected byte[] getResponseBytes() { + return new byte[1024]; + } + protected int redirectCode() { return 301; } diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestExchange.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestExchange.java index 828c939f53f..77c9e831fb0 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestExchange.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestExchange.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2017, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -28,12 +28,18 @@ import java.io.InputStream; import java.io.OutputStream; import java.net.URI; import java.net.InetSocketAddress; +import java.net.http.HttpClient.Version; import java.net.http.HttpHeaders; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.BiPredicate; import javax.net.ssl.SSLSession; + import jdk.internal.net.http.common.HttpHeadersBuilder; +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.qpack.Encoder; +import jdk.internal.net.http.quic.VariableLengthEncoder; import jdk.internal.net.http.frame.Http2Frame; public interface Http2TestExchange { @@ -72,19 +78,177 @@ public interface Http2TestExchange { boolean serverPushAllowed(); - void serverPush(URI uri, HttpHeaders headers, InputStream content); + default void serverPush(URI uri, HttpHeaders headers, InputStream content) + throws IOException { + serverPush(uri, headers, HttpHeaders.of(Map.of(), (n,v) -> true), content); + } + + void serverPush(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream content) + throws IOException; + + // For HTTP/3 only: send push promise + push stream, returns pushId + // + + /** + * For HTTP/3 only: send push promise + push stream, returns pushId. + * The pushId can be promised again using {@link + * #sendPushId(long, URI, HttpHeaders)} + * + * @implSpec + * The default implementation of this method throws {@link + * UnsupportedOperationException} + * + * @param uri the push promise URI + * @param reqHeaders the push promise request headers + * @param rspHeaders the push promise response headers + * @param content the push response body + * + * @return the pushId used to push the promise + * + * @throws IOException if an error occurs + * @throws UnsupportedOperationException if the exchange is not {@link + * #getExchangeVersion() HTTP_3} + */ + default long serverPushWithId(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream content) + throws IOException { + throw new UnsupportedOperationException("serverPushWithId " + getExchangeVersion()); + } + + /** + * For HTTP/3 only: only sends a push promise frame. If a positive + * pushId is provided, uses the provided pushId and returns it. + * Otherwise, a new pushId will be allocated and returned. + * This allows to send an additional promise after {@linkplain + * #serverPushWithId(URI, HttpHeaders, HttpHeaders, InputStream) sending the first}, + * or to send one or several push promise frames before {@linkplain + * #sendPushResponse(long, URI, HttpHeaders, HttpHeaders, InputStream) sending + * the response}. + * + * @implSpec + * The default implementation of this method throws {@link + * UnsupportedOperationException} + * + * @param pushId the pushId to use, or {@code -1} if a new + * pushId should be allocated. + * @param uri the push promise URI + * @param headers the push promise request headers + * + * @return the given pushId, if positive, otherwise the new allocated pushId + * + * @throws IOException if an error occurs + * @throws UnsupportedOperationException if the exchange is not {@link + * #getExchangeVersion() HTTP_3} + */ + default long sendPushId(long pushId, URI uri, HttpHeaders headers) throws IOException { + throw new UnsupportedOperationException("sendPushId with " + getExchangeVersion()); + } + + /** + * For HTTP/3 only: sends an HTTP/3 CANCEL_PUSH frame to cancel + * a push that has been promised by either {@link + * #serverPushWithId(URI, HttpHeaders, HttpHeaders, InputStream)} or {@link + * #sendPushId(long, URI, HttpHeaders)}. + * + * This method just sends a CANCEL_PUSH frame. + * Note that if the push stream has already been opened this + * sending a CANCEL_PUSH frame may have no effect. + * + * @apiNote + * No check is performed on the provided pushId + * + * @implSpec + * The default implementation of this method throws {@link + * UnsupportedOperationException} + * + * @param pushId the cancelled pushId + * + * @throws IOException if an error occurs + * @throws UnsupportedOperationException if the exchange is not {@link + * #getExchangeVersion() HTTP_3} + */ + default void cancelPushId(long pushId) throws IOException { + throw new UnsupportedOperationException("cancelPush with " + getExchangeVersion()); + } + + /** + * For HTTP/3 only: opens an HTTP/3 PUSH_STREAM to send a + * push promise response headers and body. + * + * @apiNote + * No check is performed on the provided pushId + * + * @param pushId a positive pushId obtained from {@link + * #sendPushId(long, URI, HttpHeaders)} + * @param uri the push request URI + * @param reqHeaders the push promise request headers + * @param rspHeaders the push promise response headers + * @param content the push response body + * + * @throws IOException if an error occurs + * @throws UnsupportedOperationException if the exchange is not {@link + */ + default void sendPushResponse(long pushId, URI uri, + HttpHeaders reqHeaders, + HttpHeaders rspHeaders, + InputStream content) + throws IOException { + throw new UnsupportedOperationException("sendPushResponse with " + getExchangeVersion()); + } + + default void requestStopSending(long errorCode) { + throw new UnsupportedOperationException("sendStopSendingFrame with " + getExchangeVersion()); + } default void sendFrames(List frames) throws IOException { throw new UnsupportedOperationException("not implemented"); } /** - * Send a PING on this exchanges connection, and completes the returned CF + * For HTTP/3 only: waits until the given {@code pushId} is allowed by + * the HTTP/3 peer. + * + * @implSpec + * The default implementation of this method returns the larger + * possible variable length integer. + * + * @param pushId a pushId + * + * @return the upper bound pf the maximum pushId allowed (exclusive) + * + * @throws UnsupportedOperationException if the exchange is not {@link + * #getExchangeVersion() HTTP_3} + */ + default long waitForMaxPushId(long pushId) throws InterruptedException { + return VariableLengthEncoder.MAX_ENCODED_INTEGER; + } + + default Encoder qpackEncoder() { + throw new UnsupportedOperationException("QPack encoder not supported: " + getExchangeVersion()); + } + + default CompletableFuture clientHttp3Settings() { + throw new UnsupportedOperationException("HTTP/3 client connection settings not supported: " + getExchangeVersion()); + } + + /** + * Send a PING on this exchange connection, and completes the returned CF * with the number of milliseconds it took to get a valid response. * It may also complete exceptionally */ CompletableFuture sendPing(); + default void close(IOException closed) throws IOException { + close(); + } + + default Version getServerVersion() { return Version.HTTP_2; } + + default Version getExchangeVersion() { return Version.HTTP_2; } + + default void resetStream(long code) throws IOException { + throw new UnsupportedOperationException("resetStream with " + getExchangeVersion()); + } + /** * {@return the identification of the connection on which this exchange is being * processed} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestExchangeImpl.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestExchangeImpl.java index 12324b3ba0b..8bc0fd2473c 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestExchangeImpl.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestExchangeImpl.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -149,12 +149,12 @@ public class Http2TestExchangeImpl implements Http2TestExchange { long clen = responseLength > 0 ? responseLength : 0; rspheadersBuilder.setHeader("Content-length", Long.toString(clen)); } - - rspheadersBuilder.setHeader(":status", Integer.toString(rCode)); - HttpHeaders headers = rspheadersBuilder.build(); - + final HttpHeadersBuilder pseudoHeadersBuilder = new HttpHeadersBuilder(); + pseudoHeadersBuilder.setHeader(":status", Integer.toString(rCode)); + final HttpHeaders pseudoHeaders = pseudoHeadersBuilder.build(); + final HttpHeaders headers = rspheadersBuilder.build(); ResponseHeaders response - = new ResponseHeaders(headers, insertionPolicy); + = new ResponseHeaders(pseudoHeaders, headers, insertionPolicy); response.streamid(streamid); response.setFlag(HeaderFrame.END_HEADERS); @@ -184,6 +184,13 @@ public class Http2TestExchangeImpl implements Http2TestExchange { conn.sendFrames(frames); } + @Override + public void resetStream(long code) throws IOException { + // will close the os if not closed. + // reset will be sent only if the os is not closed. + os.sendReset((int) code); + } + @Override public InetSocketAddress getRemoteAddress() { return (InetSocketAddress) conn.socket.getRemoteSocketAddress(); @@ -210,18 +217,18 @@ public class Http2TestExchangeImpl implements Http2TestExchange { } @Override - public void serverPush(URI uri, HttpHeaders headers, InputStream content) { + public void serverPush(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream content) { HttpHeadersBuilder headersBuilder = new HttpHeadersBuilder(); headersBuilder.setHeader(":method", "GET"); headersBuilder.setHeader(":scheme", uri.getScheme()); headersBuilder.setHeader(":authority", uri.getAuthority()); headersBuilder.setHeader(":path", uri.getPath()); - for (Map.Entry> entry : headers.map().entrySet()) { + for (Map.Entry> entry : reqHeaders.map().entrySet()) { for (String value : entry.getValue()) headersBuilder.addHeader(entry.getKey(), value); } HttpHeaders combinedHeaders = headersBuilder.build(); - OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, combinedHeaders, content); + OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, combinedHeaders, rspHeaders, content); pp.setFlag(HeaderFrame.END_HEADERS); try { diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestServer.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestServer.java index a4d696ee53c..543c4079c14 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestServer.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestServer.java @@ -24,22 +24,39 @@ package jdk.httpclient.test.lib.http2; import java.io.IOException; -import java.net.*; -import java.util.*; +import java.net.BindException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Predicate; import javax.net.ServerSocketFactory; +import javax.net.ssl.SNIMatcher; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLServerSocket; +import jdk.httpclient.test.lib.common.RequestPathMatcherUtil; +import jdk.httpclient.test.lib.common.RequestPathMatcherUtil.Resolved; +import jdk.httpclient.test.lib.common.ServerNameMatcher; +import jdk.httpclient.test.lib.http3.Http3TestServer; import jdk.internal.net.http.frame.ErrorFrame; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.quic.QuicVersion; +import jdk.httpclient.test.lib.quic.QuicServer; /** * Waits for incoming TCP connections from a client and establishes @@ -48,7 +65,10 @@ import jdk.internal.net.http.frame.ErrorFrame; * Http2Handler on additional threads. All threads * obtained from the supplied ExecutorService. */ -public class Http2TestServer implements AutoCloseable { +public final class Http2TestServer implements AutoCloseable { + + record AltSvcAddr(String host, int port, InetSocketAddress original) {} + static final AtomicLong IDS = new AtomicLong(); final long id = IDS.incrementAndGet(); final ServerSocket server; @@ -61,7 +81,12 @@ public class Http2TestServer implements AutoCloseable { final String serverName; final Set connections; final Properties properties; + volatile Http3TestServer h3Server; + volatile AltSvcAddr h3AltSvcAddr; final String name; + private final SNIMatcher sniMatcher; + volatile boolean altSvcAsRespHeader; + private final ReentrantLock serverLock = new ReentrantLock(); // request approver which takes the server connection key as the input private volatile Predicate newRequestApprover; @@ -179,7 +204,6 @@ public class Http2TestServer implements AutoCloseable { throws Exception { this.name = "TestServer(%d)".formatted(id); - this.serverName = serverName; this.supportsHTTP11 = supportsHTTP11; if (secure) { if (context != null) @@ -192,12 +216,166 @@ public class Http2TestServer implements AutoCloseable { server = initPlaintext(port, backlog); } this.secure = secure; + this.serverName = serverName; + this.sniMatcher = serverName == null + ? new ServerNameMatcher(localAddr.getHostName()) + : new ServerNameMatcher(this.serverName); this.exec = exec == null ? createExecutor(name) : exec; this.handlers = Collections.synchronizedMap(new HashMap<>()); this.properties = properties == null ? new Properties() : properties; this.connections = ConcurrentHashMap.newKeySet(); } + /** + * {@return the {@link SNIMatcher} configured for this server. Returns {@code null} + * if none is configured} + */ + public SNIMatcher getSniMatcher() { + return this.sniMatcher; + } + + /** + * Creates a H3 server which will attempt to use the same host/port as the one used + * by this current H2 server (except that the H3 server will use UDP). If that host/port + * isn't available for the H3 server then it uses an ephemeral port on loopback address. + * That H3 server then acts as the alternate service for this H2 server and will be advertised + * as such when this H2 server responds to any HTTP requests. + */ + public Http2TestServer enableH3AltServiceOnSamePort() throws IOException { + this.enableH3AltService(false); + return this; + } + + /** + * Creates a H3 server that acts as the alternate service for this H2 server and will be advertised + * as such when this H2 server responds to any HTTP requests. + *

    + * The H3 server will be created using an ephemeral port on loopback address. + *

    + */ + public Http2TestServer enableH3AltServiceOnEphemeralPort() throws IOException { + return enableH3AltService(true); + } + + /** + * Creates a H3 server that acts as the alternate service for this H2 server and will be advertised + * as such when this H2 server responds to any HTTP requests. + *

    + * The H3 server will be created using an ephemeral port on loopback address. + *

    + * The server will switch to selected QUIC version using compatible or incompatible negotiation. + */ + public Http2TestServer enableH3AltServiceOnEphemeralPortWithVersion(QuicVersion version, boolean compatible) throws IOException { + return enableH3AltService(0, new QuicVersion[]{version}, compatible); + } + + public Http2TestServer enableH3AltServiceOnPort(int port) throws IOException { + return enableH3AltService(port, new QuicVersion[]{QuicVersion.QUIC_V1}); + } + + /** + * Creates a H3 server that acts as the alternate service for this H2 server and will be advertised + * as such when this H2 server responds to any HTTP requests. + *

    + * If {@code useEphemeralAddr} is {@code true} then the H3 server will be created using an + * ephemeral port on loopback address. Otherwise, an attempt will be made by this current H2 + * server (except that the H3 server will use UDP). If that attempt fails then this method + * implementation will fallback to using an ephemeral port for creating the H3 server. + *

    + * @param useEphemeralAddr If true then the H3 server will be created using an ephemeral port + */ + Http2TestServer enableH3AltService(final boolean useEphemeralAddr) throws IOException { + return enableH3AltService(useEphemeralAddr ? 0 : getAddress().getPort(), new QuicVersion[]{QuicVersion.QUIC_V1}); + } + + Http2TestServer enableH3AltService(final int port, QuicVersion[] quicVersions) throws IOException { + return enableH3AltService(port, quicVersions, false); + } + + Http2TestServer enableH3AltService(final int port, QuicVersion[] quicVersions, boolean compatible) throws IOException { + if (this.h3Server != null) { + // already enabled + // TODO: throw exception instead? + return this; + } + if (!secure) { + throw new IllegalStateException("Cannot enable H3 alt service for a non-secure H2 server"); + } + serverLock.lock(); + try { + if (this.h3Server != null) { + return this; + } + QuicServer.Builder quicServerBuilder = Http3TestServer.quicServerBuilder(); + quicServerBuilder.sslContext(this.sslContext).serverId("h2-server-" + id) + .executor(this.exec) + .sniMatcher(this.sniMatcher) + .availableVersions(quicVersions) + .compatibleNegotiation(compatible) + .appErrorCodeToString(Http3Error::stringForCode) + .bindAddress(new InetSocketAddress(InetAddress.getLoopbackAddress(), port)); + try { + this.h3Server = new Http3TestServer(quicServerBuilder.build(), this::getHandlerFor); + } catch (BindException be) { + if (port == 0) { + // this means that we already attempted to bind with an ephemeral port + // and it failed, so no need to attempt again. Just throw back the original + // exception + throw be; + } + // try with an ephemeral port + quicServerBuilder.bindAddress(new InetSocketAddress( + InetAddress.getLoopbackAddress(), 0)); + this.h3Server = new Http3TestServer(quicServerBuilder.build(), this::getHandlerFor); + } + // we keep track of the InetSocketAddress.getHostString() when the alt service address + // was created and keep using the same host string irrespective of whether the + // underlying/original InetSocketAddress' hostname resolution could potentially have + // changed the value returned by getHostString(). This allows us to use a consistent + // host in the alt-svc that we advertise. + this.h3AltSvcAddr = new AltSvcAddr(h3Server.getAddress().getHostString(), + h3Server.getAddress().getPort(), h3Server.getAddress()); + } finally { + serverLock.unlock(); + } + return this; + } + + public Optional getH3AltService() { + return Optional.ofNullable(h3Server); + } + + /** + * {@return true if this H2 server is configured with an H3 alternate service and that + * H3 alternate service listens on the same host and port as that of this H2 server (except + * that H3 uses UDP). Returns false otherwise} + */ + public boolean supportsH3DirectConnection() { + final AltSvcAddr h3Addr = this.h3AltSvcAddr; + if (h3Addr == null) { + return false; + } + final InetSocketAddress h2Addr = this.getAddress(); + return h2Addr.equals(h3Addr.original); + } + + /** + * Controls whether this H2 server sends a alt-svc response header or an AltSvc frame + * when H3 alternate service is enable on this server. The alt-svc header or the frame + * will be sent whenever this server responds next to an HTTP request. + * + * @param enable If {@code true} then the alt-svc response header is sent. Else AltSvc frame + * is sent. + * @return The current Http2TestServer + */ + public Http2TestServer advertiseAltSvcResponseHeader(final boolean enable) { + // TODO: this is only set as a flag currently. We need to implement the logic + // which sends the alt-svc response header. we currently send a alt-svc frame + // whenever a alt-svc is present. + this.altSvcAsRespHeader = enable; + return this; + } + /** * Adds the given handler for the given path */ @@ -218,24 +396,14 @@ public class Http2TestServer implements AutoCloseable { } Http2Handler getHandlerFor(String path) { - if (path == null || path.equals("")) - path = "/"; - - final String fpath = path; - AtomicReference bestMatch = new AtomicReference<>(""); - AtomicReference href = new AtomicReference<>(); - - handlers.forEach((key, value) -> { - if (fpath.startsWith(key) && key.length() > bestMatch.get().length()) { - bestMatch.set(key); - href.set(value); - } - }); - Http2Handler handler = href.get(); - if (handler == null) - throw new RuntimeException("No handler found for path " + path); - System.err.println(name + ": Using handler for: " + bestMatch.get()); - return handler; + final Optional> match = RequestPathMatcherUtil.findHandler(path, handlers); + if (match.isEmpty()) { + // no handler available for the path + return null; + } + final Resolved resolved = match.get(); + System.err.println(name + ": Using handler for: " + resolved.bestMatchedPath()); + return resolved.handler(); } final ServerSocket initPlaintext(int port, int backlog) throws Exception { @@ -245,7 +413,16 @@ public class Http2TestServer implements AutoCloseable { return ss; } - public synchronized void stop() { + public void stop() { + serverLock.lock(); + try { + implStop(); + } finally { + serverLock.unlock(); + } + } + + private void implStop() { // TODO: clean shutdown GoAway stopping = true; System.err.printf("%s: stopping %d connections\n", name, connections.size()); @@ -255,6 +432,12 @@ public class Http2TestServer implements AutoCloseable { try { server.close(); } catch (IOException e) {} + try { + var h3Server = this.h3Server; + if (h3Server != null) { + h3Server.close(); + } + } catch (IOException e) {} exec.shutdownNow(); } @@ -284,6 +467,16 @@ public class Http2TestServer implements AutoCloseable { return serverName; } + private void putConnection(InetSocketAddress addr, Http2TestServerConnection c) { + serverLock.lock(); + try { + if (!stopping) + connections.add(c); + } finally { + serverLock.unlock(); + } + } + public void setRequestApprover(final Predicate approver) { this.newRequestApprover = approver; } @@ -292,13 +485,13 @@ public class Http2TestServer implements AutoCloseable { return this.newRequestApprover; } - private synchronized void putConnection(InetSocketAddress addr, Http2TestServerConnection c) { - if (!stopping) - connections.add(c); - } - - private synchronized void removeConnection(InetSocketAddress addr, Http2TestServerConnection c) { - connections.remove(c); + private void removeConnection(InetSocketAddress addr, Http2TestServerConnection c) { + serverLock.lock(); + try { + connections.remove(c); + } finally { + serverLock.unlock(); + } } record AcceptedConnection(Http2TestServer server, @@ -350,6 +543,10 @@ public class Http2TestServer implements AutoCloseable { * Starts a thread which waits for incoming connections. */ public void start() { + var h3Server = this.h3Server; + if (h3Server != null) { + h3Server.start(); + } exec.submit(() -> { try { while (!stopping) { diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestServerConnection.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestServerConnection.java index deec3ec2c24..20668d281c8 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestServerConnection.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/Http2TestServerConnection.java @@ -41,25 +41,19 @@ import jdk.internal.net.http.frame.SettingsFrame; import jdk.internal.net.http.frame.WindowUpdateFrame; import jdk.internal.net.http.hpack.Decoder; import jdk.internal.net.http.hpack.DecodingCallback; -import jdk.internal.net.http.hpack.Encoder; import sun.net.www.http.ChunkedInputStream; import sun.net.www.http.HttpClient; -import javax.net.ssl.SNIHostName; import javax.net.ssl.SNIMatcher; -import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; -import javax.net.ssl.StandardConstants; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.Closeable; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.io.UncheckedIOException; -import java.net.InetAddress; import java.net.Socket; import java.net.URI; import java.net.URISyntaxException; @@ -74,20 +68,25 @@ import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Properties; import java.util.Random; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutorService; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.Semaphore; import java.util.concurrent.locks.ReentrantLock; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiPredicate; import java.util.function.Consumer; + +import jdk.internal.net.http.frame.AltSvcFrame; + import java.util.function.Predicate; import static java.nio.charset.StandardCharsets.ISO_8859_1; +import static java.nio.charset.StandardCharsets.US_ASCII; import static java.nio.charset.StandardCharsets.UTF_8; import static jdk.internal.net.http.frame.ErrorFrame.REFUSED_STREAM; import static jdk.internal.net.http.frame.SettingsFrame.DEFAULT_MAX_FRAME_SIZE; @@ -125,7 +124,6 @@ public class Http2TestServerConnection { private final AtomicInteger goAwayRequestStreamId = new AtomicInteger(-1); final static ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); - final static byte[] EMPTY_BARRAY = new byte[0]; final Random random; final static byte[] clientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".getBytes(); @@ -175,7 +173,7 @@ public class Http2TestServerConnection { if (socket instanceof SSLSocket) { SSLSocket sslSocket = (SSLSocket)socket; - handshake(server.serverName(), sslSocket); + handshake(server.getSniMatcher(), sslSocket); if (!server.supportsHTTP11 && !"h2".equals(sslSocket.getApplicationProtocol())) { throw new IOException("Unexpected ALPN: [" + sslSocket.getApplicationProtocol() + "]"); } @@ -312,39 +310,12 @@ public class Http2TestServerConnection { outputQ.put(frame); } - private static boolean compareIPAddrs(InetAddress addr1, String host) { - try { - InetAddress addr2 = InetAddress.getByName(host); - return addr1.equals(addr2); - } catch (IOException e) { - throw new UncheckedIOException(e); + private static void handshake(final SNIMatcher sniMatcher, final SSLSocket sock) throws IOException { + if (sniMatcher != null) { + final SSLParameters params = sock.getSSLParameters(); + params.setSNIMatchers(List.of(sniMatcher)); + sock.setSSLParameters(params); } - } - - private static void handshake(String name, SSLSocket sock) throws IOException { - if (name == null) { - sock.startHandshake(); // blocks until handshake done - return; - } else if (name.equals("localhost")) { - name = "localhost"; - } - final String fname = name; - final InetAddress addr1 = InetAddress.getByName(name); - SSLParameters params = sock.getSSLParameters(); - SNIMatcher matcher = new SNIMatcher(StandardConstants.SNI_HOST_NAME) { - public boolean matches (SNIServerName n) { - String host = ((SNIHostName)n).getAsciiName(); - if (host.equals("localhost")) - host = "localhost"; - boolean cmp = host.equalsIgnoreCase(fname); - if (cmp) - return true; - return compareIPAddrs(addr1, host); - } - }; - List list = List.of(matcher); - params.setSNIMatchers(list); - sock.setSSLParameters(params); sock.startHandshake(); // blocks until handshake done } @@ -356,8 +327,9 @@ public class Http2TestServerConnection { if (stopping) return; stopping = true; - System.err.printf(server.name + ": Server connection to %s stopping. %d streams\n", - socket.getRemoteSocketAddress().toString(), streams.size()); + System.err.printf(server.name + ": Server connection to %s stopping (%s). %d streams\n", + socket.getRemoteSocketAddress().toString(), + (error == -1 ? "no error" : ("error="+error)), streams.size()); streams.forEach((i, q) -> { q.orderlyClose(); }); @@ -374,15 +346,20 @@ public class Http2TestServerConnection { private void readPreface() throws IOException { int len = clientPreface.length; byte[] bytes = new byte[len]; + System.err.println("reading preface"); int n = is.readNBytes(bytes, 0, len); - if (Arrays.compare(clientPreface, bytes) != 0) { - String msg = String.format("Invalid preface: read %s/%s bytes", n, len); - System.err.println(server.name + ": " + msg); - throw new IOException(msg +": \"" + - new String(bytes, 0, n, ISO_8859_1) - .replace("\r", "\\r") - .replace("\n", "\\n") - + "\""); + if (n >= 0) { + if (Arrays.compare(clientPreface, bytes) != 0) { + String msg = String.format("Invalid preface: read %s/%s bytes", n, len); + System.err.println(server.name + ": " + msg); + throw new IOException(msg +": \"" + + new String(bytes, 0, n, ISO_8859_1) + .replace("\r", "\\r") + .replace("\n", "\\n") + + "\""); + } + } else { + throw new IOException("EOF while reading preface"); } } @@ -686,7 +663,7 @@ public class Http2TestServerConnection { // all other streams created here @SuppressWarnings({"rawtypes","unchecked"}) - void createStream(HeaderFrame frame) throws IOException { + private boolean createStream(HeaderFrame frame, Http2TestServer.AltSvcAddr altSvcAddr) throws IOException { List frames = new LinkedList<>(); frames.add(frame); int streamid = frame.streamid(); @@ -725,7 +702,7 @@ public class Http2TestServerConnection { if (disallowedHeader.isPresent()) { throw new IOException("Unexpected HTTP2-Settings in headers:" + headers); } - + boolean altSvcSent = false; // skip processing the request if the server is configured to do so final String connKey = connectionKey(); final String path = headers.firstValue(":path").orElse(""); @@ -734,10 +711,21 @@ public class Http2TestServerConnection { + " and sending GOAWAY on server connection " + connKey + ", for request: " + path); sendGoAway(ErrorFrame.NO_ERROR); - return; + return altSvcSent; } + Queue q = new Queue(sentinel); streams.put(streamid, q); + + if (altSvcAddr != null) { + String originHost = headers.firstValue("host") + .or(() -> headers.firstValue(":authority")) + .orElse(null); + if (originHost != null) { + altSvcSent = sendAltSvc(originHost, altSvcAddr); + } + } + // keep track of the largest request id that we have processed int currentLargest = maxProcessedRequestStreamId.get(); while (streamid > currentLargest) { @@ -749,6 +737,7 @@ public class Http2TestServerConnection { exec.submit(() -> { handleRequest(headers, q, streamid, endStreamReceived); }); + return altSvcSent; } // runs in own thread. Handles request from start to finish. Incoming frames @@ -793,14 +782,19 @@ public class Http2TestServerConnection { headers, rspheadersBuilder, uri, bis, getSSLSession(), bos, this, pushAllowed); - // give to user - Http2Handler handler = server.getHandlerFor(uri.getPath()); - - // Need to pass the BodyInputStream reference to the BodyOutputStream, so it can determine if the stream - // must be reset due to the BodyInputStream not being consumed by the handler when invoked. - if (bis instanceof BodyInputStream bodyInputStream) bos.bis = bodyInputStream; - + final String reqPath = uri.getPath(); + // locate a handler for the request + final Http2Handler handler = server.getHandlerFor(reqPath); try { + // no handler available for the request path, respond with 404 + if (handler == null) { + respondForMissingHandler(exchange); + return; + } + // Need to pass the BodyInputStream reference to the BodyOutputStream, so it can determine if the stream + // must be reset due to the BodyInputStream not being consumed by the handler when invoked. + if (bis instanceof BodyInputStream bodyInputStream) bos.bis = bodyInputStream; + handler.handle(exchange); } catch (IOException closed) { if (bos.closed) { @@ -821,6 +815,16 @@ public class Http2TestServerConnection { close(-1); } } + private void respondForMissingHandler(final Http2TestExchange exchange) + throws IOException { + final byte[] responseBody = (this.getClass().getSimpleName() + + " - No handler available to handle request " + + exchange.getRequestURI()).getBytes(US_ASCII); + try (final OutputStream os = exchange.getResponseBody()) { + exchange.sendResponseHeaders(404, responseBody.length); + os.write(responseBody); + } + } public void sendFrames(List frames) throws IOException { synchronized (outputQ) { @@ -845,6 +849,7 @@ public class Http2TestServerConnection { @SuppressWarnings({"rawtypes","unchecked"}) void readLoop() { try { + boolean altSvcSent = false; while (!stopping) { Http2Frame frame = readFrameImpl(); if (frame == null) { @@ -885,7 +890,9 @@ public class Http2TestServerConnection { outputQ.put(rst); continue; } - createStream((HeadersFrame) frame); + final Http2TestServer.AltSvcAddr altSvcAddr = server.h3AltSvcAddr; + final boolean sendAltSvc = secure && !altSvcSent && altSvcAddr != null; + altSvcSent = createStream((HeadersFrame) frame, sendAltSvc ? altSvcAddr : null); } } else { if (q == null && !pushStreams.contains(stream)) { @@ -896,9 +903,12 @@ public class Http2TestServerConnection { } if (frame.type() == WindowUpdateFrame.TYPE) { WindowUpdateFrame wup = (WindowUpdateFrame) frame; - synchronized (updaters) { + updatersLock.lock(); + try { Consumer r = updaters.get(stream); r.accept(wup.getUpdate()); + } finally { + updatersLock.unlock(); } } else if (frame.type() == ResetFrame.TYPE) { // do orderly close on input q @@ -949,6 +959,28 @@ public class Http2TestServerConnection { } } + boolean sendAltSvc(final String originHost, final Http2TestServer.AltSvcAddr altSvcAddr) { + Objects.requireNonNull(originHost); + System.err.printf("TestServer: AltSvcFrame for: %s%n", originHost); + try { + URI url = new URI("https://" + originHost); + String origin = url.toASCIIString(); + String svc = "h3=\"" + altSvcAddr.host() + ":" + altSvcAddr.port() + "\""; + svc = "fooh2=\":443\"; ma=2592000; persist=1, " + svc; + svc = svc + ", bar3=\":446\"; ma=2592000; persist=1"; + svc = svc + ", h3-34=\"" + altSvcAddr.host() + ":" + altSvcAddr.port() + +"\"; ma=2592000; persist=1"; + AltSvcFrame frame = new AltSvcFrame(0, 0, Optional.of(origin), svc); + System.err.printf("TestServer: Sending AltSvcFrame for: %s [%s]%n", origin, svc); + outputQ.put(frame); + return true; + } catch (IOException | URISyntaxException x) { + System.err.println("TestServer: Failed to send AltSvcFrame: " + x); + x.printStackTrace(); + } + return false; + } + static boolean isClientStreamId(int streamid) { return (streamid & 0x01) == 0x01; } @@ -967,12 +999,13 @@ public class Http2TestServerConnection { public List encodeHeaders(HttpHeaders headers, BiPredicate insertionPolicy) { List buffers = new LinkedList<>(); - + var entrySet = headers.map().entrySet(); + if (entrySet.isEmpty()) return buffers; ByteBuffer buf = getBuffer(); boolean encoded; headersLock.lock(); try { - for (Map.Entry> entry : headers.map().entrySet()) { + for (Map.Entry> entry : entrySet) { List values = entry.getValue(); String key = entry.getKey().toLowerCase(); for (String value : values) { @@ -998,6 +1031,7 @@ public class Http2TestServerConnection { /** Encodes an ordered list of headers. */ public List encodeHeadersOrdered(List> headers) { List buffers = new LinkedList<>(); + if (headers.isEmpty()) return buffers; ByteBuffer buf = getBuffer(); boolean encoded; @@ -1045,7 +1079,9 @@ public class Http2TestServerConnection { } else throw x; } if (frame instanceof ResponseHeaders rh) { - var buffers = encodeHeaders(rh.headers, rh.insertionPolicy); + // order of headers matters - pseudo headers first followed by rest of the headers + final List encodedHeaders = new ArrayList(encodeHeaders(rh.pseudoHeaders, rh.insertionPolicy)); + encodedHeaders.addAll(encodeHeaders(rh.headers, rh.insertionPolicy)); int maxFrameSize = Math.min(rh.getMaxFrameSize(), getMaxFrameSize() - 64); int next = 0; int cont = 0; @@ -1054,9 +1090,9 @@ public class Http2TestServerConnection { // size we need to split the headers into one // HeadersFrame + N x ContinuationFrames int remaining = maxFrameSize; - var list = new ArrayList(buffers.size()); - for (; next < buffers.size(); next++) { - var b = buffers.get(next); + var list = new ArrayList(encodedHeaders.size()); + for (; next < encodedHeaders.size(); next++) { + var b = encodedHeaders.get(next); var len = b.remaining(); if (!b.hasRemaining()) continue; if (len <= remaining) { @@ -1072,7 +1108,7 @@ public class Http2TestServerConnection { } } int flags = rh.getFlags(); - if (next != buffers.size()) { + if (next != encodedHeaders.size()) { flags = flags & ~HeadersFrame.END_HEADERS; } if (cont > 0) { @@ -1087,7 +1123,7 @@ public class Http2TestServerConnection { } writeFrame(hf); cont++; - } while (next < buffers.size()); + } while (next < encodedHeaders.size()); } else if (frame instanceof OutgoingPushPromise) { handlePush((OutgoingPushPromise)frame); } else @@ -1109,7 +1145,7 @@ public class Http2TestServerConnection { PushPromiseFrame pp = new PushPromiseFrame(op.parentStream, op.getFlags(), promisedStreamid, - encodeHeaders(op.headers), + encodeHeaders(op.reqHeaders), 0); pushStreams.add(promisedStreamid); nextPushStreamId += 2; @@ -1140,7 +1176,7 @@ public class Http2TestServerConnection { oo.goodToGo(); exec.submit(() -> { try { - ResponseHeaders oh = getPushResponse(promisedStreamid); + ResponseHeaders oh = getPushResponse(promisedStreamid, op.rspHeaders); outputQ.put(oh); ii.transferTo(oo); @@ -1162,10 +1198,10 @@ public class Http2TestServerConnection { // returns a minimal response with status 200 // that is the response to the push promise just sent - private ResponseHeaders getPushResponse(int streamid) { - HttpHeadersBuilder hb = createNewHeadersBuilder(); - hb.addHeader(":status", "200"); - ResponseHeaders oh = new ResponseHeaders(hb.build()); + private ResponseHeaders getPushResponse(int streamid, HttpHeaders rspHeaders) { + HttpHeadersBuilder pseudoHeaders = createNewHeadersBuilder(); + pseudoHeaders.addHeader(":status", "200"); + ResponseHeaders oh = new ResponseHeaders(pseudoHeaders.build(), rspHeaders); oh.streamid(streamid); oh.setFlag(HeaderFrame.END_HEADERS); return oh; @@ -1187,7 +1223,9 @@ public class Http2TestServerConnection { try { byte[] buf = new byte[9]; int ret; + // System.err.println("TestServer: reading frame headers"); ret=is.readNBytes(buf, 0, 9); + // System.err.println("TestServer: got frame headers"); if (ret == 0) { return null; } else if (ret != 9) { @@ -1200,6 +1238,7 @@ public class Http2TestServerConnection { len = (len << 8) + n; } byte[] rest = new byte[len]; + // System.err.println("TestServer: reading frame body"); int n = is.readNBytes(rest, 0, len); if (n != len) throw new IOException("Error reading frame"); @@ -1364,82 +1403,60 @@ public class Http2TestServerConnection { // window updates done in main reader thread because they may // be used to unblock BodyOutputStreams waiting for WUPs - HashMap> updaters = new HashMap<>(); + final HashMap> updaters = new HashMap<>(); + final ReentrantLock updatersLock = new ReentrantLock(); void registerStreamWindowUpdater(int streamid, Consumer r) { - synchronized(updaters) { + updatersLock.lock(); + try { updaters.put(streamid, r); + } finally { + updatersLock.unlock(); } } - int sendWindow = 64 * 1024 - 1; // connection level send window + // connection level send window, permits = bytes + final Semaphore sendWindow = new Semaphore(64 * 1024 - 1); /** * BodyOutputStreams call this to get the connection window first. * * @param amount */ - public synchronized void obtainConnectionWindow(int amount) throws InterruptedException { - int demand = amount; - try { - int waited = 0; - while (amount > 0) { - int n = Math.min(amount, sendWindow); - amount -= n; - sendWindow -= n; - if (amount > 0) { - // Do not include this print line on a version that does not have - // JDK-8337395 - System.err.printf("%s: blocked waiting for %s connection window, obtained %s%n", - server.name, amount, demand - amount); - waited++; - wait(); - } - } - if (waited > 0) { - // Do not backport this print line on a version that does not have - // JDK-8337395 - System.err.printf("%s: obtained %s connection window, remaining %s%n", - server.name, demand, sendWindow); - } - assert amount == 0; - } catch (Throwable t) { - sendWindow += (demand - amount); - throw t; - } + public void obtainConnectionWindow(int amount) throws InterruptedException { + sendWindow.acquire(amount); } public void updateConnectionWindow(int amount) { - synchronized (this) { - // Do not backport this print line on a version that does not have - // JDK-8337395 - System.err.printf(server.name + ": update sendWindow (window=%s, amount=%s) is now: %s%n", - sendWindow, amount, sendWindow + amount); - sendWindow += amount; - notifyAll(); - } + System.err.printf("%s: sendWindow (available:%s, released amount=%s) is now: %s%n", + server.name, + sendWindow.availablePermits(), amount, sendWindow.availablePermits() + amount); + sendWindow.release(amount); } // simplified output headers class. really just a type safe container // for the hashmap. public static class ResponseHeaders extends Http2Frame { + final HttpHeaders pseudoHeaders; final HttpHeaders headers; final BiPredicate insertionPolicy; final int maxFrameSize; - public ResponseHeaders(HttpHeaders headers) { - this(headers, (n,v) -> false); + public ResponseHeaders(HttpHeaders pseudoHeaders, HttpHeaders headers) { + this(pseudoHeaders, headers, (n,v) -> false); } - public ResponseHeaders(HttpHeaders headers, BiPredicate insertionPolicy) { - this(headers, insertionPolicy, Integer.MAX_VALUE); + public ResponseHeaders(HttpHeaders pseudoHeaders, HttpHeaders headers, BiPredicate insertionPolicy) { + this(pseudoHeaders, headers, insertionPolicy, Integer.MAX_VALUE); } - public ResponseHeaders(HttpHeaders headers, + public ResponseHeaders(HttpHeaders pseudoHeaders, + HttpHeaders headers, BiPredicate insertionPolicy, int maxFrameSize) { super(0, 0); + this.pseudoHeaders = pseudoHeaders; this.headers = headers; this.insertionPolicy = insertionPolicy; this.maxFrameSize = maxFrameSize; diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/OutgoingPushPromise.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/OutgoingPushPromise.java index 908a901133a..ace975e94fb 100644 --- a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/OutgoingPushPromise.java +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http2/OutgoingPushPromise.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -34,7 +34,8 @@ import jdk.internal.net.http.frame.Http2Frame; // will be converted to a PushPromiseFrame in the writeLoop // a thread is then created to produce the DataFrames from the InputStream public class OutgoingPushPromise extends Http2Frame { - final HttpHeaders headers; + final HttpHeaders reqHeaders; + final HttpHeaders rspHeaders; final URI uri; final InputStream is; final int parentStream; // not the pushed streamid @@ -42,19 +43,22 @@ public class OutgoingPushPromise extends Http2Frame { public OutgoingPushPromise(int parentStream, URI uri, - HttpHeaders headers, + HttpHeaders reqHeaders, + HttpHeaders rspHeaders, InputStream is) { - this(parentStream, uri, headers, is, List.of()); + this(parentStream, uri, reqHeaders, rspHeaders, is, List.of()); } public OutgoingPushPromise(int parentStream, URI uri, - HttpHeaders headers, + HttpHeaders reqHeaders, + HttpHeaders rspHeaders, InputStream is, List continuations) { super(0,0); this.uri = uri; - this.headers = headers; + this.reqHeaders = reqHeaders; + this.rspHeaders = rspHeaders; this.is = is; this.parentStream = parentStream; this.continuations = List.copyOf(continuations); diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerConnection.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerConnection.java new file mode 100644 index 00000000000..4b0ea351a09 --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerConnection.java @@ -0,0 +1,801 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.http3; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.http.HttpHeaders; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; + +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.CancelPushFrame; +import jdk.internal.net.http.http3.frames.FramesDecoder; +import jdk.internal.net.http.http3.frames.GoAwayFrame; +import jdk.internal.net.http.http3.frames.HeadersFrame; +import jdk.internal.net.http.http3.frames.Http3Frame; +import jdk.internal.net.http.http3.frames.MalformedFrame; +import jdk.internal.net.http.http3.frames.MaxPushIdFrame; +import jdk.internal.net.http.http3.frames.PartialFrame; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.http.http3.streams.Http3Streams; +import jdk.internal.net.http.http3.streams.Http3Streams.StreamType; +import jdk.internal.net.http.http3.streams.PeerUniStreamDispatcher; +import jdk.internal.net.http.http3.streams.QueuingStreamPair; +import jdk.internal.net.http.http3.streams.UniStreamPair; +import jdk.internal.net.http.qpack.Decoder; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.Encoder; +import jdk.internal.net.http.qpack.readers.HeaderFrameReader; +import jdk.internal.net.http.qpack.writers.HeaderFrameWriter; +import jdk.internal.net.http.qpack.TableEntry; +import jdk.internal.net.http.quic.ConnectionTerminator; +import jdk.internal.net.http.quic.QuicConnectionImpl; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; +import jdk.internal.net.http.quic.streams.QuicSenderStream; +import jdk.internal.net.http.quic.streams.QuicStream; +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.http.quic.streams.QuicStreamWriter; + +import static jdk.internal.net.http.http3.Http3Error.H3_STREAM_CREATION_ERROR; +import static jdk.internal.net.http.http3.frames.SettingsFrame.SETTINGS_MAX_FIELD_SECTION_SIZE; +import static jdk.internal.net.http.http3.frames.SettingsFrame.SETTINGS_QPACK_BLOCKED_STREAMS; +import static jdk.internal.net.http.http3.frames.SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY; +import static jdk.internal.net.http.quic.TerminationCause.appLayerClose; + +public class Http3ServerConnection { + private final Http3TestServer server; + private final QuicServerConnection quicConnection; + private final ConnectionTerminator quicConnTerminator; + private final SocketAddress peerAddress; + private final String dbgTag; + private final Logger debug; + private final UniStreamPair controlStreams; + private final UniStreamPair encoderStreams; + private final UniStreamPair decoderStreams; + private final Encoder qpackEncoder; + private final Decoder qpackDecoder; + private final FramesDecoder controlFramesDecoder; + private final AtomicLong nextPushId = new AtomicLong(); + private volatile long maxPushId = 0; + private final ReentrantLock pushIdLock = new ReentrantLock(); + private final Condition pushIdChanged = pushIdLock.newCondition(); + private final ConcurrentHashMap> requests = + new ConcurrentHashMap<>(); + private volatile boolean closeRequested; + // the max stream id of a processed H3 request. -1 implies none were processed. + private final AtomicLong maxProcessedRequestStreamId = new AtomicLong(-1); + // the stream id that was sent in a GOAWAY frame. -1 implies no GOAWAY frame was sent. + private final AtomicLong goAwayRequestStreamId = new AtomicLong(-1); + + private final CompletableFuture afterSettings = new MinimalFuture<>(); + private final CompletableFuture clientSettings = new MinimalFuture<>(); + + private final ConcurrentLinkedQueue lcsWriterQueue = + new ConcurrentLinkedQueue<>(); + + // A class used to dispatch peer initiated unidirectional streams + // according to their type. + private final class Http3StreamDispatcher extends PeerUniStreamDispatcher { + Http3StreamDispatcher(QuicReceiverStream stream) { + super(stream); + } + + @Override + protected Logger debug() { + return debug; + } + + @Override + protected void onStreamAbandoned(QuicReceiverStream stream) { + if (debug.on()) debug.log("Stream " + stream.streamId() + " abandoned!"); + qpackDecoder.cancelStream(stream.streamId()); + } + + @Override + protected void onControlStreamCreated(String description, QuicReceiverStream stream) { + if (debug.on()) { + debug.log("peerControlStream %s dispatched", stream.streamId()); + } + complete(description, stream, controlStreams.futureReceiverStream()); + } + + @Override + protected void onEncoderStreamCreated(String description, QuicReceiverStream stream) { + if (debug.on()) debug.log("peer opened QPack encoder stream"); + complete(description, stream, decoderStreams.futureReceiverStream()); + } + + @Override + protected void onDecoderStreamCreated(String description, QuicReceiverStream stream) { + if (debug.on()) debug.log("peer opened QPack decoder stream"); + complete(description, stream, encoderStreams.futureReceiverStream()); + } + + @Override + protected void onPushStreamCreated(String description, QuicReceiverStream stream, long pushId) { + // From RFC 9114: + // Only servers can push; if a server receives a client-initiated push stream, + // this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR. + close(H3_STREAM_CREATION_ERROR.code(), + "Push Stream %s opened by client" + .formatted(stream.streamId())); + } + + // completes the given completable future with the given stream + private void complete(String description, QuicReceiverStream stream, + CompletableFuture cf) { + if (debug.on()) { + debug.log("completing CF for %s with stream %s", description, stream.streamId()); + } + boolean completed = cf.complete(stream); + if (!completed) { + if (!cf.isCompletedExceptionally()) { + debug.log("CF for %s already completed with stream %s!", description, cf.resultNow().streamId()); + close(Http3Error.H3_STREAM_CREATION_ERROR, + "%s already created".formatted(description)); + } else { + debug.log("CF for %s already completed exceptionally!", description); + } + } + } + + static CompletableFuture dispatch(Http3ServerConnection conn, + QuicReceiverStream stream) { + var dispatcher = conn.new Http3StreamDispatcher(stream); + dispatcher.start(); + return dispatcher.dispatchCF(); + } + } + + /** + * Creates a new {@code Http3ServerConnection}. + * Once created, the connection must be {@linkplain #start()} started. + * @param server the HTTP/3 server creating this connection + * @param connection the underlying Quic connection + */ + Http3ServerConnection(Http3TestServer server, + QuicServerConnection connection, + SocketAddress peerAddress) { + this.server = server; + this.quicConnection = connection; + this.quicConnTerminator = connection.connectionTerminator(); + this.peerAddress = peerAddress; + var qtag = connection.dbgTag(); + dbgTag = "H3-Server(" + qtag + ")"; + debug = Utils.getDebugLogger(this::dbgTag); + controlFramesDecoder = new FramesDecoder("H3-Server-control("+qtag+")", + FramesDecoder::isAllowedOnClientControlStream); + controlStreams = new UniStreamPair(StreamType.CONTROL, + quicConnection, this::receiveControlBytes, + this::lcsWriterLoop, + this::onControlStreamError, debug); + qpackEncoder = new Encoder(this::qpackInsertionPolicy, + this::createEncoderStreams, + this::connectionError); + encoderStreams = qpackEncoder.encoderStreams(); + qpackDecoder = new Decoder(this::createDecoderStreams, + this::connectionError); + decoderStreams = qpackDecoder.decoderStreams(); + } + + boolean qpackInsertionPolicy(TableEntry entry) { + List allowedHeaders = Http3TestServer.ENCODER_ALLOWED_HEADERS; + if (allowedHeaders.isEmpty()) { + return false; + } + if (allowedHeaders.contains(Http3TestServer.ALL_ALLOWED)) { + return true; + } + return allowedHeaders.contains(entry.name()); + } + + /** + * Starts this {@code Http3ServerConnection}. + */ + public void start() { + quicConnection.addRemoteStreamListener(this::onNewRemoteStream); + quicConnection.onHandshakeCompletion(this::handshakeDone); + } + + // push bytes to the local control stream queue + void writeControlStream(ByteBuffer buffer) { + lcsWriterQueue.add(buffer); + controlStreams.localWriteScheduler().runOrSchedule(); + } + + QuicConnectionImpl quicConnection() { + return this.quicConnection; + } + + Http3TestServer server() { + return this.server; + } + + String connectionKey() { + // assuming the localConnectionId never changes; + // this will return QuicServerConnectionId(NNN), which should + // be enough to detect whether two exchanges are made on the + // same connection + return quicConnection.logTag(); + } + + // The local control stream write loop + private void lcsWriterLoop() { + var controlStreams = this.controlStreams; + if (controlStreams == null) return; + var writer = controlStreams.localWriter(); + if (writer == null) return; + ByteBuffer buffer; + if (debug.on()) + debug.log("start control writing loop: credit=" + writer.credit()); + while (writer.credit() > 0 && (buffer = lcsWriterQueue.poll()) != null) { + try { + if (debug.on()) + debug.log("schedule %s bytes for writing on control stream", buffer.remaining()); + writer.scheduleForWriting(buffer, buffer == QuicStreamReader.EOF); + } catch (Throwable t) { + var stream = writer.stream(); + Http3Streams.debugErrorCode(debug, stream, "Control stream"); + if (!closeRequested && quicConnection.isOpen()) { + if (!Http3Error.isNoError(stream.sndErrorCode())) { + if (debug.on()) debug.log("Failed to write to control stream", t); + } + close(Http3Error.H3_CLOSED_CRITICAL_STREAM, "Failed to write to control stream"); + return; + } + } + } + } + + private boolean hasHttp3Error(QuicStream stream) { + if (stream instanceof QuicReceiverStream rcvs) { + var code = rcvs.rcvErrorCode(); + if (code > 0 && !Http3Error.isNoError(code)) return true; + } + if (stream instanceof QuicSenderStream snds) { + var code = snds.sndErrorCode(); + if (code > 0 && !Http3Error.isNoError(code)) return true; + } + return false; + } + + + private void onControlStreamError(final QuicStream stream, final UniStreamPair uniStreamPair, + final Throwable throwable) { + // TODO: implement this! + try { + Http3Streams.debugErrorCode(debug, stream, "Control stream"); + if (!closeRequested && quicConnection.isOpen()) { + if (hasHttp3Error(stream)) { + if (debug.on()) { + debug.log("control stream " + stream.mode() + " failed", throwable); + } + } + close(Http3Error.H3_CLOSED_CRITICAL_STREAM, + "Control stream " + stream.mode() + " failed"); + } + } catch (Throwable t) { + if (debug.on() && !closeRequested) { + debug.log("onControlStreamError: handling ", throwable); + debug.log("onControlStreamError: exception while handling error: ", t); + } + } + } + + private void receiveControlBytes(ByteBuffer buffer) { + if (debug.on()) debug.log("received client control: %s bytes", buffer.remaining()); + controlFramesDecoder.submit(buffer); + Http3Frame frame; + while ((frame = controlFramesDecoder.poll()) != null) { + if (debug.on()) debug.log("client control frame: %s", frame); + if (frame instanceof MalformedFrame malformed) { + var cause = malformed.getCause(); + if (cause != null && debug.on()) { + debug.log(malformed.toString(), cause); + } + close(malformed.getErrorCode(), malformed.getMessage()); + controlStreams.stopSchedulers(); + controlFramesDecoder.clear(); + return; + } else if (frame instanceof PartialFrame) { + var payloadBytes = controlFramesDecoder.readPayloadBytes(); + if (debug.on()) { + debug.log("added %s bytes to %s", + Utils.remaining(payloadBytes), + frame); + } + } else if (frame instanceof CancelPushFrame cpf) { + cancelPushReceived(cpf.getPushId()); + } else if (frame instanceof MaxPushIdFrame mpf) { + maxPushIdReceived(mpf.getMaxPushId()); + } else if (frame instanceof SettingsFrame sf) { + ConnectionSettings clientSettings = ConnectionSettings.createFrom(sf); + // Set max and current capacity of the QPack encoder + qpackEncoder.configure(clientSettings); + long clientMaxTableCapacity = clientSettings.qpackMaxTableCapacity(); + long capacity = Math.min(Http3TestServer.ENCODER_CAPACITY_LIMIT, + clientMaxTableCapacity); + // RFC9204 3.2.3. Maximum Dynamic Table Capacity: + // "When the maximum table capacity is zero, the encoder MUST NOT + // insert entries into the dynamic table and MUST NOT send any encoder + // instructions on the encoder stream." + if (clientMaxTableCapacity != 0) { + qpackEncoder.setTableCapacity(capacity); + } + this.clientSettings.complete(clientSettings); + } + if (controlFramesDecoder.eof()) break; + } + if (controlFramesDecoder.eof()) { + close(Http3Error.H3_CLOSED_CRITICAL_STREAM, + "EOF reached while reading client control stream"); + } + } + + private void handshakeDone(Throwable t) { + if (t == null) { + controlStreams.futureSenderStream() + .thenApply(this::sendSettings) + .exceptionally(this::exceptionallyAndClose) + .thenApply(afterSettings::complete); + } else { + if (debug.on()) debug.log("Handshake failed: " + t, t); + // the connection is probably closed already, but just in case... + close(Http3Error.H3_INTERNAL_ERROR, "Handshake failed"); + } + } + + private T exceptionallyAndClose(Throwable t) { + try { + return exceptionally(t); + } finally { + // TODO: should we distinguish close due to + // exception from graceful close? + close(Http3Error.H3_INTERNAL_ERROR, message(t)); + } + } + + String message(Throwable t) { + return t == null ? "No Error" : t.getClass().getSimpleName(); + } + + private T exceptionally(Throwable t) { + try { + if (debug.on()) debug.log(t.getMessage(), t); + throw t; + } catch (RuntimeException | Error r) { + throw r; + } catch (ExecutionException x) { + throw new CompletionException(x.getMessage(), x.getCause()); + } catch (Throwable e) { + throw new CompletionException(e.getMessage(), e); + } + } + + private QuicSenderStream sendSettings(final QuicSenderStream localControlStream) { + final ConnectionSettings settings = server.getConfiguredConnectionSettings(); + final SettingsFrame settingsFrame = new SettingsFrame(); + + settingsFrame.setParameter(SETTINGS_MAX_FIELD_SECTION_SIZE, settings.maxFieldSectionSize()); + settingsFrame.setParameter(SETTINGS_QPACK_MAX_TABLE_CAPACITY, settings.qpackMaxTableCapacity()); + settingsFrame.setParameter(SETTINGS_QPACK_BLOCKED_STREAMS, settings.qpackBlockedStreams()); + qpackDecoder.configure(settings); + + if (debug.on()) { + debug.log("sending server settings %s for connection %s", settingsFrame, this); + } + final long size = settingsFrame.size(); + assert size >= 0 && size < Integer.MAX_VALUE; + var buf = ByteBuffer.allocate((int)settingsFrame.size()); + settingsFrame.writeFrame(buf); + buf.flip(); + writeControlStream(buf); + return localControlStream; + } + + + private boolean onNewRemoteStream(QuicReceiverStream stream) { + boolean closeRequested = this.closeRequested; + if (closeRequested) return false; + + if (stream instanceof QuicBidiStream bidiStream) { + onNewHttpRequest(bidiStream); + } else { + Http3StreamDispatcher.dispatch(this, stream).whenComplete((r, t) -> { + if (t != null) dispatchFailed(t); + }); + } + if (debug.on()) { + debug.log("New stream %s accepted", stream.streamId()); + } + return true; + } + + private void onNewHttpRequest(QuicBidiStream stream) { + if (!this.server.shouldProcessNewHTTPRequest(this)) { + if (debug.on()) { + debug.log("Rejecting HTTP request on stream %s of connection %s", + stream.streamId(), this); + } + // consider the request as unprocessed and send a GOAWAY on the connection + try { + sendGoAway(); + } catch (IOException ioe) { + System.err.println("Failed to send GOAWAY on connection " + this + + " due to: " + ioe); + ioe.printStackTrace(); + } + return; + } + var streamId = stream.streamId(); + // keep track of the largest request id that we have processed + long currentLargest = maxProcessedRequestStreamId.get(); + while (streamId > currentLargest) { + if (maxProcessedRequestStreamId.compareAndSet(currentLargest, streamId)) { + break; + } + currentLargest = maxProcessedRequestStreamId.get(); + } + if (debug.on()) { + debug.log("new incoming HTTP request on stream %s", streamId); + } + if (requests.containsKey(stream.streamId())) { + if (debug.on()) { + debug.log("Stream %s already created!", streamId); + } + quicConnTerminator.terminate(appLayerClose(H3_STREAM_CREATION_ERROR.code()) + .loggedAs("stream already created")); + return; + } + // creation of the Http3ServerExchangeImpl involves connecting its reader, which + // takes a fair amount of (JIT?) time. Since this method is called from + // within the decrypt loop, it prevents decrypting the following ONERTT packets, + // which can unnecessarily delay the processing of ACKs and cause excessive + // retransmission + MinimalFuture exchCf = new MinimalFuture<>(); + requests.put(stream.streamId(), exchCf); + if (debug.on()) { + debug.log("HTTP/3 exchange future for stream %s registered", streamId); + } + server.getQuicServer().executor().execute(() -> createExchange(exchCf, stream)); + if (debug.on()) { + debug.log("HTTP/3 exchange creation for stream %s triggered", streamId); + } + } + + private void createExchange(CompletableFuture exchCf, + QuicBidiStream stream) { + var streamId = stream.streamId(); + if (debug.on()) { + debug.log("Completing HTTP/3 exchange future for stream %s", streamId); + } + exchCf.complete(new Http3ServerStreamImpl(this, stream)); + if (debug.on()) { + debug.log("HTTP/3 exchange future for stream %s Completed", streamId); + } + } + + public final String dbgTag() { return dbgTag; } + + private void dispatchFailed(Throwable throwable) { + // TODO: anything to do? + if (debug.on()) debug.log("dispatch failed: " + throwable); + } + + QueuingStreamPair createEncoderStreams(Consumer encoderReceiver) { + return new QueuingStreamPair(StreamType.QPACK_ENCODER, quicConnection, + encoderReceiver, this::onEncoderStreamsFailed, debug); + } + + private void onEncoderStreamsFailed(final QuicStream stream, final UniStreamPair uniStreamPair, + final Throwable throwable) { + // TODO: implement this! + // close connection here. + if (!closeRequested) { + String message = stream != null ? stream.mode() + " failed" : "is null"; + if (quicConnection().isOpen()) { + if (debug.on()) { + debug.log("QPack encoder stream " + message, throwable); + } + } else { + if (debug.on()) { + debug.log("QPack encoder stream " + message + ": " + throwable); + } + } + } + } + + QueuingStreamPair createDecoderStreams(Consumer encoderReceiver) { + return new QueuingStreamPair(StreamType.QPACK_DECODER, quicConnection, + encoderReceiver, this::onDecoderStreamsFailed, debug); + } + + private void onDecoderStreamsFailed(final QuicStream stream, final UniStreamPair uniStreamPair, + final Throwable throwable) { + // TODO: implement this! + // close connection here. + if (!closeRequested) { + String message = stream != null ? stream.mode() + " failed" : "is null"; + if (quicConnection().isOpen()) { + if (debug.on()) { + debug.log("QPack decoder stream " + message, throwable); + } + } else { + debug.log("QPack decoder stream " + message + ": " + throwable); + } + } + } + + // public, to allow invocations from within tests + public void sendGoAway() throws IOException { + final QuicStreamWriter writer = controlStreams.localWriter(); + if (writer == null || !quicConnection.isOpen()) { + return; + } + // RFC-9114, section 5.2: + // Requests ... with the indicated identifier or greater + // are rejected ... by the sender of the GOAWAY. + final long maxProcessedStreamId = maxProcessedRequestStreamId.get(); + // adding 4 gets us the next stream id for the stream type + final long streamIdToReject = maxProcessedStreamId == -1 ? 0 : maxProcessedStreamId + 4; + // An endpoint MAY send multiple GOAWAY frames indicating different + // identifiers, but the identifier in each frame MUST NOT be greater + // than the identifier in any previous frame, since clients might + // already have retried unprocessed requests on another HTTP connection. + long currentGoAwayReqStrmId = goAwayRequestStreamId.get(); + while (currentGoAwayReqStrmId != -1 && streamIdToReject < currentGoAwayReqStrmId) { + if (goAwayRequestStreamId.compareAndSet(currentGoAwayReqStrmId, streamIdToReject)) { + break; + } + currentGoAwayReqStrmId = goAwayRequestStreamId.get(); + } + final GoAwayFrame frame = new GoAwayFrame(streamIdToReject); + final long size = frame.size(); + assert size >= 0 && size < Integer.MAX_VALUE; + final var buf = ByteBuffer.allocate((int) size); + frame.writeFrame(buf); + buf.flip(); + if (debug.on()) { + debug.log("Sending GOAWAY frame %s from server connection %s", frame, this); + } + writer.scheduleForWriting(buf, false); + } + + public void close(Http3Error error, String reason) { + close(error.code(), reason); + } + + void connectionError(Throwable throwable, Http3Error error) { + close(error, throwable.getMessage()); + } + + private boolean markCloseRequested() { + var closeRequested = this.closeRequested; + if (!closeRequested) { + synchronized (this) { + closeRequested = this.closeRequested; + if (!closeRequested) { + return this.closeRequested = true; + } + } + } + assert closeRequested; + if (debug.on()) debug.log("close already requested"); + return false; + } + + public void close(long error, String reason) { + if (markCloseRequested()) { + try { + sendGoAway(); + } catch (IOException e) { + // it's OK if we couldn't send a GOAWAY + if (debug.on()) { + debug.log("ignoring failure to send GOAWAY from server connection " + + this + " due to " + e); + } + } + if (quicConnection.isOpen()) { + if (debug.on()) debug.log("closing quic connection: " + reason); + quicConnTerminator.terminate(appLayerClose(error).loggedAs(reason)); + } else { + if (debug.on()) debug.log("quic connection already closed"); + } + } + } + + HeaderFrameReader newHeaderFrameReader(DecodingCallback decodingCallback) { + return qpackDecoder.newHeaderFrameReader(decodingCallback); + } + + void exchangeClosed(Http3ServerStreamImpl http3ServerExchange) { + requests.remove(http3ServerExchange.streamId()); + } + + sealed interface PushPromise permits CancelledPush, CompletedPush, PendingPush { + long pushId(); + } + + record CancelledPush(long pushId) implements PushPromise {} + record CompletedPush(long pushId, HttpHeaders headers) implements PushPromise {} + record PendingPush(long pushId, + CompletableFuture stream, + HttpHeaders headers, + Http3ServerExchange exchange) implements PushPromise { + } + + private final Map promiseMap = new ConcurrentHashMap<>(); + + PushPromise addPendingPush(long pushId, + CompletableFuture stream, + HttpHeaders headers, + Http3ServerExchange exchange) { + var push = new PendingPush(pushId, stream, headers, exchange); + expungePromiseMap(); + var previous = promiseMap.putIfAbsent(pushId, push); + if (previous == null || !(previous instanceof CancelledPush)) { + // allow to open multiple streams for the same pushId + // in order to test client behavior. We will return + // push even if the map contains a pending or completed + // push; + return push; + } + return previous; + } + + void addPushPromise(final long promiseId, final PushPromise promise) { + this.promiseMap.put(promiseId, promise); + } + + PushPromise getPushPromise(final long promiseId) { + return this.promiseMap.get(promiseId); + } + + void cancelPush(long pushId) { + expungePromiseMap(); + var push = promiseMap.putIfAbsent(pushId, new CancelledPush(pushId)); + if (push == null || push instanceof CancelledPush) return; + if (push instanceof CompletedPush) return; + if (push instanceof PendingPush pp) { + promiseMap.put(pushId, new CancelledPush(pushId)); + var ps = pp.stream(); + if (ps == null) { + try { + sendCancelPush(pushId); + } catch (IOException io) { + if (debug.on()) { + debug.log("Failed to send CANCEL_PUSH pushId=%s: %s", pushId, io); + } + } + } else { + ps.thenAccept(s -> { + try { + s.reset(Http3Error.H3_REQUEST_CANCELLED.code()); + } catch (IOException io) { + if (debug.on()) { + debug.log("Failed to reset push stream pushId=%s, stream=%s: %s", + pushId, s.streamId(), io); + } + } + }); + } + } + } + + void sendCancelPush(long pushId) throws IOException { + CancelPushFrame cancelPushFrame = new CancelPushFrame(pushId); + ByteBuffer buf = ByteBuffer.allocate((int)cancelPushFrame.size()); + cancelPushFrame.writeFrame(buf); + buf.flip(); + // need to wait until after settings are sent. + afterSettings.thenAccept((s) -> writeControlStream(buf)); + } + + + void cancelPushReceived(long pushId) { + cancelPush(pushId); + } + + void maxPushIdReceived(long pushId) { + pushIdLock.lock(); + try { + if (pushId > maxPushId) { + if (debug.on()) debug.log("max pushId: " + pushId); + maxPushId = pushId; + pushIdChanged.signalAll(); + } + } finally { + pushIdLock.unlock(); + } + } + + final AtomicLong minPush = new AtomicLong(); + static final int MAX_PUSH_HISTORY = 100; + void expungePromiseMap() { + assert MAX_PUSH_HISTORY > 0; + while (promiseMap.size() >= MAX_PUSH_HISTORY) { + long lowest = minPush.getAndIncrement(); + var pp = promiseMap.remove(lowest); + if (pp instanceof PendingPush ppp) { + cancelPush(ppp.pushId); + } + } + } + + List encodeHeaders(int bufferSize, long streamId, HttpHeaders... headers) { + HeaderFrameWriter writer = qpackEncoder.newHeaderFrameWriter(); + return qpackEncoder.encodeHeaders(writer, streamId, bufferSize, headers); + } + + void decodeHeaders(final HeadersFrame partialHeadersFrame, final ByteBuffer buffer, + final HeaderFrameReader headersReader) throws IOException { + ByteBuffer received = partialHeadersFrame.nextPayloadBytes(buffer); + boolean done = partialHeadersFrame.remaining() == 0; + this.qpackDecoder.decodeHeader(received, done, headersReader); + } + + long nextPushId() { + return this.nextPushId.getAndIncrement(); + } + + long waitForMaxPushId(long pushId) throws InterruptedException { + long maxPushId = this.maxPushId; + if (maxPushId > pushId) return maxPushId; + do { + this.pushIdLock.lock(); + try { + maxPushId = this.maxPushId; + if (maxPushId > pushId) return maxPushId; + this.pushIdChanged.await(); + } finally { + this.pushIdLock.unlock(); + } + } while (true); + } + + public Encoder qpackEncoder() { + return qpackEncoder; + } + + public CompletableFuture clientHttp3Settings() { + return clientSettings; + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerExchange.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerExchange.java new file mode 100644 index 00000000000..9a39ba4358d --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerExchange.java @@ -0,0 +1,801 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.http3; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; +import java.util.function.LongSupplier; + +import javax.net.ssl.SSLSession; + +import jdk.httpclient.test.lib.http2.Http2TestExchange; +import jdk.internal.net.http.common.HttpHeadersBuilder; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.DataFrame; +import jdk.internal.net.http.http3.frames.HeadersFrame; +import jdk.internal.net.http.http3.frames.PushPromiseFrame; +import jdk.internal.net.http.http3.streams.Http3Streams; +import jdk.internal.net.http.qpack.Encoder; +import jdk.internal.net.http.quic.QuicConnectionImpl; +import jdk.internal.net.http.quic.VariableLengthEncoder; +import jdk.internal.net.http.quic.streams.QuicSenderStream; +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.http.quic.streams.QuicStreamWriter; + +public final class Http3ServerExchange implements Http2TestExchange { + + private final Http3ServerStreamImpl serverStream; + private final Http3ServerConnection serverConn; + private final Logger debug; + final HttpHeaders requestHeaders; + final String method; + final String scheme; + final String authority; + final String path; + final URI uri; + final HttpHeadersBuilder rspheadersBuilder; + final SSLSession sslSession; + volatile long responseLength = 0; // 0 is unknown, -1 is 0 + volatile int responseCode; + final Http3ServerStreamImpl.RequestBodyInputStream is; + final ResponseBodyOutputStream os; + private boolean unknownReservedFrameAlreadySent; + + Http3ServerExchange(Http3ServerStreamImpl serverStream, HttpHeaders requestHeaders, + Http3ServerStreamImpl.RequestBodyInputStream is, SSLSession sslSession) { + this.serverStream = serverStream; + this.serverConn = serverStream.serverConnection(); + this.debug = Utils.getDebugLogger(this.serverConn::dbgTag); + this.requestHeaders = requestHeaders; + this.sslSession = sslSession; + this.is = is; + this.os = new ResponseBodyOutputStream(connectionTag(), debug, serverStream.writer, serverStream.writeLock, + serverStream.writeEnabled, this::getResponseLength); + method = requestHeaders.firstValue(":method").orElse(""); + //System.out.println("method = " + method); + path = requestHeaders.firstValue(":path").orElse(""); + //System.out.println("path = " + path); + scheme = requestHeaders.firstValue(":scheme").orElse(""); + //System.out.println("scheme = " + scheme); + authority = requestHeaders.firstValue(":authority").orElse(""); + if (!path.isEmpty() && !path.startsWith("/")) { + throw new IllegalArgumentException("Path is not absolute: " + path); + } + uri = URI.create(scheme + "://" + authority + path); + rspheadersBuilder = new HttpHeadersBuilder(); + } + + String connectionTag() { + return serverConn.quicConnection().logTag(); + } + + long getResponseLength() { + return responseLength; + } + + @Override + public String toString() { + return "H3Server Http3ServerExchange(%s)".formatted(serverStream.streamId()); + } + + @Override + public HttpHeaders getRequestHeaders() { + return requestHeaders; + } + + @Override + public HttpHeadersBuilder getResponseHeaders() { + return rspheadersBuilder; + } + + @Override + public URI getRequestURI() { + return uri; + } + + @Override + public String getRequestMethod() { + return method; + } + + @Override + public SSLSession getSSLSession() { + return sslSession; + } + + @Override + public void close() { + try { + is.close(); + os.close(); + serverStream.close(); + } catch (IOException e) { + if (debug.on()) { + debug.log(this + ".close exception: " + e, e); + } + } + } + + @Override + public void close(IOException io) throws IOException { + if (debug.on()) { + debug.log(this + " closed with exception: " + io); + } + if (serverStream.writer.sendingState().isSending()) { + if (debug.on()) { + debug.log(this + " resetting writer with H3_INTERNAL_ERROR"); + } + serverStream.writer.reset(Http3Error.H3_INTERNAL_ERROR.code()); + } + is.close(io); + os.closeInternal(); + close(); + } + + public Http3ServerExchange streamResetByPeer(IOException io) { + try { + if (debug.on()) + debug.log("H3 Server closing exchange: " + io); + close(io); + } catch (IOException e) { + if (debug.on()) + debug.log("Failed to close stream %s", serverStream.streamId()); + } + return this; + } + + @Override + public InputStream getRequestBody() { + return is; + } + + @Override + public OutputStream getResponseBody() { + return os; + } + + @Override + public void sendResponseHeaders(int rCode, long responseLength) throws IOException { + // occasionally send an unknown/reserved HTTP3 frame to exercise the case + // where the client is expected to ignore such frames + try { + optionallySendUnknownOrReservedFrame(); + this.responseLength = responseLength; + sendResponseHeaders(serverStream.streamId(), serverStream.writer, isHeadRequest(), + rCode, responseLength, rspheadersBuilder, os); + } catch (Exception ex) { + throw new IOException("failed to send headers: " + ex, ex); + } + } + + // WARNING: this method is also called for PushStreams, which has + // a different writer, streamId, request etc... + // The only fields that can be safely used in this method is debug and + // http3ServerConnection + private void sendResponseHeaders(long streamId, + QuicStreamWriter writer, + boolean isHeadRequest, + int rCode, + long responseLength, + HttpHeadersBuilder rspheadersBuilder, + ResponseBodyOutputStream os) + throws IOException { + String tag = "streamId=" + streamId + " "; + // in case of HEAD request the caller is supposed to set Content-Length + // directly - and the responseLength passed here is supposed to be -1 + if (responseLength != 0 && rCode != 204 && !isHeadRequest) { + long clen = responseLength > 0 ? responseLength : 0; + rspheadersBuilder.setHeader("Content-length", Long.toString(clen)); + } + final HttpHeadersBuilder pseudoHeadersBuilder = new HttpHeadersBuilder(); + pseudoHeadersBuilder.setHeader(":status", Integer.toString(rCode)); + final HttpHeaders pseudoHeaders = pseudoHeadersBuilder.build(); + final HttpHeaders headers = rspheadersBuilder.build(); + // order of headers matters - pseudo headers first followed by rest of the headers + var payload = serverConn.encodeHeaders(1024, streamId, pseudoHeaders, headers); + if (debug.on()) + debug.log(tag + "headers payload: " + Utils.remaining(payload)); + HeadersFrame frame = new HeadersFrame(Utils.remaining(payload)); + ByteBuffer buffer = ByteBuffer.allocate(frame.headersSize()); + frame.writeHeaders(buffer); + buffer.flip(); + if (debug.on()) { + debug.log(tag + "Writing HeaderFrame headers: " + Utils.asHexString(buffer)); + } + boolean noBody = rCode >= 200 && (responseLength < 0 || rCode == 204); + boolean last = frame.length() == 0 && noBody; + if (last) { + if (debug.on()) { + debug.log(tag + "last payload sent: empty headers, no body"); + } + writer.scheduleForWriting(buffer, true); + } else { + writer.queueForWriting(buffer); + } + int size = payload.size(); + for (int i = 0; i < size; i++) { + last = i == size - 1; + var buf = payload.get(i); + if (debug.on()) { + debug.log(tag + "Writing HeaderFrame payload: " + Utils.asHexString(buf)); + } + if (last) { + if (debug.on()) { + debug.log(tag + "last headers bytes sent, %s", + noBody ? "no body" : "body should follow"); + } + writer.scheduleForWriting(buf, noBody); + } else { + writer.queueForWriting(buf); + } + } + if (noBody) { + if (debug.on()) { + debug.log(tag + "no body: closing os"); + } + os.closeInternal(); + } + os.goodToGo(); + if (debug.on()) { + debug.log(this + " Sent response headers " + tag + rCode); + } + } + + private void optionallySendUnknownOrReservedFrame() { + if (this.unknownReservedFrameAlreadySent) { + // don't send it more than once + return; + } + UnknownOrReservedFrame.tryGenerateFrame().ifPresent((f) -> { + if (debug.on()) { + debug.log("queueing to send an unknown/reserved HTTP3 frame: " + f); + } + try { + serverStream.writer.queueForWriting(f.toByteBuffer()); + } catch (IOException e) { + // ignore + if (debug.on()) { + debug.log("failed to queue unknown/reserved HTTP3 frame: " + f, e); + } + } + this.unknownReservedFrameAlreadySent = true; + }); + } + + @Override + public InetSocketAddress getRemoteAddress() { + return serverConn.quicConnection().peerAddress(); + } + + @Override + public int getResponseCode() { + return responseCode; + } + + @Override + public InetSocketAddress getLocalAddress() { + return (InetSocketAddress) serverConn.quicConnection().localAddress(); + } + + @Override + public String getConnectionKey() { + return serverConn.connectionKey(); + } + + @Override + public String getProtocol() { + return "HTTP/3"; + } + + @Override + public HttpClient.Version getServerVersion() { + return HttpClient.Version.HTTP_3; + } + + @Override + public HttpClient.Version getExchangeVersion() { + return HttpClient.Version.HTTP_3; + } + + @Override + public boolean serverPushAllowed() { + return true; + } + + @Override + public void serverPush(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream content) + throws IOException { + try { + serverPushWithId(uri, reqHeaders, rspHeaders, content); + } catch (IOException io) { + if (debug.on()) + debug.log("Failed to push " + uri + ": " + io); + throw io; + } + } + + @Override + public long serverPushWithId(URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream content) + throws IOException { + HttpHeaders combinePromiseHeaders = combinePromiseHeaders(uri, reqHeaders); + long pushId = serverConn.nextPushId(); + if (debug.on()) { + debug.log("Server sending serverPushWithId(" + pushId + "): " + uri); + } + // send PUSH_PROMISE frame + sendPushPromiseFrame(pushId, uri, combinePromiseHeaders); + if (debug.on()) + debug.log("Server sent PUSH_PROMISE(" + pushId + ")"); + // now open push stream and send response headers + body + Http3ServerConnection.PushPromise pp = sendPushResponse(pushId, combinePromiseHeaders, rspHeaders, content); + assert pushId == pp.pushId(); + return pp.pushId(); + } + + @Override + public long sendPushId(long pushId, URI uri, HttpHeaders headers) throws IOException { + HttpHeaders combinePromiseHeaders = combinePromiseHeaders(uri, headers); + return sendPushPromiseFrame(pushId, uri, combinePromiseHeaders); + } + + @Override + public void sendPushResponse(long pushId, URI uri, HttpHeaders reqHeaders, HttpHeaders rspHeaders, InputStream content) + throws IOException { + HttpHeaders combinePromiseHeaders = combinePromiseHeaders(uri, reqHeaders); + Http3ServerConnection.PushPromise pp = sendPushResponse(pushId, combinePromiseHeaders, rspHeaders, content); + assert pushId == pp.pushId(); + } + + @Override + public void resetStream(long code) throws IOException { + os.resetStream(code); + } + + @Override + public void cancelPushId(long pushId) throws IOException { + serverConn.sendCancelPush(pushId); + } + + @Override + public long waitForMaxPushId(long pushId) throws InterruptedException { + return serverConn.waitForMaxPushId(pushId); + } + + @Override + public Encoder qpackEncoder() { + return serverConn.qpackEncoder(); + } + + @Override + public CompletableFuture clientHttp3Settings() { + return serverConn.clientHttp3Settings(); + } + + private long sendPushPromiseFrame(long pushId, URI uri, HttpHeaders headers) + throws IOException { + if (pushId == -1) pushId = serverConn.nextPushId(); + List payload = serverConn.encodeHeaders(1024, serverStream.streamId(), headers); + PushPromiseFrame frame = new PushPromiseFrame(pushId, Utils.remaining(payload)); + ByteBuffer buffer = ByteBuffer.allocate(frame.headersSize()); + frame.writeHeaders(buffer); + buffer.flip(); + boolean last = frame.length() == 0; + if (last) { + if (debug.on()) { + debug.log("last payload sent: empty headers, no body"); + } + serverStream.writer.scheduleForWriting(buffer, false); + } else { + serverStream.writer.queueForWriting(buffer); + } + int size = payload.size(); + for (int i = 0; i < size; i++) { + last = i == size - 1; + var buf = payload.get(i); + if (last) { + serverStream.writer.scheduleForWriting(buf, false); + } else { + serverStream.writer.queueForWriting(buf); + } + } + return pushId; + } + + private static HttpHeaders combinePromiseHeaders(URI uri, HttpHeaders headers) { + HttpHeadersBuilder headersBuilder = new HttpHeadersBuilder(); + headersBuilder.setHeader(":method", "GET"); + headersBuilder.setHeader(":scheme", uri.getScheme()); + headersBuilder.setHeader(":authority", uri.getAuthority()); + headersBuilder.setHeader(":path", uri.getPath()); + for (Map.Entry> entry : headers.map().entrySet()) { + for (String value : entry.getValue()) + headersBuilder.addHeader(entry.getKey(), value); + } + return headersBuilder.build(); + } + + @Override + public void requestStopSending(long errorCode) { + serverStream.reader.stream().requestStopSending(errorCode); + } + + private QuicSenderStream cancel(QuicSenderStream s) { + try { + switch (s.sendingState()) { + case READY, SEND, DATA_SENT -> s.reset(Http3Error.H3_REQUEST_CANCELLED.code()); + } + } catch (IOException io) { + throw new UncheckedIOException(io); + } + return s; + } + + private Http3ServerConnection.PushPromise sendPushResponse(long pushId, + HttpHeaders reqHeaders, + HttpHeaders rspHeaders, + InputStream body) { + var stream = serverConn.quicConnection() + .openNewLocalUniStream(Duration.ofSeconds(10)); + final Http3ServerConnection.PushPromise promise = + serverConn.addPendingPush(pushId, stream, reqHeaders, this); + if (!(promise instanceof Http3ServerConnection.PendingPush)) { + stream.thenApply(this::cancel); + return promise; + } + stream.thenApplyAsync(s -> { + if (debug.on()) { + debug.log("Server open(streamId=" + s.streamId() + ", pushId=" + pushId + ")"); + } + String tag = "streamId=" + s.streamId() + ": "; + var push = serverConn.getPushPromise(pushId); + if (push instanceof Http3ServerConnection.CancelledPush) { + this.cancel(s); + return push; + } + // no write loop: just buffer everything + final ReentrantLock pushLock = new ReentrantLock(); + final Condition writePushEnabled = pushLock.newCondition(); + final Runnable writeLoop = () -> { + pushLock.lock(); + try { + writePushEnabled.signalAll(); + } finally { + pushLock.unlock(); + } + }; + var pushw = s.connectWriter(SequentialScheduler.lockingScheduler(writeLoop)); + int tlen = VariableLengthEncoder.getEncodedSize(Http3Streams.PUSH_STREAM_CODE); + int plen = VariableLengthEncoder.getEncodedSize(pushId); + ByteBuffer buf = ByteBuffer.allocate(tlen + plen); + VariableLengthEncoder.encode(buf, Http3Streams.PUSH_STREAM_CODE); + VariableLengthEncoder.encode(buf, pushId); + buf.flip(); + try { + pushw.queueForWriting(buf); + if (debug.on()) { + debug.log(tag + "Server queued push stream type pushId=" + pushId + + " 0x" + Utils.asHexString(buf)); + } + ResponseBodyOutputStream os = new ResponseBodyOutputStream(connectionTag(), + debug, pushw, pushLock, writePushEnabled, () -> 0); + sendResponseHeaders(s.streamId(), pushw, false, 200, 0, + new HttpHeadersBuilder(rspHeaders), os); + if (debug.on()) { + debug.log(tag + "Server push response headers sent pushId=" + pushId); + } + switch (s.sendingState()) { + case SEND, READY -> { + if (!s.stopSendingReceived()) { + body.transferTo(os); + serverConn.addPushPromise(pushId, new Http3ServerConnection.CompletedPush(pushId, reqHeaders)); + os.close(); + if (debug.on()) { + debug.log(tag + "Server push response body sent pushId=" + pushId); + } + } else { + if (debug.on()) { + debug.log(tag + "Server push response body cancelled pushId=" + pushId); + } + serverConn.addPushPromise(pushId, new Http3ServerConnection.CancelledPush(pushId)); + cancel(s); + } + } + case RESET_SENT, RESET_RECVD -> { + if (debug.on()) { + debug.log(tag + "Server push response body cancelled pushId=" + pushId); + } + // benign race if already cancelled, stateless marker + serverConn.addPushPromise(pushId, new Http3ServerConnection.CancelledPush(pushId)); + } + default -> { + if (debug.on()) { + debug.log(tag + "Server push response body cancelled pushId=" + pushId); + } + serverConn.addPushPromise(pushId, new Http3ServerConnection.CancelledPush(pushId)); + cancel(s); + } + } + body.close(); + } catch (IOException io) { + if (debug.on()) { + debug.log(tag + "Server failed to send pushId=" + pushId + ": " + io); + } + throw new UncheckedIOException(io); + } + return serverConn.getPushPromise(pushId); + }, serverConn.server().getQuicServer().executor()).exceptionally(t -> { + if (debug.on()) { + debug.log("Server failed to send PushPromise(pushId=" + pushId + "): " + t); + } + serverConn.addPushPromise(pushId, new Http3ServerConnection.CancelledPush(pushId)); + try { + body.close(); + } catch (IOException io) { + if (debug.on()) { + debug.log("Failed to close PushPromise stream(pushId=" + + pushId + "): " + io); + } + } + return serverConn.getPushPromise(pushId); + }); + return promise; + } + + @Override + public CompletableFuture sendPing() { + final QuicConnectionImpl quicConn = serverConn.quicConnection(); + var executor = quicConn.quicInstance().executor(); + return quicConn.requestSendPing() + // ensure that dependent actions will not be executed in the + // thread that completes the CF + .thenApplyAsync(Function.identity(), executor) + .exceptionallyAsync(this::rethrow, executor); + } + + private T rethrow(Throwable t) { + if (t instanceof RuntimeException r) throw r; + if (t instanceof Error e) throw e; + if (t instanceof ExecutionException x) return rethrow(x.getCause()); + throw new CompletionException(t); + } + + private boolean isHeadRequest() { + return "HEAD".equals(method); + } + + static class ResponseBodyOutputStream extends OutputStream { + + volatile boolean closed; + volatile boolean goodToGo; + boolean headersWritten; + long sent; + private final QuicStreamWriter osw; + private final ReentrantLock writeLock; + private final Condition writeEnabled; + private final LongSupplier responseLength; + private final Logger debug; + private final String connectionTag; + + ResponseBodyOutputStream(String connectionTag, + Logger debug, + QuicStreamWriter writer, + ReentrantLock writeLock, + Condition writeEnabled, + LongSupplier responseLength) { + this.debug = debug; + this.writeLock = writeLock; + this.writeEnabled = writeEnabled; + this.responseLength = responseLength; + this.osw = writer; + this.connectionTag = connectionTag; + } + + private void writeHeadersIfNeeded(ByteBuffer buffer) + throws IOException { + assert writeLock.isHeldByCurrentThread(); + long responseLength = this.responseLength.getAsLong(); + boolean streaming = responseLength == 0; + if (streaming) { + if (buffer.hasRemaining()) { + int len = buffer.remaining(); + if (debug.on()) { + debug.log("Streaming BodyResponse: streamId=%s writing DataFrame(%s)", + osw.stream().streamId(), len); + } + var data = new DataFrame(len); + var headers = ByteBuffer.allocate(data.headersSize()); + data.writeHeaders(headers); + headers.flip(); + osw.queueForWriting(headers); + } + } else if (!headersWritten) { + long len = responseLength > 0 ? responseLength : 0; + if (debug.on()) { + debug.log("BodyResponse: streamId=%s writing DataFrame(%s)", + osw.stream().streamId(), len); + } + var data = new DataFrame(len); + var headers = ByteBuffer.allocate(data.headersSize()); + data.writeHeaders(headers); + headers.flip(); + osw.queueForWriting(headers); + headersWritten = true; + } + } + + @Override + public void write(int b) throws IOException { + var buffer = ByteBuffer.allocate(1); + buffer.put((byte) b); + buffer.flip(); + submit(buffer); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + Objects.checkFromIndexSize(off, len, b.length); + // the data is not written immediately, and therefore + // it needs to be copied. + // maybe we should find a way to wait until the data + // has been written, but that sounds complex. + ByteBuffer buffer = ByteBuffer.wrap(b.clone(), off, len); + submit(buffer); + } + + String logTag() { + return connectionTag + " streamId=" + osw.stream().streamId(); + } + + /** + * Schedule the ByteBuffer for writing. The buffer must never + * be reused. + * + * @param buffer response data + * @throws IOException if the channel is closed + */ + public void submit(ByteBuffer buffer) throws IOException { + writeLock.lock(); + try { + if (closed && buffer.hasRemaining()) { + throw new ClosedChannelException(); + } + if (osw.credit() <= 0) { + if (Log.requests()) { + Log.logResponse(() -> logTag() + ": HTTP/3 Server waiting for credits"); + } + writeEnabled.awaitUninterruptibly(); + if (Log.requests()) { + Log.logResponse(() -> logTag() + ": HTTP/3 Server unblocked - credits: " + + osw.credit() + ", closed: " + closed); + } + } + if (closed) { + if (buffer.hasRemaining()) { + throw new ClosedChannelException(); + } else return; + } + int len = buffer.remaining(); + sent = sent + len; + writeHeadersIfNeeded(buffer); + long responseLength = this.responseLength.getAsLong(); + boolean streaming = responseLength == 0; + boolean last = !streaming && (sent == responseLength + || (sent == 0 && responseLength == -1)); + osw.scheduleForWriting(buffer, last); + if (last) closeInternal(); + if (!streaming && sent != 0 && sent > responseLength) { + throw new IOException("sent more bytes than expected"); + } + } finally { + writeLock.unlock(); + } + } + + public void closeInternal() { + if (debug.on()) { + debug.log("BodyResponse: streamId=%s closeInternal", osw.stream().streamId()); + } + if (closed) return; + writeLock.lock(); + try { + closed = true; + } finally { + writeLock.unlock(); + } + } + + public void close() throws IOException { + if (debug.on()) { + debug.log("BodyResponse: streamId=%s close", osw.stream().streamId()); + } + if (closed) return; + writeLock.lock(); + try { + if (closed) return; + closed = true; + switch (osw.sendingState()) { + case READY, SEND -> { + if (debug.on()) { + debug.log("BodyResponse: streamId=%s sending EOF", + osw.stream().streamId()); + } + osw.scheduleForWriting(QuicStreamReader.EOF, true); + writeEnabled.signalAll(); + } + default -> { + } + } + } catch (IOException io) { + throw new IOException(io); + } finally { + writeLock.unlock(); + } + } + + public void goodToGo() { + this.goodToGo = true; + } + + public void resetStream(long code) throws IOException { + if (closed) return; + writeLock.lock(); + try { + if (closed) return; + closed = true; + switch (osw.sendingState()) { + case READY, SEND, DATA_SENT: + osw.reset(code); + default: + break; + } + } catch (IOException io) { + throw new IOException(io); + } finally { + writeLock.unlock(); + } + } + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerStreamImpl.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerStreamImpl.java new file mode 100644 index 00000000000..c5a8709346c --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3ServerStreamImpl.java @@ -0,0 +1,489 @@ +/* + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.http3; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InterruptedIOException; +import java.io.UncheckedIOException; +import java.net.http.HttpHeaders; +import java.nio.ByteBuffer; +import java.util.Objects; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +import jdk.internal.net.http.common.HttpHeadersBuilder; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.common.ValidatingHeadersConsumer; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.CancelPushFrame; +import jdk.internal.net.http.http3.frames.DataFrame; +import jdk.internal.net.http.http3.frames.FramesDecoder; +import jdk.internal.net.http.http3.frames.HeadersFrame; +import jdk.internal.net.http.http3.frames.Http3Frame; +import jdk.internal.net.http.http3.frames.MalformedFrame; +import jdk.internal.net.http.http3.frames.PartialFrame; +import jdk.internal.net.http.http3.frames.UnknownFrame; +import jdk.internal.net.http.http3.streams.Http3Streams; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.QPackException; +import jdk.internal.net.http.qpack.readers.HeaderFrameReader; +import jdk.internal.net.http.quic.BuffersReader; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.http.quic.streams.QuicStreamWriter; + +final class Http3ServerStreamImpl { + private final Http3ServerConnection serverConn; + private final Logger debug; + final QuicBidiStream stream; + final SequentialScheduler readScheduler = SequentialScheduler.lockingScheduler(this::readLoop); + final SequentialScheduler writeScheduler = SequentialScheduler.lockingScheduler(this::writeLoop); + final QuicStreamReader reader; + final QuicStreamWriter writer; + final BuffersReader.ListBuffersReader incoming = BuffersReader.list(); + final DecodingCallback headersConsumer; + final HeaderFrameReader headersReader; + final HttpHeadersBuilder requestHeadersBuilder; + final ReentrantLock writeLock = new ReentrantLock(); + final Condition writeEnabled = writeLock.newCondition(); + final CompletableFuture requestHeadersCF = new MinimalFuture<>(); + final CompletableFuture exchangeCF; + final BlockingQueue requestBodyQueue = new LinkedBlockingQueue<>(); + private volatile boolean eof; + + volatile PartialFrame partialFrame; + + Http3ServerStreamImpl(Http3ServerConnection http3ServerConnection, QuicBidiStream stream) { + this.serverConn = http3ServerConnection; + this.debug = Utils.getDebugLogger(this.serverConn::dbgTag); + this.stream = stream; + requestHeadersBuilder = new HttpHeadersBuilder(); + headersConsumer = new HeadersConsumer(); + headersReader = http3ServerConnection.newHeaderFrameReader(headersConsumer); + writer = stream.connectWriter(writeScheduler); + reader = stream.connectReader(readScheduler); + exchangeCF = requestHeadersCF.thenApply(this::startExchange); + // TODO: add a start() method that calls reader.start(), and + // call it outside of the constructor + reader.start(); + } + + Http3ServerConnection serverConnection() { + return this.serverConn; + } + + private void readLoop() { + try { + readLoop0(); + } catch (QPackException qe) { + boolean isConnectionError = qe.isConnectionError(); + Http3Error error = qe.http3Error(); + Throwable cause = qe.getCause(); + if (isConnectionError) { + headersConsumer.onConnectionError(cause, error); + } else { + headersConsumer.onStreamError(cause, error); + } + } + } + + private void readLoop0() { + ByteBuffer buffer; + + // reader can be null if the readLoop is invoked + // before reader is assigned. + if (reader == null) return; + + if (debug.on()) { + debug.log("H3Server: entering readLoop(stream=%s)", stream.streamId()); + } + try { + while (!reader.isReset() && (buffer = reader.poll()) != null) { + if (buffer == QuicStreamReader.EOF) { + if (debug.on()) { + debug.log("H3Server: EOF on stream=" + stream.streamId()); + } + if (!eof) requestBodyQueue.add(buffer); + eof = true; + return; + } + if (debug.on()) { + debug.log("H3Server: got %s bytes on stream %s", buffer.remaining(), stream.streamId()); + } + + var partialFrame = this.partialFrame; + if (partialFrame != null && partialFrame.remaining() == 0) { + this.partialFrame = partialFrame = null; + } + if (partialFrame instanceof HeadersFrame partialHeaders) { + serverConn.decodeHeaders(partialHeaders, buffer, headersReader); + } else if (partialFrame instanceof DataFrame partialData) { + receiveData(partialData, buffer); + } else if (partialFrame != null) { + partialFrame.nextPayloadBytes(buffer); + } + if (!buffer.hasRemaining()) { + continue; + } + + incoming.add(buffer); + Http3Frame frame = Http3Frame.decode(incoming, FramesDecoder::isAllowedOnRequestStream, debug); + if (frame == null) continue; + if (frame instanceof PartialFrame partial) { + this.partialFrame = partialFrame = partial; + if (frame instanceof HeadersFrame partialHeaders) { + if (debug.on()) { + debug.log("H3Server Got headers: " + frame + " on stream=" + stream.streamId()); + } + long remaining = partial.remaining(); + long available = incoming.remaining(); + long read = Math.min(remaining, available); + if (read > 0) { + for (ByteBuffer buf : incoming.getAndRelease(read)) { + serverConn.decodeHeaders(partialHeaders, buf, headersReader); + } + } + } else if (frame instanceof DataFrame partialData) { + if (debug.on()) { + debug.log("H3Server Got request body: " + frame + " on stream=" + stream.streamId()); + } + long remaining = partial.remaining(); + long available = incoming.remaining(); + long read = Math.min(remaining, available); + if (read > 0) { + for (ByteBuffer buf : incoming.getAndRelease(read)) { + receiveData(partialData, buf); + } + } + + } else if (frame instanceof UnknownFrame unknown) { + unknown.nextPayloadBytes(incoming); + } else { + if (debug.on()) { + debug.log("H3Server Got unexpected partial frame: " + + frame + " on stream=" + stream.streamId()); + } + serverConn.close(Http3Error.H3_FRAME_UNEXPECTED, + "unexpected frame type=" + frame.type() + + " on stream=" + stream.streamId()); + readScheduler.stop(); + writeScheduler.stop(); + return; + } + } else if (frame instanceof MalformedFrame malformed) { + if (debug.on()) { + debug.log("H3Server Got frame: " + frame + " on stream=" + stream.streamId()); + } + serverConn.close(malformed.getErrorCode(), + malformed.getMessage()); + readScheduler.stop(); + writeScheduler.stop(); + return; + } else { + if (debug.on()) { + debug.log("H3Server Got frame: " + frame + " on stream=" + stream.streamId()); + } + } + } + if (reader.isReset()) { + if (debug.on()) + debug.log("H3 Server: stream %s reset", reader.stream().streamId()); + readScheduler.stop(); + resetReceived(); + } + if (debug.on()) + debug.log("H3 Server: exiting read loop"); + } catch (Throwable t) { + if (debug.on()) + debug.log("H3 Server: read loop failed: " + t); + if (reader.isReset()) { + if (debug.on()) { + debug.log("H3 Server: stream %s reset", reader.stream()); + } + readScheduler.stop(); + resetReceived(); + } else { + if (debug.on()) { + debug.log("H3 Server: closing connection due to: " + t, t); + } + serverConn.close(Http3Error.H3_INTERNAL_ERROR, serverConn.message(t)); + readScheduler.stop(); + writeScheduler.stop(); + } + } + } + + String readErrorString(String defVal) { + return Http3Streams.errorCodeAsString(reader.stream()).orElse(defVal); + } + + void resetReceived() { + // If stop_sending sent and reset received (implied by this method being called) + // then exit normally and don't send a reset + if (debug.on()) { + debug.log("resetReceived: stream:%s, isStopSendingRequested:%s, errorCode:%s, isNoError:%s", + stream.streamId(), stream.isStopSendingRequested(), stream.rcvErrorCode(), + Http3Error.isNoError(reader.stream().rcvErrorCode())); + } + + if (reader.stream().isStopSendingRequested() + && requestHeadersCF.isDone()) { + // we can only request stop sending in the handler after having + // parsed the headers, therefore, if requestHeadersCF is not + // completed when we reach here we should reset the stream. + + // We have requested stopSending and received a reset in response: + // nothing to do - let the response be sent to the client, but throw an + // exception if `is` is used again. + exchangeCF.thenApply(en -> { + en.is.close(new IOException("stopSendingRequested")); + return en; + }); + return; + } + + String msg = "Stream %s reset by peer: %s" + .formatted(streamId(), readErrorString("no error code")); + if (debug.on()) + debug.log("H3 Server: reset received: " + msg); + var io = new IOException(msg); + requestHeadersCF.completeExceptionally(io); + exchangeCF.thenApply((e) -> e.streamResetByPeer(io)); + } + + void receiveData(DataFrame partialDataFrame, ByteBuffer buffer) { + if (debug.on()) { + debug.log("receiving data: " + buffer.remaining() + " on stream=" + stream.streamId()); + } + ByteBuffer received = partialDataFrame.nextPayloadBytes(buffer); + requestBodyQueue.add(received); + } + + void cancelPushFrameReceived(CancelPushFrame cancel) { + serverConn.cancelPush(cancel.getPushId()); + } + + class RequestBodyInputStream extends InputStream { + volatile IOException error; + volatile boolean closed; + // uses an unbounded blocking queue in which the readrLoop + // publishes the DataFrames payload... + ByteBuffer current; + // Use lock to avoid pinned threads on the blocking queue + final ReentrantLock lock = new ReentrantLock(); + + ByteBuffer current() throws IOException { + lock.lock(); + try { + while (true) { + if (current != null && current.hasRemaining()) { + return current; + } + if (current == QuicStreamReader.EOF) return current; + try { + if (debug.on()) + debug.log("Taking buffer from queue"); + // Blocking call + current = requestBodyQueue.take(); + } catch (InterruptedException e) { + var io = new InterruptedIOException(); + Thread.currentThread().interrupt(); + io.initCause(e); + close(io); + var error = this.error; + if (error != null) throw error; + } + } + } finally { + lock.unlock(); + } + } + + @Override + public int read() throws IOException { + ByteBuffer buffer = current(); + if (buffer == QuicStreamReader.EOF) { + var error = this.error; + if (error == null) return -1; + throw error; + } + return buffer.get() & 0xFF; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + Objects.checkFromIndexSize(off, len, b.length); + int remaining = len; + while (remaining > 0) { + ByteBuffer buffer = current(); + if (buffer == QuicStreamReader.EOF) { + if (len == remaining) { + var error = this.error; + if (error == null) return -1; + throw error; + } else return len - remaining; + } + int count = Math.min(buffer.remaining(), remaining); + buffer.get(b, off + (len - remaining), count); + remaining -= count; + } + return len - remaining; + } + + @Override + public void close() throws IOException { + lock.lock(); + try { + if (closed) return; + closed = true; + + } finally { + lock.unlock(); + } + if (debug.on()) + debug.log("Closing request body input stream"); + requestBodyQueue.add(QuicStreamReader.EOF); + } + + void close(IOException io) { + lock.lock(); + try { + if (closed) return; + closed = true; + error = io; + } finally { + lock.unlock(); + } + if (debug.on()) { + debug.log("Closing request body input stream: " + io); + } + requestBodyQueue.clear(); + requestBodyQueue.add(QuicStreamReader.EOF); + } + } + + Http3ServerExchange startExchange(HttpHeaders headers) { + Http3ServerExchange exchange = new Http3ServerExchange(this, headers, + new RequestBodyInputStream(), + serverConn.quicConnection().getTLSEngine().getSession()); + try { + serverConn.server().submitExchange(exchange); + } catch (Exception e) { + try { + exchange.close(new IOException(e)); + } catch (IOException ex) { + if (debug.on()) + debug.log("Failed to close exchange: " + ex); + } + } + return exchange; + } + + long streamId() { + return stream.streamId(); + } + + private void writeLoop() { + writeLock.lock(); + try { + writeEnabled.signalAll(); + } finally { + writeLock.unlock(); + } + } + + void close() { + serverConn.exchangeClosed(this); + } + + private final class HeadersConsumer extends ValidatingHeadersConsumer + implements DecodingCallback { + + private HeadersConsumer() { + super(Context.REQUEST); + } + + @Override + public void reset() { + super.reset(); + requestHeadersBuilder.clear(); + if (debug.on()) { + debug.log("Response builder cleared, ready to receive new headers."); + } + } + + @Override + public void onDecoded(CharSequence name, CharSequence value) + throws UncheckedIOException { + String n = name.toString(); + String v = value.toString(); + super.onDecoded(n, v); + requestHeadersBuilder.addHeader(n, v); + if (Log.headers() && Log.trace()) { + Log.logTrace("RECEIVED HEADER (streamid={0}): {1}: {2}", + streamId(), n, v); + } + } + + @Override + public void onComplete() { + HttpHeaders requestHeaders = requestHeadersBuilder.build(); + headersReader.reset(); + requestHeadersCF.complete(requestHeaders); + } + + @Override + public void onConnectionError(Throwable throwable, Http3Error http3Error) { + try { + stream.reset(http3Error.code()); + serverConn.connectionError(throwable, http3Error); + } catch (IOException ioe) { + serverConn.close(http3Error.code(), + ioe.getMessage()); + } + } + + @Override + public void onStreamError(Throwable throwable, Http3Error http3Error) { + try { + stream.reset(http3Error.code()); + } catch (IOException ioe) { + serverConn.close(http3Error.code(), + ioe.getMessage()); + } + } + + @Override + public long streamId() { + return stream.streamId(); + } + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3TestServer.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3TestServer.java new file mode 100644 index 00000000000..ff76c67ed7e --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/Http3TestServer.java @@ -0,0 +1,371 @@ +/* + * Copyright (c) 2022, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.http3; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.lang.System.Logger.Level; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.function.Predicate; + +import javax.net.ssl.SSLContext; + +import jdk.httpclient.test.lib.common.RequestPathMatcherUtil; +import jdk.httpclient.test.lib.common.RequestPathMatcherUtil.Resolved; +import jdk.httpclient.test.lib.http2.Http2Handler; +import jdk.httpclient.test.lib.http2.Http2TestExchange; +import jdk.httpclient.test.lib.quic.QuicServer; +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicVersion; +import static java.nio.charset.StandardCharsets.US_ASCII; + +public class Http3TestServer implements QuicServer.ConnectionAcceptor, AutoCloseable { + private static final AtomicLong IDS = new AtomicLong(); + + static final long DECODER_MAX_CAPACITY = + Utils.getLongProperty("http3.test.server.decoderMaxTableCapacity", 4096); + static final long ENCODER_CAPACITY_LIMIT = + Utils.getLongProperty("http3.test.server.encoderTableCapacityLimit", 4096); + + private static final String ALLOWED_HEADERS_PROP_NAME = "http3.test.server.encoderAllowedHeaders"; + static final String ALL_ALLOWED = "*"; + static final List ENCODER_ALLOWED_HEADERS = readEncoderAllowedHeadersProp(); + + private static List readEncoderAllowedHeadersProp() { + String properties = Utils.getProperty(ALLOWED_HEADERS_PROP_NAME); + if (properties == null) { + // If the system property is not set all headers are allowed to be encoded + return List.of(ALL_ALLOWED); + } else if(properties.isBlank()) { + // If system property value is a blank string - no headers are + // allowed to be encoded + return List.of(); + } + var headers = Arrays.stream(properties.split(",")) + .filter(Predicate.not(String::isBlank)) + .toList(); + if (headers.contains(ALL_ALLOWED)) { + return List.of(ALL_ALLOWED); + } + return headers; + } + + private final QuicServer quicServer; + private volatile boolean stopping; + private final Map handlers = new ConcurrentHashMap<>(); + private final Function handlerProvider; + private final Logger debug; + private final InetSocketAddress serverAddr; + private volatile ConnectionSettings ourSettings; + // request approver which takes the server connection key as the input + private volatile Predicate newRequestApprover; + + private static String nextName() { + return "h3-server-" + IDS.incrementAndGet(); + } + + /** + * Same as calling {@code Http3TestServer(sslContext, null)} + * + * @param sslContext SSLContext + * @throws IOException if the server could not be created + */ + public Http3TestServer(final SSLContext sslContext) throws IOException { + this(sslContext, null); + } + + /** + * Same as calling {@code Http3TestServer(sslContext, + * new InetSocketAddress(InetAddress.getLoopbackAddress(), port), null)} + * + * @param sslContext SSLContext + * @throws IOException if the server could not be created + */ + public Http3TestServer(final SSLContext sslContext, int port) throws IOException { + this(sslContext, new InetSocketAddress(InetAddress.getLoopbackAddress(), port), null); + } + + public Http3TestServer(final SSLContext sslContext, InetSocketAddress address, final ExecutorService executor) throws IOException { + this(quicServerBuilder().sslContext(sslContext).executor(executor).bindAddress(address).build(), null); + } + + public Http3TestServer(final SSLContext sslContext, final ExecutorService executor) throws IOException { + this(quicServerBuilder().sslContext(sslContext).executor(executor).build(), null); + } + + public Http3TestServer(final QuicServer quicServer) throws IOException { + this(quicServer, null); + } + + public Http3TestServer(final QuicServer quicServer, + final Function handlerProvider) + throws IOException { + Objects.requireNonNull(quicServer); + this.debug = Utils.getDebugLogger(quicServer::name); + this.quicServer = quicServer; + this.handlerProvider = handlerProvider; + this.quicServer.setConnectionAcceptor(this); + this.serverAddr = bindServer(this.quicServer); + debug.log(Level.INFO, "H3 server is listening at " + this.serverAddr); + } + + public void start() { + quicServer.start(); + } + + public void stop() { + try { + close(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public QuicServer getQuicServer() { + return this.quicServer; + } + + public InetSocketAddress getAddress() { + return this.serverAddr; + } + + public String serverAuthority() { + final InetSocketAddress inetSockAddr = getAddress(); + final String hostIP = inetSockAddr.getAddress().getHostAddress(); + // escape for ipv6 + final String h = hostIP.contains(":") ? "[" + hostIP + "]" : hostIP; + return h + ":" + inetSockAddr.getPort(); + } + + public void addHandler(final String path, final Http2Handler handler) { + if (this.handlerProvider != null) { + throw new IllegalStateException("Cannot add handler to H3 server which uses a handler provider"); + } + Objects.requireNonNull(path); + Objects.requireNonNull(handler); + this.handlers.put(path, handler); + } + + public void setRequestApprover(final Predicate approver) { + this.newRequestApprover = approver; + } + + /** + * Sets the connection settings that will be used by this server to generate a SETTINGS frame + * to send to client peers + * + * @param connectionSettings The connection settings + * @return The instance of this server + */ + public Http3TestServer setConnectionSettings(final ConnectionSettings connectionSettings) { + Objects.requireNonNull(connectionSettings); + this.ourSettings = connectionSettings; + return this; + } + + /** + * {@return the connection settings of this server, which will be sent to + * client peers in a SETTINGS frame. If none have been configured then this method returns + * {@link Optional#empty() empty}} + */ + public ConnectionSettings getConfiguredConnectionSettings() { + if (this.ourSettings == null) { + SettingsFrame settings = SettingsFrame.defaultRFCSettings(); + settings.setParameter(SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, DECODER_MAX_CAPACITY); + return ConnectionSettings.createFrom(settings); + } + return this.ourSettings; + } + + private static InetSocketAddress bindServer(final QuicServer server) throws IOException { + // bind the quic server to the socket + final SocketAddress addr = server.getEndpoint().getLocalAddress(); + if (addr instanceof InetSocketAddress inetsaddr) { + return inetsaddr; + } + throw new IOException(new IOException("Unexpected socket address type " + + addr.getClass().getName())); + } + + void submitExchange(final Http2TestExchange exchange) { + debug.log("H3 server handling exchange for: %s%n\t\t" + + "(Memory: max=%s, free=%s, total=%s)%n", + exchange.getRequestURI(), Runtime.getRuntime().maxMemory(), + Runtime.getRuntime().freeMemory(), Runtime.getRuntime().totalMemory()); + final String reqPath = exchange.getRequestURI().getPath(); + final Http2Handler handler; + if (this.handlerProvider != null) { + handler = this.handlerProvider.apply(reqPath); + } else { + Optional> match = + RequestPathMatcherUtil.findHandler(reqPath, this.handlers); + handler = match.isPresent() ? match.get().handler() : null; + } + // The server Http3ServerExchange uses a BlockingQueue to + // read data so handling the exchange in the current thread would + // wedge it. The executor must have at least one thread and must not + // execute inline - otherwise, we'd be wedged. + Thread currentThread = Thread.currentThread(); + this.quicServer.executorService().execute(() -> { + try { + // if no handler was located, we respond with a 404 + if (handler == null) { + respondForMissingHandler(exchange); + return; + } + // This assertion is too strong: there are cases + // where the calling task might terminate before + // the submitted task is executed. In which case + // it can safely run on the same thread. + // assert Thread.currentThread() != currentThread + // : "HTTP/3 server executor must have at least one thread"; + handler.handle(exchange); + } catch (Throwable failure) { + System.err.println("Failed to handle exchange: " + failure); + failure.printStackTrace(); + final var ioException = (failure instanceof IOException) + ? (IOException) failure + : new IOException(failure); + try { + exchange.close(ioException); + } catch (IOException x) { + System.err.println("Failed to close exchange: " + x); + } + } + }); + } + + private void respondForMissingHandler(final Http2TestExchange exchange) + throws IOException { + final byte[] responseBody = (this.getClass().getSimpleName() + + " - No handler available to handle request " + + exchange.getRequestURI()).getBytes(US_ASCII); + try (final OutputStream os = exchange.getResponseBody()) { + exchange.sendResponseHeaders(404, responseBody.length); + os.write(responseBody); + } + } + + /** + * Called by the {@link QuicServer} when a new connection has been added to the endpoint's + * connection map. + * + * @param source The client address + * @param quicConn the new connection + * @return true if the new connection should be accepted, false + * if it should be closed + */ + @Override + public boolean acceptIncoming(final SocketAddress source, final QuicServerConnection quicConn) { + if (stopping) { + return false; + } + debug.log("New connection %s accepted from %s", quicConn.dbgTag(), source); + quicConn.onSuccessfulHandshake(() -> { + var http3Connection = new Http3ServerConnection(this, quicConn, source); + http3Connection.start(); + }); + return true; + } + + boolean shouldProcessNewHTTPRequest(final Http3ServerConnection serverConn) { + final Predicate approver = this.newRequestApprover; + if (approver == null) { + // by the default the server will process new requests + return true; + } + final String connKey = serverConn.connectionKey(); + return approver.test(connKey); + } + + @Override + public void close() throws IOException { + stopping = true; + if (debug.on()) { + debug.log("Stopping H3 server " + this.serverAddr); + } + // TODO: should we close the quic server only if we "own" it + if (this.quicServer != null) { + this.quicServer.close(); + } + } + + public static QuicServer.Builder quicServerBuilder() { + return new H3QuicBuilder(); + } + + private static final class H3QuicBuilder extends QuicServer.Builder { + + public H3QuicBuilder() { + alpn = "h3"; + } + + @Override + public QuicServer build() throws IOException { + QuicVersion[] versions = availableQuicVersions; + if (versions == null) { + // default to v1 and v2 + versions = new QuicVersion[]{QuicVersion.QUIC_V1, QuicVersion.QUIC_V2}; + } + if (versions.length == 0) { + throw new IllegalStateException("Empty available QUIC versions"); + } + InetSocketAddress addr = bindAddress; + if (addr == null) { + // default to loopback address and ephemeral port + addr = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0); + } + SSLContext ctx = sslContext; + if (ctx == null) { + try { + ctx = SSLContext.getDefault(); + } catch (NoSuchAlgorithmException e) { + throw new IOException(e); + } + } + final QuicTLSContext quicTLSContext = new QuicTLSContext(ctx); + final String name = serverId == null ? nextName() : serverId; + return new QuicServer(name, addr, executor, versions, compatible, quicTLSContext, sniMatcher, + incomingDeliveryPolicy, outgoingDeliveryPolicy, alpn, Http3Error::stringForCode); + } + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/UnknownOrReservedFrame.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/UnknownOrReservedFrame.java new file mode 100644 index 00000000000..eddad466daa --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/http3/UnknownOrReservedFrame.java @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.http3; + +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.Random; + +import jdk.internal.net.http.http3.frames.AbstractHttp3Frame; +import jdk.internal.net.http.quic.VariableLengthEncoder; + +/** + * A test-only HTTP3 frame used to exercise the case where (client) implementation + * is expected to ignore unknown/reserved frames, as expected by RFC-9114, section + * 7.2.8 and section 9 + */ +final class UnknownOrReservedFrame extends AbstractHttp3Frame { + + private static final Random random = new Random(getSeed()); + + private final int payloadLength; + private final byte[] payload; + + public UnknownOrReservedFrame() { + super(generateRandomFrameType()); + this.payloadLength = random.nextInt(13); // arbitrary upper bound + this.payload = new byte[this.payloadLength]; + random.nextBytes(this.payload); + } + + @Override + public long length() { + return this.payloadLength; + } + + ByteBuffer toByteBuffer() { + final int frameSize = (int) this.size(); // cast is OK - value expected to be within range + final ByteBuffer buf = ByteBuffer.allocate(frameSize); + // write the type of the frame + VariableLengthEncoder.encode(buf, this.type()); + // write the length of the payload + VariableLengthEncoder.encode(buf, this.payloadLength); + // write the payload + buf.put(this.payload); + buf.flip(); + return buf; + } + + private static long generateRandomFrameType() { + final boolean useReservedFrameType = random.nextBoolean(); + if (useReservedFrameType) { + final int N = random.nextInt(100); // arbitrary upper bound + // RFC-9114, section 7.2.8: Frame types of the format 0x1f * N + 0x21 for non-negative + // integer values of N are reserved to exercise the requirement + // that unknown types be ignored + return 0x1F * N + 0x21; + } + // arbitrary lower bound of 0x45 + return random.nextLong(0x45, VariableLengthEncoder.MAX_ENCODED_INTEGER); + } + + private static long getSeed() { + Long seed = Long.getLong("seed"); + return seed != null ? seed : System.nanoTime() ^ new Random().nextLong(); + } + + static Optional tryGenerateFrame() { + // an arbitrary decision to create a new unknown/reserved frame + if (random.nextInt() % 8 == 0) { + return Optional.of(new UnknownOrReservedFrame()); + } + return Optional.empty(); + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/ClientConnection.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/ClientConnection.java new file mode 100644 index 00000000000..96bb04a4ac7 --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/ClientConnection.java @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.quic; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.QuicClient; +import jdk.internal.net.http.quic.QuicConnection; +import jdk.internal.net.http.quic.QuicConnectionImpl; +import jdk.internal.net.http.quic.QuicEndpoint; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import static jdk.internal.net.http.quic.TerminationCause.forTransportError; +import static jdk.internal.net.quic.QuicTransportErrors.NO_ERROR; + +/** + * A client initiated QUIC connection to a server + */ +public final class ClientConnection implements AutoCloseable { + + private static final Logger debug = Utils.getDebugLogger(() -> ClientConnection.class.getName()); + + private static final ByteBuffer EOF = ByteBuffer.allocate(0); + private static final String ALPN = QuicStandaloneServer.ALPN; + + private final QuicConnection connection; + private final QuicEndpoint endpoint; + + /** + * Establishes a connection between a Quic client and a target Quic server. This includes completing + * the handshake between the client and the server. + * + * @param client The Quic client + * @param serverAddr the target server address + * @return a ClientConnection + * @throws IOException If there was any exception while establishing the connection + */ + public static ClientConnection establishConnection(final QuicClient client, + final InetSocketAddress serverAddr) + throws IOException { + Objects.requireNonNull(client); + Objects.requireNonNull(serverAddr); + final QuicConnection conn = client.createConnectionFor(serverAddr, new String[]{ALPN}); + assert conn instanceof QuicConnectionImpl : "unexpected QUIC connection type: " + + conn.getClass(); + final CompletableFuture handshakeCf = + conn.startHandshake(); + final QuicEndpoint endpoint; + try { + endpoint = handshakeCf.get(); + } catch (InterruptedException e) { + throw new IOException(e); + } catch (ExecutionException e) { + throw new IOException(e.getCause()); + } + assert endpoint != null : "null endpoint after handshake completion"; + if (debug.on()) { + debug.log("Quic connection established for client: " + client.name() + + ", local addr: " + conn.localAddress() + + ", peer addr: " + serverAddr + + ", endpoint: " + endpoint); + } + return new ClientConnection(conn, endpoint); + } + + private ClientConnection(final QuicConnection connection, final QuicEndpoint endpoint) { + this.connection = Objects.requireNonNull(connection); + this.endpoint = endpoint; + } + + /** + * Creates a new client initiated bidirectional stream to the server. The returned + * {@code ConnectedBidiStream} will have the reader and writer tasks started, thus + * allowing the caller of this method to then use the returned {@code ConnectedBidiStream} + * for sending or received data through the {@link ConnectedBidiStream#outputStream() output stream} + * or {@link ConnectedBidiStream#inputStream() input stream} respectively. + * + * @return + * @throws IOException + */ + public ConnectedBidiStream initiateNewBidiStream() throws IOException { + final QuicBidiStream quicBidiStream; + try { + // TODO: review the duration being passed and whether it needs to be something + // that should be taken as an input to the initiateNewBidiStream() method + quicBidiStream = this.connection.openNewLocalBidiStream(Duration.ZERO).get(); + } catch (InterruptedException e) { + throw new IOException(e); + } catch (ExecutionException e) { + throw new IOException(e.getCause()); + } + return new ConnectedBidiStream(quicBidiStream); + } + + public QuicConnection underlyingQuicConnection() { + return this.connection; + } + + public QuicEndpoint endpoint() { + return this.endpoint; + } + + @Override + public void close() throws Exception { + this.connection.connectionTerminator().terminate(forTransportError(NO_ERROR)); + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/ConnectedBidiStream.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/ConnectedBidiStream.java new file mode 100644 index 00000000000..9aa2e9040e2 --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/ConnectedBidiStream.java @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.quic; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.Objects; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Semaphore; + +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.http.quic.streams.QuicStreamWriter; + +/** + * A {@code ConnectedBidiStream} represents a {@link + * jdk.internal.net.http.quic.streams.QuicBidiStream} + * which has a reader and writer task/loop started for it + */ +public class ConnectedBidiStream implements AutoCloseable { + + private final QuicBidiStream bidiStream; + private final QuicStreamReader quicStreamReader; + private final QuicStreamWriter quicStreamWriter; + private final BlockingQueue incomingData; + private final Semaphore writeSemaphore = new Semaphore(1); + private final OutputStream outputStream; + private final QueueInputStream inputStream; + private final SequentialScheduler readScheduler; + private volatile boolean closed; + + ConnectedBidiStream(final QuicBidiStream bidiStream) { + Objects.requireNonNull(bidiStream); + this.bidiStream = bidiStream; + incomingData = new ArrayBlockingQueue<>(1024, true); + this.quicStreamReader = bidiStream.connectReader( + readScheduler = SequentialScheduler.lockingScheduler(new ReaderLoop())); + this.inputStream = new QueueInputStream(this.incomingData, QuicStreamReader.EOF, quicStreamReader); + this.quicStreamWriter = bidiStream.connectWriter( + SequentialScheduler.lockingScheduler(() -> { + System.out.println("Server writer task called"); + writeSemaphore.release(); + })); + this.outputStream = new OutStream(this.quicStreamWriter, writeSemaphore); + // TODO: start the reader when inputStream() is called instead? + this.quicStreamReader.start(); + } + + public InputStream inputStream() { + return this.inputStream; + } + + public OutputStream outputStream() { + return this.outputStream; + } + + public QuicBidiStream underlyingBidiStream() { + return this.bidiStream; + } + + @Override + public void close() throws Exception { + this.closed = true; + // TODO: use runOrSchedule(executor)? + this.readScheduler.runOrSchedule(); + } + + + private final class ReaderLoop implements Runnable { + + private volatile boolean alreadyLogged; + + @Override + public void run() { + try { + if (quicStreamReader == null) return; + while (true) { + final var bb = quicStreamReader.poll(); + if (closed) { + return; + } + if (bb == null) { + return; + } + incomingData.add(bb); + if (bb == QuicStreamReader.EOF) { + break; + } + } + } catch (Throwable e) { + if (closed && e instanceof IOException) { + // the stream has been closed so we ignore any IOExceptions + return; + } + System.err.println("Error in " + getClass()); + e.printStackTrace(); + var in = inputStream; + if (in != null) { + in.error(e); + } + } + } + } + +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/DatagramDeliveryPolicy.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/DatagramDeliveryPolicy.java new file mode 100644 index 00000000000..277e0b289a5 --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/DatagramDeliveryPolicy.java @@ -0,0 +1,315 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.quic; + +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.text.ParseException; +import java.util.Random; +import java.util.concurrent.atomic.AtomicLong; + +import jdk.internal.net.http.quic.packets.QuicPacket; + +/** + * Used by the {@link QuicServer Quic server} and the {@link QuicServerConnection Quic server + * connection} to decide whether an incoming datagram needs to be dropped + */ +public interface DatagramDeliveryPolicy { + + /** + * System property for configuring the incoming datagram delivery policy for Quic server + */ + public static final String SYS_PROP_INCOMING_DELIVERY_POLICY = + "jdk.internal.httpclient.test.quic.incoming"; + + /** + * System property for configuring the outgoing datagram delivery policy for Quic server + */ + public static final String SYS_PROP_OUTGOING_DELIVERY_POLICY = + "jdk.internal.httpclient.test.quic.outgoing"; + + /** + * Will be called to decide if an incoming or an outgoing datagram should be dropped + * + * @param address The source or the destination address for the datagram + * @param payload The datagram payload + * @param connection The connection which was chosen to handle the Quic packet + * @return true if the datagram should be dropped, false otherwise + */ + boolean shouldDrop(SocketAddress address, ByteBuffer payload, QuicServerConnection connection, + QuicPacket.HeadersType headersType); + + /** + * Will be called to decide if an incoming datagram, which wasn't matched against + * any specific Quic connection, should be dropped + * + * @param source The source address which transmitted the datagram + * @param payload The datagram payload + * @return true if the datagram should be dropped, false otherwise + */ + boolean shouldDrop(SocketAddress source, ByteBuffer payload, QuicPacket.HeadersType headersType); + + static final DatagramDeliveryPolicy ALWAYS_DELIVER = new DatagramDeliveryPolicy() { + @Override + public boolean shouldDrop(final SocketAddress address, + final ByteBuffer payload, + final QuicServerConnection connection, + final QuicPacket.HeadersType headersType) { + return false; + } + + @Override + public boolean shouldDrop(final SocketAddress source, final ByteBuffer payload, + final QuicPacket.HeadersType headersType) { + return false; + } + + @Override + public String toString() { + return "[DatagramDeliveryPolicy=always deliver]"; + } + }; + + static final DatagramDeliveryPolicy NEVER_DELIVER = new DatagramDeliveryPolicy() { + @Override + public boolean shouldDrop(final SocketAddress address, + final ByteBuffer payload, + final QuicServerConnection connection, + final QuicPacket.HeadersType headersType) { + return true; + } + + @Override + public boolean shouldDrop(final SocketAddress source, final ByteBuffer payload, + final QuicPacket.HeadersType headersType) { + return true; + } + + @Override + public String toString() { + return "[DatagramDeliveryPolicy=never deliver]"; + } + }; + + static final class FixedRate implements DatagramDeliveryPolicy { + private final int n; + private final AtomicLong counter = new AtomicLong(); + + FixedRate(final int n) { + if (n <= 0) { + throw new IllegalArgumentException("n should be greater than 0"); + } + this.n = n; + } + + @Override + public boolean shouldDrop(final SocketAddress address, + final ByteBuffer payload, + final QuicServerConnection connection, + final QuicPacket.HeadersType headersType) { + final long current = counter.incrementAndGet(); + return current % n == 0; // drop every nth + } + + @Override + public boolean shouldDrop(final SocketAddress source, final ByteBuffer payload, + final QuicPacket.HeadersType headersType) { + final long current = counter.incrementAndGet(); + return current % n == 0; // drop every nth + } + + @Override + public String toString() { + return "[DatagramDeliveryPolicy=drop every " + n + "]"; + } + } + + static final class RandomDrop implements DatagramDeliveryPolicy { + private final long seed; + private final Random random; + + RandomDrop() { + Long s = null; + try { + // note that Long.valueOf(null) also throws a + // NumberFormatException so if the property is undefined this + // will still work correctly + s = Long.valueOf(System.getProperty("seed")); + } catch (NumberFormatException e) { + // do nothing: seed is still null + } + this.seed = s != null ? s : new Random().nextLong(); + this.random = new Random(seed); + } + + @Override + public boolean shouldDrop(final SocketAddress address, + final ByteBuffer payload, + final QuicServerConnection connection, + final QuicPacket.HeadersType headersType) { + return this.random.nextLong() % 42 == 0; + } + + @Override + public boolean shouldDrop(final SocketAddress source, final ByteBuffer payload, + final QuicPacket.HeadersType headersType) { + return this.random.nextLong() % 42 == 0; + } + + @Override + public String toString() { + return "[DatagramDeliveryPolicy=drop randomly, seed=" + seed + "]"; + } + } + + /** + * {@return a DatagramDeliveryPolicy which always returns false from the {@code shouldDrop} + * methods} + */ + public static DatagramDeliveryPolicy alwaysDeliver() { + return ALWAYS_DELIVER; + } + + /** + * {@return a DatagramDeliveryPolicy which always returns true from the {@code shouldDrop} + * methods} + */ + public static DatagramDeliveryPolicy neverDeliver() { + return NEVER_DELIVER; + } + + /** + * @param n the repeat count at which the datagram will be dropped + * @return a DatagramDeliveryPolicy which will return true on every {@code n}th call to + * either of the {@code shouldDrop} methods + */ + public static DatagramDeliveryPolicy dropEveryNth(final int n) { + return new FixedRate(n); + } + + /** + * @return a DatagramDeliveryPolicy which will randomly return true from the {@code shouldDrop} + * methods. If the {@code seed} system property is set then the {@code Random} instance used by + * this policy will use that seed. + */ + public static DatagramDeliveryPolicy dropRandomly() { + return new RandomDrop(); + } + + private static String privilegedGetProperty(String property) { + return privilegedGetProperty(property, null); + } + + private static String privilegedGetProperty(String property, String defval) { + return System.getProperty(property, defval); + } + + /** + * Reads the system property {@code sysPropName} and parses the value into a + * {@link DatagramDeliveryPolicy}. If the {@code sysPropName} system property isn't set or + * is set to a value of {@link String#isBlank() blank}, then this method returns a + * {@link DatagramDeliveryPolicy#alwaysDeliver() always deliver policy}. + *

    + * The {@code sysPropName} if set is expected to have either of the following values: + *

      + *
    • {@code always} - this returns a {@link DatagramDeliveryPolicy#alwaysDeliver() + * always deliver policy}
    • + *
    • {@code never} - this returns a {@link DatagramDeliveryPolicy#neverDeliver() + * never deliver policy}
    • + *
    • {@code fixed=} - where n is a positive integer, this returns a + * {@link DatagramDeliveryPolicy#dropEveryNth(int) dropEveryNth policy}
    • + *
    • {@code random} - this returns a + * {@link DatagramDeliveryPolicy#dropRandomly() dropRandomly policy}
    • + *
    + *

    + * + * @param sysPropName The system property name to use + * @return a DatagramDeliveryPolicy + * @throws ParseException If the system property value cannot be parsed into a + * DatagramDeliveryPolicy + */ + private static DatagramDeliveryPolicy fromSystemProperty(final String sysPropName) + throws ParseException { + String val = privilegedGetProperty(sysPropName); + if (val == null || val.isBlank()) { + return ALWAYS_DELIVER; + } + val = val.trim(); + if (val.startsWith("fixed=")) { + // read the characters following "fixed=" + String rateVal = val.substring("fixed=".length()); + final int n; + try { + n = Integer.parseInt(rateVal); + } catch (NumberFormatException nfe) { + throw new ParseException("Unexpected value: " + val, "fixed=".length()); + } + return dropEveryNth(n); + } else if (val.equals("random")) { + return dropRandomly(); + } else if (val.equals("always")) { + return ALWAYS_DELIVER; + } else if (val.equals("never")) { + return NEVER_DELIVER; + } else { + throw new ParseException("Unexpected value: " + val, 0); + } + } + + /** + * Returns the default incoming datagram delivery policy. This takes into account the + * {@link DatagramDeliveryPolicy#SYS_PROP_INCOMING_DELIVERY_POLICY} system property to decide + * the default policy + * + * @return the default incoming datagram delivery policy + * @throws ParseException If the {@link DatagramDeliveryPolicy#SYS_PROP_INCOMING_DELIVERY_POLICY} + * was configured and there was a problem parsing its value + */ + public static DatagramDeliveryPolicy defaultIncomingPolicy() throws ParseException { + try { + return fromSystemProperty(SYS_PROP_INCOMING_DELIVERY_POLICY); + } catch (Throwable t) { + t.printStackTrace(); + throw t; + } + } + + /** + * Returns the default outgoing datagram delivery policy. This takes into account the + * {@link DatagramDeliveryPolicy#SYS_PROP_OUTGOING_DELIVERY_POLICY} system property to decide + * the default policy + * + * @return the default outgoing datagram delivery policy + * @throws ParseException If the {@link DatagramDeliveryPolicy#SYS_PROP_OUTGOING_DELIVERY_POLICY} + * was configured and there was a problem parsing its value + */ + + public static DatagramDeliveryPolicy defaultOutgoingPolicy() throws ParseException { + try { + return fromSystemProperty(SYS_PROP_OUTGOING_DELIVERY_POLICY); + } catch (Throwable t) { + t.printStackTrace(); + throw t; + } + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/OutStream.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/OutStream.java new file mode 100644 index 00000000000..659c68682ef --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/OutStream.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.quic; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.Objects; +import java.util.concurrent.Semaphore; + +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.http.quic.streams.QuicStreamWriter; + +/** + * An {@link OutputStream} which writes using a {@link QuicStreamWriter} + */ +final class OutStream extends OutputStream { + + private final QuicStreamWriter quicStreamWriter; + private final Semaphore writeSemaphore; + + OutStream(final QuicStreamWriter quicStreamWriter, Semaphore writeSemaphore) { + Objects.requireNonNull(quicStreamWriter); + this.quicStreamWriter = quicStreamWriter; + this.writeSemaphore = Objects.requireNonNull(writeSemaphore); + } + + @Override + public void write(final int b) throws IOException { + this.write(new byte[]{(byte) (b & 0xff)}); + } + + @Override + public void write(final byte[] b, final int off, final int len) throws IOException { + Objects.checkFromIndexSize(off, len, b.length); + while (quicStreamWriter.credit() < 0 + && !quicStreamWriter.stopSendingReceived()) { + try { + writeSemaphore.acquire(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + // the data will be queued and won't be written immediately, and therefore + // it needs to be copied. + final ByteBuffer data = ByteBuffer.wrap(b.clone(), off, len); + quicStreamWriter.scheduleForWriting(data, false); + } + + @Override + public void close() throws IOException { + quicStreamWriter.scheduleForWriting(QuicStreamReader.EOF, true); + super.close(); + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QueueInputStream.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QueueInputStream.java new file mode 100644 index 00000000000..2756fa732fa --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QueueInputStream.java @@ -0,0 +1,164 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.quic; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Objects; +import java.util.concurrent.BlockingQueue; + +import jdk.internal.net.http.quic.streams.QuicStreamReader; +import jdk.internal.net.quic.QuicTransportErrors; + +/** + * An {@code InputStream} which reads its data from a {@link BlockingQueue} + */ +final class QueueInputStream extends InputStream { + private final BlockingQueue incomingData; + private final ByteBuffer eofIndicator; + private final QuicStreamReader streamReader; + // error needs volatile access as it is set by a different thread + private volatile Throwable error; + // current might not need volatile access as it should only + // be read/set by the reading thread. However available() + // might conceivably be called by multiple threads. + private volatile ByteBuffer current; + + QueueInputStream(final BlockingQueue incomingData, + final ByteBuffer eofIndicator, + QuicStreamReader streamReader) { + this.incomingData = incomingData; + this.eofIndicator = eofIndicator; + this.streamReader = streamReader; + } + + private ByteBuffer current() throws InterruptedException { + ByteBuffer current = this.current; + // if eof, there should no more byte buffer + if (current == eofIndicator) return eofIndicator; + if (current == null || !current.hasRemaining()) { + return (current = this.current = incomingData.take()); + } + return current; + } + + private boolean eof() { + ByteBuffer current = this.current; + return current == eofIndicator; + } + + @Override + public int read() throws IOException { + final byte[] data = new byte[1]; + final int numRead = this.read(data, 0, data.length); + // can't be 0, since we block till we receive at least 1 byte of data + assert numRead != 0 : "No data read"; + if (numRead == -1) { + return -1; + } + return data[0]; + } + + // concurrent calls to read() should not and are not supported + @Override + public int read(final byte[] b, final int off, final int len) throws IOException { + Objects.checkFromIndexSize(off, len, b.length); + int totalRead = 0; + while (totalRead < len) { + ByteBuffer bb = null; + checkError(); + try { + bb = current(); + } catch (InterruptedException e) { + streamReader.stream().requestStopSending(QuicTransportErrors.NO_ERROR.code()); + // TODO: should close here + error(e); + Thread.currentThread().interrupt(); + throw toIOException(e); + } + if (bb == eofIndicator) { + return totalRead == 0 ? -1 : totalRead; + } + final int available = bb.remaining(); + if (available > 0) { + final int numToTransfer = Math.min(available, (len - totalRead)); + bb.get(b, off + totalRead, numToTransfer); + totalRead += numToTransfer; + } + // if more data is available, take more, else if we have read at least 1 byte + // then return back + if (totalRead > 0 && incomingData.peek() == null) { + return totalRead; + } + } + if (totalRead > 0) return totalRead; + // if we reach here then len must be 0 + checkError(); + assert len == 0; + return eof() ? -1 : 0; + } + + @Override + public int available() throws IOException { + var bb = current; + if (bb == null || !bb.hasRemaining()) bb = incomingData.peek(); + if (bb == null || bb == eofIndicator) return 0; + return bb.remaining(); + } + + // we only check for errors after the incoming data queue + // has been emptied - except for interrupt. + private void checkError() throws IOException { + var error = this.error; + if (error == null) return; + if (error instanceof InterruptedException) + throw new IOException("closed by interrupt"); + var bb = current; + if (bb == eofIndicator || (bb != null && bb.hasRemaining())) return; + // we create a new exception to have the caller in the + // stack trace. + if (incomingData.isEmpty()) throw toIOException(error); + } + + // called if an error comes from upstream + void error(Throwable error) { + boolean firstError = false; + // only keep the first error + synchronized (this) { + var e = this.error; + if (e == null) { + e = this.error = error; + firstError = true; + } + } + // unblock read if needed + if (firstError) { + incomingData.add(ByteBuffer.allocate(0)); + } + } + + static IOException toIOException(Throwable error) { + return new IOException(error); + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServer.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServer.java new file mode 100644 index 00000000000..ae600a7a592 --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServer.java @@ -0,0 +1,913 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.quic; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.security.AlgorithmConstraints; +import java.security.AlgorithmParameters; +import java.security.CryptoPrimitive; +import java.security.Key; +import java.text.ParseException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; +import java.util.function.LongFunction; + +import javax.net.ssl.SNIMatcher; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; + +import jdk.httpclient.test.lib.common.ServerNameMatcher; +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.PeerConnectionId; +import jdk.internal.net.http.quic.QuicConnection; +import jdk.internal.net.http.quic.QuicConnectionId; +import jdk.internal.net.http.quic.QuicConnectionImpl; +import jdk.internal.net.http.quic.QuicEndpoint; +import jdk.internal.net.http.quic.QuicEndpoint.QuicEndpointFactory; +import jdk.internal.net.http.quic.QuicInstance; +import jdk.internal.net.http.quic.QuicSelector; +import jdk.internal.net.http.quic.QuicTransportParameters; +import jdk.internal.net.http.quic.packets.LongHeader; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.http.quic.packets.QuicPacketDecoder; +import jdk.internal.net.http.quic.packets.QuicPacketEncoder; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicVersion; +import static jdk.internal.net.http.quic.TerminationCause.forTransportError; +import static jdk.internal.net.quic.QuicTransportErrors.CONNECTION_REFUSED; + +/** + * This class represents a QuicServer. + */ +public sealed class QuicServer implements QuicInstance, AutoCloseable permits QuicStandaloneServer { + + public interface ConnectionAcceptor { + boolean acceptIncoming(SocketAddress source, QuicServerConnection quicConnection); + } + + final Logger debug = Utils.getDebugLogger(this::name); + + private final String serverId; + private final String name; + private final ExecutorService executor; + private final boolean ownExecutor; + private volatile ConnectionAcceptor newConnectionAcceptor; + private final String alpn; + private final InetSocketAddress bindAddress; + private final SNIMatcher sniMatcher; + private volatile InetSocketAddress listenAddress; + private final QuicTLSContext quicTLSContext; + private volatile boolean started; + private volatile boolean sendRetry; + protected final List availableQuicVersions; + private final QuicVersion preferredVersion; + private volatile QuicEndpoint endpoint; + private volatile QuicSelector selector; + private volatile boolean closed; + private volatile QuicTransportParameters transportParameters; + private final byte[] retryTokenPrefixBytes = new byte[4]; + private final byte[] newTokenPrefixBytes = new byte[4]; + private final DatagramDeliveryPolicy incomingDeliveryPolicy; + private final DatagramDeliveryPolicy outgoingDeliveryPolicy; + private boolean wantClientAuth; + private boolean needClientAuth; + // set of KeyAgreement algorithms to reject; used to force a HelloRetryRequest + private Set rejectedKAAlgos; + // used to compute MAX_STREAMS limit by connections created on this server instance. + // if null, then an internal algorithm is used to compute the limit. + // The Function takes a boolean argument whose value is true if the computation is for bidi + // streams and false for uni streams. The returned value from this function is expected + // to be the MAX_STREAMS limit to be imposed on the peer for that particular stream type + private volatile Function maxStreamLimitComputer; + + private final ReentrantLock quickServerLock = new ReentrantLock(); + private final LongFunction appErrorCodeToString; + + record RetryData(QuicConnectionId originalServerConnId, + QuicConnectionId serverChosenConnId) { + } + + public static abstract class Builder { + protected SSLContext sslContext; + protected InetSocketAddress bindAddress; + protected DatagramDeliveryPolicy incomingDeliveryPolicy; + protected DatagramDeliveryPolicy outgoingDeliveryPolicy; + protected String serverId; + protected SNIMatcher sniMatcher; + protected ExecutorService executor; + protected QuicVersion[] availableQuicVersions; + protected boolean compatible; + protected LongFunction appErrorCodeToString; + + protected ConnectionAcceptor connAcceptor = + (source, conn) -> { + System.err.println("Rejecting connection " + conn + " attempt from source " + + source); + return false; + }; + + protected String alpn; + + protected Builder() { + try { + incomingDeliveryPolicy = DatagramDeliveryPolicy.defaultIncomingPolicy(); + outgoingDeliveryPolicy = DatagramDeliveryPolicy.defaultOutgoingPolicy(); + } catch (ParseException e) { + throw new RuntimeException(e); + } + } + + public Builder availableVersions(final QuicVersion[] available) { + Objects.requireNonNull(available); + if (available.length == 0) { + throw new IllegalArgumentException("Empty available versions"); + } + this.availableQuicVersions = available; + return this; + } + + public Builder appErrorCodeToString(LongFunction appErrorCodeToString) { + this.appErrorCodeToString = appErrorCodeToString; + return this; + } + + public Builder compatibleNegotiation(boolean compatible) { + this.compatible = compatible; + return this; + } + + public Builder sslContext(final SSLContext sslContext) { + this.sslContext = sslContext; + return this; + } + + public Builder sniMatcher(final SNIMatcher sniMatcher) { + this.sniMatcher = sniMatcher; + return this; + } + + public Builder bindAddress(final InetSocketAddress addr) { + Objects.requireNonNull(addr); + this.bindAddress = addr; + return this; + } + + public Builder serverId(final String serverId) { + Objects.requireNonNull(serverId); + this.serverId = serverId; + return this; + } + + public Builder incomingDeliveryPolicy(final DatagramDeliveryPolicy policy) { + Objects.requireNonNull(policy); + this.incomingDeliveryPolicy = policy; + return this; + } + + public Builder outgoingDeliveryPolicy(final DatagramDeliveryPolicy policy) { + Objects.requireNonNull(policy); + this.outgoingDeliveryPolicy = policy; + return this; + } + + public Builder executor(final ExecutorService executor) { + this.executor = executor; + return this; + } + + public Builder alpn(String alpn) { + this.alpn = alpn; + return this; + } + + public abstract T build() throws IOException; + } + + private final QuicEndpointFactory endpointFactory; + public QuicServer(final String serverId, final InetSocketAddress bindAddress, + final ExecutorService executor, final QuicVersion[] availableQuicVersions, + boolean compatible, final QuicTLSContext quicTLSContext, final SNIMatcher sniMatcher, + final DatagramDeliveryPolicy incomingDeliveryPolicy, + final DatagramDeliveryPolicy outgoingDeliveryPolicy, String alpn, + final LongFunction appErrorCodeToString) { + this.bindAddress = bindAddress; + this.sniMatcher = sniMatcher == null + ? new ServerNameMatcher(this.bindAddress.getHostName()) + : sniMatcher; + this.alpn = Objects.requireNonNull(alpn); + this.appErrorCodeToString = appErrorCodeToString == null + ? QuicInstance.super::appErrorToString + : appErrorCodeToString; + if (executor != null) { + this.executor = executor; + this.ownExecutor = false; + } else { + this.executor = Utils.safeExecutor( + createExecutor(serverId), + (_, t) -> debug.log("rejected task - using ASYNC_POOL", t)); + this.ownExecutor = true; + } + this.serverId = serverId; + this.quicTLSContext = quicTLSContext; + this.name = "QuicServer(%s)".formatted(serverId); + if (compatible) { + this.availableQuicVersions = Arrays.asList(QuicVersion.values()); + } else { + this.availableQuicVersions = Arrays.asList(availableQuicVersions); + } + this.preferredVersion = availableQuicVersions[0]; + this.incomingDeliveryPolicy = incomingDeliveryPolicy; + this.outgoingDeliveryPolicy = outgoingDeliveryPolicy; + final Random random = new Random(); + random.nextBytes(retryTokenPrefixBytes); + random.nextBytes(newTokenPrefixBytes); + this.endpointFactory = newQuicEndpointFactory(); + if (debug.on()) { + debug.log("server created, incoming delivery policy %s, outgoing delivery policy %s", + this.incomingDeliveryPolicy, this.outgoingDeliveryPolicy); + } + } + + private static ExecutorService createExecutor(String name) { + String threadNamePrefix = "%s-quic-pool".formatted(name); + ThreadFactory threadFactory = Thread.ofPlatform().name(threadNamePrefix, 0).factory(); + return Executors.newCachedThreadPool(threadFactory); + } + + @Override + public String appErrorToString(long errorCode) { + return appErrorCodeToString.apply(errorCode); + } + + static QuicEndpointFactory newQuicEndpointFactory() { + return new QuicEndpointFactory(); + } + + public void setConnectionAcceptor(final ConnectionAcceptor acceptor) { + Objects.requireNonNull(acceptor); + quickServerLock.lock(); + try { + var current = this.newConnectionAcceptor; + if (current != null) { + throw new IllegalStateException("An connection acceptor already exists for" + + " this quic server " + this); + } + this.newConnectionAcceptor = acceptor; + } finally { + quickServerLock.unlock(); + } + } + + public void setWantClientAuth(boolean wantClientAuth) { + this.wantClientAuth = wantClientAuth; + } + + public void setNeedClientAuth(boolean needClientAuth) { + this.needClientAuth = needClientAuth; + } + + public void setRejectKeyAgreement(Set rejectedKAAlgos) { + this.rejectedKAAlgos = rejectedKAAlgos; + } + + Function getMaxStreamLimitComputer() { + return this.maxStreamLimitComputer; + } + + /** + * Sets a new MAX_STREAMS limit computer for this server. + * @param computer the limit computer. can be null, in which case an internal computation + * algorithm with decide the MAX_STREAMS limit for connections on this server + * instance + */ + public void setMaxStreamLimitComputer(final Function computer) { + this.maxStreamLimitComputer = computer; + } + + @Override + public String instanceId() { + return serverId; + } + + @Override + public QuicTLSContext getQuicTLSContext() { + return quicTLSContext; + } + + @Override + public boolean isVersionAvailable(QuicVersion quicVersion) { + return availableQuicVersions.contains(quicVersion); + } + + @Override + public List getAvailableVersions() { + return availableQuicVersions; + } + + public void sendRetry(final boolean enable) { + this.sendRetry = enable; + } + + public void start() { + this.started = true; + try { + final QuicEndpoint endpoint = getEndpoint(); + final InetSocketAddress addr = this.listenAddress = (InetSocketAddress) endpoint.getLocalAddress(); + if (debug.on()) { + debug.log("Quic server listening at: " + addr + + " supported versions " + this.availableQuicVersions); + } + } catch (IOException io) { + throw new UncheckedIOException(io); + } + } + + /** + * {@return the address on which the server is listening on} + * + * @throws IllegalStateException If server hasn't yet started + */ + public InetSocketAddress getAddress() { + final var addr = this.listenAddress; + if (addr == null) { + throw new IllegalArgumentException("Server hasn't started"); + } + return addr; + } + + /** + * The name identifying this QuicServer, used in debug traces. + * + * @return the name identifying this QuicServer. + * @implNote This is {@code "QuicServer()"}. + */ + @Override + public String name() { + return name; + } + + /** + * The executor used by this QuicServer when a task needs to + * be offloaded to a separate thread. + * + * @return the executor used by this QuicServer. + * @implNote This is the server internal executor. + */ + @Override + public Executor executor() { + return executor; + } + + public ExecutorService executorService() { + return this.executor; + } + + /** + * Get the QuicEndpoint for the given transport. + * + * @return the QuicEndpoint for the given transport. + * @throws IOException if an error occurs when setting up the endpoint + * or linking the transport with the endpoint. + * @throws IllegalStateException if the server is closed. + */ + @Override + public QuicEndpoint getEndpoint() throws IOException { + var endpoint = this.endpoint; + if (endpoint != null) return endpoint; + var selector = getSelector(); + quickServerLock.lock(); + try { + if (closed) throw new IllegalStateException("QuicServer is closed"); + endpoint = this.endpoint; + if (endpoint != null) return endpoint; + final String endpointName = "QuicEndpoint(" + serverId + ")"; + endpoint = this.endpoint = switch (QuicEndpoint.CONFIGURED_CHANNEL_TYPE) { + case NON_BLOCKING_WITH_SELECTOR -> + endpointFactory.createSelectableEndpoint(this, endpointName, + bindAddress, selector.timer()); + case BLOCKING_WITH_VIRTUAL_THREADS -> + endpointFactory.createVirtualThreadedEndpoint(this, endpointName, + bindAddress, selector.timer()); + }; + } finally { + quickServerLock.unlock(); + } + // register the newly created endpoint with the selector + QuicEndpoint.registerWithSelector(endpoint, selector, debug); + return endpoint; + } + + /** + * Gets the QuicSelector for the transport. + * + * @return the QuicSelector for the given transport. + * @throws IOException if an error occurs when setting up the selector + * or linking the transport with the selector. + * @throws IllegalStateException if the server is closed. + */ + private QuicSelector getSelector() throws IOException { + var selector = this.selector; + if (selector != null) return selector; + quickServerLock.lock(); + try { + if (closed) throw new IllegalStateException("QuicServer is closed"); + selector = this.selector; + if (selector != null) return selector; + final String selectorName = "QuicSelector(" + serverId + ")"; + selector = this.selector = switch (QuicEndpoint.CONFIGURED_CHANNEL_TYPE) { + case NON_BLOCKING_WITH_SELECTOR -> + QuicSelector.createQuicNioSelector(this, selectorName); + case BLOCKING_WITH_VIRTUAL_THREADS -> + QuicSelector.createQuicVirtualThreadPoller(this, selectorName); + }; + } finally { + quickServerLock.unlock(); + } + // we may be closed when we reach here. It doesn't matter though. + // if the selector is closed before it's started the thread will + // immediately exit (or exit after the first wakeup) + debug.log("starting selector"); + selector.start(); + return selector; + } + + @Override + public void unmatchedQuicPacket(SocketAddress source, QuicPacket.HeadersType type, ByteBuffer buffer) { + if (debug.on()) { + debug.log("Received datagram %s(src=%s, payload(%d))", type, source, buffer.remaining()); + } + // consult the delivery policy to see if we should silently drop this packet + if (this.incomingDeliveryPolicy.shouldDrop(source, buffer, type)) { + silentIgnorePacket(source, buffer, type, false); + return; + } + // check packet type. If Initial, it may be a connection attempt + int pos = buffer.position(); + if (type != QuicPacket.HeadersType.LONG) { + if (debug.on()) { + debug.log("Dropping unmatched datagram %s(src=%s, payload(%d))", + type, source, buffer.remaining()); + } + return; + } + // INITIAL packet + // decode packet here + // TODO: FIXME + // Transport: is this needed? + // ALPN, etc... + // Move this to a dedicated method + // Double check how the serverId provided by the client should + // be replaced + // Should the new connection have 2 connections id for a time? + // => the initial one that the client sent, and that will + // be used until the client receives our response, + // => the new connection id that we are sending back to the + // client? + LongHeader header = QuicPacketDecoder.peekLongHeader(buffer); + if (header == null) { + if (debug.on()) { + debug.log("Dropping invalid datagram %s(src=%s, payload(%d))", + type, source, buffer.remaining()); + } + return; + } + // need to assert that dest.remaining() >= 8 and drop the packet + // if this is not the case. + if (header.destinationId().length() < 8) { + debug.log("destination connection id has not enough bytes: %d", + header.destinationId().length()); + return; + } + + final QuicVersion version = QuicVersion.of(header.version()).orElse(null); + try { + // check that the server supports the version, send a version + if (header.version() == 0) { + if (debug.on()) { + debug.log("Stray version negotiation packet"); + } + return; + } + if (version == null || !availableQuicVersions.contains(version)) { + if (debug.on()) { + debug.log("Unsupported version number 0x%x in incoming packet (len=%d)", header.version(), buffer.remaining()); + } + if (buffer.remaining() >= QuicConnectionImpl.SMALLEST_MAXIMUM_DATAGRAM_SIZE) { + // A server might not send a Version Negotiation packet if the datagram it receives is smaller than the minimum size specified in a different version + int[] supported = availableQuicVersions.stream().mapToInt(QuicVersion::versionNumber).toArray(); + var negotiate = QuicPacketEncoder + .newVersionNegotiationPacket(header.destinationId(), + header.sourceId(), supported); + ByteBuffer datagram = ByteBuffer.allocateDirect(negotiate.size()); + QuicPacketEncoder.of(QuicVersion.QUIC_V1).encode(negotiate, datagram, null); + datagram.flip(); + sendDatagram(source, datagram); + } + return; + } + } catch (Throwable t) { + debug.log("Failed to decode packet", t); + return; + } + assert availableQuicVersions.contains(version); + final InetSocketAddress peerAddress = (InetSocketAddress) source; + final ByteBuffer token = QuicPacketDecoder.peekInitialPacketToken(buffer); + if (token == null) { + // packet is malformed: token will be an empty ByteBuffer if + // the packet doesn't contain a token. + debug.log("failed to read connection token"); + return; + } + var localAddress = this.getAddress(); + var conflict = Utils.addressConflict(localAddress, peerAddress); + if (conflict != null) { + String msg = "%s: %s (local:%s == peer:%s)!"; + System.out.println(msg.formatted(this, conflict, localAddress, peerAddress)); + debug.log(msg, "WARNING", conflict, localAddress, peerAddress); + Log.logError(msg.formatted(this, conflict, localAddress, peerAddress)); + } + final QuicServerConnection connection; + if (token.hasRemaining()) { + // the INITIAL packet contains a token. This token is then expected to either match + // the RETRY packet token that this server might have sent (if any) or a NEW_TOKEN frame + // token that this server might have sent (if any). If the token doesn't match either + // of these expectations then drop the packet (or send CLOSE_CONNECTION frame) + final RetryData retryData = isRetryToken(token.asReadOnlyBuffer()); + if (retryData != null) { + // the token matches one that this server could have sent as a RETRY token. verify + // that this server was indeed configured to send a RETRY token. + if (!sendRetry) { + // although the token looks like a RETRY token, this server wasn't configured + // to send a RETRY token, so consider this an invalid token and drop the packet + // (or send CLOSE_CONNECTION frame) + debug.log("Server dropping INITIAL packet due to token " + + "(which looks like an unexpected retry token) from " + peerAddress); + return; + } + // verify the dest connection id in the INITIAL packet is the one that we had asked + // the client to use through our RETRY packet + if (!retryData.serverChosenConnId.equals(header.destinationId())) { + // drop the packet + debug.log("Invalid dest connection id in INITIAL packet," + + " expected the one sent in RETRY packet " + retryData.serverChosenConnId + + " but found a different one " + header.destinationId()); + return; + } + // at this point we have verified that the token is a valid retry token that this server + // sent. We can now create a connection + final SSLParameters sslParameters = createSSLParameters(peerAddress); + final byte[] clientInitialToken = new byte[token.remaining()]; + token.get(clientInitialToken); + connection = new QuicServerConnection(this, version, preferredVersion, + peerAddress, header.destinationId(), + sslParameters, clientInitialToken, retryData); + debug.log("Created new server connection " + connection + " (with a retry token) " + + "to client " + peerAddress); + } else { + // token doesn't match a RETRY token. check if it is a NEW_TOKEN that this server + // sent + final boolean isNewToken = isNewToken(token.asReadOnlyBuffer()); + if (!isNewToken) { + // invalid token in the INITIAL packet. drop packet (or send CLOSE_CONNECTION + // frame) + debug.log("Server dropping INITIAL packet due to unexpected token from " + + peerAddress); + return; + } + // matches a NEW_TOKEN token. create the connection + final SSLParameters sslParameters = createSSLParameters(peerAddress); + final byte[] clientInitialToken = new byte[token.remaining()]; + token.get(clientInitialToken); + connection = new QuicServerConnection(this, version, preferredVersion, + peerAddress, header.destinationId(), + sslParameters, clientInitialToken); + debug.log("Created new server connection " + connection + + " (with NEW_TOKEN initial token) to client " + peerAddress); + } + } else { + // token is empty in INITIAL packet. send a RETRY packet if the server is configured + // to do so. The spec allows us to send the RETRY packet more than once to the same + // client, so we don't have to maintain any state to check if we already have sent one. + if (sendRetry) { + // send RETRY packet + final QuicConnectionId serverConnId = this.endpoint.idFactory().newConnectionId(); + final byte[] retryToken = buildRetryToken(header.destinationId(), serverConnId); + QuicPacketEncoder encoder = QuicPacketEncoder.of(version); + final var retry = encoder + .newRetryPacket(serverConnId, header.sourceId(), retryToken); + final ByteBuffer datagram = ByteBuffer.allocateDirect(retry.size()); + try { + encoder.encode(retry, datagram, new RetryCodingContext(header.destinationId(), quicTLSContext)); + } catch (Throwable t) { + // TODO: should we throw exception? + debug.log("Failed to encode packet", t); + return; + } + datagram.flip(); + debug.log("Sending RETRY packet to client " + peerAddress); + sendDatagram(source, datagram); + return; + } + // no token in INITIAL frame and the server isn't configured to send a RETRY packet. + // we are now ready to create the connection + final SSLParameters sslParameters = createSSLParameters(peerAddress); + connection = new QuicServerConnection(this, version, preferredVersion, + peerAddress, header.destinationId(), + sslParameters, null); + debug.log("Created new server connection " + connection + + " (without any initial token) to client " + peerAddress); + } + assert connection.quicVersion() == version; + + // TODO: maybe we should coalesce some dummy packet in the datagram + // to make sure the client will ignore it + // => this might require slightly altering the algorithm for + // encoding packets: + // we may need to build a packet, and then only encode it + // instead of having a toByteBuffer() method. + try { + endpoint.registerNewConnection(connection); + } catch (IOException io) { + if (closed) { + debug.log("Can't register new connection: server closed"); + } else if (debug.on()) { + debug.log("Can't register new connection", io); + } + // drop all bytes in the payload + buffer.position(buffer.limit()); + connection.connectionTerminator().terminate( + forTransportError(CONNECTION_REFUSED).loggedAs(io.getMessage())); + return; + } + connection.processIncoming(source, header.destinationId().asReadOnlyBuffer(), type, buffer); + + final ConnectionAcceptor connAcceptor = this.newConnectionAcceptor; + if (connAcceptor == null || !connAcceptor.acceptIncoming(source, connection)) { + buffer.position(buffer.limit()); + final String msg = "Quic server " + this.serverId + " refused connection"; + connection.connectionTerminator().terminate( + forTransportError(CONNECTION_REFUSED).loggedAs(msg)); + return; + } + } + + void sendDatagram(final SocketAddress dest, final ByteBuffer datagram) { + final QuicPacket.HeadersType headersType = QuicPacketDecoder.peekHeaderType(datagram, + datagram.position()); + if (this.outgoingDeliveryPolicy.shouldDrop(dest, datagram, headersType)) { + silentIgnorePacket(dest, datagram, headersType, true); + return; + } + endpoint.pushDatagram(null, dest, datagram); + return; + } + + private SSLParameters createSSLParameters(final InetSocketAddress peerAddress) { + final SSLParameters sslParameters = Utils.copySSLParameters(this.getSSLParameters()); + sslParameters.setApplicationProtocols(new String[]{this.alpn}); + sslParameters.setProtocols(new String[] {"TLSv1.3"}); + if (this.sniMatcher != null) { + sslParameters.setSNIMatchers(List.of(this.sniMatcher)); + } + if (wantClientAuth) { + sslParameters.setWantClientAuth(true); + } else if (needClientAuth) { + sslParameters.setNeedClientAuth(true); + } + if (rejectedKAAlgos != null) { + sslParameters.setAlgorithmConstraints(new TestAlgorithmConstraints(rejectedKAAlgos)); + } + return sslParameters; + } + + private byte[] buildRetryToken(final QuicConnectionId originalServerConnId, + final QuicConnectionId serverChosenNewConnId) { + // TODO this token is too simple to provide authenticity guarantee; use for testing only + final int NUM_BYTES_FOR_CONN_ID_LENGTH = 1; + final byte[] result = new byte[retryTokenPrefixBytes.length + + NUM_BYTES_FOR_CONN_ID_LENGTH + originalServerConnId.length() + + NUM_BYTES_FOR_CONN_ID_LENGTH + serverChosenNewConnId.length()]; + // copy the retry token prefix + System.arraycopy(retryTokenPrefixBytes, 0, result, 0, retryTokenPrefixBytes.length); + int currentIndex = retryTokenPrefixBytes.length; + // copy over the length of the original dest conn id + result[currentIndex++] = (byte) originalServerConnId.length(); + // copy over the original dest connection id sent by the client + originalServerConnId.asReadOnlyBuffer().get(0, result, currentIndex, originalServerConnId.length()); + currentIndex += originalServerConnId.length(); + // copy over the length of the server chosen dest conn id + result[currentIndex++] = (byte) serverChosenNewConnId.length(); + // copy over the connection id that the server has chosen and expects clients to use as new dest conn id + serverChosenNewConnId.asReadOnlyBuffer().get(0, result, currentIndex, serverChosenNewConnId.length()); + return result; + } + + private RetryData isRetryToken(final ByteBuffer token) { + Objects.requireNonNull(token); + if (!token.hasRemaining()) { + return null; + } + final int NUM_BYTES_FOR_CONN_ID_LENGTH = 1; + // we expect the retry token prefix and 2 connection ids. so expected length = retry token prefix + // length plus the length of each of the connection ids + final int expectedLength = retryTokenPrefixBytes.length + NUM_BYTES_FOR_CONN_ID_LENGTH + + NUM_BYTES_FOR_CONN_ID_LENGTH; + if (token.remaining() <= expectedLength) { + return null; + } + final byte[] tokenPrefixBytes = new byte[retryTokenPrefixBytes.length]; + token.get(tokenPrefixBytes); + int mismatchIndex = Arrays.mismatch(retryTokenPrefixBytes, 0, retryTokenPrefixBytes.length, + tokenPrefixBytes, 0, retryTokenPrefixBytes.length); + if (mismatchIndex != -1) { + // token doesn't start with the expected retry token prefix. Not a valid retry token + return null; + } + // now find the length of the original connection id + final int originalServerConnIdLen = token.get(); + final byte[] originalServerConnId = new byte[originalServerConnIdLen]; + // read the original dest conn id + token.get(originalServerConnId); + + // now find the length of the server generated dest connection id + final int serverChosenDestConnIdLen = token.get(); + final byte[] serverChosenDestConnId = new byte[serverChosenDestConnIdLen]; + // read the server chosen dest conn id + token.get(serverChosenDestConnId); + + // TODO: the use of PeerConnectionId is only for convenience + return new RetryData(new PeerConnectionId(originalServerConnId), new PeerConnectionId(serverChosenDestConnId)); + } + + DatagramDeliveryPolicy incomingDeliveryPolicy() { + return this.incomingDeliveryPolicy; + } + + DatagramDeliveryPolicy outgoingDeliveryPolicy() { + return this.outgoingDeliveryPolicy; + } + + byte[] buildNewToken() { + // TODO this token is too simple to provide authenticity guarantee; use for testing only + final byte[] token = new byte[newTokenPrefixBytes.length]; + // copy the new_token prefix + System.arraycopy(newTokenPrefixBytes, 0, token, 0, newTokenPrefixBytes.length); + return token; + } + + private boolean isNewToken(final ByteBuffer token) { + Objects.requireNonNull(token); + if (!token.hasRemaining()) { + return false; + } + if (token.remaining() != newTokenPrefixBytes.length) { + return false; + } + final byte[] tokenBytes = new byte[newTokenPrefixBytes.length]; + token.get(tokenBytes); + int mismatchIndex = Arrays.mismatch(newTokenPrefixBytes, 0, newTokenPrefixBytes.length, + tokenBytes, 0, newTokenPrefixBytes.length); + if (mismatchIndex != -1) { + // token doesn't start with the expected new_token prefix. Not a valid token + return false; + } + return true; + } + + private void silentIgnorePacket(final SocketAddress source, final ByteBuffer payload, + final QuicPacket.HeadersType headersType, final boolean outgoing) { + if (debug.on()) debug.log("silently dropping %s packet %s %s", headersType, + (outgoing ? "to dest" : "from source"), source); + } + + @Override + public void close() throws IOException { + // TODO: ignore exceptions while closing? + quickServerLock.lock(); + try { + if (closed) return; + closed = true; + } finally { + quickServerLock.unlock(); + } + //http3Server.stop(); + var endpoint = this.endpoint; + if (endpoint != null) endpoint.close(); + debug.log("endpoint closed"); + var selector = this.selector; + if (selector != null) selector.close(); + debug.log("selector closed"); + if (ownExecutor && executor != null) { + debug.log("shutting down executor"); + this.executor.shutdown(); + try { + debug.log("awaiting termination"); + this.executor.awaitTermination(QuicSelector.IDLE_PERIOD_MS, + TimeUnit.MILLISECONDS); + } catch (InterruptedException ie) { + this.executor.shutdownNow(); + } + } + } + + /** + * The transport parameters that will be sent to the peer by any new subsequent server connections + * that are created by this server + * + * @param params transport parameters. Can be null, in which case the new server connection + * (whenever it is created) will use internal defaults + */ + public void setTransportParameters(QuicTransportParameters params) { + this.transportParameters = params; + } + + /** + * {@return the current configured transport parameters for new server connections. null if none + * configured} + */ + @Override + public QuicTransportParameters getTransportParameters() { + final QuicTransportParameters qtp = this.transportParameters; + if (qtp == null) { + return null; + } + // return a copy + return new QuicTransportParameters(qtp); + } + + /** + * Called + * + * @param connection + * @param originalConnectionId + * @param localConnectionId + */ + void connectionAcknowledged(QuicConnection connection, + QuicConnectionId originalConnectionId, + QuicConnectionId localConnectionId) { + // endpoint.removeConnectionId(originalConnectionId); + } + + private static class TestAlgorithmConstraints implements AlgorithmConstraints { + private final Set rejectedKAAlgos; + + public TestAlgorithmConstraints(Set rejectedKAAlgos) { + this.rejectedKAAlgos = rejectedKAAlgos; + } + + @Override + public boolean permits(Set primitives, String algorithm, AlgorithmParameters parameters) { + if (primitives.contains(CryptoPrimitive.KEY_AGREEMENT) && + rejectedKAAlgos.contains(algorithm)) { + return false; + } + return true; + } + + @Override + public boolean permits(Set primitives, Key key) { + return true; + } + + @Override + public boolean permits(Set primitives, String algorithm, Key key, AlgorithmParameters parameters) { + return true; + } + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServerConnection.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServerConnection.java new file mode 100644 index 00000000000..53a9129d452 --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServerConnection.java @@ -0,0 +1,570 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.quic; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Stream; + +import javax.net.ssl.SSLParameters; + +import jdk.internal.net.http.common.Log; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.QuicEndpoint; +import jdk.internal.net.http.quic.QuicTransportParameters.VersionInformation; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.http.quic.QuicConnectionId; +import jdk.internal.net.http.quic.QuicConnectionImpl; +import jdk.internal.net.http.quic.QuicTransportParameters; +import jdk.internal.net.http.quic.QuicTransportParameters.ParameterId; +import jdk.internal.net.http.quic.VariableLengthEncoder; +import jdk.internal.net.http.quic.frames.CryptoFrame; +import jdk.internal.net.http.quic.frames.HandshakeDoneFrame; +import jdk.internal.net.http.quic.frames.NewTokenFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.packets.InitialPacket; +import jdk.internal.net.http.quic.packets.OneRttPacket; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketType; +import jdk.internal.net.http.quic.packets.QuicPacketDecoder; +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTLSEngine.HandshakeState; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.quic.QuicVersion; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_data; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_stream_data_bidi_local; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_stream_data_bidi_remote; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_stream_data_uni; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_streams_bidi; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_max_streams_uni; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.initial_source_connection_id; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.max_idle_timeout; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.original_destination_connection_id; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.retry_source_connection_id; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.stateless_reset_token; +import static jdk.internal.net.http.quic.QuicTransportParameters.ParameterId.version_information; +import static jdk.internal.net.quic.QuicTransportErrors.PROTOCOL_VIOLATION; + +public final class QuicServerConnection extends QuicConnectionImpl { + private static final AtomicLong CONNECTIONS = new AtomicLong(); + private final QuicVersion preferredQuicVersion; + private volatile boolean connectionIdAcknowledged; + private final QuicServer server; + private final byte[] clientInitialToken; + private final QuicConnectionId clientSentDestConnId; + private final QuicConnectionId originalServerConnId; + private final QuicServer.RetryData retryData; + private final AtomicBoolean firstHandshakePktProcessed = new AtomicBoolean(); + + public static final boolean FILTER_SENDER_ADDRESS = Utils.getBooleanProperty( + "test.quic.server.filterSenderAddress", true); + + QuicServerConnection(QuicServer server, + QuicVersion quicVersion, + QuicVersion preferredQuicVersion, + InetSocketAddress peerAddress, + QuicConnectionId clientSentDestConnId, + SSLParameters sslParameters, + byte[] initialToken) { + this(server, quicVersion, preferredQuicVersion, peerAddress, clientSentDestConnId, + sslParameters, initialToken, null); + + } + + QuicServerConnection(QuicServer server, + QuicVersion quicVersion, + QuicVersion preferredQuicVersion, + InetSocketAddress peerAddress, + QuicConnectionId clientSentDestConnId, + SSLParameters sslParameters, + byte[] initialToken, + QuicServer.RetryData retryData) { + super(quicVersion, server, peerAddress, null, -1, sslParameters, + "QuicServerConnection(%s)", CONNECTIONS.incrementAndGet()); + this.preferredQuicVersion = preferredQuicVersion; + // this should have been first statement in this constructor but compiler doesn't allow it + Objects.requireNonNull(quicVersion, "quic version"); + this.clientInitialToken = initialToken; + this.server = server; + this.clientSentDestConnId = clientSentDestConnId; + this.retryData = retryData; + this.originalServerConnId = retryData == null ? clientSentDestConnId : retryData.originalServerConnId(); + handshakeFlow().handshakeCF().thenAccept((hs) -> { + try { + onHandshakeCompletion(hs); + } catch (Exception e) { + // TODO: consider if this needs to be propagated somehow. for now just log + System.err.println("onHandshakeCompletion() failed: " + e); + e.printStackTrace(); + } + }); + assert quicVersion == quicVersion() : "unexpected quic version on" + + " server connection, expected " + quicVersion + " but found " + quicVersion(); + try { + getTLSEngine().deriveInitialKeys(quicVersion, clientSentDestConnId.asReadOnlyBuffer()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + protected QuicConnectionId originalServerConnId() { + return this.originalServerConnId; + } + + @Override + protected boolean verifyToken(QuicConnectionId destinationID, byte[] token) { + return Arrays.equals(clientInitialToken, token); + } + + @Override + public List connectionIds() { + var connectionIds = super.connectionIds(); + // we can stop using the original connection id if we have + // received the ClientHello fully. + // TODO: find when/where to switch connectionIdAcknowledged to true + // TODO: what if the ClientHello is in 3 initial packets and the packet number 2 + // gets lost? How do we know? I guess we can assume that the client hello + // was fully receive when we send (or receive) the first handshake packet. + if (!connectionIdAcknowledged) { + // Add client's initial connection ID (original or retry) + QuicConnectionId initial = this.clientSentDestConnId; + connectionIds = Stream.concat(connectionIds.stream(), Stream.of(initial)).toList(); + } + return connectionIds; + } + + @Override + public Optional initialConnectionId() { + return Optional.ofNullable(clientSentDestConnId); + } + + @Override + public void processIncoming(final SocketAddress source, final ByteBuffer destConnId, + final QuicPacket.HeadersType headersType, final ByteBuffer buffer) { + // consult the delivery policy if this packet should be dropped + if (this.server.incomingDeliveryPolicy().shouldDrop(source, buffer, this, headersType)) { + silentIgnorePacket(source, buffer, headersType, false, "incoming delivery policy"); + return; + } + if (!connectionIdAcknowledged && localConnectionId().matches(destConnId)) { + debug.log("connection acknowledged"); + connectionIdAcknowledged = true; + server.connectionAcknowledged(this, clientSentDestConnId, localConnectionId()); + endpoint.removeConnectionId(clientSentDestConnId, this); + } + super.processIncoming(source, destConnId, headersType, buffer); + } + + @Override + public boolean accepts(SocketAddress source) { + if (FILTER_SENDER_ADDRESS && !source.equals(peerAddress())) { + // We do not support path migration yet, so we only accept + // packets from the endpoint to which we send them. + if (debug.on()) { + debug.log("unexpected sender %s, skipping packet", source); + } + return false; + } + return true; + } + + @Override + protected void processRetryPacket(final QuicPacket quicPacket) { + // server is not supposed to receive retry packet: + // ignore it? + Objects.requireNonNull(quicPacket); + if (quicPacket.packetType() != PacketType.RETRY) { + throw new IllegalArgumentException("Not a RETRY packet: " + quicPacket.packetType()); + } + if (Log.errors()) { + Log.logError("Server received RETRY packet - discarding it"); + } + } + + @Override + protected void incoming1RTTFrame(HandshakeDoneFrame frame) throws QuicTransportException { + // RFC-9000, section 19.20: A HANDSHAKE_DONE frame can only be sent by + // the server. ... A server MUST treat receipt of a HANDSHAKE_DONE frame + // as a connection error of type PROTOCOL_VIOLATION + throw new QuicTransportException("HANDSHAKE_DONE frame isn't allowed from clients", + null, + frame.getTypeField(), PROTOCOL_VIOLATION); + } + + @Override + protected void incoming1RTTFrame(NewTokenFrame frame) throws QuicTransportException { + // This is a server connection and as per RFC-9000, section 19.7, clients + // aren't supposed to send NEW_TOKEN frames and if a server receives such + // a frame then it is considered a connection error + // of type PROTOCOL_VIOLATION. + throw new QuicTransportException("NEW_TOKEN frame isn't allowed from clients", + null, + frame.getTypeField(), PROTOCOL_VIOLATION); + } + + @Override + protected void pushDatagram(final SocketAddress destination, final ByteBuffer datagram) { + final QuicPacket.HeadersType headersType = QuicPacketDecoder.peekHeaderType(datagram, + datagram.position()); + // consult the delivery policy if this packet should be dropped + if (this.server.outgoingDeliveryPolicy().shouldDrop(destination, datagram, + this, headersType)) { + silentIgnorePacket(destination, datagram, headersType, true, "outgoing delivery policy"); + return; + } + super.pushDatagram(destination, datagram); + } + + @Override + protected void processInitialPacket(final QuicPacket quicPacket) { + try { + if (!(quicPacket instanceof InitialPacket initialPacket)) { + throw new AssertionError("Bad packet type: " + quicPacket); + } + updatePeerConnectionId(initialPacket); + var initialPayloadLength = initialPacket.payloadSize(); + assert initialPayloadLength < Integer.MAX_VALUE; + if (debug.on()) { + debug.log("Initial payload (count=%d, remaining=%d)", + initialPacket.frames().size(), initialPayloadLength); + } + long total = processInitialPacketPayload(initialPacket); + assert total == initialPayloadLength; + if (initialPacket.frames().stream().anyMatch(f -> f instanceof CryptoFrame)) { + debug.log("ClientHello received"); + } + var hsState = getTLSEngine().getHandshakeState(); + debug.log("hsState: " + hsState); + if (hsState == QuicTLSEngine.HandshakeState.NEED_SEND_CRYPTO) { + debug.log("Continuing handshake"); + continueHandshake(); + } else if (quicPacket.isAckEliciting() && + getTLSEngine().getCurrentSendKeySpace() == QuicTLSEngine.KeySpace.HANDSHAKE) { + packetNumberSpaces().initial().fastRetransmit(); + } + } catch (Throwable t) { + debug.log("Unexpected exception handling initial packet", t); + connectionTerminator().terminate(TerminationCause.forException(t)); + } + } + + @Override + protected void processHandshakePacket(final QuicPacket quicPacket) { + super.processHandshakePacket(quicPacket); + if (this.firstHandshakePktProcessed.compareAndSet(false, true)) { + // close INITIAL packet space and discard INITIAL keys as expected by + // RFC-9001, section 4.9.1: ... a server MUST discard Initial keys when + // it first successfully processes a Handshake packet. Endpoints MUST NOT send + // Initial packets after this point. + if (debug.on()) { + debug.log("server processed first handshake packet, initiating close of" + + " INITIAL packet space"); + } + packetNumberSpaces().initial().close(); + } + QuicTLSEngine engine = getTLSEngine(); + switch (engine.getHandshakeState()) { + case NEED_SEND_HANDSHAKE_DONE -> { + // should ack handshake and possibly send HandshakeDoneFrame + // the HANDSHAKE space will be closed after sending the + // HANDSHAKE_DONE frame (see sendStreamData) + packetSpace(PacketNumberSpace.HANDSHAKE).runTransmitter(); + engine.tryMarkHandshakeDone(); + enqueue1RTTFrame(new HandshakeDoneFrame()); + debug.log("Adding HandshakeDoneFrame"); + completeHandshakeCF(); + packetSpace(PacketNumberSpace.APPLICATION).runTransmitter(); + } + } + } + + @Override + protected void completeHandshakeCF() { + completeHandshakeCF(null); + } + + @Override + protected void send1RTTPacket(final OneRttPacket packet) + throws QuicKeyUnavailableException, QuicTransportException { + boolean closeHandshake = false; + var handshakeSpace = packetNumberSpaces().handshake(); + if (!handshakeSpace.isClosed()) { + closeHandshake = packet.frames() + .stream() + .anyMatch(HandshakeDoneFrame.class::isInstance); + } + super.send1RTTPacket(packet); + if (closeHandshake) { + // close handshake space after sending + // HANDSHAKE_DONE + handshakeSpace.close(); + } + } + + /** + * This method can be invoked if a certain {@code action} needs to be performed, + * by this server connection, on a successful completion of Quic connection handshake + * (initiated by a client). + * + * @param action The action to be performed on successful completion of the handshake + */ + public void onSuccessfulHandshake(final Runnable action) { + this.handshakeFlow().handshakeCF().thenRun(action); + } + + /** + * This method can be invoked if a certain {@code action} needs to be performed, + * by this server connection, when a Quic connection handshake (initiated by a client), completes. + * The handshake could either have succeeded or failed. If the handshake succeeded, then the + * {@code Throwable} passed to the {@code action} will be {@code null}, else it will represent + * the handshake failure. + * + * @param action The action to be performed on completion of the handshake + */ + public void onHandshakeCompletion(final Consumer action) { + this.handshakeFlow().handshakeCF().handle((unused, failure) -> { + action.accept(failure); + return null; + }); + } + + + @Override + public QuicServer quicInstance() { + return server; + } + + @Override + protected ByteBuffer buildInitialParameters() { + final QuicTransportParameters params = new QuicTransportParameters(this.transportParams); + if (!params.isPresent(original_destination_connection_id)) { + params.setParameter(original_destination_connection_id, originalServerConnId.getBytes()); + } + if (!params.isPresent(initial_source_connection_id)) { + params.setParameter(initial_source_connection_id, localConnectionId().getBytes()); + } + if (!params.isPresent(stateless_reset_token)) { + params.setParameter(stateless_reset_token, + endpoint.idFactory().statelessTokenFor(localConnectionId())); + } + if (!params.isPresent(max_idle_timeout)) { + final long idleTimeoutMillis = TimeUnit.SECONDS.toMillis( + Utils.getLongProperty("jdk.test.server.quic.idleTimeout", 30)); + params.setIntParameter(max_idle_timeout, idleTimeoutMillis); + } + if (retryData != null && !params.isPresent(retry_source_connection_id)) { + // include the connection id that was directed by this server's RETRY packet + // for usage in INITIAL packets sent by client + params.setParameter(retry_source_connection_id, + retryData.serverChosenConnId().getBytes()); + } + setIntParamIfNotSet(params, initial_max_data, () -> DEFAULT_INITIAL_MAX_DATA); + setIntParamIfNotSet(params, initial_max_stream_data_bidi_local, () -> DEFAULT_INITIAL_STREAM_MAX_DATA); + setIntParamIfNotSet(params, initial_max_stream_data_bidi_remote, () -> DEFAULT_INITIAL_STREAM_MAX_DATA); + setIntParamIfNotSet(params, initial_max_stream_data_uni, () -> DEFAULT_INITIAL_STREAM_MAX_DATA); + setIntParamIfNotSet(params, initial_max_streams_bidi, () -> (long) DEFAULT_MAX_BIDI_STREAMS); + setIntParamIfNotSet(params, initial_max_streams_uni, () -> (long) DEFAULT_MAX_UNI_STREAMS); + // params.setParameter(QuicTransportParameters.ParameterId.stateless_reset_token, ...); // no token + // params.setIntParameter(QuicTransportParameters.ParameterId.ack_delay_exponent, 3); // unit 2^3 microseconds + // params.setIntParameter(QuicTransportParameters.ParameterId.max_ack_delay, 25); //25 millis + // params.setBooleanParameter(QuicTransportParameters.ParameterId.disable_active_migration, false); + // params.setPreferedAddressParameter(QuicTransportParameters.ParameterId.preferred_address, ...); + // params.setIntParameter(QuicTransportParameters.ParameterId.active_connection_id_limit, 2); + if (!params.isPresent(version_information)) { + final VersionInformation vi = QuicTransportParameters.buildVersionInformation( + quicVersion(), quicInstance().getAvailableVersions()); + params.setVersionInformationParameter(version_information, vi); + } + final byte[] unsupportedTransportParam = encodeRandomUnsupportedTransportParameter(); + final int capacity = params.size() + unsupportedTransportParam.length; + final ByteBuffer buf = ByteBuffer.allocate(capacity); + params.encode(buf); + // add an unsupported transport param id so that we can exercise the case where endpoints + // are expected to ignore unsupported transport parameters (RFC-9000, section 7.4.2) + buf.put(unsupportedTransportParam); + buf.flip(); + newLocalTransportParameters(params); + return buf; + } + + // returns the encoded representation of a random unsupported transport parameter + private static byte[] encodeRandomUnsupportedTransportParameter() { + final int n = new Random().nextInt(1, 100); + final long unsupportedParamId = 31 * n + 27; // RFC-9000, section 18.1 + final int value = 42; + // Transport Parameter { + // Transport Parameter ID (i), + // Transport Parameter Length (i), + // Transport Parameter Value (..), + // } + int size = 0; + size += VariableLengthEncoder.getEncodedSize(unsupportedParamId); + final int paramLength = VariableLengthEncoder.getEncodedSize(value); + size += VariableLengthEncoder.getEncodedSize(paramLength); + size += paramLength; + final byte[] encoded = new byte[size]; + final ByteBuffer buf = ByteBuffer.wrap(encoded); + + VariableLengthEncoder.encode(buf, unsupportedParamId); // write out the id, as a variable length integer + VariableLengthEncoder.encode(buf, paramLength); // write out the len, as a variable length integer + VariableLengthEncoder.encode(buf, value); // write out the actual value + return encoded; + } + + @Override + protected void consumeQuicParameters(ByteBuffer byteBuffer) + throws QuicTransportException { + final QuicTransportParameters params = QuicTransportParameters.decode(byteBuffer); + if (debug.on()) { + debug.log("Received (from client) Quic transport params: " + params.toStringWithValues()); + } + if (params.isPresent(retry_source_connection_id)) { + throw new QuicTransportException("Retry connection ID not expected here", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + if (params.isPresent(original_destination_connection_id)) { + throw new QuicTransportException("Original connection ID not expected here", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + if (params.isPresent(stateless_reset_token)) { + throw new QuicTransportException("Reset token not expected here", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + if (params.isPresent(ParameterId.preferred_address)) { + throw new QuicTransportException("Preferred address not expected here", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + if (!params.matches(initial_source_connection_id, + getIncomingInitialPacketSourceId())) { + throw new QuicTransportException("Peer connection ID does not match", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + VersionInformation vi = + params.getVersionInformationParameter(version_information); + if (vi != null) { + boolean found = false; + for (int v: vi.availableVersions()) { + if (v == vi.chosenVersion()) { + found = true; + break; + } + } + if (!found) { + throw new QuicTransportException( + "[version_information] Chosen Version not in Available Versions", + null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + } + if (vi.chosenVersion() != quicVersion().versionNumber()) { + throw new QuicTransportException( + "[version_information] Chosen Version %s does not match version in use %s" + .formatted(vi.chosenVersion(), quicVersion().versionNumber()), + null, 0, QuicTransportErrors.VERSION_NEGOTIATION_ERROR); + } + assert Arrays.stream(vi.availableVersions()).anyMatch(v -> v == preferredQuicVersion.versionNumber()); + if (preferredQuicVersion != quicVersion()) { + if (!switchVersion(preferredQuicVersion)) { + throw new QuicTransportException("Switching version failed", + null, 0, QuicTransportErrors.VERSION_NEGOTIATION_ERROR); + } + } + } else { + assert preferredQuicVersion == quicVersion(); + } + markVersionNegotiated(preferredQuicVersion.versionNumber()); + handleIncomingPeerTransportParams(params); + + // build our parameters + final ByteBuffer quicInitialParameters = buildInitialParameters(); + getTLSEngine().setLocalQuicTransportParameters(quicInitialParameters); + // params.setIntParameter(QuicTransportParameters.ParameterId.initial_max_data, DEFAULT_INITIAL_MAX_DATA); + // params.setIntParameter(QuicTransportParameters.ParameterId.initial_max_stream_data_bidi_local, DEFAULT_INITIAL_STREAM_MAX_DATA); + // params.setIntParameter(QuicTransportParameters.ParameterId.initial_max_stream_data_bidi_remote, DEFAULT_INITIAL_STREAM_MAX_DATA); + // params.setIntParameter(QuicTransportParameters.ParameterId.initial_max_stream_data_uni, DEFAULT_INITIAL_STREAM_MAX_DATA); + // params.setIntParameter(QuicTransportParameters.ParameterId.initial_max_streams_bidi, DEFAULT_MAX_STREAMS); + // params.setIntParameter(QuicTransportParameters.ParameterId.initial_max_streams_uni, DEFAULT_MAX_STREAMS); + // params.setIntParameter(QuicTransportParameters.ParameterId.ack_delay_exponent, 3); // unit 2^3 microseconds + // params.setIntParameter(QuicTransportParameters.ParameterId.max_ack_delay, 25); //25 millis + // params.setBooleanParameter(QuicTransportParameters.ParameterId.disable_active_migration, false); + // params.setIntParameter(QuicTransportParameters.ParameterId.active_connection_id_limit, 2); + } + + @Override + protected void processVersionNegotiationPacket(final QuicPacket quicPacket) { + // ignore the packet: the server doesn't reply to version negotiation. + debug.log("Server ignores version negotiation packet: " + quicPacket); + } + + @Override + public boolean isClientConnection() { + return false; + } + + @Override + protected QuicEndpoint onHandshakeCompletion(HandshakeState result) { + super.onHandshakeCompletion(result); + // send a new token frame to the client, for use in new connection attempts. + sendNewToken(); + return this.endpoint; + } + + private void sendNewToken() { + final byte[] token = server.buildNewToken(); + final QuicFrame newTokenFrame = new NewTokenFrame(ByteBuffer.wrap(token)); + enqueue1RTTFrame(newTokenFrame); + } + + @Override + public long nextMaxStreamsLimit(final boolean bidi) { + final Function limitComputer = this.server.getMaxStreamLimitComputer(); + if (limitComputer != null) { + return limitComputer.apply(bidi); + } + return super.nextMaxStreamsLimit(bidi); + } + + private void silentIgnorePacket(final SocketAddress source, final ByteBuffer payload, + final QuicPacket.HeadersType headersType, + final boolean outgoing, + final String reason) { + if (debug.on()) { + debug.log("silently dropping %s packet %s %s, reason: %s", headersType, + (outgoing ? "to dest" : "from source"), source, reason); + } + } +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServerHandler.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServerHandler.java new file mode 100644 index 00000000000..886f15821a6 --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicServerHandler.java @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.quic; + +import java.io.IOException; +import java.net.SocketAddress; + +import jdk.internal.net.http.quic.streams.QuicBidiStream; + +/** + * Used by server side application code to handle incoming Quic connections and streams associated + * with those connections + * + * @see QuicStandaloneServer#addHandler(QuicServerHandler) + */ +// TODO: should we make this an abstract class instead of interface? +public interface QuicServerHandler { + + /** + * @param source The (client) source of the incoming connection + * @param serverConn The {@link QuicServerConnection}, constructed on the server side, + * representing in the incoming connection + * {@return true if the incoming connection should be accepted. false otherwise} + */ + default boolean acceptIncomingConnection(final SocketAddress source, + final QuicServerConnection serverConn) { + // by default accept new connections + return true; + } + + /** + * Called whenever a client initiated bidirectional stream has been received on the + * Quic connection which this {@code QuicServerHandler} previously + * {@link #acceptIncomingConnection(SocketAddress, QuicServerConnection) accepted} + * + * @param conn The connection on which the stream was created + * @param bidi The client initiated bidirectional stream + * @throws IOException + */ + default void onClientInitiatedBidiStream(final QuicServerConnection conn, + final QuicBidiStream bidi) throws IOException { + // start the reader/writer loops for this stream, by creating a ConnectedBidiStream + try (final ConnectedBidiStream connectedBidiStream = new ConnectedBidiStream(bidi)) { + // let the handler use this connected stream to read/write data, if it wants to + this.handleBidiStream(conn, connectedBidiStream); + } catch (IOException e) { + throw e; + } catch (Exception e) { + throw new IOException(e); + } + + } + + /** + * Called whenever a client initiated bidirectional stream has been received on the connection + * previously accepted by this handler. This method is called from within + * {@link #onClientInitiatedBidiStream(QuicBidiStream)} with a {@code ConnectedBidiStream} which + * has the reader and writer tasks started. + * + * @param conn The connection on which the stream was created + * @param bidiStream The bidirectional stream which has the reader and writer tasks started + * @throws IOException + */ + void handleBidiStream(final QuicServerConnection conn, + final ConnectedBidiStream bidiStream) throws IOException; + +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicStandaloneServer.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicStandaloneServer.java new file mode 100644 index 00000000000..b1697064556 --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/QuicStandaloneServer.java @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.quic; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.security.NoSuchAlgorithmException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.LongFunction; + +import javax.net.ssl.SNIMatcher; +import javax.net.ssl.SSLContext; + +import jdk.internal.net.http.common.Logger; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicVersion; + +public final class QuicStandaloneServer extends QuicServer { + public static final String ALPN = "quic-standalone-test-alpn"; + + private static final AtomicLong IDS = new AtomicLong(); + + private final Logger debug = Utils.getDebugLogger(this::name); + private QuicServerHandler handler; + + private static String nextName() { + return "quic-standalone-server-" + IDS.incrementAndGet(); + } + + QuicStandaloneServer(String serverId, InetSocketAddress bindAddress, + ExecutorService executor, QuicVersion[] supportedQuicVersions, + boolean compatible, QuicTLSContext quicTLSContext, SNIMatcher sniMatcher, + DatagramDeliveryPolicy incomingDeliveryPolicy, + DatagramDeliveryPolicy outgoingDeliveryPolicy, String alpn, + LongFunction appErrorCodeToString) { + super(serverId, bindAddress, executor, supportedQuicVersions, compatible, quicTLSContext, sniMatcher, + incomingDeliveryPolicy, outgoingDeliveryPolicy, alpn, appErrorCodeToString); + // set a connection acceptor + setConnectionAcceptor(QuicStandaloneServer::acceptIncoming); + } + + public void addHandler(final QuicServerHandler handler) { + this.handler = handler; + } + + QuicServerHandler getHandler() { + return this.handler; + } + + static boolean acceptIncoming(final SocketAddress source, final QuicServerConnection serverConn) { + try { + final QuicStandaloneServer server = (QuicStandaloneServer) serverConn.quicInstance(); + final QuicServerHandler handler = server.getHandler(); + if (handler == null) { + if (server.debug.on()) { + server.debug.log("Handler absent - rejecting new connection " + + serverConn + " from source " + source); + } + return false; + } + if (!handler.acceptIncomingConnection(source, serverConn)) { + if (server.debug.on()) { + server.debug.log("Handler " + handler + " rejected new connection " + + serverConn + " from source " + source); + } + return false; + } + if (server.debug.on()) { + server.debug.log("New connection " + serverConn + " accepted from " + source); + } + serverConn.onSuccessfulHandshake(() -> { + if (server.debug.on()) { + server.debug.log("Registering a listener for remote streams on connection " + + serverConn); + } + // add a listener for streams that have been created by the remote side + // (i.e. initiated by the client) + serverConn.addRemoteStreamListener((stream) -> { + if (stream.isBidirectional()) { + // invoke the handler (application code) as a async work + server.asyncHandleBidiStream(source, serverConn, (QuicBidiStream) stream); + return true; // true implies that this listener wishes to acquire the stream + } else { + if (server.debug.on()) { + server.debug.log("Ignoring stream " + stream + " on connection " + serverConn); + } + return false; + } + }); + }); + return true; // true implies we wish to accept this incoming connection + } catch (Throwable t) { + // TODO: re-evaluate why this try/catch block is there. it's likely + // that in the absence of this block, the call just "disappears"/hangs when an exception + // occurs in this method + System.err.println("Exception while accepting incoming connection: " + t.getMessage()); + t.printStackTrace(); + return false; + } + } + + private void asyncHandleBidiStream(final SocketAddress source, final QuicServerConnection serverConn, + final QuicBidiStream stream) { + this.executor().execute(() -> { + try { + if (debug.on()) { + debug.log("Invoking handler " + handler + " for handling bidi stream " + + stream + " on connection " + serverConn); + } + handler.onClientInitiatedBidiStream(serverConn, stream); + } catch (Throwable t) { + System.err.println("Failed to handle client initiated" + + " bidi stream for connection " + serverConn + + " from source " + source); + t.printStackTrace(); + } + }); + } + + public static Builder newBuilder() { + return new StandaloneBuilder(); + } + + private static final class StandaloneBuilder extends Builder { + + public StandaloneBuilder() { + this.alpn = ALPN; + } + + @Override + public QuicStandaloneServer build() throws IOException { + QuicVersion[] versions = availableQuicVersions; + if (versions == null) { + // default to v1 and v2 + versions = new QuicVersion[]{QuicVersion.QUIC_V1, QuicVersion.QUIC_V2}; + } + if (versions.length == 0) { + throw new IllegalStateException("Empty supported QUIC versions"); + } + InetSocketAddress addr = bindAddress; + if (addr == null) { + // default to loopback address and ephemeral port + addr = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0); + } + SSLContext ctx = sslContext; + if (ctx == null) { + try { + ctx = SSLContext.getDefault(); + } catch (NoSuchAlgorithmException e) { + throw new IOException(e); + } + } + final QuicTLSContext quicTLSContext = new QuicTLSContext(ctx); + final String name = serverId == null ? nextName() : serverId; + return new QuicStandaloneServer(name, addr, executor, versions, compatible, quicTLSContext, + sniMatcher, incomingDeliveryPolicy, outgoingDeliveryPolicy, alpn, appErrorCodeToString); + } + } + +} diff --git a/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/RetryCodingContext.java b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/RetryCodingContext.java new file mode 100644 index 00000000000..c627ed52fab --- /dev/null +++ b/test/jdk/java/net/httpclient/lib/jdk/httpclient/test/lib/quic/RetryCodingContext.java @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.httpclient.test.lib.quic; + +import jdk.internal.net.http.quic.CodingContext; +import jdk.internal.net.http.quic.QuicConnectionId; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicTLSEngine; + +import java.io.IOException; +import java.nio.ByteBuffer; + +public class RetryCodingContext implements CodingContext { + private final QuicConnectionId connectionId; + private final QuicTLSEngine engine; + + public RetryCodingContext(QuicConnectionId connectionId, QuicTLSContext quicTLSContext) { + this.connectionId = connectionId; + engine = quicTLSContext.createEngine(); + } + + @Override + public long largestProcessedPN(QuicPacket.PacketNumberSpace packetSpace) { + throw new UnsupportedOperationException(); + } + + @Override + public long largestAckedPN(QuicPacket.PacketNumberSpace packetSpace) { + throw new UnsupportedOperationException(); + } + + @Override + public int connectionIdLength() { + throw new UnsupportedOperationException(); + } + + @Override + public int writePacket(QuicPacket packet, ByteBuffer buffer) { + throw new UnsupportedOperationException(); + } + + @Override + public QuicPacket parsePacket(ByteBuffer src) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public QuicConnectionId originalServerConnId() { + return connectionId; + } + + @Override + public QuicTLSEngine getTLSEngine() { + return engine; + } + + @Override + public boolean verifyToken(QuicConnectionId destinationID, byte[] token) { + throw new UnsupportedOperationException(); + } +} diff --git a/test/jdk/java/net/httpclient/qpack/BlockingDecodingTest.java b/test/jdk/java/net/httpclient/qpack/BlockingDecodingTest.java new file mode 100644 index 00000000000..c588f8117ae --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/BlockingDecodingTest.java @@ -0,0 +1,374 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.Encoder; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import static org.testng.Assert.assertNotEquals; + +/* + * @test + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack:+open + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @build EncoderDecoderConnector + * @run testng/othervm -Djdk.internal.httpclient.qpack.log.level=EXTRA BlockingDecodingTest + */ + + +public class BlockingDecodingTest { + @Test + public void blockedStreamsSettingDefaultValueTest() throws Exception { + // Default SETTINGS_QPACK_BLOCKED_STREAMS value (0) doesn't allow blocked streams + var encoderEh = new TestErrorHandler(); + var decoderEh = new TestErrorHandler(); + var streamError = new AtomicReference(); + EncoderDecoderConnector.EncoderDecoderPair pair = + newPreconfiguredEncoderDecoder(encoderEh, decoderEh, + streamError, -1L, 1); + + // Get Encoder and Decoder instances from a newly established connector + var encoder = pair.encoder(); + var decoder = pair.decoder(); + + // Create a decoding callback to check for completion and to log failures + TestDecodingCallback decodingCallback = new TestDecodingCallback(); + // Start encoding Headers Frame + var headerFrameWriter = encoder.newHeaderFrameWriter(); + var headerFrameReader = decoder.newHeaderFrameReader(decodingCallback); + + // create encoding context and buffer to hold encoded headers + List buffers = new ArrayList<>(); + + ByteBuffer headersBb = ByteBuffer.allocate(2048); + Encoder.EncodingContext context = + encoder.newEncodingContext(0, 0, headerFrameWriter); + var header = TestHeader.withId(0); + encoder.header(context, header.name(), header.value(), + false, IGNORE_RECEIVED_COUNT_CHECK); + headerFrameWriter.write(headersBb); + assertNotEquals(headersBb.position(), 0); + headersBb.flip(); + buffers.add(headersBb); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + + // Decode headers + decoder.decodeHeader(buffers.get(0), false, headerFrameReader); + decoder.decodeHeader(buffers.get(1), true, headerFrameReader); + + // It is expected to get QPACK_DECOMPRESSION_FAILED here since decoder is + // expected to be blocked due to missing entry with index 0 in the decoder table, + // and the default number of blocked streams (0) will be exceeded (1). + var lastHttp3Error = decodingCallback.lastHttp3Error.get(); + System.err.println("Last Http3Error: " + lastHttp3Error); + Assert.assertEquals(lastHttp3Error, Http3Error.QPACK_DECOMPRESSION_FAILED); + Assert.assertFalse(decodingCallback.completed.isDone()); + } + + @Test + public void noBlockedStreamsTest() throws Exception { + // No blocked streams - with default SETTINGS_QPACK_BLOCKED_STREAMS value + // Default SETTINGS_QPACK_BLOCKED_STREAMS value (0) doesn't allow blocked streams + var encoderEh = new TestErrorHandler(); + var decoderEh = new TestErrorHandler(); + var streamError = new AtomicReference(); + EncoderDecoderConnector.EncoderDecoderPair pair = + newPreconfiguredEncoderDecoder(encoderEh, decoderEh, + streamError, -1L, 1); + + // Populate decoder table with an entry - there should be no blocked streams + // observed during decoding + prepopulateDynamicTable(pair.decoderTable(), 1); + + // Get Encoder and Decoder instances from a newly established connector + var encoder = pair.encoder(); + var decoder = pair.decoder(); + + // Create a decoding callback to check for completion and to log failures + TestDecodingCallback decodingCallback = new TestDecodingCallback(); + // Start encoding Headers Frame + var headerFrameWriter = encoder.newHeaderFrameWriter(); + var headerFrameReader = decoder.newHeaderFrameReader(decodingCallback); + + // create encoding context and buffer to hold encoded headers + List buffers = new ArrayList<>(); + + ByteBuffer headersBb = ByteBuffer.allocate(2048); + Encoder.EncodingContext context = + encoder.newEncodingContext(0, 0, headerFrameWriter); + var expectedHeader = TestHeader.withId(0); + encoder.header(context, expectedHeader.name, + expectedHeader.value, false, IGNORE_RECEIVED_COUNT_CHECK); + headerFrameWriter.write(headersBb); + assertNotEquals(headersBb.position(), 0); + headersBb.flip(); + buffers.add(headersBb); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + + // Decode headers + decoder.decodeHeader(buffers.get(0), false, headerFrameReader); + decoder.decodeHeader(buffers.get(1), true, headerFrameReader); + + // It is expected to get QPACK_DECOMPRESSION_FAILED here since decoder is + // expected to be blocked due to missing entry with index 0 in the decoder table, + // and the default number of blocked streams (0) will be exceeded (1). + var lastHttp3Error = decodingCallback.lastHttp3Error.get(); + System.err.println("Last Http3Error: " + lastHttp3Error); + Assert.assertNull(lastHttp3Error); + Assert.assertNull(decodingCallback.lastThrowable.get()); + Assert.assertTrue(decodingCallback.completed.isDone()); + // Check that onDecoded was called for the test entry + var decodedHeader = decodingCallback.decodedHeaders.get(0); + Assert.assertEquals(decodedHeader, expectedHeader); + } + + @Test + public void awaitBlockedStreamsTest() throws Exception { + // Max number of blocked streams is not exceeded + // No blocked streams - with default SETTINGS_QPACK_BLOCKED_STREAMS value + // Default SETTINGS_QPACK_BLOCKED_STREAMS value (0) doesn't allow blocked streams + final int numberOfMaxAllowedBlockedStreams = 5; + final int numberOfHeaders = 4; + final int base = 2; + var encoderEh = new TestErrorHandler(); + var decoderEh = new TestErrorHandler(); + var streamError = new AtomicReference(); + EncoderDecoderConnector.EncoderDecoderPair pair = + newPreconfiguredEncoderDecoder(encoderEh, decoderEh, + streamError, numberOfMaxAllowedBlockedStreams, numberOfHeaders); + + // Create list of headers to encode for each thread + List expectedHeaders = Collections.synchronizedList(new ArrayList<>()); + for (int headerId = 0; headerId < numberOfHeaders; headerId++) { + expectedHeaders.add(TestHeader.withId(headerId)); + } + + // Create virtual threads executor + var vtExecutor = Executors.newVirtualThreadPerTaskExecutor(); + List> decodingTaskResults = new ArrayList<>(); + + // Create 10 blocked tasks + for (int taskCount = 0; taskCount < numberOfMaxAllowedBlockedStreams; taskCount++) { + var decodingTask = new Callable() { + final EncoderDecoderConnector.EncoderDecoderPair ed = pair; + @Override + public TestDecodingCallback call() throws Exception { + var encoder = ed.encoder(); + var decoder = ed.decoder(); + // Create a decoding callback to check for completion and to log failures + TestDecodingCallback decodingCallback = new TestDecodingCallback(); + // Start encoding Headers Frame + var headerFrameWriter = encoder.newHeaderFrameWriter(); + var headerFrameReader = decoder.newHeaderFrameReader(decodingCallback); + + // create encoding context and buffer to hold encoded headers + List buffers = new ArrayList<>(); + + ByteBuffer headersBb = ByteBuffer.allocate(2048); + Encoder.EncodingContext context = + encoder.newEncodingContext(0, base, headerFrameWriter); + + for (var header : expectedHeaders) { + encoder.header(context, header.name, header.value, false, + IGNORE_RECEIVED_COUNT_CHECK); + headerFrameWriter.write(headersBb); + } + assertNotEquals(headersBb.position(), 0); + headersBb.flip(); + buffers.add(headersBb); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + + // Decode headers + decoder.decodeHeader(buffers.get(0), false, headerFrameReader); + decoder.decodeHeader(buffers.get(1), true, headerFrameReader); + return decodingCallback; + } + }; + decodingTaskResults.add(vtExecutor.submit(decodingTask)); + } + + // Schedule the delayed update to the decoders dynamic table + var delayedExecutor = CompletableFuture.delayedExecutor(100, TimeUnit.MILLISECONDS, + vtExecutor); + AtomicLong updateDoneTimestamp = new AtomicLong(); + delayedExecutor.execute(() -> { + updateDoneTimestamp.set(System.nanoTime()); + prepopulateDynamicTable(pair.decoderTable(), numberOfHeaders); + }); + + // Await completion of all tasks + for (var decodingResultFuture : decodingTaskResults) { + decodingResultFuture.get().completed.get(); + } + // Acquire the timestamp + long updateDoneTimeStamp = updateDoneTimestamp.get(); + + System.err.println("All decoding tasks are done"); + System.err.println("Decoder table update timestamp: " + updateDoneTimeStamp); + // Check results of each decoding task + for (var decodingResultFuture : decodingTaskResults) { + var taskCallback = decodingResultFuture.get(); + Assert.assertNull(taskCallback.lastHttp3Error.get()); + Assert.assertNull(taskCallback.lastThrowable.get()); + long decodingTaskCompleted = taskCallback.completedTimestamp.get(); + System.err.println("Decoding task completion timestamp: " + decodingTaskCompleted); + Assert.assertTrue(decodingTaskCompleted >= updateDoneTimeStamp); + var decodedHeaders = taskCallback.decodedHeaders; + Assert.assertEquals(decodedHeaders, expectedHeaders); + } + } + + private static EncoderDecoderConnector.EncoderDecoderPair newPreconfiguredEncoderDecoder( + TestErrorHandler encoderEh, + TestErrorHandler decoderEh, + AtomicReference streamError, + long maxBlockedStreams, + int numberOfEntriesInEncoderDT) { + EncoderDecoderConnector conn = new EncoderDecoderConnector(); + var pair = conn.newEncoderDecoderPair( + e -> false, + encoderEh::qpackErrorHandler, + decoderEh::qpackErrorHandler, + streamError::set); + // Create settings frame with dynamic table capacity and number of blocked streams + SettingsFrame settingsFrame = SettingsFrame.defaultRFCSettings(); + // 4k should be enough for storing dynamic table entries added by 'prepopulateDynamicTable' + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, DT_CAPACITY); + if (maxBlockedStreams > 0) { + // Set max number of blocked decoder streams if the provided value is positive, otherwise + // use the default RFC setting which is 0 + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_BLOCKED_STREAMS, maxBlockedStreams); + } + ConnectionSettings settings = ConnectionSettings.createFrom(settingsFrame); + + // Configure encoder and decoder with constructed ConnectionSettings + pair.encoder().configure(settings); + pair.decoder().configure(settings); + pair.encoderTable().setCapacity(DT_CAPACITY); + pair.decoderTable().setCapacity(DT_CAPACITY); + + // Prepopulate encoder dynamic table with test entries. Decoder dynamic table will be pre-populated with + // a test-case specific code to reproduce blocked decoding scenario + prepopulateDynamicTable(pair.encoderTable(), numberOfEntriesInEncoderDT); + + return pair; + } + + private static void prepopulateDynamicTable(DynamicTable dynamicTable, int numEntries) { + for (int count = 0; count < numEntries; count++) { + var header = TestHeader.withId(count); + dynamicTable.insert(header.name(), header.value()); + } + } + + private static class TestDecodingCallback implements DecodingCallback { + + final List decodedHeaders = new CopyOnWriteArrayList<>(); + final CompletableFuture completed = new CompletableFuture<>(); + final AtomicLong completedTimestamp = new AtomicLong(); + + final AtomicReference lastThrowable = new AtomicReference<>(); + final AtomicReference lastHttp3Error = new AtomicReference<>(); + + @Override + public void onDecoded(CharSequence name, CharSequence value) { + var nameValue = new TestHeader(name.toString(), value.toString()); + decodedHeaders.add(nameValue); + System.err.println("Decoding callback 'onDecoded': " + nameValue); + } + + @Override + public void onComplete() { + System.err.println("Decoding callback 'onComplete'"); + completedTimestamp.set(System.nanoTime()); + completed.complete(null); + } + + @Override + public void onConnectionError(Throwable throwable, Http3Error http3Error) { + System.err.println("Decoding callback 'onError': " + http3Error); + lastThrowable.set(throwable); + lastHttp3Error.set(http3Error); + } + + @Override + public long streamId() { + return 0; + } + } + + private static class TestErrorHandler { + final AtomicReference error = new AtomicReference<>(); + final AtomicReference http3Error = new AtomicReference<>(); + + public void qpackErrorHandler(Throwable error, Http3Error http3Error) { + this.error.set(error); + this.http3Error.set(http3Error); + throw new RuntimeException("http3 error: " + http3Error, error); + } + } + + record TestHeader(String name, String value) { + public static TestHeader withId(int id) { + return new TestHeader(NAME + id, VALUE + id); + } + } + + private static final String NAME = "test"; + private static final String VALUE = "valueTest"; + private static final long DT_CAPACITY = 4096L; + private static final long IGNORE_RECEIVED_COUNT_CHECK = -1L; +} diff --git a/test/jdk/java/net/httpclient/qpack/DecoderInstructionsReaderTest.java b/test/jdk/java/net/httpclient/qpack/DecoderInstructionsReaderTest.java new file mode 100644 index 00000000000..e231d0dded3 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/DecoderInstructionsReaderTest.java @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @key randomness + * @library /test/lib + * @run junit/othervm -Djdk.internal.httpclient.qpack.log.level=NORMAL DecoderInstructionsReaderTest + */ + +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.QPackException; +import jdk.internal.net.http.qpack.readers.DecoderInstructionsReader; +import jdk.internal.net.http.qpack.readers.IntegerReader; +import jdk.internal.net.http.qpack.writers.IntegerWriter; +import jdk.test.lib.RandomFactory; +import org.junit.jupiter.api.RepeatedTest; + +import java.nio.ByteBuffer; +import java.util.Random; +import java.util.concurrent.atomic.AtomicLong; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class DecoderInstructionsReaderTest { + + DecoderInstructionsReader decoderInstructionsReader; + private static final Random RANDOM = RandomFactory.getRandom(); + + @RepeatedTest(10) + public void acknowledgementTest() { + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 1 | Stream ID (7+) | + // +---+---------------------------+ + + TestDecoderInstructionsCallback callback = new TestDecoderInstructionsCallback(); + decoderInstructionsReader = new DecoderInstructionsReader(callback, QPACK.getLogger()); + + long streamId = RANDOM.nextLong(0, IntegerReader.QPACK_MAX_INTEGER_VALUE); + IntegerWriter writer = new IntegerWriter(); + int bufferSize = requiredBufferSize(7, streamId); + var payload = 0b1000_0000; + + ByteBuffer byteBuffer = ByteBuffer.allocate(bufferSize); + writer.configure(streamId, 7, payload); + writer.write(byteBuffer); + byteBuffer.flip(); + + decoderInstructionsReader.read(byteBuffer); + assertEquals(streamId, callback.lastSectionAckStreamId.get()); + } + + @RepeatedTest(10) + public void cancellationTest() { + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 1 | Stream ID (6+) | + // +---+---+-----------------------+ + + TestDecoderInstructionsCallback callback = new TestDecoderInstructionsCallback(); + decoderInstructionsReader = new DecoderInstructionsReader(callback, QPACK.getLogger()); + + long streamId = RANDOM.nextLong(0, IntegerReader.QPACK_MAX_INTEGER_VALUE); + IntegerWriter writer = new IntegerWriter(); + int bufferSize = requiredBufferSize(6, streamId); + var payload = 0b0100_0000; + + ByteBuffer byteBuffer = ByteBuffer.allocate(bufferSize); + writer.configure(streamId, 6, payload); + writer.write(byteBuffer); + byteBuffer.flip(); + + decoderInstructionsReader.read(byteBuffer); + assertEquals(streamId, callback.lastCancelStreamId.get()); + } + + @RepeatedTest(10) + public void incrementTest() { + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | Increment (6+) | + // +---+---+-----------------------+ + + TestDecoderInstructionsCallback callback = new TestDecoderInstructionsCallback(); + decoderInstructionsReader = new DecoderInstructionsReader(callback, QPACK.getLogger()); + + long increaseCountInc = RANDOM.nextLong(0, IntegerReader.QPACK_MAX_INTEGER_VALUE); + IntegerWriter writer = new IntegerWriter(); + int bufferSize = requiredBufferSize(6, increaseCountInc); + var payload = 0b0000_0000; + + ByteBuffer byteBuffer = ByteBuffer.allocate(bufferSize); + writer.configure(increaseCountInc, 6, payload); + writer.write(byteBuffer); + byteBuffer.flip(); + + decoderInstructionsReader.read(byteBuffer); + assertEquals(increaseCountInc, callback.lastInsertCountInc.get()); + + } + + static int requiredBufferSize(int N, long value) { + checkPrefix(N); + int size = 1; + int max = (2 << (N - 1)) - 1; + if (value < max) { + return size; + } + size++; + value -= max; + while (value >= 128) { + value /= 128; + size++; + } + return size; + } + + private static void checkPrefix(int N) { + if (N < 1 || N > 8) { + throw new IllegalArgumentException("1 <= N <= 8: N= " + N); + } + } + + private static class TestDecoderInstructionsCallback implements DecoderInstructionsReader.Callback { + final AtomicLong lastSectionAckStreamId = new AtomicLong(-1L); + final AtomicLong lastCancelStreamId = new AtomicLong(-1L); + final AtomicLong lastInsertCountInc = new AtomicLong(-1L); + + @Override + public void onSectionAck(long streamId) { + lastSectionAckStreamId.set(streamId); + } + + @Override + public void onStreamCancel(long streamId) { + lastCancelStreamId.set(streamId); + } + + @Override + public void onInsertCountIncrement(long increment) { + lastInsertCountInc.set(increment); + } + } +} diff --git a/test/jdk/java/net/httpclient/qpack/DecoderInstructionsWriterTest.java b/test/jdk/java/net/httpclient/qpack/DecoderInstructionsWriterTest.java new file mode 100644 index 00000000000..275dec34927 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/DecoderInstructionsWriterTest.java @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @key randomness + * @library /test/lib + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @run junit/othervm -Djdk.internal.httpclient.qpack.log.level=NORMAL DecoderInstructionsWriterTest + */ + +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.readers.DecoderInstructionsReader; +import jdk.internal.net.http.qpack.readers.IntegerReader; +import jdk.internal.net.http.qpack.writers.DecoderInstructionsWriter; +import jdk.test.lib.RandomFactory; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Random; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class DecoderInstructionsWriterTest { + + @ParameterizedTest + @MethodSource("decoderInstructionsSource") + public void decoderInstructionsTest(DecoderInstruction instruction, long value) throws Exception { + testRunner(instruction, value); + } + + private static Stream decoderInstructionsSource() { + // "Section Acknowledgment" + Stream sectionAck = RANDOM.longs(10, + 0, IntegerReader.QPACK_MAX_INTEGER_VALUE) + .boxed() + .map(l -> Arguments.of(DecoderInstruction.SECTION_ACK, l)); + + // "Stream Cancellation" + Stream streamCancel = RANDOM.longs(10, + 0, IntegerReader.QPACK_MAX_INTEGER_VALUE) + .boxed() + .map(l -> Arguments.of(DecoderInstruction.STREAM_CANCEL, l)); + + // "Insert Count Increment" + Stream insertCountInc = RANDOM.longs(10, + 0, IntegerReader.QPACK_MAX_INTEGER_VALUE) + .boxed() + .map(l -> Arguments.of(DecoderInstruction.INSERT_COUNT_INC, l)); + + return Stream.concat(sectionAck, Stream.concat(streamCancel, insertCountInc)); + } + + + enum DecoderInstruction { + SECTION_ACK, + STREAM_CANCEL, + INSERT_COUNT_INC; + } + + private static void testRunner(DecoderInstruction instruction, long value) throws Exception { + var writer = new DecoderInstructionsWriter(); + int calculatedInstructionSize = configureWriter(writer, instruction, value); + var logger = QPACK.getLogger(); + var dynamicTable = new DynamicTable(logger); + + var buffers = new ArrayList(); + + boolean writeDone = false; + int writtenBytes = 0; + // Write instruction to a byte buffers of a random size + while (!writeDone) { + int allocSize = RANDOM.nextInt(1, 9); + var buffer = ByteBuffer.allocate(allocSize); + + writeDone = writer.write(buffer); + writtenBytes += buffer.position(); + buffer.flip(); + buffers.add(buffer); + } + // Check that instruction size calculated by the writer matches + // the number of written bytes + assertEquals(writtenBytes, calculatedInstructionSize); + + // Read back the data from byte buffers + var callback = new TestDecoderInstructionsCallback(); + var reader = new DecoderInstructionsReader(callback, logger); + for (var bb : buffers) { + reader.read(bb); + } + // Check that reader callback values match values supplied to the writer + long instructionValue = extractCallbackValue(instruction, callback); + assertEquals(value, instructionValue); + } + + private static long extractCallbackValue(DecoderInstruction instruction, + TestDecoderInstructionsCallback callback) { + return switch (instruction) { + case SECTION_ACK -> callback.lastSectionAckStreamId.get(); + case STREAM_CANCEL -> callback.lastCancelStreamId.get(); + case INSERT_COUNT_INC -> callback.lastInsertCountInc.get(); + }; + } + + + private static int configureWriter(DecoderInstructionsWriter writer, + DecoderInstruction instruction, + long instructionValue) { + return switch (instruction) { + case SECTION_ACK -> writer.configureForSectionAck(instructionValue); + case STREAM_CANCEL -> writer.configureForStreamCancel(instructionValue); + case INSERT_COUNT_INC -> writer.configureForInsertCountInc(instructionValue); + }; + } + + private static final Random RANDOM = RandomFactory.getRandom(); + + private static class TestDecoderInstructionsCallback implements DecoderInstructionsReader.Callback { + final AtomicLong lastSectionAckStreamId = new AtomicLong(-1L); + final AtomicLong lastCancelStreamId = new AtomicLong(-1L); + final AtomicLong lastInsertCountInc = new AtomicLong(-1L); + + @Override + public void onSectionAck(long streamId) { + lastSectionAckStreamId.set(streamId); + } + + @Override + public void onStreamCancel(long streamId) { + lastCancelStreamId.set(streamId); + } + + @Override + public void onInsertCountIncrement(long increment) { + lastInsertCountInc.set(increment); + } + } +} diff --git a/test/jdk/java/net/httpclient/qpack/DecoderSectionSizeLimitTest.java b/test/jdk/java/net/httpclient/qpack/DecoderSectionSizeLimitTest.java new file mode 100644 index 00000000000..4b715a46e08 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/DecoderSectionSizeLimitTest.java @@ -0,0 +1,265 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @key randomness + * @library /test/lib + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack:+open + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @build EncoderDecoderConnector + * @run testng/othervm -Djdk.internal.httpclient.qpack.log.level=NORMAL + * DecoderSectionSizeLimitTest + */ + +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.Encoder; +import jdk.test.lib.RandomFactory; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicReference; + +public class DecoderSectionSizeLimitTest { + @Test(dataProvider = "headerSequences") + public void fieldSectionSizeLimitExceeded(List headersSequence, + long maxFieldSectionSize) { + + boolean decoderErrorExpected = + maxFieldSectionSize > 0 && maxFieldSectionSize < REQUIRED_FIELD_SECTION_SIZE; + + System.err.println("=".repeat(50)); + System.err.println("Max Field Section Size = " + maxFieldSectionSize); + System.err.println("Max Field Section Size is" + (decoderErrorExpected ? " not" : "") + + " enough to encode headers"); + + AtomicReference error = new AtomicReference<>(); + EncoderDecoderConnector encoderDecoderConnector = new EncoderDecoderConnector(); + DecoderSectionSizeLimitTest.TestErrorHandler encoderErrorHandler = + new DecoderSectionSizeLimitTest.TestErrorHandler(); + DecoderSectionSizeLimitTest.TestErrorHandler decoderErrorHandler = + new DecoderSectionSizeLimitTest.TestErrorHandler(); + var conn = encoderDecoderConnector.newEncoderDecoderPair(entry -> false, + encoderErrorHandler::qpackErrorHandler, decoderErrorHandler::qpackErrorHandler, + error::set); + + var encoder = conn.encoder(); + var decoder = conn.decoder(); + + // This test emulates a scenario with an Encoder that doesn't respect + // the SETTINGS_MAX_FIELD_SECTION_SIZE setting value while encoding the headers frame + SettingsFrame settingsFrame = SettingsFrame.defaultRFCSettings(); + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, 512L); + settingsFrame.setParameter(SettingsFrame.SETTINGS_MAX_FIELD_SECTION_SIZE, maxFieldSectionSize); + ConnectionSettings decoderConnectionSetting = ConnectionSettings.createFrom(settingsFrame); + + // Encoder imposes no limit on the field section size + settingsFrame.setParameter(SettingsFrame.SETTINGS_MAX_FIELD_SECTION_SIZE, -1L); + ConnectionSettings encoderConnectionSetting = ConnectionSettings.createFrom(settingsFrame); + + // Configure encoder and decoder + encoder.configure(encoderConnectionSetting); + decoder.configure(decoderConnectionSetting); + + // Configure dynamic tables + configureDynamicTable(conn.encoderTable()); + configureDynamicTable(conn.decoderTable()); + + // Encode headers + // Create header frame writer + var headerFrameWriter = encoder.newHeaderFrameWriter(); + + // create encoding context + Encoder.EncodingContext context = encoder.newEncodingContext( + 0, BASE, headerFrameWriter); + + ByteBuffer buffer = ByteBuffer.allocate(RANDOM.nextInt(1, 65)); + List buffers = new ArrayList<>(); + for (TestHeader header : headersSequence) { + // Configures encoder for writing the header name:value pair + encoder.header(context, header.name, header.value, + false, -1L); + + // Write the header + while (!headerFrameWriter.write(buffer)) { + buffer.flip(); + buffers.add(buffer); + buffer = ByteBuffer.allocate(RANDOM.nextInt(1, 65)); + } + } + buffer.flip(); + buffers.add(buffer); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + System.err.println("Number of generated header buffers:" + buffers.size()); + + // Decode header buffers and check if expected HTTP/3 error is reported + // via decoding callback + var decodingCallback = new TestDecodingCallback(); + var headerFrameReader = decoder.newHeaderFrameReader(decodingCallback); + for (int bufferIdx = 0; bufferIdx < buffers.size(); bufferIdx++) { + headerFrameReader.read(buffers.get(bufferIdx), + bufferIdx == buffers.size() - 1); + Http3Error decodingError = decodingCallback.lastHttp3Error.get(); + if (decodingError != null) { + System.err.printf("Decoding error observed during buffer #%d processing: %s throwable: %s%n", + bufferIdx, decodingError, decodingCallback.lastThrowable.get()); + if (decoderErrorExpected) { + Assert.assertEquals(decodingError, Http3Error.QPACK_DECOMPRESSION_FAILED); + return; + } else { + Assert.fail("No HTTP/3 error was expected"); + } + } else { + System.err.println("Buffer #" + bufferIdx + " readout completed without errors"); + } + } + if (decoderErrorExpected) { + Assert.fail("HTTP/3 error was expected but was not observed"); + } + } + + @DataProvider + public Object[][] headerSequences() { + List testCases = new ArrayList<>(); + for (var sequence : generateHeaderSequences()) { + // Decoding should complete without failure + testCases.add(new Object[]{sequence, -1L}); + // No failure since it is enough bytes specified in the SETTINGS_MAX_FIELD_SECTION_SIZE + // setting value + testCases.add(new Object[]{sequence, REQUIRED_FIELD_SECTION_SIZE}); + // Failure is expected - not enough bytes specified in the SETTINGS_MAX_FIELD_SECTION_SIZE + // setting value + testCases.add(new Object[]{sequence, REQUIRED_FIELD_SECTION_SIZE - 1}); + } + return testCases.toArray(Object[][]::new); + } + + private static List> generateHeaderSequences() { + List> headersSequences = new ArrayList<>(); + headersSequences.add(TEST_HEADERS); + // startIndex == 0 - the TEST_HEADERS sequence that is already + // added to the sequences list + for (int startIndex = 1; startIndex < TEST_HEADERS.size(); startIndex++) { + List firstPart = TEST_HEADERS.subList(startIndex, TEST_HEADERS.size()); + List secondPart = TEST_HEADERS.subList(0, startIndex); + List sequence = new ArrayList<>(); + sequence.addAll(firstPart); + sequence.addAll(secondPart); + headersSequences.add(sequence); + } + return headersSequences; + } + + record TestHeader(String name, String value, long size) { + public TestHeader(String name, String value) { + this(name, value, name.length() + value.length() + 32L); + } + } + + private static void configureDynamicTable(DynamicTable table) { + table.setCapacity(512L); + table.insert(NAME_IN_TABLE, VALUE_IN_TABLE); + table.insert(NAME_IN_TABLE_POSTBASE, VALUE_IN_TABLE_POSTBASE); + } + + private static class TestErrorHandler { + final AtomicReference error = new AtomicReference<>(); + final AtomicReference http3Error = new AtomicReference<>(); + + public void qpackErrorHandler(Throwable error, Http3Error http3Error) { + this.error.set(error); + this.http3Error.set(http3Error); + } + } + + private static class TestDecodingCallback implements DecodingCallback { + + final AtomicReference lastHttp3Error = new AtomicReference<>(); + final AtomicReference lastThrowable = new AtomicReference<>(); + + @Override + public void onDecoded(CharSequence name, CharSequence value) { + } + + @Override + public void onComplete() { + } + + @Override + public void onConnectionError(Throwable throwable, Http3Error http3Error) { + lastHttp3Error.set(http3Error); + lastThrowable.set(throwable); + } + + @Override + public long streamId() { + return 0; + } + } + + private static final String NAME_IN_TABLE = "HEADER_NAME_FROM_TABLE"; + private static final String VALUE_IN_TABLE = "HEADER_VALUE_FROM_TABLE"; + private static final String NAME_IN_TABLE_POSTBASE = "HEADER_NAME_FROM_TABLE_POSTBASE"; + private static final String VALUE_IN_TABLE_POSTBASE = "HEADER_VALUE_FROM_TABLE_POSTBASE"; + private static final String NAME_NOT_IN_TABLE = "NAME_NOT_IN_TABLE"; + private static final String VALUE_NOT_IN_TABLE = "VALUE_NOT_IN_TABLE"; + + private static List TEST_HEADERS = List.of( + // Relative index + new TestHeader(NAME_IN_TABLE, VALUE_IN_TABLE), + // Relative name index + new TestHeader(NAME_IN_TABLE, VALUE_NOT_IN_TABLE), + // Post-base index + new TestHeader(NAME_IN_TABLE_POSTBASE, VALUE_IN_TABLE_POSTBASE), + // Post-base name index + new TestHeader(NAME_IN_TABLE_POSTBASE, VALUE_NOT_IN_TABLE), + // Literal + new TestHeader(NAME_NOT_IN_TABLE, VALUE_NOT_IN_TABLE) + ); + + private static long REQUIRED_FIELD_SECTION_SIZE = TEST_HEADERS.stream() + .mapToLong(TestHeader::size) + .sum(); + private static final long BASE = 1L; + private static final Random RANDOM = RandomFactory.getRandom(); +} diff --git a/test/jdk/java/net/httpclient/qpack/DecoderTest.java b/test/jdk/java/net/httpclient/qpack/DecoderTest.java new file mode 100644 index 00000000000..37f87bdda59 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/DecoderTest.java @@ -0,0 +1,392 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.streams.QueuingStreamPair; +import jdk.internal.net.http.hpack.QuickHuffman; +import jdk.internal.net.http.qpack.Decoder; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.qpack.readers.HeaderFrameReader; +import jdk.internal.net.http.qpack.writers.IntegerWriter; +import jdk.internal.net.http.qpack.StaticTable; +import jdk.internal.net.http.qpack.writers.StringWriter; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + +import static org.testng.Assert.*; + +/* + * @test + * @modules java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @run testng/othervm DecoderTest + */ +public class DecoderTest { + + private final Random random = new Random(); + private final IntegerWriter intWriter = new IntegerWriter(); + private final StringWriter stringWriter = new StringWriter(); + + record DecoderWithReader(Decoder decoder, HeaderFrameReader reader) { + } + + private static DecoderWithReader newDecoderWithReader(DecodingCallback decodingCallback) throws IOException { + var decoder = new Decoder(DecoderTest::createDecoderStreams, DecoderTest::qpackErrorHandler); + var headerFrameReader = decoder.newHeaderFrameReader(decodingCallback); + // Supply byte buffer with two bytes of zeroed Field Section Prefix + ByteBuffer prefix = ByteBuffer.allocate(2); + decoder.decodeHeader(prefix, true, headerFrameReader); + // Return record with decoder/reader tuple + return new DecoderWithReader(decoder, headerFrameReader); + } + + private static final int TEST_STR_MAX_LENGTH = 10; + + static QueuingStreamPair createDecoderStreams(Consumer receiver) { + return null; + } + + @DataProvider(name = "indexProvider") + public Object[][] indexProvider() { + AtomicInteger tableIndex = new AtomicInteger(); + return StaticTable.HTTP3_HEADER_FIELDS.stream() + .map(headerField -> List.of(tableIndex.getAndIncrement(), headerField)) + .map(List::toArray) + .toArray(Object[][]::new); + } + + @DataProvider(name = "nameReferenceProvider") + public Object[][] nameReferenceProvider() { + AtomicInteger tableIndex = new AtomicInteger(); + return StaticTable.HTTP3_HEADER_FIELDS.stream() + .map(h -> List.of(tableIndex.getAndIncrement(), h.name(), randomString())) + .map(List::toArray).toArray(Object[][]::new); + } + + @DataProvider(name = "literalProvider") + public Object[][] literalProvider() { + var output = new String[100][]; + for (int i = 0; i < 100; i++) { + output[i] = new String[]{ randomString(), randomString() }; + } + return output; + } + + @Test(dataProvider = "indexProvider") + public void testIndexedOnStaticTable(int index, HeaderField h) throws IOException { + var actual = writeIndex(index); + var callback = new TestingCallBack(index, h.name(), h.value()); + var dr = newDecoderWithReader(callback); + dr.decoder().decodeHeader(actual, true, dr.reader()); + } + + @Test(dataProvider = "nameReferenceProvider") + public void testLiteralWithNameReferenceOnStaticTable(int index, String name, String value) throws IOException { + boolean sensitive = random.nextBoolean(); + + var actual = writeNameRef(index, sensitive, value); + var callback = new TestingCallBack(index, sensitive, name, value); + var dr = newDecoderWithReader(callback); + dr.decoder().decodeHeader(actual, true, dr.reader()); + } + + @Test(dataProvider = "literalProvider") + public void testLiteralWithLiteralNameOnStaticTable(String name, String value) throws IOException { + boolean sensitive = random.nextBoolean(); + + var actual = writeLiteral(sensitive, name, value); + var callback = new TestingCallBack(sensitive, name, value); + var dr = newDecoderWithReader(callback); + dr.decoder().decodeHeader(actual, true, dr.reader()); + } + + @Test + public void stateCheckSingle() throws IOException { + boolean sensitive = random.nextBoolean(); + var name = "foo"; + var value = "bar"; + + var bb = writeLiteral(sensitive, name, value); + var callback = new TestingCallBack(sensitive, name, value); + + var dr = newDecoderWithReader(callback); + int len = bb.capacity(); + for (int i = 0; i < len; i++) { + var b = ByteBuffer.wrap(new byte[]{ bb.get() }); + dr.decoder().decodeHeader(b, (i == len - 1), dr.reader()); + } + } + + /* Test Methods */ + private void debug(ByteBuffer bb, String msg, boolean verbose) { + if (verbose) { + System.out.printf("DEBUG[%s]: pos=%d, limit=%d, remaining=%d\n", + msg, bb.position(), bb.limit(), bb.remaining()); + } + System.out.printf("DEBUG[%s]: ", msg); + for (byte b : bb.array()) { + System.out.printf("(%s,%d) ", Integer.toBinaryString(b & 0xFF), (int)(b & 0xFF)); + } + System.out.print("\n"); + } + + private ByteBuffer writeIndex(int index) { + int N = 6; + int payload = 0b1100_0000; // static table = true + var bb = ByteBuffer.allocate(2); + + intWriter.configure(index, N, payload); + intWriter.write(bb); + intWriter.reset(); + + bb.flip(); + return bb; + } + + private ByteBuffer writeNameRef(int index, boolean sensitive, String value) { + int N = 4; + int payload = 0b0101_0000; // static table = true + if (sensitive) + payload |= 0b0010_0000; + var bb = allocateNameRefBuffer(N, index, value); + intWriter.configure(index, N, payload); + intWriter.write(bb); + intWriter.reset(); + + boolean huffman = QuickHuffman.isHuffmanBetterFor(value); + int huffmanMask = 0b0000_0000; + if (huffman) + huffmanMask = 0b1000_0000; + stringWriter.configure(value, 7, huffmanMask, huffman); + stringWriter.write(bb); + stringWriter.reset(); + + bb.flip(); + return bb; + } + + private ByteBuffer writeLiteral(boolean sensitive, String name, String value) { + int N = 3; + //boolean hasInSt = Sta; + int payload = 0b0010_0000; // static table = true + var bb = allocateLiteralBuffer(N, name, value); + + if (sensitive) + payload |= 0b0001_0000; + boolean huffmanName = QuickHuffman.isHuffmanBetterFor(name); + if (huffmanName) + payload |= 0b0000_1000; + stringWriter.configure(name, N, payload, huffmanName); + stringWriter.write(bb); + stringWriter.reset(); + + boolean huffmanValue = QuickHuffman.isHuffmanBetterFor(value); + int huffmanMask = 0b0000_0000; + if (huffmanValue) + huffmanMask = 0b1000_0000; + stringWriter.configure(value, 7, huffmanMask, huffmanValue); + stringWriter.write(bb); + stringWriter.reset(); + + bb.flip(); + return bb; + } + + private ByteBuffer allocateIndexBuffer(int index) { + /* + * Note on Integer Representation used for storing the length of name and value strings. + * Taken from RFC 7541 Section 5.1 + * + * "An integer is represented in two parts: a prefix that fills the current octet and an + * optional list of octets that are used if the integer value does not fit within the + * prefix. The number of bits of the prefix (called N) is a parameter of the integer + * representation. If the integer value is small enough, i.e., strictly less than 2N-1, it + * is encoded within the N-bit prefix. + * + * ... + * + * Otherwise, all the bits of the prefix are set to 1, and the value, decreased by 2N-1, is + * encoded using a list of one or more octets. The most significant bit of each octet is + * used as a continuation flag: its value is set to 1 except for the last octet in the list. + * The remaining bits of the octets are used to encode the decreased value." + * + * Use "null" for name, if name isn't being provided (i.e. for a nameRef); otherwise, buffer + * will be too large. + * + */ + int N = 6; // bits available in first byte + int size = 1; + index -= Math.pow(2, N) - 1; // number that you can store in first N bits + while (index >= 0) { + index -= 127; + size++; + } + return ByteBuffer.allocate(size); + } + + private ByteBuffer allocateNameRefBuffer(int N, int index, CharSequence value) { + int vlen = Math.min(QuickHuffman.lengthOf(value), value.length()); + int size = 1 + vlen; + + index -= Math.pow(2, N) - 1; + while (index >= 0) { + index -= 127; + size++; + } + vlen -= 127; + size++; + while (vlen >= 0) { + vlen -= 127; + size++; + } + return ByteBuffer.allocate(size); + } + + private ByteBuffer allocateLiteralBuffer(int N, CharSequence name, CharSequence value) { + int nlen = Math.min(QuickHuffman.lengthOf(name), name.length()); + int vlen = Math.min(QuickHuffman.lengthOf(value), value.length()); + int size = nlen + vlen; + + nlen -= Math.pow(2, N) - 1; + size++; + while (nlen >= 0) { + nlen -= 127; + size++; + } + + vlen -= 127; + size++; + while (vlen >= 0) { + vlen -= 127; + size++; + } + return ByteBuffer.allocate(size); + } + + private static final String LOREM = """ + Lorem ipsum dolor sit amet, consectetur adipiscing + elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris + nisi ut aliquip ex ea commodo consequat.Duis aute irure dolor in + reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla + pariatur.Excepteur sint occaecat cupidatat non proident, sunt in + culpa qui officia deserunt mollit anim id est laborum."""; + + private String randomString() { + int lower = random.nextInt(LOREM.length() - TEST_STR_MAX_LENGTH); + /** + * The empty string ("") is a valid value String in the static table and the random + * String returned cannot refer to a entry in the table. Therefore, we set the upper + * bound below to a minimum of 1. + */ + return LOREM.substring(lower, 1 + lower + random.nextInt(TEST_STR_MAX_LENGTH)); + } + + private static class TestingCallBack implements DecodingCallback { + final int index; + final boolean huffmanName, huffmanValue; + final boolean sensitive; + final String name, value; + + // Indexed + TestingCallBack(int index, String name, String value) { + this(index, false, name, value); + } + // Literal w/Literal Name + TestingCallBack(boolean sensitive, String name, String value) { + this(-1, sensitive, name, value); + } + // Literal w/Name Reference + TestingCallBack(int index, boolean sensitive, String name, String value) { + this.index = index; + this.sensitive = sensitive; + this.name = name; + this.value = value; + this.huffmanName = QuickHuffman.isHuffmanBetterFor(name); + this.huffmanValue = QuickHuffman.isHuffmanBetterFor(value); + } + + @Override + public void onDecoded(CharSequence name, CharSequence value) { + fail("should not be called"); + } + + @Override + public void onIndexed(long index, CharSequence name, CharSequence value) { + assertEquals(this.index, index); + assertEquals(this.name, name.toString()); + assertEquals(this.value, value.toString()); + } + + @Override + public void onLiteralWithNameReference(long index, CharSequence name, + CharSequence value, boolean huffmanValue, + boolean sensitive) { + assertEquals(this.index, index); + assertEquals(this.value, value.toString()); + assertEquals(this.huffmanValue, huffmanValue); + assertEquals(this.sensitive, sensitive); + } + + @Override + public void onLiteralWithLiteralName(CharSequence name, boolean huffmanName, + CharSequence value, boolean huffmanValue, + boolean sensitive) { + assertEquals(this.name, name.toString()); + assertEquals(this.huffmanName, huffmanName); + assertEquals(this.value, value.toString()); + assertEquals(this.huffmanValue, huffmanValue); + assertEquals(this.sensitive, sensitive); + } + + @Override + public void onComplete() { + } + + @Override + public void onConnectionError(Throwable throwable, Http3Error http3Error) { + fail(http3Error + "Decoding error:" + http3Error, throwable); + } + + @Override + public long streamId() { + return 0; + } + } + + private static void qpackErrorHandler(Throwable error, Http3Error http3Error) { + fail("QPACK error:" + http3Error, error); + } +} diff --git a/test/jdk/java/net/httpclient/qpack/DynamicTableFieldLineRepresentationTest.java b/test/jdk/java/net/httpclient/qpack/DynamicTableFieldLineRepresentationTest.java new file mode 100644 index 00000000000..cdf584d1854 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/DynamicTableFieldLineRepresentationTest.java @@ -0,0 +1,404 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.Encoder; +import jdk.test.lib.RandomFactory; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicReference; + +import static org.testng.Assert.*; + +/* + * @test + * @key randomness + * @library /test/lib + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack:+open + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @build EncoderDecoderConnector + * @run testng/othervm -Djdk.internal.httpclient.qpack.log.level=EXTRA + * -Djdk.http.qpack.allowBlockingEncoding=true + * -Djdk.http.qpack.decoderBlockedStreams=4 + * DynamicTableFieldLineRepresentationTest + */ +public class DynamicTableFieldLineRepresentationTest { + + private static final Random RANDOM = RandomFactory.getRandom(); + private static final long DT_CAPACITY = 4096L; + + //4.5.2. Indexed Field Line + @Test + public void indexedFieldLineOnDynamicTable() throws IOException { + boolean sensitive = RANDOM.nextBoolean(); + List buffers = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + EncoderDecoderConnector encoderDecoderConnector = new EncoderDecoderConnector(); + + TestErrorHandler encoderErrorHandler = new TestErrorHandler(); + TestErrorHandler decoderErrorHandler = new TestErrorHandler(); + + var conn = encoderDecoderConnector.newEncoderDecoderPair(entry -> true, + encoderErrorHandler::qpackErrorHandler, decoderErrorHandler::qpackErrorHandler, + error::set); + + // Create encoder and decoder + var encoder = conn.encoder(); + var decoder = conn.decoder(); + + // Configure settingsFrames and prepopulate table + configureConnector(conn,1); + + var name = conn.decoderTable().get(0).name(); + var value = conn.decoderTable().get(0).value(); + + // Create header frame reader and writer + var headerFrameWriter = encoder.newHeaderFrameWriter(); + var callback = new TestingDynamicCallBack( name, value); + var headerFrameReader = decoder.newHeaderFrameReader(callback); + + // create encoding context + Encoder.EncodingContext context = encoder.newEncodingContext(0, 1, + headerFrameWriter); + ByteBuffer headersBb = ByteBuffer.allocate(2048); + + // Configures encoder for writing the header name:value pair + encoder.header(context, name, value, sensitive, -1); + + // Write the header + headerFrameWriter.write(headersBb); + assertNotEquals(headersBb.position(), 0); + headersBb.flip(); + buffers.add(headersBb); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + + // Decode headers + decoder.decodeHeader(buffers.get(0), false, headerFrameReader); + decoder.decodeHeader(buffers.get(1), true, headerFrameReader); + assertEquals(callback.lastIndexedName,name); + } + + //4.5.3. Indexed Field Line with Post-Base Index + @Test + public void indexedFieldLineOnDynamicTablePostBase() throws IOException { + System.err.println("start indexedFieldLineOnDynamicTablePostBase"); + boolean sensitive = RANDOM.nextBoolean(); + + List buffers = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + EncoderDecoderConnector encoderDecoderConnector = new EncoderDecoderConnector(); + + TestErrorHandler encoderErrorHandler = new TestErrorHandler(); + TestErrorHandler decoderErrorHandler = new TestErrorHandler(); + + var conn = encoderDecoderConnector.newEncoderDecoderPair(entry -> true, + encoderErrorHandler::qpackErrorHandler, decoderErrorHandler::qpackErrorHandler, + error::set); + + // Create encoder and decoder + var encoder = conn.encoder(); + var decoder = conn.decoder(); + + // Create settings frame with dynamic table capacity and number of blocked streams + configureConnector(conn, 3); + + var name = conn.decoderTable().get(1).name(); + var value = conn.decoderTable().get(1).value(); + + // Create header frame reader and writer + var headerFrameWriter = encoder.newHeaderFrameWriter(); + var callback = new TestingDynamicCallBack( name, value); + var headerFrameReader = decoder.newHeaderFrameReader(callback); + + // create encoding context + Encoder.EncodingContext context = encoder.newEncodingContext(0, 1, + headerFrameWriter); + ByteBuffer headersBb = ByteBuffer.allocate(2048); + + // Configures encoder for writing the header name:value pair + encoder.header(context, name, value, sensitive, -1); + + // Write the header + headerFrameWriter.write(headersBb); + assertNotEquals(headersBb.position(), 0); + headersBb.flip(); + buffers.add(headersBb); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + + // Decode headers + decoder.decodeHeader(buffers.get(0), false, headerFrameReader); + decoder.decodeHeader(buffers.get(1), true, headerFrameReader); + assertEquals(callback.lastIndexedName,name); + assertEquals(callback.lastValue,value); + } + + // 4.5.4. Literal Field Line with Name Reference + // A literal field line with name reference representation encodes a field line + // where the field name matches the field name of an entry in the static table + // or the field name of an entry in the dynamic table with an absolute index + // less than the value of the Base. + @Test + public void literalFieldLineNameReference() throws IOException { + System.err.println("start literalFieldLineNameReference"); + boolean sensitive = RANDOM.nextBoolean(); + + List buffers = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + EncoderDecoderConnector encoderDecoderConnector = new EncoderDecoderConnector(); + + TestErrorHandler encoderErrorHandler = new TestErrorHandler(); + TestErrorHandler decoderErrorHandler = new TestErrorHandler(); + + var conn = encoderDecoderConnector.newEncoderDecoderPair(entry -> false, + encoderErrorHandler::qpackErrorHandler, decoderErrorHandler::qpackErrorHandler, + error::set); + + // Create encoder and decoder + var encoder = conn.encoder(); + var decoder = conn.decoder(); + + // Create settings frame with dynamic table capacity and number of blocked streams + configureConnector(conn, 3); + + var name = conn.decoderTable().get(1).name(); + var value = conn.decoderTable().get(2).value(); // don't want value to match + + + // Create header frame reader and writer + var headerFrameWriter = encoder.newHeaderFrameWriter(); + var callback = new TestingDynamicCallBack(name, value); + var headerFrameReader = decoder.newHeaderFrameReader(callback); + + // create encoding context + Encoder.EncodingContext context = encoder.newEncodingContext(0, 3, + headerFrameWriter); + ByteBuffer headersBb = ByteBuffer.allocate(2048); + + // Configures encoder for writing the header name:value pair + encoder.header(context, name, value, sensitive, -1); + + // Write the header + headerFrameWriter.write(headersBb); + assertNotEquals(headersBb.position(), 0); + headersBb.flip(); + buffers.add(headersBb); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + + // Decode headers + decoder.decodeHeader(buffers.get(0), false, headerFrameReader); + decoder.decodeHeader(buffers.get(1), true, headerFrameReader); + assertEquals(callback.lastReferenceName, name); + } + + //4.5.5. Literal Field Line with Post-Base Name Reference + @Test + public void literalFieldLineNameReferencePostBase() throws IOException { + System.err.println("start literalFieldLineNameReferencePostBase"); + boolean sensitive = RANDOM.nextBoolean(); + + List buffers = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + EncoderDecoderConnector encoderDecoderConnector = new EncoderDecoderConnector(); + + TestErrorHandler encoderErrorHandler = new TestErrorHandler(); + TestErrorHandler decoderErrorHandler = new TestErrorHandler(); + + var conn = encoderDecoderConnector.newEncoderDecoderPair(entry -> false, + encoderErrorHandler::qpackErrorHandler, decoderErrorHandler::qpackErrorHandler, + error::set); + + // Create encoder and decoder + var encoder = conn.encoder(); + var decoder = conn.decoder(); + + // Create settings frame with dynamic table capacity and number of blocked streams + configureConnector(conn,4); + + var name = conn.decoderTable().get(3).name(); + var value = conn.decoderTable().get(2).value(); // don't want value to match + + // Create header frame reader and writer + var headerFrameWriter = encoder.newHeaderFrameWriter(); + var callback = new TestingDynamicCallBack(name, value); + var headerFrameReader = decoder.newHeaderFrameReader(callback); + + // create encoding context + Encoder.EncodingContext context = encoder.newEncodingContext(0, 2, + headerFrameWriter); + ByteBuffer headersBb = ByteBuffer.allocate(2048); + + // Configures encoder for writing the header name:value pair + encoder.header(context, name, value, sensitive, -1); + + // Write the header + headerFrameWriter.write(headersBb); + assertNotEquals(headersBb.position(), 0); + headersBb.flip(); + buffers.add(headersBb); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + + // Decode headers + decoder.decodeHeader(buffers.get(0), false, headerFrameReader); + decoder.decodeHeader(buffers.get(1), true, headerFrameReader); + assertEquals(callback.lastReferenceName, name); + assertEquals(callback.lastValue, value); + } + + private void configureConnector(EncoderDecoderConnector.EncoderDecoderPair connector, int numberOfEntries){ + // Create settings frame with dynamic table capacity and number of blocked streams + SettingsFrame settingsFrame = SettingsFrame.defaultRFCSettings(); + // 4k should be enough for storing dynamic table entries added by 'prepopulateDynamicTable' + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, DT_CAPACITY); + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_BLOCKED_STREAMS,2); + + ConnectionSettings settings = ConnectionSettings.createFrom(settingsFrame); + + // Configure encoder and decoder with constructed ConnectionSettings + connector.encoder().configure(settings); + connector.decoder().configure(settings); + connector.encoderTable().setCapacity(DT_CAPACITY); + connector.decoderTable().setCapacity(DT_CAPACITY); + + // add basic matching entries to both + prepopulateDynamicTable(connector.encoderTable(), numberOfEntries); + prepopulateDynamicTable(connector.decoderTable(), numberOfEntries); + } + + private static void prepopulateDynamicTable(DynamicTable dynamicTable, int numEntries) { + for (int count = 0; count < numEntries; count++) { + var header = TestHeader.withId(count); + dynamicTable.insert(header.name(), header.value()); + } + } + + private static class TestErrorHandler { + final AtomicReference error = new AtomicReference<>(); + final AtomicReference http3Error = new AtomicReference<>(); + + public void qpackErrorHandler(Throwable error, Http3Error http3Error) { + this.error.set(error); + this.http3Error.set(http3Error); + } + } + + private static class TestingDynamicCallBack implements DecodingCallback { + final String name, value; + String lastLiteralName = null; + String lastReferenceName = null; + String lastValue = null; + String lastIndexedName = null; + + TestingDynamicCallBack(String name, String value) { + this.name = name; + this.value = value; + } + + @Override + public void onDecoded(CharSequence actualName, CharSequence value) { + fail("onDecoded should not be called"); + } + + @Override + public void onComplete() { + System.out.println("completed it"); + } + + @Override + public void onConnectionError(Throwable throwable, Http3Error http3Error) { + fail("Decoding error: " + http3Error, throwable); + } + + @Override + public long streamId() { + return 0; + } + + @Override + public void onIndexed(long actualIndex, CharSequence actualName, CharSequence actualValue) { + System.out.println("Indexed called"); + assertEquals(actualName, name); + assertEquals(actualValue, value); + lastValue = value; + lastIndexedName = name; + } + + @Override + public void onLiteralWithNameReference(long index, + CharSequence actualName, + CharSequence actualValue, + boolean valueHuffman, + boolean hideIntermediary) { + System.out.println("Literal with name reference called"); + assertEquals(actualName.toString(), name); + assertEquals(actualValue.toString(), value); + lastReferenceName = name; + lastValue = value; + } + + @Override + public void onLiteralWithLiteralName(CharSequence actualName, boolean nameHuffman, + CharSequence actualValue, boolean valueHuffman, + boolean hideIntermediary) { + System.out.println("Literal with literal name called"); + assertEquals(actualName.toString(), name); + assertEquals(actualValue.toString(), value); + lastLiteralName = name; + lastValue = value; + } + } + + record TestHeader(String name, String value) { + public static BlockingDecodingTest.TestHeader withId(int id) { + return new BlockingDecodingTest.TestHeader(NAME + id, VALUE + id); + } + } + + private static final String NAME = "test"; + private static final String VALUE = "valueTest"; +} diff --git a/test/jdk/java/net/httpclient/qpack/DynamicTableTest.java b/test/jdk/java/net/httpclient/qpack/DynamicTableTest.java new file mode 100644 index 00000000000..80765e41787 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/DynamicTableTest.java @@ -0,0 +1,366 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @key randomness + * @library /test/lib + * @modules java.net.http/jdk.internal.net.http.qpack:+open + * java.net.http/jdk.internal.net.http.qpack.readers + * @run testng/othervm -Djdk.internal.httpclient.qpack.log.level=NORMAL DynamicTableTest + */ + +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.readers.IntegerReader; +import jdk.test.lib.RandomFactory; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.lang.invoke.VarHandle; +import java.util.Arrays; +import java.util.Random; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.stream.IntStream; + +public class DynamicTableTest { + + // Test for addition to the table and that indices are growing monotonically, + // and they can be used to retrieve previously added entries + @Test + public void monotonicIndexes() { + int tableMaxCapacityBytes = 2048; + int numberOfElementsToAdd = 1024; + int charsPerNumber = (int) Math.ceil(Math.log10(numberOfElementsToAdd)); + int oneElementSize = 32 + HEADER_NAME_PREFIX.length() + HEADER_VALUE_PREFIX.length() + charsPerNumber * 2; + + // Expected table capacity in elements + int maxElementsInTable = tableMaxCapacityBytes / oneElementSize; + + // Test element id counter + long lastAddedId; + var dynamicTable = new DynamicTable(QPACK.getLogger().subLogger("monotonicIndexes")); + dynamicTable.setMaxTableCapacity(2048); + dynamicTable.setCapacity(tableMaxCapacityBytes); + + for (lastAddedId = 0; lastAddedId < numberOfElementsToAdd; lastAddedId++) { + var name = generateHeaderString(lastAddedId, true, charsPerNumber); + var value = generateHeaderString(lastAddedId, false, charsPerNumber); + long addedId = dynamicTable.insert(name, value); + + // Check that dynamic table put gives back monotonically increasing indexes + Assert.assertEquals(addedId, lastAddedId); + + if (lastAddedId > maxElementsInTable) { + // Check that oldest element is available and not reclaimed + long oldestAliveId = lastAddedId - maxElementsInTable + 1; + dynamicTable.get(oldestAliveId); + + // Check that relative indexing can be used to get oldest and newest entry + dynamicTable.getRelative(maxElementsInTable - 1); + dynamicTable.getRelative(0); + + // Check that reverse lookup is working for a random index from not reclaimed region + long rid = RANDOM.nextLong(oldestAliveId, lastAddedId); + String rName = generateHeaderString(rid, true, charsPerNumber); + String rValue = generateHeaderString(rid, false, charsPerNumber); + + // The reverse lookup search result range is shifted by 1 to implement search result indexing: + // full match found in a table: searchResult = idx + 1 + // partial match (name) in a table: searchResult = -idx - 1 + // no match: 0 + long fullMatchSearchResult = dynamicTable.search(rName, rValue); + long onlyNameSearchResult = dynamicTable.search(rName, "notFoundInTable"); + long noMatchResult = dynamicTable.search(HEADER_NAME_PREFIX, HEADER_VALUE_PREFIX); + + Assert.assertEquals(fullMatchSearchResult - 1, rid); + Assert.assertEquals(-onlyNameSearchResult - 1, rid); + Assert.assertEquals(noMatchResult, 0); + } + } + } + + @Test(dataProvider = "randomTableResizeData") + public void randomTableResize(int initialSize, long tail, long head, int resizeTo) + throws Throwable { + HeaderField[] initial = generateHeadersArray(initialSize, tail, head); + resizeTestRunner(initial, tail, head, resizeTo); + } + + @DataProvider + public Object[][] randomTableResizeData() { + return IntStream.range(0, 1000) + .boxed() + .map(i -> newRandomTableConfiguration()) + .toArray(Object[][]::new); + } + + @Test + public void holderArrayLengthTest() { + // Test that holder array size for storing elements is increased according to demand on array + // elements, and that by default its length is set to 64 elements. + var dynamicTable = new DynamicTable(QPACK.getLogger().subLogger("tableResizeTests")); + + // Check that the initial array length is DynamicTable.INITIAL_HOLDER_ARRAY_SIZE + Assert.assertEquals(getElementsArrayLength(dynamicTable), + INITIAL_HOLDER_ARRAY_SIZE); + + // Update dynamic table capacity to maximum allowed value and check + // that holder array is not changed + dynamicTable.setMaxTableCapacity(IntegerReader.QPACK_MAX_INTEGER_VALUE); + dynamicTable.setCapacity(IntegerReader.QPACK_MAX_INTEGER_VALUE); + Assert.assertEquals(getElementsArrayLength(dynamicTable), + INITIAL_HOLDER_ARRAY_SIZE); + + // Add DynamicTable.INITIAL_HOLDER_ARRAY_SIZE + 1 element to the dynamic table + // and check that its length is increased 2 times + for (int i = 0; i <= INITIAL_HOLDER_ARRAY_SIZE; i++) { + dynamicTable.insert("name" + i, "value" + i); + } + Assert.assertEquals(getElementsArrayLength(dynamicTable), INITIAL_HOLDER_ARRAY_SIZE << 1); + } + + // Test for a simple resize that checks that unique indexes still reference the correct entry + @Test(dataProvider = "simpleTableResizeData") + public void simpleTableResize(HeaderField[] array, long tail, long head, int resizeTo) throws Throwable { + resizeTestRunner(array, tail, head, resizeTo); + } + + @DataProvider + public Object[][] simpleTableResizeData() { + return new Object[][]{ + tableResizeScenario1(), tableResizeScenario2(), + tableResizeScenario3(), tableResizeScenario4(), + tableResizeScenario5()}; + } + + private Object[] tableResizeScenario1() { + HeaderField[] elements = new HeaderField[8]; + elements[5] = new HeaderField("5", "5"); // Tail + elements[6] = new HeaderField("6", "6"); + elements[7] = new HeaderField("7", "7"); // Head + return new Object[]{elements, 21L, 24L, 4}; + } + + private Object[] tableResizeScenario2() { + HeaderField[] elements = new HeaderField[8]; + elements[2] = new HeaderField("2", "2"); // Tail + elements[3] = new HeaderField("3", "3"); + elements[4] = new HeaderField("4", "4"); + elements[5] = new HeaderField("5", "5"); // Head + return new Object[]{elements, 26L, 30L, 4}; + } + + private Object[] tableResizeScenario3() { + HeaderField[] elements = new HeaderField[8]; + elements[0] = new HeaderField("4", "4"); + elements[1] = new HeaderField("5", "5"); // Head + elements[6] = new HeaderField("2", "2"); // Tail + elements[7] = new HeaderField("3", "3"); + return new Object[]{elements, 30L, 34L, 64}; + } + + private Object[] tableResizeScenario4() { + HeaderField[] elements = new HeaderField[8]; + elements[0] = new HeaderField("4", "4"); + elements[1] = new HeaderField("5", "5"); // Head + elements[5] = new HeaderField("1", "1"); // Tail + elements[6] = new HeaderField("2", "2"); + elements[7] = new HeaderField("3", "3"); + return new Object[]{elements, 29L, 34L, 16}; + } + + private Object[] tableResizeScenario5() { + HeaderField[] elements = new HeaderField[64]; + elements[10] = new HeaderField("1", "1"); + return new Object[]{elements, 3977L, 3978L, 16}; + } + + private static void resizeTestRunner(HeaderField[] array, long tail, long head, int resizeTo) throws Throwable { + assert tail < head; + var dynamicTable = new DynamicTable(QPACK.getLogger().subLogger("tableResizeTests")); + dynamicTable.setMaxTableCapacity(2048); + dynamicTable.setCapacity(2048); + // Prepare dynamic table state for the resize operation + DT_ELEMENTS_VH.set(dynamicTable, array); + DT_HEAD_VH.set(dynamicTable, head); + DT_TAIL_VH.set(dynamicTable, tail); + + // Call resize + ReentrantReadWriteLock lock = (ReentrantReadWriteLock) DT_LOCK_VH.get(dynamicTable); + lock.writeLock().lock(); + HeaderField[] resizeResult; + try { + // Call DynamicTable.resize + DT_RESIZE_MH.invoke(dynamicTable, resizeTo); + // Acquire resize result + resizeResult = (HeaderField[]) DT_ELEMENTS_VH.get(dynamicTable); + } finally { + lock.writeLock().unlock(); + } + + // Check the resulting array by calculating the expected array + HeaderField[] expectedResult = calcResizeResult(array, tail, head, resizeTo); + + // Check the resulting of the resize operation + checkResizeResult(array, resizeResult, expectedResult); + } + + private static HeaderField[] generateHeadersArray(int size, long tail, long head) { + assert head > tail; + HeaderField[] res = new HeaderField[size]; + assert head > 0L; + int charsPerNumber = (int) (Math.log10(head) + 1); + for (long eid = tail; eid < head; eid++) { + int idx = (int) (eid % size); + res[idx] = new HeaderField(generateHeaderString(eid, true, charsPerNumber), + generateHeaderString(eid, false, charsPerNumber)); + } + return res; + } + + private static int getElementsArrayLength(DynamicTable dynamicTable) { + HeaderField[] array = (HeaderField[]) DT_ELEMENTS_VH.get(dynamicTable); + return array.length; + } + + private static Object[] newRandomTableConfiguration() { + boolean shrink = RANDOM.nextBoolean(); + int initialSize = pow2size(RANDOM.nextInt(2, 2048)); + int resizeTo = shrink ? pow2size(RANDOM.nextInt(1, initialSize)) : pow2size(RANDOM.nextInt(initialSize, 4096)); + int elementsCount = RANDOM.nextInt(0, Math.min(initialSize, resizeTo)); + long tail = RANDOM.nextLong(100000); + long head = tail + elementsCount + 1; + return new Object[]{initialSize, tail, head, resizeTo}; + } + + private static HeaderField[] calcResizeResult(HeaderField[] array, long tail, long head, int resizeTo) { + HeaderField[] result = new HeaderField[resizeTo]; + for (long p = tail; p < head; p++) { + int newIdx = (int) (p % resizeTo); + int oldIdx = (int) (p % array.length); + result[newIdx] = array[oldIdx]; + } + return result; + } + + private static void checkResizeResult(HeaderField[] initial, HeaderField[] resized, HeaderField[] expected) { + Assert.assertEquals(resized.length, expected.length); + for (int index = 0; index < expected.length; index++) { + if (!sameHeaderField(expected[index], resized[index])) { + System.err.println("Initial Array:" + Arrays.deepToString(initial)); + System.err.println("Resized Array:" + Arrays.deepToString(resized)); + System.err.println("Expected Array:" + Arrays.deepToString(expected)); + Assert.fail("DynamicTable.resize failed"); + } + } + } + + private static boolean sameHeaderField(HeaderField a, HeaderField b) { + // Check if one HeaderField is null and another is not null + if (a == null ^ b == null) { + return false; + } + // Given previous check, check if both HeaderField are null + if (a == null) { + return true; + } + // Both HFs are not null - will check name() and value() values + return a.name().equals(b.name()) && a.value().equals(b.value()); + } + + private static MethodHandle findDynamicTableResizeMH() { + try { + MethodType mt = MethodType.methodType(void.class, int.class); + return DT_LOOKUP.findVirtual(DynamicTable.class, "resize", mt); + } catch (Exception e) { + Assert.fail("Failed to initialize DynamicTable.resize MH", e); + return null; + } + } + + private static VarHandle findDynamicTableFieldVH(String fieldName, Class fieldType) { + try { + return DT_LOOKUP.findVarHandle(DynamicTable.class, fieldName, fieldType); + } catch (Exception e) { + Assert.fail("Failed to initialize DynamicTable private Lookup instance", e); + return null; + } + } + + private static T readDynamicTableStaticFieldValue(String fieldName, Class fieldType) { + try { + var vh = DT_LOOKUP.findStaticVarHandle(DynamicTable.class, fieldName, fieldType); + return (T) vh.get(); + } catch (Exception e) { + Assert.fail("Failed to read DynamicTable static field value", e); + return null; + } + } + + private static MethodHandles.Lookup initializeDtLookup() { + try { + return MethodHandles.privateLookupIn(DynamicTable.class, MethodHandles.lookup()); + } catch (IllegalAccessException e) { + Assert.fail("Failed to initialize DynamicTable private Lookup instance", e); + return null; + } + } + + + private static final MethodHandles.Lookup DT_LOOKUP; + private static final MethodHandle DT_RESIZE_MH; + private static final VarHandle DT_HEAD_VH; + private static final VarHandle DT_TAIL_VH; + private static final VarHandle DT_ELEMENTS_VH; + private static final VarHandle DT_LOCK_VH; + private static final int INITIAL_HOLDER_ARRAY_SIZE; + + static { + DT_LOOKUP = initializeDtLookup(); + DT_RESIZE_MH = findDynamicTableResizeMH(); + DT_HEAD_VH = findDynamicTableFieldVH("head", long.class); + DT_TAIL_VH = findDynamicTableFieldVH("tail", long.class); + DT_ELEMENTS_VH = findDynamicTableFieldVH("elements", HeaderField[].class); + DT_LOCK_VH = findDynamicTableFieldVH("lock", ReentrantReadWriteLock.class); + INITIAL_HOLDER_ARRAY_SIZE = readDynamicTableStaticFieldValue( + "INITIAL_HOLDER_ARRAY_LENGTH", int.class); + } + + private static String generateHeaderString(long id, boolean generateName, int charsPerNumber) { + return (generateName ? HEADER_NAME_PREFIX : HEADER_VALUE_PREFIX) + ("%0" + charsPerNumber + "d").formatted(id); + } + + private static int pow2size(int size) { + return 1 << (32 - Integer.numberOfLeadingZeros(size - 1)); + } + + private static final String HEADER_NAME_PREFIX = "HeaderName"; + private static final String HEADER_VALUE_PREFIX = "HeaderValue"; + private static final Random RANDOM = RandomFactory.getRandom(); +} diff --git a/test/jdk/java/net/httpclient/qpack/EncoderDecoderConnectionTest.java b/test/jdk/java/net/httpclient/qpack/EncoderDecoderConnectionTest.java new file mode 100644 index 00000000000..418af9be2ad --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/EncoderDecoderConnectionTest.java @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.http.qpack.TableEntry; +import jdk.internal.net.http.qpack.writers.EncoderInstructionsWriter; +import jdk.internal.net.http.qpack.writers.HeaderFrameWriter; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicReference; + +/* + * @test + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack:+open + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @build EncoderDecoderConnector + * @run testng/othervm -Djdk.internal.httpclient.qpack.log.level=EXTRA + * EncoderDecoderConnectionTest + */ +public class EncoderDecoderConnectionTest { + + @Test + public void capacityUpdateTest() { + AtomicReference error = new AtomicReference<>(); + EncoderDecoderConnector encoderDecoderConnector = new EncoderDecoderConnector(); + TestErrorHandler encoderErrorHandler = new TestErrorHandler(); + TestErrorHandler decoderErrorHandler = new TestErrorHandler(); + var conn = encoderDecoderConnector.newEncoderDecoderPair(entry -> true, + encoderErrorHandler::qpackErrorHandler, decoderErrorHandler::qpackErrorHandler, error::set); + + var encoder = conn.encoder(); + var decoder = conn.decoder(); + + // Set encoder and decoder maximum dynamic table capacity + SettingsFrame settingsFrame = SettingsFrame.defaultRFCSettings(); + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, 2048L); + ConnectionSettings settings = ConnectionSettings.createFrom(settingsFrame); + decoder.configure(settings); + encoder.configure(settings); + + // Encoder - update DT capacity + final long capacityToSet = 1024L; + encoder.setTableCapacity(capacityToSet); + + // Check that no errors observed + Assert.assertNull(encoderErrorHandler.error.get()); + Assert.assertNull(encoderErrorHandler.http3Error.get()); + Assert.assertNull(decoderErrorHandler.error.get()); + Assert.assertNull(decoderErrorHandler.http3Error.get()); + + // Check that encoder's table capacity is updated + Assert.assertEquals(conn.encoderTable().capacity(), capacityToSet); + // Since encoder/decoder streams are cross-wired we expect see dynamic + // table capacity updated for the decoder too + Assert.assertEquals(conn.decoderTable().capacity(), + conn.encoderTable().capacity()); + } + + @Test + public void entryInsertionTest() { + AtomicReference error = new AtomicReference<>(); + EncoderDecoderConnector encoderDecoderConnector = new EncoderDecoderConnector(); + TestErrorHandler encoderErrorHandler = new TestErrorHandler(); + TestErrorHandler decoderErrorHandler = new TestErrorHandler(); + + var conn = encoderDecoderConnector.newEncoderDecoderPair(entry -> true, + encoderErrorHandler::qpackErrorHandler, decoderErrorHandler::qpackErrorHandler, + error::set); + var encoder = conn.encoder(); + var decoder = conn.decoder(); + + // Set encoder and decoder maximum dynamic table capacity + SettingsFrame settingsFrame = SettingsFrame.defaultRFCSettings(); + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, 2048L); + ConnectionSettings settings = ConnectionSettings.createFrom(settingsFrame); + decoder.configure(settings); + encoder.configure(settings); + + // Update encoder and decoder DTs capacity - note that "Set Dynamic Table Capacity" + // is issued by the encoder that updates capacity on the decoder side + encoder.setTableCapacity(1024L); + + // Create table entry for insertion to the dynamic table + var entryToInsert = new TableEntry("test", "testValue"); + + // Create encoder instruction writer for generating "Insert with Literal Name" + // encoder instruction + var encoderInstructionWriter = new EncoderInstructionsWriter(); + + // Check that no errors observed + Assert.assertNull(encoderErrorHandler.error.get()); + Assert.assertNull(encoderErrorHandler.http3Error.get()); + Assert.assertNull(decoderErrorHandler.error.get()); + Assert.assertNull(decoderErrorHandler.http3Error.get()); + + // Issue the insert instruction on encoder stream + conn.encoderTable().insertWithEncoderStreamUpdate(entryToInsert, + encoderInstructionWriter, conn.encoderStreams(), + encoder.newEncodingContext(0, 0, new HeaderFrameWriter())); + var encoderHeader = conn.encoderTable().get(0); + var decoderHeader = conn.decoderTable().get(0); + Assert.assertEquals(encoderHeader.name(), decoderHeader.name()); + Assert.assertEquals(encoderHeader.value(), decoderHeader.value()); + } + + @Test + public void decoderErrorReportingTest() { + AtomicReference error = new AtomicReference<>(); + TestErrorHandler encoderErrorHandler = new TestErrorHandler(); + TestErrorHandler decoderErrorHandler = new TestErrorHandler(); + EncoderDecoderConnector encoderDecoderConnector = new EncoderDecoderConnector(); + var conn = encoderDecoderConnector.newEncoderDecoderPair(e -> false, + encoderErrorHandler::qpackErrorHandler, + decoderErrorHandler::qpackErrorHandler, + error::set); + SettingsFrame settingsFrame = SettingsFrame.defaultRFCSettings(); + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, 2048L); + ConnectionSettings settings = ConnectionSettings.createFrom(settingsFrame); + conn.encoder().configure(settings); + conn.encoderTable().setCapacity(1024L); + conn.encoderTable().insertWithEncoderStreamUpdate( + new TableEntry("a", "b"), + new EncoderInstructionsWriter(), + conn.encoderStreams(), + conn.encoder().newEncodingContext(0, 0, new HeaderFrameWriter())); + + // QPACK_ENCODER_STREAM_ERROR is expected on the decoder side + // since the decoder dynamic table capacity was not updated + Assert.assertEquals(decoderErrorHandler.http3Error.get(), + Http3Error.QPACK_ENCODER_STREAM_ERROR); + + // It is expected that http3 error reported to + // the decoder error handler only + Assert.assertNull(encoderErrorHandler.http3Error.get()); + } + + @Test + public void overflowIntegerInInstructions() { + AtomicReference error = new AtomicReference<>(); + TestErrorHandler encoderErrorHandler = new TestErrorHandler(); + TestErrorHandler decoderErrorHandler = new TestErrorHandler(); + EncoderDecoderConnector encoderDecoderConnector = new EncoderDecoderConnector(); + var conn = encoderDecoderConnector.newEncoderDecoderPair(e -> false, + encoderErrorHandler::qpackErrorHandler, + decoderErrorHandler::qpackErrorHandler, + error::set); + + // Forge byte buffer with encoder instruction containing integer > + // QPACK_MAX_INTEGER_VALUE + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 1 | Capacity (5+) | + // +---+---+---+-------------------+ + var encoderInstBb = instructionWithOverflowInteger(5, 0b0010_0000); + conn.encoderStreams().submitData(encoderInstBb); + + // Send bad decoder instruction back to encoder + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 1 | Stream ID (6+) | + // +---+---+-----------------------+ + var buffer = instructionWithOverflowInteger(6, 0b0100_0000); + conn.decoderStreams().submitData(buffer); + + // Analyze errors for expected results + Throwable encoderError = encoderErrorHandler.error.get(); + Http3Error encoderHttp3Error = encoderErrorHandler.http3Error.get(); + Throwable decoderError = decoderErrorHandler.error.get(); + Http3Error decoderHttp3Error = decoderErrorHandler.http3Error.get(); + + System.err.println("Encoder Error: " + encoderError); + System.err.println("Encoder Http3 error: " + encoderHttp3Error); + System.err.println("Decoder Error: " + decoderError); + System.err.println("Decoder Http3 error: " + decoderHttp3Error); + + if (encoderError == null || !(encoderError instanceof IOException)) { + Assert.fail("Incorrect encoder error type", encoderError); + } + if (decoderError == null || !(decoderError instanceof IOException)) { + Assert.fail("Incorrect decoder error type", decoderError); + } + Assert.assertEquals(encoderHttp3Error, Http3Error.QPACK_DECODER_STREAM_ERROR); + Assert.assertEquals(decoderHttp3Error, Http3Error.QPACK_ENCODER_STREAM_ERROR); + } + + private static ByteBuffer instructionWithOverflowInteger(int N, int payload) { + var buffer = ByteBuffer.allocate(11); + int max = (2 << (N - 1)) - 1; + buffer.put((byte) (payload | max)); + for (int i = 0; i < 9; i++) { + buffer.put((byte) 128); + } + buffer.put((byte)10); + buffer.flip(); + return buffer; + } + + private static class TestErrorHandler { + final AtomicReference error = new AtomicReference<>(); + final AtomicReference http3Error = new AtomicReference<>(); + + public void qpackErrorHandler(Throwable error, Http3Error http3Error) { + this.error.set(error); + this.http3Error.set(http3Error); + } + } +} diff --git a/test/jdk/java/net/httpclient/qpack/EncoderDecoderConnector.java b/test/jdk/java/net/httpclient/qpack/EncoderDecoderConnector.java new file mode 100644 index 00000000000..3cf185ff8a7 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/EncoderDecoderConnector.java @@ -0,0 +1,509 @@ +/* + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.http3.streams.Http3Streams; +import jdk.internal.net.http.http3.streams.QueuingStreamPair; +import jdk.internal.net.http.http3.streams.UniStreamPair; +import jdk.internal.net.http.qpack.Decoder; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.Encoder; +import jdk.internal.net.http.qpack.InsertionPolicy; +import jdk.internal.net.http.qpack.QPACK.QPACKErrorHandler; +import jdk.internal.net.http.qpack.QPackException; +import jdk.internal.net.http.qpack.readers.IntegerReader; +import jdk.internal.net.http.quic.ConnectionTerminator; +import jdk.internal.net.http.quic.QuicConnection; +import jdk.internal.net.http.quic.QuicEndpoint; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; +import jdk.internal.net.http.quic.streams.QuicSenderStream; +import jdk.internal.net.http.quic.streams.QuicStream; +import jdk.internal.net.http.quic.streams.QuicStreamWriter; +import jdk.internal.net.quic.QuicTLSEngine; +import org.testng.Assert; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.stream.Stream; + +/** + * Instance of this class provides a stubbed Quic Connection implementation that + * cross-wires encoder/decoder streams, and also provides access to + * encoder and decoder dynamic tables. + */ +public class EncoderDecoderConnector { + /** + * Constructs test connector instance capable of instantiating cross-wired encoder/decoder pair. + * It is achieved by implementing stubs for Quic connection and writer classes. + */ + public EncoderDecoderConnector() { + this(null, null); + } + + /** + * Constructs test connector instance capable of instantiating encoder/decoder pair. + * The encoder/decoder connections are not cross-wired. The provided byte buffer consumers + * are used by the underlying Quic connection instead. + * + * @param encoderBytesConsumer consumer of the encoder byte buffers + * @param decoderBytesConsumer consumer of the decoder byte buffers + */ + public EncoderDecoderConnector(Consumer encoderBytesConsumer, Consumer decoderBytesConsumer) { + encoderReceiverFuture = new CompletableFuture<>(); + decoderReceiverFuture = new CompletableFuture<>(); + encoderConnection = new TestQuicConnection(decoderReceiverFuture, encoderBytesConsumer); + decoderConnection = new TestQuicConnection(encoderReceiverFuture, decoderBytesConsumer); + } + + /** + * Create new encoder/decoder pair and establish Quic stub connection between them. + * + * @param encoderInsertionPolicy encoder insertion policy + * @param encoderErrorHandler encoder stream error handler + * @param decoderErrorHandler decoder stream error handler + * @param streamsErrorHandler streams error handler + * @return encoder/decoder pair + */ + public EncoderDecoderConnector.EncoderDecoderPair + newEncoderDecoderPair(InsertionPolicy encoderInsertionPolicy, + QPACKErrorHandler encoderErrorHandler, + QPACKErrorHandler decoderErrorHandler, + Consumer streamsErrorHandler) { + // One instance of this class supports only one encoder/decoder pair + if (connectionCreated) { + throw new IllegalStateException("Encoder/decoder pair was already instantiated"); + } + connectionCreated = true; + + // Create encoder + var encoder = new Encoder(encoderInsertionPolicy, (receiver) -> + createEncoderStreams(receiver, streamsErrorHandler), + encoderErrorHandler); + + // Create decoder + var decoder = new Decoder((receiver) -> + createDecoderStreams(receiver, streamsErrorHandler), + decoderErrorHandler); + // Extract encoder and decoder dynamic tables + DynamicTable encoderTable = dynamicTable(encoder); + DynamicTable decoderTable = dynamicTable(decoder); + return new EncoderDecoderConnector.EncoderDecoderPair(encoder, decoder, + encoderTable, decoderTable, encoderStreamPair, decoderStreamPair); + } + + /** + * Record describing {@linkplain EncoderDecoderConnector#EncoderDecoderConnector() cross-wired} + * OR {@link EncoderDecoderConnector#EncoderDecoderConnector(Consumer, Consumer) decoupled} + * encoder and decoder pair. + * The references for encoder and decoder dynamic tables also provided for testing purposes. + * + * @param encoder encoder + * @param decoder decoder + * @param encoderTable encoder's dynamic table + * @param decoderTable decoder's dynamic table + * @param encoderStreams encoder streams + */ + record EncoderDecoderPair(Encoder encoder, Decoder decoder, + DynamicTable encoderTable, + DynamicTable decoderTable, + QueuingStreamPair encoderStreams, + QueuingStreamPair decoderStreams) { + } + + + private CompletableFuture> decoderReceiverFuture; + private CompletableFuture> encoderReceiverFuture; + private final QuicConnection encoderConnection; + private final QuicConnection decoderConnection; + private volatile QueuingStreamPair encoderStreamPair; + private volatile QueuingStreamPair decoderStreamPair; + + private volatile boolean connectionCreated; + + private static DynamicTable dynamicTable(Encoder encoder) { + return (DynamicTable) ENCODER_DT_VH.get(encoder); + } + + private static DynamicTable dynamicTable(Decoder decoder) { + return (DynamicTable) DECODER_DT_VH.get(decoder); + } + + private static final MethodHandles.Lookup ENCODER_LOOKUP = + initializeLookup(Encoder.class); + private static final MethodHandles.Lookup DECODER_LOOKUP = + initializeLookup(Decoder.class); + private static final VarHandle ENCODER_DT_VH = findDynamicTableVH( + ENCODER_LOOKUP, Encoder.class); + private static final VarHandle DECODER_DT_VH = findDynamicTableVH( + DECODER_LOOKUP, Decoder.class); + + + private static MethodHandles.Lookup initializeLookup(Class clz) { + try { + return MethodHandles.privateLookupIn(clz, MethodHandles.lookup()); + } catch (IllegalAccessException e) { + Assert.fail("Failed to initialize private Lookup instance", e); + return null; + } + } + + private static VarHandle findDynamicTableVH( + final MethodHandles.Lookup lookup, Class recv) { + try { + return lookup.findVarHandle(recv, "dynamicTable", DynamicTable.class); + } catch (Exception e) { + Assert.fail("Failed to acquire dynamic table VarHandle instance", e); + return null; + } + } + + + QueuingStreamPair createEncoderStreams(Consumer receiver, + Consumer errorHandler) { + QueuingStreamPair streamPair = new QueuingStreamPair( + Http3Streams.StreamType.QPACK_ENCODER, + encoderConnection, + receiver, + TestErrorHandler.of(errorHandler), + Utils.getDebugLogger(() -> "test-encoder")); + encoderReceiverFuture.complete(receiver); + encoderStreamPair = streamPair; + return streamPair; + } + + QueuingStreamPair createDecoderStreams(Consumer receiver, + Consumer errorHandler) { + QueuingStreamPair streamPair = new QueuingStreamPair( + Http3Streams.StreamType.QPACK_DECODER, + decoderConnection, + receiver, + TestErrorHandler.of(errorHandler), + Utils.getDebugLogger(() -> "test-decoder")); + decoderReceiverFuture.complete(receiver); + decoderStreamPair = streamPair; + return streamPair; + } + + private class TestQuicConnection extends QuicConnection { + + public TestQuicConnection(CompletableFuture> receiverFuture, + Consumer bytesWriter) { + this.sender = new TestQuicSenderStream(receiverFuture, bytesWriter); + } + + final TestQuicSenderStream sender; + + @Override + public boolean isOpen() { + return true; + } + + @Override + public TerminationCause terminationCause() { + return null; + } + + @Override + public QuicTLSEngine getTLSEngine() { + return null; + } + + @Override + public InetSocketAddress peerAddress() { + return null; + } + + @Override + public SocketAddress localAddress() { + return null; + } + + @Override + public CompletableFuture startHandshake() { + return null; + } + + @Override + public CompletableFuture openNewLocalBidiStream( + Duration duration) { + return null; + } + + @Override + public CompletableFuture openNewLocalUniStream( + Duration duration) { + // This method is called to create two unidirectional streams: + // one for decoder, one for encoder + return MinimalFuture.completedFuture(sender); + } + + @Override + public void addRemoteStreamListener( + Predicate streamConsumer) { + } + + @Override + public boolean removeRemoteStreamListener( + Predicate streamConsumer) { + return false; + } + + @Override + public Stream quicStreams() { + return null; + } + + @Override + public CompletableFuture handshakeReachedPeer() { + return MinimalFuture.completedFuture(null); + } + + @Override + public CompletableFuture requestSendPing() { + return MinimalFuture.completedFuture(-1L); + } + + @Override + public ConnectionTerminator connectionTerminator() { + return null; + } + + @Override + public String dbgTag() { + return null; + } + + @Override + public String logTag() { + return null; + } + } + + + private class TestQuicStreamWriter extends QuicStreamWriter { + final TestQuicSenderStream sender; + volatile boolean gotStreamType; + volatile Http3Streams.StreamType associatedStreamType; + + final CompletableFuture> receiverFuture; + final Consumer bytesWriter; + + TestQuicStreamWriter(SequentialScheduler scheduler, TestQuicSenderStream sender, + CompletableFuture> receiverFuture, + Consumer bytesWriter) { + super(scheduler); + this.sender = sender; + this.gotStreamType = false; + this.receiverFuture = receiverFuture; + this.bytesWriter = bytesWriter; + } + + private void write(ByteBuffer bb) { + if (bytesWriter == null) { + if (!gotStreamType) { + IntegerReader integerReader = new IntegerReader(); + integerReader.configure(8); + try { + integerReader.read(bb); + } catch (QPackException e) { + System.err.println("Can't read stream type byte"); + } + Http3Streams.StreamType type = Http3Streams.StreamType.ofCode((int) integerReader.get()).get(); + System.err.println("Stream opened with type=" + type); + gotStreamType = true; + associatedStreamType = type; + } else { + if (receiverFuture.isDone() && !receiverFuture.isCompletedExceptionally()) { + Consumer receiver = receiverFuture.getNow(null); + if (receiver != null) { + receiver.accept(bb); + } + } + } + } else { + bytesWriter.accept(bb); + } + } + + @Override + public QuicSenderStream.SendingStreamState sendingState() { + return null; + } + + @Override + public void scheduleForWriting(ByteBuffer buffer, boolean last) + throws IOException { + write(buffer); + } + + @Override + public void queueForWriting(ByteBuffer buffer) throws IOException { + write(buffer); + } + + @Override + public long credit() { + return Long.MAX_VALUE; + } + + @Override + public void reset(long errorCode) { + } + + @Override + public QuicSenderStream stream() { + return connected() ? sender : null; + } + + @Override + public boolean connected() { + return sender.writer == this; + } + } + + class TestQuicSenderStream implements QuicSenderStream { + private static AtomicLong ids = new AtomicLong(); + private final long id; + TestQuicStreamWriter writer; + Consumer bytesWriter; + final CompletableFuture> receiverFuture; + + TestQuicSenderStream(CompletableFuture> receiverFuture, + Consumer bytesWriter) { + id = ids.getAndIncrement() * 4 + type(); + this.receiverFuture = receiverFuture; + this.bytesWriter = bytesWriter; + } + + @Override + public SendingStreamState sendingState() { + return SendingStreamState.READY; + } + + @Override + public QuicStreamWriter connectWriter(SequentialScheduler scheduler) { + return writer == null ? writer = new TestQuicStreamWriter( + scheduler, this, receiverFuture, bytesWriter) : writer; + } + + @Override + public void disconnectWriter(QuicStreamWriter writer) { + } + + @Override + public void reset(long errorCode) { + } + + @Override + public long dataSent() { + return 0; + } + + @Override + public long streamId() { + return id; + } + + @Override + public StreamMode mode() { + return null; + } + + @Override + public boolean isClientInitiated() { + return true; + } + + @Override + public boolean isServerInitiated() { + return false; + } + + @Override + public boolean isBidirectional() { + return false; + } + + @Override + public boolean isLocalInitiated() { + return true; + } + + @Override + public boolean isRemoteInitiated() { + return false; + } + + @Override + public int type() { + return 0x02; + } + + @Override + public StreamState state() { + return SendingStreamState.READY; + } + + @Override + public long sndErrorCode() { + return -1; + } + + @Override + public boolean stopSendingReceived() { + return false; + } + } + + private static class TestErrorHandler + implements UniStreamPair.StreamErrorHandler { + final Consumer handler; + + private TestErrorHandler(Consumer handler) { + this.handler = handler; + } + + @Override + public void onError(QuicStream stream, UniStreamPair uniStreamPair, + Throwable throwable) { + handler.accept(throwable); + } + + public static TestErrorHandler of(Consumer handler) { + return new TestErrorHandler(handler); + } + } +} diff --git a/test/jdk/java/net/httpclient/qpack/EncoderDecoderTest.java b/test/jdk/java/net/httpclient/qpack/EncoderDecoderTest.java new file mode 100644 index 00000000000..a1fa9745e05 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/EncoderDecoderTest.java @@ -0,0 +1,442 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.Encoder; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.hpack.QuickHuffman; +import jdk.internal.net.http.qpack.StaticTable; +import org.testng.annotations.Test; +import org.testng.annotations.DataProvider; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; + +/* + * @test + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack:+open + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @build EncoderDecoderConnector + * @run testng/othervm EncoderDecoderTest + */ +public class EncoderDecoderTest { + private final Random random = new Random(); + + private static final int TEST_STR_MAX_LENGTH = 10; + + + private static void qpackErrorHandler(Throwable error, Http3Error http3Error) { + fail(http3Error + "QPACK error:" + http3Error, error); + } + + @DataProvider(name = "indexProvider") + public Object[][] indexProvider() { + AtomicLong index = new AtomicLong(); + return StaticTable.HTTP3_HEADER_FIELDS.stream() + .map(headerField -> List.of(index.getAndIncrement(), headerField)) + .map(List::toArray) + .toArray(Object[][]::new); + } + + @DataProvider(name = "nameReferenceProvider") + public Object[][] nameReferenceProvider() { + AtomicLong tableIndex = new AtomicLong(); + Map> map = new HashMap<>(); + for (var headerField : StaticTable.HTTP3_HEADER_FIELDS) { + var name = headerField.name(); + var index = tableIndex.getAndIncrement(); + + if (!map.containsKey(name)) + map.put(name, new ArrayList<>()); + map.get(name).add(index); + } + return map.entrySet().stream() + .map(e -> List.of(e.getKey(), randomString(), e.getValue())) + .map(List::toArray).toArray(Object[][]::new); + } + + @DataProvider(name = "literalProvider") + public Object[][] literalProvider() { + var output = new String[100][]; + for (int i = 0; i < 100; i++) { + output[i] = new String[]{randomString(), randomString()}; + } + return output; + } + + private void assertNotFailed(AtomicReference errorRef) { + var error = errorRef.get(); + if (error != null) throw new AssertionError(error); + } + + @Test(dataProvider = "indexProvider") + public void encodeDecodeIndexedOnStaticTable(long index, HeaderField h) throws IOException { + var actual = allocateIndexTestBuffer(index); + List buffers = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + AtomicReference writerStub = new AtomicReference<>(); + EncoderDecoderConnector connector = new EncoderDecoderConnector(writerStub::set, writerStub::set); + var conn = connector.newEncoderDecoderPair(e -> false, + EncoderDecoderTest::qpackErrorHandler, + EncoderDecoderTest::qpackErrorHandler, + error::set); + + // Create encoder and decoder + var encoder = conn.encoder(); + + // Set encoder maximum dynamic table capacity + conn.encoderTable().setMaxTableCapacity(256); + // Set dynamic table capacity that doesn't exceed the max capacity value + conn.encoderTable().setCapacity(256); + + var decoder = conn.decoder(); + + // Create header frame reader and writer + var callback = new TestingCallBack(index, h.name(), h.value()); + var headerFrameReader = decoder.newHeaderFrameReader(callback); + var headerFrameWriter = encoder.newHeaderFrameWriter(); + + // create encoding context + Encoder.EncodingContext context = encoder.newEncodingContext( + 0, 0, headerFrameWriter); + + // Configures encoder for writing the header name:value pair + encoder.header(context, h.name(), h.value(), false); + + // Write the header + headerFrameWriter.write(actual); + assertNotEquals(actual.position(), 0); + actual.flip(); + buffers.add(actual); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + + // Decode generated prefix bytes and encoded headers + decoder.decodeHeader(buffers.get(0), false, headerFrameReader); + decoder.decodeHeader(buffers.get(1), true, headerFrameReader); + assertNotFailed(error); + } + + @Test(dataProvider = "nameReferenceProvider") + public void encodeDecodeLiteralWithNameRefOnStaticTable(String name, String value, List validIndices) throws IOException { + long index = Collections.max(validIndices); + boolean sensitive = random.nextBoolean(); + + var actual = allocateNameRefBuffer(index, value); + List buffers = new ArrayList<>(); + AtomicReference error = new AtomicReference<>(); + AtomicReference writerStub = new AtomicReference<>(); + + // Create encoder and decoder + EncoderDecoderConnector connector = new EncoderDecoderConnector(writerStub::set, writerStub::set); + var conn = connector.newEncoderDecoderPair(e -> false, + EncoderDecoderTest::qpackErrorHandler, + EncoderDecoderTest::qpackErrorHandler, + error::set); + var encoder = conn.encoder(); + var decoder = conn.decoder(); + + // Create header frame reader and writer + var headerFrameWriter = encoder.newHeaderFrameWriter(); + var callback = new TestingCallBack(validIndices, name, value, sensitive); + var headerFrameReader = decoder.newHeaderFrameReader(callback); + + // create encoding context + Encoder.EncodingContext context = encoder.newEncodingContext( + 0, 0, headerFrameWriter); + + // Configures encoder for writing the header name:value pair + encoder.header(context, name, value, sensitive); + + // Write the header + headerFrameWriter.write(actual); + assertNotEquals(actual.position(), 0); + actual.flip(); + buffers.add(actual); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + + // Decode headers + decoder.decodeHeader(buffers.get(0), false, headerFrameReader); + decoder.decodeHeader(buffers.get(1), true, headerFrameReader); + assertNotFailed(error); + } + + @Test(dataProvider = "literalProvider") + public void encodeDecodeLiteralWithLiteralNameOnStaticTable(String name, String value) throws IOException { + boolean sensitive = random.nextBoolean(); + List buffers = new ArrayList<>(); + var actual = allocateLiteralBuffer(name, value); + AtomicReference error = new AtomicReference<>(); + AtomicReference writerStub = new AtomicReference<>(); + + // Create encoder and decoder + EncoderDecoderConnector connector = new EncoderDecoderConnector(writerStub::set, writerStub::set); + var conn = connector.newEncoderDecoderPair(e -> false, + EncoderDecoderTest::qpackErrorHandler, + EncoderDecoderTest::qpackErrorHandler, + error::set); + var encoder = conn.encoder(); + var decoder = conn.decoder(); + + // Create header frame reader and writer + var headerFrameWriter = encoder.newHeaderFrameWriter(); + var callback = new TestingCallBack(name, value, sensitive); + var headerFrameReader = decoder.newHeaderFrameReader(callback); + + // create encoding context + Encoder.EncodingContext context = encoder.newEncodingContext( + 0, 0, headerFrameWriter); + + // Configures encoder for writing the header name:value conn + encoder.header(context, name, value, sensitive); + // Write the header + headerFrameWriter.write(actual); + assertNotEquals(actual.position(), 0); + actual.flip(); + buffers.add(actual); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + + // Decode headers + decoder.decodeHeader(buffers.get(0), false, headerFrameReader); + decoder.decodeHeader(buffers.get(1), true, headerFrameReader); + assertNotFailed(error); + } + + /* Test Methods */ + private void debug(ByteBuffer bb, String msg, boolean verbose) { + if (verbose) + System.out.printf("DEBUG[%s]: pos=%d, limit=%d, remaining=%d%n", + msg, bb.position(), bb.limit(), bb.remaining()); + System.out.printf("DEBUG[%s]: ", msg); + for (byte b : bb.array()) { + System.out.printf("(%s,%d) ", Integer.toBinaryString(b & 0xFF), b & 0xFF); + } + System.out.println(); + } + + private ByteBuffer allocateIndexTestBuffer(long index) { + /* + * Note on Integer Representation used for storing the length of name and value strings. + * Taken from RFC 7541 Section 5.1 + * + * "An integer is represented in two parts: a prefix that fills the current octet and an + * optional list of octets that are used if the integer value does not fit within the + * prefix. The number of bits of the prefix (called N) is a parameter of the integer + * representation. If the integer value is small enough, i.e., strictly less than 2N-1, it + * is encoded within the N-bit prefix. + * + * ... + * + * Otherwise, all the bits of the prefix are set to 1, and the value, decreased by 2N-1, is + * encoded using a list of one or more octets. The most significant bit of each octet is + * used as a continuation flag: its value is set to 1 except for the last octet in the list. + * The remaining bits of the octets are used to encode the decreased value." + * + * Use "null" for name, if name isn't being provided (i.e. for a nameRef); otherwise, buffer + * will be too large. + * + */ + int N = 6; // bits available in first byte + int size = 1; + index -= Math.pow(2, N) - 1; // number that you can store in first N bits + while (index >= 0) { + index -= 127; + size++; + } + return ByteBuffer.allocate(size + 2); + } + + private ByteBuffer allocateNameRefBuffer(long index, CharSequence value) { + int N = 4; + return allocateNameRefBuffer(N, index, value); + } + + private ByteBuffer allocateNameRefBuffer(int N, long index, CharSequence value) { + int vlen = Math.min(QuickHuffman.lengthOf(value), value.length()); + int size = 1 + vlen; + + index -= Math.pow(2, N) - 1; + while (index >= 0) { + index -= 127; + size++; + } + vlen -= 127; + size++; + while (vlen >= 0) { + vlen -= 127; + size++; + } + return ByteBuffer.allocate(size + 2); + } + + private ByteBuffer allocateLiteralBuffer(CharSequence name, CharSequence value) { + int N = 3; + return allocateLiteralBuffer(N, name, value); + } + + private ByteBuffer allocateLiteralBuffer(int N, CharSequence name, CharSequence value) { + int nlen = Math.min(QuickHuffman.lengthOf(name), name.length()); + int vlen = Math.min(QuickHuffman.lengthOf(value), value.length()); + int size = nlen + vlen; + + nlen -= Math.pow(2, N) - 1; + size++; + while (nlen >= 0) { + nlen -= 127; + size++; + } + + vlen -= 127; + size++; + while (vlen >= 0) { + vlen -= 127; + size++; + } + return ByteBuffer.allocate(size + 2); + } + + static final String LOREM = """ + Lorem ipsum dolor sit amet, consectetur adipiscing + elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris + nisi ut aliquip ex ea commodo consequat.Duis aute irure dolor in + reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla + pariatur.Excepteur sint occaecat cupidatat non proident, sunt in + culpa qui officia deserunt mollit anim id est laborum."""; + + private String randomString() { + int lower = random.nextInt(LOREM.length() - TEST_STR_MAX_LENGTH); + /** + * The empty string ("") is a valid value String in the static table and the random + * String returned cannot refer to an entry in the table. Therefore, we set the upper + * bound below to a minimum of 1. + */ + return LOREM.substring(lower, 1 + lower + random.nextInt(TEST_STR_MAX_LENGTH)); + } + + private static class TestingCallBack implements DecodingCallback { + final long index; + final boolean huffmanName, huffmanValue; + final boolean sensitive; + final String name, value; + final List validIndices; + + // Indexed + TestingCallBack(long index, String name, String value) { + this(index, null, name, value, false); + } + // Literal w/Literal Name + TestingCallBack(String name, String value, boolean sensitive) { + this(-1L, null, name, value, sensitive); + } + // Literal w/Name Reference + TestingCallBack(List validIndices, String name, String value, boolean sensitive) { + this(-1L, validIndices, name, value, sensitive); + } + TestingCallBack(long index, List validIndices, String name, String value, boolean sensitive) { + this.index = index; + this.validIndices = validIndices; + this.huffmanName = QuickHuffman.isHuffmanBetterFor(name); + this.huffmanValue = QuickHuffman.isHuffmanBetterFor(value); + this.sensitive = sensitive; + this.name = name; + this.value = value; + } + + @Override + public void onDecoded(CharSequence actualName, CharSequence value) { + fail("onDecoded should not be called"); + } + + @Override + public void onComplete() { + } + + @Override + public void onConnectionError(Throwable throwable, Http3Error http3Error) { + fail("Decoding error: " + http3Error, throwable); + } + + @Override + public long streamId() { + return 0; + } + + @Override + public void onIndexed(long actualIndex, CharSequence actualName, CharSequence actualValue) { + assertEquals(actualIndex, index); + assertEquals(actualName, name); + assertEquals(actualValue, value); + } + + @Override + public void onLiteralWithNameReference(long actualIndex, CharSequence actualName, + CharSequence actualValue, boolean huffmanValue, + boolean actualHideIntermediary) { + assertTrue(validIndices.contains(actualIndex)); + assertEquals(actualName.toString(), name); + assertEquals(actualValue.toString(), value); + assertEquals(huffmanValue, huffmanValue); + assertEquals(actualHideIntermediary, sensitive); + } + + @Override + public void onLiteralWithLiteralName(CharSequence actualName, boolean actualHuffmanName, + CharSequence actualValue, boolean actualHuffmanValue, + boolean actualHideIntermediary) { + assertEquals(actualName.toString(), name); + assertEquals(actualHuffmanName, huffmanName); + assertEquals(actualValue.toString(), value); + assertEquals(actualHuffmanValue, huffmanValue); + assertEquals(actualHideIntermediary, sensitive); + } + } +} diff --git a/test/jdk/java/net/httpclient/qpack/EncoderInstructionsReaderTest.java b/test/jdk/java/net/httpclient/qpack/EncoderInstructionsReaderTest.java new file mode 100644 index 00000000000..5c31762cdb3 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/EncoderInstructionsReaderTest.java @@ -0,0 +1,373 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @key randomness + * @library /test/lib + * @run junit/othervm -Djdk.internal.httpclient.qpack.log.level=NORMAL EncoderInstructionsReaderTest + */ + +import jdk.internal.net.http.hpack.QuickHuffman; +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.readers.IntegerReader; +import java.nio.ByteBuffer; +import jdk.internal.net.http.qpack.writers.IntegerWriter; +import jdk.internal.net.http.qpack.writers.StringWriter; +import jdk.test.lib.RandomFactory; +import org.junit.jupiter.api.RepeatedTest; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jdk.internal.net.http.qpack.readers.EncoderInstructionsReader; +import jdk.internal.net.http.qpack.readers.EncoderInstructionsReader.Callback; + +import java.util.*; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.IntStream; + +public class EncoderInstructionsReaderTest { + + EncoderInstructionsReader encoderInstructionsReader; + private static final Random RANDOM = RandomFactory.getRandom(); + + @RepeatedTest(5) + public void testCapacity() { + + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 1 | Capacity (5+) | + // +---+---+---+-------------------+ + + // create logger and callback + QPACK.Logger logger = QPACK.getLogger().subLogger("testCapacity"); + TestCallback callback = new TestCallback(); + + // create a random value to be assigned as capacity + Long expectedCapacity = RANDOM.nextLong(IntegerReader.QPACK_MAX_INTEGER_VALUE); + + // create integerWriter, set expected size for the bytebuffer and write to it + IntegerWriter integerWriter = new IntegerWriter(); + int bufferSize = requiredBufferSize(5, expectedCapacity); + ByteBuffer byteBuffer = ByteBuffer.allocate(bufferSize); + int payload = 0b0010_0000; + integerWriter.configure(expectedCapacity, 5, payload); + boolean result = integerWriter.write(byteBuffer); + + // assert that the writer finished and isn't expecting another bytebuffer + assert result; + + byteBuffer.flip(); + + // use EncoderInstructionReader and check it successfully reads the input + encoderInstructionsReader = new EncoderInstructionsReader(callback, logger); + encoderInstructionsReader.read(byteBuffer, -1); + + long actualCapacity = callback.capacity.get(); + assertEquals(expectedCapacity, actualCapacity, "expected capacity differed from actual result"); + } + + @RepeatedTest(10) + public void testInsertWithNameReference() { + + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 1 | T | Name Index (6+) | + // +---+---+-----------------------+ + // | H | Value Length (7+) | + // +---+---------------------------+ + // | Value String (Length bytes) | + // +-------------------------------+ + + QPACK.Logger logger = QPACK.getLogger() + .subLogger("testInsertWithNameReference"); + TestCallback callback = new TestCallback(); + + // Needs both an integer and String writer + IntegerWriter integerWriter = new IntegerWriter(); + StringWriter stringWriter = new StringWriter(); + boolean huffman = RANDOM.nextBoolean(); + boolean staticTable = RANDOM.nextBoolean(); + + int payload; + if (staticTable) { + payload = 0b1100_0000; + } else { + payload = 0b1000_0000; + } + + // get a random member of the dynamic table and create a random string to update it with + long index = RANDOM.nextLong(IntegerReader.QPACK_MAX_INTEGER_VALUE); + String value = randomString(); + + // calculate the size of the byteBuffer + int bufferSize = requiredBufferSize(6, index); + bufferSize += requiredBufferSize(7, value.length()); + + if (huffman) { + bufferSize += QuickHuffman.lengthOf(value); + } else { + bufferSize += value.length(); + } + + integerWriter.configure(index, 6, payload); + stringWriter.configure(value, huffman); + + boolean intWriterFinished = false; + boolean stringWriterFinished = false; + List byteBufferList = new ArrayList<>(); + + // Feed the writers with bytebuffers of random size until they total bufferSize, + // once each bytebuffer is full add it to byteBufferList to be read later + while (!stringWriterFinished) { + int randomSize = RANDOM.nextInt(0, bufferSize + 1); + bufferSize -= randomSize; + ByteBuffer byteBuffer = ByteBuffer.allocate(randomSize); + + if (!intWriterFinished) { + intWriterFinished = integerWriter.write(byteBuffer); + if (!intWriterFinished) { + // writer not finished, add bytebuffer to list of full bytebuffers + // then loop + byteBuffer.flip(); + byteBufferList.add(byteBuffer); + continue; + } + } + + // this stage should only be reached if the intWriter is finished + stringWriterFinished = stringWriter.write(byteBuffer); + byteBuffer.flip(); + byteBufferList.add(byteBuffer); + } + + encoderInstructionsReader = new EncoderInstructionsReader(callback, logger); + for (var byteBuffer : byteBufferList) { + encoderInstructionsReader.read(byteBuffer, -1); + } + + assertEquals(index, callback.indexInsert.nameIndex); + assertEquals(value, callback.indexInsert.value); + assertEquals(staticTable, callback.indexInsert.staticTable); + } + + @RepeatedTest(10) + public void testInsertWithLiteralName() { + + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 1 | H | Name Length (5+) | + // +---+---+---+-------------------+ + // | Name String (Length bytes) | + // +---+---------------------------+ + // | H | Value Length (7+) | + // +---+---------------------------+ + // | Value String (Length bytes) | + // +-------------------------------+ + + QPACK.Logger logger = QPACK.getLogger() + .subLogger("testInsertWithLiteralName"); + TestCallback callback = new TestCallback(); + + StringWriter stringWriter = new StringWriter(); + boolean huffman = RANDOM.nextBoolean(); + + int payload; + if (huffman) { + payload = 0b0110_0000; // static table = true + } else { + payload = 0b0100_0000; // static table = false + } + + //generate name and value then calculate the size required for the bytebuffer + String name = randomString(); + String value = randomString(); + + int bufferSize = requiredBufferSize(5, name.length()); + bufferSize += requiredBufferSize(7, value.length()); + bufferSize += value.length() + name.length(); + + List byteBuffers = new ArrayList<>(); + + // configure the stringWriter to take name + stringWriter.configure(name, 5, payload, huffman); + + boolean firstStringWriterFinished = false; + boolean secondStringWriterFinished = false; + + // Feed the writers with bytebuffers of random size until they total bufferSize, + // once each bytebuffer is full add it to byteBufferList to be read later + while (!secondStringWriterFinished) { + int randomSize = RANDOM.nextInt(0, bufferSize + 1); + bufferSize -= randomSize; + + ByteBuffer byteBuffer = ByteBuffer.allocate(randomSize); + + if (!firstStringWriterFinished) { + firstStringWriterFinished = stringWriter.write(byteBuffer); + + if (!firstStringWriterFinished) { + // writer not finished, add bytebuffer to array of full bytebuffers + // then loop + byteBuffer.flip(); + byteBuffers.add(byteBuffer); + continue; + } else { + // if the name has been written then reset the stringWriter + // so that it can be reused for value + stringWriter.reset(); + stringWriter.configure(value, huffman); + } + } + + // this stage should only be reached if the name has already been written + secondStringWriterFinished = stringWriter.write(byteBuffer); + byteBuffer.flip(); + byteBuffers.add(byteBuffer); + } + + System.err.println(name + " Attempting to insert value: " + value); + + encoderInstructionsReader = new EncoderInstructionsReader(callback, logger); + + for (var byteBuffer : byteBuffers) { + encoderInstructionsReader.read(byteBuffer, -1); + } + + assertEquals(name, callback.lastInsert.name); + assertEquals(value, callback.lastInsert.value); + } + + @RepeatedTest(5) + public void testDuplicate() { + // + // 0 1 2 3 4 5 6 7 + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 0 | Index (5+) | + // +---+---+---+-------------------+ + // + + QPACK.Logger logger = QPACK.getLogger() + .subLogger("testDuplicate"); + TestCallback callback = new TestCallback(); + + long index = RANDOM.nextLong(0, DT_NAMES.size()); + + IntegerWriter integerWriter = new IntegerWriter(); + int bufferSize = requiredBufferSize(5, index); + ByteBuffer byteBuffer = ByteBuffer.allocate(bufferSize); + int payload = 0b0000_0000; + integerWriter.configure(index, 5, payload); + boolean result = integerWriter.write(byteBuffer); + assert result; + + byteBuffer.flip(); + + encoderInstructionsReader = new EncoderInstructionsReader(callback, logger); + encoderInstructionsReader.read(byteBuffer, -1); + + assertEquals(index, callback.duplicate.get()); + } + + private static void checkPrefix(int N) { + if (N < 1 || N > 8) { + throw new IllegalArgumentException("1 <= N <= 8: N= " + N); + } + } + static int requiredBufferSize(int N, long value) { + checkPrefix(N); + int size = 1; + int max = (2 << (N - 1)) - 1; + if (value < max) { + return size; + } + size++; + value -= max; + while (value >= 128) { + value /= 128; + size++; + } + return size; + } + + private static final int TEST_STR_MAX_LENGTH = 20; + private static final String LOREM = """ + Lorem ipsum dolor sit amet, consectetur adipiscing + elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris + nisi ut aliquip ex ea commodo consequat.Duis aute irure dolor in + reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla + pariatur.Excepteur sint occaecat cupidatat non proident, sunt in + culpa qui officia deserunt mollit anim id est laborum.""" + .replaceAll(" ", "") + .replaceAll("\\W", ""); + + private static final List DT_NAMES = generateTableNames(40); + + private static List generateTableNames(int count) { + return IntStream.range(0, count) + .boxed() + .map(i -> randomString()) + .toList(); + } + + private static String randomString() { + int lower = RANDOM.nextInt(LOREM.length() - TEST_STR_MAX_LENGTH); + return LOREM.substring(lower, 1 + lower + RANDOM.nextInt(TEST_STR_MAX_LENGTH)); + } + + private static class TestCallback implements Callback { + + record LiteralInsert(String name, String value) { + } + + record IndexedInsert(boolean staticTable, Long nameIndex, String value) { + } + + final AtomicLong capacity = new AtomicLong(-1L); + final AtomicLong duplicate = new AtomicLong(-1L); + LiteralInsert lastInsert; + + IndexedInsert indexInsert; + + @Override + public void onCapacityUpdate(long capacity) { + this.capacity.set(capacity); + } + + @Override + public void onInsert(String name, String value) { + lastInsert = new LiteralInsert(name, value); + } + + @Override + public void onInsertIndexedName(boolean indexInStaticTable, long nameIndex, String valueString) { + indexInsert = new IndexedInsert(indexInStaticTable, nameIndex, valueString); + } + + @Override + public void onDuplicate(long duplicateValue) { + this.duplicate.set(duplicateValue); + } + } +} diff --git a/test/jdk/java/net/httpclient/qpack/EncoderInstructionsWriterTest.java b/test/jdk/java/net/httpclient/qpack/EncoderInstructionsWriterTest.java new file mode 100644 index 00000000000..fd1fc41d8f7 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/EncoderInstructionsWriterTest.java @@ -0,0 +1,343 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.StaticTable; +import jdk.internal.net.http.qpack.TableEntry; +import jdk.internal.net.http.qpack.readers.EncoderInstructionsReader; +import jdk.internal.net.http.qpack.readers.EncoderInstructionsReader.Callback; +import jdk.internal.net.http.qpack.readers.IntegerReader; +import jdk.internal.net.http.qpack.writers.EncoderInstructionsWriter; +import jdk.test.lib.RandomFactory; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import org.junit.Assert; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import static jdk.internal.net.http.qpack.TableEntry.EntryType.NAME; +import static org.junit.jupiter.api.Assertions.*; + +/* + * @test + * @key randomness + * @library /test/lib + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @run junit/othervm -Djdk.internal.httpclient.qpack.log.level=NORMAL EncoderInstructionsWriterTest + */ + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class EncoderInstructionsWriterTest { + + @RepeatedTest(10) + public void tableCapacityInstructionTest() throws Exception { + // Get test-case specific logger and a dynamic table instance + QPACK.Logger logger = QPACK.getLogger() + .subLogger("tableCapacityInstructionTest"); + var dynamicTable = new DynamicTable(logger); + // Generate random capacity value + long capacity = RANDOM.nextLong(IntegerReader.QPACK_MAX_INTEGER_VALUE); + logger.log(System.Logger.Level.TRACE, "Capacity value = " + capacity); + + // Initial dynamic table capacity - required to check for + // writer changing the capacity once the instruction is written. + long initialTableCapacity = capacity == 1234L ? 1235L : 1234L; + + // Create and configure encoder instruction writer for writing the table + // capacity update instruction + var encoderInstructionsWriter = new EncoderInstructionsWriter(); + int calculatedInstructionSize = + encoderInstructionsWriter.configureForTableCapacityUpdate(capacity); + + // Set max capacity to maximum possible value + dynamicTable.setMaxTableCapacity(IntegerReader.QPACK_MAX_INTEGER_VALUE); + + // Create dynamic table with initial capacity + dynamicTable.setCapacity(initialTableCapacity); + + // Perform write operation and then read the capacity update instruction + var callback = new TestEncoderInstructionsCallback(dynamicTable); + + int bytesWritten = writeThenReadInstruction(encoderInstructionsWriter, + callback, -1, dynamicTable, + (dt) -> dt.capacity() == initialTableCapacity, + logger); + + // We expect here to get a callback with the capacity value supplied + // to the instruction writer + assertEquals(capacity, callback.capacityFromCallback.get()); + + // We don't expect dynamic table capacity to be updated by the encoder + // instruction reader + assertNotEquals(capacity, dynamicTable.capacity()); + + // Check if size calculated by the EncoderInstructionsWriter matches + // the number of bytes written to the byte buffers + assertEquals(calculatedInstructionSize, bytesWritten); + } + + @ParameterizedTest + @MethodSource("nameReferenceInsertSource") + public void insertWithNameReferenceInstructionTest(boolean referencingStatic, + long nameIndex, int byteBufferSize) + throws Exception { + // Get test-case specific logger and a dynamic table instance + QPACK.Logger logger = QPACK.getLogger() + .subLogger("insertWithNameReferenceInstructionTest"); + var dynamicTable = dynamicTable(logger); + + // generate random value String + String value = randomString(); + TableEntry entry = new TableEntry(referencingStatic, nameIndex, "", value, NAME); + + // Create and configure encoder instruction writer for writing + // the "Insert With Name Reference" instruction + var writer = new EncoderInstructionsWriter(); + int calculatedInstructionSize = writer.configureForEntryInsertion(entry); + + // Perform write operation and then read back the insert entry instruction + var callback = new TestEncoderInstructionsCallback(dynamicTable); + int bytesWritten = writeThenReadInstruction(writer, callback, byteBufferSize, + dynamicTable, (dt) -> true, logger); + + // Check that reader callback values match values supplied to the writer + assertEquals(nameIndex, callback.lastNameInsert.index()); + assertEquals(value, callback.lastNameInsert.value()); + assertEquals(referencingStatic, callback.lastNameInsert.isStaticTable()); + + // Check if size calculated by the EncoderInstructionsWriter matches + // the number of bytes written to the byte buffers + assertEquals(calculatedInstructionSize, bytesWritten); + } + + private static Stream nameReferenceInsertSource() { + Stream staticTableCases = + RANDOM.longs(10, 0, + StaticTable.HTTP3_HEADER_FIELDS.size()) + .boxed() + .map(index -> Arguments.of(true, index, + RANDOM.nextInt(1, 65))); + Stream dynamicTableCases = + RANDOM.longs(10, 0, + DT_NAMES.size()) + .boxed() + .map(index -> Arguments.of(false, index, + RANDOM.nextInt(1, 65))); + return Stream.concat(staticTableCases, dynamicTableCases); + } + + @RepeatedTest(10) + public void insertWithLiteralInstructionTest() throws Exception { + // Get test-case specific logger and a dynamic table instance + QPACK.Logger logger = QPACK.getLogger() + .subLogger("insertWithLiteralInstructionTest"); + var dynamicTable = dynamicTable(logger); + + // Generate random strings for name:value entry + String name = randomString(); + String value = randomString(); + var tableEntry = new TableEntry(name, value); + // Create and configure encoder instruction writer for writing the "Insert With Literal Name" + // instruction + var writer = new EncoderInstructionsWriter(); + int calculatedInstructionSize = writer.configureForEntryInsertion(tableEntry); + + var callback = new TestEncoderInstructionsCallback(dynamicTable); + int writtenBytes = writeThenReadInstruction(writer, callback, -1, + dynamicTable, (dt) -> true, logger); + + // Check that reader callback values match values supplied to the writer + assertEquals(name, callback.lastLiteralInsert.name()); + assertEquals(value, callback.lastLiteralInsert.value()); + // Check if size calculated by the EncoderInstructionsWriter matches the number of + // bytes written to the byte buffers + assertEquals(calculatedInstructionSize, writtenBytes); + } + + @RepeatedTest(10) + public void duplicateInstructionTest() throws Exception { + // Get test-case specific logger and a dynamic table instance + QPACK.Logger logger = QPACK.getLogger() + .subLogger("duplicateInstructionTest"); + var dynamicTable = dynamicTable(logger); + // Absolute id to duplicate + // size() - 1 is used to not duplicate the last entry, otherwise the + // DynamicTable.duplicate(relativeIndex) checks below will fail + long idToDuplicate = RANDOM.nextLong(0, DT_NAMES.size() - 1); + // Create and configure encoder instruction writer for writing the "Duplicate" + // instruction + var writer = new EncoderInstructionsWriter(); + // insert count - 1 is the head element index + long relativeIndex = dynamicTable.insertCount() - 1 - idToDuplicate; + int calculatedInstructionSize = writer.configureForEntryDuplication(relativeIndex); + var callback = new TestEncoderInstructionsCallback(dynamicTable); + int writtenBytes = writeThenReadInstruction(writer, callback, -1, + dynamicTable, (dt) -> true, logger); + HeaderField original = dynamicTable.get(idToDuplicate); + HeaderField head = dynamicTable.getRelative(0); + + // Check that reader callback values match values supplied to the writer + assertEquals(relativeIndex, callback.duplicateIdFromCallback.get()); + // Check that DynamicTable.duplicate(relativeIndex) properly recreates + // the referenced entry + assertEquals(head.name(), original.name()); + assertEquals(head.value(), original.value()); + // Check if size calculated by the EncoderInstructionsWriter matches the number + // of bytes written to the byte buffers + assertEquals(calculatedInstructionSize, writtenBytes); + + } + + // Test runner method that writes an instruction with the supplied encoder + // instructions writer pre-configured for writing of a specific instruction. + private static int writeThenReadInstruction( + EncoderInstructionsWriter writer, TestEncoderInstructionsCallback callback, + int bufferSize, DynamicTable dynamicTable, + Function partialWriteCheck, + QPACK.Logger logger) throws Exception { + var buffers = new ArrayList(); + boolean writeDone = false; + int writtenBytes = 0; + // Write instruction to the list of byte buffers + while (!writeDone) { + int allocSize = bufferSize == -1 ? RANDOM.nextInt(1, 65) : bufferSize; + var buffer = ByteBuffer.allocate(allocSize); + writeDone = writer.write(buffer); + writtenBytes += buffer.position(); + if (!writeDone && !partialWriteCheck.apply(dynamicTable)) { + Assert.fail("Wrong dynamic table state after partial write"); + } + buffer.flip(); + buffers.add(buffer); + } + + // Read back the data from byte buffers + var encoderInstructionReader = new EncoderInstructionsReader(callback, logger); + + // Read out an instruction and return the callback instance + for (var bb : buffers) { + encoderInstructionReader.read(bb, -1); + } + return writtenBytes; + } + + private static final Random RANDOM = RandomFactory.getRandom(); + private static final int TEST_STR_MAX_LENGTH = 20; + private static final String LOREM = """ + Lorem ipsum dolor sit amet, consectetur adipiscing + elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris + nisi ut aliquip ex ea commodo consequat.Duis aute irure dolor in + reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla + pariatur.Excepteur sint occaecat cupidatat non proident, sunt in + culpa qui officia deserunt mollit anim id est laborum.""" + .replaceAll(" ", "") + .replaceAll("\\W", ""); + + private static final List DT_NAMES = generateTableNames(40); + + private static List generateTableNames(int count) { + return IntStream.range(0, count) + .boxed() + .map(i -> randomString()) + .toList(); + } + + private static DynamicTable dynamicTable(QPACK.Logger logger) { + var dt = new DynamicTable(logger.subLogger("dynamicTable")); + dt.setMaxTableCapacity(4096); + dt.setCapacity(4096); + for (var name : DT_NAMES) { + dt.insert(name, randomString()); + } + return dt; + } + + private static String randomString() { + int lower = RANDOM.nextInt(LOREM.length() - TEST_STR_MAX_LENGTH); + return LOREM.substring(lower, 1 + lower + RANDOM.nextInt(TEST_STR_MAX_LENGTH)); + } + + private static class TestEncoderInstructionsCallback implements Callback { + + final DynamicTable dynamicTable; + public TestEncoderInstructionsCallback(DynamicTable dynamicTable) { + this.dynamicTable = dynamicTable; + } + + record LiteralInsert(String name, String value) { + } + + record IndexedNameInsert(boolean isStaticTable, long index, String value) { + } + + final AtomicLong capacityFromCallback = new AtomicLong(-1L); + final AtomicLong duplicateIdFromCallback = new AtomicLong(-1L); + LiteralInsert lastLiteralInsert; + IndexedNameInsert lastNameInsert; + + @Override + public void onCapacityUpdate(long capacity) { + capacityFromCallback.set(capacity); + } + + @Override + public void onInsert(String name, String value) { + lastLiteralInsert = new LiteralInsert(name, value); + } + + @Override + public void onInsertIndexedName(boolean indexInStaticTable, long nameIndex, String valueString) { + lastNameInsert = new IndexedNameInsert(indexInStaticTable, nameIndex, valueString); + } + + @Override + public void onDuplicate(long l) { + dynamicTable.duplicate(l); + duplicateIdFromCallback.set(l); + } + } +} diff --git a/test/jdk/java/net/httpclient/qpack/EncoderTest.java b/test/jdk/java/net/httpclient/qpack/EncoderTest.java new file mode 100644 index 00000000000..a80a0899420 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/EncoderTest.java @@ -0,0 +1,670 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.hpack.QuickHuffman; +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.http.http3.streams.Http3Streams.StreamType; +import jdk.internal.net.http.http3.streams.QueuingStreamPair; +import jdk.internal.net.http.http3.streams.UniStreamPair; +import jdk.internal.net.http.qpack.Encoder; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.qpack.writers.IntegerWriter; +import jdk.internal.net.http.qpack.writers.StringWriter; +import jdk.internal.net.http.qpack.StaticTable; +import jdk.internal.net.http.quic.ConnectionTerminator; +import jdk.internal.net.http.quic.QuicConnection; +import jdk.internal.net.http.quic.QuicEndpoint; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.http.quic.streams.QuicBidiStream; +import jdk.internal.net.http.quic.streams.QuicReceiverStream; +import jdk.internal.net.http.quic.streams.QuicSenderStream; +import jdk.internal.net.http.quic.streams.QuicSenderStream.SendingStreamState; +import jdk.internal.net.http.quic.streams.QuicStream; +import jdk.internal.net.http.quic.streams.QuicStreamWriter; +import jdk.internal.net.quic.QuicTLSEngine; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.stream.Stream; + +import static org.testng.Assert.*; + +/* + * @test + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @run testng/othervm EncoderTest + */ +public class EncoderTest { + private final Random random = new Random(); + private final IntegerWriter intWriter = new IntegerWriter(); + private final StringWriter stringWriter = new StringWriter(); + private static final int TEST_STR_MAX_LENGTH = 10; + + @DataProvider(name = "indexProvider") + public Object[][] indexProvider() { + AtomicInteger tableIndex = new AtomicInteger(); + return StaticTable.HTTP3_HEADER_FIELDS.stream() + .map(headerField -> List.of(tableIndex.getAndIncrement(), headerField)) + .map(List::toArray) + .toArray(Object[][]::new); + } + + @DataProvider + public Object[][] staticNameReferenceProvider() { + AtomicInteger tableIndex = new AtomicInteger(); + Map> map = new HashMap<>(); + for (var headerField : StaticTable.HTTP3_HEADER_FIELDS) { + var name = headerField.name(); + var index = tableIndex.getAndIncrement(); + if (!map.containsKey(name)) + map.put(name, new ArrayList<>()); + map.get(name).add(index); + } + return map.entrySet().stream() + .map(e -> List.of(e.getKey(), randomString(), e.getValue())) + .map(List::toArray).toArray(Object[][]::new); + } + + @DataProvider + public Object[][] literalsProvider() { + var output = new String[100][]; + for (int i = 0; i < 100; i++) { + output[i] = new String[]{ randomString(), randomString() }; + } + return output; + } + + private static class TestErrorHandler implements + UniStreamPair.StreamErrorHandler { + final Consumer handler; + private TestErrorHandler(Consumer handler) { + this.handler = handler; + } + @Override + public void onError(QuicStream stream, UniStreamPair uniStreamPair, Throwable throwable) { + handler.accept(throwable); + } + public static TestErrorHandler of(Consumer handler) { + return new TestErrorHandler(handler); + } + } + + QueuingStreamPair createEncoderStreams(Consumer receiver, + Consumer errorHandler, + TestQuicConnection quicConnection) { + return new QueuingStreamPair(StreamType.QPACK_ENCODER, + quicConnection, + receiver, + TestErrorHandler.of(errorHandler), + Utils.getDebugLogger(() -> "quic-encoder-test") + ); + } + + private void assertNotFailed(AtomicReference errorRef) { + var error = errorRef.get(); + if (error != null) throw new AssertionError(error); + } + + private static void qpackErrorHandler(Throwable error, Http3Error http3Error) { + fail(http3Error + "QPACK error:" + http3Error, error); + } + + @Test(dataProvider = "indexProvider") + public void testFieldLineWriterWithStaticIndex(int index, HeaderField h) { + var actual = allocateIndexBuffer(index); + var expected = writeIndex(index); + var quicConnection = new TestQuicConnection(); + AtomicReference error = new AtomicReference<>(); + var encoder = new Encoder(insert -> false, + (receiver) -> createEncoderStreams(receiver, error::set, quicConnection), + EncoderTest::qpackErrorHandler); + var headerFrameWriter = encoder.newHeaderFrameWriter(); + // create encoding context + Encoder.EncodingContext context = + encoder.newEncodingContext(0, 0, headerFrameWriter); + + encoder.header(context, h.name(), h.value(), false); + headerFrameWriter.write(actual); + assertNotEquals(actual.position(), 0); + actual.flip(); + + assertEquals(actual, expected, debug(h.name(), h.value(), actual, expected)); + assertNotFailed(error); + } + + @Test(dataProvider = "staticNameReferenceProvider") + public void testInsertWithStaticTableNameReference(String name, String value, List validIndices) { + int index = Collections.max(validIndices); + + var actual = allocateInsertNameRefBuffer(index, value); + var expected = writeInsertNameRef(value, validIndices); + var quicConnection = new TestQuicConnection(); + AtomicReference error = new AtomicReference<>(); + var encoder = new Encoder(insert -> true, + (receiver) -> createEncoderStreams(receiver, error::set, quicConnection), + EncoderTest::qpackErrorHandler); + configureDynamicTableSize(encoder); + + var headerFrameWriter = encoder.newHeaderFrameWriter(); + // create encoding context + Encoder.EncodingContext context = + encoder.newEncodingContext(0, 0, headerFrameWriter); + encoder.header(context, name, value, false); + headerFrameWriter.write(actual); + assertNotEquals(actual.position(), 0); + actual.flip(); + + TestQuicStreamWriter quicStreamWriter = quicConnection.sender.writer; + + assertTrue(expected.contains(quicStreamWriter.get()), debug(name, value, quicStreamWriter.get(), expected)); + assertNotFailed(error); + } + + @Test(dataProvider = "staticNameReferenceProvider") + public void testFieldLineWithStaticTableNameReference(String name, String value, List validIndices) { + int index = Collections.max(validIndices); + boolean sensitive = random.nextBoolean(); + + var actual = allocateNameRefBuffer(index, value); + var expected = writeNameRef(sensitive, value, validIndices); + var quicConnection = new TestQuicConnection(); + AtomicReference error = new AtomicReference<>(); + var encoder = new Encoder(insert -> false, (receiver) -> + createEncoderStreams(receiver, error::set, quicConnection), + EncoderTest::qpackErrorHandler); + + var headerFrameWriter = encoder.newHeaderFrameWriter(); + // create encoding context + Encoder.EncodingContext context = + encoder.newEncodingContext(0, 0, headerFrameWriter); + encoder.header(context, name, value, sensitive); + headerFrameWriter.write(actual); + assertNotEquals(actual.position(), 0); + actual.flip(); + + assertTrue(expected.contains(actual), debug(name, value, actual, expected)); + assertNotFailed(error); + } + + @Test(dataProvider = "literalsProvider") + public void testInsertWithLiterals(String name, String value) { + var expected = writeInsertLiteral(name, value); + var actual = allocateInsertLiteralBuffer(name, value); + var quicConnection = new TestQuicConnection(); + AtomicReference error = new AtomicReference<>(); + var encoder = new Encoder(insert -> true, (receiver) -> + createEncoderStreams(receiver, error::set, quicConnection), + EncoderTest::qpackErrorHandler); + configureDynamicTableSize(encoder); + + var headerFrameWriter = encoder.newHeaderFrameWriter(); + // create encoding context + Encoder.EncodingContext context = + encoder.newEncodingContext(0, 0, headerFrameWriter); + encoder.header(context, name, value, false); + headerFrameWriter.write(actual); + assertNotEquals(actual.position(), 0); + actual.flip(); + TestQuicStreamWriter quicStreamWriter = quicConnection.sender.writer; + assertEquals(quicStreamWriter.get(), expected, debug(name, value, quicStreamWriter.get(), expected)); + assertNotFailed(error); + } + + @Test(dataProvider = "literalsProvider") + public void testFieldLineEncodingWithLiterals(String name, String value) { + boolean sensitive = random.nextBoolean(); + + var expected = writeLiteral(sensitive, name, value); + var actual = allocateLiteralBuffer(name, value); + var quicConnection = new TestQuicConnection(); + AtomicReference error = new AtomicReference<>(); + var encoder = new Encoder(insert -> false, (receiver) -> + createEncoderStreams(receiver, error::set, quicConnection), + EncoderTest::qpackErrorHandler); + + var headerFrameWriter = encoder.newHeaderFrameWriter(); + // create encoding context + Encoder.EncodingContext context = + encoder.newEncodingContext(0, 0, headerFrameWriter); + encoder.header(context, name, value, sensitive); + headerFrameWriter.write(actual); + assertNotEquals(actual.position(), 0); + actual.flip(); + + assertEquals(actual, expected, debug(name, value, actual, expected)); + assertNotFailed(error); + } + + // Test cases which test insertion of entries to the dynamic need to have + // dynamic table with non-zero capacity + private static void configureDynamicTableSize(Encoder encoder) { + // Set encoder maximum dynamic table capacity + SettingsFrame settingsFrame = SettingsFrame.defaultRFCSettings(); + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, 256); + ConnectionSettings settings = ConnectionSettings.createFrom(settingsFrame); + encoder.configure(settings); + // Set dynamic table capacity that doesn't exceed the max capacity value + encoder.setTableCapacity(256); + } + + /* Test Methods */ + private class TestQuicStreamWriter extends QuicStreamWriter { + volatile ByteBuffer b = null; + final TestQuicSenderStream sender; + + TestQuicStreamWriter(SequentialScheduler scheduler, TestQuicSenderStream sender) { + super(scheduler); + this.sender = sender; + } + + private void write(ByteBuffer bb) { + b = bb; + } + public ByteBuffer get() { + if (b == null) { + fail("TestQuicStreamWriter buffer is null"); + } + return b; + } + @Override + public SendingStreamState sendingState() { return null;} + @Override + public void scheduleForWriting(ByteBuffer buffer, boolean last) throws IOException { + write(buffer); + } + @Override + public void queueForWriting(ByteBuffer buffer) throws IOException { + write(buffer); + } + @Override + public long credit() { return Long.MAX_VALUE;} + @Override + public void reset(long errorCode) {} + @Override + public QuicSenderStream stream() { + return connected() ? sender : null; + } + @Override + public boolean connected() { + return sender.writer == this; + } + } + + private class TestQuicConnection extends QuicConnection { + final TestQuicSenderStream sender = new TestQuicSenderStream(); + @Override + public boolean isOpen() {return true;} + @Override + public TerminationCause terminationCause() {return null;} + @Override + public QuicTLSEngine getTLSEngine() {return null;} + @Override + public InetSocketAddress peerAddress() {return null;} + @Override + public SocketAddress localAddress() {return null;} + @Override + public CompletableFuture startHandshake() {return null;} + @Override + public CompletableFuture openNewLocalBidiStream(Duration duration) { + return null; + } + @Override + public CompletableFuture openNewLocalUniStream(Duration duration) { + return MinimalFuture.completedFuture(sender); + } + @Override + public void addRemoteStreamListener(Predicate streamConsumer) { + } + @Override + public boolean removeRemoteStreamListener(Predicate streamConsumer) { + return false; + } + @Override + public Stream quicStreams() { + return null; + } + @Override + public CompletableFuture handshakeReachedPeer() { + return MinimalFuture.completedFuture(null); + } + @Override + public CompletableFuture requestSendPing() { + return MinimalFuture.completedFuture(-1L); + } + + @Override + public ConnectionTerminator connectionTerminator() { + return null; + } + + @Override + public String dbgTag() { return null; } + + @Override + public String logTag() { + return null; + } + } + + class TestQuicSenderStream implements QuicSenderStream { + private static AtomicLong ids = new AtomicLong(); + private final long id; + TestQuicStreamWriter writer; + TestQuicSenderStream() { + id = ids.getAndIncrement() * 4 + type(); + } + @Override + public SendingStreamState sendingState() { return SendingStreamState.READY; } + @Override + public QuicStreamWriter connectWriter(SequentialScheduler scheduler) { + return writer = new TestQuicStreamWriter(scheduler, this); + } + @Override + public void disconnectWriter(QuicStreamWriter writer) { } + @Override + public void reset(long errorCode) { } + @Override + public long dataSent() { return 0; } + @Override + public long streamId() { return id; } + @Override + public StreamMode mode() { return null; } + @Override + public boolean isClientInitiated() { return true; } + @Override + public boolean isServerInitiated() { return false; } + @Override + public boolean isBidirectional() { return false; } + @Override + public boolean isLocalInitiated() { return true; } + @Override + public boolean isRemoteInitiated() { return false; } + @Override + public int type() { return 0x02; } + @Override + public StreamState state() { return SendingStreamState.READY; } + + @Override + public long sndErrorCode() { return -1; } + @Override + public boolean stopSendingReceived() { return false; } + } + + private String debug(String name, String value, ByteBuffer actual, ByteBuffer expected) { + return debug(name, value, actual, List.of(expected)); + } + + private String debug(String name, String value, ByteBuffer actual, List expected) { + var output = new StringBuilder(); + output.append("\n\nBUFFER CONTENTS\n"); + output.append("----------------\n"); + output.append("DEBUG[NAME]: %s\nDEBUG[VALUE]: %s\n".formatted(name, value)); + output.append("DEBUG[ACTUAL]: "); + for (byte b : actual.array()) { + output.append("(%s,%d) ".formatted(Integer.toBinaryString(b & 0xFF), (int)(b & 0xFF))); + } + output.append("\n"); + + output.append("DEBUG[EXPECTED]: "); + for (var bb : expected) { + for (byte b : bb.array()) { + output.append("(%s,%d) ".formatted(Integer.toBinaryString(b & 0xFF), (int) (b & 0xFF))); + } + output.append("\n"); + } + return output.toString(); + } + + private ByteBuffer writeIndex(int index) { + int N = 6; + int payload = 0b1100_0000; // use static table = true; + var bb = ByteBuffer.allocate(2); + + intWriter.configure(index, N, payload); + intWriter.write(bb); + intWriter.reset(); + + bb.flip(); + return bb; + } + + private List writeNameRef(boolean sensitive, String value, List validIndices) { + int N = 4; + int payload = 0b0101_0000; // static table = true + return writeNameRef(N, payload, sensitive, value, validIndices); + } + + private List writeInsertNameRef(String value, List validIndices) { + int N = 6; + int payload = 0b1100_0000; // static table = true + return writeNameRef(N, payload, false, value, validIndices); + } + + private List writeNameRef(int N, int payload, boolean sensitive, String value, List validIndices) { + // Each Header name may have several valid indices associated with it. + List output = new ArrayList<>(); + for (int index : validIndices) { + if (sensitive) + payload |= 0b0010_0000; + var bb = allocateNameRefBuffer(N, index, value); + intWriter.configure(index, N, payload); + intWriter.write(bb); + intWriter.reset(); + + boolean huffman = QuickHuffman.isHuffmanBetterFor(value); + int huffmanMask = 0b0000_0000; + if (huffman) + huffmanMask = 0b1000_0000; + stringWriter.configure(value, 7, huffmanMask, huffman); + stringWriter.write(bb); + stringWriter.reset(); + + bb.flip(); + output.add(bb); + } + + return output; + } + + private ByteBuffer writeInsertLiteral(String name, String value) { + int N = 5; + int payload = 0b0100_0000; + boolean huffmanName = QuickHuffman.isHuffmanBetterFor(name); + if (huffmanName) + payload |= 0b0010_0000; + return writeLiteral(N, payload, name, value); + } + + private ByteBuffer writeLiteral(boolean sensitive, String name, String value) { + int N = 3; + int payload = 0b0010_0000; // static table = true + if (sensitive) + payload |= 0b0001_0000; + + if (QuickHuffman.isHuffmanBetterFor(name)) + payload |= 0b0000_1000; + return writeLiteral(N, payload, name, value); + } + + private ByteBuffer writeLiteral(int N, int payload, String name, String value) { + var bb = allocateLiteralBuffer(N, name, value); + + boolean huffmanName = QuickHuffman.isHuffmanBetterFor(name); + stringWriter.configure(name, N, payload, huffmanName); + stringWriter.write(bb); + stringWriter.reset(); + + boolean huffmanValue = QuickHuffman.isHuffmanBetterFor(value); + int huffmanMask = 0b0000_0000; + if (huffmanValue) + huffmanMask = 0b1000_0000; + stringWriter.configure(value, 7, huffmanMask, huffmanValue); + stringWriter.write(bb); + stringWriter.reset(); + + bb.flip(); + return bb; + } + + private ByteBuffer allocateIndexBuffer(int index) { + /* + * Note on Integer Representation used for storing the length of name and value strings. + * Taken from RFC 7541 Section 5.1 + * + * "An integer is represented in two parts: a prefix that fills the current octet and an + * optional list of octets that are used if the integer value does not fit within the + * prefix. The number of bits of the prefix (called N) is a parameter of the integer + * representation. If the integer value is small enough, i.e., strictly less than 2N-1, it + * is encoded within the N-bit prefix. + * + * ... + * + * Otherwise, all the bits of the prefix are set to 1, and the value, decreased by 2N-1, is + * encoded using a list of one or more octets. The most significant bit of each octet is + * used as a continuation flag: its value is set to 1 except for the last octet in the list. + * The remaining bits of the octets are used to encode the decreased value." + * + * Use "null" for name, if name isn't being provided (i.e. for a nameRef); otherwise, buffer + * will be too large. + * + */ + int N = 6; // bits available in first byte + int size = 1; + index -= Math.pow(2, N) - 1; // number that you can store in first N bits + while (index >= 0) { + index -= 127; + size++; + } + return ByteBuffer.allocate(size); + } + + private ByteBuffer allocateInsertNameRefBuffer(int index, CharSequence value) { + int N = 6; + return allocateNameRefBuffer(N, index, value); + } + + private ByteBuffer allocateNameRefBuffer(int index, CharSequence value) { + int N = 4; + return allocateNameRefBuffer(N, index, value); + } + + private ByteBuffer allocateNameRefBuffer(int N, int index, CharSequence value) { + int vlen = Math.min(QuickHuffman.lengthOf(value), value.length()); + int size = 1 + vlen; + + index -= Math.pow(2, N) - 1; + while (index >= 0) { + index -= 127; + size++; + } + vlen -= 127; + size++; + while (vlen >= 0) { + vlen -= 127; + size++; + } + return ByteBuffer.allocate(size); + } + + private ByteBuffer allocateInsertLiteralBuffer(CharSequence name, CharSequence value) { + int N = 5; + return allocateLiteralBuffer(N, name, value); + } + + private ByteBuffer allocateLiteralBuffer(CharSequence name, CharSequence value) { + int N = 3; + return allocateLiteralBuffer(N, name, value); + } + + private ByteBuffer allocateLiteralBuffer(int N, CharSequence name, CharSequence value) { + int nlen = Math.min(QuickHuffman.lengthOf(name), name.length()); + int vlen = Math.min(QuickHuffman.lengthOf(value), value.length()); + int size = nlen + vlen; + + nlen -= Math.pow(2, N) - 1; + size++; + while (nlen >= 0) { + nlen -= 127; + size++; + } + + vlen -= 127; + size++; + while (vlen >= 0) { + vlen -= 127; + size++; + } + return ByteBuffer.allocate(size); + } + + static final String LOREM = """ + Lorem ipsum dolor sit amet, consectetur adipiscing + elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris + nisi ut aliquip ex ea commodo consequat.Duis aute irure dolor in + reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla + pariatur.Excepteur sint occaecat cupidatat non proident, sunt in + culpa qui officia deserunt mollit anim id est laborum."""; + + private String randomString() { + int lower = random.nextInt(LOREM.length() - TEST_STR_MAX_LENGTH); + /** + * The empty string ("") is a valid value String in the static table and the random + * String returned cannot refer to a entry in the table. Therefore, we set the upper + * bound below to a minimum of 1. + */ + return LOREM.substring(lower, 1 + lower + random.nextInt(TEST_STR_MAX_LENGTH)); + } +} diff --git a/test/jdk/java/net/httpclient/qpack/EntriesEvictionTest.java b/test/jdk/java/net/httpclient/qpack/EntriesEvictionTest.java new file mode 100644 index 00000000000..c295ec5d01d --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/EntriesEvictionTest.java @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @summary test dynamic table entry eviction scenarios + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack:+open + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @run testng/othervm -Djdk.internal.httpclient.qpack.log.level=EXTRA EntriesEvictionTest + */ + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.Encoder.SectionReference; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.QPACK.Logger; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +public class EntriesEvictionTest { + + @Test(dataProvider = "evictionScenarios") + public void evictionInsertionTest(TestHeader headerToAdd, + SectionReference sectionReference, + long insertedId, + long largestEvictedId) { + Logger logger = QPACK.getLogger().subLogger("evictionInsertionTest"); + DynamicTable dynamicTable = new DynamicTable(logger); + + dynamicTable.setMaxTableCapacity(TABLE_CAPACITY); + dynamicTable.setCapacity(TABLE_CAPACITY); + + for (TestHeader header : TEST_HEADERS) { + dynamicTable.insert(header.name, header.value); + } + + // Insert last entry + long id = dynamicTable.insert(headerToAdd.name, headerToAdd.value, sectionReference); + + Assert.assertEquals(id, insertedId); + + if (largestEvictedId != -1) { + // Check that evicted entry with the largest absolute index + // is not accessible + Assert.assertThrows(() -> dynamicTable.get(largestEvictedId)); + // Check that an entry after that can be acquired with its + // absolute index + dynamicTable.get(largestEvictedId + 1); + } + + if (insertedId != -1) { + HeaderField insertedField = dynamicTable.get(insertedId); + Assert.assertEquals(insertedField, + new HeaderField(headerToAdd.name(), headerToAdd.value())); + } + } + + @DataProvider + public static Object[][] evictionScenarios() { + + // Header that requires only one entry to be evicted + String oneSizedValue = HEADER_PREXIX + String.format("%03d", HEADERS_COUNT); + TestHeader headerToAddWithOneEviction = TestHeader.newHeader( + oneSizedValue, oneSizedValue); + + // Header that requires two entries to be evicted + // "e" is repeated 16 times to compensate 32 bytes - 16 in header name, + // another 16 in header value + String doubleSizedHeaderValue = (HEADER_PREXIX + "Dbl").repeat(2) + + "e".repeat(16); + TestHeader headerToAddWithTwoEvictions = TestHeader.newHeader( + doubleSizedHeaderValue, doubleSizedHeaderValue); + + // Construct header with size equals to the dynamic table capacity + // / 2 since the string used two times - for headers name and value + String hugeStrPart1 = TEST_HEADERS.stream().map(TestHeader::name) + .collect(Collectors.joining()); + String hugeStrToCompensate32PerElement = "a".repeat(32 * (HEADERS_COUNT - 1) / 2); + String hugeStr = hugeStrPart1 + hugeStrToCompensate32PerElement; + + TestHeader hugeEntryWithAllEviction = TestHeader.newHeader(hugeStr, hugeStr); + + // Header with size 2 bytes bigger than the dynamic table capacity + TestHeader hugeEntryExceedsCapacity = TestHeader.newHeader( + hugeStr + "H", hugeStr + "H"); + + return new Object[][]{ + // Evict one to have space for a new entry + {headerToAddWithOneEviction, SectionReference.noReferences(), + HEADERS_COUNT, 0}, + + // Evict all entries to have space for a new entry + {hugeEntryWithAllEviction, SectionReference.noReferences(), + HEADERS_COUNT, HEADERS_COUNT - 1}, + + // Not enough capacity for a new entry even if all entries are evicted + {hugeEntryExceedsCapacity, SectionReference.noReferences(), + -1, -1}, + + // Entry with size == capacity and there are section references preventing + // eviction of all entries + {hugeEntryWithAllEviction, new SectionReference(0, 1), + -1, -1}, + + // Element with 0 absolute id is not referenced and therefore can be evicted + {headerToAddWithOneEviction, new SectionReference(1, 2), + HEADERS_COUNT, 0}, + + // Elements with 0 and 1 ids are not referenced and should be + // evicted to insert double-sized entry + {headerToAddWithTwoEvictions, new SectionReference(2, 3), + HEADERS_COUNT, 1}, + + // Element with 1 id cannot be evicted since it is + // referenced + {headerToAddWithTwoEvictions, new SectionReference(1, 3), + -1, -1} + }; + } + + record TestHeader(String name, String value, long size) { + public static TestHeader newHeader(String name, String value) { + return new TestHeader(name, value, 32L + name.length() + value.length()); + } + + @Override + public String toString() { + return name + ":" + value + "[" + size + "]"; + } + } + + // Number of headers to insert before running an eviction scenario + private static final int HEADERS_COUNT = 3; + // Test header prefix + private static final String HEADER_PREXIX = "HeaderPrefix"; + // List of headers to insert before running an eviction scenario + private static final List TEST_HEADERS; + // Table capacity required by test scenarios + private static final long TABLE_CAPACITY; + + static { + List testHeaders = new ArrayList<>(); + long capacity = 0; + + // List of headers to prepopulate dynamic table before running + // test cases + for (int i = 0; i < HEADERS_COUNT; i++) { + String headerStr = HEADER_PREXIX + String.format("%03d", i); + var header = TestHeader.newHeader(headerStr, headerStr); + capacity += header.size(); + testHeaders.add(header); + } + TEST_HEADERS = testHeaders; + TABLE_CAPACITY = capacity; + } +} diff --git a/test/jdk/java/net/httpclient/qpack/FieldSectionPrefixTest.java b/test/jdk/java/net/httpclient/qpack/FieldSectionPrefixTest.java new file mode 100644 index 00000000000..e9a7a86e975 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/FieldSectionPrefixTest.java @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @modules java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @run testng FieldSectionPrefixTest + */ + + +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.FieldSectionPrefix; +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.readers.HeaderFrameReader; +import jdk.internal.net.http.qpack.writers.FieldLineSectionPrefixWriter; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicLong; + +public class FieldSectionPrefixTest { + + private static final long DT_CAPACITY = 220L; + private static final long MAX_ENTRIES = DT_CAPACITY / 32L; + + @Test(dataProvider = "encodingCases") + public void encodingTest(long base, long requiredInsertCount, + byte expectedRic, byte expectedBase) { + var fieldSectionPrefix = new FieldSectionPrefix(requiredInsertCount, base); + FieldLineSectionPrefixWriter writer = new FieldLineSectionPrefixWriter(); + int bytesNeeded = writer.configure(fieldSectionPrefix, MAX_ENTRIES); + var byteBuffer = ByteBuffer.allocate(bytesNeeded); + writer.write(byteBuffer); + byteBuffer.flip(); + Assert.assertEquals(byteBuffer.get(0), expectedRic); + Assert.assertEquals(byteBuffer.get(1), expectedBase); + } + + @DataProvider(name = "encodingCases") + public Object[][] encodingCases() { + var cases = new ArrayList(); + // Simple with 0 values + cases.add(new Object[]{0L, 0L, (byte) 0x0, (byte) 0x0}); + // Based on RFC-9204: "B.2. Dynamic Table example" + cases.add(new Object[]{0L, 2L, (byte) 0x3, (byte) 0x81}); + // Based on RFC-9204: "Duplicate Instruction, Stream Cancellation" + cases.add(new Object[]{4L, 4L, (byte) 0x5, (byte) 0x0}); + return cases.toArray(Object[][]::new); + } + + @Test(dataProvider = "decodingCases") + public void decodingTest(long expectedRIC, long expectedBase, byte... bytes) throws IOException { + var logger = QPACK.getLogger().subLogger("decodingTest"); + var dt = new DynamicTable(logger, false); + dt.setMaxTableCapacity(DT_CAPACITY); + dt.setCapacity(DT_CAPACITY); + var callback = new DecodingCallback() { + @Override + public void onDecoded(CharSequence name, CharSequence value) { + } + + @Override + public void onComplete() { + } + + @Override + public void onConnectionError(Throwable throwable, Http3Error http3Error) { + throw new RuntimeException("Error during Field Line Section Prefix decoding - " + + http3Error + ": " + throwable.getMessage()); + } + + @Override + public long streamId() { + return 0; + } + }; + AtomicLong blockedStreamsCounter = new AtomicLong(); + // maxBlockStreams = 1 is needed for tests with Required Insert Count > 0 - otherwise test + // fails with "QPACK_DECOMPRESSION_FAILED: too many blocked streams" + HeaderFrameReader reader = new HeaderFrameReader(dt, callback, blockedStreamsCounter, + 1, -1, logger); + var bb = ByteBuffer.wrap(bytes); + reader.read(bb, false); + var fsp = reader.decodedSectionPrefix(); + + System.err.println("Required Insert Count:" + fsp.requiredInsertCount()); + System.err.println("Base:" + fsp.base()); + Assert.assertEquals(fsp.requiredInsertCount(), expectedRIC); + Assert.assertEquals(fsp.base(), expectedBase); + + } + + @DataProvider(name = "decodingCases") + public Object[][] decodingCases() { + var cases = new ArrayList(); + cases.add(new Object[]{0L, 0L, (byte) 0x0, (byte) 0x0}); + cases.add(new Object[]{4L, 4L, (byte) 0x5, (byte) 0x0}); + cases.add(new Object[]{2L, 0L, (byte) 0x3, (byte) 0x81}); + return cases.toArray(Object[][]::new); + } +} diff --git a/test/jdk/java/net/httpclient/qpack/IntegerReaderMaxValuesTest.java b/test/jdk/java/net/httpclient/qpack/IntegerReaderMaxValuesTest.java new file mode 100644 index 00000000000..6f21549f497 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/IntegerReaderMaxValuesTest.java @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.qpack.QPackException; +import jdk.internal.net.http.qpack.readers.IntegerReader; +import jdk.internal.net.http.qpack.writers.IntegerWriter; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.nio.ByteBuffer; +import java.util.stream.IntStream; + +/* + * @test + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.http3 + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * @run testng/othervm -Djdk.internal.httpclient.qpack.log.level=INFO + * IntegerReaderMaxValuesTest + */ +public class IntegerReaderMaxValuesTest { + @DataProvider + public Object[][] nValues() { + return IntStream.range(1, 8) + .boxed() + .map(N -> new Object[]{N}) + .toArray(Object[][]::new); + } + + @Test(dataProvider = "nValues") + public void maxIntegerWriteRead(int N) { + IntegerWriter writer = new IntegerWriter(); + writer.configure(IntegerReader.QPACK_MAX_INTEGER_VALUE, N, 0); + ByteBuffer buffer = ByteBuffer.allocate(1024); + writer.write(buffer); + IntegerReader reader = new IntegerReader(); + reader.configure(N); + buffer.flip(); + reader.read(buffer); + long result = reader.get(); + Assert.assertEquals(result, IntegerReader.QPACK_MAX_INTEGER_VALUE); + } + + @Test(dataProvider = "nValues", expectedExceptions = QPackException.class) + public void overflowInteger(int N) { + // Construct buffer with overflowed integer + ByteBuffer overflowBuffer = ByteBuffer.allocate(11); + + overflowBuffer.put((byte) ((2 << (N - 1)) - 1)); + for (int i = 0; i < 9; i++) { + overflowBuffer.put((byte) 128); + } + overflowBuffer.put((byte) 10); + overflowBuffer.flip(); + // Read the buffer with IntegerReader + IntegerReader reader = new IntegerReader(); + reader.configure(N); + reader.read(overflowBuffer); + } +} diff --git a/test/jdk/java/net/httpclient/qpack/StaticTableFieldsTest.java b/test/jdk/java/net/httpclient/qpack/StaticTableFieldsTest.java new file mode 100644 index 00000000000..662d0c264d8 --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/StaticTableFieldsTest.java @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @modules java.net.http/jdk.internal.net.http.qpack + * @run testng/othervm -Djdk.internal.httpclient.qpack.log.level=NORMAL StaticTableFieldsTest + */ + +import jdk.internal.net.http.qpack.StaticTable; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.testng.Assert.assertEquals; + +public class StaticTableFieldsTest { + + @BeforeTest + public void setUp() { + // Populate expected table as defined by RFC + expectedTable = new ArrayList<>(); + String[] arr = staticTableFields.split("\n"); + for (String s : arr) { + s = s.replaceAll("( )+", " "); + // index + int endOfIndex = s.indexOf(" "); + var i = Integer.parseInt(s.substring(0, endOfIndex)); + // name + int endOfName = s.indexOf(" ", endOfIndex + 1); + var n = s.substring(endOfIndex + 1, endOfName).trim(); + // value + var v = s.substring(endOfName + 1).strip(); + expectedTable.add(new TableLine(i, n, v)); + } + // Populate actual static table currently being used by QPACK + actualTable = new ArrayList<>(); + for (int i = 0; i < StaticTable.HTTP3_HEADER_FIELDS.size(); i++) { + var n = StaticTable.HTTP3_HEADER_FIELDS.get(i).name(); + var v = StaticTable.HTTP3_HEADER_FIELDS.get(i).value(); + actualTable.add(new TableLine(i, n, v)); + } + } + + @Test + public void testStaticTable() { + assertEquals(actualTable.size(), expectedTable.size()); + for (int i = 0; i < expectedTable.size(); i++) { + assertEquals(actualTable.get(i).name(), expectedTable.get(i).name()); + assertEquals(actualTable.get(i).value(), expectedTable.get(i).value()); + } + } + + // Copy-Paste of static table from RFC 9204 for QPACK Appendix A + // https://www.rfc-editor.org/rfc/rfc9204.html#name-static-table-2 + String staticTableFields = """ + 0 :authority \s + 1 :path / + 2 age 0 + 3 content-disposition \s + 4 content-length 0 + 5 cookie \s + 6 date \s + 7 etag \s + 8 if-modified-since \s + 9 if-none-match \s + 10 last-modified \s + 11 link \s + 12 location \s + 13 referer \s + 14 set-cookie \s + 15 :method CONNECT + 16 :method DELETE + 17 :method GET + 18 :method HEAD + 19 :method OPTIONS + 20 :method POST + 21 :method PUT + 22 :scheme http + 23 :scheme https + 24 :status 103 + 25 :status 200 + 26 :status 304 + 27 :status 404 + 28 :status 503 + 29 accept */* + 30 accept application/dns-message + 31 accept-encoding gzip, deflate, br + 32 accept-ranges bytes + 33 access-control-allow-headers cache-control + 34 access-control-allow-headers content-type + 35 access-control-allow-origin * + 36 cache-control max-age=0 + 37 cache-control max-age=2592000 + 38 cache-control max-age=604800 + 39 cache-control no-cache + 40 cache-control no-store + 41 cache-control public, max-age=31536000 + 42 content-encoding br + 43 content-encoding gzip + 44 content-type application/dns-message + 45 content-type application/javascript + 46 content-type application/json + 47 content-type application/x-www-form-urlencoded + 48 content-type image/gif + 49 content-type image/jpeg + 50 content-type image/png + 51 content-type text/css + 52 content-type text/html; charset=utf-8 + 53 content-type text/plain + 54 content-type text/plain;charset=utf-8 + 55 range bytes=0- + 56 strict-transport-security max-age=31536000 + 57 strict-transport-security max-age=31536000; includesubdomains + 58 strict-transport-security max-age=31536000; includesubdomains; preload + 59 vary accept-encoding + 60 vary origin + 61 x-content-type-options nosniff + 62 x-xss-protection 1; mode=block + 63 :status 100 + 64 :status 204 + 65 :status 206 + 66 :status 302 + 67 :status 400 + 68 :status 403 + 69 :status 421 + 70 :status 425 + 71 :status 500 + 72 accept-language \s + 73 access-control-allow-credentials FALSE + 74 access-control-allow-credentials TRUE + 75 access-control-allow-headers * + 76 access-control-allow-methods get + 77 access-control-allow-methods get, post, options + 78 access-control-allow-methods options + 79 access-control-expose-headers content-length + 80 access-control-request-headers content-type + 81 access-control-request-method get + 82 access-control-request-method post + 83 alt-svc clear + 84 authorization \s + 85 content-security-policy script-src 'none'; object-src 'none'; base-uri 'none' + 86 early-data 1 + 87 expect-ct \s + 88 forwarded \s + 89 if-range \s + 90 origin \s + 91 purpose prefetch + 92 server \s + 93 timing-allow-origin * + 94 upgrade-insecure-requests 1 + 95 user-agent \s + 96 x-forwarded-for \s + 97 x-frame-options deny + 98 x-frame-options sameorigin + """; + + private List actualTable, expectedTable; + private record TableLine(int index, String name, String value) { } +} diff --git a/test/jdk/java/net/httpclient/qpack/StringLengthLimitsTest.java b/test/jdk/java/net/httpclient/qpack/StringLengthLimitsTest.java new file mode 100644 index 00000000000..ade50d2843b --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/StringLengthLimitsTest.java @@ -0,0 +1,482 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + + +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.QPackException; +import jdk.internal.net.http.qpack.readers.HeaderFrameReader; +import jdk.internal.net.http.qpack.readers.StringReader; +import jdk.internal.net.http.qpack.writers.HeaderFrameWriter; +import jdk.internal.net.http.qpack.writers.IntegerWriter; +import jdk.internal.net.http.qpack.writers.StringWriter; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.ProtocolException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +/* + * @test + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack:+open + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @build EncoderDecoderConnector + * @run testng/othervm -Djdk.http.qpack.allowBlockingEncoding=true + * StringLengthLimitsTest + */ +public class StringLengthLimitsTest { + + @DataProvider + Object[][] stringReaderLimitsData() { + return new Object[][]{ + {STRING_READER_STRING_LENGTH, STRING_READER_STRING_LENGTH, false, false}, + {STRING_READER_STRING_LENGTH, STRING_READER_STRING_LENGTH - 1, false, true}, + {STRING_READER_STRING_LENGTH, STRING_READER_STRING_LENGTH / 4, true, false}, + {STRING_READER_STRING_LENGTH, STRING_READER_STRING_LENGTH / 4 - 1, true, true} + }; + } + + @Test(dataProvider = "stringReaderLimitsData") + public void stringReaderLimits(int length, int limit, boolean huffmanBit, + boolean exceptionExpected) throws IOException { + IntegerWriter intWriter = new IntegerWriter(); + intWriter.configure(length, 7, huffmanBit ? STRING_READER_HUFFMAN_PAYLOAD : STRING_READER_PAYLOAD); + + var byteBuffer = ByteBuffer.allocate(2); + if (!intWriter.write(byteBuffer)) { + Assert.fail("Error with test buffer preparations"); + } + byteBuffer.flip(); + + StringReader stringReader = new StringReader(); + StringBuilder unusedOutput = new StringBuilder(); + if (exceptionExpected) { + QPackException exception = Assert.expectThrows(QPackException.class, + () -> stringReader.read(byteBuffer, unusedOutput, limit)); + Throwable cause = exception.getCause(); + Assert.assertNotNull(cause); + Assert.assertTrue(cause instanceof ProtocolException); + System.err.println("Got expected ProtocolException: " + cause); + } else { + boolean done = stringReader.read(byteBuffer, unusedOutput, limit); + Assert.assertFalse(done, "read done"); + } + } + + @DataProvider + Object[][] encoderInstructionLimitsData() { + int maxEntrySize = ENCODER_INSTRUCTIONS_DT_CAPACITY - 32; + return new Object[][]{ + // "Insert with Literal Name" instruction tests + // No Huffman, incomplete instruction, enough space in the DT + {EncoderInstruction.INSERT_LITERAL_NAME, maxEntrySize, + false, DO_NOT_GENERATE_PART, false, true}, + // No Huffman, incomplete instruction, not enough space in the DT + {EncoderInstruction.INSERT_LITERAL_NAME, maxEntrySize + 1, + false, DO_NOT_GENERATE_PART, false, false}, + // No Huffman, full instruction, enough space in the DT + {EncoderInstruction.INSERT_LITERAL_NAME, maxEntrySize / 2, + false, maxEntrySize / 2, false, true}, + // No Huffman, full instruction, not enough space in the DT + {EncoderInstruction.INSERT_LITERAL_NAME, maxEntrySize / 2, + false, 1 + maxEntrySize / 2, false, false}, + // Huffman (name + value), full instruction, enough space + // in the DT + {EncoderInstruction.INSERT_LITERAL_NAME, maxEntrySize / 4 / 2, + true, maxEntrySize / 4 / 2, true, true}, + // Huffman (value only), full instruction, not enough space + // in the DT. + // +16 term is added to make sure huffman estimate exceeds the limit + {EncoderInstruction.INSERT_LITERAL_NAME, maxEntrySize / 2, + false, 2 * maxEntrySize + 16, true, false}, + + // "Insert with Name Reference" instruction tests + // Enough space in the DT for the value part + {EncoderInstruction.INSERT_NAME_REFERENCE, DO_NOT_GENERATE_PART, + false, maxEntrySize, false, true}, + // Not enough space in the DT for the value part + {EncoderInstruction.INSERT_NAME_REFERENCE, DO_NOT_GENERATE_PART, + false, maxEntrySize + 1, false, false}, + // Enough space in the DT for the Huffman encoded value part + {EncoderInstruction.INSERT_NAME_REFERENCE, DO_NOT_GENERATE_PART, + false, maxEntrySize * 4, true, true}, + // Not enough space in the DT for the Huffman encoded value part + {EncoderInstruction.INSERT_NAME_REFERENCE, DO_NOT_GENERATE_PART, + false, maxEntrySize * 4 + 4, true, false} + }; + } + + @Test(dataProvider = "encoderInstructionLimitsData") + public void encoderInstructionLimits(EncoderInstruction instruction, + int nameLength, boolean nameHuffman, + int valueLength, boolean valueHuffman, + boolean successExpected) { + // Encoder/decoder pair with instruction that has string and with length > limit + var connector = new EncoderDecoderConnector(); + AtomicReference observedError = new AtomicReference<>(); + QPACK.QPACKErrorHandler errorHandler = (throwable, error) -> { + System.err.println("QPACK error observed: " + error); + observedError.set(throwable); + }; + var streamError = new AtomicReference(); + var pair = connector.newEncoderDecoderPair((_) -> true, errorHandler, errorHandler, streamError::set); + + // Configure dynamic tables + var decoderDT = pair.decoderTable(); + var encoderDT = pair.encoderTable(); + encoderDT.setMaxTableCapacity(ENCODER_INSTRUCTIONS_DT_CAPACITY); + encoderDT.setCapacity(ENCODER_INSTRUCTIONS_DT_CAPACITY); + decoderDT.setMaxTableCapacity(ENCODER_INSTRUCTIONS_DT_CAPACITY); + decoderDT.setCapacity(ENCODER_INSTRUCTIONS_DT_CAPACITY); + + // Generate buffers with encoder instruction bytes + var instructionBuffers = generateInstructionBuffers(instruction, + nameLength, nameHuffman, valueLength, valueHuffman); + for (var buffer : instructionBuffers) { + // Submit encoder instruction with test instructions which + // could be incomplete + pair.encoderStreams().submitData(buffer); + } + Throwable error = observedError.get(); + if (successExpected && error != null) { + Assert.fail("Unexpected error", error); + } else if (error == null && !successExpected) { + Assert.fail("Expected error"); + } + } + + /* + * Generate a list of instruction buffers. + * First buffer contains a name part (index or String), + * Second buffer contains a value part (index or String). + * If instruction type is INSERT_NAME_REFERENCE - + * the nameLength and nameHuffman are ignored. + */ + private static List generateInstructionBuffers( + EncoderInstruction instruction, + int nameLength, boolean nameHuffman, + int valueLength, boolean valueHuffman) { + IntegerWriter intWriter = new IntegerWriter(); + StringWriter stringWriter = new StringWriter(); + List instructionBuffers = new ArrayList<>(); + int valuePartPayload = valueHuffman ? + STRING_READER_HUFFMAN_PAYLOAD : STRING_READER_PAYLOAD; + // Configure writers for an instruction + switch (instruction) { + case INSERT_LITERAL_NAME: + int namePartPayload = nameHuffman ? + INSERT_INSTRUCTION_LITERAL_NAME_HUFFMAN_PAYLOAD : + INSERT_INSTRUCTION_LITERAL_NAME_PAYLOAD; + if (valueLength != DO_NOT_GENERATE_PART) { + // Generate data for the name part + var namePartBB = ByteBuffer.allocate(nameLength + 1); + stringWriter.configure("T".repeat(nameLength), 5, namePartPayload, valueHuffman); + boolean nameDone = stringWriter.write(namePartBB); + assert nameDone; + namePartBB.flip(); + instructionBuffers.add(namePartBB); + // Generate data for the value part + var valuePartBB = generatePartialString(7, valuePartPayload, valueLength); + instructionBuffers.add(valuePartBB); + } else { + // Generate data for the name part only + var namePartBB = generatePartialString(5, namePartPayload, nameLength); + instructionBuffers.add(namePartBB); + } + break; + case INSERT_NAME_REFERENCE: + var nameIndexPart = ByteBuffer.allocate(1); + // Write some static table name id + // Referencing static table entry with id = 16, ie ":method" + // nameLength and nameHuffman are ignored + intWriter.configure(16, 6, INSERT_INSTRUCTION_WITH_NAME_REFERENCE_PAYLOAD); + boolean nameIndexDone = intWriter.write(nameIndexPart); + assert nameIndexDone; + nameIndexPart.flip(); + intWriter.reset(); + // Write value part with specified length and huffman encoding + // Generate data for the value part + var valueLengthPart = + generatePartialString(7, valuePartPayload, valueLength); + // Add both parts to the list of forged instruction buffers + instructionBuffers.add(nameIndexPart); + instructionBuffers.add(valueLengthPart); + break; + } + return instructionBuffers; + } + + @DataProvider + Object[][] fieldLineLimitsData() { + return new Object[][]{ + // Post-Base Index + {-1, -1, ENTRY_NAME, ENTRY_VALUE, true, true}, + // Relative Index + {-1, -1, ENTRY_NAME, ENTRY_VALUE, false, true}, + // Post-Base Name Index + {-1, -1, ENTRY_NAME, "X".repeat(ENTRY_VALUE.length()), true, true}, + // Relative Name Index + {-1, -1, ENTRY_NAME, "X".repeat(ENTRY_VALUE.length()), false, true}, + // Post-Base Index, limit is exceeded + {-1, -1, BIG_ENTRY_NAME, BIG_ENTRY_VALUE, true, false}, + // Relative Index, limit is exceeded + {-1, -1, BIG_ENTRY_NAME, BIG_ENTRY_VALUE, false, false}, + // Post-Base Name Index, limit is exceeded + {-1, -1, BIG_ENTRY_NAME, ENTRY_VALUE, true, false}, + // Relative Name Index, limit is exceeded + {-1, -1, ENTRY_NAME, BIG_ENTRY_VALUE, false, false}, + // Name and Value are literals, limit is not exceeded + {ENTRY_NAME.length(), ENTRY_VALUE.length(), null, null, false, true}, + // Name and Value are literals, limit is exceeded in name part + {ENTRY_NAME.length() + ENTRY_VALUE.length() + 1, 0, null, null, false, false}, + // Name and Value are literals, limit is exceeded in value part + {1, ENTRY_NAME.length() + ENTRY_VALUE.length(), null, null, false, false} + + }; + } + + @Test(dataProvider = "fieldLineLimitsData") + public void fieldLineLimits(int nameLength, int valueLength, + String name, String value, + boolean isPostBase, boolean successExpected) throws IOException { + // QPACK writers for test data generations + IntegerWriter intWriter = new IntegerWriter(); + StringWriter stringWriter = new StringWriter(); + + // QPACK error handlers, stream error capture and decoding callback + var encoderError = new AtomicReference(); + var decoderError = new AtomicReference(); + var streamError = new AtomicReference(); + var decodingCallbackError = new AtomicReference(); + + QPACK.QPACKErrorHandler encoderErrorHandler = (throwable, error) -> { + System.err.println("Encoder error observed: " + error); + encoderError.set(throwable); + }; + + QPACK.QPACKErrorHandler decoderErrorHandler = (throwable, error) -> { + System.err.println("Decoder error observed: " + error); + decoderError.set(throwable); + }; + + var decodingCallback = new FieldLineDecodingCallback(decodingCallbackError); + + // Create encoder/decoder pair + var conn = new EncoderDecoderConnector(); + var pair = conn.newEncoderDecoderPair( + // Disallow entries insertion. + // The dynamic table is pre-populated with needed entries + // before the test execution + _ -> false, + encoderErrorHandler, + decoderErrorHandler, + streamError::set); + var encoder = pair.encoder(); + var decoder = pair.decoder(); + + // Set MAX_HEADER_SIZE limit on a decoder side + // Create settings frame with MAX_FIELD_SECTION_SIZE + SettingsFrame settingsFrame = SettingsFrame.defaultRFCSettings(); + settingsFrame.setParameter(SettingsFrame.SETTINGS_MAX_FIELD_SECTION_SIZE, MAX_FIELD_SECTION_SIZE_SETTING_VALUE); + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, FIELD_LINES_DT_CAPACITY); + pair.decoder().configure(ConnectionSettings.createFrom(settingsFrame)); + + // Configure tables + configureTablesForFieldLinesTest(pair); + + // Encode the section prefix + long RIC = pair.encoderTable().insertCount(); + long base = isPostBase ? 0 : RIC; + HeaderFrameWriter writer = encoder.newHeaderFrameWriter(); + HeaderFrameReader reader = decoder.newHeaderFrameReader(decodingCallback); + var encodingContext = encoder.newEncodingContext(123, base, writer); + List buffers = new ArrayList<>(); + + if (nameLength == -1 && valueLength == -1) { + // Test configuration for all indexed field line wire formats + // (indexed name and indexed entry) + var headersBuffer = ByteBuffer.allocate(1024); + // knownReceivedCount == InsertCount to allow DT reference encodings, + // since there is one entry in the test dynamic table + encoder.header(encodingContext, name, value, false, RIC); + writer.write(headersBuffer); + headersBuffer.flip(); + buffers.add(headersBuffer); + } else { + if (nameLength > MAX_FIELD_SECTION_SIZE_SETTING_VALUE - DynamicTable.ENTRY_SIZE) { + // if nameLength > limit - only need to generate name part + // We only write partial name part of the "Literal Field Line with Literal + // Name" instruction with String length value + var nameLengthBB = + generatePartialString(3, FIELD_LINE_NAME_VALUE_LITERALS_PAYLOAD, nameLength); + buffers.add(nameLengthBB); + } else if (nameLength + valueLength > + MAX_FIELD_SECTION_SIZE_SETTING_VALUE - DynamicTable.ENTRY_SIZE) { + // if nameLength + valueLength > limit - + // the whole instruction needs to be generated with basic writers + var fieldLineBB = ByteBuffer.allocate(1024); + stringWriter.configure("Z".repeat(nameLength), 3, + FIELD_LINE_NAME_VALUE_LITERALS_PAYLOAD, false); + intWriter.configure(valueLength, 7, 0); + stringWriter.write(fieldLineBB); + intWriter.write(fieldLineBB); + fieldLineBB.flip(); + buffers.add(fieldLineBB); + } else { + // name + value doesn't exceed MAX_FIELD_SECTION_SIZE + var headersBuffer = ByteBuffer.allocate(1024); + // We use 'X' and 'Z' letters to prevent encoder from + // huffman encoding. + encoder.header(encodingContext, + "X".repeat(nameLength), "Z".repeat(valueLength), + false, RIC); + writer.write(headersBuffer); + headersBuffer.flip(); + buffers.add(headersBuffer); + } + } + // Generate field lines section prefix + encoder.generateFieldLineSectionPrefix(encodingContext, buffers); + assert buffers.size() == 2; + + // Decode generated header buffers + decoder.decodeHeader(buffers.get(0), false, reader); + decoder.decodeHeader(buffers.get(1), true, reader); + + // Check if any error is observed and it meets the test expectations + var error = decodingCallbackError.get(); + System.err.println("Decoding callback error: " + error); + if (successExpected && error != null) { + Assert.fail("Unexpected error", error); + } else if (error == null && !successExpected) { + Assert.fail("Error expected"); + } + } + + private static void configureTablesForFieldLinesTest( + EncoderDecoderConnector.EncoderDecoderPair pair) { + // Encoder + pair.encoderTable().setMaxTableCapacity(StringLengthLimitsTest.FIELD_LINES_DT_CAPACITY); + pair.encoderTable().setCapacity(StringLengthLimitsTest.FIELD_LINES_DT_CAPACITY); + + // Decoder max table capacity is set via the settings frame + pair.decoderTable().setCapacity(StringLengthLimitsTest.FIELD_LINES_DT_CAPACITY); + + // Insert test entry to both tables + pair.decoderTable().insert(ENTRY_NAME, ENTRY_VALUE); + pair.encoderTable().insert(ENTRY_NAME, ENTRY_VALUE); + pair.decoderTable().insert(BIG_ENTRY_NAME, BIG_ENTRY_VALUE); + pair.encoderTable().insert(BIG_ENTRY_NAME, BIG_ENTRY_VALUE); + } + + // Encoder instructions under test + public enum EncoderInstruction { + INSERT_LITERAL_NAME, + INSERT_NAME_REFERENCE, + } + + // Decoding callback used by Field Line Representation tests + private static class FieldLineDecodingCallback implements DecodingCallback { + private final AtomicReference decodingError; + + public FieldLineDecodingCallback(AtomicReference decodingCallbackError) { + this.decodingError = decodingCallbackError; + } + + @Override + public void onDecoded(CharSequence name, CharSequence value) { + } + + @Override + public void onComplete() { + } + + @Override + public void onConnectionError(Throwable throwable, Http3Error http3Error) { + decodingError.set(throwable); + } + + @Override + public long streamId() { + return 0; + } + } + + // Utility method to generate partial QPack string with length part only + private static ByteBuffer generatePartialString(int N, int payload, int length) { + IntegerWriter intWriter = new IntegerWriter(); + var partialStringBB = ByteBuffer.allocate( + IntegerWriter.requiredBufferSize(N, length)); + intWriter.configure(length, N, payload); + boolean done = intWriter.write(partialStringBB); + assert done; + partialStringBB.flip(); + return partialStringBB; + } + + // Constants for StringReader tests + private static final int STRING_READER_HUFFMAN_PAYLOAD = 0b1000_0000; + private static final int STRING_READER_PAYLOAD = 0b0000_0000; + private static final int STRING_READER_STRING_LENGTH = 32; + + // Constants for Encoder Instructions Reader tests + private static final int ENCODER_INSTRUCTIONS_DT_CAPACITY = 64; + // This constant is used to instruct test data generator not to generate + // value or name (if name is referenced by the DT index) part of the + // decoder instruction + private static final int DO_NOT_GENERATE_PART = -1; + // Encoder instruction payloads used to forge instruction buffers + private static final int INSERT_INSTRUCTION_LITERAL_NAME_HUFFMAN_PAYLOAD = 0b0110_0000; + private static final int INSERT_INSTRUCTION_LITERAL_NAME_PAYLOAD = 0b0100_0000; + private static final int INSERT_INSTRUCTION_WITH_NAME_REFERENCE_PAYLOAD = 0b1100_0000; + private static final int FIELD_LINE_NAME_VALUE_LITERALS_PAYLOAD = 0b0010_0000; + + // Constants for Field Line Representation tests + // Table capacity big enough for insertion of all entries + private static final long FIELD_LINES_DT_CAPACITY = 1024; + private static final String ENTRY_NAME = "FullEntryName"; + private static final String ENTRY_VALUE = "FullEntryValue"; + private static final String BIG_ENTRY_NAME = "FullEntryName_Big"; + private static final String BIG_ENTRY_VALUE = "FullEntryValue_Big"; + private static final long MAX_FIELD_SECTION_SIZE_SETTING_VALUE = + ENTRY_NAME.length() + ENTRY_VALUE.length() + DynamicTable.ENTRY_SIZE; + +} diff --git a/test/jdk/java/net/httpclient/qpack/TablesIndexerTest.java b/test/jdk/java/net/httpclient/qpack/TablesIndexerTest.java new file mode 100644 index 00000000000..4ed489ca6dd --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/TablesIndexerTest.java @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @modules java.net.http/jdk.internal.net.http.qpack + * @run testng/othervm -Djdk.internal.httpclient.qpack.log.level=INFO TablesIndexerTest + */ + +import jdk.internal.net.http.qpack.DynamicTable; +import jdk.internal.net.http.qpack.HeaderField; +import jdk.internal.net.http.qpack.QPACK; +import jdk.internal.net.http.qpack.StaticTable; +import jdk.internal.net.http.qpack.TableEntry; +import jdk.internal.net.http.qpack.TablesIndexer; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; + +import static jdk.internal.net.http.qpack.TableEntry.EntryType; + +public class TablesIndexerTest { + + @DataProvider(name = "indicesLookupData") + public Object[][] indicesData() { + List tcs = new ArrayList<>(); + + long index = 0; + for (HeaderField f : StaticTable.HTTP3_HEADER_FIELDS) { + // Full name:value match + tcs.add(new Object[]{ + f.name(), f.value(), f.value(), + Set.of(index), EntryType.NAME_VALUE + }); + // Static and Dynamic tables contain only name match - we expect static + // NAME only entry to be returned + tcs.add(new Object[]{ + f.name(), "NotInStatic", "InDynamic", + acceptableIndicesForName(f.name()), EntryType.NAME}); + index++; + } + return tcs.toArray(Object[][]::new); + } + + @Test(dataProvider = "indicesLookupData") + public void checkIndicesLookup(String name, String value, + String dynamicTableValue, + Set indices, EntryType type) { + // We construct dynamic table with the same name value to check that + // static entry is returned first + DynamicTable dt = new DynamicTable(QPACK.getLogger().subLogger( + "checkStaticIndicesLookup")); + dt.setMaxTableCapacity(256); + dt.setCapacity(256); + dt.insert(name, dynamicTableValue); + + // Construct TablesIndexer + TablesIndexer tablesIndexer = new TablesIndexer(STATIC_TABLE, dt); + + // Use TablesIndexer to locate TableEntry + TableEntry tableEntry = tablesIndexer.entryOf(name, value, + IGNORE_RECEIVED_COUNT_CHECK); + + // TableEntry should be for static table only + Assert.assertTrue(tableEntry.isStaticTable()); + + // If value is not equal to dynamicTableValue, the full name:dynamicTableValue + // should be found in the dynamic table with index 0 + if (!value.equals(dynamicTableValue)) { + TableEntry dtEntry = tablesIndexer.entryOf(name, dynamicTableValue, + IGNORE_RECEIVED_COUNT_CHECK); + Assert.assertFalse(dtEntry.isStaticTable()); + Assert.assertEquals(dtEntry.type(), EntryType.NAME_VALUE); + Assert.assertEquals(dtEntry.index(), 0L); + } + + // Check that found index is contained in a set and returned indices match + Assert.assertTrue(indices.contains(tableEntry.index())); + + // Check that entry type matches + Assert.assertEquals(tableEntry.type(), type); + + var headerField = STATIC_TABLE.get(tableEntry.index()); + // Check that name and/or value matches the one that can be acquired by + // using looked-up index + if (tableEntry.type() == EntryType.NAME) { + Assert.assertEquals(headerField.name(), name); + // If only name entry is found huffmanName should be set to false + Assert.assertFalse(tableEntry.huffmanName()); + } else if (tableEntry.type() == EntryType.NAME_VALUE) { + Assert.assertEquals(headerField.name(), name); + Assert.assertEquals(headerField.value(), value); + // If "name:value" match is found huffmanName and huffmanValue should + // be set to false + Assert.assertFalse(tableEntry.huffmanName()); + Assert.assertFalse(tableEntry.huffmanValue()); + + } else { + Assert.fail("Unexpected TableEntry type returned:" + tableEntry); + } + } + + @Test(dataProvider = "unacknowledgedEntriesLookupData") + public void unacknowledgedEntryLookup(String headerName, String headerValue, + boolean staticEntryExpected, + EntryType expectedType) { + // Construct dynamic table with pre-populated entries + DynamicTable dynamicTable = dynamicTableForUnackedEntriesTest(); + // Construct TablesIndexer + TablesIndexer tablesIndexer = new TablesIndexer(STATIC_TABLE, dynamicTable); + // Search for an entry in the dynamic and the static tables + var entry = tablesIndexer.entryOf(headerName, headerValue, TEST_KNOWN_RECEIVED_COUNT); + // Check that entry references expected table + Assert.assertEquals(entry.isStaticTable(), staticEntryExpected); + // And the type of found entry matches expectations + Assert.assertEquals(entry.type(), expectedType); + } + + @DataProvider + public Object[][] unacknowledgedEntriesLookupData() { + List data = new ArrayList<>(); + data.add(new Object[]{USER_AGENT_ST_NAME, "not-in-dynamic", true, EntryType.NAME}); + data.add(new Object[]{USER_AGENT_ST_NAME, USER_AGENT_DT_VALUE, false, EntryType.NAME_VALUE}); + data.add(new Object[]{TEST_ACKED_ENTRY, "not-in-dynamic", false, EntryType.NAME}); + data.add(new Object[]{TEST_ACKED_ENTRY, TEST_ACKED_ENTRY, false, EntryType.NAME_VALUE}); + data.add(new Object[]{CONTENT_TYPE_ST_NAME, "what/ever", true, EntryType.NAME}); + data.add(new Object[]{TEST_UNACKED_ENTRY, TEST_UNACKED_ENTRY, false, EntryType.NEITHER}); + return data.toArray(Object[][]::new); + } + + private static DynamicTable dynamicTableForUnackedEntriesTest() { + DynamicTable dt = new DynamicTable(QPACK.getLogger() + .subLogger("unacknowledgedEntryLookup")); + dt.setMaxTableCapacity(1024); + dt.setCapacity(1024); + // Acknowledged entry with name available in static and dynamic table + dt.insert(USER_AGENT_ST_NAME, USER_AGENT_DT_VALUE); // 0 + // Acknowledged entry with name available in dynamic table only + dt.insert(TEST_ACKED_ENTRY, TEST_ACKED_ENTRY); // 1 + // Unacknowledged entry with name available in static table + dt.insert(CONTENT_TYPE_ST_NAME, "what/ever"); // 2 + // Unacknowledged entry with name available in dynamic table only + dt.insert(TEST_UNACKED_ENTRY, TEST_UNACKED_ENTRY); // 3 + return dt; + } + + private static final long IGNORE_RECEIVED_COUNT_CHECK = -1L; + private static final long TEST_KNOWN_RECEIVED_COUNT = 2L; + + private static final String USER_AGENT_ST_NAME = "user-agent"; + private static final String USER_AGENT_DT_VALUE = "qpack-test-client"; + private static final String CONTENT_TYPE_ST_NAME = "content-type"; + private static final String TEST_ACKED_ENTRY = "test-acked-entry"; + private static final String TEST_UNACKED_ENTRY = "test-unacked-entry"; + + private final static StaticTable STATIC_TABLE = StaticTable.HTTP3; + + private Set acceptableIndicesForName(String name) { + AtomicLong enumerator = new AtomicLong(); + return StaticTable.HTTP3_HEADER_FIELDS.stream() + .map(f -> new NameEnumeration(enumerator.getAndIncrement(), f.name())) + .filter(ne -> name.equals(ne.name())) + .mapToLong(NameEnumeration::id) + .boxed() + .collect(Collectors.toUnmodifiableSet()); + } + + record NameEnumeration(long id, String name) { + } +} diff --git a/test/jdk/java/net/httpclient/qpack/UnacknowledgedInsertionTest.java b/test/jdk/java/net/httpclient/qpack/UnacknowledgedInsertionTest.java new file mode 100644 index 00000000000..7b621e2d69d --- /dev/null +++ b/test/jdk/java/net/httpclient/qpack/UnacknowledgedInsertionTest.java @@ -0,0 +1,243 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.http3.ConnectionSettings; +import jdk.internal.net.http.http3.Http3Error; +import jdk.internal.net.http.http3.frames.SettingsFrame; +import jdk.internal.net.http.qpack.DecodingCallback; +import jdk.internal.net.http.qpack.Encoder; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import static org.testng.Assert.assertNotEquals; + +/* + * @test + * @summary check that unacknowledged header is not inserted + * twice to the dynamic table + * @modules java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.hpack + * java.net.http/jdk.internal.net.http.qpack:+open + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * @build EncoderDecoderConnector + * @run testng/othervm -Djdk.internal.httpclient.qpack.log.level=EXTRA UnacknowledgedInsertionTest + */ +public class UnacknowledgedInsertionTest { + + @Test(dataProvider = "duplicateEntryInsertions") + public void unacknowledgedDoubleInsertion(long knownReceiveCount, List expectedHeadersEncodingType) throws Exception { + // When knownReceiveCount is set to -1 the Encoder.knownReceiveCount() + // value is used to encode headers - otherwise the provided value is used + var encoderEh = new UnacknowledgedInsertionTest.TestErrorHandler(); + var decoderEh = new UnacknowledgedInsertionTest.TestErrorHandler(); + var streamError = new AtomicReference(); + EncoderDecoderConnector.EncoderDecoderPair ed = + newPreconfiguredEncoderDecoder(encoderEh, decoderEh, streamError); + + + var encoder = ed.encoder(); + var decoder = ed.decoder(); + // Create a decoding callback to check for completion and to log failures + UnacknowledgedInsertionTest.TestDecodingCallback decodingCallback = + new UnacknowledgedInsertionTest.TestDecodingCallback(); + // Start encoding Headers Frame + var headerFrameWriter = encoder.newHeaderFrameWriter(); + var headerFrameReader = decoder.newHeaderFrameReader(decodingCallback); + + // create encoding context and buffer to hold encoded headers + List buffers = new ArrayList<>(); + + ByteBuffer headersBb = ByteBuffer.allocate(2048); + Encoder.EncodingContext context = encoder.newEncodingContext(0, 0, + headerFrameWriter); + + String name = "name"; + String value = "value"; + + for (int i = 0; i < 3; i++) { + long krcToUse = knownReceiveCount == -1L ? encoder.knownReceivedCount() : knownReceiveCount; + if (i == 1) { + encoder.header(context, name, "nameMatchOnly", false, krcToUse); + } else { + encoder.header(context, name, value, false, krcToUse); + } + headerFrameWriter.write(headersBb); + } + + // Only two entries are expected to be inserted to the dynamic table + Assert.assertEquals(ed.decoderTable().insertCount(), 2); + + // Check that headers byte buffer is not empty + assertNotEquals(headersBb.position(), 0); + headersBb.flip(); + buffers.add(headersBb); + + // Generate field section prefix bytes + encoder.generateFieldLineSectionPrefix(context, buffers); + + // Use decoder to process generated byte buffers + decoder.decodeHeader(buffers.get(0), false, headerFrameReader); + decoder.decodeHeader(buffers.get(1), true, headerFrameReader); + + var actualHeaderEncodingTypes = decodingCallback.decodedHeaders + .stream() + .map(DecodedHeader::encodingType) + .toList(); + Assert.assertEquals(actualHeaderEncodingTypes, expectedHeadersEncodingType); + } + + + @DataProvider(name = "duplicateEntryInsertions") + private Object[][] duplicateEntryInsertionsData() { + return new Object[][]{ + {0, List.of(EncodingType.LITERAL, EncodingType.LITERAL, EncodingType.LITERAL)}, + {-1, List.of(EncodingType.LITERAL, EncodingType.NAME_REF, EncodingType.INDEXED)}, + }; + } + + private static EncoderDecoderConnector.EncoderDecoderPair newPreconfiguredEncoderDecoder( + UnacknowledgedInsertionTest.TestErrorHandler encoderEh, + UnacknowledgedInsertionTest.TestErrorHandler decoderEh, + AtomicReference streamError) { + EncoderDecoderConnector conn = new EncoderDecoderConnector(); + var pair = conn.newEncoderDecoderPair( + e -> true, + encoderEh::qpackErrorHandler, + decoderEh::qpackErrorHandler, + streamError::set); + // Create settings frame with dynamic table capacity and number of blocked streams + SettingsFrame settingsFrame = SettingsFrame.defaultRFCSettings(); + // 4k should be enough for storing dynamic table entries added by 'prepopulateDynamicTable' + settingsFrame.setParameter(SettingsFrame.SETTINGS_QPACK_MAX_TABLE_CAPACITY, DT_CAPACITY); + ConnectionSettings settings = ConnectionSettings.createFrom(settingsFrame); + + // Configure encoder and decoder with constructed ConnectionSettings + pair.encoder().configure(settings); + pair.decoder().configure(settings); + pair.encoderTable().setCapacity(DT_CAPACITY); + pair.decoderTable().setCapacity(DT_CAPACITY); + + return pair; + } + + private static class TestDecodingCallback implements DecodingCallback { + + final List decodedHeaders = new CopyOnWriteArrayList<>(); + final CompletableFuture completed = new CompletableFuture<>(); + final AtomicLong completedTimestamp = new AtomicLong(); + + final AtomicReference lastThrowable = new AtomicReference<>(); + final AtomicReference lastHttp3Error = new AtomicReference<>(); + + @Override + public void onDecoded(CharSequence name, CharSequence value) { + Assert.fail("onDecoded not expected to be called"); + } + + @Override + public void onLiteralWithLiteralName(CharSequence name, boolean nameHuffman, + CharSequence value, boolean valueHuffman, + boolean hideIntermediary) { + var header = new DecodedHeader(name.toString(), value.toString(), EncodingType.LITERAL); + decodedHeaders.add(header); + System.err.println("Decoding callback 'onLiteralWithLiteralName': " + header); + } + + @Override + public void onLiteralWithNameReference(long index, + CharSequence name, + CharSequence value, + boolean valueHuffman, + boolean hideIntermediary) { + var header = new DecodedHeader(name.toString(), value.toString(), EncodingType.NAME_REF); + decodedHeaders.add(header); + System.err.println("Decoding callback 'onLiteralWithNameReference': " + header); + + } + + @Override + public void onIndexed(long index, CharSequence name, CharSequence value) { + var header = new DecodedHeader(name.toString(), value.toString(), EncodingType.INDEXED); + decodedHeaders.add(header); + System.err.println("Decoding callback 'onIndexed': " + header); + } + + + @Override + public void onComplete() { + System.err.println("Decoding callback 'onComplete'"); + completedTimestamp.set(System.nanoTime()); + completed.complete(null); + } + + @Override + public void onConnectionError(Throwable throwable, Http3Error http3Error) { + System.err.println("Decoding callback 'onError': " + http3Error); + lastThrowable.set(throwable); + lastHttp3Error.set(http3Error); + } + + @Override + public long streamId() { + return 0; + } + } + + private static class TestErrorHandler { + final AtomicReference error = new AtomicReference<>(); + final AtomicReference http3Error = new AtomicReference<>(); + + public void qpackErrorHandler(Throwable error, Http3Error http3Error) { + this.error.set(error); + this.http3Error.set(http3Error); + throw new RuntimeException("http3 error: " + http3Error, error); + } + } + + private record DecodedHeader(String name, String value, EncodingType encodingType) { + } + + enum EncodingType { + LITERAL, + NAME_REF, + INDEXED + } + + private static final long DT_CAPACITY = 4096L; +} diff --git a/test/jdk/java/net/httpclient/quic/AckElicitingTest.java b/test/jdk/java/net/httpclient/quic/AckElicitingTest.java new file mode 100644 index 00000000000..47fdb935598 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/AckElicitingTest.java @@ -0,0 +1,729 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HexFormat; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.function.Function; +import java.util.function.IntFunction; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import jdk.internal.net.http.quic.PeerConnectionId; +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicOneRttContext; +import jdk.internal.net.quic.QuicVersion; +import jdk.internal.net.http.quic.frames.AckFrame; +import jdk.internal.net.http.quic.frames.AckFrame.AckRange; +import jdk.internal.net.http.quic.frames.ConnectionCloseFrame; +import jdk.internal.net.http.quic.frames.CryptoFrame; +import jdk.internal.net.http.quic.frames.DataBlockedFrame; +import jdk.internal.net.http.quic.frames.HandshakeDoneFrame; +import jdk.internal.net.http.quic.frames.MaxDataFrame; +import jdk.internal.net.http.quic.frames.MaxStreamDataFrame; +import jdk.internal.net.http.quic.frames.MaxStreamsFrame; +import jdk.internal.net.http.quic.frames.NewConnectionIDFrame; +import jdk.internal.net.http.quic.frames.NewTokenFrame; +import jdk.internal.net.http.quic.frames.PaddingFrame; +import jdk.internal.net.http.quic.frames.PathChallengeFrame; +import jdk.internal.net.http.quic.frames.PathResponseFrame; +import jdk.internal.net.http.quic.frames.PingFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.frames.ResetStreamFrame; +import jdk.internal.net.http.quic.frames.RetireConnectionIDFrame; +import jdk.internal.net.http.quic.frames.StopSendingFrame; +import jdk.internal.net.http.quic.frames.StreamDataBlockedFrame; +import jdk.internal.net.http.quic.frames.StreamFrame; +import jdk.internal.net.http.quic.frames.StreamsBlockedFrame; +import jdk.internal.net.http.quic.packets.QuicPacketDecoder; +import jdk.internal.net.http.quic.packets.QuicPacketEncoder; +import jdk.internal.net.http.quic.CodingContext; +import jdk.internal.net.http.quic.QuicConnectionId; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketType; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.quic.QuicTransportParametersConsumer; +import jdk.test.lib.RandomFactory; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import javax.crypto.AEADBadTagException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; + +import static jdk.internal.net.http.quic.frames.QuicFrame.*; +import static jdk.internal.net.http.quic.frames.ConnectionCloseFrame.CONNECTION_CLOSE_VARIANT; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; + +/** + * @test + * @summary tests the logic to decide whether a packet or + * a frame is ACK-eliciting. + * @library /test/lib + * @run testng AckElicitingTest + * @run testng/othervm -Dseed=-7997973196290088038 AckElicitingTest + */ +public class AckElicitingTest { + + static final Random RANDOM = RandomFactory.getRandom(); + + private static class DummyQuicTLSEngine implements QuicTLSEngine { + @Override + public HandshakeState getHandshakeState() { + throw new AssertionError("should not come here!"); + } + + @Override + public boolean isTLSHandshakeComplete() { + return true; + } + + @Override + public KeySpace getCurrentSendKeySpace() { + throw new AssertionError("should not come here!"); + } + @Override + public boolean keysAvailable(KeySpace keySpace) { + return true; + } + + @Override + public void discardKeys(KeySpace keySpace) { + // no-op + } + + @Override + public void setLocalQuicTransportParameters(ByteBuffer params) { + throw new AssertionError("should not come here!"); + } + + @Override + public void restartHandshake() throws IOException { + throw new AssertionError("should not come here!"); + } + + @Override + public void setRemoteQuicTransportParametersConsumer(QuicTransportParametersConsumer consumer) { + throw new AssertionError("should not come here!"); + } + + @Override + public void deriveInitialKeys(QuicVersion version, ByteBuffer connectionId) { } + + @Override + public int getHeaderProtectionSampleSize(KeySpace keySpace) { + return 0; + } + @Override + public ByteBuffer computeHeaderProtectionMask(KeySpace keySpace, boolean incoming, ByteBuffer sample) { + return ByteBuffer.allocate(5); + } + + @Override + public int getAuthTagSize() { + return 0; + } + + @Override + public void encryptPacket(KeySpace keySpace, long packetNumber, + IntFunction headerGenerator, + ByteBuffer packetPayload, ByteBuffer output) + throws QuicKeyUnavailableException, QuicTransportException { + // this dummy QUIC TLS engine doesn't do any encryption. + // we just copy over the raw packet payload into the output buffer + output.put(packetPayload); + } + + @Override + public void decryptPacket(KeySpace keySpace, long packetNumber, int keyPhase, + ByteBuffer packet, int headerLength, ByteBuffer output) { + packet.position(packet.position() + headerLength); + output.put(packet); + } + @Override + public void signRetryPacket(QuicVersion quicVersion, + ByteBuffer originalConnectionId, ByteBuffer packet, ByteBuffer output) { + throw new AssertionError("should not come here!"); + } + @Override + public void verifyRetryPacket(QuicVersion quicVersion, + ByteBuffer originalConnectionId, ByteBuffer packet) throws AEADBadTagException { + throw new AssertionError("should not come here!"); + } + @Override + public ByteBuffer getHandshakeBytes(KeySpace keySpace) { + throw new AssertionError("should not come here!"); + } + @Override + public void consumeHandshakeBytes(KeySpace keySpace, ByteBuffer payload) { + throw new AssertionError("should not come here!"); + } + @Override + public Runnable getDelegatedTask() { + throw new AssertionError("should not come here!"); + } + @Override + public boolean tryMarkHandshakeDone() { + throw new AssertionError("should not come here!"); + } + @Override + public boolean tryReceiveHandshakeDone() { + throw new AssertionError("should not come here!"); + } + + @Override + public Set getSupportedQuicVersions() { + return Set.of(QuicVersion.QUIC_V1); + } + + @Override + public void setUseClientMode(boolean mode) { + throw new AssertionError("should not come here!"); + } + + @Override + public boolean getUseClientMode() { + throw new AssertionError("should not come here!"); + } + + @Override + public SSLParameters getSSLParameters() { + throw new AssertionError("should not come here!"); + } + + @Override + public void setSSLParameters(SSLParameters sslParameters) { + throw new AssertionError("should not come here!"); + } + + @Override + public String getApplicationProtocol() { + return null; + } + + @Override + public SSLSession getSession() { throw new AssertionError("should not come here!"); } + + @Override + public SSLSession getHandshakeSession() { throw new AssertionError("should not come here!"); } + + @Override + public void versionNegotiated(QuicVersion quicVersion) { + // no-op + } + + @Override + public void setOneRttContext(QuicOneRttContext ctx) { + // no-op + } + } + private static final QuicTLSEngine TLS_ENGINE = new DummyQuicTLSEngine(); + private static abstract class TestCodingContext implements CodingContext { + TestCodingContext() { } + @Override + public int writePacket(QuicPacket packet, ByteBuffer buffer) { + throw new AssertionError("should not come here!"); + } + @Override + public QuicPacket parsePacket(ByteBuffer src) throws IOException { + throw new AssertionError("should not come here!"); + } + @Override + public boolean verifyToken(QuicConnectionId destinationID, byte[] token) { + return true; + } + @Override + public QuicTLSEngine getTLSEngine() { + return TLS_ENGINE; + } + } + + static final int CIDLEN = RANDOM.nextInt(5, QuicConnectionId.MAX_CONNECTION_ID_LENGTH + 1); + + private static final TestCodingContext CONTEXT = new TestCodingContext() { + + @Override + public long largestProcessedPN(PacketNumberSpace packetSpace) { + return 0; + } + + @Override + public long largestAckedPN(PacketNumberSpace packetSpace) { + return 0; + } + + @Override + public int connectionIdLength() { + return CIDLEN; + } + + @Override + public QuicConnectionId originalServerConnId() { + return null; + } + }; + + /** + * A record to store all the input of a given test case. + * @param type a concrete frame type or packet type + * @param describer a function to describe the {@code obj} instance + * for tracing/diagnosis purposes + * @param ackEliciting the function we want to test. This is either + * {@link QuicFrame#isAckEliciting() + * QuicFrame::isAckEliciting} or + * {@link QuicPacket#isAckEliciting() + * QuicPacket::isAckEliciting} + * @param obj the instance on which to call the {@code ackEliciting} + * function. + * @param expected the expected result of calling + * {@code obj.ackEliciting()} + * @param A concrete subclass of {@link QuicFrame} or {@link QuicPacket} + */ + static record TestCase(Class type, + Function describer, + Predicate ackEliciting, + T obj, + boolean expected) { + + @Override + public String toString() { + // shorter & better toString than the default + return "%s(%s)" + .formatted(type.getSimpleName(), describer.apply(obj)); + } + + private static String describeFrame(QuicFrame frame) { + long type = frame.getTypeField(); + return HexFormat.of().toHexDigits((byte)type); + } + + /** + * Creates an instance of {@code TestCase} for a concrete frame type + * @param type the concrete frame class + * @param frame the concrete instance + * @param expected whether {@link QuicFrame#isAckEliciting()} + * should return true for that instance. + * @param a concrete subclass of {@code QuicFrame} + * @return a new instance of {@code TestCase} + */ + public static TestCase + of(Class type, T frame, boolean expected) { + return new TestCase(type, TestCase::describeFrame, + QuicFrame::isAckEliciting, frame, expected); + } + + /** + * Creates an instance of {@code TestCase} for a concrete frame type + * @param frame the concrete frame instance + * @param expected whether {@link QuicFrame#isAckEliciting()} + * should return true for that instance. + * @param a concrete subclass of {@code QuicFrame} + * @return a new instance of {@code TestCase} + */ + public static TestCase of(T frame, boolean expected) { + return new TestCase((Class)frame.getClass(), + TestCase::describeFrame, + QuicFrame::isAckEliciting, + frame, expected); + } + + /** + * Creates an instance of {@code TestCase} for a concrete packet type + * @param packet the concrete packet instance + * @param expected whether {@link QuicPacket#isAckEliciting()} + * should return true for that instance. + * @param a concrete subclass of {@code QuicPacket} + * @return a new instance of {@code TestCase} + */ + public static TestCase of(T packet, boolean expected) { + return new TestCase((Class)packet.getClass(), + (p) -> p.frames().stream() + .map(Object::getClass) + .map(Class::getSimpleName) + .collect(Collectors.joining(", ")), + QuicPacket::isAckEliciting, + packet, expected); + } + } + + // convenient alias to shorten lines in data providers + private static TestCase of(T frame, boolean ackEliciting) { + return TestCase.of(frame, ackEliciting); + } + + // convenient alias to shorten lines in data providers + private static TestCase of(T packet, boolean ackEliciting) { + return TestCase.of(packet, ackEliciting); + } + + /** + * Create a new instance of the given frame type, populated with + * dummy values. + * @param frameClass the frame type + * @param a concrete subclass of {@code QuicFrame} + * @return a new instance of the given concrete class. + */ + T newFrame(Class frameClass) { + var frameType = QuicFrame.frameTypeOf(frameClass); + if (frameType == CONNECTION_CLOSE) { + if (RANDOM.nextBoolean()) { + frameType = CONNECTION_CLOSE_VARIANT; + } + } + long streamId = 4; + long largestAcknowledge = 3; + long ackDelay = 20; + long gap = 0; + long range = largestAcknowledge; + long offset = 0; + boolean fin = false; + int length = 10; + long errorCode = 1; + long finalSize = 10; + int size = 3; + long maxData = 10; + boolean maxStreamsBidi = true; + long maxStreams = 100; + long maxStreamData = 10; + long sequenceNumber = 4; + long retirePriorTo = 3; + String reason = "none"; + long errorFrameType = ACK; + int pathChallengeLen = PathChallengeFrame.LENGTH; + int pathResponseLen = PathResponseFrame.LENGTH; + var frame = switch (frameType) { + case ACK -> new AckFrame(largestAcknowledge, ackDelay, List.of(new AckRange(gap, range))); + case STREAM -> new StreamFrame(streamId, offset, length, fin, ByteBuffer.allocate(length)); + case RESET_STREAM -> new ResetStreamFrame(streamId, errorCode, finalSize); + case PADDING -> new PaddingFrame(size); + case PING -> new PingFrame(); + case STOP_SENDING -> new StopSendingFrame(streamId, errorCode); + case CRYPTO -> new CryptoFrame(offset, length, ByteBuffer.allocate(length)); + case NEW_TOKEN -> new NewTokenFrame(ByteBuffer.allocate(length)); + case DATA_BLOCKED -> new DataBlockedFrame(maxData); + case MAX_DATA -> new MaxDataFrame(maxData); + case MAX_STREAMS -> new MaxStreamsFrame(maxStreamsBidi, maxStreams); + case MAX_STREAM_DATA -> new MaxStreamDataFrame(streamId, maxStreamData); + case STREAM_DATA_BLOCKED -> new StreamDataBlockedFrame(streamId, maxStreamData); + case STREAMS_BLOCKED -> new StreamsBlockedFrame(maxStreamsBidi, maxStreams); + case NEW_CONNECTION_ID -> new NewConnectionIDFrame(sequenceNumber, retirePriorTo, + ByteBuffer.allocate(length), ByteBuffer.allocate(16)); + case RETIRE_CONNECTION_ID -> new RetireConnectionIDFrame(sequenceNumber); + case PATH_CHALLENGE -> new PathChallengeFrame(ByteBuffer.allocate(pathChallengeLen)); + case PATH_RESPONSE -> new PathResponseFrame(ByteBuffer.allocate(pathResponseLen)); + case CONNECTION_CLOSE -> new ConnectionCloseFrame(errorCode, errorFrameType, reason); + case CONNECTION_CLOSE_VARIANT -> new ConnectionCloseFrame(errorCode, reason); + case HANDSHAKE_DONE -> new HandshakeDoneFrame(); + default -> throw new IllegalArgumentException("Unrecognised frame"); + }; + return frameClass.cast(frame); + } + + /** + * Creates a list of {@code TestCase} to test all possible concrete + * subclasses of {@code QuicFrame}. + * + * @return a list of {@code TestCase} to test all possible concrete + * subclasses of {@code QuicFrame} + */ + public List> createFramesTests() { + List> frames = new ArrayList<>(); + frames.add(of(newFrame(AckFrame.class), false)); + frames.add(of(newFrame(ConnectionCloseFrame.class), false)); + frames.add(of(newFrame(PaddingFrame.class), false)); + + for (var frameType : QuicFrame.class.getPermittedSubclasses()) { + if (frameType == AckFrame.class) continue; + if (frameType == ConnectionCloseFrame.class) continue; + if (frameType == PaddingFrame.class) continue; + Class quicFrameClass = (Class)frameType; + frames.add(of(newFrame(quicFrameClass), true)); + } + + return List.copyOf(frames); + } + + /** + * Creates a {@code QuicPacket} containing the given list of frames. + * @param frames a list of frames + * @return a new instance of {@code QuicPacket} + */ + QuicPacket createPacket(List frames) { + PacketType[] values = PacketType.values(); + int index = PacketType.NONE.ordinal(); + while (index == PacketType.NONE.ordinal()) { + index = RANDOM.nextInt(0, values.length); + } + PacketType packetType = values[index]; + QuicPacketEncoder encoder = QuicPacketEncoder.of(QuicVersion.QUIC_V1); + byte[] scid = new byte[CIDLEN]; + RANDOM.nextBytes(scid); + byte[] dcid = new byte[CIDLEN]; + RANDOM.nextBytes(dcid); + QuicConnectionId source = new PeerConnectionId(scid); + QuicConnectionId dest = new PeerConnectionId(dcid); + long largestAckedPacket = CONTEXT.largestAckedPN(packetType); + QuicPacket packet = switch (packetType) { + case NONE -> throw new AssertionError("should not come here"); + // TODO: add more packet types + default -> encoder.newInitialPacket(source, dest, null, largestAckedPacket + 1 , + largestAckedPacket, frames, CONTEXT); + }; + return packet; + + } + + /** + * Creates a random instance of {@code QuicPacket} containing a + * pseudo random list of concrete {@link QuicFrame} instances. + * @param ackEliciting whether the returned packet should be + * ack eliciting. + * @return + */ + QuicPacket createPacket(boolean ackEliciting) { + List frames = new ArrayList<>(); + int mincount = ackEliciting ? 1 : 0; + int ackCount = RANDOM.nextInt(mincount, 5); + int nackCount = RANDOM.nextInt(0, 10); + + // TODO: maybe refactor this to make sure the frame + // we use are compatible with the packet type. + List> noAckFrames = List.of(AckFrame.class, + PaddingFrame.class, ConnectionCloseFrame.class); + for (int i=0; i < nackCount ; i++) { + frames.add(newFrame(noAckFrames.get(i % noAckFrames.size()))); + } + if (ackEliciting) { + // TODO: maybe refactor this to make sure the frame + // we use are compatible with the packet type. + Class[] frameClasses = QuicFrame.class.getPermittedSubclasses(); + for (int i=0; i < ackCount; i++) { + Class selected; + do { + int fx = RANDOM.nextInt(0, frameClasses.length); + selected = frameClasses[fx]; + } while (noAckFrames.contains(selected)); + frames.add(newFrame((Class) selected)); + } + } + if (!ackEliciting || RANDOM.nextBoolean()) { + // if !ackEliciting we always shuffle. + // Otherwise, we only shuffle half the time. + Collections.shuffle(frames, RANDOM); + } + return createPacket(mergeConsecutivePaddingFrames(frames)); + } + + private List mergeConsecutivePaddingFrames(List frames) { + var iterator = frames.listIterator(); + QuicFrame previous = null; + + while (iterator.hasNext()) { + var frame = iterator.next(); + if (previous instanceof PaddingFrame prevPad + && frame instanceof PaddingFrame nextPad) { + int previousIndex = iterator.previousIndex(); + QuicFrame merged = new PaddingFrame(prevPad.size() + nextPad.size()); + frames.set(previousIndex, merged); + iterator.remove(); + } else { + previous = frame; + } + } + return frames; + } + + /** + * Creates a list of {@code TestCase} to test random instances of + * {@code QuicPacket} containing random instances of {@link QuicFrame} + * @return a list of {@code TestCase} to test random instances of + * {@code QuicPacket} containing random instances of {@link QuicFrame} + */ + public List> createPacketsTests() { + List> packets = new ArrayList<>(); + packets.add(of(createPacket(List.of(newFrame(AckFrame.class))), false)); + packets.add(of(createPacket(List.of(newFrame(ConnectionCloseFrame.class))), false)); + packets.add(of(createPacket(List.of(newFrame(PaddingFrame.class))), false)); + var frames = new ArrayList<>(List.of( + newFrame(PaddingFrame.class), + newFrame(AckFrame.class), + newFrame(ConnectionCloseFrame.class))); + Collections.shuffle(frames, RANDOM); + packets.add(of(createPacket(List.copyOf(frames)), false)); + + int maxPackets = RANDOM.nextInt(5, 11); + for (int i = 0; i < maxPackets ; i++) { + packets.add(of(createPacket(true), true)); + } + return List.copyOf(packets); + } + + + /** + * A provider of test case to test + * {@link QuicFrame#isAckEliciting()}. + * @return test case to test + * {@link QuicFrame#isAckEliciting()} + */ + @DataProvider(name = "frames") + public Object[][] framesDataProvider() { + return createFramesTests().stream() + .map(List::of) + .map(List::toArray) + .toArray(Object[][]::new); + } + + /** + * A provider of test case to test + * {@link QuicPacket#isAckEliciting()}. + * @return test case to test + * {@link QuicPacket#isAckEliciting()} + */ + @DataProvider(name = "packets") + public Object[][] packetsDataProvider() { + return createPacketsTests().stream() + .map(List::of) + .map(List::toArray) + .toArray(Object[][]::new); + } + + /** + * Verifies the behavior of {@link QuicFrame#isAckEliciting()} + * with the given test case inputs. + * @param test the test inputs + * @param a concrete subclass of QuicFrame + */ + @Test(dataProvider = "frames") + public void testFrames(TestCase test) { + testAckEliciting(test.type(), + test.describer(), + test.ackEliciting(), + test.obj(), + test.expected()); + } + + /** + * Verifies the behavior of {@link QuicPacket#isAckEliciting()} + * with the given test case inputs. + * @param test the test inputs + * @param a concrete subclass of QuickPacket + */ + @Test(dataProvider = "packets") + public void testPackets(TestCase test) { + testAckEliciting(test.type(), + test.describer(), + test.ackEliciting(), + test.obj(), + test.expected()); + } + + + /** + * Asserts that {@code ackEliciting.test(obj) == expected}. + * @param type the concrete type of {@code obj} + * @param describer a function to describe {@code obj} + * @param ackEliciting the function being tested + * @param obj the instance on which to call the function being tested + * @param expected the expected result of {@code ackEliciting.test(obj)} + * @param the concrete class being tested + */ + private void testAckEliciting(Class type, + Function describer, + Predicate ackEliciting, + T obj, + boolean expected) { + System.out.printf("%ntestAckEliciting: %s(%s) - expecting %s%n", + type.getSimpleName(), + describer.apply(obj), + expected); + assertEquals(ackEliciting.test(obj), expected, describer.apply(obj)); + if (obj instanceof QuicFrame frame) { + checkFrame(frame); + } else if (obj instanceof QuicPacket packet) { + checkPacket(packet); + } + } + + // This is not a full-fledged test for frame encoding/decoding. + // Just a smoke test to verify that the ACK-eliciting property + // survives encoding/decoding + private void checkFrame(QuicFrame frame) { + int size = frame.size(); + ByteBuffer buffer = ByteBuffer.allocate(size); + System.out.println("Checking frame: " + frame.getClass()); + try { + frame.encode(buffer); + buffer.flip(); + var decoded = QuicFrame.decode(buffer); + checkFrame(decoded, frame); + } catch (QuicTransportException x) { + throw new AssertionError(frame.getClass().getName(), x); + } + } + + // This is not a full-fledged test for frame equality: + // And we still need a proper decoding/encoding test for frames + private void checkFrame(QuicFrame decoded, QuicFrame expected) { + System.out.printf("Comparing frames: %s with %s%n", + decoded.getClass().getSimpleName(), + expected.getClass().getSimpleName()); + assertEquals(decoded.getClass(), expected.getClass()); + assertEquals(decoded.size(), expected.size()); + assertEquals(decoded.getTypeField(), expected.getTypeField()); + assertEquals(decoded.isAckEliciting(), expected.isAckEliciting()); + } + + // This is not a full-fledged test for packet encoding/decoding. + // Just a smoke test to verify that the ACK-eliciting property + // survives encoding/decoding + private void checkPacket(QuicPacket packet) { + int size = packet.size(); + ByteBuffer buffer = ByteBuffer.allocate(size); + System.out.println("Checking packet: " + packet.getClass()); + try { + var encoder = QuicPacketEncoder.of(QuicVersion.QUIC_V1); + var decoder = QuicPacketDecoder.of(QuicVersion.QUIC_V1); + encoder.encode(packet, buffer, CONTEXT); + buffer.flip(); + var decoded = decoder.decode(buffer, CONTEXT); + assertEquals(decoded.size(), packet.size()); + assertEquals(decoded.packetType(), packet.packetType()); + assertEquals(decoded.payloadSize(), packet.payloadSize()); + assertEquals(decoded.isAckEliciting(), packet.isAckEliciting()); + var frames = packet.frames(); + var decodedFrames = decoded.frames(); + assertEquals(decodedFrames.size(), frames.size()); + } catch (Exception x) { + throw new AssertionError(packet.getClass().getName(), x); + } + + } +} diff --git a/test/jdk/java/net/httpclient/quic/AckFrameTest.java b/test/jdk/java/net/httpclient/quic/AckFrameTest.java new file mode 100644 index 00000000000..129394f126c --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/AckFrameTest.java @@ -0,0 +1,387 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.quic.CodingContext; +import jdk.internal.net.http.quic.frames.AckFrame; +import jdk.internal.net.http.quic.frames.AckFrame.AckFrameBuilder; +import jdk.internal.net.http.quic.frames.AckFrame.AckRange; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace; +import jdk.internal.net.http.quic.QuicConnectionId; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.test.lib.RandomFactory; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.function.LongPredicate; +import java.util.stream.LongStream; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.assertFalse; + +/** + * @test + * @summary tests the logic to build an AckFrame + * @library /test/lib + * @run testng AckFrameTest + */ +public class AckFrameTest { + + static final Random RANDOM = RandomFactory.getRandom(); + + private static abstract class TestCodingContext implements CodingContext { + TestCodingContext() { } + @Override + public int writePacket(QuicPacket packet, ByteBuffer buffer) { + throw new AssertionError("should not come here!"); + } + @Override + public QuicPacket parsePacket(ByteBuffer src) throws IOException { + throw new AssertionError("should not come here!"); + } + @Override + public boolean verifyToken(QuicConnectionId destinationID, byte[] token) { + return true; + } + @Override + public QuicTLSEngine getTLSEngine() { + throw new AssertionError("should not come here!"); + } + } + + static final int CIDLEN = RANDOM.nextInt(5, QuicConnectionId.MAX_CONNECTION_ID_LENGTH + 1); + + private static final TestCodingContext CONTEXT = new TestCodingContext() { + + @Override + public long largestProcessedPN(PacketNumberSpace packetSpace) { + return 0; + } + + @Override + public long largestAckedPN(PacketNumberSpace packetSpace) { + return 0; + } + + @Override + public int connectionIdLength() { + return CIDLEN; + } + + @Override + public QuicConnectionId originalServerConnId() { + return null; + } + }; + + public static record Acknowledged(long first, long last) { + public boolean contains(long packet) { + return first <= packet && last >= packet; + } + public static List of(long... numbers) { + if (numbers == null || numbers.length == 0) return List.of(); + if (numbers.length%2 != 0) throw new IllegalArgumentException(); + List res = new ArrayList<>(numbers.length/2); + for (int i = 0; i < numbers.length; i += 2) { + res.add(new Acknowledged(numbers[i], numbers[i+1])); + } + return List.copyOf(res); + } + } + public static record Packet(long packetNumber) { + static List ofAcks(List acks) { + return packets(acks); + } + static List of(long... numbers) { + return LongStream.of(numbers).mapToObj(Packet::new).toList(); + } + } + + public static record TestCase(List acks, List packets, boolean shuffled) { + public TestCase(List acks) { + this(acks, Packet.ofAcks(acks), false); + } + public TestCase shuffle() { + List shuffled = new ArrayList<>(); + shuffled.addAll(packets); + Collections.shuffle(shuffled, RANDOM); + return new TestCase(acks, List.copyOf(shuffled), true); + } + } + + List generateTests() { + List tests = new ArrayList<>(); + List simples = List.of( + new TestCase(List.of(new Acknowledged(5,5))), + new TestCase(List.of(new Acknowledged(5,7))), + new TestCase(List.of(new Acknowledged(3, 5), new Acknowledged(7,9))), + new TestCase(List.of(new Acknowledged(3, 5), new Acknowledged(7,7))), + new TestCase(List.of(new Acknowledged(3,3), new Acknowledged(5,7))) + ); + tests.addAll(simples); + List specials = List.of( + new TestCase(Acknowledged.of(5,5,7,7), Packet.of(5,7), false), + new TestCase(Acknowledged.of(5,7), Packet.of(5,7,6), true), + new TestCase(Acknowledged.of(6,7), Packet.of(6,7), false), + new TestCase(Acknowledged.of(5,7), Packet.of(6,7,5), true), + new TestCase(Acknowledged.of(5,7), Packet.of(5,6,7), true), + new TestCase(Acknowledged.of(5,5,7,8), Packet.of(5, 7, 8), true), + new TestCase(Acknowledged.of(5,5,8,8), Packet.of(8, 5), true), + new TestCase(Acknowledged.of(5,5,7,8), Packet.of(8, 5, 7), true), + new TestCase(Acknowledged.of(3,5,7,9), Packet.of(8,5,7,4,9,3), true), + new TestCase(Acknowledged.of(27,27,31,31), + Packet.of(27, 31), true), + new TestCase(Acknowledged.of(27,27,29,29,31,31), + Packet.of(27, 31, 29), true), + new TestCase(Acknowledged.of(3,5,7,7,9,9,22,22,27,27,29,29,31,31), + Packet.of(4,22,27,31,9,29,7,5,3), true) + ); + tests.addAll(specials); + for (int i=0; i < 5; i++) { + List acks = generateAcks(); + List packets = packets(acks); + TestCase test = new TestCase(acks, List.copyOf(packets), false); + tests.add(test); + for (int j = 0; j < 5; j++) { + tests.add(test.shuffle()); + } + } + return tests; + } + + List generateAcks() { + int count = RANDOM.nextInt(3, 10); + List acks = new ArrayList<>(count); + long prev = -1; + for (int i=0; i packets(List acks) { + List res = new ArrayList<>(); + for (Acknowledged ack : acks) { + for (long i = ack.first() ; i<= ack.last() ; i++) { + var packet = new Packet(i); + assert !res.contains(packet); + res.add(packet); + } + } + return res; + } + + @DataProvider(name = "tests") + public Object[][] tests() { + return generateTests().stream() + .map(List::of) + .map(List::toArray) + .toArray(Object[][]::new); + } + + @Test(dataProvider = "tests") + public void testAckFrames(TestCase testCase) { + AckFrameBuilder builder = new AckFrameBuilder(); + List acks = testCase.acks; + List packets = testCase.packets; + long largest = packets.stream() + .mapToLong(Packet::packetNumber) + .max().getAsLong(); + System.out.printf("%ntestAckFrames(%s, %s)%n", acks, testCase.shuffled); + builder.ackDelay(250); + packets.stream().mapToLong(Packet::packetNumber).forEach(builder::addAck); + AckFrame frame = builder.build(); + System.out.printf(" -> %s%n", frame); + checkFrame(frame, testCase, packets, frame); + checkAcknowledging(builder::isAcknowledging, testCase, packets); + + AckFrameBuilder dup = new AckFrameBuilder(frame); + assertEquals(frame, dup.build()); + assertEquals(frame, builder.build()); + checkAcknowledging(dup::isAcknowledging, testCase, packets); + + packets.stream().mapToLong(Packet::packetNumber).forEach(builder::addAck); + checkFrame(builder.build(), testCase, packets, frame); + checkAcknowledging(builder::isAcknowledging, testCase, packets); + + packets.stream().mapToLong(Packet::packetNumber).forEach(dup::addAck); + checkFrame(dup.build(), testCase, packets, frame); + checkAcknowledging(dup::isAcknowledging, testCase, packets); + + AckFrameBuilder dupdup = new AckFrameBuilder(); + dupdup.ackDelay(250); + List dups = new ArrayList<>(packets); + dups.addAll(packets); + dups.addAll(packets); + Collections.shuffle(dups, RANDOM); + dups.stream().mapToLong(Packet::packetNumber).forEach(dupdup::addAck); + checkFrame(dupdup.build(), testCase, dups, frame); + checkAcknowledging(dupdup::isAcknowledging, testCase, packets); + + } + + private void checkFrame(AckFrame frame, TestCase testCase, List packets, AckFrame reference) { + long largest = testCase.packets.stream() + .mapToLong(Packet::packetNumber) + .max().getAsLong(); + assertEquals(frame.largestAcknowledged(), largest); + checkAcknowledging(frame::isAcknowledging, testCase, packets); + for (var ack : testCase.acks) { + checkRangeAcknowledged(frame, ack.first, ack.last); + } + assertEquals(frame, reference); + int size = frame.size(); + ByteBuffer buffer = ByteBuffer.allocate(size + 10); + buffer.position(5); + buffer.limit(size + 5); + try { + frame.encode(buffer); + assertEquals(buffer.position(), buffer.limit()); + buffer.position(5); + buffer.limit(buffer.capacity()); + var decoded = QuicFrame.decode(buffer); + assertEquals(buffer.position(), size + 5); + assertEquals(decoded, frame); + assertEquals(decoded, reference); + } catch (Exception e) { + throw new AssertionError("Can't encode or decode frame: " + frame, e); + } + } + + private void checkRangeAcknowledged(AckFrame frame, long first, long last) { + assertTrue(frame.isRangeAcknowledged(first, last), + "range [%s, %s] should be acked".formatted(first, last)); + if (first > 0) { + if (!frame.isAcknowledging(first - 1)) { + assertFalse(frame.isRangeAcknowledged(first -1, last), + "range [%s, %s] should not be acked".formatted(first -1, last)); + } else { + assertTrue(frame.isRangeAcknowledged(first - 1, last), + "range [%s, %s] should be acked".formatted(first - 1, last)); + if (frame.isAcknowledging(last + 1)) { + assertTrue(frame.isRangeAcknowledged(first -1, last + 1), + "range [%s, %s] should be acked".formatted(first -1, last+1)); + } + } + } + if (!frame.isAcknowledging(last + 1)) { + assertFalse(frame.isRangeAcknowledged(first, last + 1), + "range [%s, %s] should not be acked".formatted(first, last + 1)); + } else { + assertTrue(frame.isRangeAcknowledged(first, last+1), + "range [%s, %s] should be acked".formatted(first, last + 1)); + } + if (last - 1 >= first) { + assertTrue(frame.isRangeAcknowledged(first + 1, last), + "range [%s, %s] should be acked".formatted(first + 1, last)); + assertTrue(frame.isRangeAcknowledged(first, last - 1), + "range [%s, %s] should be acked".formatted(first, last - 1)); + } + if (last - 2 >= first) { + assertTrue(frame.isRangeAcknowledged(first + 1, last - 1), + "range [%s, %s] should be acked".formatted(first + 1, last - 1)); + } + } + + private void checkAcknowledging(LongPredicate isAckPredicate, + TestCase testCase, + List packets) { + long largest = testCase.packets.stream() + .mapToLong(Packet::packetNumber) + .max().getAsLong(); + for (long i = largest + 10; i >= 0; i--) { + long pn = i; + boolean expected = testCase.acks.stream().anyMatch((a) -> a.contains(pn)); + boolean isAcknowledging = isAckPredicate.test(pn); + if (isAcknowledging != expected && testCase.shuffled) { + System.out.printf(" -> %s%n", packets); + } + assertEquals(isAcknowledging, expected, String.valueOf(pn)); + } + for (var p : testCase.packets) { + boolean isAcknowledging = isAckPredicate.test(p.packetNumber); + if (!isAcknowledging && testCase.shuffled) { + System.out.printf(" -> %s%n", packets); + } + assertEquals(isAcknowledging, true, p.toString()); + } + } + + @Test + public void simpleTest() { + AckFrame frame = new AckFrame(1, 0, List.of(new AckRange(0,0))); + System.out.println("simpleTest: " + frame); + assertTrue(frame.isAcknowledging(1), "1 should be acked"); + assertFalse(frame.isAcknowledging(0), "0 should not be acked"); + assertFalse(frame.isAcknowledging(2), "2 should not be acked"); + assertEquals(frame.smallestAcknowledged(), 1); + assertEquals(frame.largestAcknowledged(), 1); + assertEquals(frame.acknowledged().toArray(), new long[] {1L}); + assertTrue(frame.isRangeAcknowledged(1,1), "[1,1] should be acked"); + assertFalse(frame.isRangeAcknowledged(0, 1), "[0,1] should not be acked"); + assertFalse(frame.isRangeAcknowledged(1, 2), "[1,2] should not be acked"); + assertFalse(frame.isRangeAcknowledged(0, 2), "[0,2] should not be acked"); + + frame = new AckFrame(1, 0, List.of(new AckRange(0,1))); + System.out.println("simpleTest: " + frame); + assertTrue(frame.isAcknowledging(1), "1 should be acked"); + assertTrue(frame.isAcknowledging(0), "0 should be acked"); + assertFalse(frame.isAcknowledging(2), "2 should not be acked"); + assertEquals(frame.smallestAcknowledged(), 0); + assertEquals(frame.largestAcknowledged(), 1); + assertEquals(frame.acknowledged().toArray(), new long[] {1L, 0L}); + assertTrue(frame.isRangeAcknowledged(0,0), "[0,0] should be acked"); + assertTrue(frame.isRangeAcknowledged(1,1), "[1,1] should be acked"); + assertTrue(frame.isRangeAcknowledged(0, 1), "[0,1] should be acked"); + assertFalse(frame.isRangeAcknowledged(1, 2), "[1,2] should not be acked"); + assertFalse(frame.isRangeAcknowledged(0, 2), "[0,2] should not be acked"); + + frame = new AckFrame(10, 0, List.of(new AckRange(0,3), new AckRange(2, 3))); + System.out.println("simpleTest: " + frame); + assertTrue(frame.isAcknowledging(10), "10 should be acked"); + assertTrue(frame.isAcknowledging(0), "0 should be acked"); + assertTrue(frame.isRangeAcknowledged(0, 3), "[0,3] should be acked"); + assertTrue(frame.isRangeAcknowledged(7, 10), "[7,10] should be acked"); + assertTrue(frame.isRangeAcknowledged(7, 10), "[0,2] should be acked"); + assertTrue(frame.isRangeAcknowledged(7, 10), "[1,3] should be acked"); + assertTrue(frame.isRangeAcknowledged(7, 10), "[1,2] should be acked"); + assertTrue(frame.isRangeAcknowledged(7, 10), "[7,9] should be acked"); + assertTrue(frame.isRangeAcknowledged(7, 10), "[8,10] should be acked"); + assertTrue(frame.isRangeAcknowledged(7, 10), "[8,9] should be acked"); + assertFalse(frame.isRangeAcknowledged(0, 10), "[0,10] should not be acked"); + assertFalse(frame.isRangeAcknowledged(4, 6), "[4,6] should not be acked"); + assertFalse(frame.isRangeAcknowledged(4, 6), "[3,7] should not be acked"); + assertFalse(frame.isRangeAcknowledged(4, 6), "[2,8] should not be acked"); + } + } diff --git a/test/jdk/java/net/httpclient/quic/BuffersReaderTest.java b/test/jdk/java/net/httpclient/quic/BuffersReaderTest.java new file mode 100644 index 00000000000..8d03f4265b3 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/BuffersReaderTest.java @@ -0,0 +1,520 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.LongStream; + +import jdk.internal.net.http.quic.BuffersReader; +import jdk.internal.net.http.quic.BuffersReader.ListBuffersReader; +import jdk.internal.net.http.quic.VariableLengthEncoder; + +import jdk.test.lib.RandomFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.junit.jupiter.api.Assertions.*; + + +/* + * @test + * @library /test/lib + * @modules java.net.http/jdk.internal.net.http.quic + * @run junit/othervm BuffersReaderTest + * @summary Tests various BuffersReader methods + * work as expected. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class BuffersReaderTest { + static final Class IAE = IllegalArgumentException.class; + + static final Random RAND = RandomFactory.getRandom(); + static final int GENERATED = 2; + + // describes a byte buffer at a given global offset + record BB(long globalOffset, int position, int length, int capacity) {} + record Simple(long position, long index, int expected) {} + + record TestCase(List bbs, List simples) {} + + + // describes a BuffersReader configuration composed of 5 bytes buffer + // added with various position, limit, and capacity (limit = position + length) + List specialCases = List.of(new TestCase(List.of( + new BB(0, 10, 10, 30), + new BB(10, 5, 10, 20), + new BB(20, 15, 10, 40), + new BB(30, 0, 10, 20), + new BB(40, 5, 10, 20)), + List.of(new Simple(11, 50, 40)) + )); + + + private List tests() { + int generated = 2; + List allcases = new ArrayList<>(specialCases.size() + GENERATED); + allcases.addAll(specialCases); + for (int i = 0; i < GENERATED; i++) { + allcases.add(new TestCase(generateBBs(), List.of())); + } + return allcases; + } + + private List generateBBs() { + var bbscount = RAND.nextInt(1, 11); + List bbs = new ArrayList<>(bbscount); + long globalOffset = 0; + for (int i = 0; i < bbscount; i++) { + int length = RAND.nextInt(1, 11); + int offset = RAND.nextInt(0,3); + int tail = RAND.nextInt(0,3); + bbs.add(new BB(globalOffset, offset, length, offset + length + tail)); + globalOffset += length; + } + return List.copyOf(bbs); + } + + + @Test + public void testGet() { + test("hello world".getBytes(StandardCharsets.US_ASCII), 2, 10); + } + + @Test + public void testGetPos6() { + test("May the road rise up to meet you".getBytes(StandardCharsets.US_ASCII), 6, 23); + } + + @Test + public void testGetPos0() { + test("May the wind always be at your back".getBytes(StandardCharsets.US_ASCII), 0, 29); + } + + public void test(byte[] values, int position, int limit) { + ByteBuffer bb = ByteBuffer.wrap(values); + bb.position(position); + bb.limit(limit); + + ListBuffersReader br = BuffersReader.list(bb); + assertEquals(br.position(), position); + assertEquals(br.limit(), limit); + for (int i = position; i < limit; i++) { + int j = limit - (i - position) - 1; + System.err.printf("%ntesting(v[i:%s]=%s, v[j:%s]=%s)%n", + i, values[i], j, values[j]); + assertEquals(br.position(), i); + System.err.printf("assertEquals((char)br.get(%s), (char)values[%s])%n", i, i); + assertEquals((char)br.get(i), (char)values[i]); + System.err.printf("assertEquals((char)br.get(%s), (char)values[%s])%n", j, j); + assertEquals((char)br.get(j), (char)values[j]); + assertEquals(br.position(), i); + System.err.printf("assertEquals((char)br.get(), (char)values[%s])%n", i); + assertEquals((char)br.get(), (char)values[i]); + assertEquals(br.position(), i+1); + System.err.printf("assertEquals((char)br.get(%s), (char)values[%s])%n", i, i); + assertEquals((char)br.get(i), (char)values[i]); + System.err.printf("assertEquals((char)br.get(%s), (char)values[%s])%n", j, j); + assertEquals((char)br.get(j), (char)values[j]); + } + assertEquals(br.position(), br.limit()); + br.release(); + assertEquals(br.position(), 0); + assertEquals(br.limit(), 0); + bb.position(0); + bb.limit(bb.capacity()); + int start = 0; + limit = bb.limit(); + br.add(bb); + + final int N = 3; + for (int i = 1 ; i < N; i++) { + ByteBuffer bbb = ByteBuffer.allocate(bb.limit() + 4); + bbb.put((byte)-1); + bbb.put((byte)-2); + bbb.put(bb.slice()); + bbb.put((byte)-3); + bbb.put((byte)-4); + bbb.position(2); + bbb.limit(2 + bb.limit()); + br.add(bbb); + } + + long read = br.read(); + for (int i = start; i < N*limit; i++) { + var vi = values[i%limit]; + var j = N*limit - i - 1; + var vj = values[j%limit]; + System.err.printf("%ndouble testing(v[i:%s]=%s, v[j:%s]=%s) position: %s%n", + i, vi, j, vj, br.position()); + assertEquals(br.get(i), vi); + assertEquals(br.get(j), vj); + assertEquals(br.get(), vi); + assertEquals(br.get(i), vi); + assertEquals(br.get(j), vj); + } + assertEquals(br.position(), N * values.length); + assertEquals(br.read() - read, N * values.length - start); + + if (N > 2) { + System.err.printf("testing getAndRelease()%n"); + br.position(values.length + position); + assertEquals(br.position(), values.length + position); + assertEquals(br.read() - read, values.length + position - start); + var bbl = br.getAndRelease(values.length); + assertEquals(bbl.size(), (position == 0 ? 1 : 2)); + // We expect bbl.getFirst() to be the second byte buffer, which will + // have an offset of 2. The position in that byte buffer + // should therefore be position + 2, since we moved the + // position of the buffers reader to values.length + + // position before calling getAndRelease. + assertEquals(position + 2, bbl.getFirst().position()); + int rstart = (int) bbl.getFirst().position(); + ListBuffersReader br2 = BuffersReader.list(bbl); + System.err.printf("position=%s, bbl[0].position=%s%n", position, rstart); + // br2 initial position should reflect the initial position + // of the first buffer in the bbl list. + assertEquals(br2.position(), rstart); + try { + br2.position(rstart - 1); + throw new AssertionError("Expected IllegalArgumentException not thrown"); + } catch (IllegalArgumentException iae) { + System.err.printf("Got expected exception" + + " trying to move before initial position: %s%n", iae); + } + assertEquals(br2.limit(), values.length + rstart); + for (int i = 0; i < values.length; i++) { + assertEquals(br2.get(), values[(i + position) % values.length]); + } + } + } + + // Encode the given length and then decodes it and compares + // the results, asserting various invariants along the way. + @Test + public void testEncodeDecodeVL() { + testEncodeDecodeVL(4611686018427387903L, 3); + } + + public void testEncodeDecodeVL(long length, int expectedPrefix) { + var actualSize = VariableLengthEncoder.getEncodedSize(length); + assertEquals(actualSize, 1 << expectedPrefix); + assertTrue(actualSize > 0, "length is negative or zero: " + actualSize); + assertTrue(actualSize < 9, "length is too big: " + actualSize); + + // Use different offsets for the position at which to encode/decode + for (int offset : List.of(10)) { + System.err.printf("Encode/Decode %s on %s bytes with offset %s%n", + length, actualSize, offset); + + // allocate buffers: one exact, one too short, one too long + ByteBuffer exact = ByteBuffer.allocate(actualSize + offset); + exact.position(offset); + ByteBuffer shorter = ByteBuffer.allocate(actualSize - 1 + offset); + shorter.position(offset); + ByteBuffer shorterref = ByteBuffer.allocate(actualSize - 1 + offset); + shorterref.position(offset); + ByteBuffer longer = ByteBuffer.allocate(actualSize + 10 + offset); + longer.position(offset); + + // attempt to encode with a buffer that has the exact size + var exactres = VariableLengthEncoder.encode(exact, length); + assertEquals(exactres, actualSize); + assertEquals(exact.position(), actualSize + offset); + assertFalse(exact.hasRemaining()); + + // attempt to encode with a buffer that has more bytes + var longres = VariableLengthEncoder.encode(longer, length); + assertEquals(longres, actualSize); + assertEquals(longer.position(), offset + actualSize); + assertEquals(longer.limit(), longer.capacity()); + assertEquals(longer.remaining(), 10); + + // compare encodings + + // first reset buffer positions for reading. + exact.position(offset); + longer.position(offset); + assertEquals(longer.mismatch(exact), actualSize); + assertEquals(exact.mismatch(longer), actualSize); + + // decode with a buffer that is missing the last + // byte... + var shortSlice = exact.duplicate(); + shortSlice.position(offset); + shortSlice.limit(offset + actualSize - 1); + ListBuffersReader br = BuffersReader.list(shortSlice); + var actualLength = VariableLengthEncoder.decode(br); + assertEquals(actualLength, -1L); + assertEquals(shortSlice.position(), offset); + assertEquals(shortSlice.limit(), offset + actualSize - 1); + assertEquals(br.position(), offset); + assertEquals(br.limit(), offset + actualSize - 1); + br.release(); + + // decode with the exact buffer + br = BuffersReader.list(exact); + actualLength = VariableLengthEncoder.decode(br); + assertEquals(actualLength, length); + assertEquals(exact.position(), offset + actualSize); + assertFalse(exact.hasRemaining()); + assertEquals(br.position(), offset + actualSize); + assertFalse(br.hasRemaining()); + br.release(); + assertEquals(br.read(), actualSize); + assertFalse(br.hasRemaining()); + + + // decode with the longer buffer + long read = br.read(); + assertEquals(br.limit(), 0); + assertEquals(br.position(), 0); + br.add(longer); + actualLength = VariableLengthEncoder.decode(br); + assertEquals(actualLength, length); + assertEquals(longer.position(), offset + actualSize); + assertEquals(longer.remaining(), 10); + assertEquals(br.position(), offset + actualSize); + assertEquals(br.remaining(), 10); + br.release(); + assertEquals(br.read() - read, actualSize); + assertEquals(br.remaining(), 10); + } + } + + @ParameterizedTest + @MethodSource("tests") + void testAbsolutes(TestCase testCase) { + + List bbs = testCase.bbs(); + // Add byte buffers that match the description in bbs to the BuffersReader. + // The byte buffer bytes that should never be read are set to -1, this way + // if a get returns -1 we know it's peeking outside the expected range. + // bytes at any valid readable position are set to (position - start) % 128 + var reader = BuffersReader.list(); + int val = 0; + for (var bb : bbs) { + var b = ByteBuffer.allocate(bb.capacity); + for (int i=0; i reader.get()); + System.err.printf("Got expected BufferUnderflowException for %s: %s%n", reader.position(), bue); + + if (!testCase.simples.isEmpty()) { + System.err.println("\n*** Simple tests\n"); + } + for (var simple : testCase.simples) { + System.err.printf("get(%s) with position=%s, expect %s%n", + simple.index, simple.position, simple.expected); + long p0 = reader.position(); + reader.position(simple.position); + assertEquals(reader.get(simple.index), simple.expected); + reader.position(p0); + assertEquals(reader.position(), reader.limit()); + } + + System.err.println("\n*** Testing BuffersReader::get(long)\n"); + for (long i=0; i < limit; i++) { + final long pos = i; + if (pos < start) { + var ioobe = assertThrows(IndexOutOfBoundsException.class, () -> reader.get(pos)); + System.err.printf("Got expected IndexOutOfBoundsException for %s: %s%n", pos, ioobe); + } else { + assertEquals(reader.get(pos), (pos - start) % 128, + "get failed at index " + pos + " " + + "(start: " + start + ", limit: " + limit + ")"); + } + } + System.err.println("\n*** Testing BuffersReader::position(long)\n"); + for (long i=0; i <= limit; i++) { + final long pos = limit-i; + final long rpos = i; + if (pos < start) { + try { + var iae = assertThrows(IAE, () -> reader.position(pos)); + System.err.printf("Got expected IllegalArgumentException for %s: %s%n", pos, iae); + } catch (AssertionError error) { + System.err.printf(error.getMessage() + " for start: %s, index: %s, limit: %s", + start, pos, limit); + throw error; + } + } else { + System.err.printf("> reader.position(%s -> %s)%n", reader.position(), pos); + reader.position(pos); + if (pos < limit) { + try { + assertEquals(reader.get(), (pos - start) % 128, + "get failed at index " + pos + " " + + "(start: " + start + ", limit: " + limit + ")"); + System.err.printf("> reader.position is now %s%n", reader.position()); + assertEquals(reader.read(), pos - start + 1); + } catch (RuntimeException x) { + System.err.println("get failed at index " + pos + + " (start: " + start + ", limit: " + limit + ")" + x); + throw x; + } + } + } + if (rpos >= start && rpos < limit) { + try { + System.err.printf("get(%s) with position=%s, expect %s%n", + rpos, reader.position(), (rpos - start) % 128); + assertEquals(reader.get(rpos), (rpos - start) % 128, + "get failed at index " + rpos + " " + + "(start: " + start + ", limit: " + limit + ")"); + } catch (RuntimeException x) { + System.err.println("get failed at index " + rpos + + " (start: " + start + ", limit: " + limit + ")" + x); + throw x; + } + } + assertEquals(reader.read(), reader.position() - start); + if (rpos < start) { + var iae = assertThrows(IAE, () -> reader.position(rpos)); + System.err.printf("Got expected IllegalArgumentException for %s: %s%n", rpos, iae); + } else { + System.err.printf("< reader.position(%s -> %s)%n", reader.position(), rpos); + reader.position(rpos); + if (rpos < limit) { + try { + assertEquals(reader.get(), (rpos - start) % 128, + "get failed at index " + rpos + " " + + "(start: " + start + ", limit: " + limit + ")"); + assertEquals(reader.read(), rpos - start + 1); + System.err.printf("< reader.position is now %s%n", reader.position()); + } catch (RuntimeException x) { + System.err.println("get failed at index " + rpos + + " (start: " + start + ", limit: " + limit + ")" + x); + throw x; + } + } + } + if (pos >= start && pos < limit) { + try { + System.err.printf("get(%s) with position=%s, expect %s%n", + pos, reader.position(), (pos - start) % 128); + assertEquals(reader.get(pos), (pos - start) % 128, + "get failed at index " + pos + " " + + "(start: " + start + ", limit: " + limit + ")"); + } catch (RuntimeException x) { + System.err.println("get failed at index " + pos + + " (start: " + start + ", limit: " + limit + ")" + x); + throw x; + } + } + assertEquals(reader.read(), reader.position() - start); + } + + System.err.println("\n*** Testing BuffersReader::position(rand1) and get(rand2)\n"); + List positions = LongStream.range(0, limit+1).mapToObj(Long::valueOf) + .collect(Collectors.toCollection(ArrayList::new)); + Collections.shuffle(positions, RAND); + List indices = LongStream.range(0, limit+1).mapToObj(Long::valueOf) + .collect(Collectors.toCollection(ArrayList::new)); + Collections.shuffle(indices, RAND); + for (int i = 0; i <= limit; i++) { + long pos = positions.get(i); + long index = indices.get(i); + System.err.printf("position(%s) -> get() -> get(%s)%n", pos, index); + if (pos < start) { + try { + var iae = assertThrows(IAE, () -> reader.position(pos)); + System.err.printf("Got expected IllegalArgumentException for %s: %s%n", pos, iae); + } catch (AssertionError error) { + System.err.printf(error.getMessage() + " for start: %s, index: %s, limit: %s", + start, pos, limit); + throw error; + } + } else { + System.err.printf("> reader.position(%s -> %s)%n", reader.position(), pos); + reader.position(pos); + if (pos < limit) { + try { + assertEquals(reader.get(), (pos - start) % 128, + "get failed at index " + pos + " " + + "(start: " + start + ", limit: " + limit + ")"); + System.err.printf("> reader.position is now %s%n", reader.position()); + assertEquals(reader.read(), pos - start + 1); + } catch (RuntimeException x) { + System.err.println("get failed at index " + pos + + " (start: " + start + ", limit: " + limit + ")" + x); + throw x; + } + } + } + if (index < start || index >= limit) { + var ioobe = assertThrows(IndexOutOfBoundsException.class, () -> reader.get(index)); + System.err.printf("Got expected IndexOutOfBoundsException for %s: %s%n", index, ioobe); + } else { + assertEquals(reader.get(index), (index - start) % 128, + "get failed at index " + index + " " + + "(start: " + start + ", limit: " + limit + ")"); + } + } + + } +} diff --git a/test/jdk/java/net/httpclient/quic/BuffersReaderVLTest.java b/test/jdk/java/net/httpclient/quic/BuffersReaderVLTest.java new file mode 100644 index 00000000000..5064bbd5cb3 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/BuffersReaderVLTest.java @@ -0,0 +1,325 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.nio.ByteBuffer; +import java.util.List; + +import jdk.internal.net.http.quic.BuffersReader; +import jdk.internal.net.http.quic.BuffersReader.ListBuffersReader; +import jdk.internal.net.http.quic.VariableLengthEncoder; +import jtreg.SkippedException; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; + +/* + * @test + * @library /test/lib + * @modules java.net.http/jdk.internal.net.http.quic + * @run testng/othervm BuffersReaderVLTest + * @summary Tests to check quic/util methods encode/decodeVariableLength methods + * work as expected. + */ +public class BuffersReaderVLTest { + static final Class IAE = IllegalArgumentException.class; + + @DataProvider(name = "decode invariants") + public Object[][] decodeInvariants() { + return new Object[][] + { + { new byte[]{7}, 7, 1 }, // 00 + { new byte[]{65, 11}, 267, 2 }, // 01 + { new byte[]{-65, 11, 22, 33}, 1057691169, 4 }, // 10 + { new byte[]{-1, 11, 22, 33, 44, 55, 66, 77}, 4542748980864827981L, 8 }, // 11 + { new byte[]{-1, -11, -22, -33, -44, -55, -66, -77}, 4608848040752168627L, 8 }, + { new byte[]{}, -1, 0 }, + { new byte[]{-65}, -1, 0 }, + }; + } + @DataProvider(name = "prefix invariants") + public Object[][] prefixInvariants() { + return new Object[][] + { + { Long.MAX_VALUE, 0, IAE }, + { 4611686018427387903L+1, 0, IAE }, + { 4611686018427387903L, 3, null }, + { 4611686018427387903L-1, 3, null }, + { 1073741823+1, 3, null }, + { 1073741823, 2, null }, // (length > (1L << 30)-1) + { 1073741823-1, 2, null }, + { 16383+1, 2, null }, + { 16383, 1, null }, // (length > (1L << 14)-1 + { 16383-1, 1, null }, + { 63+1, 1, null }, + { 63 , 0, null }, // (length > (1L << 6)-1 + { 63-1, 0, null }, + { 100, 1, null }, + { 10, 0, null }, + { 1, 0, null }, + { 0, 0, null }, // (length >= 0) + { -1, 0, IAE }, + { -10, 0, IAE }, + { -100, 0, IAE }, + { Long.MIN_VALUE, 0, IAE }, + { -4611686018427387903L-1, 0, IAE }, + { -4611686018427387903L, 0, IAE }, + { -4611686018427387903L+1, 0, IAE }, + { -1073741823-1, 0, IAE }, + { -1073741823, 0, IAE }, // (length > (1L << 30)-1) + { -1073741823+1, 0, IAE }, + { -16383-1, 0, IAE }, + { -16383, 0, IAE }, // (length > (1L << 14)-1 + { -16383+1, 0, IAE }, + { -63-1, 0, IAE }, + { -63 , 0, IAE }, // (length > (1L << 6)-1 + { -63+1, 0, IAE }, + }; + } + + @Test(dataProvider = "decode invariants") + public void testDecode(byte[] values, long expectedLength, int expectedPosition) { + ByteBuffer bb = ByteBuffer.wrap(values); + BuffersReader br = BuffersReader.list(bb); + var actualLength = VariableLengthEncoder.decode(br); + assertEquals(actualLength, expectedLength); + + var actualPosition = bb.position(); + assertEquals(actualPosition, expectedPosition); + assertEquals(br.position(), expectedPosition); + br.release(); + assertEquals(br.read(), expectedPosition); + } + + @Test(dataProvider = "decode invariants") + public void testPeek(byte[] values, long expectedLength, int expectedPosition) { + ByteBuffer bb = ByteBuffer.wrap(values); + BuffersReader br = BuffersReader.list(bb); + var actualLength = VariableLengthEncoder.peekEncodedValue(br, 0); + assertEquals(actualLength, expectedLength); + + var actualPosition = bb.position(); + assertEquals(actualPosition, 0); + assertEquals(br.position(), 0); + br.release(); + assertEquals(br.read(), 0); + } + + // Encode the given length and then decodes it and compares + // the results, asserting various invariants along the way. + @Test(dataProvider = "prefix invariants") + public void testEncodeDecode(long length, int expectedPrefix, Class exception) { + if (exception != null) { + assertThrows(exception, () -> VariableLengthEncoder.getEncodedSize(length)); + assertThrows(exception, () -> VariableLengthEncoder.encode(ByteBuffer.allocate(16), length)); + } else { + var actualSize = VariableLengthEncoder.getEncodedSize(length); + assertEquals(actualSize, 1 << expectedPrefix); + assertTrue(actualSize > 0, "length is negative or zero: " + actualSize); + assertTrue(actualSize < 9, "length is too big: " + actualSize); + + // Use different offsets for the position at which to encode/decode + for (int offset : List.of(0, 10)) { + System.out.printf("Encode/Decode %s on %s bytes with offset %s%n", + length, actualSize, offset); + + // allocate buffers: one exact, one too short, one too long + ByteBuffer exact = ByteBuffer.allocate(actualSize + offset); + exact.position(offset); + ByteBuffer shorter = ByteBuffer.allocate(actualSize - 1 + offset); + shorter.position(offset); + ByteBuffer shorterref = ByteBuffer.allocate(actualSize - 1 + offset); + shorterref.position(offset); + ByteBuffer longer = ByteBuffer.allocate(actualSize + 10 + offset); + longer.position(offset); + + // attempt to encode with a buffer too short + expectThrows(IAE, () -> VariableLengthEncoder.encode(shorter, length)); + assertEquals(shorter.position(), offset); + assertEquals(shorter.limit(), shorter.capacity()); + + assertEquals(shorter.mismatch(shorterref), -1); + assertEquals(shorterref.mismatch(shorter), -1); + + // attempt to encode with a buffer that has the exact size + var exactres = VariableLengthEncoder.encode(exact, length); + assertEquals(exactres, actualSize); + assertEquals(exact.position(), actualSize + offset); + assertFalse(exact.hasRemaining()); + + // attempt to encode with a buffer that has more bytes + var longres = VariableLengthEncoder.encode(longer, length); + assertEquals(longres, actualSize); + assertEquals(longer.position(), offset + actualSize); + assertEquals(longer.limit(), longer.capacity()); + assertEquals(longer.remaining(), 10); + + // compare encodings + + // first reset buffer positions for reading. + exact.position(offset); + longer.position(offset); + assertEquals(longer.mismatch(exact), actualSize); + assertEquals(exact.mismatch(longer), actualSize); + + // decode with a buffer that is missing the last + // byte... + var shortSlice = exact.duplicate(); + shortSlice.position(offset); + shortSlice.limit(offset + actualSize -1); + ListBuffersReader br = BuffersReader.list(shortSlice); + var actualLength = VariableLengthEncoder.decode(br); + assertEquals(actualLength, -1L); + assertEquals(shortSlice.position(), offset); + assertEquals(shortSlice.limit(), offset + actualSize - 1); + assertEquals(br.position(), offset); + assertEquals(br.limit(), offset + actualSize - 1); + br.release(); + + // decode with the exact buffer + br = BuffersReader.list(exact); + actualLength = VariableLengthEncoder.decode(br); + assertEquals(actualLength, length); + assertEquals(exact.position(), offset + actualSize); + assertFalse(exact.hasRemaining()); + assertEquals(br.position(), offset + actualSize); + assertFalse(br.hasRemaining()); + br.release(); + assertEquals(br.read(), actualSize); + assertFalse(br.hasRemaining()); + + + // decode with the longer buffer + long read = br.read(); + br.add(longer); + actualLength = VariableLengthEncoder.decode(br); + assertEquals(actualLength, length); + assertEquals(longer.position(), offset + actualSize); + assertEquals(longer.remaining(), 10); + assertEquals(br.position(), offset + actualSize); + assertEquals(br.remaining(), 10); + br.release(); + assertEquals(br.read() - read, actualSize); + assertEquals(br.remaining(), 10); + } + + } + } + + // Encode the given length and then peeks it and compares + // the results, asserting various invariants along the way. + @Test(dataProvider = "prefix invariants") + public void testEncodePeek(long length, int expectedPrefix, Class exception) { + if (exception != null) { + assertThrows(exception, () -> VariableLengthEncoder.getEncodedSize(length)); + assertThrows(exception, () -> VariableLengthEncoder.encode(ByteBuffer.allocate(16), length)); + return; + } + + var actualSize = VariableLengthEncoder.getEncodedSize(length); + assertEquals(actualSize, 1 << expectedPrefix); + assertTrue(actualSize > 0, "length is negative or zero: " + actualSize); + assertTrue(actualSize < 9, "length is too big: " + actualSize); + + // Use different offsets for the position at which to encode/decode + for (int offset : List.of(0, 10)) { + System.out.printf("Encode/Peek %s on %s bytes with offset %s%n", + length, actualSize, offset); + + // allocate buffers: one exact, one too long + ByteBuffer exact = ByteBuffer.allocate(actualSize + offset); + exact.position(offset); + ByteBuffer longer = ByteBuffer.allocate(actualSize + 10 + offset); + longer.position(offset); + + // attempt to encode with a buffer that has the exact size + var exactres = VariableLengthEncoder.encode(exact, length); + assertEquals(exactres, actualSize); + assertEquals(exact.position(), actualSize + offset); + assertFalse(exact.hasRemaining()); + + // attempt to encode with a buffer that has more bytes + var longres = VariableLengthEncoder.encode(longer, length); + assertEquals(longres, actualSize); + assertEquals(longer.position(), offset + actualSize); + assertEquals(longer.limit(), longer.capacity()); + assertEquals(longer.remaining(), 10); + + // compare encodings + + // first reset buffer positions for reading. + exact.position(offset); + longer.position(offset); + assertEquals(longer.mismatch(exact), actualSize); + assertEquals(exact.mismatch(longer), actualSize); + exact.position(0); + longer.position(0); + exact.limit(exact.capacity()); + longer.limit(longer.capacity()); + + // decode with a buffer that is missing the last + // byte... + var shortSlice = exact.duplicate(); + shortSlice.position(0); + shortSlice.limit(offset + actualSize - 1); + // need at least one byte to decode the size len... + var expectedSize = shortSlice.limit() <= offset ? -1 : actualSize; + assertEquals(VariableLengthEncoder.peekEncodedValueSize(shortSlice, offset), expectedSize); + var actualLength = VariableLengthEncoder.peekEncodedValue(shortSlice, offset); + assertEquals(actualLength, -1L); + assertEquals(shortSlice.position(), 0); + assertEquals(shortSlice.limit(), offset + actualSize - 1); + + // decode with the exact buffer + assertEquals(VariableLengthEncoder.peekEncodedValueSize(exact, offset), actualSize); + actualLength = VariableLengthEncoder.peekEncodedValue(exact, offset); + assertEquals(actualLength, length); + assertEquals(exact.position(), 0); + assertEquals(exact.limit(), exact.capacity()); + + // decode with the longer buffer + assertEquals(VariableLengthEncoder.peekEncodedValueSize(longer, offset), actualSize); + actualLength = VariableLengthEncoder.peekEncodedValue(longer, offset); + assertEquals(actualLength, length); + assertEquals(longer.position(), 0); + assertEquals(longer.limit(), longer.capacity()); + } + + } + + + private ByteBuffer getTestBuffer(long length, int capacity) { + return switch (capacity) { + case 0 -> ByteBuffer.allocate(1).put((byte) length); + case 1 -> ByteBuffer.allocate(capacity).put((byte) length); + case 2 -> ByteBuffer.allocate(capacity).putShort((short) length); + case 4 -> ByteBuffer.allocate(capacity).putInt((int) length); + case 8 -> ByteBuffer.allocate(capacity).putLong(length); + default -> throw new SkippedException("bad value used for capacity"); + }; + } +} diff --git a/test/jdk/java/net/httpclient/quic/ConnectionIDSTest.java b/test/jdk/java/net/httpclient/quic/ConnectionIDSTest.java new file mode 100644 index 00000000000..e7fabeeecac --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/ConnectionIDSTest.java @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HexFormat; +import java.util.List; +import java.util.Map; + +import jdk.internal.net.http.quic.QuicConnectionIdFactory; +import org.testng.annotations.Test; +import static org.testng.Assert.*; + +/** + * @test + * @run testng/othervm ConnectionIDSTest + */ +public class ConnectionIDSTest { + + record ConnID(long token, byte[] bytes) { + ConnID { + bytes = bytes.clone(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConnID connID = (ConnID) o; + return Arrays.equals(bytes, connID.bytes); + } + + @Override + public int hashCode() { + return Arrays.hashCode(bytes); + } + + @Override + public String toString() { + return "ConnID{" + + "token=" + token + + ", bytes=" + HexFormat.of().formatHex(bytes) + + '}'; + } + } + + @Test + public void testConnectionIDS() { + List ids = new ArrayList<>(); + + // regular test, for length in [-21, 21] + long previous = 0; + QuicConnectionIdFactory idFactory = QuicConnectionIdFactory.getClient(); + for (int length = -21; length <= 22 ; length++) { + int expectedLength = Math.min(length, 20); + expectedLength = Math.max(9, expectedLength); + long token = idFactory.newToken(); + assertEquals(token, previous +1); + previous = token; + var id = idFactory.newConnectionId(length, token); + var cid = new ConnID(token, id); + System.out.printf("%s: %s/%s%n", length, token, cid); + assertEquals(id.length, expectedLength); + assertEquals(idFactory.getConnectionIdLength(id), expectedLength); + assertEquals(idFactory.getConnectionIdToken(id), token); + ids.add(cid); + } + + // token length test, for token coded on [1, 8] bytes, + // with positive and negative values, for cid length=9 + // Ox7F, -Ox7F, 0x7F7F, -0x7F7F, etc... + previous = 0; + int length = 9; + for (int i=0; i<8; i++) { + long ptoken = (previous << 8) + 0x7F; + long ntoken = - ptoken; + previous = ptoken; + for (long token : List.of(ptoken, ntoken)) { + long expectedToken = token >= 0 ? token : -token -1; + var id = idFactory.newConnectionId(length, token); + var cid = new ConnID(expectedToken, id); + System.out.printf("%s: %s/%s%n", length, token, cid); + assertEquals(id.length, length); + assertEquals(idFactory.getConnectionIdLength(id), length); + assertEquals(idFactory.getConnectionIdToken(id), expectedToken); + ids.add(cid); + } + } + + // test token bounds, for various cid length... + var bounds = List.of(Long.MIN_VALUE, Long.MIN_VALUE + 1L, Long.MIN_VALUE + 255L, -1L, + 0L, 1L, Long.MAX_VALUE -255L, Long.MAX_VALUE - 1L, Long.MAX_VALUE); + // test the bounds twice to try to trigger duplicates with length = 9 + bounds = bounds.stream().mapMulti((n,c) -> {c.accept(n); c.accept(n);}).toList(); + for (length=9; length <= 20; length++) { + for (long token : bounds) { + long expectedToken = token >= 0 ? token : -token - 1; + var id = idFactory.newConnectionId(length, token); + var cid = new ConnID(expectedToken, id); + System.out.printf("%s: %s/%s%n", length, token, cid); + assertEquals(id.length, length); + assertEquals(idFactory.getConnectionIdLength(id), length); + assertEquals(idFactory.getConnectionIdToken(id), expectedToken); + ids.add(cid); + } + } + + // now verify uniqueness + Map tested = new HashMap(); + record duplicates(ConnID first, ConnID second) {} + List duplicates = new ArrayList<>(); + for (var cid : ids) { + if (tested.containsKey(cid)) { + var dup = new duplicates(tested.get(cid), cid); + System.out.printf("duplicate ids: %s%n", dup); + duplicates.add(dup); + } else { + tested.put(cid, cid); + } + } + + // some duplicates can be expected if the connection id is too short + // and the token value is too big; check and remove them + for (var iter = duplicates.iterator(); iter.hasNext(); ) { + var dup = iter.next(); + assertEquals(dup.first.token(), dup.second.token()); + assertEquals(dup.first.bytes().length, dup.second.bytes().length); + assertEquals(dup.first.bytes(), dup.second.bytes()); + long mask = 0x00FFFFFF00000000L; + for (int i=0; i<3; i++) { + mask = mask << 8; + assert (mask & 0xFF00000000000000L) != 0 + : "mask: " + Long.toHexString(mask); + if (dup.first.bytes().length == (9+i)) { + if ((dup.first.token() & mask) != 0L) { + iter.remove(); + System.out.println("duplicates expected due to lack of entropy: " + dup); + } + } + } + } + + // verify no unexpected duplicates + for (var dup : duplicates) { + System.out.println("unexpected duplicate: " + dup); + } + assertEquals(duplicates.size(), 0); + } +} diff --git a/test/jdk/java/net/httpclient/quic/CryptoWriterQueueTest.java b/test/jdk/java/net/httpclient/quic/CryptoWriterQueueTest.java new file mode 100644 index 00000000000..cfc62625605 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/CryptoWriterQueueTest.java @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.quic.frames.CryptoFrame; +import jdk.internal.net.http.quic.streams.CryptoWriterQueue; +import org.testng.annotations.Test; + +import java.nio.ByteBuffer; + +import static org.testng.Assert.*; + +/** + * @test + * @summary Tests jdk.internal.net.http.quic.streams,CryptoWriterQueue + * @modules java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.quic.frames + * @run testng CryptoWriterQueueTest + */ +public class CryptoWriterQueueTest { + + /** + * {@link CryptoWriterQueue#enqueue(ByteBuffer) enqueues} data from multiple ByteBuffer + * instances and then expects the {@link CryptoWriterQueue#produceFrame(int)} to process + * the enqueued data correctly. + */ + @Test + public void testProduceFrame() throws Exception { + final CryptoWriterQueue writerQueue = new CryptoWriterQueue(); + final ByteBuffer buff1 = createByteBuffer(83); + final ByteBuffer buff2 = createByteBuffer(1429); + final ByteBuffer buff3 = createByteBuffer(4); + // enqueue them + writerQueue.enqueue(buff1); + writerQueue.enqueue(buff2); + writerQueue.enqueue(buff3); + final int expectedRemaining = buff1.remaining() + buff2.remaining() + buff3.remaining(); + assertEquals(writerQueue.remaining(), expectedRemaining, + "Unexpected remaining bytes in CryptoWriterQueue"); + // create frame(s) from the enqueued buffers + final int maxPayloadSize = 1134; + while (writerQueue.remaining() > 0) { + final CryptoFrame frame = writerQueue.produceFrame(maxPayloadSize); + assertNotNull(frame, "Crypto frame is null"); + assertTrue(frame.size() <= maxPayloadSize, "Crypto size " + frame.size() + + " exceeds max payload size of " + maxPayloadSize); + } + } + + private static ByteBuffer createByteBuffer(final int numBytes) { + return ByteBuffer.wrap(new byte[numBytes]); + } + +} diff --git a/test/jdk/java/net/httpclient/quic/KeyUpdateTest.java b/test/jdk/java/net/httpclient/quic/KeyUpdateTest.java new file mode 100644 index 00000000000..631a054b2f9 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/KeyUpdateTest.java @@ -0,0 +1,272 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.Stack; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; + +import jdk.httpclient.test.lib.common.TestUtil; +import jdk.httpclient.test.lib.quic.ClientConnection; +import jdk.httpclient.test.lib.quic.ConnectedBidiStream; +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.httpclient.test.lib.quic.QuicServerHandler; +import jdk.httpclient.test.lib.quic.QuicStandaloneServer; +import jdk.internal.net.http.quic.QuicClient; +import jdk.internal.net.http.quic.QuicConnection; +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; +import sun.security.ssl.QuicTLSEngineImpl; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +/* + * @test + * @summary verifies the QUIC TLS key update process + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @modules java.net.http/jdk.internal.net.http + * java.net.http/jdk.internal.net.http.common + * java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.packets + * java.net.http/jdk.internal.net.http.quic.frames + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * + * @modules java.base/jdk.internal.util + * java.base/sun.security.ssl + * @build jdk.httpclient.test.lib.quic.QuicStandaloneServer + * jdk.httpclient.test.lib.quic.ClientConnection + * jdk.httpclient.test.lib.common.TestUtil + * jdk.test.lib.net.SimpleSSLContext + * @comment the test is run with -Djava.security.properties= to augment + * the master java.security file + * @run testng/othervm -Djava.security.properties=${test.src}/quic-tls-keylimits-java.security + * -Djdk.internal.httpclient.debug=true + * -Djavax.net.debug=all + * KeyUpdateTest + */ +public class KeyUpdateTest { + + private QuicStandaloneServer server; + private SSLContext sslContext; + private ExecutorService executor; + + private static final byte[] HELLO_MSG = "Hello Quic".getBytes(StandardCharsets.UTF_8); + private static final EchoHandler handler = new EchoHandler(HELLO_MSG.length); + + @BeforeClass + public void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + executor = Executors.newCachedThreadPool(); + server = QuicStandaloneServer.newBuilder() + .availableVersions(new QuicVersion[]{QuicVersion.QUIC_V1}) + .sslContext(sslContext) + .build(); + // add a handler which deals with incoming connections + server.addHandler(handler); + server.start(); + System.out.println("Server started at " + server.getAddress()); + } + + @AfterClass + public void afterClass() throws Exception { + if (server != null) { + System.out.println("Stopping server " + server.getAddress()); + server.close(); + } + if (executor != null) { + executor.close(); + } + } + + private QuicClient createClient() { + var versions = List.of(QuicVersion.QUIC_V1); + var context = new QuicTLSContext(sslContext); + var params = new SSLParameters(); + return new QuicClient.Builder() + .availableVersions(versions) + .tlsContext(context) + .sslParameters(params) + .executor(executor) + .bindAddress(TestUtil.chooseClientBindAddress().orElse(null)) + .build(); + } + + @Test + public void test() throws Exception { + try (final QuicClient client = createClient()) { + // create a QUIC connection to the server + final ClientConnection conn = ClientConnection.establishConnection(client, + server.getAddress()); + final Stack clientConnKeyPhases = new Stack<>(); + for (int i = 1; i <= 100; i++) { + System.out.println("Iteration: " + i); + // open a bidi stream + final ConnectedBidiStream bidiStream = conn.initiateNewBidiStream(); + // write data on the stream + try (final OutputStream os = bidiStream.outputStream()) { + os.write(HELLO_MSG); + System.out.println("client: Client wrote message to bidi stream's output stream"); + } + // wait for response + try (final InputStream is = bidiStream.inputStream()) { + System.out.println("client: reading from bidi stream's input stream"); + final byte[] data = is.readAllBytes(); + System.out.println("client: Received response of size " + data.length); + final String response = new String(data, StandardCharsets.UTF_8); + // verify response + System.out.println("client: Response: " + response); + if (!Arrays.equals(response.getBytes(StandardCharsets.UTF_8), HELLO_MSG)) { + throw new AssertionError("Unexpected response: " + response); + } + } finally { + System.err.println("client: Closing bidi stream from test"); + bidiStream.close(); + } + // keep track of the 1-RTT key phase that was used by the client connection + final int invocation = i; + getKeyPhase(conn.underlyingQuicConnection()).ifPresent((keyPhase) -> { + if (clientConnKeyPhases.empty() || clientConnKeyPhases.peek() != keyPhase) { + // new key phase detected, add it + clientConnKeyPhases.push(keyPhase); + System.out.println("Detected client 1-RTT key phase " + keyPhase + + " on connection " + conn + " for invocation " + invocation); + } + }); + } + // verify that the client and server did do a key update + // stacks should contain at least a sequence of 0, 1, 0 + System.out.println("Number of 1-RTT keys used by client connection: " + + clientConnKeyPhases.size() + ", key phase switches: " + clientConnKeyPhases); + System.out.println("Number of 1-RTT keys used by server connection: " + + handler.serverConnKeyPhases.size() + ", key phase switches: " + + handler.serverConnKeyPhases); + assertTrue(clientConnKeyPhases.size() >= 3, "Client connection" + + " didn't do a key update"); + assertTrue(handler.serverConnKeyPhases.size() >= 3, "Server connection" + + " didn't do a key update"); + + assertEquals(0, (int) clientConnKeyPhases.getFirst(), "Client connection used" + + " unexpected first key phase"); + assertEquals(0, (int) handler.serverConnKeyPhases.getFirst(), "Server connection used" + + " unexpected first key phase"); + + assertEquals(1, (int) clientConnKeyPhases.get(1), "Client connection used" + + " unexpected second key phase"); + assertEquals(1, (int) handler.serverConnKeyPhases.get(1), "Server connection used" + + " unexpected second key phase"); + + assertEquals(0, (int) clientConnKeyPhases.get(2), "Client connection used" + + " unexpected third key phase"); + assertEquals(0, (int) handler.serverConnKeyPhases.get(2), "Server connection used" + + " unexpected third key phase"); + } + } + + /** + * Reads data from incoming client initiated bidirectional stream of a Quic connection + * and writes back a response which is same as the read data + */ + private static final class EchoHandler implements QuicServerHandler { + + private final int numBytesToRead; + private final AtomicInteger numInvocations = new AtomicInteger(); + private final Stack serverConnKeyPhases = new Stack<>(); + + private EchoHandler(final int numBytesToRead) { + this.numBytesToRead = numBytesToRead; + } + + @Override + public void handleBidiStream(final QuicServerConnection conn, + final ConnectedBidiStream bidiStream) throws IOException { + final int invocation = numInvocations.incrementAndGet(); + System.out.println("Handling incoming bidi stream " + bidiStream + + " on connection " + conn); + // keep track of the 1-RTT key phase that was used by the server connection + getKeyPhase(conn).ifPresent((keyPhase) -> { + if (this.serverConnKeyPhases.empty() + || this.serverConnKeyPhases.peek() != keyPhase) { + // new key phase detected, add it + this.serverConnKeyPhases.push(keyPhase); + System.out.println("Detected server 1-RTT key phase " + keyPhase + + " on connection " + conn + " for invocation " + invocation); + } + }); + final byte[] data; + // read the request content + try (final InputStream is = bidiStream.inputStream()) { + System.out.println("Handler reading data from bidi stream's inputstream " + is); + data = is.readAllBytes(); + System.out.println("Handler read " + data.length + " bytes of data"); + } + if (data.length != numBytesToRead) { + throw new IOException("Expected to read " + numBytesToRead + + " bytes but read only " + data.length + " bytes"); + } + // write response + try (final OutputStream os = bidiStream.outputStream()) { + System.out.println("Handler writing data to bidi stream's outputstream " + os); + os.write(data); + } + System.out.println("Handler invocation complete"); + } + } + + private static Optional getKeyPhase(final QuicConnection conn) throws IOException { + if (!(conn.getTLSEngine() instanceof QuicTLSEngineImpl qtls)) { + return Optional.empty(); + } + final int keyPhase; + try { + keyPhase = qtls.getOneRttKeyPhase(); + } catch (QuicKeyUnavailableException e) { + throw new IOException("failed to get key phase, reason: " + e.getMessage()); + } + if (keyPhase != 0 && keyPhase != 1) { + throw new IOException("Unexpected 1-RTT key phase on connection: " + conn); + } + return Optional.of(keyPhase); + } +} diff --git a/test/jdk/java/net/httpclient/quic/OrderedFlowTest.java b/test/jdk/java/net/httpclient/quic/OrderedFlowTest.java new file mode 100644 index 00000000000..cb95620c00e --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/OrderedFlowTest.java @@ -0,0 +1,401 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Random; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.function.ToLongFunction; +import java.util.stream.Collectors; + +import jdk.internal.net.http.quic.OrderedFlow; +import jdk.internal.net.http.quic.frames.CryptoFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.frames.StreamFrame; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +/** + * @test + * @summary tests the reordering logic implemented by OrderedFlow + * and its two concrete subclasses + * @library /test/lib + * @run testng OrderedFlowTest + * @run testng/othervm -Dseed=-2680947227866359853 OrderedFlowTest + * @run testng/othervm -Dseed=-273117134353023275 OrderedFlowTest + * @run testng/othervm -Dseed=3649132517916066643 OrderedFlowTest + * @run testng/othervm -Dseed=4568737726943220431 OrderedFlowTest + */ +public class OrderedFlowTest { + + static final int WITH_DUPS = 1; + static final int WITH_OVERLAPS = 2; + + record TestData(Class frameType, + Supplier> flowSupplier, + Function payloadAccessor, + Comparator framesComparator, + List frames, + String expectedResult, + boolean duplicates, + boolean shuffled) { + + boolean hasEmptyFrames() { + return frames.stream().map(payloadAccessor) + .mapToInt(String::length) + .anyMatch((i) -> i == 0); + } + + @Override + public String toString() { + return frameType.getSimpleName() + + "(frames=" + frames.size() + + ", duplicates=" + duplicates + + ", shuffled=" + shuffled + + ", hasEmptyFrames=" + hasEmptyFrames() + + ")"; + } + } + + static final Random RANDOM = jdk.test.lib.RandomFactory.getRandom(); + static final String LOREM = """ + Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer + id elementum sem. In rhoncus nisi a ante convallis, at iaculis augue + elementum. Ut eget imperdiet justo, sed sodales est. In nec laoreet + lorem. Integer et arcu nibh. Quisque quis felis consectetur, luctus + libero eu, facilisis risus. Aliquam at viverra diam. Sed nec lacus + eget dui hendrerit porttitor et nec ligula. Suspendisse rutrum, + augue non ultricies vestibulum, metus orci faucibus est, et tempus + ante diam a quam. Proin venenatis justo eleifend vestibulum tincidunt. + + Nullam nec elementum sem. Class aptent taciti sociosqu ad litora + torquent per conubia nostra, per inceptos himenaeos. Integer euismod, + purus ut sollicitudin semper, quam turpis condimentum arcu, sit amet + suscipit sapien elit ac nisi. Orci varius natoque penatibus et magnis + dis parturient montes, nascetur ridiculus mus. Duis vel tortor non purus + scelerisque iaculis at efficitur dolor. Nam dapibus tellus non aliquet + suscipit. Nulla facilisis mi eget ex blandit sodales. Pellentesque enim + sem, aliquet non luctus id, feugiat in eros. Aliquam molestie felis + lorem, eget tristique nisi mollis lobortis. + + Suspendisse aliquam vitae purus nec mollis. Quisque et urna nec nunc + porttitor blandit quis a magna. Maecenas porta est velit, in volutpat + felis suscipit eget. Vivamus porta semper ipsum, et sodales nibh molestie + eget. Vivamus tincidunt quam id ante efficitur tincidunt. Suspendisse + potenti. Integer posuere felis ut semper feugiat. Vivamus id dui quam. + + Pellentesque accumsan quam non est pretium faucibus. Donec vel euismod + magna, ac scelerisque mauris. Nullam vitae varius diam, hendrerit semper + velit. Vestibulum et nisl felis. Orci varius natoque penatibus et magnis + dis parturient montes, nascetur ridiculus mus. Cras elementum auctor lacus, + vel tempor erat lobortis sed. Suspendisse sed felis ut mi condimentum + eleifend. Proin et arcu cursus, fermentum arcu non, tristique nulla. + Suspendisse tristique volutpat elit, et blandit metus aliquet id. Nunc non + dapibus dui. Nam sagittis justo magna. Nulla pharetra ex nec sem porta + consequat. + + Nam sit amet luctus ante, nec eleifend nunc. Phasellus lobortis lorem a + auctor ornare. Sed venenatis fermentum arcu, ut tincidunt turpis auctor + at. Praesent felis mi, tincidunt a sem et, luctus condimentum libero. + Phasellus egestas ac lectus vitae tincidunt. Etiam eu lobortis felis. + Nulla semper est ac nisl placerat, vitae sollicitudin diam lobortis. + Cras pellentesque semper purus at rutrum. Suspendisse a pellentesque + orci, ac tincidunt libero. Integer ex augue, ultrices sit amet aliquam + eget, laoreet eget elit. + """; + + interface FramesFactory { + public T create(int offset, String payload, boolean fin); + public int length(T frame); + public long offset(T frame); + public String getPayload(T frame); + public OrderedFlow flow(); + public Class frameType(); + public Comparator comparator(); + } + + static class StreamFrameFactory implements FramesFactory { + final long streamId = RANDOM.nextInt(0, Integer.MAX_VALUE); + + @Override + public StreamFrame create(int offset, String payload, boolean fin) { + byte[] bytes = payload.getBytes(StandardCharsets.UTF_8); + int length = RANDOM.nextBoolean() ? bytes.length : -1; + return new StreamFrame(streamId, offset, length, + fin, ByteBuffer.wrap(bytes)); + } + + @Override + public int length(StreamFrame frame) { + return frame.dataLength(); + } + + @Override + public long offset(StreamFrame frame) { + return frame.offset(); + } + + @Override + public String getPayload(StreamFrame frame) { + int length = frame.dataLength(); + byte[] bytes = new byte[length]; + frame.payload().get(bytes); + return new String(bytes, StandardCharsets.UTF_8); + } + + @Override + public OrderedFlow flow() { + return new OrderedFlow.StreamDataFlow(); + } + + @Override + public Class frameType() { + return StreamFrame.class; + } + + @Override + public Comparator comparator() { + return StreamFrame::compareOffsets; + } + } + + static class CryptoFrameFactory implements FramesFactory { + final long streamId = RANDOM.nextInt(0, Integer.MAX_VALUE); + + @Override + public CryptoFrame create(int offset, String payload, boolean fin) { + byte[] bytes = payload.getBytes(StandardCharsets.UTF_8); + int length = bytes.length; + return new CryptoFrame(offset, length, ByteBuffer.wrap(bytes)); + } + + @Override + public int length(CryptoFrame frame) { + return frame.length(); + } + + @Override + public long offset(CryptoFrame frame) { + return frame.offset(); + } + + @Override + public String getPayload(CryptoFrame frame) { + int length = frame.length(); + byte[] bytes = new byte[length]; + frame.payload().get(bytes); + return new String(bytes, StandardCharsets.UTF_8); + } + + @Override + public OrderedFlow flow() { + return new OrderedFlow.CryptoDataFlow(); + } + + @Override + public Class frameType() { + return CryptoFrame.class; + } + + @Override + public Comparator comparator() { + return CryptoFrame::compareOffsets; + } + } + + static TestData generateData(FramesFactory factory, int options) { + int length = LOREM.length(); + int chunks = length/20; + int offset = 0; + int remaining = length; + List frames = new ArrayList<>(); + T first = null; + T second = null; + boolean duplicates = (options & WITH_DUPS) == WITH_DUPS; + boolean overlaps = (options & WITH_OVERLAPS) == WITH_OVERLAPS; + while (remaining > 0) { + int len = remaining < 20 + ? remaining + : RANDOM.nextInt(Math.min(19, remaining - 1), Math.min(chunks, remaining)); + remaining -= len; + String data = LOREM.substring(offset, offset + len); + T frame; + if (overlaps && len > 4) { + int start = RANDOM.nextInt(0, len/4); + int end = RANDOM.nextInt(3*len/4, len); + for (int i=start; i < end; i+=2) { + frame = factory.create(offset+i, data.substring(i, i+1), + i == len-1 && remaining == 0); + frames.add(frame); + } + } + frame = factory.create(offset, data, remaining == 0); + frames.add(frame); + if (first == null) first = frame; + else if (second == null) second = frame; + if (duplicates && RANDOM.nextInt(1, 5) > 3) { + frames.add(factory.create(offset, data, remaining == 0)); + } else if (overlaps && RANDOM.nextInt(1, 5) > 3 && second != null && len > 1) { + // next frame will overlap with this one. + offset -= len / 2; remaining += len / 2; + } + offset += len; + } + if (duplicates) frames.add(first); + if (overlaps && frames.size() > 1) { + if (factory.length(first) > 0) { + frames.remove(second); + String firstpayload = factory.getPayload(first); + String secondpayload = factory.getPayload(second); + String newpayload = firstpayload.charAt(firstpayload.length() - 1) + + secondpayload; + long newoffset = factory.offset(second) - 1; + assert newoffset >= 0; + frames.add(1, factory.create((int) newoffset, newpayload, frames.size() == 2)); + } + } + return new TestData<>(factory.frameType(), factory::flow, + factory::getPayload, factory.comparator(), + List.copyOf(frames), LOREM, + duplicates, false); + } + + // Returns a new data set where all frames have been shuffled randomly. + // This should help flush bugs with buffering of frames that come out of order. + static TestData shuffle(TestData data) { + List shuffled = new ArrayList<>(data.frames()); + Collections.shuffle(shuffled, RANDOM); + return new TestData<>(data.frameType(),data.flowSupplier(), data.payloadAccessor(), + data.framesComparator(), List.copyOf(shuffled), data.expectedResult(), + data.duplicates(), true); + } + + // Returns a new data set where all frames have been sorted in reverse + // order: largest offset first. This is the worst case scenario for + // buffering. This should help checking that the amount of data buffered + // never exceeds the length of the stream, as duplicates and overlaps should + // not be buffered. + static TestData reversed(TestData data) { + List sorted = new ArrayList<>(data.frames()); + Collections.sort(sorted, data.framesComparator().reversed()); + return new TestData<>(data.frameType(),data.flowSupplier(), data.payloadAccessor(), + data.framesComparator(), List.copyOf(sorted), data.expectedResult(), + data.duplicates(), true); + } + + static List> generateData(FramesFactory factory) { + List> result = new ArrayList<>(); + TestData data = generateData(factory, 0); + TestData withdups = generateData(factory, WITH_DUPS); + TestData withoverlaps = generateData(factory, WITH_OVERLAPS); + TestData withall = generateData(factory, WITH_DUPS | WITH_OVERLAPS); + result.add(data); + result.add(withdups); + result.add(withoverlaps); + result.add(withall); + result.add(reversed(data)); + result.add(reversed(withdups)); + result.add(reversed(withoverlaps)); + result.add(reversed(withall)); + for (int i=0; i<5; i++) { + result.add(shuffle(data)); + result.add(shuffle(withdups)); + result.add(shuffle(withoverlaps)); + result.add(shuffle(withall)); + } + return List.copyOf(result); + } + + @DataProvider(name="CryptoFrame") + Object[][] generateCryptoFrames() { + return generateData(new CryptoFrameFactory()) + .stream() + .map(List::of) + .map(List::toArray) + .toArray(Object[][]::new); + } + + @DataProvider(name="StreamFrame") + Object[][] generateStreanFrames() { + return generateData(new StreamFrameFactory()) + .stream() + .map(List::of) + .map(List::toArray) + .toArray(Object[][]::new); + } + + private void testOrderedFlow(TestData testData, ToLongFunction offset) { + System.out.println("\n ---------------- " + + testData.frameType().getName() + + " ---------------- \n"); + System.out.println("testOrderedFlow: " + testData); + String offsets = testData.frames().stream().mapToLong(offset) + .mapToObj(Long::toString).collect(Collectors.joining(", ")); + System.out.println("offsets: " + offsets); + + // we should not have empty frames, but maybe we do? + // if we do - should we make allowance for that? + var hasEmptyFrames = testData.hasEmptyFrames(); + assertFalse(hasEmptyFrames, "generated data has empty frames"); + + var flow = testData.flowSupplier().get(); + var size = LOREM.length(); + StringBuilder result = new StringBuilder(size); + long maxBuffered = 0; + for (var f : testData.frames()) { + T received = flow.receive(f); + var buffered = flow.buffered(); + maxBuffered = Math.max(buffered, maxBuffered); + assertTrue(buffered < size, + "buffered data %s exceeds or equals payload size %s".formatted(buffered, size)); + while (received != null) { + var payload = testData.payloadAccessor.apply(received); + assertNotEquals(payload, "", "empty frames not expected: " + received); + result.append(payload); + received = flow.poll(); + } + } + assertEquals(result.toString(), testData.expectedResult); + } + + @Test(dataProvider = "CryptoFrame") + public void testCryptoFlow(TestData testData) { + testOrderedFlow(testData, CryptoFrame::offset); + } + + @Test(dataProvider = "StreamFrame") + public void testStreamFlow(TestData testData) { + testOrderedFlow(testData, StreamFrame::offset); + } + +} diff --git a/test/jdk/java/net/httpclient/quic/PacketEncodingTest.java b/test/jdk/java/net/httpclient/quic/PacketEncodingTest.java new file mode 100644 index 00000000000..495e8d41403 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/PacketEncodingTest.java @@ -0,0 +1,1440 @@ +/* + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.quic.CodingContext; +import jdk.internal.net.http.quic.PeerConnectionId; +import jdk.internal.net.http.quic.QuicConnectionIdFactory; +import jdk.internal.net.http.quic.packets.LongHeader; +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicOneRttContext; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.quic.QuicVersion; +import jdk.internal.net.http.quic.frames.CryptoFrame; +import jdk.internal.net.http.quic.frames.PaddingFrame; +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.http.quic.packets.HandshakePacket; +import jdk.internal.net.http.quic.packets.InitialPacket; +import jdk.internal.net.http.quic.packets.LongHeaderPacket; +import jdk.internal.net.http.quic.packets.OneRttPacket; +import jdk.internal.net.http.quic.packets.QuicPacket; +import jdk.internal.net.http.quic.packets.QuicPacket.HeadersType; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketNumberSpace; +import jdk.internal.net.http.quic.packets.QuicPacket.PacketType; +import jdk.internal.net.http.quic.packets.QuicPacketDecoder; +import jdk.internal.net.http.quic.packets.QuicPacketEncoder; +import jdk.internal.net.http.quic.packets.QuicPacketNumbers; +import jdk.internal.net.http.quic.packets.RetryPacket; +import jdk.internal.net.http.quic.packets.ShortHeaderPacket; +import jdk.internal.net.http.quic.packets.VersionNegotiationPacket; +import jdk.internal.net.http.quic.packets.ZeroRttPacket; +import jdk.internal.net.http.quic.QuicConnectionId; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportParametersConsumer; +import jdk.internal.net.http.quic.VariableLengthEncoder; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import javax.crypto.AEADBadTagException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HexFormat; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.IntFunction; +import java.util.stream.Collectors; + +import static jdk.internal.net.http.quic.packets.QuicPacketNumbers.computePacketNumberLength; +import static org.testng.Assert.*; + +/** + * @test + * @library /test/lib + * @summary test packet encoding and decoding in unencrypted form and without + * any network involvement. + * @run testng/othervm -Dseed=2646683818688275736 PacketEncodingTest + * @run testng/othervm -Dseed=-3723256402256409075 PacketEncodingTest + * @run testng/othervm -Dseed=-3689060484817342283 PacketEncodingTest + * @run testng/othervm -Dseed=2425718686525936108 PacketEncodingTest + * @run testng/othervm -Dseed=-2996954753243104355 PacketEncodingTest + * @run testng/othervm -Dseed=8750823652999067800 PacketEncodingTest + * @run testng/othervm -Dseed=2906555779406889127 PacketEncodingTest + * @run testng/othervm -Dseed=902801756808168822 PacketEncodingTest + * @run testng/othervm -Dseed=5643545543196691308 PacketEncodingTest + * @run testng/othervm -Dseed=2646683818688275736 PacketEncodingTest + * @run testng/othervm -Djdk.internal.httpclient.debug=true PacketEncodingTest + */ +public class PacketEncodingTest { + + @DataProvider + public Object[][] longHeaderPacketProvider() { + final QuicVersion[] quicVersions = QuicVersion.values(); + final List params = new ArrayList<>(); + for (final QuicVersion version : quicVersions) { + final var p = new Object[][] { + // quic-version, srcIdLen, dstIdLen, pn, largestAck + new Object[] {version, 20, 20, 0L, -1L}, + new Object[] {version, 10, 20, 1L, 0L}, + new Object[] {version, 10, 20, 255L, 0L}, + new Object[] {version, 12, 15, 0xFFFFL, 0L}, + new Object[] {version, 9, 8, 0x7FFFFFFFL, 255L}, + new Object[] {version, 13, 11, 0x8FFFFFFFL, 0x10000000L}, + new Object[] {version, 19, 6, 0xFFFFFFFFL, 0xFFFFFFFEL}, + new Object[] {version, 6, 17, 0xFFFFFFFFFFL, 0xFFFFFFFF00L}, + new Object[] {version, 15, 14, 0x7FFFFFFFFFFFL, 0x7FFFFFFFFF00L}, + new Object[] {version, 7, 9, 0xa82f9b32L, 0xa82f30eaL}, + new Object[] {version, 18, 16, 0xace8feL, 0xabe8b3L}, + new Object[] {version, 16, 19, 0xac5c02L, 0xabe8b3L} + }; + params.addAll(Arrays.asList(p)); + } + return params.toArray(Object[][]::new); + } + + @DataProvider + public Object[][] shortHeaderPacketProvider() { + final QuicVersion[] quicVersions = QuicVersion.values(); + final List params = new ArrayList<>(); + for (final QuicVersion version : quicVersions) { + final var p = new Object[][] { + new Object[] {version, 20, 0L, -1L}, + new Object[] {version, 17, 1L, 0L}, + new Object[] {version, 10, 255L, 0L}, + new Object[] {version, 12, 0xFFFFL, 0L}, + new Object[] {version, 9, 0x7FFFFFFFL, 255L}, + new Object[] {version, 13, 0x8FFFFFFFL, 0x10000000L}, + new Object[] {version, 19, 0xFFFFFFFFL, 0xFFFFFFFEL}, + new Object[] {version, 6, 0xFFFFFFFFFFL, 0xFFFFFFFF00L}, + new Object[] {version, 15, 0x7FFFFFFFFFFFL, 0x7FFFFFFFFF00L}, + new Object[] {version, 7, 0xa82f9b32L, 0xa82f30eaL}, + new Object[] {version, 18, 0xace8feL, 0xabe8b3L}, + new Object[] {version, 16, 0xac5c02L, 0xabe8b3L}, + }; + params.addAll(Arrays.asList(p)); + } + return params.toArray(Object[][]::new); + } + + @DataProvider + public Object[][] versionAndRetryProvider() { + final QuicVersion[] quicVersions = QuicVersion.values(); + final List params = new ArrayList<>(); + for (final QuicVersion version : quicVersions) { + final var p = new Object[][] { + // quic-version, srcIdLen, dstIdLen, pn, largestAck + new Object[] {version, 20, 20}, + new Object[] {version, 10, 20}, + new Object[] {version, 12, 15}, + new Object[] {version, 9, 8}, + new Object[] {version, 13, 11}, + new Object[] {version, 19, 6}, + new Object[] {version, 6, 17}, + new Object[] {version, 15, 14}, + new Object[] {version, 7, 9}, + new Object[] {version, 18, 16}, + new Object[] {version, 16, 19}, + }; + params.addAll(Arrays.asList(p)); + } + return params.toArray(Object[][]::new); + } + + private static final AtomicLong IDS = new AtomicLong(); + private static final Random RANDOM = jdk.test.lib.RandomFactory.getRandom(); + private static final int MAX_DATAGRAM_IPV6 = 65527; + + byte[] randomIdBytes(int connectionLength) { + byte[] bytes = new byte[connectionLength]; + RANDOM.nextBytes(bytes); + return bytes; + } + + private static class DummyQuicTLSEngine implements QuicTLSEngine { + @Override + public HandshakeState getHandshakeState() { + throw new AssertionError("should not come here!"); + } + + @Override + public boolean isTLSHandshakeComplete() { + return true; + } + + @Override + public KeySpace getCurrentSendKeySpace() { + throw new AssertionError("should not come here!"); + } + @Override + public boolean keysAvailable(KeySpace keySpace) { + return true; + } + + @Override + public void discardKeys(KeySpace keySpace) { + // no-op + } + + @Override + public void setLocalQuicTransportParameters(ByteBuffer params) { + throw new AssertionError("should not come here!"); + } + + @Override + public void restartHandshake() throws IOException { + throw new AssertionError("should not come here!"); + } + + @Override + public void setRemoteQuicTransportParametersConsumer(QuicTransportParametersConsumer consumer) { + throw new AssertionError("should not come here!"); + } + @Override + public void deriveInitialKeys(QuicVersion version, ByteBuffer connectionId) { } + @Override + public int getHeaderProtectionSampleSize(KeySpace keySpace) { + return 0; + } + @Override + public ByteBuffer computeHeaderProtectionMask(KeySpace keySpace, boolean incoming, ByteBuffer sample) { + return ByteBuffer.allocate(5); + } + + @Override + public int getAuthTagSize() { + return 0; + } + + @Override + public void encryptPacket(KeySpace keySpace, long packetNumber, + IntFunction headerGenerator, + ByteBuffer packetPayload, ByteBuffer output) + throws QuicKeyUnavailableException, QuicTransportException { + // this dummy QUIC TLS engine doesn't do any encryption. + // we just copy over the raw packet payload into the output buffer + output.put(packetPayload); + } + + @Override + public void decryptPacket(KeySpace keySpace, long packetNumber, int keyPhase, + ByteBuffer packet, int headerLength, ByteBuffer output) { + packet.position(packet.position() + headerLength); + output.put(packet); + } + + @Override + public void signRetryPacket(QuicVersion version, + ByteBuffer originalConnectionId, ByteBuffer packet, ByteBuffer output) { + output.put(ByteBuffer.allocate(16)); + } + @Override + public void verifyRetryPacket(QuicVersion version, + ByteBuffer originalConnectionId, ByteBuffer packet) throws AEADBadTagException { + } + @Override + public ByteBuffer getHandshakeBytes(KeySpace keySpace) { + throw new AssertionError("should not come here!"); + } + @Override + public void consumeHandshakeBytes(KeySpace keySpace, ByteBuffer payload) { + throw new AssertionError("should not come here!"); + } + @Override + public Runnable getDelegatedTask() { + throw new AssertionError("should not come here!"); + } + @Override + public boolean tryMarkHandshakeDone() { + throw new AssertionError("should not come here!"); + } + @Override + public boolean tryReceiveHandshakeDone() { + throw new AssertionError("should not come here!"); + } + + @Override + public Set getSupportedQuicVersions() { + return Set.of(QuicVersion.QUIC_V1); + } + + @Override + public void setUseClientMode(boolean mode) { + throw new AssertionError("should not come here!"); + } + + @Override + public boolean getUseClientMode() { + throw new AssertionError("should not come here!"); + } + + @Override + public SSLParameters getSSLParameters() { + throw new AssertionError("should not come here!"); + } + + @Override + public void setSSLParameters(SSLParameters sslParameters) { + throw new AssertionError("should not come here!"); + } + + @Override + public String getApplicationProtocol() { + return null; + } + + @Override + public SSLSession getSession() { + throw new AssertionError("should not come here!"); + } + + @Override + public SSLSession getHandshakeSession() { + throw new AssertionError("should not come here!"); + } + + @Override + public void versionNegotiated(QuicVersion quicVersion) { + // no-op + } + + @Override + public void setOneRttContext(QuicOneRttContext ctx) { + // no-op + } + } + + private static final QuicTLSEngine TLS_ENGINE = new DummyQuicTLSEngine(); + private static abstract class TestCodingContext implements CodingContext { + TestCodingContext() { } + @Override + public int writePacket(QuicPacket packet, ByteBuffer buffer) { + throw new AssertionError("should not come here!"); + } + @Override + public QuicPacket parsePacket(ByteBuffer src) throws IOException { + throw new AssertionError("should not come here!"); + } + @Override + public boolean verifyToken(QuicConnectionId destinationID, byte[] token) { + return true; + } + @Override + public QuicConnectionId originalServerConnId() { + throw new AssertionError("should not come here!"); + } + + @Override + public QuicTLSEngine getTLSEngine() { + return TLS_ENGINE; + } + + @Override + public int minShortPacketPayloadSize(int destConnectionIdLength) { + return 100 - (destConnectionIdLength - connectionIdLength()); + } + } + + private void checkLongHeaderPacket(LongHeaderPacket packet, + PacketType packetType, + int versionNumber, + PacketNumberSpace packetNumberSpace, + long packetNumber, + QuicConnectionId srcConnectionId, + QuicConnectionId destConnectionId, + List payload, + int padding) { + List expected; + if (padding == 0) { + expected = payload; + } else if (payload.get(0) instanceof PaddingFrame pf) { + expected = new ArrayList<>(payload); + expected.set(0, new PaddingFrame(padding + pf.size())); + } else { + expected = new ArrayList<>(payload.size()+1); + expected.add(new PaddingFrame(padding)); + expected.addAll(payload); + } + checkLongHeaderPacket(packet, packetType, versionNumber, packetNumberSpace, packetNumber, + srcConnectionId, destConnectionId, expected); + + } + + private void checkLongHeaderPacket(LongHeaderPacket packet, + PacketType packetType, + int versionNumber, + PacketNumberSpace packetNumberSpace, + long packetNumber, + QuicConnectionId srcConnectionId, + QuicConnectionId destConnectionId, + List payload) { + // Check created packet + assertEquals(packet.headersType(), HeadersType.LONG); + assertEquals(packet.packetType(), packetType); + boolean hasLength = switch (packetType) { + case VERSIONS, RETRY -> false; + default -> true; + }; + assertEquals(packet.hasLength(), hasLength); + assertEquals(packet.numberSpace(), packetNumberSpace); + if (payload == null) { + assertTrue(packet.frames().isEmpty()); + } else { + assertEquals(getBuffers(packet.frames()), getBuffers(payload)); + } + assertEquals(packet.version(), versionNumber); + assertEquals(packet.packetNumber(), packetNumber); + assertEquals(packet.sourceId(), srcConnectionId); + assertEquals(packet.destinationId(), destConnectionId); + + } + + private static ByteBuffer encodeFrame(QuicFrame frame) { + ByteBuffer result = ByteBuffer.allocate(frame.size()); + frame.encode(result); + return result; + } + private static List getBuffers(List payload) { + return payload.stream().map(PacketEncodingTest::encodeFrame).toList(); + } + private static List getBuffers(List payload, int minSize) { + int payloadSize = payload.stream().mapToInt(QuicFrame::size).sum(); + if (payloadSize < minSize) { + payload = new ArrayList<>(payload); + payload.add(0, new PaddingFrame(minSize - payloadSize)); + } + return payload.stream().map(PacketEncodingTest::encodeFrame).toList(); + } + + private static String toHex(ByteBuffer buffer) { + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes, 0, buffer.remaining()); + return HexFormat.of().formatHex(bytes); + } + + private static String toHex(List byteBuffers) { + return "0x" + byteBuffers.stream() + .map(PacketEncodingTest::toHex) + .collect(Collectors.joining(":")); + } + + private void checkLongHeaderPacketAt(ByteBuffer datagram, int offset, + PacketType packetType, int versionNumber, + QuicConnectionId srcConnectionId, + QuicConnectionId destConnectionId) { + assertEquals(QuicPacketDecoder.peekHeaderType(datagram, offset), HeadersType.LONG); + assertEquals(QuicPacketDecoder.of(datagram, offset).peekPacketType(datagram, offset), packetType); + LongHeader header = QuicPacketDecoder.peekLongHeader(datagram, offset); + assertNotNull(header, "Could not parse packet header"); + assertEquals(header.version(), versionNumber); + assertTrue(header.destinationId() + .matches(destConnectionId.asReadOnlyBuffer()), "Destination ID doesn't match"); + assertTrue(header.sourceId() + .matches(srcConnectionId.asReadOnlyBuffer()), "Source ID doesn't match"); + } + + private List frames(byte[] payload) throws IOException { + return frames(payload, false); + } + + private List frames(byte[] payload, boolean insert) throws IOException { + int payloadSize = payload.length; + ByteBuffer buf = ByteBuffer.wrap(payload); + List frames = new ArrayList<>(); + int remaining = payloadSize; + while (remaining > 7) { + int size = RANDOM.nextInt(1, remaining - 6); + byte[] data = new byte[size]; + RANDOM.nextBytes(data); + QuicFrame frame = new CryptoFrame(0, size, ByteBuffer.wrap(data)); + int encoded = frame.size(); + assertTrue(encoded > 0, String.valueOf(encoded)); + assertTrue(encoded <= remaining, String.valueOf(encoded)); + if (insert) { + frames.add(0, frame); + buf.position(remaining - encoded); + } else { + frames.add(frame); + } + frame.encode(buf); + remaining -= encoded; + } + if (remaining > 0) { + var padding = new PaddingFrame(remaining); + if (insert) { + frames.add(0, padding); + buf.position(0); + } else { + frames.add(padding); + } + padding.encode(buf); + } + if (insert) { + assertEquals(buf.position(), remaining); + assertEquals(buf.remaining(), payloadSize - remaining); + } else { + assertEquals(buf.remaining(), 0); + } + return List.copyOf(frames); + } + + private ByteBuffer toByteBuffer(QuicPacketEncoder encoder, QuicPacket outgoingQuicPacket, CodingContext context) + throws Exception { + int size = outgoingQuicPacket.size(); + ByteBuffer buffer = ByteBuffer.allocate(size); + encoder.encode(outgoingQuicPacket, buffer, context); + assertEquals(buffer.position(), size, " for " + outgoingQuicPacket); + buffer.flip(); + return buffer; + } + + private void checkShortHeaderPacket(ShortHeaderPacket packet, + PacketType packetType, + PacketNumberSpace packetNumberSpace, + long packetNumber, + QuicConnectionId destConnectionId, + List payload, + int minSize) { + // Check created packet + assertEquals(packet.headersType(), HeadersType.SHORT); + assertEquals(packet.packetType(), packetType); + assertEquals(packet.hasLength(), false); + assertEquals(packet.numberSpace(), packetNumberSpace); + assertEquals(getBuffers(packet.frames()), getBuffers(payload, minSize)); + assertEquals(packet.packetNumber(), packetNumber); + assertEquals(packet.destinationId(), destConnectionId); + } + + private void checkShortHeaderPacketAt(ByteBuffer datagram, int offset, + PacketType packetType, + QuicConnectionId destConnectionId, + CodingContext context) { + assertEquals(QuicPacketDecoder.peekHeaderType(datagram, offset), HeadersType.SHORT); + assertEquals(QuicPacketDecoder.of(QuicVersion.QUIC_V1).peekPacketType(datagram, offset), packetType); + assertEquals(QuicPacketDecoder.peekVersion(datagram, offset), 0); + int pos = datagram.position(); + if (pos != offset) datagram.position(offset); + try { + assertEquals(QuicPacketDecoder.peekShortConnectionId(datagram, destConnectionId.length()) + .mismatch(destConnectionId.asReadOnlyBuffer()), -1); + } finally { + if (pos != offset) datagram.position(pos); + } + } + + @Test(dataProvider = "longHeaderPacketProvider") + public void testInitialPacket(QuicVersion quicVersion, int srcIdLength, int destIdLength, + long packetNumber, long largestAcked) throws Exception { + System.out.printf("%ntestInitialPacket(qv:%s, scid:%d, dcid:%d, pn:%d, ack:%d)%n", + quicVersion, srcIdLength, destIdLength, packetNumber, largestAcked); + QuicPacketEncoder encoder = QuicPacketEncoder.of(quicVersion); + QuicPacketDecoder decoder = QuicPacketDecoder.of(quicVersion); + byte[] destid = QuicConnectionIdFactory.getClient() + .newConnectionId(destIdLength, IDS.incrementAndGet()); + assert destid.length <= 20; + final QuicConnectionId destConnectionId = new PeerConnectionId(destid); + assertEquals(destid.length, destConnectionId.length(), "dcid length"); + destIdLength = destid.length; + + byte[] srcid = randomIdBytes(srcIdLength); + final QuicConnectionId srcConnectionId = new PeerConnectionId(srcid); + assertEquals(srcid.length, srcConnectionId.length(), "scid length"); + + int bound = MAX_DATAGRAM_IPV6 - srcIdLength - destid.length - 7 + - QuicPacketNumbers.computePacketNumberLength(packetNumber, largestAcked) + - VariableLengthEncoder.getEncodedSize(MAX_DATAGRAM_IPV6); + + // ensure that bound - tokenLength - 1 > 0 + assert bound > 4; + int tokenLength = RANDOM.nextInt(bound - 4); + + byte[] token = tokenLength == 0 ? null : new byte[tokenLength]; + if (token != null) RANDOM.nextBytes(token); + int packetNumberLength = + QuicPacketNumbers.computePacketNumberLength(packetNumber, largestAcked); + int payloadSize = Math.max(RANDOM.nextInt(bound - tokenLength - 1) + 1, 4 - packetNumberLength); + System.out.printf("testInitialPacket.encode(scid:%s, dcid:%s, token:%d, payload:%d)%n", + srcIdLength, destIdLength, tokenLength, payloadSize); + + CodingContext context = new TestCodingContext() { + @Override public long largestProcessedPN(PacketNumberSpace packetSpace) { + return packetSpace == PacketNumberSpace.INITIAL ? largestAcked : -1; + } + @Override public long largestAckedPN(PacketNumberSpace packetSpace) { + return packetSpace == PacketNumberSpace.INITIAL ? largestAcked : -1; + } + @Override public int connectionIdLength() { + return srcIdLength; + } + }; + int minsize = encoder.computeMaxInitialPayloadSize(context, + computePacketNumberLength(packetNumber, + context.largestAckedPN(PacketNumberSpace.INITIAL)), + tokenLength, srcIdLength, + destIdLength, 1200); + int padding = (payloadSize < minsize) ? minsize - payloadSize : 0; + System.out.println("testInitialPacket: available=%s, payload=%s, padding=%s" + .formatted(minsize, payloadSize, padding)); + + + byte[] payload = new byte[payloadSize]; + List frames = frames(payload, padding != 0); + assertEquals(frames.stream().mapToInt(QuicFrame::size) + .reduce(0, Math::addExact), payloadSize); + + + // Create an initial packet + var packet = encoder.newInitialPacket(srcConnectionId, + destConnectionId, + token, + packetNumber, + largestAcked, + frames, + context); + + if (padding > 0) { + var frames2 = new ArrayList<>(frames); + frames2.add(0, new PaddingFrame(padding)); + var packet2 = encoder.newInitialPacket(srcConnectionId, + destConnectionId, + token, + packetNumber, + largestAcked, + frames2, + context); + assertEquals(padding, padding + (1200 - packet2.size())); + } + + // Check created packet + assertTrue(packet instanceof InitialPacket); + var initialPacket = (InitialPacket) packet; + System.out.printf("%s: pn:%s, tklen:%s, payloadSize:%s, padding:%s, packet::size:%s, " + + "\n\tinputFrames: %s, " + + "\n\tencodedFrames:%s%n", + PacketType.INITIAL, packetNumber, tokenLength, payload.length, padding, + packet.size(), frames, packet.frames()); + checkLongHeaderPacket(initialPacket, PacketType.INITIAL, quicVersion.versionNumber(), + PacketNumberSpace.INITIAL, packetNumber, + srcConnectionId, destConnectionId, frames, padding); + assertEquals(initialPacket.tokenLength(), tokenLength); + assertEquals(initialPacket.token(), token); + assertEquals(initialPacket.hasLength(), true); + assertEquals(initialPacket.length(), packetNumberLength + payloadSize + padding); + + // Check that peeking at the encoded packet returns correct information + + // Decode the two packets in the datagram + ByteBuffer encoded = toByteBuffer(encoder, packet, context); + checkLongHeaderPacketAt(encoded, 0, PacketType.INITIAL, quicVersion.versionNumber(), + srcConnectionId, destConnectionId); + + // coalesce two packets in a single datagram and check + // the peek methods again + int offset = RANDOM.nextInt(256); + int second = offset + encoded.limit(); + System.out.printf("testInitialPacket.encode(offset:%d, second:%d)%n", + offset, second); + ByteBuffer datagram = ByteBuffer.allocate(encoded.limit() * 2 + offset * 2); + datagram.position(offset); + datagram.put(encoded); + encoded.flip(); + datagram.put(encoded); + encoded.flip(); + datagram.flip(); + + // check header, type and version of both packets + System.out.printf("datagram(offset:%d, second:%d, position:%d, limit:%d)%n", + offset, second, datagram.position(), datagram.limit()); + System.out.printf("reading first datagram(offset:%d, position:%d, limit:%d)%n", + offset, datagram.position(), datagram.limit()); + checkLongHeaderPacketAt(datagram, offset, PacketType.INITIAL, quicVersion.versionNumber(), + srcConnectionId, destConnectionId); + System.out.printf("reading second datagram(offset:%d, position:%d, limit:%d)%n", + second, datagram.position(), datagram.limit()); + checkLongHeaderPacketAt(datagram, second, PacketType.INITIAL, quicVersion.versionNumber(), + srcConnectionId, destConnectionId); + + // check that skip packet can skip both packets + datagram.position(0); + datagram.limit(datagram.capacity()); + decoder.skipPacket(datagram, offset); + assertEquals(datagram.position(), second); + decoder.skipPacket(datagram, second); + assertEquals(datagram.remaining(), offset); + + datagram.position(offset); + int size = second - offset; + for (int i=0; i<2; i++) { + int pos = datagram.position(); + System.out.printf("Decoding packet: %d at %d%n", (i+1), pos); + var decodedPacket = decoder.decode(datagram, context); + assertEquals(datagram.position(), pos + size); + assertTrue(decodedPacket instanceof InitialPacket, "decoded: " + decodedPacket); + InitialPacket initialDecoded = InitialPacket.class.cast(decodedPacket); + checkLongHeaderPacket(initialDecoded, PacketType.INITIAL, quicVersion.versionNumber(), + PacketNumberSpace.INITIAL, packetNumber, + srcConnectionId, destConnectionId, frames, padding); + assertEquals(decodedPacket.size(), packet.size()); + assertEquals(decodedPacket.size(), size); + assertEquals(initialDecoded.tokenLength(), tokenLength); + assertEquals(initialDecoded.token(), token); + assertEquals(initialDecoded.length(), initialPacket.length()); + assertEquals(initialDecoded.length(), packetNumberLength + payloadSize + padding); + } + assertEquals(datagram.position(), second + second - offset); + } + + @Test(dataProvider = "longHeaderPacketProvider") + public void testHandshakePacket(QuicVersion quicVersion, int srcIdLength, int destIdLength, + long packetNumber, long largestAcked) throws Exception { + System.out.printf("%ntestHandshakePacket(qv:%s, scid:%d, dcid:%d, pn:%d, ack:%d)%n", + quicVersion, srcIdLength, destIdLength, packetNumber, largestAcked); + QuicPacketEncoder encoder = QuicPacketEncoder.of(quicVersion); + QuicPacketDecoder decoder = QuicPacketDecoder.of(quicVersion); + byte[] destid = QuicConnectionIdFactory.getClient() + .newConnectionId(destIdLength, IDS.incrementAndGet()); + assert destid.length <= 20; + QuicConnectionId destConnectionId = new PeerConnectionId(destid); + byte[] srcid = randomIdBytes(srcIdLength); + QuicConnectionId srcConnectionId = new PeerConnectionId(srcid); + int bound = MAX_DATAGRAM_IPV6 - srcIdLength - destid.length - 7 + - QuicPacketNumbers.computePacketNumberLength(packetNumber, largestAcked) + - VariableLengthEncoder.getEncodedSize(MAX_DATAGRAM_IPV6); + + int packetNumberLength = + QuicPacketNumbers.computePacketNumberLength(packetNumber, largestAcked); + int payloadSize = Math.max(RANDOM.nextInt(bound - 1) + 1, 4 - packetNumberLength); + byte[] payload = new byte[payloadSize]; + var frames = frames(payload); + System.out.printf("testHandshakePacket.encode(payload:%d)%n", payloadSize); + + CodingContext context = new TestCodingContext() { + @Override public long largestProcessedPN(PacketNumberSpace packetSpace) { + return packetSpace == PacketNumberSpace.HANDSHAKE ? largestAcked : -1; + } + @Override public long largestAckedPN(PacketNumberSpace packetSpace) { + return packetSpace == PacketNumberSpace.HANDSHAKE ? largestAcked : -1; + } + @Override public int connectionIdLength() { + return srcIdLength; + } + }; + // Create an initial packet + var packet = encoder.newHandshakePacket(srcConnectionId, + destConnectionId, + packetNumber, + largestAcked, + frames, + context); + + // Check created packet + assertTrue(packet instanceof HandshakePacket); + var handshakePacket = (HandshakePacket) packet; + checkLongHeaderPacket(handshakePacket, PacketType.HANDSHAKE, quicVersion.versionNumber(), + PacketNumberSpace.HANDSHAKE, packetNumber, + srcConnectionId, destConnectionId, frames); + assertEquals(handshakePacket.hasLength(), true); + assertEquals(handshakePacket.length(), packetNumberLength + payloadSize); + + // Decode the two packets in the datagram + // Check that peeking at the encoded packet returns correct information + ByteBuffer encoded = toByteBuffer(encoder, packet, context); + checkLongHeaderPacketAt(encoded, 0, PacketType.HANDSHAKE, quicVersion.versionNumber(), + srcConnectionId, destConnectionId); + + // coalesce two packets in a single datagram and check + // the peek methods again + int offset = RANDOM.nextInt(256); + int second = offset + encoded.limit(); + System.out.printf("testHandshakePacket.encode(offset:%d, second:%d)%n", + offset, second); + ByteBuffer datagram = ByteBuffer.allocate(encoded.limit() * 2 + offset * 2); + datagram.position(offset); + datagram.put(encoded); + encoded.flip(); + datagram.put(encoded); + encoded.flip(); + datagram.flip(); + + // check header, type and version of both packets + System.out.printf("datagram(offset:%d, second:%d, position:%d, limit:%d)%n", + offset, second, datagram.position(), datagram.limit()); + // set position to first packet to check connection ids + System.out.printf("reading first datagram(offset:%d, position:%d, limit:%d)%n", + offset, datagram.position(), datagram.limit()); + checkLongHeaderPacketAt(datagram, offset, PacketType.HANDSHAKE, quicVersion.versionNumber(), + srcConnectionId, destConnectionId); + System.out.printf("reading second datagram(offset:%d, position:%d, limit:%d)%n", + second, datagram.position(), datagram.limit()); + checkLongHeaderPacketAt(datagram, second, PacketType.HANDSHAKE, quicVersion.versionNumber(), + srcConnectionId, destConnectionId); + + // check that skip packet can skip both packets + datagram.position(0); + datagram.limit(datagram.capacity()); + decoder.skipPacket(datagram, offset); + assertEquals(datagram.position(), second); + decoder.skipPacket(datagram, second); + assertEquals(datagram.remaining(), offset); + + datagram.position(offset); + int size = second - offset; + for (int i=0; i<2; i++) { + int pos = datagram.position(); + System.out.printf("Decoding packet: %d at %d%n", (i+1), pos); + var decodedPacket = decoder.decode(datagram, context); + assertEquals(datagram.position(), pos + size); + assertTrue(decodedPacket instanceof HandshakePacket, "decoded: " + decodedPacket); + HandshakePacket handshakeDecoded = HandshakePacket.class.cast(decodedPacket); + checkLongHeaderPacket(handshakeDecoded, PacketType.HANDSHAKE, quicVersion.versionNumber(), + PacketNumberSpace.HANDSHAKE, packetNumber, + srcConnectionId, destConnectionId, frames); + assertEquals(decodedPacket.size(), packet.size()); + assertEquals(decodedPacket.size(), size); + assertEquals(handshakeDecoded.length(), handshakePacket.length()); + assertEquals(handshakeDecoded.length(), packetNumberLength + payloadSize); + } + assertEquals(datagram.position(), second + second - offset); + } + + @Test(dataProvider = "longHeaderPacketProvider") + public void testZeroRTTPacket(QuicVersion quicVersion, int srcIdLength, int destIdLength, + long packetNumber, long largestAcked) throws Exception { + System.out.printf("%ntestZeroRTTPacket(qv:%s, scid:%d, dcid:%d, pn:%d, ack:%d)%n", + quicVersion, srcIdLength, destIdLength, packetNumber, largestAcked); + QuicPacketEncoder encoder = QuicPacketEncoder.of(quicVersion); + QuicPacketDecoder decoder = QuicPacketDecoder.of(quicVersion); + byte[] destid = QuicConnectionIdFactory.getClient() + .newConnectionId(destIdLength, IDS.incrementAndGet()); + assert destid.length <= 20; + QuicConnectionId destConnectionId = new PeerConnectionId(destid); + byte[] srcid = randomIdBytes(srcIdLength); + QuicConnectionId srcConnectionId = new PeerConnectionId(srcid); + int bound = MAX_DATAGRAM_IPV6 - srcIdLength - destid.length - 7 + - QuicPacketNumbers.computePacketNumberLength(packetNumber, largestAcked) + - VariableLengthEncoder.getEncodedSize(MAX_DATAGRAM_IPV6); + + int packetNumberLength = + QuicPacketNumbers.computePacketNumberLength(packetNumber, largestAcked); + int payloadSize = Math.max(RANDOM.nextInt(bound - 1) + 1, 4 - packetNumberLength); + byte[] payload = new byte[payloadSize]; + var frames = frames(payload); + System.out.printf("testZeroRTTPacket.encode(payload:%d)%n", payloadSize); + + CodingContext context = new TestCodingContext() { + @Override public long largestProcessedPN(PacketNumberSpace packetSpace) { + return packetSpace == PacketNumberSpace.APPLICATION ? largestAcked : -1; + } + @Override public long largestAckedPN(PacketNumberSpace packetSpace) { + return packetSpace == PacketNumberSpace.APPLICATION ? largestAcked : -1; + } + @Override public int connectionIdLength() { + return srcIdLength; + } + }; + // Create an initial packet + var packet = encoder.newZeroRttPacket(srcConnectionId, + destConnectionId, + packetNumber, + largestAcked, + frames, + context); + + // Check created packet + assertTrue(packet instanceof ZeroRttPacket); + var zeroRttPacket = (ZeroRttPacket) packet; + checkLongHeaderPacket(zeroRttPacket, PacketType.ZERORTT, quicVersion.versionNumber(), + PacketNumberSpace.APPLICATION, packetNumber, + srcConnectionId, destConnectionId, frames); + assertEquals(zeroRttPacket.hasLength(), true); + assertEquals(zeroRttPacket.length(), packetNumberLength + payloadSize); + + // Check that peeking at the encoded packet returns correct information + ByteBuffer encoded = toByteBuffer(encoder, packet, context); + checkLongHeaderPacketAt(encoded, 0, PacketType.ZERORTT, quicVersion.versionNumber(), + srcConnectionId, destConnectionId); + + // coalesce two packets in a single datagram and check + // the peek methods again + int offset = RANDOM.nextInt(256); + int second = offset + encoded.limit(); + System.out.printf("testZeroRTTPacket.encode(offset:%d, second:%d)%n", + offset, second); + ByteBuffer datagram = ByteBuffer.allocate(encoded.limit() * 2 + offset * 2); + datagram.position(offset); + datagram.put(encoded); + encoded.flip(); + datagram.put(encoded); + encoded.flip(); + datagram.flip(); + + // check header, type and version of both packets + System.out.printf("datagram(offset:%d, second:%d, position:%d, limit:%d)%n", + offset, second, datagram.position(), datagram.limit()); + // set position to first packet to check connection ids + System.out.printf("reading first datagram(offset:%d, position:%d, limit:%d)%n", + offset, datagram.position(), datagram.limit()); + checkLongHeaderPacketAt(datagram, offset, PacketType.ZERORTT, quicVersion.versionNumber(), + srcConnectionId, destConnectionId); + System.out.printf("reading second datagram(offset:%d, position:%d, limit:%d)%n", + second, datagram.position(), datagram.limit()); + checkLongHeaderPacketAt(datagram, second, PacketType.ZERORTT, quicVersion.versionNumber(), + srcConnectionId, destConnectionId); + + // check that skip packet can skip both packets + datagram.position(0); + datagram.limit(datagram.capacity()); + decoder.skipPacket(datagram, offset); + assertEquals(datagram.position(), second); + decoder.skipPacket(datagram, second); + assertEquals(datagram.remaining(), offset); + + // Decode the two packets in the datagram + datagram.position(offset); + int size = second - offset; + for (int i=0; i<2; i++) { + int pos = datagram.position(); + System.out.printf("Decoding packet: %d at %d%n", (i+1), pos); + var decodedPacket = decoder.decode(datagram, context); + assertEquals(datagram.position(), pos + size); + assertTrue(decodedPacket instanceof ZeroRttPacket, "decoded: " + decodedPacket); + ZeroRttPacket zeroRttDecoded = ZeroRttPacket.class.cast(decodedPacket); + checkLongHeaderPacket(zeroRttDecoded, PacketType.ZERORTT, quicVersion.versionNumber(), + PacketNumberSpace.APPLICATION, packetNumber, + srcConnectionId, destConnectionId, frames); + assertEquals(decodedPacket.size(), packet.size()); + assertEquals(decodedPacket.size(), size); + assertEquals(zeroRttDecoded.length(), zeroRttPacket.length()); + assertEquals(zeroRttDecoded.length(), packetNumberLength + payloadSize); + } + assertEquals(datagram.position(), second + second - offset); + } + + @Test(dataProvider = "versionAndRetryProvider") + public void testVersionNegotiationPacket(QuicVersion quicVersion, int srcIdLength, int destIdLength) + throws Exception { + System.out.printf("%ntestVersionNegotiationPacket(qv:%s, scid:%d, dcid:%d, pn:%d, ack:%d)%n", + quicVersion, srcIdLength, destIdLength, -1, -1); + QuicPacketEncoder encoder = QuicPacketEncoder.of(quicVersion); + QuicPacketDecoder decoder = QuicPacketDecoder.of(quicVersion); + byte[] destid = QuicConnectionIdFactory.getClient() + .newConnectionId(destIdLength, IDS.incrementAndGet()); + assert destid.length <= 20; + QuicConnectionId destConnectionId = new PeerConnectionId(destid); + byte[] srcid = randomIdBytes(srcIdLength); + QuicConnectionId srcConnectionId = new PeerConnectionId(srcid); + + final List versionList = new ArrayList<>(); + for (final QuicVersion qv : QuicVersion.values()) { + versionList.add(qv.versionNumber()); + } + System.out.printf("testVersionNegotiationPacket.encode(versions:%d)%n", versionList.size()); + + // Create an initial packet + var packet = QuicPacketEncoder.newVersionNegotiationPacket(srcConnectionId, + destConnectionId, + versionList.stream().mapToInt(Integer::intValue).toArray()); + + // Check created packet + assertTrue(packet instanceof VersionNegotiationPacket); + var versionPacket = (VersionNegotiationPacket) packet; + checkLongHeaderPacket(versionPacket, PacketType.VERSIONS, 0, + PacketNumberSpace.NONE, -1, + srcConnectionId, destConnectionId, null); + assertEquals(versionPacket.hasLength(), false); + assertEquals(versionPacket.supportedVersions(), + versionList.stream().mapToInt(Integer::intValue).toArray()); + + CodingContext context = new TestCodingContext() { + @Override public long largestProcessedPN(PacketNumberSpace packetSpace) { + return -1; + } + @Override public long largestAckedPN(PacketNumberSpace packetSpace) { + return -1; + } + @Override public int connectionIdLength() { + return srcIdLength; + } + }; + // Check that peeking at the encoded packet returns correct information + ByteBuffer encoded = toByteBuffer(encoder, packet, context); + checkLongHeaderPacketAt(encoded, 0, PacketType.VERSIONS, 0, + srcConnectionId, destConnectionId); + + // version negotiation packets can't be coalesced + int offset = RANDOM.nextInt(256); + int end = offset + encoded.limit(); + System.out.printf("testVersionNegotiationPacket.encode(offset:%d, end:%d)%n", + offset, end); + ByteBuffer datagram = ByteBuffer.allocate(encoded.limit() + offset); + datagram.position(offset); + datagram.put(encoded); + encoded.flip(); + datagram.flip(); + + // check header, type and version of both packets + System.out.printf("datagram(offset:%d, position:%d, limit:%d)%n", + offset, datagram.position(), datagram.limit()); + // set position to first packet to check connection ids + System.out.printf("reading datagram(offset:%d, position:%d, limit:%d)%n", + offset, datagram.position(), datagram.limit()); + checkLongHeaderPacketAt(datagram, offset, PacketType.VERSIONS, 0, + srcConnectionId, destConnectionId); + + // check that skip packet can skip packet + datagram.position(0); + datagram.limit(datagram.capacity()); + decoder.skipPacket(datagram, offset); + assertEquals(datagram.position(), end); + assertEquals(datagram.remaining(), 0); + + // Decode the two packets in the datagram + datagram.position(offset); + int size = end - offset; + for (int i=0; i<1; i++) { + int pos = datagram.position(); + System.out.printf("Decoding packet: %d at %d%n", (i+1), pos); + var decodedPacket = decoder.decode(datagram, context); + assertEquals(datagram.position(), pos + size); + assertTrue(decodedPacket instanceof VersionNegotiationPacket, "decoded: " + decodedPacket); + VersionNegotiationPacket decodedVersion = VersionNegotiationPacket.class.cast(decodedPacket); + checkLongHeaderPacket(decodedVersion, PacketType.VERSIONS, 0, + PacketNumberSpace.NONE, -1, + srcConnectionId, destConnectionId, null); + assertEquals(decodedPacket.size(), packet.size()); + assertEquals(decodedPacket.size(), size); + assertEquals(decodedVersion.supportedVersions(), + versionList.stream().mapToInt(Integer::intValue).toArray()); + } + assertEquals(datagram.position(), end); + } + + @Test(dataProvider = "versionAndRetryProvider") + public void testRetryPacket(QuicVersion quicVersion, int srcIdLength, int destIdLength) + throws Exception { + System.out.printf("%ntestRetryPacket(qv:%s, scid:%d, dcid:%d, pn:%d, ack:%d)%n", + quicVersion, srcIdLength, destIdLength, -1, -1); + QuicPacketEncoder encoder = QuicPacketEncoder.of(quicVersion); + QuicPacketDecoder decoder = QuicPacketDecoder.of(quicVersion); + byte[] destid = QuicConnectionIdFactory.getClient() + .newConnectionId(destIdLength, IDS.incrementAndGet()); + assert destid.length <= 20; + QuicConnectionId destConnectionId = new PeerConnectionId(destid); + byte[] srcid = randomIdBytes(srcIdLength); + QuicConnectionId srcConnectionId = new PeerConnectionId(srcid); + byte[] origId = randomIdBytes(destIdLength); + QuicConnectionId origConnectionId = new PeerConnectionId(origId); + int bound = (MAX_DATAGRAM_IPV6 - srcIdLength - destid.length - 7); + + int retryTokenLength = RANDOM.nextInt(bound - 16) + 1; + byte[] retryToken = new byte[retryTokenLength]; + RANDOM.nextBytes(retryToken); + System.out.printf("testRetryPacket.encode(token:%d)%n", retryTokenLength); + int expectedSize = 7 + 16 + destid.length + srcIdLength + retryTokenLength; + + // Create an initial packet + var packet = encoder.newRetryPacket(srcConnectionId, + destConnectionId, + retryToken); + + // Check created packet + assertTrue(packet instanceof RetryPacket); + var retryPacket = (RetryPacket) packet; + checkLongHeaderPacket(retryPacket, PacketType.RETRY, quicVersion.versionNumber(), + PacketNumberSpace.NONE, -1, + srcConnectionId, destConnectionId, null); + assertEquals(retryPacket.hasLength(), false); + assertEquals(retryPacket.retryToken(), retryToken); + assertEquals(retryPacket.size(), expectedSize); + + CodingContext context = new TestCodingContext() { + @Override public long largestProcessedPN(PacketNumberSpace packetSpace) { + return -1; + } + @Override public long largestAckedPN(PacketNumberSpace packetSpace) { + return -1; + } + @Override public int connectionIdLength() { + return srcIdLength; + } + @Override public QuicConnectionId originalServerConnId() { return origConnectionId; } + }; + // Check that peeking at the encoded packet returns correct information + ByteBuffer encoded = toByteBuffer(encoder, packet, context); + checkLongHeaderPacketAt(encoded, 0, PacketType.RETRY, quicVersion.versionNumber(), + srcConnectionId, destConnectionId); + + // version negotiation packets can't be coalesced + int offset = RANDOM.nextInt(256); + int end = offset + encoded.limit(); + System.out.printf("testRetryPacket.encode(offset:%d, end:%d)%n", + offset, end); + ByteBuffer datagram = ByteBuffer.allocate(encoded.limit() + offset); + datagram.position(offset); + datagram.put(encoded); + encoded.flip(); + datagram.flip(); + + // check header, type and version of both packets + System.out.printf("datagram(offset:%d, position:%d, limit:%d)%n", + offset, datagram.position(), datagram.limit()); + // set position to first packet to check connection ids + System.out.printf("reading datagram(offset:%d, position:%d, limit:%d)%n", + offset, datagram.position(), datagram.limit()); + checkLongHeaderPacketAt(datagram, offset, PacketType.RETRY, quicVersion.versionNumber(), + srcConnectionId, destConnectionId); + + // check that skip packet can skip packet + datagram.position(0); + datagram.limit(datagram.capacity()); + decoder.skipPacket(datagram, offset); + assertEquals(datagram.position(), end); + assertEquals(datagram.remaining(), 0); + + // Decode the two packets in the datagram + datagram.position(offset); + int size = end - offset; + for (int i=0; i<1; i++) { + int pos = datagram.position(); + System.out.printf("Decoding packet: %d at %d%n", (i+1), pos); + var decodedPacket = decoder.decode(datagram, context); + assertEquals(datagram.position(), pos + size); + assertTrue(decodedPacket instanceof RetryPacket, "decoded: " + decodedPacket); + RetryPacket decodedRetry = RetryPacket.class.cast(decodedPacket); + checkLongHeaderPacket(decodedRetry, PacketType.RETRY, quicVersion.versionNumber(), + PacketNumberSpace.NONE, -1, + srcConnectionId, destConnectionId, null); + assertEquals(decodedPacket.size(), packet.size()); + assertEquals(decodedPacket.size(), size); + assertEquals(decodedPacket.size(), expectedSize); + assertEquals(decodedRetry.retryToken(), retryToken); + } + assertEquals(datagram.position(), end); + } + + @Test(dataProvider = "shortHeaderPacketProvider") + public void testOneRTTPacket(QuicVersion quicVersion, int destIdLength, + long packetNumber, long largestAcked) throws Exception { + System.out.printf("%ntestOneRTTPacket(qv:%s, dcid:%d, pn:%d, ack:%d)%n", + quicVersion, destIdLength, packetNumber, largestAcked); + QuicPacketEncoder encoder = QuicPacketEncoder.of(quicVersion); + QuicPacketDecoder decoder = QuicPacketDecoder.of(quicVersion); + byte[] destid = QuicConnectionIdFactory.getClient() + .newConnectionId(destIdLength, IDS.incrementAndGet()); + assert destid.length <= 20; + QuicConnectionId destConnectionId = new PeerConnectionId(destid); + int bound = MAX_DATAGRAM_IPV6 - destid.length - 7 + - QuicPacketNumbers.computePacketNumberLength(packetNumber, largestAcked) + - VariableLengthEncoder.getEncodedSize(MAX_DATAGRAM_IPV6); + + int packetNumberLength = + QuicPacketNumbers.computePacketNumberLength(packetNumber, largestAcked); + int payloadSize = Math.max(RANDOM.nextInt(bound - 1) + 1, 4 - packetNumberLength); + byte[] payload = new byte[payloadSize]; + var frames = frames(payload); + + CodingContext context = new TestCodingContext() { + @Override public long largestProcessedPN(PacketNumberSpace packetSpace) { + return packetSpace == PacketNumberSpace.APPLICATION ? largestAcked : -1; + } + @Override public long largestAckedPN(PacketNumberSpace packetSpace) { + return packetSpace == PacketNumberSpace.APPLICATION ? largestAcked : -1; + } + // since we're going to decode the short packet, we need to return + // the same length that was used as destination cid in the packet + @Override public int connectionIdLength() { + return destid.length; + } + }; + + int paddedPayLoadSize = Math.max(payloadSize + packetNumberLength, context.minShortPacketPayloadSize(destid.length)); + System.out.printf("testOneRTTPacket.encode(payload:%d, padded:%d, destid.length: %d)%n", + payloadSize, paddedPayLoadSize, destid.length); + int expectedSize = 1 + destid.length + paddedPayLoadSize; + // Create an 1-RTT packet + OneRttPacket packet = encoder.newOneRttPacket(destConnectionId, + packetNumber, + largestAcked, + frames, + context); + + int minPayloadSize = context.minShortPacketPayloadSize(destConnectionId.length()) - packetNumberLength; + checkShortHeaderPacket(packet, PacketType.ONERTT, + PacketNumberSpace.APPLICATION, packetNumber, + destConnectionId, frames, minPayloadSize); + assertEquals(packet.hasLength(), false); + assertEquals(packet.size(), expectedSize); + + // Check that peeking at the encoded packet returns correct information + ByteBuffer encoded = toByteBuffer(encoder, packet, context); + checkShortHeaderPacketAt(encoded, 0, PacketType.ONERTT, + destConnectionId, context); + + // write packet at an offset in the datagram to simulate + // short packet coalesced after long packet and check + // the peek methods again + int offset = RANDOM.nextInt(256); + int end = offset + encoded.limit(); + System.out.printf("testOneRTTPacket.encode(offset:%d, end:%d)%n", + offset, end); + ByteBuffer datagram = ByteBuffer.allocate(encoded.limit() + offset * 2); + datagram.position(offset); + datagram.put(encoded); + encoded.flip(); + datagram.flip(); + assert datagram.limit() == offset + encoded.remaining(); + + // set position to first packet to check connection ids + System.out.printf("reading datagram(offset:%d, position:%d, limit:%d)%n", + offset, datagram.position(), datagram.limit()); + checkShortHeaderPacketAt(datagram, offset, PacketType.ONERTT, + destConnectionId, context); + + // check that skip packet can skip packet at offset + datagram.position(0); + datagram.limit(end); + decoder.skipPacket(datagram, offset); + assertEquals(datagram.position(), offset + expectedSize); + assertEquals(datagram.position(), datagram.limit()); + assertEquals(datagram.position(), datagram.capacity() - offset); + + + // Decode the packet in the datagram + datagram.position(offset); + int size = expectedSize; + for (int i=0; i<1; i++) { + int pos = datagram.position(); + System.out.printf("Decoding packet: %d at %d%n", (i+1), pos); + var decodedPacket = decoder.decode(datagram, context); + assertEquals(datagram.position(), pos + size); + assertTrue(decodedPacket instanceof OneRttPacket, "decoded: " + decodedPacket); + OneRttPacket oneRttDecoded = OneRttPacket.class.cast(decodedPacket); + List expectedFrames = frames; + if (frames.size() > 0 && frames.get(0) instanceof PaddingFrame) { + // The first frame should be a crypto frame, except if payloadSize + // was less than 7. + int frameSizes = frames.stream().mapToInt(QuicFrame::size).sum(); + assert frameSizes == payloadSize; + assert frameSizes <= 7; + // decoder will coalesce padding frames. So instead of finding + // two padding frames in the decoded packet we will find just one. + // To make the check pass, we should expect a bigger padding frame. + if (minPayloadSize > frameSizes) { + // replace the first frame with a bigger padding frame + expectedFrames = new ArrayList<>(frames); + var first = frames.get(0); + // replace the first frame with a bigger padding frame that + // coalesce the first padding payload frame with the padding that + // should have been added by the encoder. + // We will then be able to check that the decoded packet contains + // that single bigger padding frame. + expectedFrames.set(0, new PaddingFrame(minPayloadSize - frameSizes + first.size())); + } + } + checkShortHeaderPacket(oneRttDecoded, PacketType.ONERTT, + PacketNumberSpace.APPLICATION, packetNumber, + destConnectionId, expectedFrames, minPayloadSize); + assertEquals(decodedPacket.size(), packet.size()); + assertEquals(decodedPacket.size(), size); + } + assertEquals(datagram.position(), offset + size); + assertEquals(datagram.remaining(), 0); + assertEquals(datagram.limit(), end); + + } + + @Test + public void testNoMismatch() { + List match1 = List.of( + ByteBuffer.wrap(new byte[] {1, 2, 3}), + ByteBuffer.wrap(new byte[] {4}), + ByteBuffer.wrap(new byte[] {5, 6}), + ByteBuffer.wrap(new byte[] {7, 8}), + ByteBuffer.wrap(new byte[] {9}), + ByteBuffer.wrap(new byte[] {10, 11, 12}), + ByteBuffer.wrap(new byte[] {13, 14, 15, 16}), + ByteBuffer.wrap(new byte[] {17, 18, 19, 20}) + ); + List match2 = List.of( + ByteBuffer.wrap(new byte[] {1, 2, 3}), + ByteBuffer.wrap(new byte[] {4, 5}), + ByteBuffer.wrap(new byte[] {6}), + ByteBuffer.wrap(new byte[] {7}), + ByteBuffer.wrap(new byte[] {8, 9}), + ByteBuffer.wrap(new byte[] {10, 11}), + ByteBuffer.wrap(new byte[] {12, 13, 14}), + ByteBuffer.wrap(new byte[] {15}), + ByteBuffer.wrap(new byte[] {16, 17}), + ByteBuffer.wrap(new byte[] {18, 19, 20}) + ); + List match3 = List.of( + ByteBuffer.wrap(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}), + ByteBuffer.wrap(new byte[] {9, 10, 11}), + ByteBuffer.wrap(new byte[] {12, 13, 14}), + ByteBuffer.wrap(new byte[] {15}), + ByteBuffer.wrap(new byte[] {16, 17}), + ByteBuffer.wrap(new byte[] {18, 19, 20}) + ); + assertEquals(Utils.mismatch(match1, match1), -1); + assertEquals(Utils.mismatch(match2, match2), -1); + assertEquals(Utils.mismatch(match3, match3), -1); + assertEquals(Utils.mismatch(match1, match2), -1); + assertEquals(Utils.mismatch(match2, match1), -1); + assertEquals(Utils.mismatch(match1, match3), -1); + assertEquals(Utils.mismatch(match3, match1), -1); + assertEquals(Utils.mismatch(match2, match3), -1); + assertEquals(Utils.mismatch(match3, match2), -1); + } + + @Test + public void testMismatch() { + // match1, match2, match3 match with each others + List match1 = List.of( + ByteBuffer.wrap(new byte[] {1, 2, 3}), + ByteBuffer.wrap(new byte[] {4}), + ByteBuffer.wrap(new byte[] {5, 6}), + ByteBuffer.wrap(new byte[] {7, 8}), + ByteBuffer.wrap(new byte[] {9}), + ByteBuffer.wrap(new byte[] {10, 11, 12}), + ByteBuffer.wrap(new byte[] {13, 14, 15, 16}), + ByteBuffer.wrap(new byte[] {17, 18, 19, 20}) + ); + List match2 = List.of( + ByteBuffer.wrap(new byte[] {1, 2, 3}), + ByteBuffer.wrap(new byte[] {4, 5}), + ByteBuffer.wrap(new byte[] {6}), + ByteBuffer.wrap(new byte[] {7}), + ByteBuffer.wrap(new byte[] {8, 9}), + ByteBuffer.wrap(new byte[] {10, 11}), + ByteBuffer.wrap(new byte[] {12, 13, 14}), + ByteBuffer.wrap(new byte[] {15}), + ByteBuffer.wrap(new byte[] {16, 17}), + ByteBuffer.wrap(new byte[] {18, 19, 20}) + ); + List match3 = List.of( + ByteBuffer.wrap(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}), + ByteBuffer.wrap(new byte[] {9, 10, 11}), + ByteBuffer.wrap(new byte[] {12, 13, 14}), + ByteBuffer.wrap(new byte[] {15}), + ByteBuffer.wrap(new byte[] {16, 17}), + ByteBuffer.wrap(new byte[] {18, 19, 20}) + ); + // nomatch0, nomatch10, nomatch19 differ from the previous + // list at some index in [0..20[ + // nomatch0 mismatches at index 0 + List nomatch0 = List.of( + ByteBuffer.wrap(new byte[] {21, 2, 3}), + ByteBuffer.wrap(new byte[] {4}), + ByteBuffer.wrap(new byte[] {5, 6}), + ByteBuffer.wrap(new byte[] {7, 8}), + ByteBuffer.wrap(new byte[] {9}), + ByteBuffer.wrap(new byte[] {10, 11, 12}), + ByteBuffer.wrap(new byte[] {13, 14, 15, 16}), + ByteBuffer.wrap(new byte[] {17, 18, 19, 20}) + ); + // nomatch10 mismatches at index 10 + List nomatch10 = List.of( + ByteBuffer.wrap(new byte[] {1, 2, 3}), + ByteBuffer.wrap(new byte[] {4, 5}), + ByteBuffer.wrap(new byte[] {6}), + ByteBuffer.wrap(new byte[] {7}), + ByteBuffer.wrap(new byte[] {8, 9}), + ByteBuffer.wrap(new byte[] {10, 31}), + ByteBuffer.wrap(new byte[] {12, 13, 14}), + ByteBuffer.wrap(new byte[] {15}), + ByteBuffer.wrap(new byte[] {16, 17}), + ByteBuffer.wrap(new byte[] {18, 19, 20}) + ); + // nomatch19 mismatches at index 19 + List nomatch19 = List.of( + ByteBuffer.wrap(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}), + ByteBuffer.wrap(new byte[] {9, 10, 11}), + ByteBuffer.wrap(new byte[] {12, 13, 14}), + ByteBuffer.wrap(new byte[] {15}), + ByteBuffer.wrap(new byte[] {16, 17}), + ByteBuffer.wrap(new byte[] {18, 19, 40}) + ); + // morematch1 has one more byte at the end + List morematch1 = List.of( + ByteBuffer.wrap(new byte[] {1, 2, 3}), + ByteBuffer.wrap(new byte[] {4}), + ByteBuffer.wrap(new byte[] {5, 6}), + ByteBuffer.wrap(new byte[] {7, 8}), + ByteBuffer.wrap(new byte[] {9}), + ByteBuffer.wrap(new byte[] {10, 11, 12}), + ByteBuffer.wrap(new byte[] {13, 14, 15, 16}), + ByteBuffer.wrap(new byte[] {17, 18, 19, 20, 41}) + ); + // morematch2 and morematch3 have the same 3 additional + // bytes at the end + List morematch2 = List.of( + ByteBuffer.wrap(new byte[] {1, 2, 3}), + ByteBuffer.wrap(new byte[] {4, 5}), + ByteBuffer.wrap(new byte[] {6}), + ByteBuffer.wrap(new byte[] {7}), + ByteBuffer.wrap(new byte[] {8, 9}), + ByteBuffer.wrap(new byte[] {10, 11}), + ByteBuffer.wrap(new byte[] {12, 13, 14}), + ByteBuffer.wrap(new byte[] {15}), + ByteBuffer.wrap(new byte[] {16, 17}), + ByteBuffer.wrap(new byte[] {18, 19, 20}), + ByteBuffer.wrap(new byte[] {41, 42, 43}) + ); + List morematch3 = List.of( + ByteBuffer.wrap(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}), + ByteBuffer.wrap(new byte[] {9, 10, 11}), + ByteBuffer.wrap(new byte[] {12, 13, 14}), + ByteBuffer.wrap(new byte[] {15}), + ByteBuffer.wrap(new byte[] {16, 17}), + ByteBuffer.wrap(new byte[] {18, 19, 20, 41, 42, 43}) + ); + + assertEquals(Utils.mismatch(nomatch0, nomatch0), -1L); + assertEquals(Utils.mismatch(nomatch10, nomatch10), -1L); + assertEquals(Utils.mismatch(nomatch19, nomatch19), -1L); + assertEquals(Utils.mismatch(morematch1, morematch1), -1L); + assertEquals(Utils.mismatch(morematch2, morematch2), -1L); + assertEquals(Utils.mismatch(morematch3, morematch3), -1L); + assertEquals(Utils.mismatch(morematch2, morematch3), -1L); + assertEquals(Utils.mismatch(morematch3, morematch2), -1L); + + for (var match : List.of(match1, match2, match3)) { + assertEquals(Utils.mismatch(match, nomatch0), 0L); + assertEquals(Utils.mismatch(match, nomatch10), 10L); + assertEquals(Utils.mismatch(match, nomatch19), 19L); + assertEquals(Utils.mismatch(nomatch0, match), 0L); + assertEquals(Utils.mismatch(nomatch10, match), 10L); + assertEquals(Utils.mismatch(nomatch19, match), 19L); + for (var morematch : List.of(morematch1, morematch2, morematch3)) { + assertEquals(Utils.mismatch(match, morematch), 20L); + assertEquals(Utils.mismatch(morematch, match), 20L); + } + + } + } + +} diff --git a/test/jdk/java/net/httpclient/quic/PacketLossTest.java b/test/jdk/java/net/httpclient/quic/PacketLossTest.java new file mode 100644 index 00000000000..1298e0977de --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/PacketLossTest.java @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; + +import jdk.httpclient.test.lib.common.TestUtil; +import jdk.httpclient.test.lib.quic.ClientConnection; +import jdk.httpclient.test.lib.quic.ConnectedBidiStream; +import jdk.httpclient.test.lib.quic.DatagramDeliveryPolicy; +import jdk.httpclient.test.lib.quic.QuicServer; +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.httpclient.test.lib.quic.QuicServerHandler; +import jdk.httpclient.test.lib.quic.QuicStandaloneServer; +import jdk.internal.net.http.quic.QuicClient; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +/* + * @test + * @summary Verifies QUIC client interaction against servers which exhibit packet loss + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.quic.QuicStandaloneServer + * jdk.httpclient.test.lib.quic.ClientConnection + * jdk.httpclient.test.lib.common.TestUtil + * jdk.test.lib.net.SimpleSSLContext + * @run junit/othervm/timeout=240 -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.quic.minPtoBackoffTime=60 + * -Djdk.httpclient.quic.maxPtoBackoffTime=10 + * -Djdk.httpclient.quic.maxPtoBackoff=9 + * -Djdk.httpclient.HttpClient.log=quic,errors PacketLossTest + */ +public class PacketLossTest { + + private static SSLContext sslContext; + private static ExecutorService executor; + + private static final byte[] HELLO_MSG = "Hello Quic".getBytes(StandardCharsets.UTF_8); + + @BeforeAll + public static void beforeAll() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + executor = Executors.newCachedThreadPool(); + } + + @AfterAll + public static void afterAll() throws Exception { + if (executor != null) { + executor.close(); + } + } + + private QuicClient createClient() { + var versions = List.of(QuicVersion.QUIC_V1); + var context = new QuicTLSContext(sslContext); + var params = new SSLParameters(); + return new QuicClient.Builder() + .availableVersions(versions) + .tlsContext(context) + .sslParameters(params) + .executor(executor) + .bindAddress(TestUtil.chooseClientBindAddress().orElse(null)) + .build(); + } + + sealed interface DropPolicy { + final record DropRandomly() implements DropPolicy {} + final record DropEveryNth(int n) implements DropPolicy {} + final record DropNone() implements DropPolicy {} + final record DropAll() implements DropPolicy {} + default DatagramDeliveryPolicy policy() { + if (this instanceof DropRandomly) { + return DatagramDeliveryPolicy.dropRandomly(); + } else if (this instanceof DropEveryNth en) { + return DatagramDeliveryPolicy.dropEveryNth(en.n()); + } else if (this instanceof DropNone) { + return DatagramDeliveryPolicy.alwaysDeliver(); + } else if (this instanceof DropAll) { + return DatagramDeliveryPolicy.neverDeliver(); + } + throw new IllegalStateException("Unknown policy: " + this); + } + } + + record DropServer(String name, QuicStandaloneServer server) implements AutoCloseable { + @Override + public void close() throws IOException { + server.close(); + } + public InetSocketAddress getAddress() { + return server.getAddress(); + } + } + + DropServer of(QuicServer.Builder builder, DropPolicy incomimg, DropPolicy outgoing) + throws IOException { + QuicStandaloneServer server = builder + .incomingDeliveryPolicy(incomimg.policy()) + .outgoingDeliveryPolicy(outgoing.policy()) + .build(); + String name = "DropServer(%s, in: %s, out: %s)".formatted(server.name(), incomimg, outgoing); + return new DropServer(name, server); + } + + DropPolicy dropEveryNth(int n) { + return new DropPolicy.DropEveryNth(n); + } + DropPolicy dropRandomly() { + return new DropPolicy.DropRandomly(); + } + + // returns a List of unstarted Quic servers configured with different incoming/outgoing + // datagram delivery policies + private List unstartedServers() throws Exception { + final QuicServer.Builder builder = QuicStandaloneServer.newBuilder() + .availableVersions(new QuicVersion[]{QuicVersion.QUIC_V1}) + .sslContext(sslContext); + final List servers = new ArrayList<>(); + servers.add(of(builder, dropEveryNth(3), dropEveryNth(7))); + servers.add(of(builder, dropRandomly(), dropRandomly())); + servers.add(of(builder, dropEveryNth(5), dropRandomly())); + return servers; + } + + private static void startServer(final QuicStandaloneServer server) throws IOException { + // add a handler which deals with incoming connections + server.addHandler(new EchoHandler(HELLO_MSG.length)); + server.start(); + System.out.println("Server " + server.name() + " started at " + server.getAddress()); + } + + /** + * Uses {@link QuicClient} to pass data and expect back the data to/from Quic servers which + * might drop incoming/outgoing packets. + */ + @Test + public void testDataTransfer() throws Exception { + for (final DropServer server : unstartedServers()) { + startServer(server.server()); + try (server) { + System.out.printf("%n%n===== %s =====%n%n", server.name()); + System.err.printf("%n%n===== %s =====%n%n", server.name()); + final int numTimes = 20; + try (final QuicClient client = createClient()) { + final InetSocketAddress serverAddr = server.getAddress(); + // create a QUIC connection to the server + final ClientConnection conn = ClientConnection.establishConnection(client, serverAddr); + for (int i = 1; i <= numTimes; i++) { + System.out.println("iteration " + i + " against server: " + server.name() + + ", server addr: " + serverAddr); + // open a bidi stream + final ConnectedBidiStream bidiStream = conn.initiateNewBidiStream(); + // write data on the stream + try (final OutputStream os = bidiStream.outputStream()) { + os.write(HELLO_MSG); + System.out.println("client: Client wrote message to bidi stream's output stream"); + } + // wait for response + try (final InputStream is = bidiStream.inputStream()) { + System.out.println("client: reading from bidi stream's input stream"); + final byte[] data = is.readAllBytes(); + System.out.println("client: Received response of size " + data.length); + final String response = new String(data, StandardCharsets.UTF_8); + // verify response + System.out.println("client: Response: " + response); + if (!Arrays.equals(response.getBytes(StandardCharsets.UTF_8), HELLO_MSG)) { + throw new AssertionError("Unexpected response: " + response); + } + } finally { + System.err.println("client: Closing bidi stream from test"); + bidiStream.close(); + } + } + } + } + } + } + + /** + * Reads data from incoming client initiated bidirectional stream of a Quic connection + * and writes back a response which is same as the read data + */ + private static final class EchoHandler implements QuicServerHandler { + + private final int numBytesToRead; + + private EchoHandler(final int numBytesToRead) { + this.numBytesToRead = numBytesToRead; + } + + @Override + public void handleBidiStream(final QuicServerConnection conn, + final ConnectedBidiStream bidiStream) throws IOException { + System.out.println("Handling incoming bidi stream " + bidiStream + + " on connection " + conn); + final byte[] data; + // read the request content + try (final InputStream is = bidiStream.inputStream()) { + System.out.println("Handler reading data from bidi stream's inputstream " + is); + data = is.readAllBytes(); + System.out.println("Handler read " + data.length + " bytes of data"); + } + if (data.length != numBytesToRead) { + throw new IOException("Expected to read " + numBytesToRead + + " bytes but read only " + data.length + " bytes"); + } + // write response + try (final OutputStream os = bidiStream.outputStream()) { + System.out.println("Handler writing data to bidi stream's outputstream " + os); + os.write(data); + } + System.out.println("Handler invocation complete"); + } + } +} diff --git a/test/jdk/java/net/httpclient/quic/PacketNumbersTest.java b/test/jdk/java/net/httpclient/quic/PacketNumbersTest.java new file mode 100644 index 00000000000..3a4237d0847 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/PacketNumbersTest.java @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; + +import jdk.internal.net.http.quic.packets.QuicPacketNumbers; +import org.testng.SkipException; +import org.testng.annotations.Test; +import org.testng.annotations.DataProvider; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.expectThrows; + +/** + * @test + * @run testng PacketNumbersTest + */ +public class PacketNumbersTest { + + record EncodeResult(int expected, boolean assertion, Class failure) { + public static EncodeResult illegal() { + return new EncodeResult(-1, true, IllegalArgumentException.class); + } + public static EncodeResult asserting() { + return new EncodeResult(-1, true, AssertionError.class); + } + public static EncodeResult success(int result) { + return new EncodeResult(result, false, null); + } + public static EncodeResult fail(Class failure) { + return new EncodeResult(-1, false, failure); + } + public boolean fail() { + return failure() != null; + } + + @Override + public String toString() { + return fail() ? failure.getSimpleName() + : String.valueOf(expected); + } + } + record TestCase(String desc, long fullPN, long largestAck, EncodeResult result) { + static AtomicInteger count = new AtomicInteger(); + TestCase { + desc = count.incrementAndGet() + " - expecting " + desc; + } + public byte[] encode() { + return QuicPacketNumbers.encodePacketNumber(fullPN(), largestAck()); + } + + public long decode() { + byte[] encoded = encode(); + var largestProcessed = largestAck(); + return QuicPacketNumbers.decodePacketNumber(largestProcessed, + ByteBuffer.wrap(encoded), encoded.length); + } + + @Override + public String toString() { + return "%s: (%d, %d) -> %s".formatted(desc, fullPN, largestAck, result); + } + } + + @DataProvider + public Object[][] encode() { + return List.of( + // these first three test cases are extracted from RFC 9000, appendix A.2 and A.3 + new TestCase("success", 0xa82f9b32L, 0xa82f30eaL, EncodeResult.success(0x9b32)), + new TestCase("success", 0xace8feL, 0xabe8b3L, EncodeResult.success(0xace8fe & 0xFFFFFF)), + new TestCase("success", 0xac5c02L, 0xabe8b3L, EncodeResult.success(0xac5c02 & 0xFFFF)), + // additional test cases - these have been obtained empirically to test at the limits + new TestCase("success", 0x7FFFFFFFFFFFL, 0x7FFFFFFFFF00L, EncodeResult.success(0x0000FFFF)), + new TestCase("success", 0xFFFFFFFFFFL, 0xFFFFFFFF00L, EncodeResult.success(0x0000FFFF)), + new TestCase("success", 0xFFFFFFFFL, 0xFFFFFFFEL, EncodeResult.success(0x000000FF)), + new TestCase("success", 0xFFFFFFFFL, 0xFFFFFF00L, EncodeResult.success(0x0000FFFF)), + new TestCase("success", 0xFFFFFFFFL, 0xFFFF0000L, EncodeResult.success(0x00FFFFFF)), + new TestCase("success", 0xFFFFFFFFL, 0xFF000000L, EncodeResult.success(0xFFFFFFFF)), + new TestCase("success", 0xFFFFFFFFL, 0xF0000000L, EncodeResult.success(0xFFFFFFFF)), + new TestCase("success", 0xFFFFFFFFL, 0x80000000L, EncodeResult.success(0xFFFFFFFF)), + new TestCase("illegal(5)",0xFFFFFFFFL, 0x7FFFFFFFL, EncodeResult.illegal()), + new TestCase("success", 0x8FFFFFFFL, 0x10000000L, EncodeResult.success(0x8FFFFFFF)), + new TestCase("illegal(5)",0x8FFFFFFFL, 0x0FFFFFFFL, EncodeResult.illegal()), + new TestCase("illegal(5)",0x8FFFFFFFL, 256L, EncodeResult.illegal()), + new TestCase("success", 0x7FFFFFFFL, 255L, EncodeResult.success(0x7FFFFFFF)), + new TestCase("success", 0x7FFFFFFFL, 0L, EncodeResult.success(0x7FFFFFFF)), + new TestCase("illegal(5)",0x7FFFFFFFL, -1L, EncodeResult.illegal()), + new TestCase("success", 0x6FFFFFFFL, 0L, EncodeResult.success(0x6FFFFFFF)), + new TestCase("success", 0xFFFFFFL, 0L, EncodeResult.success(0xFFFFFF)), + new TestCase("success", 0xFFFFL, 0L, EncodeResult.success(0xFFFF)), + new TestCase("success", 255L, 0L, EncodeResult.success(255)), + new TestCase("success", 1L, 0L, EncodeResult.success(1)), + new TestCase("success", 0x6FFFFFFFL, -1L, EncodeResult.success(0x6FFFFFFF)), + new TestCase("success", 0xFFFFFFL, -1L, EncodeResult.success(0xFFFFFF)), + new TestCase("success", 0xFFFFL, -1L, EncodeResult.success(0xFFFF)), + new TestCase("success", 255L, -1L, EncodeResult.success(255)), + new TestCase("success", 1L, -1L, EncodeResult.success(1)), + new TestCase("success", 0L, -1L, EncodeResult.success(0)), + new TestCase("assert", 0L, 1L, EncodeResult.asserting()), + new TestCase("assert", 0L, 0L, EncodeResult.asserting()), + new TestCase("assert", 1L, 1L, EncodeResult.asserting()) + ).stream().map(Stream::of) + .map(Stream::toArray) + .toArray(Object[][]::new); + } + + @Test(dataProvider = "encode") + public void testEncodePacketNumber(TestCase test) { + System.out.println(test); + if (test.result().assertion()) { + if (!QuicPacketNumbers.class.desiredAssertionStatus()) { + throw new SkipException("needs assertion enabled (-esa)"); + } + Throwable t = expectThrows(test.result().failure(), test::encode); + System.out.println("Got expected assertion: " + t); + return; + } + if (test.result().fail()) { + Throwable t = expectThrows(test.result().failure(), test::encode); + System.out.println("Got expected exception: " + t); + return; + + } + byte[] res = test.encode(); + int truncated = 0; + for (int i=0; i initial; + case HANDSHAKE -> handshake; + case APPLICATION -> app; + case NONE -> throw new AssertionError("invalid number space: " + packetNumberSpace); + }; + if (result == null) { + throw new AssertionError("invalid number space: " + packetNumberSpace); + } + return result; + } + } + + private static class DummyQuicTLSEngine implements QuicTLSEngine { + @Override + public HandshakeState getHandshakeState() { + throw new AssertionError("should not come here!"); + } + + @Override + public boolean isTLSHandshakeComplete() { + return true; + } + + @Override + public KeySpace getCurrentSendKeySpace() { + throw new AssertionError("should not come here!"); + } + @Override + public boolean keysAvailable(KeySpace keySpace) { + return true; + } + + @Override + public void discardKeys(KeySpace keySpace) { + // no-op + } + + @Override + public void setLocalQuicTransportParameters(ByteBuffer params) { + throw new AssertionError("should not come here!"); + } + + @Override + public void restartHandshake() throws IOException { + throw new AssertionError("should not come here!"); + } + + @Override + public void setRemoteQuicTransportParametersConsumer(QuicTransportParametersConsumer consumer) { + throw new AssertionError("should not come here!"); + } + @Override + public void deriveInitialKeys(QuicVersion version, ByteBuffer connectionId) { } + @Override + public int getHeaderProtectionSampleSize(KeySpace keySpace) { + return 0; + } + @Override + public ByteBuffer computeHeaderProtectionMask(KeySpace keySpace, boolean incoming, ByteBuffer sample) { + return ByteBuffer.allocate(5); + } + + @Override + public int getAuthTagSize() { + return 0; + } + + @Override + public void encryptPacket(KeySpace keySpace, long packetNumber, + IntFunction headerGenerator, + ByteBuffer packetPayload, ByteBuffer output) + throws QuicKeyUnavailableException, QuicTransportException { + // this dummy QUIC TLS engine doesn't do any encryption. + // we just copy over the raw packet payload into the output buffer + output.put(packetPayload); + } + + @Override + public void decryptPacket(KeySpace keySpace, long packetNumber, int keyPhase, + ByteBuffer packet, int headerLength, ByteBuffer output) { + packet.position(packet.position() + headerLength); + output.put(packet); + } + + @Override + public void signRetryPacket(QuicVersion version, + ByteBuffer originalConnectionId, ByteBuffer packet, ByteBuffer output) { + throw new AssertionError("should not come here!"); + } + @Override + public void verifyRetryPacket(QuicVersion version, + ByteBuffer originalConnectionId, ByteBuffer packet) throws AEADBadTagException { + throw new AssertionError("should not come here!"); + } + @Override + public ByteBuffer getHandshakeBytes(KeySpace keySpace) { + throw new AssertionError("should not come here!"); + } + @Override + public void consumeHandshakeBytes(KeySpace keySpace, ByteBuffer payload) { + throw new AssertionError("should not come here!"); + } + @Override + public Runnable getDelegatedTask() { + throw new AssertionError("should not come here!"); + } + @Override + public boolean tryMarkHandshakeDone() { + throw new AssertionError("should not come here!"); + } + @Override + public boolean tryReceiveHandshakeDone() { + throw new AssertionError("should not come here!"); + } + + @Override + public Set getSupportedQuicVersions() { + return Set.of(QuicVersion.QUIC_V1); + } + + @Override + public void setUseClientMode(boolean mode) { + throw new AssertionError("should not come here!"); + } + + @Override + public boolean getUseClientMode() { + throw new AssertionError("should not come here!"); + } + + @Override + public SSLParameters getSSLParameters() { + throw new AssertionError("should not come here!"); + } + + @Override + public void setSSLParameters(SSLParameters sslParameters) { + throw new AssertionError("should not come here!"); + } + + @Override + public String getApplicationProtocol() { + return null; + } + + @Override + public SSLSession getSession() { + throw new AssertionError("should not come here!"); + } + + @Override + public SSLSession getHandshakeSession() { + throw new AssertionError("should not come here!"); + } + + @Override + public void versionNegotiated(QuicVersion quicVersion) { + // no-op + } + + @Override + public void setOneRttContext(QuicOneRttContext ctx) { + // no-op + } + } + + private static final QuicTLSEngine TLS_ENGINE = new DummyQuicTLSEngine(); + private static class TestCodingContext implements CodingContext { + final QuicPacketEncoder encoder; + final QuicPacketDecoder decoder; + final PacketSpaces spaces; + TestCodingContext(PacketSpaces spaces) { + this.spaces = spaces; + this.encoder = QuicPacketEncoder.of(QuicVersion.QUIC_V1); + this.decoder = QuicPacketDecoder.of(QuicVersion.QUIC_V1); + } + + @Override + public long largestProcessedPN(PacketNumberSpace packetNumberSpace) { + return spaces.get(packetNumberSpace).getLargestProcessedPN(); + } + + @Override + public long largestAckedPN(PacketNumberSpace packetNumberSpace) { + return spaces.get(packetNumberSpace).getLargestPeerAckedPN(); + } + + @Override + public int connectionIdLength() { + return CIDLEN; + } + + @Override + public int writePacket(QuicPacket packet, ByteBuffer buffer) + throws QuicKeyUnavailableException, QuicTransportException { + int pos = buffer.position(); + encoder.encode(packet, buffer, this); + return buffer.position() - pos; + } + @Override + public QuicPacket parsePacket(ByteBuffer src) + throws IOException, QuicKeyUnavailableException, QuicTransportException { + return decoder.decode(src, this); + } + @Override + public boolean verifyToken(QuicConnectionId destinationID, byte[] token) { + return true; + } + @Override + public QuicConnectionId originalServerConnId() { + return null; + } + + @Override + public QuicTLSEngine getTLSEngine() { + return TLS_ENGINE; + } + } + + /** + * An acknowledgement range, where acknowledged packets are [first..last]. + * For instance [9,9] acknowledges only packet 9. + * @param first the first packet acknowledged, inclusive + * @param last the last packet acknowledged, inclusive + */ + public static record Acknowledged(long first, long last) { + public Acknowledged { + assert first >= 0 && first <= last; + } + public boolean contains(long packet) { + return first <= packet && last >= packet; + } + public static List of(long... numbers) { + if (numbers == null || numbers.length == 0) return List.of(); + if (numbers.length % 2 != 0) throw new IllegalArgumentException(); + List res = new ArrayList<>(numbers.length/2); + for (int i = 0; i < numbers.length; i += 2) { + res.add(new Acknowledged(numbers[i], numbers[i+1])); + } + return List.copyOf(res); + } + } + + /** + * A packet to be emitted, followed by a pouse of {@code delay} in milliseconds. + * @param packetNumber the packet number of the packet to send + * @param delay a delay before the next packet should be emitted + */ + public static record Packet(long packetNumber, long delay) { + Packet(long packetNumber) { + this(packetNumber, RANDOM.nextLong(1, 255)); + } + static List ofAcks(List acks) { + return packets(acks); + } + static List of(long... numbers) { + return LongStream.of(numbers).mapToObj(Packet::new).toList(); + } + static final Comparator COMPARE_NUMBERS = Comparator.comparingLong(Packet::packetNumber); + } + + /** + * A test case. Composed of a list of acknowledgements, a list of packets, + * and a list of AckFrames. The list of packets is built from the list of + * acknowledgement - that is - every packet emitted should eventually be + * acknowledged. The list of AckFrame is built from the list of Packets, + * by randomly selecting a few consecutive packets in the packet list + * to acknowledge. The list of AckFrame is sorted by increasing + * largestAcknowledged. The list of Packets can be shuffled, which + * result on having AckFrames with gaps. + * @param acks A list of acknowledgement ranges + * @param packets A list of packets generated from the acknowledgement ranges. + * The list can be shuffled. + * @param ackframes A list of AckFrames, derived from the possibly shuffled + * list of packets. The list of AckFrame is sorted by increasing + * largestAcknowledged (since a packet can't be acknowledged + * before it's been emitted). + * @param shuffled whether the list of packets is shuffled. + */ + public static record TestCase(List acks, + List packets, + List ackframes, + boolean shuffled) { + public TestCase(List acks, List packets) { + this(acks, packets, ackFrames(packets),false); + } + public TestCase(List acks, List packets, boolean shuffled) { + this(acks, packets, ackFrames(packets), shuffled); + } + public TestCase(List acks) { + this(acks, Packet.ofAcks(acks)); + } + public TestCase shuffle() { + List shuffled = new ArrayList<>(packets); + Collections.shuffle(shuffled, RANDOM); + return new TestCase(acks, List.copyOf(shuffled), true); + } + } + + /** + * Construct a list of AckFrames from the possibly shuffled list + * of Packets. + * @param packets a list of packets + * @return a sorted list of AckFrames + */ + private static List ackFrames(List packets) { + List result = new ArrayList<>(); + int remaining = packets.size(); + int i = 0; + while (remaining > 0) { + int ackCount = Math.min(RANDOM.nextInt(1, 5), remaining); + AckFrameBuilder builder = new AckFrameBuilder(); + for (int j=0; j < ackCount; j++) { + builder.addAck(packets.get(i + j).packetNumber); + } + result.add(builder.build()); + i += ackCount; + remaining -= ackCount; + } + result.sort(Comparator.comparingLong(AckFrame::largestAcknowledged)); + return List.copyOf(result); + } + + /** + * Generates test cases - by concatenating a list of simple test case, + * a list of special testcases, and a list of random testcases. + * @return A list of TestCases to test. + */ + List generateTests() { + List tests = new ArrayList<>(); + List simples = List.of( + new TestCase(List.of(new Acknowledged(5,5))), + new TestCase(List.of(new Acknowledged(5,7))), + new TestCase(List.of(new Acknowledged(3, 5), new Acknowledged(7,9))), + new TestCase(List.of(new Acknowledged(3, 5), new Acknowledged(7,7))), + new TestCase(List.of(new Acknowledged(3,3), new Acknowledged(5,7))) + ); + tests.addAll(simples); + List specials = List.of( + new TestCase(Acknowledged.of(5,5,7,7), Packet.of(5,7), false), + new TestCase(Acknowledged.of(5,7), Packet.of(5,7,6), true), + new TestCase(Acknowledged.of(6,7), Packet.of(6,7), false), + new TestCase(Acknowledged.of(5,7), Packet.of(6,7,5), true), + new TestCase(Acknowledged.of(5,7), Packet.of(5,6,7), true), + new TestCase(Acknowledged.of(5,5,7,8), Packet.of(5, 7, 8), true), + new TestCase(Acknowledged.of(5,5,8,8), Packet.of(8, 5), true), + new TestCase(Acknowledged.of(5,5,7,8), Packet.of(8, 5, 7), true), + new TestCase(Acknowledged.of(3,5,7,9), Packet.of(8,5,7,4,9,3), true), + new TestCase(Acknowledged.of(27,27,31,31), + Packet.of(27, 31), true), + new TestCase(Acknowledged.of(27,27,29,29,31,31), + Packet.of(27, 31, 29), true), + new TestCase(Acknowledged.of(3,5,7,7,9,9,22,22,27,27,29,29,31,31), + Packet.of(4,22,27,31,9,29,7,5,3), true) + ); + tests.addAll(specials); + for (int i=0; i < 5; i++) { + List acks = generateAcks(); + List packets = packets(acks); + TestCase test = new TestCase(acks, List.copyOf(packets), false); + tests.add(test); + for (int j = 0; j < 5; j++) { + tests.add(test.shuffle()); + } + } + return tests; + } + + /** + * Generate a random list of increasing acknowledgement ranges. + * A packet should only be present once. + * @return a random list of increasing acknowledgement ranges. + */ + List generateAcks() { + int count = RANDOM.nextInt(3, 10); + List acks = new ArrayList<>(count); + long prev = -1; + for (int i=0; i packets(List acks) { + List res = new ArrayList<>(); + for (Acknowledged ack : acks) { + for (long i = ack.first() ; i<= ack.last() ; i++) { + var packet = new Packet(i); + assert !res.contains(packet); + res.add(packet); + } + } + return res; + } + + @DataProvider(name = "tests") + public Object[][] tests() { + return generateTests().stream() + .map(List::of) + .map(List::toArray) + .toArray(Object[][]::new); + } + + // TODO: + // 1. TestCase should have an ordered list of packets. + // 2. packets will be emitted in order - that is - + // PacketSpaceManager::packetSent will be called for each packet + // in order. + // 3. acknowledgements of packet should arrive in random order. + // a selection of packets should be acknowledged in bunch... + // However, a packet shouldn't be acknowledged before it is emitted. + // 4. some packets should not be acknowledged in time, causing + // them to be retransmitted. + // 5. all of retransmitted packets should eventually be acknowledged, + // but which packet number (among the list of numbers under which + // a packet is retransmitted should be random). + // 6. packets that are acknowledged should no longer be retransmitted + // 7. code an AsynchronousTestDriver that uses the EXECUTOR + // this is more difficult as it makes it more difficult to guess + // at when exactly a packet will be retransmitted... + + /** + * A synchronous test driver to drive a TestCase. + * The method {@link #run()} drives the test. + */ + static class SynchronousTestDriver implements PacketEmitter { + final TestCase test; + final long timeline; + final QuicTimerQueue timerQueue; + final PriorityBlockingQueue packetQueue; + final PriorityBlockingQueue framesQueue; + final PacketSpaceManager manager; + final PacketNumberSpace space; + final Executor executor = this::execute; + final Logger debug = TestLoggerUtil.getErrOutLogger(this::toString); + final TestCodingContext codingContext; + final ConcurrentLinkedQueue emittedAckPackets; + final QuicRttEstimator rttEstimator; + final QuicCongestionController congestionController; + final QuicConnectionId localId; + final QuicConnectionId peerId; + final TimeSource timeSource = new TimeSource(); + final long maxPacketNumber; + final AckFrameBuilder allAcks = new AckFrameBuilder(); + + SynchronousTestDriver(TestCase test) { + this.space = PacketNumberSpace.INITIAL; + this.test = test; + + localId = newId(); + peerId = newId(); + timeline = test.packets().stream() + .mapToLong(Packet::delay) + .reduce(0, Math::addExact); + timerQueue = new QuicTimerQueue(this::notifyQueue, debug); + packetQueue = new PriorityBlockingQueue<>(test.packets.size(), Packet.COMPARE_NUMBERS); + packetQueue.addAll(test.packets); + framesQueue = new PriorityBlockingQueue<>(test.ackframes.size(), + Comparator.comparingLong(AckFrame::largestAcknowledged)); + framesQueue.addAll(test.ackframes); + emittedAckPackets = new ConcurrentLinkedQueue<>(); + rttEstimator = new QuicRttEstimator() { + @Override + public synchronized Duration getLossThreshold() { + return Duration.ofMillis(250); + } + + @Override + public synchronized Duration getBasePtoDuration() { + return Duration.ofMillis(250); + } + }; + congestionController = new QuicCongestionController() { + @Override + public boolean canSendPacket() { + return true; + } + @Override + public void updateMaxDatagramSize(int newSize) { } + @Override + public void packetSent(int packetBytes) { } + @Override + public void packetAcked(int packetBytes, Deadline sentTime) { } + @Override + public void packetLost(Collection lostPackets, Deadline sentTime, boolean persistent) { } + @Override + public void packetDiscarded(Collection discardedPackets) { } + }; + manager = new PacketSpaceManager(space, this, timeSource, + rttEstimator, congestionController, new DummyQuicTLSEngine(), + this::toString); + maxPacketNumber = test.packets().stream().mapToLong(Packet::packetNumber) + .max().getAsLong(); + manager.getNextPN().set(maxPacketNumber + 1); + codingContext = new TestCodingContext(new PacketSpaces(manager, null, null)); + } + + static class TimeSource implements TimeLine { + final Deadline first = jdk.internal.net.http.common.TimeSource.now(); + volatile Deadline current = first; + public synchronized Deadline advance(long duration, TemporalUnit unit) { + return current = current.plus(duration, unit); + } + public Deadline advanceMillis(long millis) { + return advance(millis, ChronoUnit.MILLIS); + } + @Override + public Deadline instant() { + return current; + } + } + + void notifyQueue() { + timerQueue.processEventsAndReturnNextDeadline(now(), executor); + } + + @Override + public QuicTimerQueue timer() { return timerQueue;} + + @Override + public void retransmit(PacketSpace packetSpaceManager, QuicPacket packet, int attempts) { + if (!(packet instanceof InitialPacket initial)) + throw new AssertionError("unexpected packet type: " + packet); + long newPacketNumber = packetSpaceManager.allocateNextPN(); + debug.log("Retransmitting packet %d as %d (%d attempts)", + packet.packetNumber(), newPacketNumber, attempts); + assert attempts >= 0; + QuicPacket newPacket = codingContext.encoder + .newInitialPacket(initial.sourceId(), initial.destinationId(), + initial.token(), newPacketNumber, + packetSpaceManager.getLargestPeerAckedPN(), + initial.frames(), codingContext); + long number = initial.packetNumber(); + Deadline now = now(); + manager.packetSent(newPacket, number, newPacket.packetNumber()); + retransmissions.add(new Retransmission(initial.packetNumber(), now, + AckFrame.largestAcknowledgedInPacket(newPacket))); + expectedRetransmissions.stream() + .filter(r -> r.isFor(number)) + .forEach(r -> { + assertTrue(r.isDue(now) || + packetSpaceManager.getLargestPeerAckedPN() - 3 > number, + "retransmitted packet %d is not yet due".formatted(number)); + successfulExpectations.add(r); + }); + boolean removed = expectedRetransmissions.removeIf(r -> r.isFor(number)); + if (number <= maxPacketNumber) { + assertTrue(removed, "retransmission of packet %d was not expected" + .formatted(number)); + } + } + + @Override + public long emitAckPacket(PacketSpace packetSpaceManager, + AckFrame ackFrame, boolean sendPing) { + long newPacketNumber = packetSpaceManager.allocateNextPN(); + debug.log("Emitting ack packet %d for %s (sendPing: %s)", + newPacketNumber, ackFrame, sendPing); + List frames; + if (ackFrame != null) { + frames = sendPing + ? List.of(new PingFrame(), ackFrame) + : List.of(ackFrame); + } else { + assert sendPing; + frames = List.of(new PingFrame()); + } + QuicPacket newPacket = codingContext.encoder + .newInitialPacket(localId, peerId, + null, newPacketNumber, + packetSpaceManager.getLargestPeerAckedPN(), + frames, codingContext); + packetSpaceManager.packetSent(newPacket, -1, newPacketNumber); + emittedAckPackets.offer(newPacket); + return newPacket.packetNumber(); + } + + @Override + public void acknowledged(QuicPacket packet) { + // TODO: nothing to do? + } + + @Override + public boolean sendData(PacketNumberSpace packetNumberSpace) { + return false; + } + + @Override + public Executor executor() { + return this::execute; + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void checkAbort(PacketNumberSpace packetNumberSpace) { } + + final CopyOnWriteArrayList expectedRetransmissions = new CopyOnWriteArrayList<>(); + final CopyOnWriteArrayList retransmissions = new CopyOnWriteArrayList<>(); + final CopyOnWriteArrayList successfulExpectations = new CopyOnWriteArrayList<>(); + + static record Retransmission(long packetNumber, Deadline atOrAfter, long largestAckSent) { + boolean isFor(long number) { + return number == packetNumber; + } + boolean isDue(Deadline now) { + return !atOrAfter.isAfter(now); + } + } + + /** + * Drives the test by pretending to emit each packet in order, + * then pretending to receive ack frames (as soon as possible + * given the largest packet number emitted). + * The timeline is advanced by chunks as instructed by + * the test. + * This method checks that the retransmission logic works as + * expected. + * @throws Exception + */ + // TODO: in the end we need to check that everything that was + // expected to happen happened. What is missing is to + // check the generation of ACK packets... Also a + // retransmitted packet may need to itself retransmitted + // again and we have no test for that. + public void run() throws Exception { + long timeline = 0; + long serverPacketNumbers = 0; + long maxAck; + Packet packet; + debug.log("Packets: %s", test.packets.stream().mapToLong(Packet::packetNumber) + .mapToObj(String::valueOf) + .collect(Collectors.joining(", ", "[", "]"))); + debug.log("Frames: %s", test.ackframes.stream().mapToLong(AckFrame::largestAcknowledged) + .mapToObj(String::valueOf) + .collect(Collectors.joining(", ", "[", "]"))); + long maxRetransmissionDelay = 250; + long maxAckDelay = manager.getMaxAckDelay(); + + Deadline start = now(); + Deadline nextSendAckDeadline = Deadline.MAX; + long firstAckPaket = -1, lastAckPacket = -1; + long largestAckAcked = -1; + boolean previousAckEliciting = false; + // simulate sending each packet, ordered by their packet number + while ((packet = packetQueue.poll()) != null) { + long offset = packet.packetNumber; + AckFrame nextAck = framesQueue.peek(); + maxAck = nextAck == null ? Long.MAX_VALUE : nextAck.largestAcknowledged(); + long largestReceivedAckedPN = manager.getLargestPeerAckedPN(); + debug.log("timeline: at %dms", timeline); + debug.log("sending packet: %d, largest ACK received: %d", + packet.packetNumber, largestReceivedAckedPN); + + // randomly decide whether we should attempt to include an ack frame + // with the next packet we send out... + boolean sendAck = RANDOM.nextBoolean(); + AckFrame ackFrameToSend = sendAck ? manager.getNextAckFrame(false) : null; + long largestAckSent = -1; + if (ackFrameToSend != null) { + previousAckEliciting = false; + debug.log("including ACK frame: " + ackFrameToSend); + nextSendAckDeadline = Deadline.MAX; + // assertFalse used on purpose here to make + // sure the stack trace can't be confused with one that + // originate in another similar lambda that use assertTrue below. + LongStream.range(firstAckPaket, lastAckPacket + 1).sequential() + .forEach(p -> assertFalse(!ackFrameToSend.isAcknowledging(p), + "frame %s should acknowledge %d" + .formatted(ackFrameToSend, p))); + largestAckSent = ackFrameToSend.largestAcknowledged(); + debug.log("largestAckSent is: " + largestAckAcked); + } + + // add a crypto frame and build the packet + CryptoFrame crypto = new CryptoFrame(offset, 1, + ByteBuffer.wrap(new byte[] {nextByte(offset)})); + List frames = ackFrameToSend == null ? + List.of(crypto) : List.of(crypto, ackFrameToSend); + QuicPacket newPacket = codingContext.encoder + .newInitialPacket(localId, peerId, + null, + packet.packetNumber, + largestReceivedAckedPN, + frames, codingContext); + // pretend that we sent a packet + manager.packetSent(newPacket, -1, packet.packetNumber); + + // compute next deadline + var nextDeadline = timerQueue.nextDeadline(); + var nextScheduledDeadline = manager.nextScheduledDeadline(); + var nextComputedDeadline = manager.computeNextDeadline(); + var now = now(); + debugDeadline("nextDeadline", start, now, nextDeadline); + debugDeadline("nextScheduledDeadline", start, now, nextScheduledDeadline); + debugDeadline("nextComputedDeadline", start, now, nextComputedDeadline); + assertFalse(nextDeadline.isAfter(now.plusMillis(maxRetransmissionDelay * rttEstimator.getPtoBackoff())), + "nextDeadline should not be after %dms from now!" + .formatted(maxRetransmissionDelay)); + expectedRetransmissions.add(new Retransmission(packet.packetNumber, + now.plus(maxRetransmissionDelay, ChronoUnit.MILLIS), largestAckSent)); + + List pending = manager.pendingAcknowledgements((s) ->s.boxed().toList()); + debug.log("pending ack: %s", pending); + assertContains(Assertion.TRUE, pending, packet.packetNumber, + "pending ack"); + + pending = manager.pendingRetransmission((s) ->s.boxed().toList()); + debug.log("pending retransmission: %s", pending); + assertContains(Assertion.FALSE, pending, packet.packetNumber, + "pending retransmission"); + + pending = manager.triggeredForRetransmission((s) ->s.boxed().toList()); + debug.log("triggered for retransmission: %s", pending); + assertContains(Assertion.FALSE, pending, packet.packetNumber, + "triggered for retransmission"); + + if (!nextDeadline.isAfter(now) || !nextSendAckDeadline.isAfter(now)) { + var nextI = timerQueue + .processEventsAndReturnNextDeadline(now, executor); + // this might have triggered sending an ACK packet, unless we + // already sent the ack with the initial packet just above. + debugDeadline("new deadline after events", start, now, nextI); + } + + // check generated ack packets, if any should have been generated... + if (!nextSendAckDeadline.isAfter(now)) { + debug.log("checking emitted ack packets: emitted %d", emittedAckPackets.size()); + var ackPacket = emittedAckPackets.poll(); + assertNotNull(ackPacket); + List ackFrames = ackPacket.frames() + .stream().filter(AckFrame.class::isInstance) + .map(AckFrame.class::cast) + .toList(); + assertEquals(frames.size(), 1, + "unexpected ack frames: " + frames); + AckFrame ackFrame = ackFrames.get(0); + LongStream.range(firstAckPaket, lastAckPacket + 1) + .forEach(p -> assertTrue(ackFrame.isAcknowledging(p), + "frame %s should acknowledge %d" + .formatted(ackFrame, p))); + assertNull(emittedAckPackets.peek(), + "emitted ackPacket queue not empty: " + emittedAckPackets); + debug.log("Got expected ackFrame for emitted ack packet: %s", ackFrame); + previousAckEliciting = false; + nextSendAckDeadline = Deadline.MAX; + } + + // advance the timeline by the instructed delay... + Deadline next = timeSource.advanceMillis(packet.delay); + timeline = timeline + packet.delay; + debug.log("advance deadline by %dms at %dms", packet.delay, timeline); + // note: beyond this point now > now(); packets between now and now() + // may not be retransmitted yet + + // Do not pretend to receive acknowledgement for + // packets that we haven't sent yet. + if (packet.packetNumber >= maxAck) { + // pretend to be receiving the next ack frame... + nextAck = framesQueue.poll(); + debug.log("Receiving acks for " + nextAck); + long spn = serverPacketNumbers++; + boolean isAckEliciting = RANDOM.nextBoolean(); + manager.packetReceived(PacketType.INITIAL, spn, isAckEliciting); + + // calculate if and when we should send out the ack frame for + // the ACK packet we just received + if (firstAckPaket == -1) firstAckPaket = spn; + lastAckPacket = spn; + debug.log("next sent ack should acknowledge [%d..%d]", + firstAckPaket, lastAckPacket); + if (isAckEliciting) { + debug.log("prevEliciting: %s", + previousAckEliciting); + if (previousAckEliciting) { + nextSendAckDeadline = min(now, nextSendAckDeadline); + } else { + nextSendAckDeadline = min(nextSendAckDeadline, next.plusMillis(maxAckDelay)); + } + debugDeadline("next ack deadline", start, next, nextSendAckDeadline); + } + previousAckEliciting |= isAckEliciting; + + // process the ack frame we just received + assertNotNull(nextAck); + manager.processAckFrame(nextAck); + firstAckPaket = Math.max(firstAckPaket, manager.getMinPNThreshold() + 1); + + // Here we can compute which packets will not be acknowledged yet, + // and which packets will be retransmitted. + long largestAckAckedBefore = largestAckAcked; + for (long number : acknowledgePackets(nextAck)) { + allAcks.addAck(number); + pending = manager.pendingAcknowledgements((s) -> s.boxed().toList()); + assertContains(Assertion.FALSE, pending, number, + "pending ack"); + + pending = manager.pendingRetransmission((s) -> s.boxed().toList()); + assertContains(Assertion.FALSE, pending, number, + "pending retransmission"); + + pending = manager.triggeredForRetransmission((s) -> s.boxed().toList()); + assertContains(Assertion.FALSE, pending, number, + "triggered for retransmission"); + // TODO check if we only retransmitted the expected packets + // ...need to replicate the logic + + /*expectedRetransmissions.stream() + .filter(r -> r.isFor(number)) + .forEach(r -> assertFalse(r.isDue(now), + "due packet %d was not retransmitted".formatted(number))); + for (Retransmission r : expectedRetransmissions) { + if (r.isFor(number)) { + largestAckAcked = Math.max(largestAckAcked, r.largestAckSent); + } + }*/ + for (Retransmission r : successfulExpectations) { + if (r.isFor(number)) { + largestAckAcked = Math.max(largestAckAcked, r.largestAckSent); + } + } + if (largestAckAcked != largestAckAckedBefore) { + debug.log("largestAckAcked is now %d", largestAckAcked); + largestAckAckedBefore = largestAckAcked; + } + if (largestAckAcked > -1) { + boolean changed = false; + if (firstAckPaket <= largestAckAcked) { + changed = true; + if (lastAckPacket > largestAckAcked) { + firstAckPaket = largestAckAcked + 1; + } else firstAckPaket = -1; + } + if (lastAckPacket <= largestAckAcked) { + changed = true; + lastAckPacket = -1; + } + if (changed) { + debug.log("next sent ack should now be [%d..%d]", + firstAckPaket, lastAckPacket); + } + } + boolean removed = expectedRetransmissions.removeIf(r -> r.isFor(number)); + successfulExpectations.stream() + .filter(r -> r.isFor(number)) + .forEach( r -> assertTrue(r.isDue(now()) || + manager.getLargestPeerAckedPN() - 3 > r.packetNumber, + "packet %d was retransmitted too early (deadline was %d)" + .formatted(number, start + .until(r.atOrAfter, ChronoUnit.MILLIS)))); + retransmissions.stream().filter(r -> r.isFor(number)).forEach( r -> { + assertTrue(r.isDue(now()), + "packet %d was retransmitted too early at %d" + .formatted(number, start + .until(r.atOrAfter, ChronoUnit.MILLIS))); + assertFalse(removed, "packet %d was in both lists" + .formatted(number)); + }); + } + } + } + } + + Deadline min(Deadline one, Deadline two) { + return one.isAfter(two) ? two : one; + } + + /** + * Should be called after {@link #run()}. + */ + void check() { + assertFalse(now().isBefore(timeSource.first.plusMillis(timeline))); + assertTrue(expectedRetransmissions.isEmpty()); + assertEquals(retransmissions.stream() + .map(Retransmission::packetNumber) + .filter(pn -> pn <= maxPacketNumber).toList(), + successfulExpectations.stream().map(Retransmission::packetNumber) + .filter(pn -> pn <= maxPacketNumber).toList()); + for (Retransmission r : retransmissions) { + if (r.packetNumber > maxPacketNumber || + manager.getLargestPeerAckedPN() - 3 > r.packetNumber) continue; + List succesful = successfulExpectations.stream() + .filter(s -> s.isFor(r.packetNumber)) + .toList(); + assertEquals(succesful.size(), 1); + succesful.forEach(s -> assertFalse(s.atOrAfter.isAfter(r.atOrAfter))); + } + + List acknowledged = new ArrayList<>(acknowledgePackets(allAcks.build())); + Collections.sort(acknowledged); + assertEquals(acknowledged, test.packets.stream() + .map(Packet::packetNumber).sorted().toList()); + } + + // TODO: add a LongStream acknowledged() to AckFrame - write a spliterator + // for that. + List acknowledgePackets(AckFrame frame) { + List list = new ArrayList<>(); + long largest = frame.largestAcknowledged(); + long smallest = largest + 2; + for (AckRange range : frame.ackRanges()) { + largest = smallest - range.gap() -2; + smallest = largest - range.range(); + for (long i = largest; i >= smallest; i--) { + assert frame.isAcknowledging(i) + : "%s is not acknowledging %d".formatted(frame, i); + list.add(i); + } + } + return list; + } + + interface Assertion { + void check(boolean result, String message); + default String negation() { + return (this == FALSE) ? "doesn't " : ""; + } + Assertion TRUE = Assert::assertTrue; + Assertion FALSE = Assert::assertFalse; + } + static void assertContains(Assertion assertion, List list, long number, String desc) { + assertion.check(list.contains(number), + "%s: %s %scontains %d".formatted(desc, list, assertion.negation(), number)); + } + + void debugDeadline(String desc, Deadline start, Deadline now, Deadline deadline) { + long nextMs = deadline.equals(Deadline.MAX) ? 0 : + now.until(deadline, ChronoUnit.MILLIS); + long at = deadline.equals(Deadline.MAX) ? 0 : + start.until(deadline, ChronoUnit.MILLIS); + String when = deadline.equals(Deadline.MAX) ? "never" + : nextMs >= 0 ? ("at %d in %dms".formatted(at, nextMs)) + : ("at %d due by %dms".formatted(at, (-nextMs))); + debug.log("%s: %s", desc, when); + } + + byte nextByte(long offset) { + long start = 'a'; + int len = 'z' - 'a' + 1; + long res = start + offset % len; + assert res >= 'a' && res <= 'z'; + return (byte) res; + } + + private void execute(Runnable runnable) { + runnable.run(); + } + + Deadline now() { + return timeSource.instant(); + } + + static QuicConnectionId newId() { + byte[] idbites = new byte[CIDLEN]; + RANDOM.nextBytes(idbites); + return new PeerConnectionId(idbites); + } + } + + @Test(dataProvider = "tests") + public void testPacketSpaceManager(TestCase testCase) throws Exception { + System.out.printf("%n ------- testPacketSpaceManager ------- %n"); + SynchronousTestDriver driver = new SynchronousTestDriver(testCase); + driver.run(); + driver.check(); + } + + } diff --git a/test/jdk/java/net/httpclient/quic/QuicFramesDecoderTest.java b/test/jdk/java/net/httpclient/quic/QuicFramesDecoderTest.java new file mode 100644 index 00000000000..f0d4b316e46 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/QuicFramesDecoderTest.java @@ -0,0 +1,298 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.quic.frames.QuicFrame; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.nio.ByteBuffer; +import java.util.HexFormat; + +import static org.testng.Assert.*; + + +/* + * @test + * @library /test/lib + * @summary Tests to check QUIC frame decoding errors are handled correctly + * @run testng/othervm QuicFramesDecoderTest + */ +public class QuicFramesDecoderTest { + + + // correct frames. Single byte frames (padding, ping, handshake_done) omitted. + // ACK without ECN, acked = 0, 2; delay = 2 + private static final byte[] ACK_BASE = HexFormat.of().parseHex("02020201000000"); + private static final byte[] ACK_BASE_B = HexFormat.of().parseHex("0202020100004000"); + + // ACK with ECN, acked = 0, 2; delay = 2, ECN=(3,4,5) + private static final byte[] ACK_ECN = HexFormat.of().parseHex("03020201000000030405"); + private static final byte[] ACK_ECN_B = HexFormat.of().parseHex("0302020100000003044005"); + + // RESET_STREAM, stream 3, error 2, final size 1 + private static final byte[] RESET_STREAM = HexFormat.of().parseHex("04030201"); + private static final byte[] RESET_STREAM_B = HexFormat.of().parseHex("0403024001"); + + // STOP_SENDING, stream 4, error 3 + private static final byte[] STOP_SENDING = HexFormat.of().parseHex("050403"); + private static final byte[] STOP_SENDING_B = HexFormat.of().parseHex("05044003"); + + // CRYPTO, offset 5, length 4, data + private static final byte[] CRYPTO = HexFormat.of().parseHex("06050403020100"); + + // NEW_TOKEN, length 6, data + private static final byte[] NEW_TOKEN = HexFormat.of().parseHex("0706050403020100"); + + // STREAM-o-l-f, stream 7, no data + private static final byte[] STREAM = HexFormat.of().parseHex("0807"); + + // STREAM-o-l+f, stream 8, no data + private static final byte[] STREAM_F = HexFormat.of().parseHex("0908"); + + // STREAM-o+l-f, stream 9, length 8 + private static final byte[] STREAM_L = HexFormat.of().parseHex("0a09080706050403020100"); + + // STREAM-o+l+f, stream 10, length 9 + private static final byte[] STREAM_LF = HexFormat.of().parseHex("0b0a09080706050403020100"); + + // STREAM+o-l-f, stream 11, offset 0, no data + private static final byte[] STREAM_O = HexFormat.of().parseHex("0c000a"); + + // STREAM+o-l+f, stream 12, offset 0, no data + private static final byte[] STREAM_OF = HexFormat.of().parseHex("0d000b"); + + // STREAM+o+l-f, stream 13, offset 0, length 3 + private static final byte[] STREAM_OL = HexFormat.of().parseHex("0e000c03020100"); + + // STREAM+o+l+f, stream 14, offset 0, length 3 + private static final byte[] STREAM_OLF = HexFormat.of().parseHex("0f000d03020100"); + + // MAX_DATA, max=15 + private static final byte[] MAX_DATA = HexFormat.of().parseHex("100f"); + private static final byte[] MAX_DATA_B = HexFormat.of().parseHex("10400f"); + + // MAX_STREAM_DATA, stream = 16 max=15 + private static final byte[] MAX_STREAM_DATA = HexFormat.of().parseHex("11100f"); + private static final byte[] MAX_STREAM_DATA_B = HexFormat.of().parseHex("1110400f"); + + // MAX_STREAMS, bidi, streams = 2^60 + private static final byte[] MAX_STREAMS_B = HexFormat.of().parseHex("12d000000000000000"); + + // MAX_STREAMS, uni, streams = 2^60 + private static final byte[] MAX_STREAMS_U = HexFormat.of().parseHex("13d000000000000000"); + + // DATA_BLOCKED, max=19 + private static final byte[] DATA_BLOCKED = HexFormat.of().parseHex("1413"); + private static final byte[] DATA_BLOCKED_B = HexFormat.of().parseHex("144013"); + + // STREAM_DATA_BLOCKED, stream = 20 max=19 + private static final byte[] STREAM_DATA_BLOCKED = HexFormat.of().parseHex("151413"); + private static final byte[] STREAM_DATA_BLOCKED_B = HexFormat.of().parseHex("15144013"); + + // STREAMS_BLOCKED, bidi, streams = 2^60 + private static final byte[] STREAMS_BLOCKED_B = HexFormat.of().parseHex("16d000000000000000"); + + // STREAMS_BLOCKED, uni, streams = 2^60 + private static final byte[] STREAMS_BLOCKED_U = HexFormat.of().parseHex("17d000000000000000"); + + // NEW_CONNECTION_ID, seq=23, retire=22, len = 5 + private static final byte[] NEW_CONNECTION_ID = HexFormat.of().parseHex("181716"+ + "051413121110"+"0f0e0d0c0b0a09080706050403020100"); + + // RETIRE_CONNECTION_ID, seq=24 + private static final byte[] RETIRE_CONNECTION_ID = HexFormat.of().parseHex("1918"); + private static final byte[] RETIRE_CONNECTION_ID_B = HexFormat.of().parseHex("194018"); + + // PATH_CHALLENGE + private static final byte[] PATH_CHALLENGE = HexFormat.of().parseHex("1a0706050403020100"); + + // PATH_RESPONSE + private static final byte[] PATH_RESPONSE = HexFormat.of().parseHex("1b0706050403020100"); + + // CONNECTION_CLOSE, quic, error 27, frame type 26, reason='\0' + private static final byte[] CONNECTION_CLOSE_Q = HexFormat.of().parseHex("1c1b1a0100"); + // CONNECTION_CLOSE, quic, error 27, frame type 26, reason= + // efbfbf (U+FFFF) - "not a valid unicode character + // edb080 (U+DC00) - low surrogate, prohibited in UTF8 (RFC3629), must be preceded by high surrogate otherwise + // eda080 (U+D800) - high surrogate, prohibited in UTF8 (RFC3629), must be followed by low surrogate otherwise + // 80 - not a first byte of UTF8 sequence + // c0d0e0f0ff - not a valid UTF8 sequence + private static final byte[] CONNECTION_CLOSE_Q_BAD_REASON = HexFormat.of().parseHex("1c1b1a0fefbfbfedb080eda08080c0d0e0f0ff"); + + // CONNECTION_CLOSE, app, error 28, reason='\0' + private static final byte[] CONNECTION_CLOSE_A = HexFormat.of().parseHex("1d1c0100"); + // CONNECTION_CLOSE, app, error 28, reason= same as CONNECTION_CLOSE_Q_BAD_REASON + private static final byte[] CONNECTION_CLOSE_A_BAD_REASON = HexFormat.of().parseHex("1d1c0fefbfbfedb080eda08080c0d0e0f0ff"); + // end of correct frames + + // malformed frames other than truncated + // ACK acknowledging negative packet + // ACK without ECN, acked = -1, 1; delay = 2 + private static final byte[] ACK_NEG_BASE = HexFormat.of().parseHex("02010201000000"); + + // ACK with ECN, acked = -1, 1; delay = 2, ECN=(3,4,5) + private static final byte[] ACK_NEG_ECN = HexFormat.of().parseHex("03010201000000030405"); + + // ACK without ECN, acked = -1, 0; delay = 2 + private static final byte[] ACK_NEG_BASE_2 = HexFormat.of().parseHex("0200020001"); + + // CRYPTO out of range: offset MAX_VL_INT, len=1 + private static final byte[] CRYPTO_OOR = HexFormat.of().parseHex("06ffffffffffffffff0100"); + + // NEW_TOKEN empty + private static final byte[] NEW_TOKEN_EMPTY = HexFormat.of().parseHex("0700"); + + // MAX_STREAMS out of range + // MAX_STREAMS, bidi, streams = 2^60+1 + private static final byte[] MAX_STREAMS_B_OOR = HexFormat.of().parseHex("12d000000000000001"); + // MAX_STREAMS, uni, streams = 2^60+1 + private static final byte[] MAX_STREAMS_U_OOR = HexFormat.of().parseHex("13d000000000000001"); + + // STREAMS_BLOCKED out of range + // STREAMS_BLOCKED, bidi, streams = 2^60+1 + private static final byte[] STREAMS_BLOCKED_B_OOR = HexFormat.of().parseHex("16d000000000000001"); + + // STREAMS_BLOCKED, uni, streams = 2^60+1 + private static final byte[] STREAMS_BLOCKED_U_OOR = HexFormat.of().parseHex("17d000000000000001"); + + // NEW_CONNECTION_ID, seq=23, retire=22, len = 0 + private static final byte[] NEW_CONNECTION_ID_ZERO = HexFormat.of().parseHex("181716"+ + "00"+"0f0e0d0c0b0a09080706050403020100"); + + @DataProvider + public static Object[][] goodFrames() { + return new Object[][]{ + new Object[]{"ack without ecn", ACK_BASE, false}, + new Object[]{"ack without ecn", ACK_BASE_B, true}, + new Object[]{"ack with ecn", ACK_ECN, false}, + new Object[]{"ack with ecn", ACK_ECN_B, true}, + new Object[]{"RESET_STREAM", RESET_STREAM, false}, + new Object[]{"RESET_STREAM", RESET_STREAM_B, true}, + new Object[]{"STOP_SENDING", STOP_SENDING, false}, + new Object[]{"STOP_SENDING", STOP_SENDING_B, true}, + new Object[]{"CRYPTO", CRYPTO, false}, + new Object[]{"NEW_TOKEN", NEW_TOKEN, false}, + new Object[]{"STREAM-o-l-f", STREAM, false}, + new Object[]{"STREAM-o-l+f", STREAM_F, false}, + new Object[]{"STREAM-o+l-f", STREAM_L, false}, + new Object[]{"STREAM-o+l+f", STREAM_LF, false}, + new Object[]{"STREAM+o-l-f", STREAM_O, false}, + new Object[]{"STREAM+o-l+f", STREAM_OF, false}, + new Object[]{"STREAM+o+l-f", STREAM_OL, false}, + new Object[]{"STREAM+o+l+f", STREAM_OLF, false}, + new Object[]{"MAX_DATA", MAX_DATA, false}, + new Object[]{"MAX_DATA", MAX_DATA_B, true}, + new Object[]{"MAX_STREAM_DATA", MAX_STREAM_DATA, false}, + new Object[]{"MAX_STREAM_DATA", MAX_STREAM_DATA_B, true}, + new Object[]{"MAX_STREAMS bidi", MAX_STREAMS_B, false}, + new Object[]{"MAX_STREAMS uni", MAX_STREAMS_U, false}, + new Object[]{"DATA_BLOCKED", DATA_BLOCKED, false}, + new Object[]{"DATA_BLOCKED", DATA_BLOCKED_B, true}, + new Object[]{"STREAM_DATA_BLOCKED", STREAM_DATA_BLOCKED, false}, + new Object[]{"STREAM_DATA_BLOCKED", STREAM_DATA_BLOCKED_B, true}, + new Object[]{"STREAMS_BLOCKED bidi", STREAMS_BLOCKED_B, false}, + new Object[]{"STREAMS_BLOCKED uni", STREAMS_BLOCKED_U, false}, + new Object[]{"NEW_CONNECTION_ID", NEW_CONNECTION_ID, false}, + new Object[]{"RETIRE_CONNECTION_ID", RETIRE_CONNECTION_ID, false}, + new Object[]{"RETIRE_CONNECTION_ID", RETIRE_CONNECTION_ID_B, true}, + new Object[]{"PATH_CHALLENGE", PATH_CHALLENGE, false}, + new Object[]{"PATH_RESPONSE", PATH_RESPONSE, false}, + new Object[]{"CONNECTION_CLOSE QUIC", CONNECTION_CLOSE_Q, false}, + new Object[]{"CONNECTION_CLOSE QUIC non-utf8 reason", CONNECTION_CLOSE_Q_BAD_REASON, false}, + new Object[]{"CONNECTION_CLOSE app", CONNECTION_CLOSE_A, false}, + new Object[]{"CONNECTION_CLOSE app non-utf8 reason", CONNECTION_CLOSE_A_BAD_REASON, false}, + }; + } + + @DataProvider + public static Object[][] badFrames() { + return new Object[][]{ + new Object[]{"ack without ecn, negative pn", ACK_NEG_BASE}, + new Object[]{"ack without ecn, negative pn, v2", ACK_NEG_BASE_2}, + new Object[]{"ack with ecn, negative pn", ACK_NEG_ECN}, + new Object[]{"CRYPTO out of range", CRYPTO_OOR}, + new Object[]{"NEW_TOKEN empty", NEW_TOKEN_EMPTY}, + new Object[]{"MAX_STREAMS bidi out of range", MAX_STREAMS_B_OOR}, + new Object[]{"MAX_STREAMS uni out of range", MAX_STREAMS_U_OOR}, + new Object[]{"STREAMS_BLOCKED bidi out of range", STREAMS_BLOCKED_B_OOR}, + new Object[]{"STREAMS_BLOCKED uni out of range", STREAMS_BLOCKED_U_OOR}, + new Object[]{"NEW_CONNECTION_ID zero length", NEW_CONNECTION_ID_ZERO}, + }; + } + + @Test(dataProvider = "goodFrames") + public void testReencode(String desc, byte[] frame, boolean bloated) throws Exception { + // check if the goodFrames provider indeed contains good frames + ByteBuffer buf = ByteBuffer.wrap(frame); + var qf = QuicFrame.decode(buf); + assertFalse(buf.hasRemaining(), buf.remaining() + " bytes left in buffer after parsing"); + // some frames deliberately use suboptimal encoding, skip them + if (bloated) return; + assertEquals(qf.size(), frame.length, "Frame size mismatch"); + buf.clear(); + ByteBuffer encoded = ByteBuffer.allocate(frame.length); + qf.encode(encoded); + assertFalse(encoded.hasRemaining(), "Actual frame length mismatch"); + encoded.flip(); + assertEquals(buf, encoded, "Encoded buffer is different from the original one"); + } + + @Test(dataProvider = "goodFrames") + public void testToString(String desc, byte[] frame, boolean bloated) throws Exception { + // check if the goodFrames provider indeed contains good frames + ByteBuffer buf = ByteBuffer.wrap(frame); + var qf = QuicFrame.decode(buf); + assertFalse(buf.hasRemaining(), buf.remaining() + " bytes left in buffer after parsing"); + System.out.println(qf); // should not throw + } + + @Test(dataProvider = "goodFrames") + public void testTruncatedFrame(String desc, byte[] frame, boolean bloated) throws Exception { + // check if parsing a truncated frame throws the right error + ByteBuffer buf = ByteBuffer.wrap(frame); + for (int i = 1; i < buf.capacity(); i++) { + buf.position(0); + buf.limit(i); + try { + var qf = QuicFrame.decode(buf); + fail("Expected the decoder to throw on length " + i + ", got: " + qf); + } catch (QuicTransportException e) { + assertEquals(e.getErrorCode(), QuicTransportErrors.FRAME_ENCODING_ERROR.code()); + } + } + } + + @Test(dataProvider = "badFrames") + public void testBadFrame(String desc, byte[] frame) throws Exception { + // check if parsing a bad frame throws the right error + ByteBuffer buf = ByteBuffer.wrap(frame); + try { + var qf = QuicFrame.decode(buf); + fail("Expected the decoder to throw, got: "+qf); + } catch (QuicTransportException e) { + assertEquals(e.getErrorCode(), QuicTransportErrors.FRAME_ENCODING_ERROR.code()); + } + } +} diff --git a/test/jdk/java/net/httpclient/quic/QuicRequestResponseTest.java b/test/jdk/java/net/httpclient/quic/QuicRequestResponseTest.java new file mode 100644 index 00000000000..2c5844dc4c4 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/QuicRequestResponseTest.java @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; + +import jdk.httpclient.test.lib.common.TestUtil; +import jdk.httpclient.test.lib.quic.ClientConnection; +import jdk.httpclient.test.lib.quic.ConnectedBidiStream; +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.httpclient.test.lib.quic.QuicServerHandler; +import jdk.httpclient.test.lib.quic.QuicStandaloneServer; +import jdk.internal.net.http.quic.QuicClient; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +/* + * @test + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.quic.QuicStandaloneServer + * jdk.httpclient.test.lib.quic.ClientConnection + * jdk.httpclient.test.lib.common.TestUtil + * jdk.test.lib.net.SimpleSSLContext + * @run testng/othervm -Djdk.internal.httpclient.debug=true QuicRequestResponseTest + */ +public class QuicRequestResponseTest { + + private QuicStandaloneServer server; + private SSLContext sslContext; + private ExecutorService executor; + + private static final byte[] HELLO_MSG = "Hello Quic".getBytes(StandardCharsets.UTF_8); + + @BeforeClass + public void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + executor = Executors.newCachedThreadPool(); + server = QuicStandaloneServer.newBuilder() + .availableVersions(new QuicVersion[]{QuicVersion.QUIC_V1}) + .sslContext(sslContext) + .build(); + // add a handler which deals with incoming connections + server.addHandler(new EchoHandler(HELLO_MSG.length)); + server.start(); + System.out.println("Server started at " + server.getAddress()); + } + + @AfterClass + public void afterClass() throws Exception { + if (server != null) { + System.out.println("Stopping server " + server.getAddress()); + server.close(); + } + if (executor != null) executor.close(); + } + + private QuicClient createClient() { + var versions = List.of(QuicVersion.QUIC_V1); + var context = new QuicTLSContext(sslContext); + var params = new SSLParameters(); + return new QuicClient.Builder() + .availableVersions(versions) + .tlsContext(context) + .sslParameters(params) + .executor(executor) + .bindAddress(TestUtil.chooseClientBindAddress().orElse(null)) + .build(); + } + + @Test + public void test() throws Exception { + try (final QuicClient client = createClient()) { + // create a QUIC connection to the server + final ClientConnection conn = ClientConnection.establishConnection(client, server.getAddress()); + // open a bidi stream + final ConnectedBidiStream bidiStream = conn.initiateNewBidiStream(); + // write data on the stream + try (final OutputStream os = bidiStream.outputStream()) { + os.write(HELLO_MSG); + System.out.println("client: Client wrote message to bidi stream's output stream"); + } + // wait for response + try (final InputStream is = bidiStream.inputStream()) { + System.out.println("client: reading from bidi stream's input stream"); + final byte[] data = is.readAllBytes(); + System.out.println("client: Received response of size " + data.length); + final String response = new String(data, StandardCharsets.UTF_8); + // verify response + System.out.println("client: Response: " + response); + if (!Arrays.equals(response.getBytes(StandardCharsets.UTF_8), HELLO_MSG)) { + throw new AssertionError("Unexpected response: " + response); + } + } finally { + System.err.println("client: Closing bidi stream from test"); + bidiStream.close(); + } + } + } + + /** + * Reads data from incoming client initiated bidirectional stream of a Quic connection + * and writes back a response which is same as the read data + */ + private static final class EchoHandler implements QuicServerHandler { + + private final int numBytesToRead; + + private EchoHandler(final int numBytesToRead) { + this.numBytesToRead = numBytesToRead; + } + + @Override + public void handleBidiStream(final QuicServerConnection conn, + final ConnectedBidiStream bidiStream) throws IOException { + System.out.println("Handling incoming bidi stream " + bidiStream + + " on connection " + conn); + final byte[] data; + // read the request content + try (final InputStream is = bidiStream.inputStream()) { + System.out.println("Handler reading data from bidi stream's inputstream " + is); + data = is.readAllBytes(); + System.out.println("Handler read " + data.length + " bytes of data"); + } + if (data.length != numBytesToRead) { + throw new IOException("Expected to read " + numBytesToRead + + " bytes but read only " + data.length + " bytes"); + } + // write response + try (final OutputStream os = bidiStream.outputStream()) { + System.out.println("Handler writing data to bidi stream's outputstream " + os); + os.write(data); + } + System.out.println("Handler invocation complete"); + } + } +} diff --git a/test/jdk/java/net/httpclient/quic/StatelessResetReceiptTest.java b/test/jdk/java/net/httpclient/quic/StatelessResetReceiptTest.java new file mode 100644 index 00000000000..8eb7a3663d8 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/StatelessResetReceiptTest.java @@ -0,0 +1,303 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.DatagramChannel; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; + +import jdk.httpclient.test.lib.common.TestUtil; +import jdk.httpclient.test.lib.quic.ClientConnection; +import jdk.httpclient.test.lib.quic.ConnectedBidiStream; +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.httpclient.test.lib.quic.QuicServerHandler; +import jdk.httpclient.test.lib.quic.QuicStandaloneServer; +import jdk.internal.net.http.common.MinimalFuture; +import jdk.internal.net.http.quic.QuicClient; +import jdk.internal.net.http.quic.QuicConnectionId; +import jdk.internal.net.http.quic.QuicConnectionImpl; +import jdk.internal.net.http.quic.TerminationCause; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import static jdk.internal.net.http.quic.TerminationCause.forTransportError; +import static jdk.internal.net.quic.QuicTransportErrors.NO_VIABLE_PATH; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/* + * @test + * @summary verify that when a QUIC (client) connection receives a stateless reset + * from the peer, then the connection is properly terminated + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.quic.QuicStandaloneServer + * jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.common.TestUtil + * @run junit/othervm -Djdk.internal.httpclient.debug=true StatelessResetReceiptTest + */ +public class StatelessResetReceiptTest { + + private static QuicStandaloneServer server; + private static SSLContext sslContext; + private static ExecutorService executor; + + @BeforeAll + static void beforeAll() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + executor = Executors.newCachedThreadPool(); + server = QuicStandaloneServer.newBuilder() + .availableVersions(new QuicVersion[]{QuicVersion.QUIC_V1}) + .sslContext(sslContext) + .build(); + server.start(); + System.out.println("Server started at " + server.getAddress()); + } + + @AfterAll + static void afterAll() throws Exception { + if (server != null) { + System.out.println("Stopping server " + server.getAddress()); + server.close(); + } + if (executor != null) { + executor.close(); + } + } + + private QuicClient createClient() { + var versions = List.of(QuicVersion.QUIC_V1); + var context = new QuicTLSContext(sslContext); + var params = new SSLParameters(); + return new QuicClient.Builder() + .availableVersions(versions) + .tlsContext(context) + .sslParameters(params) + .executor(executor) + .bindAddress(TestUtil.chooseClientBindAddress().orElse(null)) + .build(); + } + + /** + * Initiates a connection between client and server. When the connection is still + * active, this test initiates a stateless reset from the server connection against + * the client connection. The test then expects that the client connection, which was active + * until then, is terminated due to this stateless reset. + */ + @Test + public void testActiveConnection() throws Exception { + final CompletableFuture serverConnCF = new MinimalFuture<>(); + final NotifyingHandler handler = new NotifyingHandler(serverConnCF); + server.addHandler(handler); + try (final QuicClient client = createClient()) { + // create a QUIC connection to the server + final ClientConnection conn = ClientConnection.establishConnection(client, + server.getAddress()); + // write data on the stream + try (final ConnectedBidiStream bidiStream = conn.initiateNewBidiStream(); + final OutputStream os = bidiStream.outputStream()) { + os.write("foobar".getBytes(StandardCharsets.UTF_8)); + System.out.println("client: Client wrote message to bidi stream's output stream"); + } + // wait for the handler on the server's connection to be invoked + System.out.println("waiting for the request to be handled by the server connection"); + final QuicServerConnection serverConn = serverConnCF.get(); + System.out.println("request handled by the server connection " + serverConn); + // verify the connection is still open + assertTrue(conn.underlyingQuicConnection().isOpen(), "QUIC connection is not open"); + sendStatelessResetFrom(serverConn); + // now expect the (active) client connection to be terminated + assertStatelessResetTermination(conn); + } + } + + /** + * Initiates a connection between client and server. The client connection is then + * closed and it thus moves to the closing state. The test then initiates a stateless reset + * from the server connection against this closing client connection. The test then verifies + * that the client connection, which was in closing state, has been completely removed from the + * endpoint upon receiving this stateless reset. + */ + @Test + public void testClosingConnection() throws Exception { + final CompletableFuture serverConnCF = new MinimalFuture<>(); + final NotifyingHandler handler = new NotifyingHandler(serverConnCF); + server.addHandler(handler); + try (final QuicClient client = createClient()) { + // create a QUIC connection to the server + final ClientConnection conn = ClientConnection.establishConnection(client, + server.getAddress()); + // write data on the stream + try (final ConnectedBidiStream bidiStream = conn.initiateNewBidiStream(); + final OutputStream os = bidiStream.outputStream()) { + os.write("foobar".getBytes(StandardCharsets.UTF_8)); + System.out.println("client: Client wrote message to bidi stream's output stream"); + } + // wait for the handler on the server's connection to be invoked + System.out.println("waiting for the request to be handled by the server connection"); + final QuicServerConnection serverConn = serverConnCF.get(); + System.out.println("request handled by the server connection " + serverConn); + // now close the client/local connection so that it transitions to closing state + System.out.println("closing client connection " + conn); + conn.close(); + // verify connection is no longer open + assertFalse(conn.underlyingQuicConnection().isOpen(), "QUIC connection is still open"); + // now send a stateless reset from the server connection + sendStatelessResetFrom(serverConn); + // wait for the stateless reset to be processed + final Instant waitEnd = Instant.now().plus(Duration.ofSeconds(2)); + while (Instant.now().isBefore(waitEnd)) { + if (conn.endpoint().connectionCount() != 0) { + // wait for a while + Thread.sleep(10); + } + } + // now expect the endpoint to have removed the client connection. + // this isn't a fool proof verification because the connection could have been + // moved from the closing state to draining and then removed, without having processed + // the stateless reset, but we don't have any other credible way of verifying this + assertEquals(0, conn.endpoint().connectionCount(), "unexpected number of connections" + + " known to QUIC endpoint"); + } + } + + /** + * Initiates a connection between client and server. The server connection is then + * closed and the client connection thus moves to the draining state. The test then initiates + * a stateless reset from the server connection against this draining client connection. The + * test then verifies that the client connection, which was in draining state, has been + * completely removed from the endpoint upon receiving this stateless reset. + */ + @Test + public void testDrainingConnection() throws Exception { + final CompletableFuture serverConnCF = new MinimalFuture<>(); + final NotifyingHandler handler = new NotifyingHandler(serverConnCF); + server.addHandler(handler); + try (final QuicClient client = createClient()) { + // create a QUIC connection to the server + final ClientConnection conn = ClientConnection.establishConnection(client, + server.getAddress()); + // write data on the stream + try (final ConnectedBidiStream bidiStream = conn.initiateNewBidiStream(); + final OutputStream os = bidiStream.outputStream()) { + os.write("foobar".getBytes(StandardCharsets.UTF_8)); + System.out.println("client: Client wrote message to bidi stream's output stream"); + } + // wait for the handler on the server's connection to be invoked + System.out.println("waiting for the request to be handled by the server connection"); + final QuicServerConnection serverConn = serverConnCF.get(); + System.out.println("request handled by the server connection " + serverConn); + // now close the server connection so that the client conn transitions to draining state + System.out.println("closing server connection " + serverConn); + // intentionally use a "unique" error to confidently verify the termination cause + final TerminationCause tc = forTransportError(NO_VIABLE_PATH) + .loggedAs("intentionally closed by server to initiate draining state" + + " on client connection"); + serverConn.connectionTerminator().terminate(tc); + // wait for client conn to terminate + final TerminationCause clientTC = ((QuicConnectionImpl) conn.underlyingQuicConnection()) + .futureTerminationCause().get(); + // verify connection closed for the right reason + assertEquals(NO_VIABLE_PATH.code(), clientTC.getCloseCode(), + "unexpected termination cause"); + // now send a stateless reset from the server connection + sendStatelessResetFrom(serverConn); + // wait for the stateless reset to be processed + final Instant waitEnd = Instant.now().plus(Duration.ofSeconds(2)); + while (Instant.now().isBefore(waitEnd)) { + if (conn.endpoint().connectionCount() != 0) { + // wait for a while + Thread.sleep(10); + } + } + // now expect the endpoint to have removed the client connection. + // this isn't a fool proof verification because the connection could have been + // removed after moving out from the draining state, without having processed the + // stateless reset, but we don't have any other credible way of verifying this + assertEquals(0, conn.endpoint().connectionCount(), "unexpected number of connections" + + " known to QUIC endpoint"); + } + } + + private static void sendStatelessResetFrom(final QuicServerConnection serverConn) + throws IOException { + final QuicConnectionId localConnId = serverConn.localConnectionId(); + final ByteBuffer resetDatagram = serverConn.endpoint().idFactory().statelessReset( + localConnId.asReadOnlyBuffer(), 43); + final InetSocketAddress targetAddr = serverConn.peerAddress(); + ((DatagramChannel) serverConn.channel()).send(resetDatagram, targetAddr); + System.out.println("sent stateless reset from server conn " + serverConn + " to " + + targetAddr); + } + + private static void assertStatelessResetTermination(final ClientConnection conn) + throws Exception { + final CompletableFuture cf = + ((QuicConnectionImpl) conn.underlyingQuicConnection()).futureTerminationCause(); + final TerminationCause tc = cf.get(); + System.out.println("got termination cause " + tc.getCloseCause() + " - " + tc.getLogMsg()); + final IOException closeCause = tc.getCloseCause(); + assertNotNull(closeCause, "close cause IOException is null"); + final String expectedMsg = "stateless reset from peer"; + if (closeCause.getMessage() != null && closeCause.getMessage().contains(expectedMsg)) { + // got expected IOException + return; + } + // unexpected IOException. throw it back + throw closeCause; + } + + private static final class NotifyingHandler implements QuicServerHandler { + + private final CompletableFuture serverConnCF; + + private NotifyingHandler(final CompletableFuture serverConnCF) { + this.serverConnCF = serverConnCF; + } + + @Override + public void handleBidiStream(final QuicServerConnection conn, + final ConnectedBidiStream bidiStream) { + System.out.println("Handling incoming bidi stream " + bidiStream + + " on connection " + conn); + this.serverConnCF.complete(conn); + } + } +} diff --git a/test/jdk/java/net/httpclient/quic/VariableLengthTest.java b/test/jdk/java/net/httpclient/quic/VariableLengthTest.java new file mode 100644 index 00000000000..aec3e999812 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/VariableLengthTest.java @@ -0,0 +1,348 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +import jdk.internal.net.http.quic.VariableLengthEncoder; +import jtreg.SkippedException; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; + +/* + * @test + * @library /test/lib + * @modules java.net.http/jdk.internal.net.http.quic + * @run testng/othervm VariableLengthTest + * @summary Tests to check quic/util methods encode/decodeVariableLength methods + * work as expected. + */ +public class VariableLengthTest { + static final Class IAE = IllegalArgumentException.class; + + @DataProvider(name = "decode invariants") + public Object[][] decodeInvariants() { + return new Object[][] + { + { new byte[]{7}, 7, 1 }, // 00 + { new byte[]{65, 11}, 267, 2 }, // 01 + { new byte[]{-65, 11, 22, 33}, 1057691169, 4 }, // 10 + { new byte[]{-1, 11, 22, 33, 44, 55, 66, 77}, 4542748980864827981L, 8 }, // 11 + { new byte[]{-1, -11, -22, -33, -44, -55, -66, -77}, 4608848040752168627L, 8 }, + { new byte[]{}, -1, 0 }, + { new byte[]{-65}, -1, 0 }, + }; + } + @DataProvider(name = "encode invariants") + public Object[][] encodeInvariants() { + return new Object[][] + { + { 7, 1, null }, // 00 + { 267, 2, null }, // 01 + { 1057691169, 4, null }, // 10 + { 4542748980864827981L, 8, null }, // 11 + { Long.MAX_VALUE, 0, IAE }, + { -1, 0, IAE }, + }; + } + @DataProvider(name = "prefix invariants") + public Object[][] prefixInvariants() { + return new Object[][] + { + { Long.MAX_VALUE, 0, IAE }, + { 4611686018427387903L+1, 0, IAE }, + { 4611686018427387903L, 3, null }, + { 4611686018427387903L-1, 3, null }, + { 1073741823+1, 3, null }, + { 1073741823, 2, null }, // (length > (1L << 30)-1) + { 1073741823-1, 2, null }, + { 16383+1, 2, null }, + { 16383, 1, null }, // (length > (1L << 14)-1 + { 16383-1, 1, null }, + { 63+1, 1, null }, + { 63 , 0, null }, // (length > (1L << 6)-1 + { 63-1, 0, null }, + { 100, 1, null }, + { 10, 0, null }, + { 1, 0, null }, + { 0, 0, null }, // (length >= 0) + { -1, 0, IAE }, + { -10, 0, IAE }, + { -100, 0, IAE }, + { Long.MIN_VALUE, 0, IAE }, + { -4611686018427387903L-1, 0, IAE }, + { -4611686018427387903L, 0, IAE }, + { -4611686018427387903L+1, 0, IAE }, + { -1073741823-1, 0, IAE }, + { -1073741823, 0, IAE }, // (length > (1L << 30)-1) + { -1073741823+1, 0, IAE }, + { -16383-1, 0, IAE }, + { -16383, 0, IAE }, // (length > (1L << 14)-1 + { -16383+1, 0, IAE }, + { -63-1, 0, IAE }, + { -63 , 0, IAE }, // (length > (1L << 6)-1 + { -63+1, 0, IAE }, + }; + } + + @Test(dataProvider = "decode invariants") + public void testDecode(byte[] values, long expectedLength, int expectedPosition) { + ByteBuffer bb = ByteBuffer.wrap(values); + var actualLength = VariableLengthEncoder.decode(bb); + assertEquals(actualLength, expectedLength); + + var actualPosition = bb.position(); + assertEquals(actualPosition, expectedPosition); + } + + @Test(dataProvider = "decode invariants") + public void testPeek(byte[] values, long expectedLength, int expectedPosition) { + ByteBuffer bb = ByteBuffer.wrap(values); + var actualLength = VariableLengthEncoder.peekEncodedValue(bb, 0); + assertEquals(actualLength, expectedLength); + + var actualPosition = bb.position(); + assertEquals(actualPosition, 0); + } + + @Test(dataProvider = "encode invariants") + public void testEncode(long length, int capacity, Class exception) throws IOException { + var actualBuffer = ByteBuffer.allocate(capacity); + var expectedBuffer = getTestBuffer(length, capacity); + + if (exception != null) { + assertThrows(exception, () -> VariableLengthEncoder.encode(actualBuffer, length)); + // if method fails ensure that position hasn't changed + var actualPosition = actualBuffer.position(); + assertEquals(actualPosition, capacity); + } else { + VariableLengthEncoder.encode(actualBuffer, length); + var actualPosition = actualBuffer.position(); + assertEquals(actualPosition, capacity); + + // check length prefix + int firstByte = actualBuffer.get(0) & 0xFF; + int lengthPrefix = firstByte & 0xC0; + lengthPrefix >>= 6; + int expectedValue = (int)(Math.log(capacity) / Math.log(2)); + assertEquals(lengthPrefix, expectedValue); + + // check length encoded in buffer correctly + int b = firstByte & 0x3F; + actualBuffer.put(0, (byte) b); + assertEquals(actualBuffer.compareTo(expectedBuffer), 0); + } + } + + @Test(dataProvider = "prefix invariants") + public void testLengthPrefix(long length, int expectedPrefix, Class exception) { + if (exception != null) { + assertThrows(exception, () -> VariableLengthEncoder.getVariableLengthPrefix(length)); + } else { + var actualValue = VariableLengthEncoder.getVariableLengthPrefix(length); + assertEquals(actualValue, expectedPrefix); + } + } + + // Encode the given length and then decodes it and compares + // the results, asserting various invariants along the way. + @Test(dataProvider = "prefix invariants") + public void testEncodeDecode(long length, int expectedPrefix, Class exception) { + if (exception != null) { + assertThrows(exception, () -> VariableLengthEncoder.getEncodedSize(length)); + assertThrows(exception, () -> VariableLengthEncoder.encode(ByteBuffer.allocate(16), length)); + } else { + var actualSize = VariableLengthEncoder.getEncodedSize(length); + assertEquals(actualSize, 1 << expectedPrefix); + assertTrue(actualSize > 0, "length is negative or zero: " + actualSize); + assertTrue(actualSize < 9, "length is too big: " + actualSize); + + // Use different offsets for the position at which to encode/decode + for (int offset : List.of(0, 10)) { + System.out.printf("Encode/Decode %s on %s bytes with offset %s%n", + length, actualSize, offset); + + // allocate buffers: one exact, one too short, one too long + ByteBuffer exact = ByteBuffer.allocate(actualSize + offset); + exact.position(offset); + ByteBuffer shorter = ByteBuffer.allocate(actualSize - 1 + offset); + shorter.position(offset); + ByteBuffer shorterref = ByteBuffer.allocate(actualSize - 1 + offset); + shorterref.position(offset); + ByteBuffer longer = ByteBuffer.allocate(actualSize + 10 + offset); + longer.position(offset); + + // attempt to encode with a buffer too short + expectThrows(IAE, () -> VariableLengthEncoder.encode(shorter, length)); + assertEquals(shorter.position(), offset); + assertEquals(shorter.limit(), shorter.capacity()); + + assertEquals(shorter.mismatch(shorterref), -1); + assertEquals(shorterref.mismatch(shorter), -1); + + // attempt to encode with a buffer that has the exact size + var exactres = VariableLengthEncoder.encode(exact, length); + assertEquals(exactres, actualSize); + assertEquals(exact.position(), actualSize + offset); + assertFalse(exact.hasRemaining()); + + // attempt to encode with a buffer that has more bytes + var longres = VariableLengthEncoder.encode(longer, length); + assertEquals(longres, actualSize); + assertEquals(longer.position(), offset + actualSize); + assertEquals(longer.limit(), longer.capacity()); + assertEquals(longer.remaining(), 10); + + // compare encodings + + // first reset buffer positions for reading. + exact.position(offset); + longer.position(offset); + assertEquals(longer.mismatch(exact), actualSize); + assertEquals(exact.mismatch(longer), actualSize); + + // decode with a buffer that is missing the last + // byte... + var shortSlice = exact.duplicate(); + shortSlice.position(offset); + shortSlice.limit(offset + actualSize -1); + var actualLength = VariableLengthEncoder.decode(shortSlice); + assertEquals(actualLength, -1L); + assertEquals(shortSlice.position(), offset); + assertEquals(shortSlice.limit(), offset + actualSize - 1); + + // decode with the exact buffer + actualLength = VariableLengthEncoder.decode(exact); + assertEquals(actualLength, length); + assertEquals(exact.position(), offset + actualSize); + assertFalse(exact.hasRemaining()); + + // decode with the longer buffer + actualLength = VariableLengthEncoder.decode(longer); + assertEquals(actualLength, length); + assertEquals(longer.position(), offset + actualSize); + assertEquals(longer.remaining(), 10); + } + + } + } + + // Encode the given length and then peeks it and compares + // the results, asserting various invariants along the way. + @Test(dataProvider = "prefix invariants") + public void testEncodePeek(long length, int expectedPrefix, Class exception) { + if (exception != null) { + assertThrows(exception, () -> VariableLengthEncoder.getEncodedSize(length)); + assertThrows(exception, () -> VariableLengthEncoder.encode(ByteBuffer.allocate(16), length)); + return; + } + + var actualSize = VariableLengthEncoder.getEncodedSize(length); + assertEquals(actualSize, 1 << expectedPrefix); + assertTrue(actualSize > 0, "length is negative or zero: " + actualSize); + assertTrue(actualSize < 9, "length is too big: " + actualSize); + + // Use different offsets for the position at which to encode/decode + for (int offset : List.of(0, 10)) { + System.out.printf("Encode/Peek %s on %s bytes with offset %s%n", + length, actualSize, offset); + + // allocate buffers: one exact, one too long + ByteBuffer exact = ByteBuffer.allocate(actualSize + offset); + exact.position(offset); + ByteBuffer longer = ByteBuffer.allocate(actualSize + 10 + offset); + longer.position(offset); + + // attempt to encode with a buffer that has the exact size + var exactres = VariableLengthEncoder.encode(exact, length); + assertEquals(exactres, actualSize); + assertEquals(exact.position(), actualSize + offset); + assertFalse(exact.hasRemaining()); + + // attempt to encode with a buffer that has more bytes + var longres = VariableLengthEncoder.encode(longer, length); + assertEquals(longres, actualSize); + assertEquals(longer.position(), offset + actualSize); + assertEquals(longer.limit(), longer.capacity()); + assertEquals(longer.remaining(), 10); + + // compare encodings + + // first reset buffer positions for reading. + exact.position(offset); + longer.position(offset); + assertEquals(longer.mismatch(exact), actualSize); + assertEquals(exact.mismatch(longer), actualSize); + exact.position(0); + longer.position(0); + exact.limit(exact.capacity()); + longer.limit(longer.capacity()); + + // decode with a buffer that is missing the last + // byte... + var shortSlice = exact.duplicate(); + shortSlice.position(0); + shortSlice.limit(offset + actualSize - 1); + // need at least one byte to decode the size len... + var expectedSize = shortSlice.limit() <= offset ? -1 : actualSize; + assertEquals(VariableLengthEncoder.peekEncodedValueSize(shortSlice, offset), expectedSize); + var actualLength = VariableLengthEncoder.peekEncodedValue(shortSlice, offset); + assertEquals(actualLength, -1L); + assertEquals(shortSlice.position(), 0); + assertEquals(shortSlice.limit(), offset + actualSize - 1); + + // decode with the exact buffer + assertEquals(VariableLengthEncoder.peekEncodedValueSize(exact, offset), actualSize); + actualLength = VariableLengthEncoder.peekEncodedValue(exact, offset); + assertEquals(actualLength, length); + assertEquals(exact.position(), 0); + assertEquals(exact.limit(), exact.capacity()); + + // decode with the longer buffer + assertEquals(VariableLengthEncoder.peekEncodedValueSize(longer, offset), actualSize); + actualLength = VariableLengthEncoder.peekEncodedValue(longer, offset); + assertEquals(actualLength, length); + assertEquals(longer.position(), 0); + assertEquals(longer.limit(), longer.capacity()); + } + + } + + + private ByteBuffer getTestBuffer(long length, int capacity) { + return switch (capacity) { + case 0 -> ByteBuffer.allocate(1).put((byte) length); + case 1 -> ByteBuffer.allocate(capacity).put((byte) length); + case 2 -> ByteBuffer.allocate(capacity).putShort((short) length); + case 4 -> ByteBuffer.allocate(capacity).putInt((int) length); + case 8 -> ByteBuffer.allocate(capacity).putLong(length); + default -> throw new SkippedException("bad value used for capacity"); + }; + } +} diff --git a/test/jdk/java/net/httpclient/quic/VersionNegotiationTest.java b/test/jdk/java/net/httpclient/quic/VersionNegotiationTest.java new file mode 100644 index 00000000000..d0de5ce9a60 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/VersionNegotiationTest.java @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLParameters; + +import jdk.httpclient.test.lib.common.TestUtil; +import jdk.httpclient.test.lib.quic.ClientConnection; +import jdk.httpclient.test.lib.quic.ConnectedBidiStream; +import jdk.httpclient.test.lib.quic.QuicServer; +import jdk.httpclient.test.lib.quic.QuicServerConnection; +import jdk.httpclient.test.lib.quic.QuicServerHandler; +import jdk.httpclient.test.lib.quic.QuicStandaloneServer; +import jdk.internal.net.http.quic.QuicClient; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; + +/* + * @test + * @summary Test the version negotiation semantics of Quic + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.httpclient.test.lib.quic.QuicStandaloneServer + * jdk.httpclient.test.lib.common.TestUtil + * jdk.httpclient.test.lib.quic.ClientConnection + * jdk.test.lib.net.SimpleSSLContext + * @run testng/othervm -Djdk.internal.httpclient.debug=true VersionNegotiationTest + */ +public class VersionNegotiationTest { + + private static SSLContext sslContext; + private static ExecutorService executor; + + @BeforeClass + public static void beforeClass() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) { + throw new AssertionError("Unexpected null sslContext"); + } + executor = Executors.newCachedThreadPool(); + } + + @AfterClass + public static void afterClass() throws Exception { + if (executor != null) executor.shutdown(); + } + + private QuicClient createClient() { + return new QuicClient.Builder() + .availableVersions(List.of(QuicVersion.QUIC_V1)) + .tlsContext(new QuicTLSContext(sslContext)) + .sslParameters(new SSLParameters()) + .executor(executor) + .bindAddress(TestUtil.chooseClientBindAddress().orElse(null)) + .build(); + } + + /** + * Uses a Quic client which is enabled for a specific version and a Quic server + * which is enabled for a different version. Verifies that the connection attempt fails + * as noted in RFC-9000, section 6.2 + */ + @Test + public void testUnsupportedClientVersion() throws Exception { + try (final var client = createClient()) { + final QuicVersion serverVersion = QuicVersion.QUIC_V2; + try (final QuicServer server = createAndStartServer(serverVersion)) { + System.out.println("Attempting to connect " + client.getAvailableVersions() + + " client to a " + server.getAvailableVersions() + " server"); + final IOException thrown = expectThrows(IOException.class, + () -> ClientConnection.establishConnection(client, server.getAddress())); + // a version negotiation failure (since it happens during a QUIC connection + // handshake) gets thrown as a SSLHandshakeException + if (!(thrown.getCause() instanceof SSLHandshakeException sslhe)) { + throw thrown; + } + System.out.println("Received (potentially expected) exception: " + sslhe); + // additional check to make sure it was thrown for the right reason + assertEquals(sslhe.getMessage(), "QUIC connection establishment failed"); + // underlying cause of SSLHandshakeException should be version negotiation failure + final Throwable underlyingCause = sslhe.getCause(); + assertNotNull(underlyingCause, "missing cause in SSLHandshakeException"); + assertNotNull(underlyingCause.getMessage(), "missing message in " + underlyingCause); + assertTrue(underlyingCause.getMessage().contains("No support for any of the" + + " QUIC versions being negotiated")); + } + } + } + + /** + * Creates a server which supports only the specified Quic version + */ + private static QuicServer createAndStartServer(final QuicVersion version) throws IOException { + final QuicStandaloneServer server = QuicStandaloneServer.newBuilder() + .availableVersions(new QuicVersion[]{version}) + .sslContext(sslContext) + .build(); + server.addHandler(new ExceptionThrowingHandler()); + server.start(); + System.out.println("Quic server with version " + version + " started at " + server.getAddress()); + return server; + } + + private static final class ExceptionThrowingHandler implements QuicServerHandler { + + @Override + public void handleBidiStream(final QuicServerConnection conn, + final ConnectedBidiStream bidiStream) throws IOException { + throw new AssertionError("Handler shouldn't have been called for " + + bidiStream + " on connection " + conn); + } + } +} diff --git a/test/jdk/java/net/httpclient/quic/quic-tls-keylimits-java.security b/test/jdk/java/net/httpclient/quic/quic-tls-keylimits-java.security new file mode 100644 index 00000000000..0d66b3a9e2f --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/quic-tls-keylimits-java.security @@ -0,0 +1,4 @@ +# meant to override the jdk.quic.tls.keyLimits security property +# in the KeyUpdateTest +jdk.quic.tls.keyLimits=AES/GCM/NoPadding 10, \ + ChaCha20-Poly1305 -1 diff --git a/test/jdk/java/net/httpclient/quic/tls/PacketEncryptionTest.java b/test/jdk/java/net/httpclient/quic/tls/PacketEncryptionTest.java new file mode 100644 index 00000000000..ed1cbc08de0 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/tls/PacketEncryptionTest.java @@ -0,0 +1,457 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportException; +import org.testng.annotations.Test; +import sun.security.ssl.QuicTLSEngineImpl; +import sun.security.ssl.QuicTLSEngineImplAccessor; + +import javax.crypto.AEADBadTagException; +import javax.crypto.SecretKey; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.SecretKeySpec; +import javax.net.ssl.SSLContext; +import java.nio.ByteBuffer; +import java.security.NoSuchAlgorithmException; +import java.util.HexFormat; +import java.util.function.IntFunction; + +import static jdk.internal.net.quic.QuicVersion.QUIC_V1; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; + +/** + * @test + * @library /test/lib + * @modules java.base/sun.security.ssl + * java.base/jdk.internal.net.quic + * @build java.base/sun.security.ssl.QuicTLSEngineImplAccessor + * @summary known-answer test for packet encryption and decryption + * @run testng/othervm PacketEncryptionTest + */ +public class PacketEncryptionTest { + + // RFC 9001, appendix A + private static final String INITIAL_DCID = "8394c8f03e515708"; + // section A.2 + // header includes 4-byte packet number 2 + private static final String INITIAL_C_HEADER = "c300000001088394c8f03e5157080000449e00000002"; + private static final int INITIAL_C_PAYLOAD_OFFSET = INITIAL_C_HEADER.length() / 2; + private static final int INITIAL_C_PN_OFFSET = INITIAL_C_PAYLOAD_OFFSET - 4; + private static final int INITIAL_C_PN = 2; + // payload is zero-padded to 1162 bytes, not shown here + private static final String INITIAL_C_PAYLOAD = + "060040f1010000ed0303ebf8fa56f129"+"39b9584a3896472ec40bb863cfd3e868" + + "04fe3a47f06a2b69484c000004130113"+"02010000c000000010000e00000b6578" + + "616d706c652e636f6dff01000100000a"+"00080006001d00170018001000070005" + + "04616c706e0005000501000000000033"+"00260024001d00209370b2c9caa47fba" + + "baf4559fedba753de171fa71f50f1ce1"+"5d43e994ec74d748002b000302030400" + + "0d0010000e0403050306030203080408"+"050806002d00020101001c0002400100" + + "3900320408ffffffffffffffff050480"+"00ffff07048000ffff08011001048000" + + "75300901100f088394c8f03e51570806"+"048000ffff"; + private static final int INITIAL_C_PAYLOAD_LENGTH = 1162; + private static final String ENCRYPTED_C_PAYLOAD = + "c000000001088394c8f03e5157080000"+"449e7b9aec34d1b1c98dd7689fb8ec11" + + "d242b123dc9bd8bab936b47d92ec356c"+"0bab7df5976d27cd449f63300099f399" + + "1c260ec4c60d17b31f8429157bb35a12"+"82a643a8d2262cad67500cadb8e7378c" + + "8eb7539ec4d4905fed1bee1fc8aafba1"+"7c750e2c7ace01e6005f80fcb7df6212" + + "30c83711b39343fa028cea7f7fb5ff89"+"eac2308249a02252155e2347b63d58c5" + + "457afd84d05dfffdb20392844ae81215"+"4682e9cf012f9021a6f0be17ddd0c208" + + "4dce25ff9b06cde535d0f920a2db1bf3"+"62c23e596d11a4f5a6cf3948838a3aec" + + "4e15daf8500a6ef69ec4e3feb6b1d98e"+"610ac8b7ec3faf6ad760b7bad1db4ba3" + + "485e8a94dc250ae3fdb41ed15fb6a8e5"+"eba0fc3dd60bc8e30c5c4287e53805db" + + "059ae0648db2f64264ed5e39be2e20d8"+"2df566da8dd5998ccabdae053060ae6c" + + "7b4378e846d29f37ed7b4ea9ec5d82e7"+"961b7f25a9323851f681d582363aa5f8" + + "9937f5a67258bf63ad6f1a0b1d96dbd4"+"faddfcefc5266ba6611722395c906556" + + "be52afe3f565636ad1b17d508b73d874"+"3eeb524be22b3dcbc2c7468d54119c74" + + "68449a13d8e3b95811a198f3491de3e7"+"fe942b330407abf82a4ed7c1b311663a" + + "c69890f4157015853d91e923037c227a"+"33cdd5ec281ca3f79c44546b9d90ca00" + + "f064c99e3dd97911d39fe9c5d0b23a22"+"9a234cb36186c4819e8b9c5927726632" + + "291d6a418211cc2962e20fe47feb3edf"+"330f2c603a9d48c0fcb5699dbfe58964" + + "25c5bac4aee82e57a85aaf4e2513e4f0"+"5796b07ba2ee47d80506f8d2c25e50fd" + + "14de71e6c418559302f939b0e1abd576"+"f279c4b2e0feb85c1f28ff18f58891ff" + + "ef132eef2fa09346aee33c28eb130ff2"+"8f5b766953334113211996d20011a198" + + "e3fc433f9f2541010ae17c1bf202580f"+"6047472fb36857fe843b19f5984009dd" + + "c324044e847a4f4a0ab34f719595de37"+"252d6235365e9b84392b061085349d73" + + "203a4a13e96f5432ec0fd4a1ee65accd"+"d5e3904df54c1da510b0ff20dcc0c77f" + + "cb2c0e0eb605cb0504db87632cf3d8b4"+"dae6e705769d1de354270123cb11450e" + + "fc60ac47683d7b8d0f811365565fd98c"+"4c8eb936bcab8d069fc33bd801b03ade" + + "a2e1fbc5aa463d08ca19896d2bf59a07"+"1b851e6c239052172f296bfb5e724047" + + "90a2181014f3b94a4e97d117b4381303"+"68cc39dbb2d198065ae3986547926cd2" + + "162f40a29f0c3c8745c0f50fba3852e5"+"66d44575c29d39a03f0cda721984b6f4" + + "40591f355e12d439ff150aab7613499d"+"bd49adabc8676eef023b15b65bfc5ca0" + + "6948109f23f350db82123535eb8a7433"+"bdabcb909271a6ecbcb58b936a88cd4e" + + "8f2e6ff5800175f113253d8fa9ca8885"+"c2f552e657dc603f252e1a8e308f76f0" + + "be79e2fb8f5d5fbbe2e30ecadd220723"+"c8c0aea8078cdfcb3868263ff8f09400" + + "54da48781893a7e49ad5aff4af300cd8"+"04a6b6279ab3ff3afb64491c85194aab" + + "760d58a606654f9f4400e8b38591356f"+"bf6425aca26dc85244259ff2b19c41b9" + + "f96f3ca9ec1dde434da7d2d392b905dd"+"f3d1f9af93d1af5950bd493f5aa731b4" + + "056df31bd267b6b90a079831aaf579be"+"0a39013137aac6d404f518cfd4684064" + + "7e78bfe706ca4cf5e9c5453e9f7cfd2b"+"8b4c8d169a44e55c88d4a9a7f9474241" + + "e221af44860018ab0856972e194cd934"; + // section A.3 + // header includes 2-byte packet number 1 + private static final String INITIAL_S_HEADER = "c1000000010008f067a5502a4262b50040750001"; + private static final int INITIAL_S_PAYLOAD_OFFSET = INITIAL_S_HEADER.length() / 2; + private static final int INITIAL_S_PN_OFFSET = INITIAL_S_PAYLOAD_OFFSET - 2; + private static final int INITIAL_S_PN = 1; + // complete packet, no padding + private static final String INITIAL_S_PAYLOAD = + "02000000000600405a020000560303ee"+"fce7f7b37ba1d1632e96677825ddf739" + + "88cfc79825df566dc5430b9a045a1200"+"130100002e00330024001d00209d3c94" + + "0d89690b84d08a60993c144eca684d10"+"81287c834d5311bcf32bb9da1a002b00" + + "020304"; + private static final int INITIAL_S_PAYLOAD_LENGTH = INITIAL_S_PAYLOAD.length() / 2; + private static final String ENCRYPTED_S_PAYLOAD = + "cf000000010008f067a5502a4262b500"+"4075c0d95a482cd0991cd25b0aac406a" + + "5816b6394100f37a1c69797554780bb3"+"8cc5a99f5ede4cf73c3ec2493a1839b3" + + "dbcba3f6ea46c5b7684df3548e7ddeb9"+"c3bf9c73cc3f3bded74b562bfb19fb84" + + "022f8ef4cdd93795d77d06edbb7aaf2f"+"58891850abbdca3d20398c276456cbc4" + + "2158407dd074ee"; + + // section A.4 + private static final String SIGNED_RETRY = + "ff000000010008f067a5502a4262b574"+"6f6b656e04a265ba2eff4d829058fb3f" + + "0f2496ba"; + + // section A.5 + public static final String ONERTT_SECRET = "9ac312a7f877468ebe69422748ad00a1" + + "5443f18203a07d6060f688f30f21632b"; + private static final String ONERTT_HEADER = "4200bff4"; + private static final int ONERTT_PAYLOAD_OFFSET = ONERTT_HEADER.length() / 2; + private static final int ONERTT_PN_OFFSET = 1; + private static final int ONERTT_PN = 654360564; + // payload is zero-padded to 1162 bytes, not shown here + private static final String ONERTT_PAYLOAD = + "01"; + private static final int ONERTT_PAYLOAD_LENGTH = + ONERTT_PAYLOAD.length() / 2; + private static final String ENCRYPTED_ONERTT_PAYLOAD = + "4cfe4189655e5cd55c41f69080575d7999c25a5bfb"; + + private static final class FixedHeaderContent implements IntFunction { + private final ByteBuffer header; + private FixedHeaderContent(ByteBuffer header) { + this.header = header; + } + + @Override + public ByteBuffer apply(final int keyphase) { + // ignore keyphase + return this.header; + } + } + + @Test + public void testEncryptClientInitialPacket() throws Exception { + QuicTLSEngine clientEngine = getQuicV1Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + clientEngine.deriveInitialKeys(QUIC_V1, dcid); + + final int packetLen = INITIAL_C_PAYLOAD_OFFSET + INITIAL_C_PAYLOAD_LENGTH + 16; + final ByteBuffer packet = ByteBuffer.allocate(packetLen); + packet.put(HexFormat.of().parseHex(INITIAL_C_HEADER)); + packet.put(HexFormat.of().parseHex(INITIAL_C_PAYLOAD)); + + final ByteBuffer header = packet.slice(0, INITIAL_C_PAYLOAD_OFFSET).asReadOnlyBuffer(); + final ByteBuffer payload = packet.slice(INITIAL_C_PAYLOAD_OFFSET, INITIAL_C_PAYLOAD_LENGTH).asReadOnlyBuffer(); + + packet.position(INITIAL_C_PAYLOAD_OFFSET); + clientEngine.encryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_C_PN, new FixedHeaderContent(header), payload, packet); + protect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_C_PN_OFFSET, INITIAL_C_PAYLOAD_OFFSET - INITIAL_C_PN_OFFSET, clientEngine, 0x0f); + + assertEquals(HexFormat.of().formatHex(packet.array()), ENCRYPTED_C_PAYLOAD); + } + + @Test + public void testDecryptClientInitialPacket() throws Exception { + QuicTLSEngine serverEngine = getQuicV1Engine(SSLContext.getDefault(), false); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + serverEngine.deriveInitialKeys(QUIC_V1, dcid); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_C_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_C_PN_OFFSET, INITIAL_C_PAYLOAD_OFFSET - INITIAL_C_PN_OFFSET, serverEngine, 0x0f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(INITIAL_C_PAYLOAD_OFFSET); + + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_C_PN, -1, + src, INITIAL_C_PAYLOAD_OFFSET, packet); + + String expectedContents = INITIAL_C_HEADER + INITIAL_C_PAYLOAD; + + assertEquals(HexFormat.of().formatHex(packet.array()).substring(0, expectedContents.length()), expectedContents); + } + + @Test(expectedExceptions = AEADBadTagException.class) + public void testDecryptClientInitialPacketBadTag() throws Exception { + QuicTLSEngine serverEngine = getQuicV1Engine(SSLContext.getDefault(), false); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + serverEngine.deriveInitialKeys(QUIC_V1, dcid); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_C_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_C_PN_OFFSET, INITIAL_C_PAYLOAD_OFFSET - INITIAL_C_PN_OFFSET, serverEngine, 0x0f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(INITIAL_C_PAYLOAD_OFFSET); + + // change one byte of AEAD tag + packet.put(packet.limit() - 1, (byte)0); + + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_C_PN, -1, + src, INITIAL_C_PAYLOAD_OFFSET, packet); + fail("Decryption should have failed"); + } + + @Test + public void testEncryptServerInitialPacket() throws Exception { + QuicTLSEngine serverEngine = getQuicV1Engine(SSLContext.getDefault(), false); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + serverEngine.deriveInitialKeys(QUIC_V1, dcid); + + final int packetLen = INITIAL_S_PAYLOAD_OFFSET + INITIAL_S_PAYLOAD_LENGTH + 16; + final ByteBuffer packet = ByteBuffer.allocate(packetLen); + packet.put(HexFormat.of().parseHex(INITIAL_S_HEADER)); + packet.put(HexFormat.of().parseHex(INITIAL_S_PAYLOAD)); + + final ByteBuffer header = packet.slice(0, INITIAL_S_PAYLOAD_OFFSET).asReadOnlyBuffer(); + final ByteBuffer payload = packet.slice(INITIAL_S_PAYLOAD_OFFSET, INITIAL_S_PAYLOAD_LENGTH).asReadOnlyBuffer(); + + packet.position(INITIAL_S_PAYLOAD_OFFSET); + serverEngine.encryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_S_PN, new FixedHeaderContent(header), payload, packet); + protect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_S_PN_OFFSET, INITIAL_S_PAYLOAD_OFFSET - INITIAL_S_PN_OFFSET, serverEngine, 0x0f); + + assertEquals(HexFormat.of().formatHex(packet.array()), ENCRYPTED_S_PAYLOAD); + } + + @Test + public void testDecryptServerInitialPacket() throws Exception { + QuicTLSEngine clientEngine = getQuicV1Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + clientEngine.deriveInitialKeys(QUIC_V1, dcid); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_S_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_S_PN_OFFSET, INITIAL_S_PAYLOAD_OFFSET - INITIAL_S_PN_OFFSET, clientEngine, 0x0f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(INITIAL_S_PAYLOAD_OFFSET); + + clientEngine.decryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_S_PN, -1, + src, INITIAL_S_PAYLOAD_OFFSET, packet); + + String expectedContents = INITIAL_S_HEADER + INITIAL_S_PAYLOAD; + + assertEquals(HexFormat.of().formatHex(packet.array()).substring(0, expectedContents.length()), expectedContents); + } + + @Test + public void testDecryptServerInitialPacketTwice() throws Exception { + // verify that decrypting the same packet twice does not throw + QuicTLSEngine clientEngine = getQuicV1Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + clientEngine.deriveInitialKeys(QUIC_V1, dcid); + + // attempt 1 + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_S_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_S_PN_OFFSET, INITIAL_S_PAYLOAD_OFFSET - INITIAL_S_PN_OFFSET, clientEngine, 0x0f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(INITIAL_S_PAYLOAD_OFFSET); + clientEngine.decryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_S_PN, -1, src, INITIAL_S_PAYLOAD_OFFSET, packet); + + // attempt 2 + packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_S_PAYLOAD)); + // must not throw + unprotect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_S_PN_OFFSET, INITIAL_S_PAYLOAD_OFFSET - INITIAL_S_PN_OFFSET, clientEngine, 0x0f); + src = packet.asReadOnlyBuffer(); + packet.position(INITIAL_S_PAYLOAD_OFFSET); + // must not throw + clientEngine.decryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_S_PN, -1, src, INITIAL_S_PAYLOAD_OFFSET, packet); + } + + @Test + public void testSignRetry() throws NoSuchAlgorithmException, ShortBufferException, QuicTransportException { + QuicTLSEngine clientEngine = getQuicV1Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + + ByteBuffer packet = ByteBuffer.allocate(SIGNED_RETRY.length() / 2); + packet.put(HexFormat.of().parseHex(SIGNED_RETRY), 0, SIGNED_RETRY.length() / 2 - 16); + + ByteBuffer src = packet.asReadOnlyBuffer(); + src.limit(src.position()); + src.position(0); + + clientEngine.signRetryPacket(QUIC_V1, dcid, src, packet); + + assertEquals(HexFormat.of().formatHex(packet.array()), SIGNED_RETRY); + } + + @Test + public void testVerifyRetry() throws NoSuchAlgorithmException, AEADBadTagException, QuicTransportException { + QuicTLSEngine clientEngine = getQuicV1Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(SIGNED_RETRY)); + + clientEngine.verifyRetryPacket(QUIC_V1, dcid, packet); + } + + @Test(expectedExceptions = AEADBadTagException.class) + public void testVerifyBadRetry() throws NoSuchAlgorithmException, AEADBadTagException, QuicTransportException { + QuicTLSEngine clientEngine = getQuicV1Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(SIGNED_RETRY)); + + // change one byte of AEAD tag + packet.put(packet.limit() - 1, (byte)0); + clientEngine.verifyRetryPacket(QUIC_V1, dcid, packet); + fail("Verification should have failed"); + } + + @Test + public void testEncryptChaCha() throws Exception { + QuicTLSEngineImpl clientEngine = (QuicTLSEngineImpl) getQuicV1Engine(SSLContext.getDefault(), true); + SecretKey key = new SecretKeySpec(HexFormat.of().parseHex(ONERTT_SECRET), 0, 32, "ChaCha20-Poly1305"); + QuicTLSEngineImplAccessor.testDeriveOneRTTKeys(QUIC_V1, clientEngine, key, key, "TLS_CHACHA20_POLY1305_SHA256", true); + + final int packetLen = ONERTT_PAYLOAD_OFFSET + ONERTT_PAYLOAD_LENGTH + 16; + final ByteBuffer packet = ByteBuffer.allocate(packetLen); + packet.put(HexFormat.of().parseHex(ONERTT_HEADER)); + packet.put(HexFormat.of().parseHex(ONERTT_PAYLOAD)); + + final ByteBuffer header = packet.slice(0, ONERTT_PAYLOAD_OFFSET).asReadOnlyBuffer(); + final ByteBuffer payload = packet.slice(ONERTT_PAYLOAD_OFFSET, ONERTT_PAYLOAD_LENGTH).asReadOnlyBuffer(); + + packet.position(ONERTT_PAYLOAD_OFFSET); + clientEngine.encryptPacket(QuicTLSEngine.KeySpace.ONE_RTT, ONERTT_PN , new FixedHeaderContent(header), payload, packet); + protect(QuicTLSEngine.KeySpace.ONE_RTT, packet, ONERTT_PN_OFFSET, ONERTT_PAYLOAD_OFFSET - ONERTT_PN_OFFSET, clientEngine, 0x1f); + + assertEquals(HexFormat.of().formatHex(packet.array()), ENCRYPTED_ONERTT_PAYLOAD); + } + + @Test + public void testDecryptChaCha() throws Exception { + QuicTLSEngineImpl serverEngine = (QuicTLSEngineImpl) getQuicV1Engine(SSLContext.getDefault(), false); + // mark the TLS handshake as FINISHED + QuicTLSEngineImplAccessor.completeHandshake(serverEngine); + SecretKey key = new SecretKeySpec(HexFormat.of().parseHex(ONERTT_SECRET), 0, 32, "ChaCha20-Poly1305"); + QuicTLSEngineImplAccessor.testDeriveOneRTTKeys(QUIC_V1, serverEngine, key, key, "TLS_CHACHA20_POLY1305_SHA256", false); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_ONERTT_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.ONE_RTT, packet, ONERTT_PN_OFFSET, ONERTT_PAYLOAD_OFFSET - ONERTT_PN_OFFSET, serverEngine, 0x1f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(ONERTT_PAYLOAD_OFFSET); + + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.ONE_RTT, ONERTT_PN, (byte) 0, + src, ONERTT_PAYLOAD_OFFSET, packet); + + String expectedContents = ONERTT_HEADER + ONERTT_PAYLOAD; + + assertEquals(HexFormat.of().formatHex(packet.array()).substring(0, expectedContents.length()), expectedContents); + } + + @Test + public void testDecryptChaChaTwice() throws Exception { + // verify that decrypting the same packet twice does not throw + QuicTLSEngineImpl serverEngine = (QuicTLSEngineImpl) getQuicV1Engine(SSLContext.getDefault(), false); + // mark the TLS handshake as FINISHED + QuicTLSEngineImplAccessor.completeHandshake(serverEngine); + SecretKey key = new SecretKeySpec(HexFormat.of().parseHex(ONERTT_SECRET), 0, 32, "ChaCha20-Poly1305"); + QuicTLSEngineImplAccessor.testDeriveOneRTTKeys(QUIC_V1, serverEngine, key, key, "TLS_CHACHA20_POLY1305_SHA256", false); + + final int keyPhase = 0; + // attempt 1 + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_ONERTT_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.ONE_RTT, packet, ONERTT_PN_OFFSET, ONERTT_PAYLOAD_OFFSET - ONERTT_PN_OFFSET, serverEngine, 0x1f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(ONERTT_PAYLOAD_OFFSET); + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.ONE_RTT, ONERTT_PN, keyPhase, src, ONERTT_PAYLOAD_OFFSET, packet); + + // attempt 2 + packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_ONERTT_PAYLOAD)); + // must not throw + unprotect(QuicTLSEngine.KeySpace.ONE_RTT, packet, ONERTT_PN_OFFSET, ONERTT_PAYLOAD_OFFSET - ONERTT_PN_OFFSET, serverEngine, 0x1f); + src = packet.asReadOnlyBuffer(); + packet.position(ONERTT_PAYLOAD_OFFSET); + // must not throw + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.ONE_RTT, ONERTT_PN, keyPhase, src, ONERTT_PAYLOAD_OFFSET, packet); + } + + @Test(expectedExceptions = AEADBadTagException.class) + public void testDecryptChaChaBadTag() throws Exception { + QuicTLSEngineImpl serverEngine = (QuicTLSEngineImpl) getQuicV1Engine(SSLContext.getDefault(), false); + // mark the TLS handshake as FINISHED + QuicTLSEngineImplAccessor.completeHandshake(serverEngine); + SecretKey key = new SecretKeySpec(HexFormat.of().parseHex(ONERTT_SECRET), 0, 32, "ChaCha20-Poly1305"); + QuicTLSEngineImplAccessor.testDeriveOneRTTKeys(QUIC_V1, serverEngine, key, key, "TLS_CHACHA20_POLY1305_SHA256", false); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_ONERTT_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.ONE_RTT, packet, ONERTT_PN_OFFSET, ONERTT_PAYLOAD_OFFSET - ONERTT_PN_OFFSET, serverEngine, 0x1f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(ONERTT_PAYLOAD_OFFSET); + + // change one byte of AEAD tag + packet.put(packet.limit() - 1, (byte)0); + + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.ONE_RTT, ONERTT_PN, (byte) 0, + src, ONERTT_PAYLOAD_OFFSET, packet); + fail("Decryption should have failed"); + } + + + private void protect(QuicTLSEngine.KeySpace space, ByteBuffer buffer, + int packetNumberStart, int packetNumberLength, QuicTLSEngine tlsEngine, + int headersMask) throws QuicKeyUnavailableException, QuicTransportException { + ByteBuffer sample = buffer.slice(packetNumberStart + 4, 16); + ByteBuffer encryptedSample = tlsEngine.computeHeaderProtectionMask(space, false, sample); + byte headers = buffer.get(0); + headers ^= encryptedSample.get() & headersMask; + buffer.put(0, headers); + maskPacketNumber(buffer, packetNumberStart, packetNumberLength, encryptedSample); + } + + private void unprotect(QuicTLSEngine.KeySpace keySpace, ByteBuffer buffer, + int packetNumberStart, int packetNumberLength, + QuicTLSEngine tlsEngine, int headersMask) throws QuicKeyUnavailableException, QuicTransportException { + ByteBuffer sample = buffer.slice(packetNumberStart + 4, 16); + ByteBuffer encryptedSample = tlsEngine.computeHeaderProtectionMask(keySpace, true, sample); + byte headers = buffer.get(0); + headers ^= encryptedSample.get() & headersMask; + buffer.put(0, headers); + maskPacketNumber(buffer, packetNumberStart, packetNumberLength, encryptedSample); + } + + private void maskPacketNumber(ByteBuffer buffer, int packetNumberStart, int packetNumberLength, ByteBuffer mask) { + for (int i = 0; i < packetNumberLength; i++) { + buffer.put(packetNumberStart + i, (byte)(buffer.get(packetNumberStart + i) ^ mask.get())); + } + } + + // returns a QuicTLSEngine with only Quic version 1 enabled + private QuicTLSEngine getQuicV1Engine(SSLContext context, boolean mode) { + final QuicTLSContext quicTLSContext = new QuicTLSContext(context); + final QuicTLSEngine engine = quicTLSContext.createEngine(); + engine.setUseClientMode(mode); + return engine; + } +} diff --git a/test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineBadParametersTest.java b/test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineBadParametersTest.java new file mode 100644 index 00000000000..e6a99caebad --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineBadParametersTest.java @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportErrors; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import org.junit.jupiter.api.Test; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * @test + * @library /test/lib + * @modules java.base/sun.security.ssl + * java.base/jdk.internal.net.quic + * @build jdk.test.lib.net.SimpleSSLContext + * @summary Verify that QuicTransportExceptions thrown by transport parameter consumer + * are propagated to the QuicTLSEngine user + * @run junit/othervm QuicTLSEngineBadParametersTest + */ +public class QuicTLSEngineBadParametersTest { + + @Test + void testServerConsumerExceptionPropagated() throws IOException { + SSLContext ctx = SimpleSSLContext.getContext("TLSv1.3"); + QuicTLSContext qctx = new QuicTLSContext(ctx); + QuicTLSEngine clientEngine = createClientEngine(qctx); + QuicTLSEngine serverEngine = createServerEngine(qctx); + QuicTransportException ex = + new QuicTransportException("", null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + serverEngine.setRemoteQuicTransportParametersConsumer(p -> { throw ex; }); + ByteBuffer cTOs = clientEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL); + try { + serverEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL, cTOs); + fail("Expected exception not thrown"); + } catch (QuicTransportException e) { + assertSame(ex, e, "Incorrect exception caught"); + } + } + + @Test + void testClientConsumerExceptionPropagated() throws IOException, QuicTransportException { + SSLContext ctx = SimpleSSLContext.getContext("TLSv1.3"); + QuicTLSContext qctx = new QuicTLSContext(ctx); + QuicTLSEngine clientEngine = createClientEngine(qctx); + QuicTLSEngine serverEngine = createServerEngine(qctx); + QuicTransportException ex = + new QuicTransportException("", null, 0, QuicTransportErrors.TRANSPORT_PARAMETER_ERROR); + clientEngine.setRemoteQuicTransportParametersConsumer(p -> { throw ex; }); + + // client hello + ByteBuffer cTOs = clientEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL); + serverEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL, cTOs); + assertFalse(cTOs.hasRemaining()); + // server hello + ByteBuffer sTOc = serverEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL); + clientEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL, sTOc); + assertFalse(sTOc.hasRemaining()); + // encrypted extensions + sTOc = serverEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.HANDSHAKE); + try { + clientEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.HANDSHAKE, sTOc); + fail("Expected exception not thrown"); + } catch (QuicTransportException e) { + assertSame(ex, e, "Incorrect exception caught"); + } + + } + + private static QuicTLSEngine createServerEngine(QuicTLSContext qctx) { + QuicTLSEngine engine = qctx.createEngine(); + engine.setUseClientMode(false); + SSLParameters params = engine.getSSLParameters(); + params.setApplicationProtocols(new String[] { "test" }); + engine.setSSLParameters(params); + engine.setLocalQuicTransportParameters(ByteBuffer.allocate(0)); + engine.setRemoteQuicTransportParametersConsumer(p -> { }); + engine.versionNegotiated(QuicVersion.QUIC_V1); + return engine; + } + + private static QuicTLSEngine createClientEngine(QuicTLSContext qctx) { + QuicTLSEngine engine = qctx.createEngine("localhost", 1234); + engine.setUseClientMode(true); + SSLParameters params = engine.getSSLParameters(); + params.setApplicationProtocols(new String[] { "test" }); + engine.setSSLParameters(params); + engine.setLocalQuicTransportParameters(ByteBuffer.allocate(0)); + engine.setRemoteQuicTransportParametersConsumer(p -> { }); + engine.versionNegotiated(QuicVersion.QUIC_V1); + return engine; + } + +} diff --git a/test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineFailedALPNTest.java b/test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineFailedALPNTest.java new file mode 100644 index 00000000000..fa7df3bc5c3 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineFailedALPNTest.java @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import org.junit.jupiter.api.Test; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * @test + * @library /test/lib + * @modules java.base/sun.security.ssl + * java.base/jdk.internal.net.quic + * @build jdk.test.lib.net.SimpleSSLContext + * @summary Verify that a missing ALPN extension results in no_application_protocol alert + * @run junit/othervm QuicTLSEngineFailedALPNTest + */ +public class QuicTLSEngineFailedALPNTest { + + @Test + void testServerRequiresALPN() throws IOException { + SSLContext ctx = SimpleSSLContext.getContext("TLSv1.3"); + QuicTLSContext qctx = new QuicTLSContext(ctx); + QuicTLSEngine clientEngine = createClientEngine(qctx); + QuicTLSEngine serverEngine = createServerEngine(qctx); + ByteBuffer cTOs = clientEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL); + try { + serverEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL, cTOs); + fail("Expected exception not thrown"); + } catch (QuicTransportException e) { + assertEquals(0x0178, e.getErrorCode(), "Unexpected error code"); + } + } + + @Test + void testClientRequiresALPN() throws IOException, QuicTransportException { + SSLContext ctx = SimpleSSLContext.getContext("TLSv1.3"); + QuicTLSContext qctx = new QuicTLSContext(ctx); + QuicTLSEngine clientEngine = createClientEngine(qctx); + QuicTLSEngine serverEngine = createServerEngine(qctx); + SSLParameters params = clientEngine.getSSLParameters(); + params.setApplicationProtocols(new String[] { "test" }); + clientEngine.setSSLParameters(params); + // client hello + ByteBuffer cTOs = clientEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL); + serverEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL, cTOs); + assertFalse(cTOs.hasRemaining()); + // server hello + ByteBuffer sTOc = serverEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL); + clientEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL, sTOc); + assertFalse(sTOc.hasRemaining()); + // encrypted extensions + sTOc = serverEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.HANDSHAKE); + try { + clientEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.HANDSHAKE, sTOc); + fail("Expected exception not thrown"); + } catch (QuicTransportException e) { + assertEquals(0x0178, e.getErrorCode(), "Unexpected error code"); + } + + } + + private static QuicTLSEngine createServerEngine(QuicTLSContext qctx) { + QuicTLSEngine engine = qctx.createEngine(); + engine.setUseClientMode(false); + engine.setLocalQuicTransportParameters(ByteBuffer.allocate(0)); + engine.setRemoteQuicTransportParametersConsumer(params-> { }); + engine.versionNegotiated(QuicVersion.QUIC_V1); + return engine; + } + + private static QuicTLSEngine createClientEngine(QuicTLSContext qctx) { + QuicTLSEngine engine = qctx.createEngine("localhost", 1234); + engine.setUseClientMode(true); + engine.setLocalQuicTransportParameters(ByteBuffer.allocate(0)); + engine.setRemoteQuicTransportParametersConsumer(params-> { }); + engine.versionNegotiated(QuicVersion.QUIC_V1); + return engine; + } + +} diff --git a/test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineMissingParametersTest.java b/test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineMissingParametersTest.java new file mode 100644 index 00000000000..da9f06785a2 --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/tls/QuicTLSEngineMissingParametersTest.java @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportException; +import jdk.internal.net.quic.QuicVersion; +import jdk.test.lib.net.SimpleSSLContext; +import org.junit.jupiter.api.Test; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * @test + * @library /test/lib + * @modules java.base/sun.security.ssl + * java.base/jdk.internal.net.quic + * @build jdk.test.lib.net.SimpleSSLContext + * @summary Verify that a missing transport parameters extension results in missing_extension alert + * @run junit/othervm QuicTLSEngineMissingParametersTest + */ +public class QuicTLSEngineMissingParametersTest { + + @Test + void testServerRequiresTransportParameters() throws IOException { + SSLContext ctx = SimpleSSLContext.getContext("TLSv1.3"); + QuicTLSContext qctx = new QuicTLSContext(ctx); + QuicTLSEngine clientEngine = createClientEngine(qctx); + QuicTLSEngine serverEngine = createServerEngine(qctx); + ByteBuffer cTOs = clientEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL); + try { + serverEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL, cTOs); + fail("Expected exception not thrown"); + } catch (QuicTransportException e) { + assertEquals(0x016d, e.getErrorCode(), "Unexpected error code"); + } + } + + @Test + void testClientRequiresTransportParameters() throws IOException, QuicTransportException { + SSLContext ctx = SimpleSSLContext.getContext("TLSv1.3"); + QuicTLSContext qctx = new QuicTLSContext(ctx); + QuicTLSEngine clientEngine = createClientEngine(qctx); + QuicTLSEngine serverEngine = createServerEngine(qctx); + clientEngine.setLocalQuicTransportParameters(ByteBuffer.allocate(0)); + // client hello + ByteBuffer cTOs = clientEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL); + serverEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL, cTOs); + assertFalse(cTOs.hasRemaining()); + // server hello + ByteBuffer sTOc = serverEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL); + clientEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.INITIAL, sTOc); + assertFalse(sTOc.hasRemaining()); + // encrypted extensions + sTOc = serverEngine.getHandshakeBytes(QuicTLSEngine.KeySpace.HANDSHAKE); + try { + clientEngine.consumeHandshakeBytes(QuicTLSEngine.KeySpace.HANDSHAKE, sTOc); + fail("Expected exception not thrown"); + } catch (QuicTransportException e) { + assertEquals(0x016d, e.getErrorCode(), "Unexpected error code"); + } + + } + + private static QuicTLSEngine createServerEngine(QuicTLSContext qctx) { + QuicTLSEngine engine = qctx.createEngine(); + engine.setUseClientMode(false); + SSLParameters params = engine.getSSLParameters(); + params.setApplicationProtocols(new String[] { "test" }); + engine.setSSLParameters(params); + engine.setRemoteQuicTransportParametersConsumer(p -> { }); + engine.versionNegotiated(QuicVersion.QUIC_V1); + return engine; + } + + private static QuicTLSEngine createClientEngine(QuicTLSContext qctx) { + QuicTLSEngine engine = qctx.createEngine("localhost", 1234); + engine.setUseClientMode(true); + SSLParameters params = engine.getSSLParameters(); + params.setApplicationProtocols(new String[] { "test" }); + engine.setSSLParameters(params); + engine.setRemoteQuicTransportParametersConsumer(p -> { }); + engine.versionNegotiated(QuicVersion.QUIC_V1); + return engine; + } + +} diff --git a/test/jdk/java/net/httpclient/quic/tls/Quicv2PacketEncryptionTest.java b/test/jdk/java/net/httpclient/quic/tls/Quicv2PacketEncryptionTest.java new file mode 100644 index 00000000000..a506495ed5e --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/tls/Quicv2PacketEncryptionTest.java @@ -0,0 +1,451 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.quic.QuicKeyUnavailableException; +import jdk.internal.net.quic.QuicTLSContext; +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicTransportException; +import org.testng.annotations.Test; +import sun.security.ssl.QuicTLSEngineImpl; +import sun.security.ssl.QuicTLSEngineImplAccessor; + +import javax.crypto.AEADBadTagException; +import javax.crypto.SecretKey; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.SecretKeySpec; +import javax.net.ssl.SSLContext; +import java.nio.ByteBuffer; +import java.security.NoSuchAlgorithmException; +import java.util.HexFormat; +import java.util.function.IntFunction; + +import static jdk.internal.net.quic.QuicVersion.QUIC_V2; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; + +/** + * @test + * @library /test/lib + * @modules java.base/sun.security.ssl + * java.base/jdk.internal.net.quic + * @build java.base/sun.security.ssl.QuicTLSEngineImplAccessor + * @summary known-answer test for packet encryption and decryption with Quic v2 + * @run testng/othervm Quicv2PacketEncryptionTest + */ +public class Quicv2PacketEncryptionTest { + + // RFC 9369, appendix A + private static final String INITIAL_DCID = "8394c8f03e515708"; + // section A.2 + // header includes 4-byte packet number 2 + private static final String INITIAL_C_HEADER = "d36b3343cf088394c8f03e5157080000449e00000002"; + private static final int INITIAL_C_PAYLOAD_OFFSET = INITIAL_C_HEADER.length() / 2; + private static final int INITIAL_C_PN_OFFSET = INITIAL_C_PAYLOAD_OFFSET - 4; + private static final int INITIAL_C_PN = 2; + // payload is zero-padded to 1162 bytes, not shown here + private static final String INITIAL_C_PAYLOAD = + "060040f1010000ed0303ebf8fa56f129"+"39b9584a3896472ec40bb863cfd3e868" + + "04fe3a47f06a2b69484c000004130113"+"02010000c000000010000e00000b6578" + + "616d706c652e636f6dff01000100000a"+"00080006001d00170018001000070005" + + "04616c706e0005000501000000000033"+"00260024001d00209370b2c9caa47fba" + + "baf4559fedba753de171fa71f50f1ce1"+"5d43e994ec74d748002b000302030400" + + "0d0010000e0403050306030203080408"+"050806002d00020101001c0002400100" + + "3900320408ffffffffffffffff050480"+"00ffff07048000ffff08011001048000" + + "75300901100f088394c8f03e51570806"+"048000ffff"; + private static final int INITIAL_C_PAYLOAD_LENGTH = 1162; + private static final String ENCRYPTED_C_PAYLOAD = + "d76b3343cf088394c8f03e5157080000"+"449ea0c95e82ffe67b6abcdb4298b485" + + "dd04de806071bf03dceebfa162e75d6c"+"96058bdbfb127cdfcbf903388e99ad04" + + "9f9a3dd4425ae4d0992cfff18ecf0fdb"+"5a842d09747052f17ac2053d21f57c5d" + + "250f2c4f0e0202b70785b7946e992e58"+"a59ac52dea6774d4f03b55545243cf1a" + + "12834e3f249a78d395e0d18f4d766004"+"f1a2674802a747eaa901c3f10cda5500" + + "cb9122faa9f1df66c392079a1b40f0de"+"1c6054196a11cbea40afb6ef5253cd68" + + "18f6625efce3b6def6ba7e4b37a40f77"+"32e093daa7d52190935b8da58976ff33" + + "12ae50b187c1433c0f028edcc4c2838b"+"6a9bfc226ca4b4530e7a4ccee1bfa2a3" + + "d396ae5a3fb512384b2fdd851f784a65"+"e03f2c4fbe11a53c7777c023462239dd" + + "6f7521a3f6c7d5dd3ec9b3f233773d4b"+"46d23cc375eb198c63301c21801f6520" + + "bcfb7966fc49b393f0061d974a2706df"+"8c4a9449f11d7f3d2dcbb90c6b877045" + + "636e7c0c0fe4eb0f697545460c806910"+"d2c355f1d253bc9d2452aaa549e27a1f" + + "ac7cf4ed77f322e8fa894b6a83810a34"+"b361901751a6f5eb65a0326e07de7c12" + + "16ccce2d0193f958bb3850a833f7ae43"+"2b65bc5a53975c155aa4bcb4f7b2c4e5" + + "4df16efaf6ddea94e2c50b4cd1dfe060"+"17e0e9d02900cffe1935e0491d77ffb4" + + "fdf85290fdd893d577b1131a610ef6a5"+"c32b2ee0293617a37cbb08b847741c3b" + + "8017c25ca9052ca1079d8b78aebd4787"+"6d330a30f6a8c6d61dd1ab5589329de7" + + "14d19d61370f8149748c72f132f0fc99"+"f34d766c6938597040d8f9e2bb522ff9" + + "9c63a344d6a2ae8aa8e51b7b90a4a806"+"105fcbca31506c446151adfeceb51b91" + + "abfe43960977c87471cf9ad4074d30e1"+"0d6a7f03c63bd5d4317f68ff325ba3bd" + + "80bf4dc8b52a0ba031758022eb025cdd"+"770b44d6d6cf0670f4e990b22347a7db" + + "848265e3e5eb72dfe8299ad7481a4083"+"22cac55786e52f633b2fb6b614eaed18" + + "d703dd84045a274ae8bfa73379661388"+"d6991fe39b0d93debb41700b41f90a15" + + "c4d526250235ddcd6776fc77bc97e7a4"+"17ebcb31600d01e57f32162a8560cacc" + + "7e27a096d37a1a86952ec71bd89a3e9a"+"30a2a26162984d7740f81193e8238e61" + + "f6b5b984d4d3dfa033c1bb7e4f0037fe"+"bf406d91c0dccf32acf423cfa1e70710" + + "10d3f270121b493ce85054ef58bada42"+"310138fe081adb04e2bd901f2f13458b" + + "3d6758158197107c14ebb193230cd115"+"7380aa79cae1374a7c1e5bbcb80ee23e" + + "06ebfde206bfb0fcbc0edc4ebec30966"+"1bdd908d532eb0c6adc38b7ca7331dce" + + "8dfce39ab71e7c32d318d136b6100671"+"a1ae6a6600e3899f31f0eed19e3417d1" + + "34b90c9058f8632c798d4490da498730"+"7cba922d61c39805d072b589bd52fdf1" + + "e86215c2d54e6670e07383a27bbffb5a"+"ddf47d66aa85a0c6f9f32e59d85a44dd" + + "5d3b22dc2be80919b490437ae4f36a0a"+"e55edf1d0b5cb4e9a3ecabee93dfc6e3" + + "8d209d0fa6536d27a5d6fbb17641cde2"+"7525d61093f1b28072d111b2b4ae5f89" + + "d5974ee12e5cf7d5da4d6a31123041f3"+"3e61407e76cffcdcfd7e19ba58cf4b53" + + "6f4c4938ae79324dc402894b44faf8af"+"bab35282ab659d13c93f70412e85cb19" + + "9a37ddec600545473cfb5a05e08d0b20"+"9973b2172b4d21fb69745a262ccde96b" + + "a18b2faa745b6fe189cf772a9f84cbfc"; + // section A.3 + // header includes 2-byte packet number 1 + private static final String INITIAL_S_HEADER = "d16b3343cf0008f067a5502a4262b50040750001"; + private static final int INITIAL_S_PAYLOAD_OFFSET = INITIAL_S_HEADER.length() / 2; + private static final int INITIAL_S_PN_OFFSET = INITIAL_S_PAYLOAD_OFFSET - 2; + private static final int INITIAL_S_PN = 1; + // complete packet, no padding + private static final String INITIAL_S_PAYLOAD = + "02000000000600405a020000560303ee"+"fce7f7b37ba1d1632e96677825ddf739" + + "88cfc79825df566dc5430b9a045a1200"+"130100002e00330024001d00209d3c94" + + "0d89690b84d08a60993c144eca684d10"+"81287c834d5311bcf32bb9da1a002b00" + + "020304"; + private static final int INITIAL_S_PAYLOAD_LENGTH = INITIAL_S_PAYLOAD.length() / 2; + private static final String ENCRYPTED_S_PAYLOAD = + "dc6b3343cf0008f067a5502a4262b500"+"4075d92faaf16f05d8a4398c47089698" + + "baeea26b91eb761d9b89237bbf872630"+"17915358230035f7fd3945d88965cf17" + + "f9af6e16886c61bfc703106fbaf3cb4c"+"fa52382dd16a393e42757507698075b2" + + "c984c707f0a0812d8cd5a6881eaf21ce"+"da98f4bd23f6fe1a3e2c43edd9ce7ca8" + + "4bed8521e2e140"; + + // section A.4 + private static final String SIGNED_RETRY = + "cf6b3343cf0008f067a5502a4262b574"+"6f6b656ec8646ce8bfe33952d9555436" + + "65dcc7b6"; + + // section A.5 + public static final String ONERTT_SECRET = "9ac312a7f877468ebe69422748ad00a1" + + "5443f18203a07d6060f688f30f21632b"; + private static final String ONERTT_HEADER = "4200bff4"; + private static final int ONERTT_PAYLOAD_OFFSET = ONERTT_HEADER.length() / 2; + private static final int ONERTT_PN_OFFSET = 1; + private static final int ONERTT_PN = 654360564; + // payload is zero-padded to 1162 bytes, not shown here + private static final String ONERTT_PAYLOAD = + "01"; + private static final int ONERTT_PAYLOAD_LENGTH = + ONERTT_PAYLOAD.length() / 2; + private static final String ENCRYPTED_ONERTT_PAYLOAD = + "5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"; + + private static final class FixedHeaderContent implements IntFunction { + private final ByteBuffer header; + private FixedHeaderContent(ByteBuffer header) { + this.header = header; + } + + @Override + public ByteBuffer apply(final int keyphase) { + // ignore keyphase + return this.header; + } + } + + @Test + public void testEncryptClientInitialPacket() throws Exception { + QuicTLSEngine clientEngine = getQuicV2Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + clientEngine.deriveInitialKeys(QUIC_V2, dcid); + final int packetLen = INITIAL_C_PAYLOAD_OFFSET + INITIAL_C_PAYLOAD_LENGTH + 16; + final ByteBuffer packet = ByteBuffer.allocate(packetLen); + packet.put(HexFormat.of().parseHex(INITIAL_C_HEADER)); + packet.put(HexFormat.of().parseHex(INITIAL_C_PAYLOAD)); + + final ByteBuffer header = packet.slice(0, INITIAL_C_PAYLOAD_OFFSET).asReadOnlyBuffer(); + final ByteBuffer payload = packet.slice(INITIAL_C_PAYLOAD_OFFSET, INITIAL_C_PAYLOAD_LENGTH).asReadOnlyBuffer(); + + packet.position(INITIAL_C_PAYLOAD_OFFSET); + clientEngine.encryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_C_PN, new FixedHeaderContent(header), payload, packet); + protect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_C_PN_OFFSET, INITIAL_C_PAYLOAD_OFFSET - INITIAL_C_PN_OFFSET, clientEngine, 0x0f); + + assertEquals(HexFormat.of().formatHex(packet.array()), ENCRYPTED_C_PAYLOAD); + } + + @Test + public void testDecryptClientInitialPacket() throws Exception { + QuicTLSEngine serverEngine = getQuicV2Engine(SSLContext.getDefault(), false); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + serverEngine.deriveInitialKeys(QUIC_V2, dcid); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_C_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_C_PN_OFFSET, INITIAL_C_PAYLOAD_OFFSET - INITIAL_C_PN_OFFSET, serverEngine, 0x0f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(INITIAL_C_PAYLOAD_OFFSET); + + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_C_PN, -1, src, INITIAL_C_PAYLOAD_OFFSET, packet); + + String expectedContents = INITIAL_C_HEADER + INITIAL_C_PAYLOAD; + + assertEquals(HexFormat.of().formatHex(packet.array()).substring(0, expectedContents.length()), expectedContents); + } + + @Test(expectedExceptions = AEADBadTagException.class) + public void testDecryptClientInitialPacketBadTag() throws Exception { + QuicTLSEngine serverEngine = getQuicV2Engine(SSLContext.getDefault(), false); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + serverEngine.deriveInitialKeys(QUIC_V2, dcid); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_C_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_C_PN_OFFSET, INITIAL_C_PAYLOAD_OFFSET - INITIAL_C_PN_OFFSET, serverEngine, 0x0f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(INITIAL_C_PAYLOAD_OFFSET); + + // change one byte of AEAD tag + packet.put(packet.limit() - 1, (byte)0); + + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_C_PN, -1, src, INITIAL_C_PAYLOAD_OFFSET, packet); + fail("Decryption should have failed"); + } + + @Test + public void testEncryptServerInitialPacket() throws Exception { + QuicTLSEngine serverEngine = getQuicV2Engine(SSLContext.getDefault(), false); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + serverEngine.deriveInitialKeys(QUIC_V2, dcid); + + final int packetLen = INITIAL_S_PAYLOAD_OFFSET + INITIAL_S_PAYLOAD_LENGTH + 16; + final ByteBuffer packet = ByteBuffer.allocate(packetLen); + packet.put(HexFormat.of().parseHex(INITIAL_S_HEADER)); + packet.put(HexFormat.of().parseHex(INITIAL_S_PAYLOAD)); + + final ByteBuffer header = packet.slice(0, INITIAL_S_PAYLOAD_OFFSET).asReadOnlyBuffer(); + final ByteBuffer payload = packet.slice(INITIAL_S_PAYLOAD_OFFSET, INITIAL_S_PAYLOAD_LENGTH).asReadOnlyBuffer(); + + packet.position(INITIAL_S_PAYLOAD_OFFSET); + serverEngine.encryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_S_PN, new FixedHeaderContent(header), payload, packet); + protect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_S_PN_OFFSET, INITIAL_S_PAYLOAD_OFFSET - INITIAL_S_PN_OFFSET, serverEngine, 0x0f); + + assertEquals(HexFormat.of().formatHex(packet.array()), ENCRYPTED_S_PAYLOAD); + } + + @Test + public void testDecryptServerInitialPacket() throws Exception { + QuicTLSEngine clientEngine = getQuicV2Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + clientEngine.deriveInitialKeys(QUIC_V2, dcid); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_S_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_S_PN_OFFSET, INITIAL_S_PAYLOAD_OFFSET - INITIAL_S_PN_OFFSET, clientEngine, 0x0f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(INITIAL_S_PAYLOAD_OFFSET); + + clientEngine.decryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_S_PN, -1, src, INITIAL_S_PAYLOAD_OFFSET, packet); + + String expectedContents = INITIAL_S_HEADER + INITIAL_S_PAYLOAD; + + assertEquals(HexFormat.of().formatHex(packet.array()).substring(0, expectedContents.length()), expectedContents); + } + + @Test + public void testDecryptServerInitialPacketTwice() throws Exception { + // verify that decrypting the same packet twice does not throw + QuicTLSEngine clientEngine = getQuicV2Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + clientEngine.deriveInitialKeys(QUIC_V2, dcid); + + // attempt 1 + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_S_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_S_PN_OFFSET, INITIAL_S_PAYLOAD_OFFSET - INITIAL_S_PN_OFFSET, clientEngine, 0x0f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(INITIAL_S_PAYLOAD_OFFSET); + clientEngine.decryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_S_PN, -1, src, INITIAL_S_PAYLOAD_OFFSET, packet); + + // attempt 2 + packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_S_PAYLOAD)); + // must not throw + unprotect(QuicTLSEngine.KeySpace.INITIAL, packet, INITIAL_S_PN_OFFSET, INITIAL_S_PAYLOAD_OFFSET - INITIAL_S_PN_OFFSET, clientEngine, 0x0f); + src = packet.asReadOnlyBuffer(); + packet.position(INITIAL_S_PAYLOAD_OFFSET); + // must not throw + clientEngine.decryptPacket(QuicTLSEngine.KeySpace.INITIAL, INITIAL_S_PN, -1, src, INITIAL_S_PAYLOAD_OFFSET, packet); + } + + @Test + public void testSignRetry() throws NoSuchAlgorithmException, ShortBufferException, QuicTransportException { + QuicTLSEngine clientEngine = getQuicV2Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + + ByteBuffer packet = ByteBuffer.allocate(SIGNED_RETRY.length() / 2); + packet.put(HexFormat.of().parseHex(SIGNED_RETRY), 0, SIGNED_RETRY.length() / 2 - 16); + + ByteBuffer src = packet.asReadOnlyBuffer(); + src.limit(src.position()); + src.position(0); + + clientEngine.signRetryPacket(QUIC_V2, dcid, src, packet); + + assertEquals(HexFormat.of().formatHex(packet.array()), SIGNED_RETRY); + } + + @Test + public void testVerifyRetry() throws NoSuchAlgorithmException, AEADBadTagException, QuicTransportException { + QuicTLSEngine clientEngine = getQuicV2Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(SIGNED_RETRY)); + + clientEngine.verifyRetryPacket(QUIC_V2, dcid, packet); + } + + @Test(expectedExceptions = AEADBadTagException.class) + public void testVerifyBadRetry() throws NoSuchAlgorithmException, AEADBadTagException, QuicTransportException { + QuicTLSEngine clientEngine = getQuicV2Engine(SSLContext.getDefault(), true); + ByteBuffer dcid = ByteBuffer.wrap(HexFormat.of().parseHex(INITIAL_DCID)); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(SIGNED_RETRY)); + + // change one byte of AEAD tag + packet.put(packet.limit() - 1, (byte)0); + clientEngine.verifyRetryPacket(QUIC_V2, dcid, packet); + fail("Verification should have failed"); + } + + @Test + public void testEncryptChaCha() throws Exception { + QuicTLSEngineImpl clientEngine = (QuicTLSEngineImpl) getQuicV2Engine(SSLContext.getDefault(), true); + SecretKey key = new SecretKeySpec(HexFormat.of().parseHex(ONERTT_SECRET), 0, 32, "ChaCha20-Poly1305"); + QuicTLSEngineImplAccessor.testDeriveOneRTTKeys(QUIC_V2, clientEngine, key, key, "TLS_CHACHA20_POLY1305_SHA256", true); + + final int packetLen = ONERTT_PAYLOAD_OFFSET + ONERTT_PAYLOAD_LENGTH + 16; + final ByteBuffer packet = ByteBuffer.allocate(packetLen); + packet.put(HexFormat.of().parseHex(ONERTT_HEADER)); + packet.put(HexFormat.of().parseHex(ONERTT_PAYLOAD)); + + final ByteBuffer header = packet.slice(0, ONERTT_PAYLOAD_OFFSET).asReadOnlyBuffer(); + final ByteBuffer payload = packet.slice(ONERTT_PAYLOAD_OFFSET, ONERTT_PAYLOAD_LENGTH).asReadOnlyBuffer(); + + packet.position(ONERTT_PAYLOAD_OFFSET); + clientEngine.encryptPacket(QuicTLSEngine.KeySpace.ONE_RTT, ONERTT_PN , new FixedHeaderContent(header), payload, packet); + protect(QuicTLSEngine.KeySpace.ONE_RTT, packet, ONERTT_PN_OFFSET, ONERTT_PAYLOAD_OFFSET - ONERTT_PN_OFFSET, clientEngine, 0x1f); + + assertEquals(HexFormat.of().formatHex(packet.array()), ENCRYPTED_ONERTT_PAYLOAD); + } + + @Test + public void testDecryptChaCha() throws Exception { + QuicTLSEngineImpl serverEngine = (QuicTLSEngineImpl) getQuicV2Engine(SSLContext.getDefault(), false); + // mark the TLS handshake as FINISHED + QuicTLSEngineImplAccessor.completeHandshake(serverEngine); + SecretKey key = new SecretKeySpec(HexFormat.of().parseHex(ONERTT_SECRET), 0, 32, "ChaCha20-Poly1305"); + QuicTLSEngineImplAccessor.testDeriveOneRTTKeys(QUIC_V2, serverEngine, key, key, "TLS_CHACHA20_POLY1305_SHA256", false); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_ONERTT_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.ONE_RTT, packet, ONERTT_PN_OFFSET, ONERTT_PAYLOAD_OFFSET - ONERTT_PN_OFFSET, serverEngine, 0x1f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(ONERTT_PAYLOAD_OFFSET); + + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.ONE_RTT, ONERTT_PN, 0, src, ONERTT_PAYLOAD_OFFSET, packet); + + String expectedContents = ONERTT_HEADER + ONERTT_PAYLOAD; + + assertEquals(HexFormat.of().formatHex(packet.array()).substring(0, expectedContents.length()), expectedContents); + } + + @Test + public void testDecryptChaChaTwice() throws Exception { + // verify that decrypting the same packet twice does not throw + QuicTLSEngineImpl serverEngine = (QuicTLSEngineImpl) getQuicV2Engine(SSLContext.getDefault(), false); + // mark the TLS handshake as FINISHED + QuicTLSEngineImplAccessor.completeHandshake(serverEngine); + SecretKey key = new SecretKeySpec(HexFormat.of().parseHex(ONERTT_SECRET), 0, 32, "ChaCha20-Poly1305"); + QuicTLSEngineImplAccessor.testDeriveOneRTTKeys(QUIC_V2, serverEngine, key, key, "TLS_CHACHA20_POLY1305_SHA256", false); + + // attempt 1 + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_ONERTT_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.ONE_RTT, packet, ONERTT_PN_OFFSET, ONERTT_PAYLOAD_OFFSET - ONERTT_PN_OFFSET, serverEngine, 0x1f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(ONERTT_PAYLOAD_OFFSET); + final int keyphase = 0; + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.ONE_RTT, ONERTT_PN, keyphase, src, ONERTT_PAYLOAD_OFFSET, packet); + + // attempt 2 + packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_ONERTT_PAYLOAD)); + // must not throw + unprotect(QuicTLSEngine.KeySpace.ONE_RTT, packet, ONERTT_PN_OFFSET, ONERTT_PAYLOAD_OFFSET - ONERTT_PN_OFFSET, serverEngine, 0x1f); + src = packet.asReadOnlyBuffer(); + packet.position(ONERTT_PAYLOAD_OFFSET); + // must not throw + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.ONE_RTT, ONERTT_PN, keyphase, src, ONERTT_PAYLOAD_OFFSET, packet); + } + + @Test(expectedExceptions = AEADBadTagException.class) + public void testDecryptChaChaBadTag() throws Exception { + QuicTLSEngineImpl serverEngine = (QuicTLSEngineImpl) getQuicV2Engine(SSLContext.getDefault(), false); + // mark the TLS handshake as FINISHED + QuicTLSEngineImplAccessor.completeHandshake(serverEngine); + SecretKey key = new SecretKeySpec(HexFormat.of().parseHex(ONERTT_SECRET), 0, 32, "ChaCha20-Poly1305"); + QuicTLSEngineImplAccessor.testDeriveOneRTTKeys(QUIC_V2, serverEngine, key, key, "TLS_CHACHA20_POLY1305_SHA256", false); + + ByteBuffer packet = ByteBuffer.wrap(HexFormat.of().parseHex(ENCRYPTED_ONERTT_PAYLOAD)); + unprotect(QuicTLSEngine.KeySpace.ONE_RTT, packet, ONERTT_PN_OFFSET, ONERTT_PAYLOAD_OFFSET - ONERTT_PN_OFFSET, serverEngine, 0x1f); + ByteBuffer src = packet.asReadOnlyBuffer(); + packet.position(ONERTT_PAYLOAD_OFFSET); + + // change one byte of AEAD tag + packet.put(packet.limit() - 1, (byte)0); + + serverEngine.decryptPacket(QuicTLSEngine.KeySpace.ONE_RTT, ONERTT_PN, 0, src, ONERTT_PAYLOAD_OFFSET, packet); + fail("Decryption should have failed"); + } + + + private void protect(QuicTLSEngine.KeySpace space, ByteBuffer buffer, int packetNumberStart, + int packetNumberLength, QuicTLSEngine tlsEngine, int headersMask) + throws QuicKeyUnavailableException, QuicTransportException { + ByteBuffer sample = buffer.slice(packetNumberStart + 4, 16); + ByteBuffer encryptedSample = tlsEngine.computeHeaderProtectionMask(space, false, sample); + byte headers = buffer.get(0); + headers ^= encryptedSample.get() & headersMask; + buffer.put(0, headers); + maskPacketNumber(buffer, packetNumberStart, packetNumberLength, encryptedSample); + } + + private void unprotect(QuicTLSEngine.KeySpace keySpace, ByteBuffer buffer, int packetNumberStart, + int packetNumberLength, QuicTLSEngine tlsEngine, int headersMask) + throws QuicKeyUnavailableException, QuicTransportException { + ByteBuffer sample = buffer.slice(packetNumberStart + 4, 16); + ByteBuffer encryptedSample = tlsEngine.computeHeaderProtectionMask(keySpace, true, sample); + byte headers = buffer.get(0); + headers ^= encryptedSample.get() & headersMask; + buffer.put(0, headers); + maskPacketNumber(buffer, packetNumberStart, packetNumberLength, encryptedSample); + } + + private void maskPacketNumber(ByteBuffer buffer, int packetNumberStart, int packetNumberLength, ByteBuffer mask) { + for (int i = 0; i < packetNumberLength; i++) { + buffer.put(packetNumberStart + i, (byte)(buffer.get(packetNumberStart + i) ^ mask.get())); + } + } + + // returns a QuicTLSEngine with only Quic version 2 enabled + private QuicTLSEngine getQuicV2Engine(SSLContext context, boolean mode) { + final QuicTLSContext quicTLSContext = new QuicTLSContext(context); + final QuicTLSEngine engine = quicTLSContext.createEngine(); + engine.setUseClientMode(mode); + return engine; + } +} diff --git a/test/jdk/java/net/httpclient/quic/tls/java.base/sun/security/ssl/QuicTLSEngineImplAccessor.java b/test/jdk/java/net/httpclient/quic/tls/java.base/sun/security/ssl/QuicTLSEngineImplAccessor.java new file mode 100644 index 00000000000..e9a57310e0b --- /dev/null +++ b/test/jdk/java/net/httpclient/quic/tls/java.base/sun/security/ssl/QuicTLSEngineImplAccessor.java @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package sun.security.ssl; + +import javax.crypto.SecretKey; +import java.io.IOException; +import java.lang.reflect.Field; +import java.security.GeneralSecurityException; + +import jdk.internal.net.quic.QuicTLSEngine; +import jdk.internal.net.quic.QuicVersion; + +public final class QuicTLSEngineImplAccessor { + // visible for testing + public static void testDeriveOneRTTKeys(QuicVersion version, + QuicTLSEngineImpl engine, + SecretKey client_application_traffic_secret_0, + SecretKey server_application_traffic_secret_0, + String negotiatedCipherSuite, + boolean clientMode) + throws IOException, GeneralSecurityException + { + engine.deriveOneRTTKeys(version, client_application_traffic_secret_0, + server_application_traffic_secret_0, + CipherSuite.valueOf(negotiatedCipherSuite), + clientMode); + } + + // visible for testing + public static void completeHandshake(final QuicTLSEngineImpl engine) { + try { + final Field f = QuicTLSEngineImpl.class.getDeclaredField("handshakeState"); + f.setAccessible(true); + f.set(engine, QuicTLSEngine.HandshakeState.HANDSHAKE_CONFIRMED); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/test/jdk/java/net/httpclient/ssltest/CertificateTest.java b/test/jdk/java/net/httpclient/ssltest/CertificateTest.java index 96ba5c05ed3..1d0502c9ce3 100644 --- a/test/jdk/java/net/httpclient/ssltest/CertificateTest.java +++ b/test/jdk/java/net/httpclient/ssltest/CertificateTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -39,8 +39,9 @@ import jdk.test.lib.security.SSLContextBuilder; /* * @test - * @library /test/lib - * @build Server CertificateTest + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build Server CertificateTest jdk.httpclient.test.lib.common.TestServerConfigurator + * @modules java.net.http/jdk.internal.net.http.common * @run main/othervm CertificateTest GOOD_CERT expectSuccess * @run main/othervm CertificateTest BAD_CERT expectFailure * @run main/othervm diff --git a/test/jdk/java/net/httpclient/ssltest/Server.java b/test/jdk/java/net/httpclient/ssltest/Server.java index fc954016991..d2d98268e14 100644 --- a/test/jdk/java/net/httpclient/ssltest/Server.java +++ b/test/jdk/java/net/httpclient/ssltest/Server.java @@ -22,6 +22,7 @@ */ import com.sun.net.httpserver.*; +import jdk.httpclient.test.lib.common.TestServerConfigurator; import java.io.*; import java.net.InetAddress; @@ -44,9 +45,9 @@ public class Server { // response with a short text string. public Server(SSLContext ctx) throws Exception { initLogger(); - Configurator cfg = new Configurator(ctx); InetSocketAddress addr = new InetSocketAddress( InetAddress.getLoopbackAddress(), 0); + Configurator cfg = new Configurator(addr.getAddress(), ctx); server = HttpsServer.create(addr, 10); server.setHttpsConfigurator(cfg); server.createContext("/", new MyHandler()); @@ -94,15 +95,20 @@ public class Server { } class Configurator extends HttpsConfigurator { - public Configurator(SSLContext ctx) throws Exception { + private final InetAddress serverAddr; + + public Configurator(InetAddress addr, SSLContext ctx) throws Exception { super(ctx); + this.serverAddr = addr; } + @Override public void configure(HttpsParameters params) { SSLParameters p = getSSLContext().getDefaultSSLParameters(); for (String cipher : p.getCipherSuites()) System.out.println("Cipher: " + cipher); System.err.println("Params = " + p); + TestServerConfigurator.addSNIMatcher(this.serverAddr, p); params.setSSLParameters(p); } } diff --git a/test/jdk/java/net/httpclient/ssltest/TlsVersionTest.java b/test/jdk/java/net/httpclient/ssltest/TlsVersionTest.java index 322a3e3789b..2e364461793 100644 --- a/test/jdk/java/net/httpclient/ssltest/TlsVersionTest.java +++ b/test/jdk/java/net/httpclient/ssltest/TlsVersionTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -32,14 +32,19 @@ import jdk.test.lib.net.URIBuilder; import jdk.test.lib.security.KeyEntry; import jdk.test.lib.security.KeyStoreUtils; import jdk.test.lib.security.SSLContextBuilder; + +import static java.net.http.HttpClient.Version.HTTP_1_1; +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.net.http.HttpClient.Version.HTTP_3; import static java.net.http.HttpResponse.BodyHandlers.ofString; import static java.net.http.HttpClient.Builder.NO_PROXY; /* * @test * @bug 8239594 8239595 - * @library /test/lib - * @build Server TlsVersionTest + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build Server TlsVersionTest jdk.httpclient.test.lib.common.TestServerConfigurator + * @modules java.net.http/jdk.internal.net.http.common * @run main/othervm * -Djdk.internal.httpclient.disableHostnameVerification * TlsVersionTest false @@ -121,11 +126,12 @@ public class TlsVersionTest { System.out.println("Making request to " + serverURI.getPath()); SSLContext ctx = getClientSSLContext(cert); HttpClient client = HttpClient.newBuilder() + .version(HTTP_2) .proxy(NO_PROXY) .sslContext(ctx) .build(); - for (var version : List.of(HttpClient.Version.HTTP_2, HttpClient.Version.HTTP_1_1)) { + for (var version : List.of(HTTP_3, HTTP_2, HTTP_1_1)) { HttpRequest request = HttpRequest.newBuilder(serverURI) .version(version) .GET() diff --git a/test/jdk/java/net/httpclient/websocket/HandshakeUrlEncodingTest.java b/test/jdk/java/net/httpclient/websocket/HandshakeUrlEncodingTest.java index 5065563c9ec..6d242a596e7 100644 --- a/test/jdk/java/net/httpclient/websocket/HandshakeUrlEncodingTest.java +++ b/test/jdk/java/net/httpclient/websocket/HandshakeUrlEncodingTest.java @@ -205,4 +205,3 @@ public class HandshakeUrlEncodingTest { } } } - diff --git a/test/jdk/java/net/httpclient/websocket/ReaderDriver.java b/test/jdk/java/net/httpclient/websocket/ReaderDriver.java index 61829a5a010..0b874792b4f 100644 --- a/test/jdk/java/net/httpclient/websocket/ReaderDriver.java +++ b/test/jdk/java/net/httpclient/websocket/ReaderDriver.java @@ -25,6 +25,6 @@ * @test * @bug 8159053 * @modules java.net.http/jdk.internal.net.http.websocket:open - * @run testng/othervm --add-reads java.net.http=ALL-UNNAMED java.net.http/jdk.internal.net.http.websocket.ReaderTest + * @run testng/othervm/timeout=240 --add-reads java.net.http=ALL-UNNAMED java.net.http/jdk.internal.net.http.websocket.ReaderTest */ public final class ReaderDriver { } diff --git a/test/jdk/java/net/httpclient/whitebox/AltSvcFrameTest.java b/test/jdk/java/net/httpclient/whitebox/AltSvcFrameTest.java new file mode 100644 index 00000000000..949c950fd76 --- /dev/null +++ b/test/jdk/java/net/httpclient/whitebox/AltSvcFrameTest.java @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.List; +import java.util.Optional; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSession; + +import jdk.httpclient.test.lib.http2.BodyOutputStream; +import jdk.httpclient.test.lib.http2.Http2Handler; +import jdk.httpclient.test.lib.http2.Http2TestExchange; +import jdk.httpclient.test.lib.http2.Http2TestExchangeImpl; +import jdk.httpclient.test.lib.http2.Http2TestServer; +import jdk.httpclient.test.lib.http2.Http2TestServerConnection; +import jdk.internal.net.http.AltServicesRegistry; +import jdk.internal.net.http.HttpClientAccess; +import jdk.internal.net.http.common.HttpHeadersBuilder; +import jdk.internal.net.http.frame.AltSvcFrame; +import jdk.test.lib.net.SimpleSSLContext; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; +import static java.net.http.HttpResponse.BodyHandlers.ofString; +import static jdk.internal.net.http.AltServicesRegistry.AltService; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +/* + * @test + * @summary This test verifies alt-svc registry updation for frames + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build java.net.http/jdk.internal.net.http.HttpClientAccess + * jdk.httpclient.test.lib.http2.Http2TestServer + * @modules java.net.http/jdk.internal.net.http + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.frame + * java.net.http/jdk.internal.net.http.hpack + * java.base/jdk.internal.util + * java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.packets + * java.net.http/jdk.internal.net.http.quic.frames + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.logging + * java.base/sun.net.www.http + * java.base/sun.net.www + * java.base/sun.net + * @run testng/othervm + * -Dtest.requiresHost=true + * -Djdk.httpclient.HttpClient.log=headers + * -Djdk.internal.httpclient.disableHostnameVerification + * -Djdk.internal.httpclient.debug=true + * AltSvcFrameTest + */ + + +public class AltSvcFrameTest { + + private static final String IGNORED_HOST = "www.should-be-ignored.com"; + private static final String ALT_SVC_TO_IGNORE = "h3=\"" + IGNORED_HOST + ":443\""; + + private static final String ACCEPTED_HOST = "www.example.com"; + private static final String ALT_SVC_TO_ACCEPT = "h3=\"" + ACCEPTED_HOST + ":443\""; + + private static final String STREAM_0_ACCEPTED_HOST = "jdk.java.net"; + private static final String STREAM_0_ALT_SVC_TO_ACCEPT = "h3=\"" + STREAM_0_ACCEPTED_HOST + ":443\""; + + private static final String FOO_BAR_ORIGIN = "https://www.foo-bar.hello-world:443"; + + static Http2TestServer https2Server; + static String https2URI; + static HttpClient client; + static SSLContext server; + + @BeforeTest + public void setUp() throws Exception { + server = SimpleSSLContext.getContext("TLS"); + getRegistry(); + https2Server = new Http2TestServer("localhost", true, server); + https2Server.addHandler(new AltSvcFrameTestHandler(), "/"); + https2Server.setExchangeSupplier(AltSvcFrameTest.CFTHttp2TestExchange::new); + https2Server.start(); + https2URI = "https://" + https2Server.serverAuthority() + "/"; + + + } + + static AltServicesRegistry getRegistry() { + client = HttpClient.newBuilder() + .sslContext(server) + .version(HttpClient.Version.HTTP_2) + .build(); + return HttpClientAccess.getRegistry(client); + } + + /* + * Verify handling of alt-svc frame on a stream other than stream 0 + */ + @Test + public void testNonStream0AltSvcFrame() throws URISyntaxException, IOException, InterruptedException { + AltServicesRegistry registry = getRegistry(); + HttpRequest request = HttpRequest.newBuilder(new URI(https2URI)) + .GET() + .build(); + HttpResponse response = client.send(request, ofString()); + assertEquals(response.statusCode(), 200, "unexpected response code"); + final List services = registry.lookup(URI.create(https2URI), "h3").toList(); + System.out.println("Alt services in registry for " + https2URI + " = " + services); + final boolean hasExpectedAltSvc = services.stream().anyMatch( + alt -> alt.alpn().equals("h3") + && alt.host().contains(ACCEPTED_HOST)); + assertTrue(hasExpectedAltSvc, "missing entry in alt service registry for origin: " + https2URI); + } + + /* + * Verify handling of alt-svc frame on stream 0 of the connection + */ + @Test + public void testStream0AltSvcFrame() throws URISyntaxException, IOException, InterruptedException { + AltServicesRegistry registry = getRegistry(); + HttpRequest request = HttpRequest.newBuilder(new URI(https2URI + "?altsvc-on-stream-0")) + .GET() + .build(); + HttpResponse response = client.send(request, ofString()); + assertEquals(response.statusCode(), 200, "unexpected response code"); + final List services = registry.lookup( + URI.create(FOO_BAR_ORIGIN), "h3").toList(); + System.out.println("Alt services in registry for " + FOO_BAR_ORIGIN + " = " + services); + final boolean containsIgnoredHost = services.stream().anyMatch( + alt -> alt.alpn().equals("h3") + && alt.host().contains(IGNORED_HOST)); + assertFalse(containsIgnoredHost, "unexpected alt service in the registry for origin: " + + FOO_BAR_ORIGIN); + + final List svcs = registry.lookup(URI.create(https2URI), "h3").toList(); + System.out.println("Alt services in registry for " + https2URI + " = " + svcs); + final boolean hasExpectedAltSvc = svcs.stream().anyMatch( + alt -> alt.alpn().equals("h3") + && alt.host().contains(STREAM_0_ACCEPTED_HOST)); + assertTrue(hasExpectedAltSvc, "missing entry in alt service registry for origin: " + https2URI); + } + + static class AltSvcFrameTestHandler implements Http2Handler { + + @Override + public void handle(Http2TestExchange t) throws IOException { + try (InputStream is = t.getRequestBody(); + OutputStream os = t.getResponseBody()) { + byte[] bytes = is.readAllBytes(); + t.sendResponseHeaders(200, bytes.length); + os.write(bytes); + } + } + } + + // A custom Http2TestExchangeImpl that overrides sendResponseHeaders to + // allow headers to be sent with AltSvcFrame. + static class CFTHttp2TestExchange extends Http2TestExchangeImpl { + + CFTHttp2TestExchange(int streamid, String method, HttpHeaders reqheaders, + HttpHeadersBuilder rspheadersBuilder, URI uri, InputStream is, + SSLSession sslSession, BodyOutputStream os, + Http2TestServerConnection conn, boolean pushAllowed) { + super(streamid, method, reqheaders, rspheadersBuilder, uri, is, sslSession, + os, conn, pushAllowed); + + } + + @Override + public void sendResponseHeaders(int rCode, long responseLength) throws IOException { + final String reqQuery = getRequestURI().getQuery(); + if (reqQuery != null && reqQuery.contains("altsvc-on-stream-0")) { + final InetSocketAddress addr = this.getLocalAddress(); + final String connectionOrigin = "https://" + addr.getAddress().getHostAddress() + + ":" + addr.getPort(); + // send one alt-svc on stream 0 with the same Origin as the connection's Origin + enqueueAltSvcFrame(new AltSvcFrame(0, 0, + // the Origin for which the alt-svc is being advertised + Optional.of(connectionOrigin), + STREAM_0_ALT_SVC_TO_ACCEPT)); + + // send one alt-svc on stream 0 with a different Origin than the connection's Origin + enqueueAltSvcFrame(new AltSvcFrame(0, 0, + // the Origin for which the alt-svc is being advertised + Optional.of(FOO_BAR_ORIGIN), + ALT_SVC_TO_IGNORE)); + } else { + // send alt-svc on non-zero stream id. + // for non-zero stream id, as per spec, the origin is inferred from the stream's origin + // by the HTTP client + enqueueAltSvcFrame(new AltSvcFrame(streamid, 0, Optional.empty(), ALT_SVC_TO_ACCEPT)); + } + super.sendResponseHeaders(rCode, responseLength); + System.out.println("Sent response headers " + rCode); + } + + private void enqueueAltSvcFrame(final AltSvcFrame frame) throws IOException { + System.out.println("enqueueing Alt-Svc frame: " + frame); + conn.addToOutputQ(frame); + } + } +} diff --git a/test/jdk/java/net/httpclient/whitebox/AltSvcRegistryTest.java b/test/jdk/java/net/httpclient/whitebox/AltSvcRegistryTest.java new file mode 100644 index 00000000000..47e69338715 --- /dev/null +++ b/test/jdk/java/net/httpclient/whitebox/AltSvcRegistryTest.java @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.internal.net.http.HttpClientAccess; +import jdk.internal.net.http.AltServicesRegistry; +import jdk.test.lib.net.SimpleSSLContext; +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import jdk.httpclient.test.lib.http2.Http2TestServer; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + + +import static jdk.internal.net.http.AltServicesRegistry.AltService; +import static java.net.http.HttpResponse.BodyHandlers.ofString; + +/* + * @test + * @summary This test verifies alt-svc registry updates + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build java.net.http/jdk.internal.net.http.HttpClientAccess + * jdk.httpclient.test.lib.common.HttpServerAdapters + * jdk.httpclient.test.lib.http2.Http2TestServer + * @modules java.net.http/jdk.internal.net.http + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.frame + * java.net.http/jdk.internal.net.http.hpack + * java.base/jdk.internal.net.quic + * java.net.http/jdk.internal.net.http.quic + * java.net.http/jdk.internal.net.http.quic.packets + * java.net.http/jdk.internal.net.http.quic.frames + * java.net.http/jdk.internal.net.http.quic.streams + * java.net.http/jdk.internal.net.http.http3.streams + * java.net.http/jdk.internal.net.http.http3.frames + * java.net.http/jdk.internal.net.http.http3 + * java.net.http/jdk.internal.net.http.qpack + * java.net.http/jdk.internal.net.http.qpack.readers + * java.net.http/jdk.internal.net.http.qpack.writers + * java.logging + * java.base/sun.net.www.http + * java.base/sun.net.www + * java.base/sun.net + * java.base/jdk.internal.util + * @run testng/othervm + * -Dtest.requiresHost=true + * -Djdk.httpclient.HttpClient.log=headers + * -Djdk.internal.httpclient.disableHostnameVerification + * -Djdk.internal.httpclient.debug=true + * AltSvcRegistryTest + */ + + +public class AltSvcRegistryTest implements HttpServerAdapters { + + static HttpTestServer https2Server; + static String https2URI; + static HttpClient client; + static SSLContext server; + + @BeforeTest + public void setUp() throws Exception { + server = SimpleSSLContext.getContext("TLS"); + getRegistry(); + final ExecutorService executor = Executors.newCachedThreadPool(); + https2Server = HttpServerAdapters.HttpTestServer.of( + new Http2TestServer("localhost", true, 0, executor, 50, null, server, true)); + https2Server.addHandler(new AltSvcRegistryTestHandler("https", https2Server), "/"); + https2Server.start(); + https2URI = "https://" + https2Server.serverAuthority() + "/"; + + + } + + static AltServicesRegistry getRegistry() { + client = HttpClient.newBuilder() + .sslContext(server) + .version(HttpClient.Version.HTTP_2) + .build(); + return HttpClientAccess.getRegistry(client); + } + @Test + public void testAltSvcRegistry() throws URISyntaxException, IOException, InterruptedException { + AltServicesRegistry registry = getRegistry(); + HttpRequest request = HttpRequest.newBuilder(new URI(https2URI)) + .GET() + .build(); + HttpResponse response = client.send(request, ofString()); + assert response.statusCode() == 200; + List h3service = registry.lookup(URI.create(https2URI), "h3") + .toList(); + System.out.println("h3 services: " + h3service); + assert h3service.stream().anyMatch( alt -> alt.alpn().equals("h3") + && alt.host().equals("www.example.com") + && alt.port() == 443 + && alt.isPersist()); + assert h3service.stream().anyMatch( alt -> alt.alpn().equals("h3") + && alt.host().equals(request.uri().getHost()) + && alt.port() == 4567 + && !alt.isPersist()); + + List h34service = registry.lookup(URI.create(https2URI), "h3-34") + .toList(); + System.out.println("h3-34 services: " + h34service); + assert h34service.stream().noneMatch( alt -> alt.alpn().equals("h3-34")); + } + + static class AltSvcRegistryTestHandler implements HttpTestHandler { + final String scheme; + final HttpTestServer server; + + AltSvcRegistryTestHandler(String scheme, HttpTestServer server) { + this.scheme = scheme; + this.server = server; + } + + @Override + public void handle(HttpTestExchange t) throws IOException { + try (InputStream is = t.getRequestBody(); + OutputStream os = t.getResponseBody()) { + var altsvc = """ + h3-34=":5678", h3="www.example.com:443"; persist=1, h3=":4567"; persist=0""" ; + t.getResponseHeaders().addHeader("alt-svc", altsvc.trim()); + byte[] bytes = is.readAllBytes(); + t.sendResponseHeaders(200, 10); + os.write(bytes); + } + } + } +} diff --git a/test/jdk/java/net/httpclient/whitebox/java.net.http/jdk/internal/net/http/HttpClientAccess.java b/test/jdk/java/net/httpclient/whitebox/java.net.http/jdk/internal/net/http/HttpClientAccess.java new file mode 100644 index 00000000000..93cf04dbcc4 --- /dev/null +++ b/test/jdk/java/net/httpclient/whitebox/java.net.http/jdk/internal/net/http/HttpClientAccess.java @@ -0,0 +1,10 @@ +package jdk.internal.net.http; + +import java.net.http.HttpClient; + +public class HttpClientAccess { + + public static AltServicesRegistry getRegistry(HttpClient client) { + return ((HttpClientFacade)client).impl.registry(); + } +} diff --git a/test/jdk/jdk/internal/net/http/quic/packets/QuicPacketNumbersTest.java b/test/jdk/jdk/internal/net/http/quic/packets/QuicPacketNumbersTest.java new file mode 100644 index 00000000000..3ec16bbc51c --- /dev/null +++ b/test/jdk/jdk/internal/net/http/quic/packets/QuicPacketNumbersTest.java @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.nio.ByteBuffer; +import jdk.internal.net.http.quic.packets.*; +import java.util.Arrays; + +/** + * @test + * @summary Unit encoding/decoding tests for the QUICPacketNumbers. + * @modules java.net.http/jdk.internal.net.http.quic.packets + */ +public class QuicPacketNumbersTest { + + public static void main(String[] args) throws Exception { + + // Test encoding logic. + // Test values from the upcoming QUIC RFC, Appendix A.2. + checkEncodePacketNumber(0xac5c02L, 0xabe8b3L, + new byte[]{(byte) 0x5c, (byte) 0x02}); + checkEncodePacketNumber(0xace8feL, 0xabe8bcL, + new byte[]{(byte) 0xac, (byte) 0xe8, (byte) 0xfe} + ); + + // Various checks for "None" encodings. + checkEncodePacketNumber(0x00L, -1L, + new byte[]{(byte) 0x00}); + checkEncodePacketNumber(0x05L, -1L, + new byte[]{(byte) 0x05}); + checkEncodePacketNumber(0x7eL, -1L, + new byte[]{(byte) 0x7E}); + checkEncodePacketNumber(0x7fL, -1L, + new byte[]{(byte) 0x00, (byte) 0x7f}); + checkEncodePacketNumber(0x80L, -1L, + new byte[]{(byte) 0x00, (byte) 0x80}); + checkEncodePacketNumber(0xffL, -1L, + new byte[]{(byte) 0x00, (byte) 0xff}); + checkEncodePacketNumber(0x100L, -1L, + new byte[]{(byte) 0x01, (byte) 0x00}); + + // Various checks for a packet 0 that has been ack'd. + checkEncodePacketNumber(0x7fL, 0L, + new byte[]{(byte) 0x7f}); + checkEncodePacketNumber(0x80L, 0L, + new byte[]{(byte) 0x00, (byte) 0x80}); + checkEncodePacketNumber(0xffL, 0L, + new byte[]{(byte) 0x00, (byte) 0xff}); + checkEncodePacketNumber(0x100L, 0L, + new byte[]{(byte) 0x01, (byte) 0x00}); + checkEncodePacketNumber(0x7FFFL, 0L, + new byte[]{(byte) 0x7f, (byte) 0xFF}); + checkEncodePacketNumber(0x8000L, 0L, + new byte[]{(byte) 0x00, (byte) 0x80, (byte) 0x00}); + checkEncodePacketNumber(0x7FFFFFL, 0L, + new byte[]{(byte) 0x7f, (byte) 0xff, (byte) 0xFF}); + checkEncodePacketNumber(0x800000L, 0L, + new byte[]{(byte) 0x00, (byte) 0x80, (byte) 0x00, (byte) 0x00}); + checkEncodePacketNumber(0x7FFFFFFFL, 0L, + new byte[]{(byte) 0x7f, (byte) 0xff, (byte) 0xff, (byte) 0xFF}); + + // Check some similar truncations. + checkEncodePacketNumber(0x0101L, 0x82L, + new byte[]{(byte) 0x01}); + checkEncodePacketNumber(0x0101L, 0x81L, + new byte[]{(byte) 0x01, (byte) 0x01}); + checkEncodePacketNumber(0x10001L, 0xFF82, + new byte[]{(byte) 0x01}); + checkEncodePacketNumber(0x10001L, 0xFF81, + new byte[]{(byte) 0x00, (byte) 0x01}); + checkEncodePacketNumber(0x1000001L, 0xFFFF82, + new byte[]{(byte) 0x01}); + checkEncodePacketNumber(0x1000001L, 0xFFFF81, + new byte[]{(byte) 0x00, (byte) 0x01}); + + // Check that > 4 bytes are not generated. + try { + checkEncodePacketNumber(0x80000000L, 0L, + new byte[]{(byte) 0x00, (byte) 0x80, (byte) 0x00, + (byte) 0x00, (byte) 0x00}); + throw new Exception("Shouldn't encode"); + } catch (RuntimeException e) { + System.out.println("Caught the right exception!"); + } + + // Test decoding logic. + + // Test values from the upcoming QUIC RFC, Appendix A.3. + checkDecodePacketNumber(0xa82f30eaL, + ByteBuffer.wrap(new byte[]{(byte) 0x9b, (byte) 0x32}), + 2, 0xa82f9b32L); + + // TBD: More test values + } + + public static void checkEncodePacketNumber( + long fullPN, long largestAcked, byte[] bytes) throws Exception { + + byte[] answer = QuicPacketNumbers.encodePacketNumber( + fullPN, largestAcked); + + if (!Arrays.equals(answer, bytes)) { + throw new Exception("Encoding Problem"); + } + } + + public static void checkDecodePacketNumber( + long largestPN, ByteBuffer buf, int headerNumBytes, + long answer) throws Exception { + + long result = QuicPacketNumbers.decodePacketNumber( + largestPN, buf, headerNumBytes); + + if (result != answer) { + throw new Exception("Decoding Problem"); + } + } +}