diff --git a/extensions/grpc/deployment/pom.xml b/extensions/grpc/deployment/pom.xml index 13d7a0ca2b066..3644bf237a396 100644 --- a/extensions/grpc/deployment/pom.xml +++ b/extensions/grpc/deployment/pom.xml @@ -95,6 +95,11 @@ rest-assured test + + io.quarkus + quarkus-elytron-security-properties-file-deployment + test + diff --git a/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java b/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java index a6d3fcc5770df..80d734ff7aa81 100644 --- a/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java +++ b/extensions/grpc/deployment/src/main/java/io/quarkus/grpc/deployment/GrpcServerProcessor.java @@ -51,6 +51,8 @@ import io.quarkus.arc.processor.BeanInfo; import io.quarkus.arc.processor.BuiltinScope; import io.quarkus.deployment.ApplicationArchive; +import io.quarkus.deployment.Capabilities; +import io.quarkus.deployment.Capability; import io.quarkus.deployment.IsDevelopment; import io.quarkus.deployment.IsNormal; import io.quarkus.deployment.annotations.BuildProducer; @@ -71,6 +73,8 @@ import io.quarkus.gizmo.MethodCreator; import io.quarkus.gizmo.MethodDescriptor; import io.quarkus.grpc.GrpcService; +import io.quarkus.grpc.auth.DefaultAuthExceptionHandlerProvider; +import io.quarkus.grpc.auth.GrpcSecurityInterceptor; import io.quarkus.grpc.deployment.devmode.FieldDefinalizingVisitor; import io.quarkus.grpc.protoc.plugin.MutinyGrpcGenerator; import io.quarkus.grpc.runtime.GrpcContainer; @@ -334,6 +338,7 @@ KubernetesPortBuildItem registerGrpcServiceInKubernetes(List beans, + Capabilities capabilities, List bindables, BuildProducer features) { // @GrpcService is a CDI qualifier beans.produce(new AdditionalBeanBuildItem(GrpcService.class)); @@ -345,15 +350,25 @@ void registerBeans(BuildProducer beans, // Global interceptors are invoked before any of the per-service interceptors beans.produce(AdditionalBeanBuildItem.unremovableOf(GrpcRequestContextGrpcInterceptor.class)); features.produce(new FeatureBuildItem(GRPC_SERVER)); + + if (capabilities.isPresent(Capability.SECURITY)) { + beans.produce(AdditionalBeanBuildItem.unremovableOf(GrpcSecurityInterceptor.class)); + beans.produce(AdditionalBeanBuildItem.unremovableOf(DefaultAuthExceptionHandlerProvider.class)); + } } else { log.debug("Unable to find beans exposing the `BindableService` interface - not starting the gRPC server"); } } @BuildStep - void registerAdditionalInterceptors(BuildProducer additionalInterceptors) { + void registerAdditionalInterceptors(BuildProducer additionalInterceptors, + Capabilities capabilities) { additionalInterceptors .produce(new AdditionalGlobalInterceptorBuildItem(GrpcRequestContextGrpcInterceptor.class.getName())); + if (capabilities.isPresent(Capability.SECURITY)) { + additionalInterceptors + .produce(new AdditionalGlobalInterceptorBuildItem(GrpcSecurityInterceptor.class.getName())); + } } @BuildStep diff --git a/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/auth/GrpcAuthTest.java b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/auth/GrpcAuthTest.java new file mode 100644 index 0000000000000..c995eb501b7c0 --- /dev/null +++ b/extensions/grpc/deployment/src/test/java/io/quarkus/grpc/auth/GrpcAuthTest.java @@ -0,0 +1,168 @@ +package io.quarkus.grpc.auth; + +import static com.example.security.Security.ThreadInfo.newBuilder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import javax.annotation.security.RolesAllowed; +import javax.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.example.security.SecuredService; +import com.example.security.Security; + +import io.grpc.Metadata; +import io.quarkus.grpc.GrpcClient; +import io.quarkus.grpc.GrpcClientUtils; +import io.quarkus.grpc.GrpcService; +import io.quarkus.security.credential.PasswordCredential; +import io.quarkus.security.identity.request.AuthenticationRequest; +import io.quarkus.security.identity.request.UsernamePasswordAuthenticationRequest; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Context; + +public class GrpcAuthTest { + + public static final Metadata.Key AUTHORIZATION = Metadata.Key.of("Authorization", + Metadata.ASCII_STRING_MARSHALLER); + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class) + .addClasses(Service.class, BasicGrpcSecurityMechanism.class) + .addPackage(SecuredService.class.getPackage()) + .add(new StringAsset("quarkus.security.users.embedded.enabled=true\n" + + "quarkus.security.users.embedded.users.john=john\n" + + "quarkus.security.users.embedded.roles.john=employees\n" + + "quarkus.security.users.embedded.users.paul=paul\n" + + "quarkus.security.users.embedded.roles.paul=interns\n" + + "quarkus.security.users.embedded.plain-text=true\n" + + "quarkus.http.auth.basic=true"), "application.properties")); + public static final String JOHN_BASIC_CREDS = "am9objpqb2hu"; + public static final String PAUL_BASIC_CREDS = "cGF1bDpwYXVs"; + + @GrpcClient + SecuredService securityClient; + + @Test + void shouldSecureUniEndpoint() { + Metadata headers = new Metadata(); + headers.put(AUTHORIZATION, "Basic " + JOHN_BASIC_CREDS); + SecuredService client = GrpcClientUtils.attachHeaders(securityClient, headers); + AtomicInteger resultCount = new AtomicInteger(); + client.unaryCall(Security.Container.newBuilder().setText("woo-hoo").build()) + .subscribe().with(e -> resultCount.incrementAndGet()); + + await().atMost(5, TimeUnit.SECONDS) + .until(() -> resultCount.get() == 1); + } + + @Test + void shouldSecureMultiEndpoint() { + Metadata headers = new Metadata(); + headers.put(AUTHORIZATION, "Basic " + PAUL_BASIC_CREDS); + SecuredService client = GrpcClientUtils.attachHeaders(securityClient, headers); + List results = new CopyOnWriteArrayList<>(); + client.streamCall(Multi.createBy().repeating() + .supplier(() -> (Security.Container.newBuilder().setText("woo-hoo").build())).atMost(4)) + .subscribe().with(e -> results.add(e.getIsOnEventLoop())); + + await().atMost(5, TimeUnit.SECONDS) + .until(() -> results.size() == 5); + + assertThat(results.stream().filter(e -> !e)).isEmpty(); + } + + @Test + void shouldFailWithInvalidCredentials() { + Metadata headers = new Metadata(); + headers.put(AUTHORIZATION, "Basic invalid creds"); + SecuredService client = GrpcClientUtils.attachHeaders(securityClient, headers); + + AtomicReference error = new AtomicReference<>(); + + AtomicInteger resultCount = new AtomicInteger(); + client.unaryCall(Security.Container.newBuilder().setText("woo-hoo").build()) + .onFailure().invoke(error::set) + .subscribe().with(e -> resultCount.incrementAndGet()); + + await().atMost(5, TimeUnit.SECONDS) + .until(() -> error.get() != null); + } + + @Test + void shouldFailWithInvalidInsufficientRole() { + Metadata headers = new Metadata(); + headers.put(AUTHORIZATION, PAUL_BASIC_CREDS); + SecuredService client = GrpcClientUtils.attachHeaders(securityClient, headers); + + AtomicReference error = new AtomicReference<>(); + + AtomicInteger resultCount = new AtomicInteger(); + client.unaryCall(Security.Container.newBuilder().setText("woo-hoo").build()) + .onFailure().invoke(error::set) + .subscribe().with(e -> resultCount.incrementAndGet()); + + await().atMost(5, TimeUnit.SECONDS) + .until(() -> error.get() != null); + } + + @GrpcService + public static class Service implements SecuredService { + @Override + @RolesAllowed("employees") + public Uni unaryCall(Security.Container request) { + return Uni.createFrom() + .item(newBuilder().setIsOnEventLoop(Context.isOnEventLoopThread()).build()); + } + + @Override + @RolesAllowed("interns") + public Multi streamCall(Multi request) { + return Multi.createBy() + .repeating().supplier(() -> newBuilder().setIsOnEventLoop(Context.isOnEventLoopThread()).build()) + .atMost(5); + } + + } + + @Singleton + public static class BasicGrpcSecurityMechanism implements GrpcSecurityMechanism { + @Override + public boolean handles(Metadata metadata) { + String authString = metadata.get(AUTHORIZATION); + return authString != null && authString.startsWith("Basic "); + } + + @Override + public AuthenticationRequest createAuthenticationRequest(Metadata metadata) { + String authString = metadata.get(AUTHORIZATION); + authString = authString.substring("Basic ".length()); + byte[] decode = Base64.getDecoder().decode(authString); + String plainChallenge = new String(decode, StandardCharsets.UTF_8); + int colonPos; + if ((colonPos = plainChallenge.indexOf(':')) > -1) { + String userName = plainChallenge.substring(0, colonPos); + char[] password = plainChallenge.substring(colonPos + 1).toCharArray(); + return new UsernamePasswordAuthenticationRequest(userName, new PasswordCredential(password)); + } else { + return null; + } + } + } +} diff --git a/extensions/grpc/deployment/src/test/proto/security.proto b/extensions/grpc/deployment/src/test/proto/security.proto new file mode 100644 index 0000000000000..ae36347fde721 --- /dev/null +++ b/extensions/grpc/deployment/src/test/proto/security.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package security; +option java_package = "com.example.security"; + +service SecuredService { + rpc unaryCall(Container) returns (ThreadInfo); + rpc streamCall(stream Container) returns (stream ThreadInfo); +} + +message ThreadInfo { + bool isOnEventLoop = 1; +} + +message Container { + string text = 1; +} \ No newline at end of file diff --git a/extensions/grpc/runtime/pom.xml b/extensions/grpc/runtime/pom.xml index 8814ac15e363d..9b51e339776c6 100644 --- a/extensions/grpc/runtime/pom.xml +++ b/extensions/grpc/runtime/pom.xml @@ -29,6 +29,10 @@ io.quarkus quarkus-arc + + io.quarkus.security + quarkus-security + io.quarkus diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/AuthExceptionHandler.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/AuthExceptionHandler.java new file mode 100644 index 0000000000000..e764c1c93a35c --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/AuthExceptionHandler.java @@ -0,0 +1,78 @@ +package io.quarkus.grpc.auth; + +import javax.enterprise.inject.spi.Prioritized; + +import io.grpc.ForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.Status; +import io.quarkus.security.AuthenticationFailedException; + +/** + * Exception mapper for authentication and authorization exceptions + * + * To alter mapping exceptions, create a subclass of this handler and create an appropriate + * {@link AuthExceptionHandlerProvider} + */ +public class AuthExceptionHandler + extends ForwardingServerCallListener.SimpleForwardingServerCallListener implements Prioritized { + + private final ServerCall serverCall; + private final Metadata metadata; + + public AuthExceptionHandler(ServerCall.Listener listener, ServerCall serverCall, + Metadata metadata) { + super(listener); + this.metadata = metadata; + this.serverCall = serverCall; + } + + @Override + public void onMessage(ReqT message) { + try { + super.onMessage(message); + } catch (RuntimeException e) { + handleException(e, serverCall, metadata); + } + } + + @Override + public void onHalfClose() { + try { + super.onHalfClose(); + } catch (RuntimeException e) { + handleException(e, serverCall, metadata); + } + } + + @Override + public void onReady() { + try { + super.onReady(); + } catch (RuntimeException e) { + handleException(e, serverCall, metadata); + } + } + + /** + * Maps exception to a gRPC error. Override this method to customize the mapping + * + * @param exception exception thrown + * @param serverCall server call to close with error + * @param metadata call metadata + */ + protected void handleException(RuntimeException exception, ServerCall serverCall, Metadata metadata) { + if (exception instanceof AuthenticationFailedException) { + serverCall.close(Status.UNAUTHENTICATED.withDescription(exception.getMessage()), metadata); + } else if (exception instanceof SecurityException) { + serverCall.close(Status.PERMISSION_DENIED.withDescription(exception.getMessage()), metadata); + } else { + throw exception; + } + } + + @Override + public int getPriority() { + return 0; + } +} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/AuthExceptionHandlerProvider.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/AuthExceptionHandlerProvider.java new file mode 100644 index 0000000000000..2d4e0f4792ad3 --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/AuthExceptionHandlerProvider.java @@ -0,0 +1,20 @@ +package io.quarkus.grpc.auth; + +import javax.enterprise.inject.spi.Prioritized; + +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; + +/** + * Provider for AuthExceptionHandler. + * + * To use a custom AuthExceptionHandler, extend {@link AuthExceptionHandler} and implement + * an {@link AuthExceptionHandlerProvider} with priority greater than the default one. + */ +public interface AuthExceptionHandlerProvider extends Prioritized { + int DEFAULT_PRIORITY = 0; + + AuthExceptionHandler createHandler(Listener listener, + ServerCall serverCall, Metadata metadata); +} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/DefaultAuthExceptionHandlerProvider.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/DefaultAuthExceptionHandlerProvider.java new file mode 100644 index 0000000000000..db0600c2a1565 --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/DefaultAuthExceptionHandlerProvider.java @@ -0,0 +1,21 @@ +package io.quarkus.grpc.auth; + +import javax.inject.Singleton; + +import io.grpc.Metadata; +import io.grpc.ServerCall; + +@Singleton +public class DefaultAuthExceptionHandlerProvider implements AuthExceptionHandlerProvider { + + @Override + public int getPriority() { + return DEFAULT_PRIORITY; + } + + @Override + public AuthExceptionHandler createHandler(ServerCall.Listener listener, + ServerCall serverCall, Metadata metadata) { + return new AuthExceptionHandler<>(listener, serverCall, metadata); + } +} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/GrpcSecurityInterceptor.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/GrpcSecurityInterceptor.java new file mode 100644 index 0000000000000..259ef106ac967 --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/GrpcSecurityInterceptor.java @@ -0,0 +1,122 @@ +package io.quarkus.grpc.auth; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Executor; + +import javax.enterprise.inject.Instance; +import javax.enterprise.inject.spi.Prioritized; +import javax.inject.Inject; +import javax.inject.Singleton; + +import org.jboss.logging.Logger; + +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.quarkus.grpc.GlobalInterceptor; +import io.quarkus.security.AuthenticationFailedException; +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.identity.IdentityProviderManager; +import io.quarkus.security.identity.SecurityIdentity; +import io.quarkus.security.identity.request.AuthenticationRequest; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Context; +import io.vertx.core.Handler; +import io.vertx.core.Vertx; + +/** + * Security interceptor invoking {@link GrpcSecurityMechanism} implementations + */ +@GlobalInterceptor +@Singleton +public final class GrpcSecurityInterceptor implements ServerInterceptor, Prioritized { + + private static final Logger log = Logger.getLogger(GrpcSecurityInterceptor.class); + + private final IdentityProviderManager identityProviderManager; + private final CurrentIdentityAssociation identityAssociation; + + private final AuthExceptionHandlerProvider exceptionHandlerProvider; + private final List securityMechanisms; + + @Inject + public GrpcSecurityInterceptor( + CurrentIdentityAssociation identityAssociation, + IdentityProviderManager identityProviderManager, + Instance securityMechanisms, + Instance exceptionHandlers) { + this.identityAssociation = identityAssociation; + this.identityProviderManager = identityProviderManager; + + AuthExceptionHandlerProvider maxPrioHandlerProvider = null; + + for (AuthExceptionHandlerProvider handler : exceptionHandlers) { + if (maxPrioHandlerProvider == null || maxPrioHandlerProvider.getPriority() < handler.getPriority()) { + maxPrioHandlerProvider = handler; + } + } + this.exceptionHandlerProvider = maxPrioHandlerProvider; + + List mechanisms = new ArrayList<>(); + for (GrpcSecurityMechanism securityMechanism : securityMechanisms) { + mechanisms.add(securityMechanism); + } + mechanisms.sort(Comparator.comparing(GrpcSecurityMechanism::getPriority)); + this.securityMechanisms = mechanisms; + } + + @Override + public ServerCall.Listener interceptCall(ServerCall serverCall, + Metadata metadata, ServerCallHandler serverCallHandler) { + Exception error = null; + for (GrpcSecurityMechanism securityMechanism : securityMechanisms) { + if (securityMechanism.handles(metadata)) { + try { + AuthenticationRequest authenticationRequest = securityMechanism.createAuthenticationRequest(metadata); + Context context = Vertx.currentContext(); + boolean onEventLoopThread = Context.isOnEventLoopThread(); + + if (authenticationRequest != null) { + Uni auth = identityProviderManager + .authenticate(authenticationRequest) + .emitOn(new Executor() { + @Override + public void execute(Runnable command) { + if (onEventLoopThread) { + context.runOnContext(new Handler<>() { + @Override + public void handle(Void event) { + command.run(); + } + }); + } else { + command.run(); + } + } + }); + identityAssociation.setIdentity(auth); + error = null; + break; + } + } catch (Exception e) { + error = e; + log.warn("Failed to prepare AuthenticationRequest for a gRPC call", e); + } + } + } + if (error != null) { // if parsing for all security mechanisms failed, let's propagate the last exception + identityAssociation.setIdentity(Uni.createFrom() + .failure(new AuthenticationFailedException("Failed to parse authentication data", error))); + } + ServerCall.Listener listener = serverCallHandler.startCall(serverCall, metadata); + return exceptionHandlerProvider.createHandler(listener, serverCall, metadata); + } + + @Override + public int getPriority() { + return Integer.MAX_VALUE - 100; + } +} diff --git a/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/GrpcSecurityMechanism.java b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/GrpcSecurityMechanism.java new file mode 100644 index 0000000000000..c1b336aab1621 --- /dev/null +++ b/extensions/grpc/runtime/src/main/java/io/quarkus/grpc/auth/GrpcSecurityMechanism.java @@ -0,0 +1,36 @@ +package io.quarkus.grpc.auth; + +import io.grpc.Metadata; +import io.quarkus.security.identity.request.AuthenticationRequest; + +/** + * gRPC security mechanism based on gRPC call metadata + * + * To secure your gRPC endpoints, create a CDI bean implementing this interface. + * + * Make sure that an {@link io.quarkus.security.identity.IdentityProvider} for the {@link AuthenticationRequest} + * returned by {@code createAuthenticationRequest} is available by adding a suitable extension to your application. + * + */ +public interface GrpcSecurityMechanism { + int DEFAULT_PRIORITY = 1000; + + /** + * + * @param metadata metadata of the gRPC call + * @return true if and only if the interceptor should handle security for this metadata. An interceptor may decide + * it should not be triggered for a call e.g. if some header is missing in metadata. + */ + boolean handles(Metadata metadata); + + /** + * + * @param metadata metadata of the gRPC call + * @return authentication request based on the metadata + */ + AuthenticationRequest createAuthenticationRequest(Metadata metadata); + + default int getPriority() { + return DEFAULT_PRIORITY; + } +} diff --git a/extensions/security/runtime/src/main/java/io/quarkus/security/runtime/interceptor/SecurityConstrainer.java b/extensions/security/runtime/src/main/java/io/quarkus/security/runtime/interceptor/SecurityConstrainer.java index 75ba078df3060..3438f80a27181 100644 --- a/extensions/security/runtime/src/main/java/io/quarkus/security/runtime/interceptor/SecurityConstrainer.java +++ b/extensions/security/runtime/src/main/java/io/quarkus/security/runtime/interceptor/SecurityConstrainer.java @@ -5,9 +5,11 @@ import javax.inject.Inject; import javax.inject.Singleton; -import io.quarkus.security.identity.SecurityIdentity; +import io.quarkus.runtime.BlockingOperationNotAllowedException; +import io.quarkus.security.runtime.SecurityIdentityAssociation; import io.quarkus.security.spi.runtime.SecurityCheck; import io.quarkus.security.spi.runtime.SecurityCheckStorage; +import io.smallrye.mutiny.Uni; /** * @author Michal Szynkiewicz, michal.l.szynkiewicz@gmail.com @@ -15,8 +17,9 @@ @Singleton public class SecurityConstrainer { + public static final Uni CHECK_OK = Uni.createFrom().item(new Object()); @Inject - SecurityIdentity identity; + SecurityIdentityAssociation identity; @Inject SecurityCheckStorage storage; @@ -25,7 +28,26 @@ public void check(Method method, Object[] parameters) { SecurityCheck securityCheck = storage.getSecurityCheck(method); if (securityCheck != null) { - securityCheck.apply(identity, method, parameters); + try { + securityCheck.apply(identity.getIdentity(), method, parameters); + } catch (BlockingOperationNotAllowedException blockingException) { + throw new BlockingOperationNotAllowedException( + "Blocking security check attempted in code running on the event loop. " + + "Make the secured method return an async type, i.e. Uni, Multi or CompletionStage, or " + + "use an authentication mechanism that sets the SecurityIdentity in a blocking manner " + + "prior to delegating the call", + blockingException); + } } } + + public Uni nonBlockingCheck(Method method, Object[] parameters) { + SecurityCheck securityCheck = storage.getSecurityCheck(method); + if (securityCheck != null) { + return identity.getDeferredIdentity() + .onItem() + .invoke(identity -> securityCheck.apply(identity, method, parameters)); + } + return CHECK_OK; + } } diff --git a/extensions/security/runtime/src/main/java/io/quarkus/security/runtime/interceptor/SecurityHandler.java b/extensions/security/runtime/src/main/java/io/quarkus/security/runtime/interceptor/SecurityHandler.java index 11d9bd6ebd4d4..0a58e6fed976f 100644 --- a/extensions/security/runtime/src/main/java/io/quarkus/security/runtime/interceptor/SecurityHandler.java +++ b/extensions/security/runtime/src/main/java/io/quarkus/security/runtime/interceptor/SecurityHandler.java @@ -1,9 +1,16 @@ package io.quarkus.security.runtime.interceptor; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.function.Function; + import javax.inject.Inject; import javax.inject.Singleton; import javax.interceptor.InvocationContext; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; + /** * @author Michal Szynkiewicz, michal.l.szynkiewicz@gmail.com */ @@ -20,11 +27,75 @@ public Object handle(InvocationContext ic) throws Exception { if (alreadyHandled(ic)) { return ic.proceed(); } - constrainer.check(ic.getMethod(), ic.getParameters()); - return ic.proceed(); + Class returnType = ic.getMethod().getReturnType(); + if (Uni.class.isAssignableFrom(returnType)) { + return constrainer.nonBlockingCheck(ic.getMethod(), ic.getParameters()) + .onItem().transformToUni(new UniContinuation(ic)); + } else if (CompletionStage.class.isAssignableFrom(returnType)) { + return constrainer.nonBlockingCheck(ic.getMethod(), ic.getParameters()) + .subscribeAsCompletionStage() + .thenApply(new CompletionStageContinuation(ic)); + } else if (Multi.class.isAssignableFrom(returnType)) { + return constrainer.nonBlockingCheck(ic.getMethod(), ic.getParameters()) + .onItem().transformToMulti(new MultiContinuation(ic)); + } else { + constrainer.check(ic.getMethod(), ic.getParameters()); + return ic.proceed(); + } } private boolean alreadyHandled(InvocationContext ic) { return ic.getContextData().put(HANDLER_NAME, EXECUTED) != null; } + + private static class UniContinuation implements Function> { + private final InvocationContext ic; + + UniContinuation(InvocationContext invocationContext) { + ic = invocationContext; + } + + @Override + public Uni apply(Object o) { + try { + return (Uni) ic.proceed(); + } catch (Exception e) { + return Uni.createFrom().failure(e); + } + } + } + + private static class CompletionStageContinuation implements Function> { + private final InvocationContext ic; + + CompletionStageContinuation(InvocationContext invocationContext) { + ic = invocationContext; + } + + @Override + public CompletionStage apply(Object o) { + try { + return (CompletionStage) ic.proceed(); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + } + } + + private static class MultiContinuation implements Function> { + private final InvocationContext ic; + + public MultiContinuation(InvocationContext invocationContext) { + ic = invocationContext; + } + + @Override + public Multi apply(Object o) { + try { + return (Multi) ic.proceed(); + } catch (Exception e) { + return Multi.createFrom().failure(e); + } + } + } } diff --git a/extensions/security/test-utils/src/main/java/io/quarkus/security/test/utils/IdentityMock.java b/extensions/security/test-utils/src/main/java/io/quarkus/security/test/utils/IdentityMock.java index a7fffb3e75c83..7788d521bfcc1 100644 --- a/extensions/security/test-utils/src/main/java/io/quarkus/security/test/utils/IdentityMock.java +++ b/extensions/security/test-utils/src/main/java/io/quarkus/security/test/utils/IdentityMock.java @@ -9,9 +9,11 @@ import javax.annotation.Priority; import javax.enterprise.context.ApplicationScoped; import javax.enterprise.inject.Alternative; +import javax.inject.Inject; import io.quarkus.security.credential.Credential; import io.quarkus.security.identity.SecurityIdentity; +import io.quarkus.security.runtime.SecurityIdentityAssociation; import io.smallrye.mutiny.Uni; /** @@ -89,4 +91,21 @@ public Uni checkPermission(Permission permission) { return null; } + @Alternative + @ApplicationScoped + @Priority(1) + public static class IdentityAssociationMock extends SecurityIdentityAssociation { + @Inject + IdentityMock identity; + + @Override + public Uni getDeferredIdentity() { + return Uni.createFrom().item(identity); + } + + @Override + public SecurityIdentity getIdentity() { + return identity; + } + } }