From b02154a69918c82a18077e740ebe61aac18d8953 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michal=20Vav=C5=99=C3=ADk?= Date: Tue, 4 Jun 2024 13:19:11 +0200 Subject: [PATCH] WebSockets Next: secure HTTP upgrade with annotation --- .../asciidoc/websockets-next-reference.adoc | 54 ++++- .../deployment/PermissionSecurityChecks.java | 114 +++++++---- .../RegisterClassSecurityCheckBuildItem.java | 21 ++ .../deployment/SecurityProcessor.java | 184 +++++++++++++----- ...ClassSecurityCheckAnnotationBuildItem.java | 32 +++ .../ClassSecurityCheckStorageBuildItem.java | 52 +++++ .../spi/SecurityTransformerUtils.java | 4 + .../next/deployment/WebSocketProcessor.java | 71 +++++++ .../HttpUpgradeAnnotationTransformerTest.java | 123 ++++++++++++ ...ttpUpgradeAuthenticatedAnnotationTest.java | 147 ++++++++++++++ .../HttpUpgradeDenyAllAnnotationTest.java | 87 +++++++++ ...gradePermissionsAllowedAnnotationTest.java | 142 ++++++++++++++ .../HttpUpgradePermitAllAnnotationTest.java | 86 ++++++++ .../HttpUpgradeRedirectOnFailureTest.java | 115 +++++++++++ ...HttpUpgradeRolesAllowedAnnotationTest.java | 104 ++++++++++ .../next/test/security/SecurityTestBase.java | 6 +- .../websockets/next/HttpUpgradeCheck.java | 7 +- .../next/WebSocketsServerRuntimeConfig.java | 16 ++ .../runtime/HttpUpgradeSecurityCheck.java | 53 +++++ .../next/runtime/WebSocketServerRecorder.java | 19 +- 20 files changed, 1332 insertions(+), 105 deletions(-) create mode 100644 extensions/security/deployment/src/main/java/io/quarkus/security/deployment/RegisterClassSecurityCheckBuildItem.java create mode 100644 extensions/security/spi/src/main/java/io/quarkus/security/spi/ClassSecurityCheckAnnotationBuildItem.java create mode 100644 extensions/security/spi/src/main/java/io/quarkus/security/spi/ClassSecurityCheckStorageBuildItem.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeAnnotationTransformerTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeAuthenticatedAnnotationTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeDenyAllAnnotationTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradePermissionsAllowedAnnotationTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradePermitAllAnnotationTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeRedirectOnFailureTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeRolesAllowedAnnotationTest.java create mode 100644 extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/HttpUpgradeSecurityCheck.java diff --git a/docs/src/main/asciidoc/websockets-next-reference.adoc b/docs/src/main/asciidoc/websockets-next-reference.adoc index dc0d890d721ca1..e8afeae5b50a55 100644 --- a/docs/src/main/asciidoc/websockets-next-reference.adoc +++ b/docs/src/main/asciidoc/websockets-next-reference.adoc @@ -655,18 +655,58 @@ public class Endpoint { `SecurityIdentity` is initially created during a secure HTTP upgrade and associated with the websocket connection. -Currently, for an HTTP upgrade be secured, users must configure an HTTP policy protecting the HTTP upgrade path. -For example, to secure the `open()` method in the above websocket endpoint, one can add the following authentication policy: +NOTE: When OpenID Connect extension is used and token expires, Quarkus automatically closes connection. -[source,properties] +== Secure HTTP upgrade + +An HTTP upgrade is secured when standard security annotation is placed on an endpoint class or an HTTP Security policy is defined. +The advantage of securing HTTP upgrade is less processing, the authorization is performed early and only once. +You should always prefer HTTP upgrade security unless, like in th example above, you need to perform action on error. + +.Use standard security annotation to secure an HTTP upgrade +[source, java] ---- -quarkus.http.auth.permission.secured.paths=/end -quarkus.http.auth.permission.secured.policy=authenticated +package io.quarkus.websockets.next.test.security; + +import io.quarkus.security.Authenticated; +import jakarta.inject.Inject; + +import io.quarkus.security.identity.SecurityIdentity; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; + +@Authenticated <1> +@WebSocket(path = "/end") +public class Endpoint { + + @Inject + SecurityIdentity currentIdentity; + + @OnOpen + String open() { + return "ready"; + } + + @OnTextMessage + String echo(String message) { + return message; + } +} ---- +<1> Initial HTTP handshake ends with the 401 status for anonymous users. +You can also redirect the handshake request on authorization failure with the `quarkus.websockets-next.server.security.auth-failure-redirect-url` configuration property. -Other options for securing HTTP upgrade requests, such as using the security annotations, will be explored in the future. +IMPORTANT: HTTP upgrade is only secured when a security annotation is declared on an endpoint class next to the `@WebSocket` annotation. +Placing a security annotation on an endpoint bean will not secure bean methods, only the HTTP upgrade. +You must always verify that your endpoint is secured as intended. -NOTE: When OpenID Connect extension is used and token expires, Quarkus automatically closes connection. +.Use HTTP Security policy to secure an HTTP upgrade +[source,properties] +---- +quarkus.http.auth.permission.http-upgrade.paths=/end +quarkus.http.auth.permission.http-upgrade.policy=authenticated +---- == Inspect and/or reject HTTP upgrade diff --git a/extensions/security/deployment/src/main/java/io/quarkus/security/deployment/PermissionSecurityChecks.java b/extensions/security/deployment/src/main/java/io/quarkus/security/deployment/PermissionSecurityChecks.java index 04f9bfa314eda1..a53a93a5bb5080 100644 --- a/extensions/security/deployment/src/main/java/io/quarkus/security/deployment/PermissionSecurityChecks.java +++ b/extensions/security/deployment/src/main/java/io/quarkus/security/deployment/PermissionSecurityChecks.java @@ -1,5 +1,6 @@ package io.quarkus.security.deployment; +import static io.quarkus.arc.processor.DotNames.CLASS; import static io.quarkus.arc.processor.DotNames.STRING; import static io.quarkus.security.PermissionsAllowed.AUTODETECTED; import static io.quarkus.security.PermissionsAllowed.PERMISSION_TO_ACTION_SEPARATOR; @@ -34,7 +35,9 @@ interface PermissionSecurityChecks { - Map get(); + Map getMethodSecurityChecks(); + + Map getClassNameSecurityChecks(); Set permissionClasses(); @@ -43,8 +46,8 @@ final class PermissionSecurityChecksBuilder { private static final DotName STRING_PERMISSION = DotName.createSimple(StringPermission.class); private static final DotName PERMISSIONS_ALLOWED_INTERCEPTOR = DotName .createSimple(PermissionsAllowedInterceptor.class); - private final Map>> methodToPermissionKeys = new HashMap<>(); - private final Map methodToPredicate = new HashMap<>(); + private final Map>> targetToPermissionKeys = new HashMap<>(); + private final Map targetToPredicate = new HashMap<>(); private final Map classSignatureToConstructor = new HashMap<>(); private final SecurityCheckRecorder recorder; @@ -53,22 +56,37 @@ public PermissionSecurityChecksBuilder(SecurityCheckRecorder recorder) { } PermissionSecurityChecks build() { + final Map cache = new HashMap<>(); + final Map methodToCheck = new HashMap<>(); + final Map classNameToCheck = new HashMap<>(); + for (var targetToPredicate : targetToPredicate.entrySet()) { + SecurityCheck check = cache.computeIfAbsent(targetToPredicate.getValue(), + new Function() { + @Override + public SecurityCheck apply(LogicalAndPermissionPredicate predicate) { + return createSecurityCheck(predicate); + } + }); + + var annotationTarget = targetToPredicate.getKey(); + if (annotationTarget.kind() == AnnotationTarget.Kind.CLASS) { + DotName className = annotationTarget.asClass().name(); + classNameToCheck.put(className, check); + } else { + MethodInfo securedMethod = annotationTarget.asMethod(); + methodToCheck.put(securedMethod, check); + } + } + return new PermissionSecurityChecks() { @Override - public Map get() { - final Map cache = new HashMap<>(); - final Map methodToCheck = new HashMap<>(); - for (var methodToPredicate : methodToPredicate.entrySet()) { - SecurityCheck check = cache.computeIfAbsent(methodToPredicate.getValue(), - new Function() { - @Override - public SecurityCheck apply(LogicalAndPermissionPredicate predicate) { - return createSecurityCheck(predicate); - } - }); - methodToCheck.put(methodToPredicate.getKey(), check); - } - return methodToCheck; + public Map getMethodSecurityChecks() { + return Map.copyOf(methodToCheck); + } + + @Override + public Map getClassNameSecurityChecks() { + return Map.copyOf(classNameToCheck); } @Override @@ -99,8 +117,8 @@ public Set permissionClasses() { */ PermissionSecurityChecksBuilder createPermissionPredicates() { Map permissionCache = new HashMap<>(); - for (Map.Entry>> entry : methodToPermissionKeys.entrySet()) { - final MethodInfo securedMethod = entry.getKey(); + for (var entry : targetToPermissionKeys.entrySet()) { + final AnnotationTarget securedTarget = entry.getKey(); final LogicalAndPermissionPredicate predicate = new LogicalAndPermissionPredicate(); // 'AND' operands @@ -113,7 +131,7 @@ PermissionSecurityChecksBuilder createPermissionPredicates() { // 'AND' operands for (PermissionKey permissionKey : permissionKeys) { - var permission = createPermission(permissionKey, securedMethod, permissionCache); + var permission = createPermission(permissionKey, securedTarget, permissionCache); if (permission.isComputed()) { predicate.markAsComputed(); } @@ -128,7 +146,7 @@ PermissionSecurityChecksBuilder createPermissionPredicates() { predicate.and(orPredicate); for (PermissionKey permissionKey : permissionKeys) { - var permission = createPermission(permissionKey, securedMethod, permissionCache); + var permission = createPermission(permissionKey, securedTarget, permissionCache); if (permission.isComputed()) { predicate.markAsComputed(); } @@ -136,7 +154,7 @@ PermissionSecurityChecksBuilder createPermissionPredicates() { } } } - methodToPredicate.put(securedMethod, predicate); + targetToPredicate.put(securedTarget, predicate); } return this; } @@ -153,7 +171,7 @@ private boolean isInclusive(List permissionKeys) { } PermissionSecurityChecksBuilder validatePermissionClasses(IndexView index) { - for (List> keyLists : methodToPermissionKeys.values()) { + for (List> keyLists : targetToPermissionKeys.values()) { for (List keyList : keyLists) { for (PermissionKey key : keyList) { if (!classSignatureToConstructor.containsKey(key.classSignature())) { @@ -187,7 +205,8 @@ PermissionSecurityChecksBuilder validatePermissionClasses(IndexView index) { PermissionSecurityChecksBuilder gatherPermissionsAllowedAnnotations(List instances, Map alreadyCheckedMethods, - Map alreadyCheckedClasses) { + Map alreadyCheckedClasses, + List additionalClassInstances) { // make sure we process annotations on methods first instances.sort(new Comparator() { @@ -217,7 +236,7 @@ public int compare(AnnotationInstance o1, AnnotationInstance o2) { methodInfo.name(), methodInfo.declaringClass())); } - gatherPermissionKeys(instance, methodInfo, cache, methodToPermissionKeys); + gatherPermissionKeys(instance, methodInfo, cache, targetToPermissionKeys); } else { // class annotation @@ -245,7 +264,7 @@ public int compare(AnnotationInstance o1, AnnotationInstance o2) { // ignore method annotated with other security annotation boolean noMethodLevelSecurityAnnotation = !alreadyCheckedMethods.containsKey(methodInfo); // ignore method annotated with method-level @PermissionsAllowed - boolean noMethodLevelPermissionsAllowed = !methodToPermissionKeys.containsKey(methodInfo); + boolean noMethodLevelPermissionsAllowed = !targetToPermissionKeys.containsKey(methodInfo); if (noMethodLevelSecurityAnnotation && noMethodLevelPermissionsAllowed) { gatherPermissionKeys(instance, methodInfo, cache, classMethodToPermissionKeys); @@ -261,12 +280,16 @@ public int compare(AnnotationInstance o1, AnnotationInstance o2) { } } } - methodToPermissionKeys.putAll(classMethodToPermissionKeys); + targetToPermissionKeys.putAll(classMethodToPermissionKeys); + for (var instance : additionalClassInstances) { + gatherPermissionKeys(instance, instance.target(), cache, targetToPermissionKeys); + } return this; } - private static void gatherPermissionKeys(AnnotationInstance instance, MethodInfo methodInfo, List cache, - Map>> methodToPermissionKeys) { + private static void gatherPermissionKeys(AnnotationInstance instance, T annotationTarget, + List cache, + Map>> targetToPermissionKeys) { // @PermissionsAllowed value is in format permission:action, permission2:action, permission:action2, permission3 // here we transform it to permission -> actions final var permissionToActions = new HashMap>(); @@ -299,9 +322,15 @@ private static void gatherPermissionKeys(AnnotationInstance instance, MethodInfo } if (permissionToActions.isEmpty()) { - throw new RuntimeException(String.format( - "Method '%s' was annotated with '@PermissionsAllowed', but no valid permission was provided", - methodInfo.name())); + if (annotationTarget.kind() == AnnotationTarget.Kind.METHOD) { + throw new RuntimeException(String.format( + "Method '%s' was annotated with '@PermissionsAllowed', but no valid permission was provided", + annotationTarget.asMethod().name())); + } else { + throw new RuntimeException(String.format( + "Class '%s' was annotated with '@PermissionsAllowed', but no valid permission was provided", + annotationTarget.asClass().name())); + } } // permissions specified via @PermissionsAllowed has 'one of' relation, therefore we put them in one list @@ -324,13 +353,8 @@ private static void gatherPermissionKeys(AnnotationInstance instance, MethodInfo } // store annotation value as permission keys - methodToPermissionKeys - .computeIfAbsent(methodInfo, new Function>>() { - @Override - public List> apply(MethodInfo methodInfo) { - return new ArrayList<>(); - } - }) + targetToPermissionKeys + .computeIfAbsent(annotationTarget, at -> new ArrayList<>()) .add(List.copyOf(orPermissions)); } @@ -378,10 +402,10 @@ private SecurityCheck createSecurityCheck(LogicalAndPermissionPredicate andPredi return securityCheck; } - private PermissionWrapper createPermission(PermissionKey permissionKey, MethodInfo securedMethod, + private PermissionWrapper createPermission(PermissionKey permissionKey, AnnotationTarget securedTarget, Map cache) { var constructor = classSignatureToConstructor.get(permissionKey.classSignature()); - return cache.computeIfAbsent(new PermissionCacheKey(permissionKey, securedMethod, constructor), + return cache.computeIfAbsent(new PermissionCacheKey(permissionKey, securedTarget, constructor), new Function() { @Override public PermissionWrapper apply(PermissionCacheKey permissionCacheKey) { @@ -568,8 +592,14 @@ private static final class PermissionCacheKey { private final boolean computed; private final boolean passActionsToConstructor; - private PermissionCacheKey(PermissionKey permissionKey, MethodInfo securedMethod, MethodInfo constructor) { + private PermissionCacheKey(PermissionKey permissionKey, AnnotationTarget securedTarget, MethodInfo constructor) { if (isComputed(permissionKey, constructor)) { + if (securedTarget.kind() != AnnotationTarget.Kind.METHOD) { + throw new IllegalArgumentException( + "@PermissionAllowed instance that accepts method arguments must be placed on a method"); + } + MethodInfo securedMethod = securedTarget.asMethod(); + // computed permission this.permissionKey = permissionKey; this.computed = true; diff --git a/extensions/security/deployment/src/main/java/io/quarkus/security/deployment/RegisterClassSecurityCheckBuildItem.java b/extensions/security/deployment/src/main/java/io/quarkus/security/deployment/RegisterClassSecurityCheckBuildItem.java new file mode 100644 index 00000000000000..8f9c2ec53915f0 --- /dev/null +++ b/extensions/security/deployment/src/main/java/io/quarkus/security/deployment/RegisterClassSecurityCheckBuildItem.java @@ -0,0 +1,21 @@ +package io.quarkus.security.deployment; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.DotName; + +import io.quarkus.builder.item.MultiBuildItem; + +/** + * Registers security check against {@link io.quarkus.security.spi.ClassSecurityCheckStorageBuildItem} + * for security annotation instances passed in this build item. + */ +final class RegisterClassSecurityCheckBuildItem extends MultiBuildItem { + + final DotName className; + final AnnotationInstance securityAnnotationInstance; + + RegisterClassSecurityCheckBuildItem(DotName className, AnnotationInstance securityAnnotationInstance) { + this.className = className; + this.securityAnnotationInstance = securityAnnotationInstance; + } +} diff --git a/extensions/security/deployment/src/main/java/io/quarkus/security/deployment/SecurityProcessor.java b/extensions/security/deployment/src/main/java/io/quarkus/security/deployment/SecurityProcessor.java index 31dac288e85951..59211ca8ae8f67 100644 --- a/extensions/security/deployment/src/main/java/io/quarkus/security/deployment/SecurityProcessor.java +++ b/extensions/security/deployment/src/main/java/io/quarkus/security/deployment/SecurityProcessor.java @@ -2,6 +2,7 @@ import static io.quarkus.arc.processor.DotNames.NO_CLASS_INTERCEPTORS; import static io.quarkus.gizmo.MethodDescriptor.ofMethod; +import static io.quarkus.security.deployment.DotNames.AUTHENTICATED; import static io.quarkus.security.deployment.DotNames.DENY_ALL; import static io.quarkus.security.deployment.DotNames.PERMISSIONS_ALLOWED; import static io.quarkus.security.deployment.DotNames.PERMIT_ALL; @@ -56,6 +57,7 @@ import io.quarkus.arc.processor.BuildExtension; import io.quarkus.arc.processor.ObserverInfo; import io.quarkus.builder.item.MultiBuildItem; +import io.quarkus.builder.item.SimpleBuildItem; import io.quarkus.deployment.Capabilities; import io.quarkus.deployment.Feature; import io.quarkus.deployment.annotations.BuildProducer; @@ -109,6 +111,9 @@ import io.quarkus.security.spi.AdditionalSecuredClassesBuildItem; import io.quarkus.security.spi.AdditionalSecuredMethodsBuildItem; import io.quarkus.security.spi.AdditionalSecurityConstrainerEventPropsBuildItem; +import io.quarkus.security.spi.ClassSecurityCheckAnnotationBuildItem; +import io.quarkus.security.spi.ClassSecurityCheckStorageBuildItem; +import io.quarkus.security.spi.ClassSecurityCheckStorageBuildItem.ClassStorageBuilder; import io.quarkus.security.spi.DefaultSecurityCheckBuildItem; import io.quarkus.security.spi.RolesAllowedConfigExpResolverBuildItem; import io.quarkus.security.spi.runtime.AuthorizationController; @@ -547,10 +552,9 @@ void transformSecurityAnnotations(BuildProducer } } - @Consume(Capabilities.class) // make sure extension combinations are validated before default security check @BuildStep @Record(ExecutionTime.STATIC_INIT) - void gatherSecurityChecks(BuildProducer syntheticBeans, + MethodSecurityChecks gatherSecurityChecks( BuildProducer configExpSecurityCheckProducer, List rolesAllowedConfigExpResolverBuildItems, BeanArchiveIndexBuildItem beanArchiveBuildItem, @@ -558,7 +562,8 @@ void gatherSecurityChecks(BuildProducer syntheticBeans, BuildProducer configBuilderProducer, List additionalSecuredMethods, SecurityCheckRecorder recorder, - List defaultSecurityCheckBuildItem, + BuildProducer classSecurityCheckStorageProducer, + List registerClassSecurityCheckBuildItems, BuildProducer reflectiveClassBuildItemBuildProducer, List additionalSecurityChecks, SecurityBuildTimeConfig config) { classPredicate.produce(new ApplicationClassPredicateBuildItem(new SecurityCheckStorageAppPredicate())); @@ -574,14 +579,27 @@ void gatherSecurityChecks(BuildProducer syntheticBeans, IndexView index = beanArchiveBuildItem.getIndex(); Map securityChecks = gatherSecurityAnnotations(index, configExpSecurityCheckProducer, additionalSecured.values(), config.denyUnannotated(), recorder, configBuilderProducer, - reflectiveClassBuildItemBuildProducer, rolesAllowedConfigExpResolverBuildItems); + reflectiveClassBuildItemBuildProducer, rolesAllowedConfigExpResolverBuildItems, + registerClassSecurityCheckBuildItems, classSecurityCheckStorageProducer); for (AdditionalSecurityCheckBuildItem additionalSecurityCheck : additionalSecurityChecks) { securityChecks.put(additionalSecurityCheck.getMethodInfo(), additionalSecurityCheck.getSecurityCheck()); } + return new MethodSecurityChecks(securityChecks); + } + + @Consume(Capabilities.class) // make sure extension combinations are validated before default security check + @BuildStep + @Record(ExecutionTime.STATIC_INIT) + void createSecurityCheckStorage(BuildProducer syntheticBeans, + BuildProducer classPredicate, + SecurityCheckRecorder recorder, MethodSecurityChecks securityChecksItem, + List defaultSecurityCheckBuildItem) { + classPredicate.produce(new ApplicationClassPredicateBuildItem(new SecurityCheckStorageAppPredicate())); + RuntimeValue builder = recorder.newBuilder(); - for (Map.Entry methodEntry : securityChecks + for (Map.Entry methodEntry : securityChecksItem.securityChecks .entrySet()) { MethodInfo method = methodEntry.getKey(); String[] params = new String[method.parametersCount()]; @@ -637,7 +655,9 @@ private static Map gatherSecurityAnnotations(IndexVie Collection additionalSecuredMethods, boolean denyUnannotated, SecurityCheckRecorder recorder, BuildProducer configBuilderProducer, BuildProducer reflectiveClassBuildItemBuildProducer, - List rolesAllowedConfigExpResolverBuildItems) { + List rolesAllowedConfigExpResolverBuildItems, + List registerClassSecurityCheckBuildItems, + BuildProducer classSecurityCheckStorageProducer) { Map methodToInstanceCollector = new HashMap<>(); Map classAnnotations = new HashMap<>(); @@ -693,31 +713,63 @@ private static Map gatherSecurityAnnotations(IndexVie final AtomicBoolean hasRolesAllowedCheckWithConfigExp = new AtomicBoolean(false); for (Map.Entry entry : methodToRoles.entrySet()) { final MethodInfo methodInfo = entry.getKey(); - final String[] allowedRoles = entry.getValue(); result.put(methodInfo, - cache.computeIfAbsent(getSetForKey(allowedRoles), new Function, SecurityCheck>() { - @Override - public SecurityCheck apply(Set allowedRolesSet) { - final int[] configExpressionPositions = configExpressionPositions(allowedRoles); - if (configExpressionPositions.length > 0) { - // we need to use supplier check as security checks are created during static init - // while config expressions are resolved during runtime - hasRolesAllowedCheckWithConfigExp.set(true); - - // we don't create security check for each method, therefore we need artificial keys - // we can safely use numbers as RolesAllowed config source prefix all keys - final int[] configKeys = new int[configExpressionPositions.length]; - for (int i = 0; i < configExpressionPositions.length; i++) { - // now we just collect artificial keys, but - // before we add the property to the Config system, we prefix it, e.g. - // @RolesAllowed("${admin}") -> QuarkusSecurityRolesAllowedConfigSource.property-0=${admin} - configKeys[i] = keyIndex.getAndIncrement(); - } - return recorder.rolesAllowedSupplier(allowedRoles, configExpressionPositions, configKeys); - } - return recorder.rolesAllowed(allowedRoles); - } - })); + computeRolesAllowedCheck(cache, hasRolesAllowedCheckWithConfigExp, keyIndex, recorder, entry.getValue())); + } + + final Map classNameToPermCheck; + List permissionInstances = new ArrayList<>( + index.getAnnotationsWithRepeatable(PERMISSIONS_ALLOWED, index)); + if (!permissionInstances.isEmpty()) { + var additionalClassInstances = registerClassSecurityCheckBuildItems + .stream() + .filter(i -> PERMISSIONS_ALLOWED.equals(i.securityAnnotationInstance.name())) + .map(i -> i.securityAnnotationInstance) + .toList(); + var securityChecks = new PermissionSecurityChecksBuilder(recorder) + .gatherPermissionsAllowedAnnotations(permissionInstances, methodToInstanceCollector, classAnnotations, + additionalClassInstances) + .validatePermissionClasses(index) + .createPermissionPredicates() + .build(); + result.putAll(securityChecks.getMethodSecurityChecks()); + classNameToPermCheck = securityChecks.getClassNameSecurityChecks(); + + // register used permission classes for reflection + for (String permissionClass : securityChecks.permissionClasses()) { + reflectiveClassBuildItemBuildProducer + .produce(ReflectiveClassBuildItem.builder(permissionClass).constructors().fields().methods().build()); + log.debugf("Register Permission class for reflection: %s", permissionClass); + } + } else { + classNameToPermCheck = Map.of(); + } + + if (!registerClassSecurityCheckBuildItems.isEmpty()) { + var classStorageBuilder = new ClassStorageBuilder(); + registerClassSecurityCheckBuildItems.forEach(item -> { + var securityAnnotationName = item.securityAnnotationInstance.name(); + + final SecurityCheck securityCheck; + if (DENY_ALL.equals(securityAnnotationName)) { + securityCheck = recorder.denyAll(); + } else if (PERMIT_ALL.equals(securityAnnotationName)) { + securityCheck = recorder.permitAll(); + } else if (AUTHENTICATED.equals(securityAnnotationName)) { + securityCheck = recorder.authenticated(); + } else if (ROLES_ALLOWED.equals(securityAnnotationName)) { + var allowedRoles = item.securityAnnotationInstance.value().asStringArray(); + securityCheck = computeRolesAllowedCheck(cache, hasRolesAllowedCheckWithConfigExp, keyIndex, recorder, + allowedRoles); + } else if (PERMISSIONS_ALLOWED.equals(securityAnnotationName)) { + securityCheck = Objects.requireNonNull(classNameToPermCheck.get(item.className)); + } else { + throw new IllegalStateException("Found unknown security annotation: " + securityAnnotationName); + } + + classStorageBuilder.addSecurityCheck(item.className, securityCheck); + }); + classSecurityCheckStorageProducer.produce(classStorageBuilder.build()); } final boolean registerRolesAllowedConfigSource; @@ -743,24 +795,6 @@ public SecurityCheck apply(Set allowedRolesSet) { .produce(new RunTimeConfigBuilderBuildItem(QuarkusSecurityRolesAllowedConfigBuilder.class.getName())); } - List permissionInstances = new ArrayList<>( - index.getAnnotationsWithRepeatable(PERMISSIONS_ALLOWED, index)); - if (!permissionInstances.isEmpty()) { - var securityChecks = new PermissionSecurityChecksBuilder(recorder) - .gatherPermissionsAllowedAnnotations(permissionInstances, methodToInstanceCollector, classAnnotations) - .validatePermissionClasses(index) - .createPermissionPredicates() - .build(); - result.putAll(securityChecks.get()); - - // register used permission classes for reflection - for (String permissionClass : securityChecks.permissionClasses()) { - reflectiveClassBuildItemBuildProducer - .produce(ReflectiveClassBuildItem.builder(permissionClass).constructors().fields().methods().build()); - log.debugf("Register Permission class for reflection: %s", permissionClass); - } - } - /* * If we need to add the denyAll security check to all unannotated methods, we simply go through all secured methods, * collect the declaring classes, then go through all methods of the classes and add the necessary check @@ -786,6 +820,34 @@ public SecurityCheck apply(Set allowedRolesSet) { return result; } + private static SecurityCheck computeRolesAllowedCheck(Map, SecurityCheck> cache, + AtomicBoolean hasRolesAllowedCheckWithConfigExp, AtomicInteger keyIndex, SecurityCheckRecorder recorder, + String[] allowedRoles) { + return cache.computeIfAbsent(getSetForKey(allowedRoles), new Function, SecurityCheck>() { + @Override + public SecurityCheck apply(Set allowedRolesSet) { + final int[] configExpressionPositions = configExpressionPositions(allowedRoles); + if (configExpressionPositions.length > 0) { + // we need to use supplier check as security checks are created during static init + // while config expressions are resolved during runtime + hasRolesAllowedCheckWithConfigExp.set(true); + + // we don't create security check for each method, therefore we need artificial keys + // we can safely use numbers as RolesAllowed config source prefix all keys + final int[] configKeys = new int[configExpressionPositions.length]; + for (int i = 0; i < configExpressionPositions.length; i++) { + // now we just collect artificial keys, but + // before we add the property to the Config system, we prefix it, e.g. + // @RolesAllowed("${admin}") -> QuarkusSecurityRolesAllowedConfigSource.property-0=${admin} + configKeys[i] = keyIndex.getAndIncrement(); + } + return recorder.rolesAllowedSupplier(allowedRoles, configExpressionPositions, configKeys); + } + return recorder.rolesAllowed(allowedRoles); + } + }); + } + public static int[] configExpressionPositions(String[] allowedRoles) { final Set expPositions = new HashSet<>(); for (int i = 0; i < allowedRoles.length; i++) { @@ -920,6 +982,25 @@ void validateStartUpObserversNotSecured(SynthesisFinishedBuildItem synthesisFini }); } + @BuildStep + void gatherClassSecurityChecks(BuildProducer producer, + BeanArchiveIndexBuildItem indexBuildItem, + List classAnnotationItems) { + if (!classAnnotationItems.isEmpty()) { + var index = indexBuildItem.getIndex(); + classAnnotationItems + .stream() + .map(ClassSecurityCheckAnnotationBuildItem::getClassAnnotation) + .map(index::getAnnotations) + .flatMap(Collection::stream) + .filter(ai -> ai.target().kind() == AnnotationTarget.Kind.CLASS) + .map(ai -> ai.target().asClass()) + .filter(SecurityTransformerUtils::hasStandardSecurityAnnotation) + .map(c -> new RegisterClassSecurityCheckBuildItem(c.name(), findFirstStandardSecurityAnnotation(c).get())) + .forEach(producer::produce); + } + } + private static boolean hasClassLevelStandardSecurityAnnotation(MethodInfo method, AnnotationStore annotationStore) { return applyClassLevenInterceptor(method, annotationStore) && hasStandardSecurityAnnotation(annotationStore.getAnnotations(method.declaringClass())); @@ -959,4 +1040,11 @@ public boolean test(String s) { } } + static final class MethodSecurityChecks extends SimpleBuildItem { + Map securityChecks; + + MethodSecurityChecks(Map securityChecks) { + this.securityChecks = securityChecks; + } + } } diff --git a/extensions/security/spi/src/main/java/io/quarkus/security/spi/ClassSecurityCheckAnnotationBuildItem.java b/extensions/security/spi/src/main/java/io/quarkus/security/spi/ClassSecurityCheckAnnotationBuildItem.java new file mode 100644 index 00000000000000..018beca15229b3 --- /dev/null +++ b/extensions/security/spi/src/main/java/io/quarkus/security/spi/ClassSecurityCheckAnnotationBuildItem.java @@ -0,0 +1,32 @@ +package io.quarkus.security.spi; + +import java.util.Objects; + +import org.jboss.jandex.DotName; + +import io.quarkus.builder.item.MultiBuildItem; + +/** + * Allows to create additional security checks for standard security annotations defined on a class level. + * We strongly recommended to secure CDI beans with {@link AdditionalSecuredMethodsBuildItem} + * if additional security is required. If you decide to use this build item, you must use + * class security check storage and apply checks manually. Thus, it's only suitable for very special cases. + */ +public final class ClassSecurityCheckAnnotationBuildItem extends MultiBuildItem { + + private final DotName classAnnotation; + + /** + * Quarkus will register security checks against {@link ClassSecurityCheckStorageBuildItem} for + * classes annotated with the {@code classAnnotation} that are secured with a standard security annotation. + * + * @param classAnnotation class-level annotation name + */ + public ClassSecurityCheckAnnotationBuildItem(DotName classAnnotation) { + this.classAnnotation = Objects.requireNonNull(classAnnotation); + } + + public DotName getClassAnnotation() { + return classAnnotation; + } +} diff --git a/extensions/security/spi/src/main/java/io/quarkus/security/spi/ClassSecurityCheckStorageBuildItem.java b/extensions/security/spi/src/main/java/io/quarkus/security/spi/ClassSecurityCheckStorageBuildItem.java new file mode 100644 index 00000000000000..234d02cb404bc4 --- /dev/null +++ b/extensions/security/spi/src/main/java/io/quarkus/security/spi/ClassSecurityCheckStorageBuildItem.java @@ -0,0 +1,52 @@ +package io.quarkus.security.spi; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import org.jboss.jandex.DotName; + +import io.quarkus.builder.item.SimpleBuildItem; + +/** + * Security check storage containing additional security checks created for secured classes + * matching one of the {@link ClassSecurityCheckAnnotationBuildItem} filters during the static init. + */ +public final class ClassSecurityCheckStorageBuildItem extends SimpleBuildItem { + + private final Map classNameToSecurityCheck; + + private ClassSecurityCheckStorageBuildItem(Map classNameToSecurityCheck) { + Objects.requireNonNull(classNameToSecurityCheck); + this.classNameToSecurityCheck = Map.copyOf(classNameToSecurityCheck); + } + + /** + * Returns additional security check created for classes annotated with standard + * security annotations based on the {@link ClassSecurityCheckAnnotationBuildItem} filter. + * + * @param className class name + * @return security check (see runtime Security SPI for respective class) + */ + public Object getSecurityCheck(DotName className) { + return classNameToSecurityCheck.get(className); + } + + public static final class ClassStorageBuilder { + + private final Map classNameToSecurityCheck; + + public ClassStorageBuilder() { + this.classNameToSecurityCheck = new HashMap<>(); + } + + public ClassStorageBuilder addSecurityCheck(DotName className, Object securityCheck) { + classNameToSecurityCheck.put(className, securityCheck); + return this; + } + + public ClassSecurityCheckStorageBuildItem build() { + return new ClassSecurityCheckStorageBuildItem(classNameToSecurityCheck); + } + } +} diff --git a/extensions/security/spi/src/main/java/io/quarkus/security/spi/SecurityTransformerUtils.java b/extensions/security/spi/src/main/java/io/quarkus/security/spi/SecurityTransformerUtils.java index 3504c4a35f6174..54562cd6dcaa61 100644 --- a/extensions/security/spi/src/main/java/io/quarkus/security/spi/SecurityTransformerUtils.java +++ b/extensions/security/spi/src/main/java/io/quarkus/security/spi/SecurityTransformerUtils.java @@ -44,4 +44,8 @@ public static boolean hasSecurityAnnotation(ClassInfo classInfo) { return false; } + + public static boolean isStandardSecurityAnnotation(AnnotationInstance annotationInstance) { + return SECURITY_ANNOTATIONS.contains(annotationInstance.name()); + } } diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index 8b690f688344ab..1ccb68627ae9b5 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -1,21 +1,26 @@ package io.quarkus.websockets.next.deployment; import static io.quarkus.deployment.annotations.ExecutionTime.RUNTIME_INIT; +import static io.quarkus.security.spi.SecurityTransformerUtils.hasSecurityAnnotation; +import static io.quarkus.websockets.next.runtime.HttpUpgradeSecurityCheck.BEAN_PRIORITY; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; import jakarta.enterprise.context.SessionScoped; import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.AnnotationTransformation; import org.jboss.jandex.AnnotationValue; import org.jboss.jandex.ClassInfo; import org.jboss.jandex.ClassInfo.NestingType; @@ -27,6 +32,7 @@ import org.jboss.jandex.Type.Kind; import io.quarkus.arc.deployment.AdditionalBeanBuildItem; +import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem; import io.quarkus.arc.deployment.AutoAddScopeBuildItem; import io.quarkus.arc.deployment.BeanArchiveIndexBuildItem; import io.quarkus.arc.deployment.BeanDefiningAnnotationBuildItem; @@ -35,6 +41,7 @@ import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem.ContextConfiguratorBuildItem; import io.quarkus.arc.deployment.CustomScopeBuildItem; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.arc.deployment.SyntheticBeansRuntimeInitBuildItem; import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; import io.quarkus.arc.deployment.UnremovableBeanBuildItem; import io.quarkus.arc.deployment.ValidationPhaseBuildItem; @@ -50,10 +57,12 @@ import io.quarkus.deployment.GeneratedClassGizmoAdaptor; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.Consume; import io.quarkus.deployment.annotations.Record; import io.quarkus.deployment.builditem.CombinedIndexBuildItem; import io.quarkus.deployment.builditem.FeatureBuildItem; import io.quarkus.deployment.builditem.GeneratedClassBuildItem; +import io.quarkus.deployment.builditem.RuntimeConfigSetupCompleteBuildItem; import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem; import io.quarkus.deployment.execannotations.ExecutionModelAnnotationsAllowedBuildItem; import io.quarkus.gizmo.BytecodeCreator; @@ -65,6 +74,10 @@ import io.quarkus.gizmo.MethodDescriptor; import io.quarkus.gizmo.ResultHandle; import io.quarkus.gizmo.TryBlock; +import io.quarkus.security.spi.ClassSecurityCheckAnnotationBuildItem; +import io.quarkus.security.spi.ClassSecurityCheckStorageBuildItem; +import io.quarkus.security.spi.SecurityTransformerUtils; +import io.quarkus.security.spi.runtime.SecurityCheck; import io.quarkus.vertx.http.deployment.HttpRootPathBuildItem; import io.quarkus.vertx.http.deployment.RouteBuildItem; import io.quarkus.vertx.http.runtime.HandlerType; @@ -83,6 +96,7 @@ import io.quarkus.websockets.next.runtime.Codecs; import io.quarkus.websockets.next.runtime.ConnectionManager; import io.quarkus.websockets.next.runtime.ContextSupport; +import io.quarkus.websockets.next.runtime.HttpUpgradeSecurityCheck; import io.quarkus.websockets.next.runtime.JsonTextMessageCodec; import io.quarkus.websockets.next.runtime.SecuritySupport; import io.quarkus.websockets.next.runtime.WebSocketClientRecorder; @@ -404,6 +418,8 @@ public String apply(String name) { } } + @Consume(RuntimeConfigSetupCompleteBuildItem.class) // HTTP Upgrade checks may need config during the initialization + @Consume(SyntheticBeansRuntimeInitBuildItem.class) // HTTP Upgrade Security check is runtime init due to runtime config @Record(RUNTIME_INIT) @BuildStep public void registerRoutes(WebSocketServerRecorder recorder, HttpRootPathBuildItem httpRootPath, @@ -496,6 +512,61 @@ void clientSyntheticBeans(WebSocketClientRecorder recorder, List producer) { + if (capabilities.isPresent(Capability.SECURITY)) { + producer.produce(new ClassSecurityCheckAnnotationBuildItem(WebSocketDotNames.WEB_SOCKET)); + } + } + + @BuildStep + void preventRepeatedSecurityChecksForHttpUpgrade(Capabilities capabilities, + BuildProducer producer) { + if (capabilities.isPresent(Capability.SECURITY)) { + producer.produce(new AnnotationsTransformerBuildItem(AnnotationTransformation + .forClasses() + .whenAnyMatch(WebSocketDotNames.WEB_SOCKET) + .transform(ctx -> ctx.remove(SecurityTransformerUtils::isStandardSecurityAnnotation)))); + } + } + + @Record(RUNTIME_INIT) + @BuildStep + void createHttpUpgradeSecurityCheck(Capabilities capabilities, BuildProducer producer, + Optional storageItem, + BeanArchiveIndexBuildItem indexItem, + WebSocketServerRecorder recorder, List endpoints) { + if (capabilities.isPresent(Capability.SECURITY) && storageItem.isPresent()) { + var endpointIdToSecurityCheck = collectEndpointSecurityChecks(endpoints, storageItem.get(), indexItem.getIndex()); + if (!endpointIdToSecurityCheck.isEmpty()) { + producer.produce(SyntheticBeanBuildItem + .configure(HttpUpgradeCheck.class) + .scope(BuiltinScope.SINGLETON.getInfo()) + .priority(HttpUpgradeSecurityCheck.BEAN_PRIORITY) + .setRuntimeInit() + .supplier(recorder.createHttpUpgradeSecurityCheck(endpointIdToSecurityCheck)) + .done()); + } + } + } + + private static Map collectEndpointSecurityChecks(List endpoints, + ClassSecurityCheckStorageBuildItem storage, IndexView index) { + return endpoints + .stream().> mapMulti((endpoint, consumer) -> { + var beanName = endpoint.beanClassName(); + if (storage.getSecurityCheck(beanName) instanceof SecurityCheck check) { + consumer.accept(Map.entry(endpoint.id, check)); + } else if (hasSecurityAnnotation(index.getClassByName(beanName))) { + throw new IllegalStateException("WebSocket endpoint '%s' requires ".formatted(beanName) + + "secured HTTP upgrade but Quarkus did not configure security check " + + "correctly. Please open issue in Quarkus project"); + } + }) + .collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); + } + static String mergePath(String prefix, String path) { if (prefix.endsWith("/")) { prefix = prefix.substring(0, prefix.length() - 1); diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeAnnotationTransformerTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeAnnotationTransformerTest.java new file mode 100644 index 00000000000000..8188efbd7a1d72 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeAnnotationTransformerTest.java @@ -0,0 +1,123 @@ +package io.quarkus.websockets.next.test.security; + +import static io.quarkus.security.test.utils.SecurityTestUtils.assertFailureFor; +import static io.quarkus.security.test.utils.SecurityTestUtils.assertSuccess; + +import java.util.Set; + +import jakarta.annotation.security.RolesAllowed; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.UnauthorizedException; +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.test.utils.AuthData; +import io.quarkus.security.test.utils.IdentityMock; +import io.quarkus.security.test.utils.SecurityTestUtils; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; + +public class HttpUpgradeAnnotationTransformerTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addClasses(Endpoint.class, WSClient.class, SecurityTestUtils.class, IdentityMock.class, + CdiBeanSecurity.class, AdminEndpoint.class)); + + @Inject + CdiBeanSecurity cdiBeanSecurity; + + @Test + public void testSecurityChecksNotRepeated() { + // fact that HTTP Upgrade is secured is tested in HttpUpgradeRolesAllowedAnnotationTest + // this test class complements these tests but must stand separately as it relies on different auth + + // when HTTP Upgrade is secured, we should not perform over and over again + // same check @OnTextMessage + var admin = new AuthData(Set.of("admin"), false, "admin"); + var user = new AuthData(Set.of("user"), false, "user"); + var anonymous = new AuthData(Set.of(), true, "anonymous"); + + // both HTTP Upgrade and @OnTextMessage are secured + assertSuccess(cdiBeanSecurity::httpUpgradeAndCdiBeanSecurity, "hey", admin); + assertFailureFor(cdiBeanSecurity::httpUpgradeAndCdiBeanSecurity, ForbiddenException.class, user); + assertFailureFor(cdiBeanSecurity::httpUpgradeAndCdiBeanSecurity, UnauthorizedException.class, anonymous); + + // only HTTP Upgrade is secured -> no CDI bean security + assertSuccess(cdiBeanSecurity::httpUpgradeSecurity, "hey", admin); + assertSuccess(cdiBeanSecurity::httpUpgradeSecurity, "hey", user); + assertSuccess(cdiBeanSecurity::httpUpgradeSecurity, "hey", anonymous); + } + + @ApplicationScoped + public static class CdiBeanSecurity { + + @Inject + AdminEndpoint adminEndpoint; + + @Inject + Endpoint endpoint; + + public String httpUpgradeSecurity() { + return adminEndpoint.echo("hey"); + } + + public String httpUpgradeAndCdiBeanSecurity() { + return endpoint.echo("hey"); + } + + } + + @RolesAllowed("admin") + @WebSocket(path = "/admin-end") + public static class AdminEndpoint { + + @OnOpen + String open() { + return "ready"; + } + + @OnTextMessage + String echo(String message) { + return message; + } + + } + + @RolesAllowed({ "admin", "user" }) + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + CurrentIdentityAssociation currentIdentity; + + @OnOpen + String open() { + return "ready"; + } + + @RolesAllowed("admin") + @OnTextMessage + String echo(String message) { + if (!currentIdentity.getIdentity().hasRole("admin")) { + throw new IllegalStateException(); + } + return message; + } + + @OnError + String error(ForbiddenException t) { + return "forbidden:" + currentIdentity.getIdentity().getPrincipal().getName(); + } + + } +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeAuthenticatedAnnotationTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeAuthenticatedAnnotationTest.java new file mode 100644 index 00000000000000..1c289eedd17962 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeAuthenticatedAnnotationTest.java @@ -0,0 +1,147 @@ +package io.quarkus.websockets.next.test.security; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.CompletionException; + +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.runtime.util.ExceptionUtil; +import io.quarkus.security.Authenticated; +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.http.UpgradeRejectedException; + +public class HttpUpgradeAuthenticatedAnnotationTest extends SecurityTestBase { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addClasses(Endpoint.class, WSClient.class, TestIdentityProvider.class, TestIdentityController.class, + PublicEndpoint.class, PublicEndpoint.SubEndpoint.class)); + + @TestHTTPResource("public-end") + URI publicEndUri; + + @TestHTTPResource("public-end/sub") + URI subEndUri; + + @Test + public void testSubEndpoint() { + try (WSClient client = new WSClient(vertx)) { + client.connect(publicEndUri); + client.waitForMessages(1); + assertEquals("ready", client.getMessages().get(0).toString()); + client.sendAndAwait("hello"); + client.waitForMessages(2); + assertEquals("hello", client.getMessages().get(1).toString()); + } + + try (WSClient client = new WSClient(vertx)) { + CompletionException ce = assertThrows(CompletionException.class, () -> client.connect(subEndUri)); + Throwable root = ExceptionUtil.getRootCause(ce); + assertInstanceOf(UpgradeRejectedException.class, root); + assertTrue(root.getMessage().contains("401"), root.getMessage()); + } + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("admin", "admin"), subEndUri); + client.waitForMessages(1); + assertEquals("ready", client.getMessages().get(0).toString()); + client.sendAndAwait("hello"); + client.waitForMessages(2); + assertEquals("sub-endpoint", client.getMessages().get(1).toString()); + } + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("user", "user"), subEndUri); + client.waitForMessages(1); + assertEquals("ready", client.getMessages().get(0).toString()); + client.sendAndAwait("hello"); + client.waitForMessages(2); + assertEquals("sub-endpoint:forbidden:user", client.getMessages().get(1).toString()); + } + } + + @WebSocket(path = "/public-end") + public static class PublicEndpoint { + + @OnOpen + String open() { + return "ready"; + } + + @OnTextMessage + String echo(String message) { + return message; + } + + @Authenticated + @WebSocket(path = "/sub") + public static class SubEndpoint { + + @Inject + CurrentIdentityAssociation currentIdentity; + + @OnOpen + String open() { + return "ready"; + } + + @RolesAllowed("admin") + @OnTextMessage + String echo(String message) { + return "sub-endpoint"; + } + + @OnError + String error(ForbiddenException t) { + return "sub-endpoint:forbidden:" + currentIdentity.getIdentity().getPrincipal().getName(); + } + } + + } + + @Authenticated + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + CurrentIdentityAssociation currentIdentity; + + @OnOpen + String open() { + return "ready"; + } + + @RolesAllowed("admin") + @OnTextMessage + String echo(String message) { + if (!currentIdentity.getIdentity().hasRole("admin")) { + throw new IllegalStateException(); + } + return message; + } + + @OnError + String error(ForbiddenException t) { + return "forbidden:" + currentIdentity.getIdentity().getPrincipal().getName(); + } + + } +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeDenyAllAnnotationTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeDenyAllAnnotationTest.java new file mode 100644 index 00000000000000..da50712668928b --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeDenyAllAnnotationTest.java @@ -0,0 +1,87 @@ +package io.quarkus.websockets.next.test.security; + +import static io.quarkus.websockets.next.test.security.SecurityTestBase.basicAuth; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.CompletionException; + +import jakarta.annotation.security.DenyAll; +import jakarta.inject.Inject; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.runtime.util.ExceptionUtil; +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; +import io.vertx.core.http.UpgradeRejectedException; + +public class HttpUpgradeDenyAllAnnotationTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot(root -> root.addClasses(Endpoint.class, AdminService.class, UserService.class, + TestIdentityProvider.class, TestIdentityController.class, WSClient.class, SecurityTestBase.class)); + + @Inject + Vertx vertx; + + @TestHTTPResource("end") + URI endUri; + + @BeforeAll + public static void setupUsers() { + TestIdentityController.resetRoles().add("admin", "admin", "admin"); + } + + @Test + public void testEndpoint() { + try (WSClient client = new WSClient(vertx)) { + CompletionException ce = assertThrows(CompletionException.class, () -> client.connect(endUri)); + Throwable root = ExceptionUtil.getRootCause(ce); + assertTrue(root instanceof UpgradeRejectedException); + assertTrue(root.getMessage().contains("401")); + } + try (WSClient client = new WSClient(vertx)) { + CompletionException ce = assertThrows(CompletionException.class, + () -> client.connect(basicAuth("admin", "admin"), endUri)); + Throwable root = ExceptionUtil.getRootCause(ce); + assertTrue(root instanceof UpgradeRejectedException); + assertTrue(root.getMessage().contains("403")); + } + } + + @DenyAll + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + UserService userService; + + @Inject + AdminService adminService; + + @OnTextMessage + String echo(String message) { + return message.equals("hello") ? adminService.ping() : userService.ping(); + } + + @OnError + String error(ForbiddenException t) { + return "forbidden"; + } + + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradePermissionsAllowedAnnotationTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradePermissionsAllowedAnnotationTest.java new file mode 100644 index 00000000000000..5bb272cc0c414e --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradePermissionsAllowedAnnotationTest.java @@ -0,0 +1,142 @@ +package io.quarkus.websockets.next.test.security; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.CompletionException; +import java.util.stream.Stream; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.runtime.util.ExceptionUtil; +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.PermissionsAllowed; +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.http.UpgradeRejectedException; + +public class HttpUpgradePermissionsAllowedAnnotationTest extends SecurityTestBase { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addClasses(Endpoint.class, WSClient.class, TestIdentityProvider.class, TestIdentityController.class, + AdminEndpoint.class, InclusiveEndpoint.class)); + + @TestHTTPResource("admin-end") + URI adminEndpointUri; + + @TestHTTPResource("inclusive-end") + URI inclusiveEndpointUri; + + @Test + public void testInsufficientRights() { + try (WSClient client = new WSClient(vertx)) { + CompletionException ce = assertThrows(CompletionException.class, + () -> client.connect(basicAuth("user", "user"), adminEndpointUri)); + Throwable root = ExceptionUtil.getRootCause(ce); + assertInstanceOf(UpgradeRejectedException.class, root); + assertTrue(root.getMessage().contains("403")); + } + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("admin", "admin"), adminEndpointUri); + client.waitForMessages(1); + assertEquals("ready", client.getMessages().get(0).toString()); + client.sendAndAwait("hello"); + client.waitForMessages(2); + assertEquals("hello", client.getMessages().get(1).toString()); + } + } + + @Test + public void testInclusivePermissions() { + Stream.of("admin", "user").forEach(name -> { + try (WSClient client = new WSClient(vertx)) { + CompletionException ce = assertThrows(CompletionException.class, + () -> client.connect(basicAuth(name, name), inclusiveEndpointUri)); + Throwable root = ExceptionUtil.getRootCause(ce); + assertInstanceOf(UpgradeRejectedException.class, root); + assertTrue(root.getMessage().contains("403")); + } + }); + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("almighty", "almighty"), inclusiveEndpointUri); + client.waitForMessages(1); + assertEquals("ready", client.getMessages().get(0).toString()); + client.sendAndAwait("hello"); + client.waitForMessages(2); + assertEquals("hello", client.getMessages().get(1).toString()); + } + } + + @PermissionsAllowed(value = { "perm1", "perm2" }, inclusive = true) + @WebSocket(path = "/inclusive-end") + public static class InclusiveEndpoint { + + @OnOpen + String open() { + return "ready"; + } + + @OnTextMessage + String echo(String message) { + return message; + } + + } + + @PermissionsAllowed("endpoint:read") + @WebSocket(path = "/admin-end") + public static class AdminEndpoint { + + @OnOpen + String open() { + return "ready"; + } + + @OnTextMessage + String echo(String message) { + return message; + } + + } + + @PermissionsAllowed(value = { "endpoint:connect", "endpoint:read" }) + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + CurrentIdentityAssociation currentIdentity; + + @OnOpen + String open() { + return "ready"; + } + + @PermissionsAllowed("endpoint:read") + @OnTextMessage + String echo(String message) { + return message; + } + + @OnError + String error(ForbiddenException t) { + return "forbidden:" + currentIdentity.getIdentity().getPrincipal().getName(); + } + + } +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradePermitAllAnnotationTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradePermitAllAnnotationTest.java new file mode 100644 index 00000000000000..c9af97bbbaca63 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradePermitAllAnnotationTest.java @@ -0,0 +1,86 @@ +package io.quarkus.websockets.next.test.security; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.Set; + +import jakarta.annotation.security.PermitAll; +import jakarta.inject.Inject; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class HttpUpgradePermitAllAnnotationTest extends SecurityTestBase { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot(root -> root.addClasses(Endpoint.class, AdminService.class, UserService.class, + TestIdentityProvider.class, TestIdentityController.class, WSClient.class)); + + @Inject + Vertx vertx; + + @TestHTTPResource("end") + URI endUri; + + @BeforeAll + public static void setupUsers() { + TestIdentityController.resetRoles() + .add("admin", "admin", "admin") + .add("user", "user", "user"); + } + + @Test + public void testEndpoint() { + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("admin", "admin"), endUri); + client.sendAndAwait("hello"); // admin service + client.sendAndAwait("hi"); // forbidden + client.waitForMessages(2); + assertEquals(Set.of("24", "forbidden"), Set.copyOf(client.getMessages().stream().map(Object::toString).toList())); + } + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("user", "user"), endUri); + client.sendAndAwait("hello"); // forbidden + client.sendAndAwait("hi"); // user service + client.waitForMessages(2); + assertEquals(Set.of("42", "forbidden"), Set.copyOf(client.getMessages().stream().map(Object::toString).toList())); + } + } + + @PermitAll + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + UserService userService; + + @Inject + AdminService adminService; + + @OnTextMessage + String echo(String message) { + return message.equals("hello") ? adminService.ping() : userService.ping(); + } + + @OnError + String error(ForbiddenException t) { + return "forbidden"; + } + + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeRedirectOnFailureTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeRedirectOnFailureTest.java new file mode 100644 index 00000000000000..e266e0d97d1eda --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeRedirectOnFailureTest.java @@ -0,0 +1,115 @@ +package io.quarkus.websockets.next.test.security; + +import static io.quarkus.websockets.next.test.security.SecurityTestBase.basicAuth; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.netty.handler.codec.http.HttpHeaderNames; +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.restassured.RestAssured; +import io.vertx.core.Vertx; + +public class HttpUpgradeRedirectOnFailureTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addClasses(Endpoint.class, WSClient.class, TestIdentityProvider.class, TestIdentityController.class, + SecurityTestBase.class) + .addAsResource( + new StringAsset( + "quarkus.websockets-next.server.security.auth-failure-redirect-url=https://quarkus.io\n"), + "application.properties")); + + @Inject + Vertx vertx; + + @TestHTTPResource("end") + URI endUri; + + @BeforeAll + public static void setupUsers() { + TestIdentityController.resetRoles() + .add("admin", "admin", "admin") + .add("user", "user", "user"); + } + + @Test + public void testRedirectOnFailure() { + // test redirected on failure + RestAssured + .given() + .redirects() + .follow(false) + .get(endUri) + .then() + .statusCode(302) + .header(HttpHeaderNames.LOCATION.toString(), "https://quarkus.io"); + + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("admin", "admin"), endUri); + client.waitForMessages(1); + assertEquals("ready", client.getMessages().get(0).toString()); + client.sendAndAwait("hello"); + client.waitForMessages(2); + assertEquals("hello", client.getMessages().get(1).toString()); + } + + // no redirect as CDI interceptor secures @OnTextMessage + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("user", "user"), endUri); + client.waitForMessages(1); + assertEquals("ready", client.getMessages().get(0).toString()); + client.sendAndAwait("hello"); + client.waitForMessages(2); + assertEquals("forbidden:user", client.getMessages().get(1).toString()); + } + } + + @RolesAllowed({ "admin", "user" }) + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + CurrentIdentityAssociation currentIdentity; + + @OnOpen + String open() { + return "ready"; + } + + @RolesAllowed("admin") + @OnTextMessage + String echo(String message) { + if (!currentIdentity.getIdentity().hasRole("admin")) { + throw new IllegalStateException(); + } + return message; + } + + @OnError + String error(ForbiddenException t) { + return "forbidden:" + currentIdentity.getIdentity().getPrincipal().getName(); + } + + } +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeRolesAllowedAnnotationTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeRolesAllowedAnnotationTest.java new file mode 100644 index 00000000000000..ff687fcab57ac5 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/HttpUpgradeRolesAllowedAnnotationTest.java @@ -0,0 +1,104 @@ +package io.quarkus.websockets.next.test.security; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.CompletionException; + +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.runtime.util.ExceptionUtil; +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.identity.CurrentIdentityAssociation; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.http.UpgradeRejectedException; + +public class HttpUpgradeRolesAllowedAnnotationTest extends SecurityTestBase { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addClasses(Endpoint.class, WSClient.class, TestIdentityProvider.class, TestIdentityController.class, + AdminEndpoint.class)); + + @TestHTTPResource("admin-end") + URI adminEndpointUri; + + @Test + public void testInsufficientRights() { + try (WSClient client = new WSClient(vertx)) { + CompletionException ce = assertThrows(CompletionException.class, + () -> client.connect(basicAuth("user", "user"), adminEndpointUri)); + Throwable root = ExceptionUtil.getRootCause(ce); + assertInstanceOf(UpgradeRejectedException.class, root); + assertTrue(root.getMessage().contains("403")); + } + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("admin", "admin"), adminEndpointUri); + client.waitForMessages(1); + assertEquals("ready", client.getMessages().get(0).toString()); + client.sendAndAwait("hello"); + client.waitForMessages(2); + assertEquals("hello", client.getMessages().get(1).toString()); + } + } + + @RolesAllowed("admin") + @WebSocket(path = "/admin-end") + public static class AdminEndpoint { + + @OnOpen + String open() { + return "ready"; + } + + @OnTextMessage + String echo(String message) { + return message; + } + + } + + @RolesAllowed({ "admin", "user" }) + @WebSocket(path = "/end") + public static class Endpoint { + + @Inject + CurrentIdentityAssociation currentIdentity; + + @OnOpen + String open() { + return "ready"; + } + + @RolesAllowed("admin") + @OnTextMessage + String echo(String message) { + if (!currentIdentity.getIdentity().hasRole("admin")) { + throw new IllegalStateException(); + } + return message; + } + + @OnError + String error(ForbiddenException t) { + return "forbidden:" + currentIdentity.getIdentity().getPrincipal().getName(); + } + + } +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/SecurityTestBase.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/SecurityTestBase.java index a9c94143ae59bc..4aee3d1c093fc5 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/SecurityTestBase.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/SecurityTestBase.java @@ -13,6 +13,7 @@ import org.junit.jupiter.api.Test; import io.quarkus.runtime.util.ExceptionUtil; +import io.quarkus.security.StringPermission; import io.quarkus.security.test.utils.TestIdentityController; import io.quarkus.test.common.http.TestHTTPResource; import io.quarkus.websockets.next.test.utils.WSClient; @@ -33,8 +34,9 @@ public abstract class SecurityTestBase { @BeforeAll public static void setupUsers() { TestIdentityController.resetRoles() - .add("admin", "admin", "admin") - .add("user", "user", "user"); + .add("admin", "admin", new StringPermission("endpoint", "read"), new StringPermission("perm1")) + .add("almighty", "almighty", new StringPermission("perm1"), new StringPermission("perm2")) + .add("user", "user", new StringPermission("endpoint", "connect"), new StringPermission("perm2")); } @Test diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/HttpUpgradeCheck.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/HttpUpgradeCheck.java index 4697d7785dc9dd..937e0fb3190492 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/HttpUpgradeCheck.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/HttpUpgradeCheck.java @@ -42,8 +42,9 @@ default boolean appliesTo(String endpointId) { /** * @param httpRequest {@link HttpServerRequest}; the HTTP 1.X request employing the 'Upgrade' header * @param securityIdentity {@link SecurityIdentity}; the identity is null if the Quarkus Security extension is absent + * @param endpointId {@link WebSocket#endpointId()} */ - record HttpUpgradeContext(HttpServerRequest httpRequest, SecurityIdentity securityIdentity) { + record HttpUpgradeContext(HttpServerRequest httpRequest, SecurityIdentity securityIdentity, String endpointId) { } final class CheckResult { @@ -91,6 +92,10 @@ public static Uni rejectUpgrade(Integer httpResponseCode) { return rejectUpgrade(httpResponseCode, null); } + public static CheckResult rejectUpgradeSync(Integer httpResponseCode) { + return rejectUpgradeSync(httpResponseCode, null); + } + public static CheckResult rejectUpgradeSync(Integer httpResponseCode, Map> responseHeaders) { return new CheckResult(false, httpResponseCode, responseHeaders); } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java index 43beffda35600a..a6bd6679836f3e 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java @@ -54,4 +54,20 @@ public interface WebSocketsServerRuntimeConfig { @WithDefault("close") UnhandledFailureStrategy unhandledFailureStrategy(); + /** + * WebSockets-specific security configuration. + */ + Security security(); + + interface Security { + + /** + * Quarkus redirects HTTP handshake request to this URL if an HTTP upgrade is rejected due to the authorization + * failure. This configuration property takes effect when you secure endpoint with a standard security annotation. + * For example, the HTTP upgrade is secured if an endpoint class is annotated with the `@RolesAllowed` annotation. + */ + Optional authFailureRedirectUrl(); + + } + } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/HttpUpgradeSecurityCheck.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/HttpUpgradeSecurityCheck.java new file mode 100644 index 00000000000000..94cfa2c9416347 --- /dev/null +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/HttpUpgradeSecurityCheck.java @@ -0,0 +1,53 @@ +package io.quarkus.websockets.next.runtime; + +import static io.netty.handler.codec.http.HttpHeaderNames.CACHE_CONTROL; +import static io.netty.handler.codec.http.HttpHeaderNames.LOCATION; + +import java.util.List; +import java.util.Map; + +import io.quarkus.security.ForbiddenException; +import io.quarkus.security.spi.runtime.MethodDescription; +import io.quarkus.security.spi.runtime.SecurityCheck; +import io.quarkus.websockets.next.HttpUpgradeCheck; +import io.smallrye.mutiny.Uni; + +public class HttpUpgradeSecurityCheck implements HttpUpgradeCheck { + + public static final int BEAN_PRIORITY = Integer.MAX_VALUE - 100; + + private final String redirectUrl; + private final Map endpointToCheck; + + HttpUpgradeSecurityCheck(String redirectUrl, Map endpointToCheck) { + this.redirectUrl = redirectUrl; + this.endpointToCheck = Map.copyOf(endpointToCheck); + } + + @Override + public Uni perform(HttpUpgradeContext context) { + return endpointToCheck + .get(context.endpointId()) + .nonBlockingApply(context.securityIdentity(), (MethodDescription) null, null) + .replaceWith(CheckResult::permitUpgradeSync) + .onFailure(SecurityException.class).recoverWithItem(this::rejectUpgrade); + } + + @Override + public boolean appliesTo(String endpointId) { + return endpointToCheck.containsKey(endpointId); + } + + private CheckResult rejectUpgrade(Throwable throwable) { + if (redirectUrl != null) { + return CheckResult.rejectUpgradeSync(302, + Map.of(LOCATION.toString(), List.of(redirectUrl), + CACHE_CONTROL.toString(), List.of("no-store"))); + } else if (throwable instanceof ForbiddenException) { + return CheckResult.rejectUpgradeSync(403); + } else { + return CheckResult.rejectUpgradeSync(401); + } + } + +} diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java index 0715daaf2114da..a9763540602e01 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -2,6 +2,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.function.Consumer; import java.util.function.Supplier; @@ -14,6 +15,7 @@ import io.quarkus.runtime.annotations.Recorder; import io.quarkus.security.identity.CurrentIdentityAssociation; import io.quarkus.security.identity.SecurityIdentity; +import io.quarkus.security.spi.runtime.SecurityCheck; import io.quarkus.vertx.core.runtime.VertxCoreRecorder; import io.quarkus.vertx.http.runtime.security.QuarkusHttpUser; import io.quarkus.websockets.next.HttpUpgradeCheck; @@ -91,14 +93,13 @@ public Handler createEndpointHandler(String generatedEndpointCla ArcContainer container = Arc.container(); ConnectionManager connectionManager = container.instance(ConnectionManager.class).get(); Codecs codecs = container.instance(Codecs.class).get(); + HttpUpgradeCheck[] httpUpgradeChecks = getHttpUpgradeChecks(endpointId, container); return new Handler() { - private final HttpUpgradeCheck[] httpUpgradeChecks = getHttpUpgradeChecks(endpointId, container); - @Override public void handle(RoutingContext ctx) { if (httpUpgradeChecks != null) { - checkHttpUpgrade(ctx).subscribe().with(result -> { + checkHttpUpgrade(ctx, endpointId).subscribe().with(result -> { if (!result.getResponseHeaders().isEmpty()) { result.getResponseHeaders().forEach((k, v) -> ctx.response().putHeader(k, v)); } @@ -132,9 +133,9 @@ private void httpUpgrade(RoutingContext ctx) { }); } - private Uni checkHttpUpgrade(RoutingContext ctx) { + private Uni checkHttpUpgrade(RoutingContext ctx, String endpointId) { SecurityIdentity identity = ctx.user() instanceof QuarkusHttpUser user ? user.getSecurityIdentity() : null; - return checkHttpUpgrade(new HttpUpgradeContext(ctx.request(), identity), httpUpgradeChecks, 0); + return checkHttpUpgrade(new HttpUpgradeContext(ctx.request(), identity, endpointId), httpUpgradeChecks, 0); } private static Uni checkHttpUpgrade(HttpUpgradeContext ctx, @@ -183,4 +184,12 @@ SecuritySupport initializeSecuritySupport(ArcContainer container, RoutingContext return SecuritySupport.NOOP; } + public Supplier createHttpUpgradeSecurityCheck(Map endpointToCheck) { + return new Supplier() { + @Override + public HttpUpgradeCheck get() { + return new HttpUpgradeSecurityCheck(config.security().authFailureRedirectUrl().orElse(null), endpointToCheck); + } + }; + } }