Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change security annotations to use recorders #11112

Merged
merged 1 commit into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package io.quarkus.security.deployment;

import java.util.function.Function;

import org.jboss.jandex.MethodInfo;

import io.quarkus.builder.item.MultiBuildItem;
import io.quarkus.gizmo.BytecodeCreator;
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.security.runtime.interceptor.check.SecurityCheck;

/**
* Used as an integration point when extensions need to customize the security behavior of a bean
Expand All @@ -15,18 +12,18 @@
public final class AdditionalSecurityCheckBuildItem extends MultiBuildItem {

private final MethodInfo methodInfo;
private final Function<BytecodeCreator, ResultHandle> function;
private final SecurityCheck securityCheck;

public AdditionalSecurityCheckBuildItem(MethodInfo methodInfo, Function<BytecodeCreator, ResultHandle> function) {
public AdditionalSecurityCheckBuildItem(MethodInfo methodInfo, SecurityCheck securityCheck) {
this.methodInfo = methodInfo;
this.function = function;
this.securityCheck = securityCheck;
}

public MethodInfo getMethodInfo() {
return methodInfo;
}

public Function<BytecodeCreator, ResultHandle> getSecurityCheckResultHandleCreator() {
return function;
public SecurityCheck getSecurityCheck() {
return securityCheck;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
package io.quarkus.security.deployment;

import static io.quarkus.security.deployment.SecurityCheckInstantiationUtil.authenticatedSecurityCheck;
import static io.quarkus.security.deployment.SecurityCheckInstantiationUtil.denyAllSecurityCheck;
import static io.quarkus.security.deployment.SecurityCheckInstantiationUtil.permitAllSecurityCheck;
import static io.quarkus.security.deployment.SecurityCheckInstantiationUtil.rolesAllowedSecurityCheck;

import java.lang.reflect.Modifier;
import java.security.Provider;
import java.security.Security;
Expand All @@ -25,7 +20,6 @@
import org.jboss.jandex.DotName;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.Type;
import org.jboss.logging.Logger;

import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
Expand All @@ -40,16 +34,18 @@
import io.quarkus.deployment.Feature;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.ApplicationClassPredicateBuildItem;
import io.quarkus.deployment.builditem.CapabilityBuildItem;
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.gizmo.BytecodeCreator;
import io.quarkus.gizmo.MethodCreator;
import io.quarkus.gizmo.MethodDescriptor;
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.runtime.RuntimeValue;
import io.quarkus.security.runtime.IdentityProviderManagerCreator;
import io.quarkus.security.runtime.SecurityBuildTimeConfig;
import io.quarkus.security.runtime.SecurityCheckRecorder;
import io.quarkus.security.runtime.SecurityIdentityAssociation;
import io.quarkus.security.runtime.SecurityIdentityProxy;
import io.quarkus.security.runtime.X509IdentityProvider;
Expand Down Expand Up @@ -141,10 +137,12 @@ void transformSecurityAnnotations(BuildProducer<AnnotationsTransformerBuildItem>
}

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
void gatherSecurityChecks(BuildProducer<BeanRegistrarBuildItem> beanRegistrars,
BeanArchiveIndexBuildItem beanArchiveBuildItem,
BuildProducer<ApplicationClassPredicateBuildItem> classPredicate,
List<AdditionalSecuredClassesBuildIem> additionalSecuredClasses,
SecurityCheckRecorder recorder,
List<AdditionalSecurityCheckBuildItem> additionalSecurityChecks, SecurityBuildTimeConfig config) {
classPredicate.produce(new ApplicationClassPredicateBuildItem(new SecurityCheckStorage.AppPredicate()));

Expand All @@ -157,82 +155,64 @@ void gatherSecurityChecks(BuildProducer<BeanRegistrarBuildItem> beanRegistrars,
});
}

IndexView index = beanArchiveBuildItem.getIndex();
Map<MethodInfo, SecurityCheck> securityChecks = gatherSecurityAnnotations(
index, additionalSecured, config.denyUnannotated, recorder);
for (AdditionalSecurityCheckBuildItem additionalSecurityCheck : additionalSecurityChecks) {
securityChecks.put(additionalSecurityCheck.getMethodInfo(),
additionalSecurityCheck.getSecurityCheck());
}

RuntimeValue<SecurityCheckStorageBuilder> builder = recorder.newBuilder();
for (Map.Entry<MethodInfo, SecurityCheck> methodEntry : securityChecks
.entrySet()) {
MethodInfo method = methodEntry.getKey();
String[] params = new String[method.parameters().size()];
for (int i = 0; i < method.parameters().size(); ++i) {
params[i] = method.parameters().get(i).name().toString();
}
recorder.addMethod(builder, method.declaringClass().name().toString(), method.name(), params,
methodEntry.getValue());
}
recorder.create(builder);

beanRegistrars.produce(new BeanRegistrarBuildItem(new BeanRegistrar() {

@Override
public void register(RegistrationContext registrationContext) {
IndexView index = beanArchiveBuildItem.getIndex();
Map<MethodInfo, Function<BytecodeCreator, ResultHandle>> securityChecks = gatherSecurityAnnotations(
index, additionalSecured, config.denyUnannotated);
for (AdditionalSecurityCheckBuildItem additionalSecurityCheck : additionalSecurityChecks) {
securityChecks.put(additionalSecurityCheck.getMethodInfo(),
additionalSecurityCheck.getSecurityCheckResultHandleCreator());
}

DotName name = DotName.createSimple(SecurityCheckStorage.class.getName());

BeanConfigurator<Object> configurator = registrationContext.configure(name);
configurator.addType(name);
configurator.scope(BuiltinScope.APPLICATION.getInfo());
configurator.creator(creator -> {
ResultHandle storageBuilder = creator
.newInstance(MethodDescriptor.ofConstructor(SecurityCheckStorageBuilder.class));
for (Map.Entry<MethodInfo, Function<BytecodeCreator, ResultHandle>> methodEntry : securityChecks
.entrySet()) {
registerSecuredMethod(storageBuilder, creator, methodEntry);
}
ResultHandle ret = creator.invokeVirtualMethod(
MethodDescriptor.ofMethod(SecurityCheckStorageBuilder.class, "create",
SecurityCheckStorage.class),
storageBuilder);
ResultHandle ret = creator.invokeStaticMethod(
MethodDescriptor.ofMethod(SecurityCheckRecorder.class, "getStorage",
SecurityCheckStorage.class));
creator.returnValue(ret);
});
configurator.done();
}
}));
}

private void registerSecuredMethod(ResultHandle checkStorage,
MethodCreator methodCreator,
Map.Entry<MethodInfo, Function<BytecodeCreator, ResultHandle>> methodEntry) {
MethodInfo methodInfo = methodEntry.getKey();
ResultHandle declaringClass = methodCreator.load(methodInfo.declaringClass().name().toString());
ResultHandle methodName = methodCreator.load(methodInfo.name());
ResultHandle methodParamTypes = paramTypes(methodCreator, methodInfo.parameters());

methodCreator.invokeVirtualMethod(
MethodDescriptor.ofMethod(SecurityCheckStorageBuilder.class, "registerCheck", void.class, String.class,
String.class, String[].class, SecurityCheck.class),
checkStorage,
declaringClass, methodName, methodParamTypes, methodEntry.getValue().apply(methodCreator));
}

private ResultHandle paramTypes(MethodCreator ctor, List<Type> parameters) {
ResultHandle result = ctor.newArray(String.class, parameters.size());

for (int i = 0; i < parameters.size(); i++) {
ctor.writeArrayValue(result, i, ctor.load(parameters.get(i).name().toString()));
}

return result;
}

private Map<MethodInfo, Function<BytecodeCreator, ResultHandle>> gatherSecurityAnnotations(
private Map<MethodInfo, SecurityCheck> gatherSecurityAnnotations(
IndexView index,
Map<DotName, ClassInfo> additionalSecuredClasses, boolean denyUnannotated) {
Map<DotName, ClassInfo> additionalSecuredClasses, boolean denyUnannotated, SecurityCheckRecorder recorder) {

Map<MethodInfo, AnnotationInstance> methodToInstanceCollector = new HashMap<>();
Map<ClassInfo, AnnotationInstance> classAnnotations = new HashMap<>();
Map<MethodInfo, Function<BytecodeCreator, ResultHandle>> result = new HashMap<>(gatherSecurityAnnotations(
Map<MethodInfo, SecurityCheck> result = new HashMap<>(gatherSecurityAnnotations(
index, DotNames.ROLES_ALLOWED, methodToInstanceCollector, classAnnotations,
(instance -> rolesAllowedSecurityCheck(instance.value().asStringArray()))));
(instance -> recorder.rolesAllowed(instance.value().asStringArray()))));
result.putAll(gatherSecurityAnnotations(index, DotNames.PERMIT_ALL, methodToInstanceCollector, classAnnotations,
(instance -> permitAllSecurityCheck())));
(instance -> recorder.permitAll())));
result.putAll(gatherSecurityAnnotations(index, DotNames.AUTHENTICATED, methodToInstanceCollector, classAnnotations,
(instance -> authenticatedSecurityCheck())));
(instance -> recorder.authenticated())));

result.putAll(gatherSecurityAnnotations(index, DotNames.DENY_ALL, methodToInstanceCollector, classAnnotations,
(instance -> denyAllSecurityCheck())));
(instance -> recorder.denyAll())));

/*
* Handle additional secured classes by adding the denyAll check to all public non-static methods
Expand All @@ -245,7 +225,7 @@ private Map<MethodInfo, Function<BytecodeCreator, ResultHandle>> gatherSecurityA
}
AnnotationInstance alreadyExistingInstance = methodToInstanceCollector.get(methodInfo);
if ((alreadyExistingInstance == null)) {
result.put(methodInfo, denyAllSecurityCheck());
result.put(methodInfo, recorder.denyAll());
} else if (alreadyExistingInstance.target().kind() == AnnotationTarget.Kind.CLASS) {
throw new IllegalStateException("Class " + methodInfo.declaringClass()
+ " should not have been added as an additional secured class");
Expand All @@ -270,7 +250,7 @@ private Map<MethodInfo, Function<BytecodeCreator, ResultHandle>> gatherSecurityA
if (methodToInstanceCollector.containsKey(methodInfo)) { // the method already has a security check
continue;
}
result.put(methodInfo, denyAllSecurityCheck());
result.put(methodInfo, recorder.denyAll());
}
}
}
Expand All @@ -283,13 +263,13 @@ private boolean isPublicNonStaticNonConstructor(MethodInfo methodInfo) {
&& !"<init>".equals(methodInfo.name());
}

private Map<MethodInfo, Function<BytecodeCreator, ResultHandle>> gatherSecurityAnnotations(
private Map<MethodInfo, SecurityCheck> gatherSecurityAnnotations(
IndexView index, DotName dotName,
Map<MethodInfo, AnnotationInstance> alreadyCheckedMethods,
Map<ClassInfo, AnnotationInstance> classLevelAnnotations,
Function<AnnotationInstance, Function<BytecodeCreator, ResultHandle>> securityCheckInstanceCreator) {
Function<AnnotationInstance, SecurityCheck> securityCheckInstanceCreator) {

Map<MethodInfo, Function<BytecodeCreator, ResultHandle>> result = new HashMap<>();
Map<MethodInfo, SecurityCheck> result = new HashMap<>();

Collection<AnnotationInstance> instances = index.getAnnotations(dotName);
// make sure we process annotations on methods first
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.quarkus.security.runtime;

import io.quarkus.runtime.RuntimeValue;
import io.quarkus.runtime.annotations.Recorder;
import io.quarkus.security.runtime.interceptor.SecurityCheckStorage;
import io.quarkus.security.runtime.interceptor.SecurityCheckStorageBuilder;
import io.quarkus.security.runtime.interceptor.check.AuthenticatedCheck;
import io.quarkus.security.runtime.interceptor.check.DenyAllCheck;
import io.quarkus.security.runtime.interceptor.check.PermitAllCheck;
import io.quarkus.security.runtime.interceptor.check.RolesAllowedCheck;
import io.quarkus.security.runtime.interceptor.check.SecurityCheck;

@Recorder
public class SecurityCheckRecorder {

private static volatile SecurityCheckStorage storage;

public static SecurityCheckStorage getStorage() {
return storage;
}

public SecurityCheck denyAll() {
return DenyAllCheck.INSTANCE;
}

public SecurityCheck permitAll() {
return PermitAllCheck.INSTANCE;
}

public SecurityCheck rolesAllowed(String... roles) {
return RolesAllowedCheck.of(roles);
}

public SecurityCheck authenticated() {
return AuthenticatedCheck.INSTANCE;
}

public RuntimeValue<SecurityCheckStorageBuilder> newBuilder() {
return new RuntimeValue<>(new SecurityCheckStorageBuilder());
}

public void addMethod(RuntimeValue<SecurityCheckStorageBuilder> builder, String className,
String methodName,
String[] parameterTypes,
SecurityCheck securityCheck) {
builder.getValue().registerCheck(className, methodName, parameterTypes, securityCheck);
}

public void create(RuntimeValue<SecurityCheckStorageBuilder> builder) {
storage = builder.getValue().create();
}

}
Loading