Skip to content

Commit

Permalink
Add webSocketClient in ServerExtension (#5765)
Browse files Browse the repository at this point in the history
Motivation:

Added `webSocketClient` which directly connects to `ServerExtension` 

Modifications:

- Added `configureWebSocketClient()` to:
  - `ServerExtension` for JUnit 5
  - `ServerRule` for JUnit 4
  - `ServerSuite` for ScalaTest
  - `ServerRuleDelegate`
- Added `webSocketClient` property to:
  - `ServerExtension` for JUnit 5
  - `ServerRule` for JUnit 4
  - `ServerRuleDelegate`

Result:

- Closes #5538

---------

Co-authored-by: minux <[email protected]>
  • Loading branch information
seonWKim and minwoox authored Jun 20, 2024
1 parent 222131a commit 071a342
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import com.linecorp.armeria.client.RestClient;
import com.linecorp.armeria.client.WebClient;
import com.linecorp.armeria.client.WebClientBuilder;
import com.linecorp.armeria.client.websocket.WebSocketClient;
import com.linecorp.armeria.client.websocket.WebSocketClientBuilder;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.UnstableApi;
Expand Down Expand Up @@ -68,6 +70,12 @@ public void configure(ServerBuilder sb) throws Exception {
public void configureWebClient(WebClientBuilder wcb) throws Exception {
ServerRule.this.configureWebClient(wcb);
}

@Override
public void configureWebSocketClient(WebSocketClientBuilder wscb)
throws Exception {
ServerRule.this.configureWebSocketClient(wscb);
}
};
}

Expand Down Expand Up @@ -109,6 +117,12 @@ public Server start() {
*/
protected void configureWebClient(WebClientBuilder webClientBuilder) throws Exception {}

/**
* Configures the {@link WebSocketClient} with the given {@link WebSocketClientBuilder}.
* You can get the configured {@link WebSocketClient} using {@link #webSocketClient()}.
*/
protected void configureWebSocketClient(WebSocketClientBuilder webSocketClientBuilder) throws Exception {}

/**
* Stops the {@link Server} asynchronously.
*
Expand Down Expand Up @@ -344,4 +358,13 @@ public RestClient restClient(Consumer<WebClientBuilder> webClientCustomizer) {
requireNonNull(webClientCustomizer, "webClientCustomizer");
return delegate.restClient(webClientCustomizer);
}

/**
* Returns the {@link WebSocketClient} configured
* by {@link #configureWebSocketClient(WebSocketClientBuilder)}.
*/
@UnstableApi
public WebSocketClient webSocketClient() {
return delegate.webSocketClient();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import com.linecorp.armeria.client.RestClient;
import com.linecorp.armeria.client.WebClient;
import com.linecorp.armeria.client.WebClientBuilder;
import com.linecorp.armeria.client.websocket.WebSocketClient;
import com.linecorp.armeria.client.websocket.WebSocketClientBuilder;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
Expand All @@ -49,6 +51,7 @@ public abstract class ServerRuleDelegate {
private final boolean autoStart;

private final AtomicReference<WebClient> webClient = new AtomicReference<>();
private final AtomicReference<WebSocketClient> webSocketClient = new AtomicReference<>();

/**
* Creates a new instance.
Expand Down Expand Up @@ -114,6 +117,14 @@ public Server start() {
*/
public abstract void configureWebClient(WebClientBuilder webClientBuilder) throws Exception;

/**
* Configures the {@link WebSocketClient} with the given {@link WebSocketClientBuilder}.
* You can get the configured {@link WebSocketClient} using {@link #webSocketClient()}.
*/
@UnstableApi
public abstract void configureWebSocketClient(WebSocketClientBuilder webSocketClientBuilder)
throws Exception;

/**
* Stops the {@link Server} asynchronously.
*
Expand Down Expand Up @@ -404,6 +415,25 @@ public RestClient restClient(Consumer<WebClientBuilder> webClientCustomizer) {
return webClient(webClientCustomizer).asRestClient();
}

/**
* Returns the {@link WebSocketClient} configured
* by {@link #configureWebSocketClient(WebSocketClientBuilder)}.
*/
@UnstableApi
public WebSocketClient webSocketClient() {
final WebSocketClient webSocketClient = this.webSocketClient.get();
if (webSocketClient != null) {
return webSocketClient;
}

final WebSocketClient newWebSocketClient = webSocketClientBuilder().build();
if (this.webSocketClient.compareAndSet(null, newWebSocketClient)) {
return newWebSocketClient;
} else {
return this.webSocketClient.get();
}
}

private void ensureStarted() {
// This will ensure that the server has started.
server();
Expand All @@ -422,4 +452,21 @@ private WebClientBuilder webClientBuilder() {
}
return webClientBuilder;
}

private WebSocketClientBuilder webSocketClientBuilder() {
final boolean hasHttps = hasHttps();
final String hostAndPort = hasHttps ? "wss://" + httpsUri().getAuthority()
: "ws://" + httpUri().getAuthority();
final WebSocketClientBuilder webSocketClientBuilder = WebSocketClient.builder(hostAndPort);
if (hasHttps) {
webSocketClientBuilder.factory(ClientFactory.insecure());
}

try {
configureWebSocketClient(webSocketClientBuilder);
} catch (Exception e) {
throw new IllegalStateException("failed to configure a WebSocketClient", e);
}
return webSocketClientBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import com.linecorp.armeria.client.RestClient;
import com.linecorp.armeria.client.WebClient;
import com.linecorp.armeria.client.WebClientBuilder;
import com.linecorp.armeria.client.websocket.WebSocketClient;
import com.linecorp.armeria.client.websocket.WebSocketClientBuilder;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.UnstableApi;
Expand Down Expand Up @@ -75,6 +77,12 @@ public void configure(ServerBuilder sb) throws Exception {
public void configureWebClient(WebClientBuilder wcb) throws Exception {
ServerExtension.this.configureWebClient(wcb);
}

@Override
public void configureWebSocketClient(WebSocketClientBuilder wscb)
throws Exception {
ServerExtension.this.configureWebSocketClient(wscb);
}
};
}

Expand Down Expand Up @@ -126,6 +134,12 @@ public Server start() {
*/
protected void configureWebClient(WebClientBuilder webClientBuilder) throws Exception {}

/**
* Configures the {@link WebSocketClient} with the given {@link WebSocketClientBuilder}.
* You can get the configured {@link WebSocketClient} using {@link #webSocketClient()}.
*/
protected void configureWebSocketClient(WebSocketClientBuilder webSocketClientBuilder) throws Exception {}

/**
* Stops the {@link Server} asynchronously.
*
Expand Down Expand Up @@ -370,6 +384,15 @@ public RestClient restClient(Consumer<WebClientBuilder> webClientCustomizer) {
return delegate.restClient(webClientCustomizer);
}

/**
* Returns the {@link WebSocketClient} configured
* by {@link #configureWebSocketClient(WebSocketClientBuilder)}.
*/
@UnstableApi
public WebSocketClient webSocketClient() {
return delegate.webSocketClient();
}

/**
* Determines whether the {@link ServiceRequestContext} should be captured or not.
* This method returns {@code true} by default. Override it to capture the contexts
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright 2024 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.testing.junit5.server;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;

import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

import com.linecorp.armeria.client.websocket.WebSocketClient;
import com.linecorp.armeria.client.websocket.WebSocketSession;
import com.linecorp.armeria.common.websocket.WebSocket;
import com.linecorp.armeria.common.websocket.WebSocketFrame;
import com.linecorp.armeria.common.websocket.WebSocketWriter;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.websocket.WebSocketService;
import com.linecorp.armeria.server.websocket.WebSocketServiceHandler;

class ServerExtensionWithWebSocketClientTest {

@RegisterExtension
static ServerExtension wsServer = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) {
sb.service("/chat", WebSocketService.builder(new WebSocketEchoHandler())
.build());
}
};

@RegisterExtension
static ServerExtension wssServer = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) {
sb.tlsSelfSigned();
sb.service("/chat", WebSocketService.builder(new WebSocketEchoHandler())
.build());
}
};

@CsvSource({ "true", "false" })
@ParameterizedTest
void webSocketClient(boolean useTls) {
final WebSocketClient webSocketClient = useTls ? wssServer.webSocketClient()
: wsServer.webSocketClient();
final WebSocketSession wsSession = webSocketClient.connect("/chat").join();
assertThat(wsSession).isNotNull();
final WebSocketWriter outbound = wsSession.outbound();
outbound.write("hello");
final String message = useTls ? "wss" : "ws";
outbound.write(message);
outbound.close();
final List<String> responses = wsSession.inbound().collect().join().stream().map(WebSocketFrame::text)
.collect(toImmutableList());
assertThat(responses).contains("hello", message);
}

static final class WebSocketEchoHandler implements WebSocketServiceHandler {

@Override
public WebSocket handle(ServiceRequestContext ctx, WebSocket in) {
final WebSocketWriter writer = WebSocket.streaming();
in.subscribe(new Subscriber<WebSocketFrame>() {
@Override
public void onSubscribe(Subscription s) {
s.request(Long.MAX_VALUE);
}

@Override
public void onNext(WebSocketFrame webSocketFrame) {
writer.write(webSocketFrame);
}

@Override
public void onError(Throwable t) {
writer.close(t);
}

@Override
public void onComplete() {
writer.close();
}
});
return writer;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package com.linecorp.armeria.server.sangria

import com.linecorp.armeria.client.{WebClient, WebClientBuilder}
import com.linecorp.armeria.client.logging.LoggingClient
import com.linecorp.armeria.client.WebClientBuilder
import com.linecorp.armeria.client.websocket.WebSocketClientBuilder
import com.linecorp.armeria.internal.testing.ServerRuleDelegate
import com.linecorp.armeria.server.ServerBuilder
import munit.Suite
Expand All @@ -33,6 +33,8 @@ trait ServerSuite {

protected def configureWebClient: WebClientBuilder => Unit = _ => ()

protected def configureWebSocketClient: WebSocketClientBuilder => Unit = _ => ()

protected def server: ServerRuleDelegate = delegate

/**
Expand All @@ -46,6 +48,9 @@ trait ServerSuite {
override def configure(sb: ServerBuilder): Unit = configureServer(sb)

override def configureWebClient(wcb: WebClientBuilder): Unit = self.configureWebClient(wcb)

override def configureWebSocketClient(webSocketClientBuilder: WebSocketClientBuilder): Unit =
self.configureWebSocketClient(webSocketClientBuilder)
}

if (!runServerForEachTest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.linecorp.armeria.server

import com.linecorp.armeria.client.WebClientBuilder
import com.linecorp.armeria.client.websocket.WebSocketClientBuilder
import com.linecorp.armeria.internal.testing.ServerRuleDelegate
import munit.Suite

Expand All @@ -29,6 +30,8 @@ trait ServerSuite {

protected def configureWebClient: WebClientBuilder => Unit = _ => ()

protected def configureWebSocketClient: WebSocketClientBuilder => Unit = _ => ()

protected def server: ServerRuleDelegate = delegate

/**
Expand All @@ -42,6 +45,9 @@ trait ServerSuite {
override def configure(sb: ServerBuilder): Unit = configureServer(sb)

override def configureWebClient(wcb: WebClientBuilder): Unit = self.configureWebClient(wcb)

override def configureWebSocketClient(wscb: WebSocketClientBuilder): Unit =
self.configureWebSocketClient(wscb)
}

if (!runServerForEachTest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.linecorp.armeria.server.scalapb

import com.linecorp.armeria.client.WebClientBuilder
import com.linecorp.armeria.client.websocket.WebSocketClientBuilder
import com.linecorp.armeria.internal.testing.ServerRuleDelegate
import com.linecorp.armeria.server.ServerBuilder
import munit.Suite
Expand All @@ -30,6 +31,8 @@ trait ServerSuite {

protected def configureWebClient: WebClientBuilder => Unit = _ => ()

protected def configureWebSocketClient: WebSocketClientBuilder => Unit = _ => ()

protected def server: ServerRuleDelegate = delegate

/**
Expand All @@ -43,6 +46,9 @@ trait ServerSuite {
override def configure(sb: ServerBuilder): Unit = configureServer(sb)

override def configureWebClient(wcb: WebClientBuilder): Unit = self.configureWebClient(wcb)

override def configureWebSocketClient(webSocketClientBuilder: WebSocketClientBuilder): Unit =
self.configureWebSocketClient(webSocketClientBuilder)
}

if (!runServerForEachTest) {
Expand Down

0 comments on commit 071a342

Please sign in to comment.