Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add proper identity propagation to WebSockets #20157

Merged
merged 1 commit into from
Oct 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bom/application/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
<opentelemetry.version>1.6.0</opentelemetry.version>
<opentelemetry-alpha.version>1.6.0-alpha</opentelemetry-alpha.version>
<jaeger.version>1.6.0</jaeger.version>
<quarkus-http.version>4.1.1</quarkus-http.version>
<quarkus-http.version>4.1.2</quarkus-http.version>
<micrometer.version>1.7.4</micrometer.version>
<google-auth.version>0.22.0</google-auth.version>
<microprofile-config-api.version>2.0</microprofile-config-api.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public final class OidcUtils {
* ignoring those which are located inside a pair of the double quotes.
*/
private static final Pattern CLAIM_PATH_PATTERN = Pattern.compile("\\/(?=(?:(?:[^\"]*\"){2})*[^\"]*$)");
public static final String QUARKUS_IDENTITY_EXPIRE_TIME = "quarkus.identity.expire-time";

private OidcUtils() {

Expand Down Expand Up @@ -163,6 +164,7 @@ static QuarkusSecurityIdentity validateAndCreateIdentity(
} catch (InvalidJwtException e) {
throw new AuthenticationFailedException(e);
}
builder.addAttribute(QUARKUS_IDENTITY_EXPIRE_TIME, jwtPrincipal.getExpirationTime());
builder.setPrincipal(jwtPrincipal);
setSecurityIdentityRoles(builder, config, rolesJson);
setSecurityIdentityUserInfo(builder, userInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.ConfigDescriptionBuildItem;
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.deployment.builditem.GeneratedClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
Expand All @@ -62,6 +63,7 @@
import io.quarkus.gizmo.MethodCreator;
import io.quarkus.gizmo.MethodDescriptor;
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.runtime.annotations.ConfigPhase;
import io.quarkus.runtime.metrics.MetricsFactory;
import io.quarkus.runtime.util.HashUtil;
import io.quarkus.smallrye.health.deployment.spi.HealthBuildItem;
Expand Down Expand Up @@ -142,7 +144,8 @@ void collectComponents(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished,
BuildProducer<MediatorBuildItem> mediatorMethods,
BuildProducer<EmitterBuildItem> emitters,
BuildProducer<ChannelBuildItem> channels,
BuildProducer<ValidationErrorBuildItem> validationErrors) {
BuildProducer<ValidationErrorBuildItem> validationErrors,
BuildProducer<ConfigDescriptionBuildItem> configDescriptionBuildItemBuildProducer) {

// We need to collect all business methods annotated with @Incoming/@Outgoing first
for (BeanInfo bean : beanDiscoveryFinished.beanStream().classBeans()) {
Expand All @@ -162,10 +165,20 @@ void collectComponents(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished,
validationErrors.produce(new ValidationErrorBuildItem(
new DeploymentException("Empty @Incoming annotation on method " + method)));
}
if (incoming != null) {
configDescriptionBuildItemBuildProducer.produce(new ConfigDescriptionBuildItem(
"mp.messaging.incoming." + incoming.value().asString() + ".connector", String.class, null,
"The connector to use", null, null, ConfigPhase.BUILD_TIME));
}
if (outgoing != null && outgoing.value().asString().isEmpty()) {
validationErrors.produce(new ValidationErrorBuildItem(
new DeploymentException("Empty @Outgoing annotation on method " + method)));
}
if (outgoing != null) {
configDescriptionBuildItemBuildProducer.produce(new ConfigDescriptionBuildItem(
"mp.messaging.outgoing." + outgoing.value().asString() + ".connector", String.class, null,
"The connector to use", null, null, ConfigPhase.BUILD_TIME));
}
if (isSynthetic(method.flags())) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import io.quarkus.builder.item.MultiBuildItem;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import io.quarkus.builder.item.SimpleBuildItem;
import io.quarkus.runtime.RuntimeValue;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import io.quarkus.builder.item.SimpleBuildItem;
import io.quarkus.undertow.websockets.client.runtime.ServerWebSocketContainerFactory;
import io.quarkus.websockets.client.runtime.ServerWebSocketContainerFactory;

public final class ServerWebSocketContainerFactoryBuildItem extends SimpleBuildItem {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import io.quarkus.builder.item.SimpleBuildItem;
import io.quarkus.runtime.RuntimeValue;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import java.lang.reflect.Modifier;
import java.util.Collection;
Expand Down Expand Up @@ -37,7 +37,7 @@
import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem;
import io.quarkus.runtime.RuntimeValue;
import io.quarkus.undertow.deployment.ServletContextAttributeBuildItem;
import io.quarkus.undertow.websockets.client.runtime.WebsocketCoreRecorder;
import io.quarkus.websockets.client.runtime.WebsocketCoreRecorder;
import io.undertow.websockets.DefaultContainerConfigurator;
import io.undertow.websockets.ServerWebSocketContainer;
import io.undertow.websockets.UndertowContainerProvider;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.deployment;
package io.quarkus.websockets.client.deployment;

import io.quarkus.runtime.annotations.ConfigItem;
import io.quarkus.runtime.annotations.ConfigPhase;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.test;
package io.quarkus.websockets.test;

import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.test;
package io.quarkus.websockets.test;

import java.net.URI;
import java.util.concurrent.CountDownLatch;
Expand Down
4 changes: 4 additions & 0 deletions extensions/websockets/client/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
<groupId>io.quarkus</groupId>
<artifactId>quarkus-vertx</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus.security</groupId>
<artifactId>quarkus-security</artifactId>
</dependency>
<dependency>
<groupId>jakarta.websocket</groupId>
<artifactId>jakarta.websocket-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.quarkus.websockets;

import java.util.Collections;
import java.util.List;
import java.util.Map;

import javax.websocket.ClientEndpointConfig;
import javax.websocket.Decoder;
import javax.websocket.Encoder;
import javax.websocket.Extension;

import io.vertx.core.http.HttpHeaders;

public class BearerTokenClientEndpointConfigurator implements ClientEndpointConfig {

final String token;

public BearerTokenClientEndpointConfigurator(String token) {
this.token = token;
}

@Override
public List<String> getPreferredSubprotocols() {
return Collections.emptyList();
}

@Override
public List<Extension> getExtensions() {
return Collections.emptyList();
}

@Override
public Configurator getConfigurator() {
return new Configurator() {
@Override
public void beforeRequest(Map<String, List<String>> headers) {
headers.put(HttpHeaders.AUTHORIZATION.toString(), Collections.singletonList("Bearer " + token));
}
};
}

@Override
public List<Class<? extends Encoder>> getEncoders() {
return Collections.emptyList();
}

@Override
public List<Class<? extends Decoder>> getDecoders() {
return Collections.emptyList();
}

@Override
public Map<String, Object> getUserProperties() {
return Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.client.runtime;
package io.quarkus.websockets.client.runtime;

import java.util.concurrent.Executor;
import java.util.function.Supplier;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkus.undertow.websockets.client.runtime;
package io.quarkus.websockets.client.runtime;

import java.net.InetSocketAddress;
import java.security.Principal;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.function.Supplier;
Expand All @@ -17,5 +18,6 @@ public interface ServerWebSocketContainerFactory {
ServerWebSocketContainer create(ObjectIntrospecter objectIntrospecter, ClassLoader classLoader,
Supplier<EventLoopGroup> eventLoopSupplier, List<ContextSetupHandler> contextSetupHandlers,
boolean dispatchToWorker, InetSocketAddress clientBindAddress, WebSocketReconnectHandler reconnectHandler,
Supplier<Executor> executorSupplier, List<Extension> installedExtensions, int maxFrameSize);
Supplier<Executor> executorSupplier, List<Extension> installedExtensions, int maxFrameSize,
Supplier<Principal> currentUserSupplier);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.quarkus.websockets.client.runtime;

import java.security.Principal;

import io.quarkus.security.identity.SecurityIdentity;

public class WebSocketPrincipal implements Principal {

final SecurityIdentity securityIdentity;

public WebSocketPrincipal(SecurityIdentity securityIdentity) {
this.securityIdentity = securityIdentity;
}

@Override
public String getName() {
return securityIdentity.getPrincipal().getName();
}

public SecurityIdentity getSecurityIdentity() {
return securityIdentity;
}

}
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package io.quarkus.undertow.websockets.client.runtime;
package io.quarkus.websockets.client.runtime;

import java.security.Principal;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.Supplier;

import javax.enterprise.inject.Instance;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerApplicationConfig;
Expand All @@ -19,9 +21,11 @@
import io.quarkus.arc.runtime.BeanContainer;
import io.quarkus.runtime.RuntimeValue;
import io.quarkus.runtime.annotations.Recorder;
import io.quarkus.security.identity.CurrentIdentityAssociation;
import io.quarkus.vertx.core.runtime.VertxCoreRecorder;
import io.undertow.websockets.ServerWebSocketContainer;
import io.undertow.websockets.UndertowContainerProvider;
import io.undertow.websockets.UndertowSession;
import io.undertow.websockets.WebSocketDeploymentInfo;
import io.undertow.websockets.util.ContextSetupHandler;
import io.undertow.websockets.util.ObjectFactory;
Expand Down Expand Up @@ -115,6 +119,9 @@ public RuntimeValue<ServerWebSocketContainer> createServerContainer(BeanContaine
if (serverContainerFactory == null) {
serverContainerFactory = ServerWebSocketContainer::new;
}
Instance<CurrentIdentityAssociation> currentIdentityAssociation = Arc.container()
.select(CurrentIdentityAssociation.class);

ServerWebSocketContainer container = serverContainerFactory.create(new ObjectIntrospecter() {
@Override
public <T> ObjectFactory<T> createInstanceFactory(Class<T> clazz) {
Expand Down Expand Up @@ -147,16 +154,31 @@ public EventLoopGroup get() {
@Override
public <T, C> Action<T, C> create(Action<T, C> action) {
return new Action<T, C>() {

CurrentIdentityAssociation getCurrentIdentityAssociation() {
if (currentIdentityAssociation.isResolvable()) {
return currentIdentityAssociation.get();
}
return null;
}

@Override
public T call(C context) throws Exception {
public T call(C context, UndertowSession session) throws Exception {
ClassLoader old = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(cl);
boolean required = !requestContext.isActive();
if (required) {
requestContext.activate();
Principal p = session.getUserPrincipal();
if (p instanceof WebSocketPrincipal) {
var current = getCurrentIdentityAssociation();
if (current != null) {
current.setIdentity(((WebSocketPrincipal) p).getSecurityIdentity());
}
}
}
try {
return action.call(context);
return action.call(context, session);
} finally {
try {
if (required) {
Expand All @@ -175,7 +197,15 @@ public T call(C context) throws Exception {
null,
info.getExecutor(),
Collections.emptyList(),
info.getMaxFrameSize());
info.getMaxFrameSize(), new Supplier<Principal>() {
@Override
public Principal get() {
if (currentIdentityAssociation.isResolvable()) {
return new WebSocketPrincipal(currentIdentityAssociation.get().getIdentity());
}
return null;
}
});
for (Class<?> i : info.getAnnotatedEndpoints()) {
container.addEndpoint(i);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.deployment;
package io.quarkus.websockets.deployment;

import java.lang.reflect.Modifier;
import java.util.Collection;
Expand All @@ -20,13 +20,13 @@
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.undertow.websockets.client.deployment.AnnotatedWebsocketEndpointBuildItem;
import io.quarkus.undertow.websockets.client.deployment.ServerWebSocketContainerBuildItem;
import io.quarkus.undertow.websockets.client.deployment.ServerWebSocketContainerFactoryBuildItem;
import io.quarkus.undertow.websockets.client.deployment.WebSocketDeploymentInfoBuildItem;
import io.quarkus.undertow.websockets.client.deployment.WebsocketClientProcessor;
import io.quarkus.undertow.websockets.runtime.WebsocketServerRecorder;
import io.quarkus.vertx.http.deployment.FilterBuildItem;
import io.quarkus.websockets.client.deployment.AnnotatedWebsocketEndpointBuildItem;
import io.quarkus.websockets.client.deployment.ServerWebSocketContainerBuildItem;
import io.quarkus.websockets.client.deployment.ServerWebSocketContainerFactoryBuildItem;
import io.quarkus.websockets.client.deployment.WebSocketDeploymentInfoBuildItem;
import io.quarkus.websockets.client.deployment.WebsocketClientProcessor;
import io.quarkus.websockets.runtime.WebsocketServerRecorder;

public class ServerWebSocketProcessor {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.test;
package io.quarkus.websockets.test;

import javax.enterprise.context.RequestScoped;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.test;
package io.quarkus.websockets.test;

import javax.inject.Inject;
import javax.websocket.OnMessage;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkus.undertow.websockets.test;
package io.quarkus.websockets.test;

import java.net.URI;
import java.util.concurrent.LinkedBlockingDeque;
Expand Down
Loading