From 5ca687c9a6b5df8d24a2a638879a7142c99bebe9 Mon Sep 17 00:00:00 2001 From: Phillip Webb Date: Tue, 15 Jun 2021 16:48:04 -0700 Subject: [PATCH] Polish 'Make livereload websocket headers case insensitive' See gh-26813 Closes gh-26813 --- .../boot/devtools/livereload/Connection.java | 20 ++-- .../livereload/LiveReloadServerTests.java | 103 +++++++++++++++++- 2 files changed, 113 insertions(+), 10 deletions(-) diff --git a/spring-boot-project/spring-boot-devtools/src/main/java/org/springframework/boot/devtools/livereload/Connection.java b/spring-boot-project/spring-boot-devtools/src/main/java/org/springframework/boot/devtools/livereload/Connection.java index 0a62bd03778e..070dca8a4f5e 100644 --- a/spring-boot-project/spring-boot-devtools/src/main/java/org/springframework/boot/devtools/livereload/Connection.java +++ b/spring-boot-project/spring-boot-devtools/src/main/java/org/springframework/boot/devtools/livereload/Connection.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2019 the original author or authors. + * Copyright 2012-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import java.net.SocketTimeoutException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.util.Locale; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -30,18 +31,20 @@ import org.apache.commons.logging.LogFactory; import org.springframework.core.log.LogMessage; +import org.springframework.util.Assert; import org.springframework.util.Base64Utils; /** * A {@link LiveReloadServer} connection. * * @author Phillip Webb + * @author Francis Lavoie */ class Connection { private static final Log logger = LogFactory.getLog(Connection.class); - private static final Pattern WEBSOCKET_KEY_PATTERN = Pattern.compile("^Sec-WebSocket-Key:(.*)$", Pattern.MULTILINE | Pattern.CASE_INSENSITIVE); + private static final Pattern WEBSOCKET_KEY_PATTERN = Pattern.compile("^sec-websocket-key:(.*)$", Pattern.MULTILINE); public static final String WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -68,8 +71,9 @@ class Connection { this.socket = socket; this.inputStream = new ConnectionInputStream(inputStream); this.outputStream = new ConnectionOutputStream(outputStream); - this.header = this.inputStream.readHeader(); - logger.debug(LogMessage.format("Established livereload connection [%s]", this.header)); + String header = this.inputStream.readHeader(); + logger.debug(LogMessage.format("Established livereload connection [%s]", header)); + this.header = header.toLowerCase(Locale.ENGLISH); } /** @@ -77,10 +81,10 @@ class Connection { * @throws Exception in case of errors */ void run() throws Exception { - if (this.header.contains("Upgrade: websocket") && this.header.toLowerCase().contains("sec-websocket-version: 13")) { + if (this.header.contains("upgrade: websocket") && this.header.contains("sec-websocket-version: 13")) { runWebSocket(); } - if (this.header.contains("GET /livereload.js")) { + if (this.header.contains("get /livereload.js")) { this.outputStream.writeHttp(getClass().getResourceAsStream("livereload.js"), "text/javascript"); } } @@ -140,9 +144,7 @@ private void writeWebSocketFrame(Frame frame) throws IOException { private String getWebsocketAcceptResponse() throws NoSuchAlgorithmException { Matcher matcher = WEBSOCKET_KEY_PATTERN.matcher(this.header); - if (!matcher.find()) { - throw new IllegalStateException("No Sec-WebSocket-Key"); - } + Assert.state(matcher.find(), "No Sec-WebSocket-Key"); String response = matcher.group(1).trim() + WEBSOCKET_GUID; MessageDigest messageDigest = MessageDigest.getInstance("SHA-1"); messageDigest.update(response.getBytes(), 0, response.length()); diff --git a/spring-boot-project/spring-boot-devtools/src/test/java/org/springframework/boot/devtools/livereload/LiveReloadServerTests.java b/spring-boot-project/spring-boot-devtools/src/test/java/org/springframework/boot/devtools/livereload/LiveReloadServerTests.java index 8b15e1e15274..ef9bfb786c28 100644 --- a/spring-boot-project/spring-boot-devtools/src/test/java/org/springframework/boot/devtools/livereload/LiveReloadServerTests.java +++ b/spring-boot-project/spring-boot-devtools/src/test/java/org/springframework/boot/devtools/livereload/LiveReloadServerTests.java @@ -19,13 +19,27 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.net.InetAddress; +import java.net.InetSocketAddress; import java.net.URI; +import java.net.UnknownHostException; import java.time.Duration; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collectors; + +import javax.websocket.ClientEndpointConfig; +import javax.websocket.ClientEndpointConfig.Configurator; +import javax.websocket.Endpoint; +import javax.websocket.HandshakeResponse; +import javax.websocket.WebSocketContainer; import org.apache.tomcat.websocket.WsWebSocketContainer; import org.awaitility.Awaitility; @@ -34,13 +48,20 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.util.concurrent.ListenableFuture; import org.springframework.web.client.RestTemplate; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter; +import org.springframework.web.socket.adapter.standard.StandardWebSocketSession; +import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter; import org.springframework.web.socket.client.WebSocketClient; import org.springframework.web.socket.client.standard.StandardWebSocketClient; import org.springframework.web.socket.handler.TextWebSocketHandler; @@ -94,7 +115,16 @@ void triggerReload() throws Exception { (msgs) -> msgs.size() == 2); assertThat(messages.get(0)).contains("http://livereload.com/protocols/official-7"); assertThat(messages.get(1)).contains("command\":\"reload\""); + } + @Test // gh-26813 + void triggerReloadWithUppercaseHeaders() throws Exception { + LiveReloadWebSocketHandler handler = connect(UppercaseWebSocketClient::new); + this.server.triggerReload(); + List messages = await().atMost(Duration.ofSeconds(10)).until(handler::getMessages, + (msgs) -> msgs.size() == 2); + assertThat(messages.get(0)).contains("http://livereload.com/protocols/official-7"); + assertThat(messages.get(1)).contains("command\":\"reload\""); } @Test @@ -126,7 +156,13 @@ void serverClose() throws Exception { } private LiveReloadWebSocketHandler connect() throws Exception { - WebSocketClient client = new StandardWebSocketClient(new WsWebSocketContainer()); + return connect(StandardWebSocketClient::new); + } + + private LiveReloadWebSocketHandler connect(Function clientFactory) + throws Exception { + WsWebSocketContainer webSocketContainer = new WsWebSocketContainer(); + WebSocketClient client = clientFactory.apply(webSocketContainer); LiveReloadWebSocketHandler handler = new LiveReloadWebSocketHandler(); client.doHandshake(handler, "ws://localhost:" + this.port + "/livereload"); handler.awaitHello(); @@ -246,4 +282,69 @@ CloseStatus getCloseStatus() { } + static class UppercaseWebSocketClient extends StandardWebSocketClient { + + private final WebSocketContainer webSocketContainer; + + UppercaseWebSocketClient(WebSocketContainer webSocketContainer) { + super(webSocketContainer); + this.webSocketContainer = webSocketContainer; + } + + @Override + protected ListenableFuture doHandshakeInternal(WebSocketHandler webSocketHandler, + HttpHeaders headers, URI uri, List protocols, List extensions, + Map attributes) { + InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), uri.getPort()); + InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), uri.getPort()); + StandardWebSocketSession session = new StandardWebSocketSession(headers, attributes, localAddress, + remoteAddress); + ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create() + .configurator(new UppercaseWebSocketClientConfigurator(headers)).preferredSubprotocols(protocols) + .extensions(extensions.stream().map(WebSocketToStandardExtensionAdapter::new) + .collect(Collectors.toList())) + .build(); + endpointConfig.getUserProperties().putAll(getUserProperties()); + Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session); + Callable connectTask = () -> { + this.webSocketContainer.connectToServer(endpoint, endpointConfig, uri); + return session; + }; + return getTaskExecutor().submitListenable(connectTask); + } + + private InetAddress getLocalHost() { + try { + return InetAddress.getLocalHost(); + } + catch (UnknownHostException ex) { + return InetAddress.getLoopbackAddress(); + } + } + + } + + private static class UppercaseWebSocketClientConfigurator extends Configurator { + + private final HttpHeaders headers; + + UppercaseWebSocketClientConfigurator(HttpHeaders headers) { + this.headers = headers; + } + + @Override + public void beforeRequest(Map> requestHeaders) { + Map> uppercaseRequestHeaders = new LinkedHashMap<>(); + requestHeaders.forEach((key, value) -> uppercaseRequestHeaders.put(key.toUpperCase(), value)); + requestHeaders.clear(); + requestHeaders.putAll(uppercaseRequestHeaders); + requestHeaders.putAll(this.headers); + } + + @Override + public void afterResponse(HandshakeResponse response) { + } + + } + }