Skip to content

Commit

Permalink
Polish 'Make livereload websocket headers case insensitive'
Browse files Browse the repository at this point in the history
See gh-26813

Closes gh-26813
  • Loading branch information
philwebb committed Jun 16, 2021
1 parent 8755512 commit 5ca687c
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2019 the original author or authors.
* Copyright 2012-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,25 +23,28 @@
import java.net.SocketTimeoutException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.core.log.LogMessage;
import org.springframework.util.Assert;
import org.springframework.util.Base64Utils;

/**
* A {@link LiveReloadServer} connection.
*
* @author Phillip Webb
* @author Francis Lavoie
*/
class Connection {

private static final Log logger = LogFactory.getLog(Connection.class);

private static final Pattern WEBSOCKET_KEY_PATTERN = Pattern.compile("^Sec-WebSocket-Key:(.*)$", Pattern.MULTILINE | Pattern.CASE_INSENSITIVE);
private static final Pattern WEBSOCKET_KEY_PATTERN = Pattern.compile("^sec-websocket-key:(.*)$", Pattern.MULTILINE);

public static final String WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";

Expand All @@ -68,19 +71,20 @@ class Connection {
this.socket = socket;
this.inputStream = new ConnectionInputStream(inputStream);
this.outputStream = new ConnectionOutputStream(outputStream);
this.header = this.inputStream.readHeader();
logger.debug(LogMessage.format("Established livereload connection [%s]", this.header));
String header = this.inputStream.readHeader();
logger.debug(LogMessage.format("Established livereload connection [%s]", header));
this.header = header.toLowerCase(Locale.ENGLISH);
}

/**
* Run the connection.
* @throws Exception in case of errors
*/
void run() throws Exception {
if (this.header.contains("Upgrade: websocket") && this.header.toLowerCase().contains("sec-websocket-version: 13")) {
if (this.header.contains("upgrade: websocket") && this.header.contains("sec-websocket-version: 13")) {
runWebSocket();
}
if (this.header.contains("GET /livereload.js")) {
if (this.header.contains("get /livereload.js")) {
this.outputStream.writeHttp(getClass().getResourceAsStream("livereload.js"), "text/javascript");
}
}
Expand Down Expand Up @@ -140,9 +144,7 @@ private void writeWebSocketFrame(Frame frame) throws IOException {

private String getWebsocketAcceptResponse() throws NoSuchAlgorithmException {
Matcher matcher = WEBSOCKET_KEY_PATTERN.matcher(this.header);
if (!matcher.find()) {
throw new IllegalStateException("No Sec-WebSocket-Key");
}
Assert.state(matcher.find(), "No Sec-WebSocket-Key");
String response = matcher.group(1).trim() + WEBSOCKET_GUID;
MessageDigest messageDigest = MessageDigest.getInstance("SHA-1");
messageDigest.update(response.getBytes(), 0, response.length());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,27 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.UnknownHostException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;

import javax.websocket.ClientEndpointConfig;
import javax.websocket.ClientEndpointConfig.Configurator;
import javax.websocket.Endpoint;
import javax.websocket.HandshakeResponse;
import javax.websocket.WebSocketContainer;

import org.apache.tomcat.websocket.WsWebSocketContainer;
import org.awaitility.Awaitility;
Expand All @@ -34,13 +48,20 @@
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import org.springframework.http.HttpHeaders;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PingMessage;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.handler.TextWebSocketHandler;
Expand Down Expand Up @@ -94,7 +115,16 @@ void triggerReload() throws Exception {
(msgs) -> msgs.size() == 2);
assertThat(messages.get(0)).contains("http://livereload.com/protocols/official-7");
assertThat(messages.get(1)).contains("command\":\"reload\"");
}

@Test // gh-26813
void triggerReloadWithUppercaseHeaders() throws Exception {
LiveReloadWebSocketHandler handler = connect(UppercaseWebSocketClient::new);
this.server.triggerReload();
List<String> messages = await().atMost(Duration.ofSeconds(10)).until(handler::getMessages,
(msgs) -> msgs.size() == 2);
assertThat(messages.get(0)).contains("http://livereload.com/protocols/official-7");
assertThat(messages.get(1)).contains("command\":\"reload\"");
}

@Test
Expand Down Expand Up @@ -126,7 +156,13 @@ void serverClose() throws Exception {
}

private LiveReloadWebSocketHandler connect() throws Exception {
WebSocketClient client = new StandardWebSocketClient(new WsWebSocketContainer());
return connect(StandardWebSocketClient::new);
}

private LiveReloadWebSocketHandler connect(Function<WebSocketContainer, WebSocketClient> clientFactory)
throws Exception {
WsWebSocketContainer webSocketContainer = new WsWebSocketContainer();
WebSocketClient client = clientFactory.apply(webSocketContainer);
LiveReloadWebSocketHandler handler = new LiveReloadWebSocketHandler();
client.doHandshake(handler, "ws://localhost:" + this.port + "/livereload");
handler.awaitHello();
Expand Down Expand Up @@ -246,4 +282,69 @@ CloseStatus getCloseStatus() {

}

static class UppercaseWebSocketClient extends StandardWebSocketClient {

private final WebSocketContainer webSocketContainer;

UppercaseWebSocketClient(WebSocketContainer webSocketContainer) {
super(webSocketContainer);
this.webSocketContainer = webSocketContainer;
}

@Override
protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
HttpHeaders headers, URI uri, List<String> protocols, List<WebSocketExtension> extensions,
Map<String, Object> attributes) {
InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), uri.getPort());
InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), uri.getPort());
StandardWebSocketSession session = new StandardWebSocketSession(headers, attributes, localAddress,
remoteAddress);
ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create()
.configurator(new UppercaseWebSocketClientConfigurator(headers)).preferredSubprotocols(protocols)
.extensions(extensions.stream().map(WebSocketToStandardExtensionAdapter::new)
.collect(Collectors.toList()))
.build();
endpointConfig.getUserProperties().putAll(getUserProperties());
Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
Callable<WebSocketSession> connectTask = () -> {
this.webSocketContainer.connectToServer(endpoint, endpointConfig, uri);
return session;
};
return getTaskExecutor().submitListenable(connectTask);
}

private InetAddress getLocalHost() {
try {
return InetAddress.getLocalHost();
}
catch (UnknownHostException ex) {
return InetAddress.getLoopbackAddress();
}
}

}

private static class UppercaseWebSocketClientConfigurator extends Configurator {

private final HttpHeaders headers;

UppercaseWebSocketClientConfigurator(HttpHeaders headers) {
this.headers = headers;
}

@Override
public void beforeRequest(Map<String, List<String>> requestHeaders) {
Map<String, List<String>> uppercaseRequestHeaders = new LinkedHashMap<>();
requestHeaders.forEach((key, value) -> uppercaseRequestHeaders.put(key.toUpperCase(), value));
requestHeaders.clear();
requestHeaders.putAll(uppercaseRequestHeaders);
requestHeaders.putAll(this.headers);
}

@Override
public void afterResponse(HandshakeResponse response) {
}

}

}

0 comments on commit 5ca687c

Please sign in to comment.