diff --git a/junit4/src/main/java/com/linecorp/armeria/testing/junit4/server/ServerRule.java b/junit4/src/main/java/com/linecorp/armeria/testing/junit4/server/ServerRule.java index 778186f9962..3eb826ba27c 100644 --- a/junit4/src/main/java/com/linecorp/armeria/testing/junit4/server/ServerRule.java +++ b/junit4/src/main/java/com/linecorp/armeria/testing/junit4/server/ServerRule.java @@ -31,6 +31,8 @@ import com.linecorp.armeria.client.RestClient; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.WebClientBuilder; +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder; import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.UnstableApi; @@ -68,6 +70,12 @@ public void configure(ServerBuilder sb) throws Exception { public void configureWebClient(WebClientBuilder wcb) throws Exception { ServerRule.this.configureWebClient(wcb); } + + @Override + public void configureWebSocketClient(WebSocketClientBuilder wscb) + throws Exception { + ServerRule.this.configureWebSocketClient(wscb); + } }; } @@ -109,6 +117,12 @@ public Server start() { */ protected void configureWebClient(WebClientBuilder webClientBuilder) throws Exception {} + /** + * Configures the {@link WebSocketClient} with the given {@link WebSocketClientBuilder}. + * You can get the configured {@link WebSocketClient} using {@link #webSocketClient()}. + */ + protected void configureWebSocketClient(WebSocketClientBuilder webSocketClientBuilder) throws Exception {} + /** * Stops the {@link Server} asynchronously. * @@ -344,4 +358,13 @@ public RestClient restClient(Consumer webClientCustomizer) { requireNonNull(webClientCustomizer, "webClientCustomizer"); return delegate.restClient(webClientCustomizer); } + + /** + * Returns the {@link WebSocketClient} configured + * by {@link #configureWebSocketClient(WebSocketClientBuilder)}. + */ + @UnstableApi + public WebSocketClient webSocketClient() { + return delegate.webSocketClient(); + } } diff --git a/junit5/src/main/java/com/linecorp/armeria/internal/testing/ServerRuleDelegate.java b/junit5/src/main/java/com/linecorp/armeria/internal/testing/ServerRuleDelegate.java index 0ea684d795e..3b6e661d309 100644 --- a/junit5/src/main/java/com/linecorp/armeria/internal/testing/ServerRuleDelegate.java +++ b/junit5/src/main/java/com/linecorp/armeria/internal/testing/ServerRuleDelegate.java @@ -32,6 +32,8 @@ import com.linecorp.armeria.client.RestClient; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.WebClientBuilder; +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder; import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; @@ -49,6 +51,7 @@ public abstract class ServerRuleDelegate { private final boolean autoStart; private final AtomicReference webClient = new AtomicReference<>(); + private final AtomicReference webSocketClient = new AtomicReference<>(); /** * Creates a new instance. @@ -114,6 +117,14 @@ public Server start() { */ public abstract void configureWebClient(WebClientBuilder webClientBuilder) throws Exception; + /** + * Configures the {@link WebSocketClient} with the given {@link WebSocketClientBuilder}. + * You can get the configured {@link WebSocketClient} using {@link #webSocketClient()}. + */ + @UnstableApi + public abstract void configureWebSocketClient(WebSocketClientBuilder webSocketClientBuilder) + throws Exception; + /** * Stops the {@link Server} asynchronously. * @@ -404,6 +415,25 @@ public RestClient restClient(Consumer webClientCustomizer) { return webClient(webClientCustomizer).asRestClient(); } + /** + * Returns the {@link WebSocketClient} configured + * by {@link #configureWebSocketClient(WebSocketClientBuilder)}. + */ + @UnstableApi + public WebSocketClient webSocketClient() { + final WebSocketClient webSocketClient = this.webSocketClient.get(); + if (webSocketClient != null) { + return webSocketClient; + } + + final WebSocketClient newWebSocketClient = webSocketClientBuilder().build(); + if (this.webSocketClient.compareAndSet(null, newWebSocketClient)) { + return newWebSocketClient; + } else { + return this.webSocketClient.get(); + } + } + private void ensureStarted() { // This will ensure that the server has started. server(); @@ -422,4 +452,21 @@ private WebClientBuilder webClientBuilder() { } return webClientBuilder; } + + private WebSocketClientBuilder webSocketClientBuilder() { + final boolean hasHttps = hasHttps(); + final String hostAndPort = hasHttps ? "wss://" + httpsUri().getAuthority() + : "ws://" + httpUri().getAuthority(); + final WebSocketClientBuilder webSocketClientBuilder = WebSocketClient.builder(hostAndPort); + if (hasHttps) { + webSocketClientBuilder.factory(ClientFactory.insecure()); + } + + try { + configureWebSocketClient(webSocketClientBuilder); + } catch (Exception e) { + throw new IllegalStateException("failed to configure a WebSocketClient", e); + } + return webSocketClientBuilder; + } } diff --git a/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/ServerExtension.java b/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/ServerExtension.java index 8ba44a31911..1c44cdf221a 100644 --- a/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/ServerExtension.java +++ b/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/ServerExtension.java @@ -31,6 +31,8 @@ import com.linecorp.armeria.client.RestClient; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.WebClientBuilder; +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder; import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.UnstableApi; @@ -75,6 +77,12 @@ public void configure(ServerBuilder sb) throws Exception { public void configureWebClient(WebClientBuilder wcb) throws Exception { ServerExtension.this.configureWebClient(wcb); } + + @Override + public void configureWebSocketClient(WebSocketClientBuilder wscb) + throws Exception { + ServerExtension.this.configureWebSocketClient(wscb); + } }; } @@ -126,6 +134,12 @@ public Server start() { */ protected void configureWebClient(WebClientBuilder webClientBuilder) throws Exception {} + /** + * Configures the {@link WebSocketClient} with the given {@link WebSocketClientBuilder}. + * You can get the configured {@link WebSocketClient} using {@link #webSocketClient()}. + */ + protected void configureWebSocketClient(WebSocketClientBuilder webSocketClientBuilder) throws Exception {} + /** * Stops the {@link Server} asynchronously. * @@ -370,6 +384,15 @@ public RestClient restClient(Consumer webClientCustomizer) { return delegate.restClient(webClientCustomizer); } + /** + * Returns the {@link WebSocketClient} configured + * by {@link #configureWebSocketClient(WebSocketClientBuilder)}. + */ + @UnstableApi + public WebSocketClient webSocketClient() { + return delegate.webSocketClient(); + } + /** * Determines whether the {@link ServiceRequestContext} should be captured or not. * This method returns {@code true} by default. Override it to capture the contexts diff --git a/junit5/src/test/java/com/linecorp/armeria/testing/junit5/server/ServerExtensionWithWebSocketClientTest.java b/junit5/src/test/java/com/linecorp/armeria/testing/junit5/server/ServerExtensionWithWebSocketClientTest.java new file mode 100644 index 00000000000..2ee442877d6 --- /dev/null +++ b/junit5/src/test/java/com/linecorp/armeria/testing/junit5/server/ServerExtensionWithWebSocketClientTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.testing.junit5.server; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketSession; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.common.websocket.WebSocketFrame; +import com.linecorp.armeria.common.websocket.WebSocketWriter; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.websocket.WebSocketService; +import com.linecorp.armeria.server.websocket.WebSocketServiceHandler; + +class ServerExtensionWithWebSocketClientTest { + + @RegisterExtension + static ServerExtension wsServer = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + sb.service("/chat", WebSocketService.builder(new WebSocketEchoHandler()) + .build()); + } + }; + + @RegisterExtension + static ServerExtension wssServer = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + sb.tlsSelfSigned(); + sb.service("/chat", WebSocketService.builder(new WebSocketEchoHandler()) + .build()); + } + }; + + @CsvSource({ "true", "false" }) + @ParameterizedTest + void webSocketClient(boolean useTls) { + final WebSocketClient webSocketClient = useTls ? wssServer.webSocketClient() + : wsServer.webSocketClient(); + final WebSocketSession wsSession = webSocketClient.connect("/chat").join(); + assertThat(wsSession).isNotNull(); + final WebSocketWriter outbound = wsSession.outbound(); + outbound.write("hello"); + final String message = useTls ? "wss" : "ws"; + outbound.write(message); + outbound.close(); + final List responses = wsSession.inbound().collect().join().stream().map(WebSocketFrame::text) + .collect(toImmutableList()); + assertThat(responses).contains("hello", message); + } + + static final class WebSocketEchoHandler implements WebSocketServiceHandler { + + @Override + public WebSocket handle(ServiceRequestContext ctx, WebSocket in) { + final WebSocketWriter writer = WebSocket.streaming(); + in.subscribe(new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(WebSocketFrame webSocketFrame) { + writer.write(webSocketFrame); + } + + @Override + public void onError(Throwable t) { + writer.close(t); + } + + @Override + public void onComplete() { + writer.close(); + } + }); + return writer; + } + } +} diff --git a/sangria/sangria_2.13/src/test/scala/com/linecorp/armeria/server/sangria/ServerSuite.scala b/sangria/sangria_2.13/src/test/scala/com/linecorp/armeria/server/sangria/ServerSuite.scala index fd5139128df..87f192d696f 100644 --- a/sangria/sangria_2.13/src/test/scala/com/linecorp/armeria/server/sangria/ServerSuite.scala +++ b/sangria/sangria_2.13/src/test/scala/com/linecorp/armeria/server/sangria/ServerSuite.scala @@ -16,8 +16,8 @@ package com.linecorp.armeria.server.sangria -import com.linecorp.armeria.client.{WebClient, WebClientBuilder} -import com.linecorp.armeria.client.logging.LoggingClient +import com.linecorp.armeria.client.WebClientBuilder +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder import com.linecorp.armeria.internal.testing.ServerRuleDelegate import com.linecorp.armeria.server.ServerBuilder import munit.Suite @@ -33,6 +33,8 @@ trait ServerSuite { protected def configureWebClient: WebClientBuilder => Unit = _ => () + protected def configureWebSocketClient: WebSocketClientBuilder => Unit = _ => () + protected def server: ServerRuleDelegate = delegate /** @@ -46,6 +48,9 @@ trait ServerSuite { override def configure(sb: ServerBuilder): Unit = configureServer(sb) override def configureWebClient(wcb: WebClientBuilder): Unit = self.configureWebClient(wcb) + + override def configureWebSocketClient(webSocketClientBuilder: WebSocketClientBuilder): Unit = + self.configureWebSocketClient(webSocketClientBuilder) } if (!runServerForEachTest) { diff --git a/scala/scala_2.13/src/test/scala/com/linecorp/armeria/server/ServerSuite.scala b/scala/scala_2.13/src/test/scala/com/linecorp/armeria/server/ServerSuite.scala index 5f95891f4e2..ca78316f58d 100644 --- a/scala/scala_2.13/src/test/scala/com/linecorp/armeria/server/ServerSuite.scala +++ b/scala/scala_2.13/src/test/scala/com/linecorp/armeria/server/ServerSuite.scala @@ -17,6 +17,7 @@ package com.linecorp.armeria.server import com.linecorp.armeria.client.WebClientBuilder +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder import com.linecorp.armeria.internal.testing.ServerRuleDelegate import munit.Suite @@ -29,6 +30,8 @@ trait ServerSuite { protected def configureWebClient: WebClientBuilder => Unit = _ => () + protected def configureWebSocketClient: WebSocketClientBuilder => Unit = _ => () + protected def server: ServerRuleDelegate = delegate /** @@ -42,6 +45,9 @@ trait ServerSuite { override def configure(sb: ServerBuilder): Unit = configureServer(sb) override def configureWebClient(wcb: WebClientBuilder): Unit = self.configureWebClient(wcb) + + override def configureWebSocketClient(wscb: WebSocketClientBuilder): Unit = + self.configureWebSocketClient(wscb) } if (!runServerForEachTest) { diff --git a/scalapb/scalapb_2.13/src/test/scala/com/linecorp/armeria/server/scalapb/ServerSuite.scala b/scalapb/scalapb_2.13/src/test/scala/com/linecorp/armeria/server/scalapb/ServerSuite.scala index 509e0a4c0e7..d6965bcc07a 100644 --- a/scalapb/scalapb_2.13/src/test/scala/com/linecorp/armeria/server/scalapb/ServerSuite.scala +++ b/scalapb/scalapb_2.13/src/test/scala/com/linecorp/armeria/server/scalapb/ServerSuite.scala @@ -17,6 +17,7 @@ package com.linecorp.armeria.server.scalapb import com.linecorp.armeria.client.WebClientBuilder +import com.linecorp.armeria.client.websocket.WebSocketClientBuilder import com.linecorp.armeria.internal.testing.ServerRuleDelegate import com.linecorp.armeria.server.ServerBuilder import munit.Suite @@ -30,6 +31,8 @@ trait ServerSuite { protected def configureWebClient: WebClientBuilder => Unit = _ => () + protected def configureWebSocketClient: WebSocketClientBuilder => Unit = _ => () + protected def server: ServerRuleDelegate = delegate /** @@ -43,6 +46,9 @@ trait ServerSuite { override def configure(sb: ServerBuilder): Unit = configureServer(sb) override def configureWebClient(wcb: WebClientBuilder): Unit = self.configureWebClient(wcb) + + override def configureWebSocketClient(webSocketClientBuilder: WebSocketClientBuilder): Unit = + self.configureWebSocketClient(webSocketClientBuilder) } if (!runServerForEachTest) {