diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java index ca30c23c77d0..b43b3291ad93 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,6 +54,8 @@ public abstract class AbstractWebSocketHandlerRegistration implements WebSock private final List allowedOrigins = new ArrayList<>(); + private final List allowedOriginPatterns = new ArrayList<>(); + @Nullable private SockJsServiceRegistration sockJsServiceRegistration; @@ -94,6 +96,15 @@ public WebSocketHandlerRegistration setAllowedOrigins(String... allowedOrigins) return this; } + @Override + public WebSocketHandlerRegistration setAllowedOriginPatterns(String... allowedOriginPatterns) { + this.allowedOriginPatterns.clear(); + if (!ObjectUtils.isEmpty(allowedOriginPatterns)) { + this.allowedOriginPatterns.addAll(Arrays.asList(allowedOriginPatterns)); + } + return this; + } + @Override public SockJsServiceRegistration withSockJS() { this.sockJsServiceRegistration = new SockJsServiceRegistration(); @@ -108,13 +119,21 @@ public SockJsServiceRegistration withSockJS() { if (!this.allowedOrigins.isEmpty()) { this.sockJsServiceRegistration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins)); } + if (!this.allowedOriginPatterns.isEmpty()) { + this.sockJsServiceRegistration.setAllowedOriginPatterns( + StringUtils.toStringArray(this.allowedOriginPatterns)); + } return this.sockJsServiceRegistration; } protected HandshakeInterceptor[] getInterceptors() { List interceptors = new ArrayList<>(this.interceptors.size() + 1); interceptors.addAll(this.interceptors); - interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins)); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(this.allowedOrigins); + if (!ObjectUtils.isEmpty(this.allowedOriginPatterns)) { + interceptor.setAllowedOriginPatterns(this.allowedOriginPatterns); + } + interceptors.add(interceptor); return interceptors.toArray(new HandshakeInterceptor[0]); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java index 72b484a98924..48642a305bdf 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,6 +63,15 @@ public interface WebSocketHandlerRegistration { */ WebSocketHandlerRegistration setAllowedOrigins(String... origins); + /** + * A variant of {@link #setAllowedOrigins(String...)} that accepts flexible + * domain patterns, e.g. {@code "https://*.domain1.com"}. Furthermore it + * always sets the {@code Access-Control-Allow-Origin} response header to + * the matched origin and never to {@code "*"}, nor to any other pattern. + * @since 5.3.5 + */ + WebSocketHandlerRegistration setAllowedOriginPatterns(String... originPatterns); + /** * Enable SockJS fallback options. */ diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java index 398c7afe134e..f7dae4c8cb02 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -115,7 +115,10 @@ public void interceptorsWithAllowedOrigins() { WebSocketHandler handler = new TextWebSocketHandler(); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); - this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins("https://mydomain1.example"); + this.registration.addHandler(handler, "/foo") + .addInterceptors(interceptor) + .setAllowedOrigins("https://mydomain1.example") + .setAllowedOriginPatterns("https://*.abc.com"); List mappings = this.registration.getMappings(); assertThat(mappings.size()).isEqualTo(1); @@ -126,7 +129,10 @@ public void interceptorsWithAllowedOrigins() { assertThat(mapping.interceptors).isNotNull(); assertThat(mapping.interceptors.length).isEqualTo(2); assertThat(mapping.interceptors[0]).isEqualTo(interceptor); - assertThat(mapping.interceptors[1].getClass()).isEqualTo(OriginHandshakeInterceptor.class); + + OriginHandshakeInterceptor originInterceptor = (OriginHandshakeInterceptor) mapping.interceptors[1]; + assertThat(originInterceptor.getAllowedOrigins()).containsExactly("https://mydomain1.example"); + assertThat(originInterceptor.getAllowedOriginPatterns()).containsExactly("https://*.abc.com"); } @Test @@ -137,6 +143,7 @@ public void interceptorsPassedToSockJsRegistration() { this.registration.addHandler(handler, "/foo") .addInterceptors(interceptor) .setAllowedOrigins("https://mydomain1.example") + .setAllowedOriginPatterns("https://*.abc.com") .withSockJS(); this.registration.getSockJsServiceRegistration().setTaskScheduler(this.taskScheduler); @@ -151,7 +158,10 @@ public void interceptorsPassedToSockJsRegistration() { assertThat(mapping.sockJsService.getAllowedOrigins().contains("https://mydomain1.example")).isTrue(); List interceptors = mapping.sockJsService.getHandshakeInterceptors(); assertThat(interceptors.get(0)).isEqualTo(interceptor); - assertThat(interceptors.get(1).getClass()).isEqualTo(OriginHandshakeInterceptor.class); + + OriginHandshakeInterceptor originInterceptor = (OriginHandshakeInterceptor) interceptors.get(1); + assertThat(originInterceptor.getAllowedOrigins()).containsExactly("https://mydomain1.example"); + assertThat(originInterceptor.getAllowedOriginPatterns()).containsExactly("https://*.abc.com"); } @Test