diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java index a0b378efee33..aac3649add9b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.web.context.ServletContextAware; @@ -63,15 +64,18 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Serv private static final ThreadLocal containerHolder = new NamedThreadLocal<>("WebSocketHandlerContainer"); - + @Nullable private WebSocketPolicy policy; - private WebSocketServerFactory factory; + @Nullable + private volatile WebSocketServerFactory factory; + @Nullable private ServletContext servletContext; private volatile boolean running = false; + @Nullable private volatile List supportedExtensions; @@ -114,17 +118,20 @@ public void start() { if (!isRunning()) { this.running = true; try { - if (this.factory == null) { - this.factory = new WebSocketServerFactory(servletContext, this.policy); + WebSocketServerFactory factory = this.factory; + if (factory == null) { + Assert.state(this.servletContext != null, "No ServletContext set"); + factory = new WebSocketServerFactory(this.servletContext, this.policy); + this.factory = factory; } - this.factory.setCreator((request, response) -> { - WebSocketHandlerContainer container = containerHolder.get(); - Assert.state(container != null, "Expected WebSocketHandlerContainer"); - response.setAcceptedSubProtocol(container.getSelectedProtocol()); - response.setExtensions(container.getExtensionConfigs()); - return container.getHandler(); - }); - this.factory.start(); + factory.setCreator((request, response) -> { + WebSocketHandlerContainer container = containerHolder.get(); + Assert.state(container != null, "Expected WebSocketHandlerContainer"); + response.setAcceptedSubProtocol(container.getSelectedProtocol()); + response.setExtensions(container.getExtensionConfigs()); + return container.getHandler(); + }); + factory.start(); } catch (Throwable ex) { throw new IllegalStateException("Unable to start Jetty WebSocketServerFactory", ex); @@ -136,9 +143,10 @@ public void start() { public void stop() { if (isRunning()) { this.running = false; - if (this.factory != null) { + WebSocketServerFactory factory = this.factory; + if (factory != null) { try { - this.factory.stop(); + factory.stop(); } catch (Throwable ex) { throw new IllegalStateException("Unable to stop Jetty WebSocketServerFactory", ex); @@ -160,14 +168,18 @@ public String[] getSupportedVersions() { @Override public List getSupportedExtensions(ServerHttpRequest request) { - if (this.supportedExtensions == null) { - this.supportedExtensions = buildWebSocketExtensions(); + List extensions = this.supportedExtensions; + if (extensions == null) { + extensions = buildWebSocketExtensions(); + this.supportedExtensions = extensions; } - return this.supportedExtensions; + return extensions; } private List buildWebSocketExtensions() { - Set names = this.factory.getExtensionFactory().getExtensionNames(); + WebSocketServerFactory factory = this.factory; + Assert.state(factory != null, "No WebSocketServerFactory available"); + Set names = factory.getExtensionFactory().getExtensionNames(); List result = new ArrayList<>(names.size()); for (String name : names) { result.add(new WebSocketExtension(name)); @@ -177,7 +189,7 @@ private List buildWebSocketExtensions() { @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String selectedProtocol, List selectedExtensions, Principal user, + @Nullable String selectedProtocol, List selectedExtensions, @Nullable Principal user, WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { Assert.isInstanceOf(ServletServerHttpRequest.class, request, "ServletServerHttpRequest required"); @@ -186,7 +198,9 @@ public void upgrade(ServerHttpRequest request, ServerHttpResponse response, Assert.isInstanceOf(ServletServerHttpResponse.class, response, "ServletServerHttpResponse required"); HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse(); - Assert.isTrue(this.factory.isUpgradeRequest(servletRequest, servletResponse), "Not a WebSocket handshake"); + WebSocketServerFactory factory = this.factory; + Assert.state(factory != null, "No WebSocketServerFactory available"); + Assert.isTrue(factory.isUpgradeRequest(servletRequest, servletResponse), "Not a WebSocket handshake"); JettyWebSocketSession session = new JettyWebSocketSession(attributes, user); JettyWebSocketHandlerAdapter handlerAdapter = new JettyWebSocketHandlerAdapter(wsHandler, session); @@ -196,7 +210,7 @@ public void upgrade(ServerHttpRequest request, ServerHttpResponse response, try { containerHolder.set(container); - this.factory.acceptWebSocket(servletRequest, servletResponse); + factory.acceptWebSocket(servletRequest, servletResponse); } catch (IOException ex) { throw new HandshakeFailureException( @@ -212,12 +226,13 @@ private static class WebSocketHandlerContainer { private final JettyWebSocketHandlerAdapter handler; + @Nullable private final String selectedProtocol; private final List extensionConfigs; - public WebSocketHandlerContainer( - JettyWebSocketHandlerAdapter handler, String protocol, List extensions) { + public WebSocketHandlerContainer(JettyWebSocketHandlerAdapter handler, + @Nullable String protocol, List extensions) { this.handler = handler; this.selectedProtocol = protocol; @@ -236,6 +251,7 @@ public JettyWebSocketHandlerAdapter getHandler() { return this.handler; } + @Nullable public String getSelectedProtocol() { return this.selectedProtocol; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/package-info.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/package-info.java index 2a3627773a10..20a6fa642b5d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/package-info.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/package-info.java @@ -1,4 +1,9 @@ /** * Server-side support for the Jetty 9+ WebSocket API. */ +@NonNullApi +@NonNullFields package org.springframework.web.socket.server.jetty; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields;