Skip to content

Commit

Permalink
Merge pull request quarkusio#40857 from michalvavrik/feature/websocke…
Browse files Browse the repository at this point in the history
…ts-close-onauth-expired

WebSocket NEXT: automatically close connection when OIDC extension provides SecurityIdentity and token expires
  • Loading branch information
sberyozkin authored May 28, 2024
2 parents 95d0eae + 514c426 commit 22ec638
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docs/src/main/asciidoc/websockets-next-reference.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,8 @@ quarkus.http.auth.permission.secured.policy=authenticated

Other options for securing HTTP upgrade requests, such as using the security annotations, will be explored in the future.

NOTE: When OpenID Connect extension is used and token expires, Quarkus automatically closes connection.

[[websocket-next-configuration-reference]]
== Configuration reference

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package io.quarkus.websockets.next.test.security;

import static io.quarkus.websockets.next.test.security.SecurityTestBase.basicAuth;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.time.Duration;
import java.util.concurrent.atomic.AtomicReference;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import org.awaitility.Awaitility;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.security.Authenticated;
import io.quarkus.security.identity.AuthenticationRequestContext;
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.security.identity.SecurityIdentityAugmentor;
import io.quarkus.security.runtime.QuarkusSecurityIdentity;
import io.quarkus.security.test.utils.TestIdentityController;
import io.quarkus.security.test.utils.TestIdentityProvider;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.CloseReason;
import io.quarkus.websockets.next.OnClose;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.smallrye.mutiny.Uni;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;

public class AuthenticationExpiredTest {

@Inject
Vertx vertx;

@TestHTTPResource("end")
URI endUri;

@BeforeAll
public static void setupUsers() {
TestIdentityController.resetRoles()
.add("admin", "admin", "admin")
.add("user", "user", "user");
}

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot(root -> root.addClasses(Endpoint.class, TestIdentityProvider.class,
TestIdentityController.class, WSClient.class, ExpiredIdentityAugmentor.class, SecurityTestBase.class));

@Test
public void testConnectionClosedWhenAuthExpires() {
try (WSClient client = new WSClient(vertx)) {
client.connect(basicAuth("admin", "admin"), endUri);

long threeSecondsFromNow = Duration.ofMillis(System.currentTimeMillis()).plusSeconds(3).toMillis();
for (int i = 1; true; i++) {
if (client.isClosed()) {
break;
} else if (System.currentTimeMillis() > threeSecondsFromNow) {
Assertions.fail("Authentication expired, therefore connection should had been closed");
}
client.sendAndAwaitReply("Hello #" + i + " from ");
}

var receivedMessages = client.getMessages().stream().map(Buffer::toString).toList();
assertTrue(receivedMessages.size() > 2, receivedMessages.toString());
assertTrue(receivedMessages.contains("Hello #1 from admin"), receivedMessages.toString());
assertTrue(receivedMessages.contains("Hello #2 from admin"), receivedMessages.toString());
assertEquals(1008, client.closeStatusCode(), "Expected close status 1008, but got " + client.closeStatusCode());

Awaitility
.await()
.atMost(Duration.ofSeconds(1))
.untilAsserted(() -> assertTrue(Endpoint.CLOSED_MESSAGE.get()
.startsWith("Connection closed with reason 'Authentication expired'")));
}
}

@Singleton
public static class ExpiredIdentityAugmentor implements SecurityIdentityAugmentor {

@Override
public Uni<SecurityIdentity> augment(SecurityIdentity securityIdentity,
AuthenticationRequestContext authenticationRequestContext) {
return Uni
.createFrom()
.item(QuarkusSecurityIdentity
.builder(securityIdentity)
.addAttribute("quarkus.identity.expire-time", expireIn2Seconds())
.build());
}

private static long expireIn2Seconds() {
return Duration.ofMillis(System.currentTimeMillis())
.plusSeconds(2)
.toSeconds();
}
}

@WebSocket(path = "/end")
public static class Endpoint {

static final AtomicReference<String> CLOSED_MESSAGE = new AtomicReference<>();

@Inject
SecurityIdentity currentIdentity;

@Authenticated
@OnTextMessage
String echo(String message) {
return message + currentIdentity.getPrincipal().getName();
}

@OnClose
void close(CloseReason reason, WebSocketConnection connection) {
CLOSED_MESSAGE.set("Connection closed with reason '%s': %s".formatted(reason.getMessage(), connection));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ public void handle(Void event) {
handleFailure(unhandledFailureStrategy, r.cause(), "Unable to complete @OnClose callback",
connection);
}
securitySupport.onClose();
onClose.run();
if (timerId != null) {
vertx.cancelTimer(timerId);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,36 @@
package io.quarkus.websockets.next.runtime;

import java.util.Objects;
import java.util.concurrent.TimeUnit;

import jakarta.enterprise.inject.Instance;

import org.jboss.logging.Logger;

import io.quarkus.security.identity.CurrentIdentityAssociation;
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.websockets.next.CloseReason;
import io.vertx.core.Vertx;

public class SecuritySupport {

static final SecuritySupport NOOP = new SecuritySupport(null, null);
private static final Logger LOG = Logger.getLogger(SecuritySupport.class);
static final SecuritySupport NOOP = new SecuritySupport(null, null, null, null);

private final Instance<CurrentIdentityAssociation> currentIdentity;
private final SecurityIdentity identity;
private final Runnable onClose;

SecuritySupport(Instance<CurrentIdentityAssociation> currentIdentity, SecurityIdentity identity) {
SecuritySupport(Instance<CurrentIdentityAssociation> currentIdentity, SecurityIdentity identity, Vertx vertx,
WebSocketConnectionImpl connection) {
this.currentIdentity = currentIdentity;
this.identity = currentIdentity != null ? Objects.requireNonNull(identity) : identity;
if (this.currentIdentity != null) {
this.identity = Objects.requireNonNull(identity);
this.onClose = closeConnectionWhenIdentityExpired(vertx, connection, this.identity);
} else {
this.identity = null;
this.onClose = null;
}
}

/**
Expand All @@ -29,4 +43,25 @@ void start() {
}
}

void onClose() {
if (onClose != null) {
onClose.run();
}
}

private static Runnable closeConnectionWhenIdentityExpired(Vertx vertx, WebSocketConnectionImpl connection,
SecurityIdentity identity) {
if (identity.getAttribute("quarkus.identity.expire-time") instanceof Long expireAt) {
long timerId = vertx.setTimer(TimeUnit.SECONDS.toMillis(expireAt) - System.currentTimeMillis(),
ignored -> connection
.close(new CloseReason(1008, "Authentication expired"))
.subscribe()
.with(
v -> LOG.tracef("Closed connection due to expired authentication: %s", connection),
e -> LOG.errorf("Unable to close connection [%s] after authentication "
+ "expired due to unhandled failure: %s", connection, e)));
return () -> vertx.cancelTimer(timerId);
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ public Handler<RoutingContext> createEndpointHandler(String generatedEndpointCla

@Override
public void handle(RoutingContext ctx) {
SecuritySupport securitySupport = initializeSecuritySupport(container, ctx);

Future<ServerWebSocket> future = ctx.request().toWebSocket();
future.onSuccess(ws -> {
Vertx vertx = VertxCoreRecorder.getVertx().get();
Expand All @@ -101,6 +99,8 @@ public void handle(RoutingContext ctx) {
connectionManager.add(generatedEndpointClass, connection);
LOG.debugf("Connection created: %s", connection);

SecuritySupport securitySupport = initializeSecuritySupport(container, ctx, vertx, connection);

Endpoints.initialize(vertx, container, codecs, connection, ws, generatedEndpointClass,
config.autoPingInterval(), securitySupport, config.unhandledFailureStrategy(),
() -> connectionManager.remove(generatedEndpointClass, connection));
Expand All @@ -109,14 +109,15 @@ public void handle(RoutingContext ctx) {
};
}

SecuritySupport initializeSecuritySupport(ArcContainer container, RoutingContext ctx) {
SecuritySupport initializeSecuritySupport(ArcContainer container, RoutingContext ctx, Vertx vertx,
WebSocketConnectionImpl connection) {
Instance<CurrentIdentityAssociation> currentIdentityAssociation = container.select(CurrentIdentityAssociation.class);
if (currentIdentityAssociation.isResolvable()) {
// Security extension is present
// Obtain the current security identity from the handshake request
QuarkusHttpUser user = (QuarkusHttpUser) ctx.user();
if (user != null) {
return new SecuritySupport(currentIdentityAssociation, user.getSecurityIdentity());
return new SecuritySupport(currentIdentityAssociation, user.getSecurityIdentity(), vertx, connection);
}
}
return SecuritySupport.NOOP;
Expand Down

0 comments on commit 22ec638

Please sign in to comment.