diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Http2Connection.java b/src/java.net.http/share/classes/jdk/internal/net/http/Http2Connection.java index cd6b3c0b89f..4fa970ee0eb 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/Http2Connection.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http2Connection.java @@ -69,6 +69,7 @@ import jdk.internal.net.http.common.MinimalFuture; import jdk.internal.net.http.common.SequentialScheduler; import jdk.internal.net.http.common.Utils; import jdk.internal.net.http.common.ValidatingHeadersConsumer; +import jdk.internal.net.http.common.ValidatingHeadersConsumer.Context; import jdk.internal.net.http.frame.ContinuationFrame; import jdk.internal.net.http.frame.DataFrame; import jdk.internal.net.http.frame.ErrorFrame; @@ -89,7 +90,6 @@ import jdk.internal.net.http.hpack.Decoder; import jdk.internal.net.http.hpack.DecodingCallback; import jdk.internal.net.http.hpack.Encoder; import static java.nio.charset.StandardCharsets.UTF_8; -import static jdk.internal.net.http.frame.SettingsFrame.DEFAULT_INITIAL_WINDOW_SIZE; import static jdk.internal.net.http.frame.SettingsFrame.ENABLE_PUSH; import static jdk.internal.net.http.frame.SettingsFrame.HEADER_TABLE_SIZE; import static jdk.internal.net.http.frame.SettingsFrame.INITIAL_CONNECTION_WINDOW_SIZE; @@ -340,6 +340,7 @@ class Http2Connection { final AtomicReference errorRef = new AtomicReference<>(); PushPromiseDecoder(int parentStreamId, int pushPromiseStreamId, Stream parent) { + super(Context.REQUEST); this.parentStreamId = parentStreamId; this.pushPromiseStreamId = pushPromiseStreamId; this.parent = parent; @@ -984,7 +985,10 @@ class Http2Connection { // always decode the headers as they may affect // connection-level HPACK decoding state if (orphanedConsumer == null || frame.getClass() != ContinuationFrame.class) { - orphanedConsumer = new ValidatingHeadersConsumer(); + orphanedConsumer = new ValidatingHeadersConsumer( + frame instanceof PushPromiseFrame ? + Context.REQUEST : + Context.RESPONSE); } DecodingCallback decoder = orphanedConsumer::onDecoded; try { diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java b/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java index 042d6c0ebfe..bfc4e75c021 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java @@ -1871,7 +1871,12 @@ class Stream extends ExchangeImpl { } } - private class HeadersConsumer extends ValidatingHeadersConsumer implements DecodingCallback { + private final class HeadersConsumer extends ValidatingHeadersConsumer + implements DecodingCallback { + + private HeadersConsumer() { + super(Context.RESPONSE); + } boolean maxHeaderListSizeReached; diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/HeaderDecoder.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/HeaderDecoder.java index 92423314553..d9f855ca891 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/HeaderDecoder.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/HeaderDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -30,7 +30,8 @@ public class HeaderDecoder extends ValidatingHeadersConsumer { private final HttpHeadersBuilder headersBuilder; - public HeaderDecoder() { + public HeaderDecoder(Context context) { + super(context); this.headersBuilder = new HttpHeadersBuilder(); } diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/common/ValidatingHeadersConsumer.java b/src/java.net.http/share/classes/jdk/internal/net/http/common/ValidatingHeadersConsumer.java index db873cdb05e..6c6cacace90 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/ValidatingHeadersConsumer.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/ValidatingHeadersConsumer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -26,6 +26,9 @@ package jdk.internal.net.http.common; import java.io.IOException; import java.io.UncheckedIOException; +import java.net.ProtocolException; +import java.util.Map; +import java.util.Objects; import java.util.Set; /* @@ -33,8 +36,29 @@ import java.util.Set; */ public class ValidatingHeadersConsumer { - private static final Set PSEUDO_HEADERS = - Set.of(":authority", ":method", ":path", ":scheme", ":status"); + private final Context context; + + public ValidatingHeadersConsumer(Context context) { + this.context = Objects.requireNonNull(context); + } + + public enum Context { + REQUEST, + RESPONSE, + } + + // Map of permitted pseudo headers in requests and responses + private static final Map PSEUDO_HEADERS = + Map.of(":authority", Context.REQUEST, + ":method", Context.REQUEST, + ":path", Context.REQUEST, + ":scheme", Context.REQUEST, + ":status", Context.RESPONSE); + + // connection-specific, prohibited by RFC 9113 section 8.2.2 + private static final Set PROHIBITED_HEADERS = + Set.of("connection", "proxy-connection", "keep-alive", + "transfer-encoding", "upgrade"); /** Used to check that if there are pseudo-headers, they go first */ private boolean pseudoHeadersEnded; @@ -60,11 +84,25 @@ public class ValidatingHeadersConsumer { if (n.startsWith(":")) { if (pseudoHeadersEnded) { throw newException("Unexpected pseudo-header '%s'", n); - } else if (!PSEUDO_HEADERS.contains(n)) { - throw newException("Unknown pseudo-header '%s'", n); + } else { + Context expectedContext = PSEUDO_HEADERS.get(n); + if (expectedContext == null) { + throw newException("Unknown pseudo-header '%s'", n); + } else if (expectedContext != context) { + throw newException("Pseudo-header '%s' is not valid in context " + context, n); + } } } else { pseudoHeadersEnded = true; + // Check for prohibited connection-specific headers. + // Some servers echo request headers in push promises. + // If the request was a HTTP/1.1 upgrade, it included some prohibited headers. + // For compatibility, we ignore prohibited headers in push promises. + if (context != Context.REQUEST) { + if (PROHIBITED_HEADERS.contains(n)) { + throw newException("Prohibited header name '%s'", n); + } + } // RFC-9113, section 8.2.1 for HTTP/2 and RFC-9114, section 4.2 state that // header name MUST be lowercase (and allowed characters) if (!Utils.isValidLowerCaseName(n)) { @@ -84,6 +122,6 @@ public class ValidatingHeadersConsumer { protected UncheckedIOException newException(String message, String header) { return new UncheckedIOException( - new IOException(formatMessage(message, header))); + new ProtocolException(formatMessage(message, header))); } } diff --git a/test/jdk/java/net/httpclient/http2/BadHeadersTest.java b/test/jdk/java/net/httpclient/http2/BadHeadersTest.java index 062b4c89e09..35e1ea19950 100644 --- a/test/jdk/java/net/httpclient/http2/BadHeadersTest.java +++ b/test/jdk/java/net/httpclient/http2/BadHeadersTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -23,7 +23,10 @@ /* * @test - * @bug 8303965 + * @bug 8303965 8354276 + * @summary This test verifies the behaviour of the HttpClient when presented + * with a HEADERS frame followed by CONTINUATION frames, and when presented + * with bad header fields. * @library /test/lib /test/jdk/java/net/httpclient/lib * @build jdk.httpclient.test.lib.http2.Http2TestServer jdk.test.lib.net.SimpleSSLContext * @run testng/othervm -Djdk.internal.httpclient.debug=true BadHeadersTest @@ -44,6 +47,7 @@ import javax.net.ssl.SSLSession; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.net.ProtocolException; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpHeaders; @@ -76,6 +80,8 @@ public class BadHeadersTest { of(entry(":status", "200"), entry("hell o", "value")), // Space in the name of(entry(":status", "200"), entry("hello", "line1\r\n line2\r\n")), // Multiline value of(entry(":status", "200"), entry("hello", "DE" + ((char) 0x7F) + "L")), // Bad byte in value + of(entry(":status", "200"), entry("connection", "close")), // Prohibited connection-specific header + of(entry(":status", "200"), entry(":scheme", "https")), // Request pseudo-header in response of(entry("hello", "world!"), entry(":status", "200")) // Pseudo header is not the first one ); @@ -86,7 +92,7 @@ public class BadHeadersTest { String https2URI; /** - * A function that returns a list of 1) a HEADERS frame ( with an empty + * A function that returns a list of 1) one HEADERS frame ( with an empty * payload ), and 2) a CONTINUATION frame with the actual headers. */ static BiFunction,List> oneContinuation = @@ -100,7 +106,7 @@ public class BadHeadersTest { }; /** - * A function that returns a list of a HEADERS frame followed by a number of + * A function that returns a list of one HEADERS frame followed by a number of * CONTINUATION frames. Each frame contains just a single byte of payload. */ static BiFunction,List> byteAtATime = @@ -189,12 +195,13 @@ public class BadHeadersTest { try { HttpResponse response = cc.sendAsync(request, BodyHandlers.ofString()).get(); fail("Expected exception, got :" + response + ", " + response.body()); - } catch (Throwable t0) { + } catch (Exception t0) { System.out.println("Got EXPECTED: " + t0); if (t0 instanceof ExecutionException) { - t0 = t0.getCause(); + t = t0.getCause(); + } else { + t = t0; } - t = t0; } assertDetailMessage(t, i); } @@ -204,15 +211,21 @@ public class BadHeadersTest { // sync with implementation. static void assertDetailMessage(Throwable throwable, int iterationIndex) { try { - assertTrue(throwable instanceof IOException, - "Expected IOException, got, " + throwable); + assertTrue(throwable instanceof ProtocolException, + "Expected ProtocolException, got " + throwable); assertTrue(throwable.getMessage().contains("malformed response"), "Expected \"malformed response\" in: " + throwable.getMessage()); if (iterationIndex == 0) { // unknown assertTrue(throwable.getMessage().contains("Unknown pseudo-header"), "Expected \"Unknown pseudo-header\" in: " + throwable.getMessage()); - } else if (iterationIndex == 4) { // unexpected + } else if (iterationIndex == 4) { // prohibited + assertTrue(throwable.getMessage().contains("Prohibited header name"), + "Expected \"Prohibited header name\" in: " + throwable.getMessage()); + } else if (iterationIndex == 5) { // unexpected type + assertTrue(throwable.getMessage().contains("not valid in context"), + "Expected \"not valid in context\" in: " + throwable.getMessage()); + } else if (iterationIndex == 6) { // unexpected sequence assertTrue(throwable.getMessage().contains(" Unexpected pseudo-header"), "Expected \" Unexpected pseudo-header\" in: " + throwable.getMessage()); } else { diff --git a/test/jdk/java/net/httpclient/http2/BadPushPromiseTest.java b/test/jdk/java/net/httpclient/http2/BadPushPromiseTest.java new file mode 100644 index 00000000000..73cc12ce478 --- /dev/null +++ b/test/jdk/java/net/httpclient/http2/BadPushPromiseTest.java @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8354276 + * @library /test/lib /test/jdk/java/net/httpclient/lib + * @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.http2.Http2TestServer + * @run testng/othervm + * -Djdk.internal.httpclient.debug=true + * -Djdk.httpclient.HttpClient.log=errors,requests,responses,trace + * BadPushPromiseTest + */ + +import jdk.httpclient.test.lib.common.HttpServerAdapters; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.ProtocolException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.net.http.HttpResponse.PushPromiseHandler; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.List.of; +import static org.testng.Assert.*; + +public class BadPushPromiseTest { + + private static final List>> BAD_HEADERS = of( + Map.of(":hello", of("GET")), // Unknown pseudo-header + Map.of("hell o", of("value")), // Space in the name + Map.of("hello", of("line1\r\n line2\r\n")), // Multiline value + Map.of("hello", of("DE" + ((char) 0x7F) + "L")), // Bad byte in value + Map.of(":status", of("200")) // Response pseudo-header in request + ); + + static final String MAIN_RESPONSE_BODY = "the main response body"; + + HttpServerAdapters.HttpTestServer server; + URI uri; + + @BeforeTest + public void setup() throws Exception { + server = HttpServerAdapters.HttpTestServer.create(HTTP_2); + HttpServerAdapters.HttpTestHandler handler = new ServerPushHandler(MAIN_RESPONSE_BODY); + server.addHandler(handler, "/"); + server.start(); + String authority = server.serverAuthority(); + System.err.println("Server listening on address " + authority); + uri = new URI("http://" + authority + "/foo/a/b/c"); + } + + @AfterTest + public void teardown() { + server.stop(); + } + + /* + * Malformed push promise headers should kill the connection + */ + @Test + public void test() throws Exception { + HttpClient client = HttpClient.newHttpClient(); + + for (int i=0; i< BAD_HEADERS.size(); i++) { + URI uriWithQuery = URI.create(uri + "?BAD_HEADERS=" + i); + HttpRequest request = HttpRequest.newBuilder(uriWithQuery) + .build(); + System.out.println("\nSending request:" + uriWithQuery); + final HttpClient cc = client; + try { + ConcurrentMap>> promises + = new ConcurrentHashMap<>(); + PushPromiseHandler pph = PushPromiseHandler + .of((r) -> BodyHandlers.ofString(), promises); + HttpResponse response = cc.sendAsync(request, BodyHandlers.ofString(), pph).join(); + fail("Expected exception, got :" + response + ", " + response.body()); + } catch (CompletionException ce) { + System.out.println("Got EXPECTED: " + ce); + assertDetailMessage(ce.getCause(), i); + } + } + } + + // Assertions based on implementation specific detail messages. Keep in + // sync with implementation. + static void assertDetailMessage(Throwable throwable, int iterationIndex) { + try { + assertTrue(throwable instanceof ProtocolException, + "Expected ProtocolException, got " + throwable); + + if (iterationIndex == 0) { // unknown + assertTrue(throwable.getMessage().contains("Unknown pseudo-header"), + "Expected \"Unknown pseudo-header\" in: " + throwable.getMessage()); + } else if (iterationIndex == 4) { // unexpected type + assertTrue(throwable.getMessage().contains("not valid in context"), + "Expected \"not valid in context\" in: " + throwable.getMessage()); + } else { + assertTrue(throwable.getMessage().contains("Bad header"), + "Expected \"Bad header\" in: " + throwable.getMessage()); + } + } catch (AssertionError e) { + System.out.println("Exception does not match expectation: " + throwable); + throwable.printStackTrace(System.out); + throw e; + } + } + + // --- server push handler --- + static class ServerPushHandler implements HttpServerAdapters.HttpTestHandler { + + private final String mainResponseBody; + + public ServerPushHandler(String mainResponseBody) { + this.mainResponseBody = mainResponseBody; + } + + public void handle(HttpServerAdapters.HttpTestExchange exchange) throws IOException { + System.err.println("Server: handle " + exchange); + try (InputStream is = exchange.getRequestBody()) { + is.readAllBytes(); + } + + pushPromise(exchange); + + // response data for the main response + try (OutputStream os = exchange.getResponseBody()) { + byte[] bytes = mainResponseBody.getBytes(UTF_8); + exchange.sendResponseHeaders(200, bytes.length); + os.write(bytes); + } + } + + private void pushPromise(HttpServerAdapters.HttpTestExchange exchange) { + URI requestURI = exchange.getRequestURI(); + String query = exchange.getRequestURI().getQuery(); + int badHeadersIndex = Integer.parseInt(query.substring(query.indexOf("=") + 1)); + URI uri = requestURI.resolve("/push/"+badHeadersIndex); + InputStream is = new ByteArrayInputStream(mainResponseBody.getBytes(UTF_8)); + HttpHeaders headers = HttpHeaders.of(BAD_HEADERS.get(badHeadersIndex), (x, y) -> true); + exchange.serverPush(uri, headers, is); + System.err.println("Server: push sent"); + } + } +}