diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityExtension.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityExtension.java index c1a087061f994..2db295644553e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityExtension.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityExtension.java @@ -115,4 +115,8 @@ default AuthenticationFailureHandler getAuthenticationFailureHandler(SecurityCom default AuthorizationEngine getAuthorizationEngine(Settings settings) { return null; } + + default String extensionName() { + return getClass().getName(); + } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index d01d6ca46e462..26e42fa565d91 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -43,6 +43,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.env.Environment; @@ -597,7 +598,7 @@ Collection createComponents(Client client, ThreadPool threadPool, Cluste extensionComponents ); if (providers != null && providers.isEmpty() == false) { - customRoleProviders.put(extension.toString(), providers); + customRoleProviders.put(extension.extensionName(), providers); } } @@ -695,37 +696,15 @@ auditTrailService, failureHandler, threadPool, anonymousUser, getAuthorizationEn } private AuthorizationEngine getAuthorizationEngine() { - AuthorizationEngine authorizationEngine = null; - String extensionName = null; - for (SecurityExtension extension : securityExtensions) { - final AuthorizationEngine extensionEngine = extension.getAuthorizationEngine(settings); - if (extensionEngine != null && authorizationEngine != null) { - throw new IllegalStateException("Extensions [" + extensionName + "] and [" + extension.toString() + "] " - + "both set an authorization engine"); - } - authorizationEngine = extensionEngine; - extensionName = extension.toString(); - } - - if (authorizationEngine != null) { - logger.debug("Using authorization engine from extension [" + extensionName + "]"); - } - return authorizationEngine; + return findValueFromExtensions("authorization engine", extension -> extension.getAuthorizationEngine(settings)); } private AuthenticationFailureHandler createAuthenticationFailureHandler(final Realms realms, final SecurityExtension.SecurityComponents components) { - AuthenticationFailureHandler failureHandler = null; - String extensionName = null; - for (SecurityExtension extension : securityExtensions) { - AuthenticationFailureHandler extensionFailureHandler = extension.getAuthenticationFailureHandler(components); - if (extensionFailureHandler != null && failureHandler != null) { - throw new IllegalStateException("Extensions [" + extensionName + "] and [" + extension.toString() + "] " - + "both set an authentication failure handler"); - } - failureHandler = extensionFailureHandler; - extensionName = extension.toString(); - } + AuthenticationFailureHandler failureHandler = findValueFromExtensions( + "authentication failure handler", + extension -> extension.getAuthenticationFailureHandler(components) + ); if (failureHandler == null) { logger.debug("Using default authentication failure handler"); Supplier>> headersSupplier = () -> { @@ -762,12 +741,48 @@ private AuthenticationFailureHandler createAuthenticationFailureHandler(final Re getLicenseState().addListener(() -> { finalDefaultFailureHandler.setHeaders(headersSupplier.get()); }); - } else { - logger.debug("Using authentication failure handler from extension [" + extensionName + "]"); } return failureHandler; } + /** + * Calls the provided function for each configured extension and return the value that was generated by the extensions. + * If multiple extensions provide a value, throws {@link IllegalStateException}. + * If no extensions provide a value (or if there are no extensions) returns {@code null}. + */ + @Nullable + private T findValueFromExtensions(String valueType, Function method) { + T foundValue = null; + String fromExtension = null; + for (SecurityExtension extension : securityExtensions) { + final T extensionValue = method.apply(extension); + if (extensionValue == null) { + continue; + } + if (foundValue == null) { + foundValue = extensionValue; + fromExtension = extension.extensionName(); + } else { + throw new IllegalStateException( + "Extensions [" + + fromExtension + + "] and [" + + extension.extensionName() + + "] " + + " both attempted to provide a value for [" + + valueType + + "]" + ); + } + } + if (foundValue == null) { + return null; + } else { + logger.debug("Using [{}] [{}] from extension [{}]", valueType, foundValue, fromExtension); + return foundValue; + } + } + @Override public Settings additionalSettings() { return additionalSettings(settings, enabled, transportClientMode); diff --git a/x-pack/qa/security-example-spi-extension/src/main/java/org/elasticsearch/example/ExampleSecurityExtension.java b/x-pack/qa/security-example-spi-extension/src/main/java/org/elasticsearch/example/ExampleSecurityExtension.java index 4efa89e840a0d..8f7f32b88e2bc 100644 --- a/x-pack/qa/security-example-spi-extension/src/main/java/org/elasticsearch/example/ExampleSecurityExtension.java +++ b/x-pack/qa/security-example-spi-extension/src/main/java/org/elasticsearch/example/ExampleSecurityExtension.java @@ -42,6 +42,11 @@ public class ExampleSecurityExtension implements SecurityExtension { }); } + @Override + public String extensionName() { + return "example"; + } + @Override public Map getRealms(SecurityComponents components) { final Map map = new HashMap<>();