Skip to content

Commit

Permalink
Add proper identity propagation to WebSockets
Browse files Browse the repository at this point in the history
Fixes #16602
  • Loading branch information
stuartwdouglas committed Sep 15, 2021
1 parent db2f6b5 commit 6f265d6
Show file tree
Hide file tree
Showing 24 changed files with 292 additions and 35 deletions.
2 changes: 1 addition & 1 deletion bom/application/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
<opentelemetry.version>1.5.0</opentelemetry.version>
<opentelemetry-alpha.version>1.5.0-alpha</opentelemetry-alpha.version>
<jaeger.version>1.6.0</jaeger.version>
<quarkus-http.version>4.1.1</quarkus-http.version>
<quarkus-http.version>4.1.2</quarkus-http.version>
<micrometer.version>1.7.3</micrometer.version>
<google-auth.version>0.22.0</google-auth.version>
<microprofile-config-api.version>2.0</microprofile-config-api.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.ConfigDescriptionBuildItem;
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.deployment.builditem.GeneratedClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
Expand All @@ -62,6 +63,7 @@
import io.quarkus.gizmo.MethodCreator;
import io.quarkus.gizmo.MethodDescriptor;
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.runtime.annotations.ConfigPhase;
import io.quarkus.runtime.metrics.MetricsFactory;
import io.quarkus.runtime.util.HashUtil;
import io.quarkus.smallrye.health.deployment.spi.HealthBuildItem;
Expand Down Expand Up @@ -142,7 +144,8 @@ void collectComponents(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished,
BuildProducer<MediatorBuildItem> mediatorMethods,
BuildProducer<EmitterBuildItem> emitters,
BuildProducer<ChannelBuildItem> channels,
BuildProducer<ValidationErrorBuildItem> validationErrors) {
BuildProducer<ValidationErrorBuildItem> validationErrors,
BuildProducer<ConfigDescriptionBuildItem> configDescriptionBuildItemBuildProducer) {

// We need to collect all business methods annotated with @Incoming/@Outgoing first
for (BeanInfo bean : beanDiscoveryFinished.beanStream().classBeans()) {
Expand All @@ -162,10 +165,20 @@ void collectComponents(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished,
validationErrors.produce(new ValidationErrorBuildItem(
new DeploymentException("Empty @Incoming annotation on method " + method)));
}
if (incoming != null) {
configDescriptionBuildItemBuildProducer.produce(new ConfigDescriptionBuildItem(
"mp.messaging.incoming." + incoming.value().asString() + ".connector", String.class, null,
"The connector to use", null, null, ConfigPhase.BUILD_TIME));
}
if (outgoing != null && outgoing.value().asString().isEmpty()) {
validationErrors.produce(new ValidationErrorBuildItem(
new DeploymentException("Empty @Outgoing annotation on method " + method)));
}
if (outgoing != null) {
configDescriptionBuildItemBuildProducer.produce(new ConfigDescriptionBuildItem(
"mp.messaging.outgoing." + outgoing.value().asString() + ".connector", String.class, null,
"The connector to use", null, null, ConfigPhase.BUILD_TIME));
}
if (isSynthetic(method.flags())) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import io.quarkus.builder.item.MultiBuildItem;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import io.quarkus.builder.item.SimpleBuildItem;
import io.quarkus.runtime.RuntimeValue;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import io.quarkus.builder.item.SimpleBuildItem;
import io.quarkus.undertow.websockets.client.runtime.ServerWebSocketContainerFactory;
import io.quarkus.websockets.client.runtime.ServerWebSocketContainerFactory;

public final class ServerWebSocketContainerFactoryBuildItem extends SimpleBuildItem {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import io.quarkus.builder.item.SimpleBuildItem;
import io.quarkus.runtime.RuntimeValue;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import java.lang.reflect.Modifier;
import java.util.Collection;
Expand Down Expand Up @@ -37,7 +37,7 @@
import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem;
import io.quarkus.runtime.RuntimeValue;
import io.quarkus.undertow.deployment.ServletContextAttributeBuildItem;
import io.quarkus.undertow.websockets.client.runtime.WebsocketCoreRecorder;
import io.quarkus.websockets.client.runtime.WebsocketCoreRecorder;
import io.undertow.websockets.DefaultContainerConfigurator;
import io.undertow.websockets.ServerWebSocketContainer;
import io.undertow.websockets.UndertowContainerProvider;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import io.quarkus.runtime.annotations.ConfigItem;
import io.quarkus.runtime.annotations.ConfigPhase;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.test;
package io.quarkus.websockets.test;

import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.test;
package io.quarkus.websockets.test;

import java.net.URI;
import java.util.concurrent.CountDownLatch;
Expand Down
4 changes: 4 additions & 0 deletions extensions/websockets/client/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
<groupId>io.quarkus</groupId>
<artifactId>quarkus-vertx</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus.security</groupId>
<artifactId>quarkus-security</artifactId>
</dependency>
<dependency>
<groupId>jakarta.websocket</groupId>
<artifactId>jakarta.websocket-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.quarkus.websockets;

import java.util.Collections;
import java.util.List;
import java.util.Map;

import javax.websocket.ClientEndpointConfig;
import javax.websocket.Decoder;
import javax.websocket.Encoder;
import javax.websocket.Extension;

import io.vertx.core.http.HttpHeaders;

public class BearerTokenClientEndpointConfigurator implements ClientEndpointConfig {

final String token;

public BearerTokenClientEndpointConfigurator(String token) {
this.token = token;
}

@Override
public List<String> getPreferredSubprotocols() {
return Collections.emptyList();
}

@Override
public List<Extension> getExtensions() {
return Collections.emptyList();
}

@Override
public Configurator getConfigurator() {
return new Configurator() {
@Override
public void beforeRequest(Map<String, List<String>> headers) {
headers.put(HttpHeaders.AUTHORIZATION.toString(), Collections.singletonList("Bearer " + token));
}
};
}

@Override
public List<Class<? extends Encoder>> getEncoders() {
return Collections.emptyList();
}

@Override
public List<Class<? extends Decoder>> getDecoders() {
return Collections.emptyList();
}

@Override
public Map<String, Object> getUserProperties() {
return Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.runtime;
package io.quarkus.websockets.client.runtime;

import java.util.concurrent.Executor;
import java.util.function.Supplier;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkus.undertow.websockets.client.runtime;
package io.quarkus.websockets.client.runtime;

import java.net.InetSocketAddress;
import java.security.Principal;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.function.Supplier;
Expand All @@ -17,5 +18,6 @@ public interface ServerWebSocketContainerFactory {
ServerWebSocketContainer create(ObjectIntrospecter objectIntrospecter, ClassLoader classLoader,
Supplier<EventLoopGroup> eventLoopSupplier, List<ContextSetupHandler> contextSetupHandlers,
boolean dispatchToWorker, InetSocketAddress clientBindAddress, WebSocketReconnectHandler reconnectHandler,
Supplier<Executor> executorSupplier, List<Extension> installedExtensions, int maxFrameSize);
Supplier<Executor> executorSupplier, List<Extension> installedExtensions, int maxFrameSize,
Supplier<Principal> currentUserSupplier);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.quarkus.websockets.client.runtime;

import java.security.Principal;

import io.quarkus.security.identity.SecurityIdentity;

public class WebSocketPrincipal implements Principal {

final SecurityIdentity securityIdentity;

public WebSocketPrincipal(SecurityIdentity securityIdentity) {
this.securityIdentity = securityIdentity;
}

@Override
public String getName() {
return securityIdentity.getPrincipal().getName();
}

public SecurityIdentity getSecurityIdentity() {
return securityIdentity;
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package io.quarkus.undertow.websockets.client.runtime;
package io.quarkus.websockets.client.runtime;

import java.security.Principal;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.Supplier;

import javax.enterprise.inject.Instance;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerApplicationConfig;
Expand All @@ -19,9 +21,11 @@
import io.quarkus.arc.runtime.BeanContainer;
import io.quarkus.runtime.RuntimeValue;
import io.quarkus.runtime.annotations.Recorder;
import io.quarkus.security.identity.CurrentIdentityAssociation;
import io.quarkus.vertx.core.runtime.VertxCoreRecorder;
import io.undertow.websockets.ServerWebSocketContainer;
import io.undertow.websockets.UndertowContainerProvider;
import io.undertow.websockets.UndertowSession;
import io.undertow.websockets.WebSocketDeploymentInfo;
import io.undertow.websockets.util.ContextSetupHandler;
import io.undertow.websockets.util.ObjectFactory;
Expand Down Expand Up @@ -115,6 +119,9 @@ public RuntimeValue<ServerWebSocketContainer> createServerContainer(BeanContaine
if (serverContainerFactory == null) {
serverContainerFactory = ServerWebSocketContainer::new;
}
Instance<CurrentIdentityAssociation> currentIdentityAssociation = Arc.container()
.select(CurrentIdentityAssociation.class);

ServerWebSocketContainer container = serverContainerFactory.create(new ObjectIntrospecter() {
@Override
public <T> ObjectFactory<T> createInstanceFactory(Class<T> clazz) {
Expand Down Expand Up @@ -147,16 +154,31 @@ public EventLoopGroup get() {
@Override
public <T, C> Action<T, C> create(Action<T, C> action) {
return new Action<T, C>() {

CurrentIdentityAssociation getCurrentIdentityAssociation() {
if (currentIdentityAssociation.isResolvable()) {
return currentIdentityAssociation.get();
}
return null;
}

@Override
public T call(C context) throws Exception {
public T call(C context, UndertowSession session) throws Exception {
ClassLoader old = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(cl);
boolean required = !requestContext.isActive();
if (required) {
requestContext.activate();
Principal p = session.getUserPrincipal();
if (p instanceof WebSocketPrincipal) {
var current = getCurrentIdentityAssociation();
if (current != null) {
current.setIdentity(((WebSocketPrincipal) p).getSecurityIdentity());
}
}
}
try {
return action.call(context);
return action.call(context, session);
} finally {
try {
if (required) {
Expand All @@ -175,7 +197,15 @@ public T call(C context) throws Exception {
null,
info.getExecutor(),
Collections.emptyList(),
info.getMaxFrameSize());
info.getMaxFrameSize(), new Supplier<Principal>() {
@Override
public Principal get() {
if (currentIdentityAssociation.isResolvable()) {
return new WebSocketPrincipal(currentIdentityAssociation.get().getIdentity());
}
return null;
}
});
for (Class<?> i : info.getAnnotatedEndpoints()) {
container.addEndpoint(i);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.deployment;
package io.quarkus.websockets.deployment;

import java.lang.reflect.Modifier;
import java.util.Collection;
Expand All @@ -20,13 +20,13 @@
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.undertow.websockets.client.deployment.AnnotatedWebsocketEndpointBuildItem;
import io.quarkus.undertow.websockets.client.deployment.ServerWebSocketContainerBuildItem;
import io.quarkus.undertow.websockets.client.deployment.ServerWebSocketContainerFactoryBuildItem;
import io.quarkus.undertow.websockets.client.deployment.WebSocketDeploymentInfoBuildItem;
import io.quarkus.undertow.websockets.client.deployment.WebsocketClientProcessor;
import io.quarkus.undertow.websockets.runtime.WebsocketServerRecorder;
import io.quarkus.vertx.http.deployment.FilterBuildItem;
import io.quarkus.websockets.client.deployment.AnnotatedWebsocketEndpointBuildItem;
import io.quarkus.websockets.client.deployment.ServerWebSocketContainerBuildItem;
import io.quarkus.websockets.client.deployment.ServerWebSocketContainerFactoryBuildItem;
import io.quarkus.websockets.client.deployment.WebSocketDeploymentInfoBuildItem;
import io.quarkus.websockets.client.deployment.WebsocketClientProcessor;
import io.quarkus.websockets.runtime.WebsocketServerRecorder;

public class ServerWebSocketProcessor {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.test;
package io.quarkus.websockets.test;

import javax.enterprise.context.RequestScoped;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.test;
package io.quarkus.websockets.test;

import javax.inject.Inject;
import javax.websocket.OnMessage;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.test;
package io.quarkus.websockets.test;

import java.net.URI;
import java.util.concurrent.LinkedBlockingDeque;
Expand Down
Loading

0 comments on commit 6f265d6

Please sign in to comment.