diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlUpgradeModeActionFilterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlUpgradeModeActionFilterTests.java index 6058906de284e..f3c7f3afc67b8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlUpgradeModeActionFilterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlUpgradeModeActionFilterTests.java @@ -97,7 +97,7 @@ public void testApply_ActionAllowedInUpgradeMode() { public void testOrder_UpgradeFilterIsExecutedAfterSecurityFilter() { MlUpgradeModeActionFilter upgradeModeFilter = new MlUpgradeModeActionFilter(clusterService); - SecurityActionFilter securityFilter = new SecurityActionFilter(null, null, null, mock(ThreadPool.class), null, null); + SecurityActionFilter securityFilter = new SecurityActionFilter(null, null, null, null, mock(ThreadPool.class), null, null); ActionFilter[] actionFiltersInOrderOfExecution = new ActionFilters(Sets.newHashSet(upgradeModeFilter, securityFilter)).filters(); assertThat(actionFiltersInOrderOfExecution, is(arrayContaining(securityFilter, upgradeModeFilter))); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index d496756e19ee4..dee66ba170612 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -521,7 +521,7 @@ auditTrailService, failureHandler, threadPool, anonymousUser, getAuthorizationEn securityInterceptor.set(new SecurityServerTransportInterceptor(settings, threadPool, authcService.get(), authzService, getLicenseState(), getSslService(), securityContext.get(), destructiveOperations, clusterService)); - securityActionFilter.set(new SecurityActionFilter(authcService.get(), authzService, getLicenseState(), + securityActionFilter.set(new SecurityActionFilter(authcService.get(), authzService, auditTrailService, getLicenseState(), threadPool, securityContext.get(), destructiveOperations)); components.add(new SecurityUsageServices(realms, allRolesStore, nativeRoleMappingStore, ipFilter.get())); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java index 3d255b7690e3c..d1f08255b1eaf 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilter.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.action.support.DestructiveOperations; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; @@ -33,11 +34,12 @@ import org.elasticsearch.xpack.core.security.support.Automatons; import org.elasticsearch.xpack.core.security.user.SystemUser; import org.elasticsearch.xpack.security.action.SecurityActionMapper; +import org.elasticsearch.xpack.security.audit.AuditTrailService; +import org.elasticsearch.xpack.security.audit.AuditUtil; import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.authz.AuthorizationService; import org.elasticsearch.xpack.security.authz.AuthorizationUtils; -import java.io.IOException; import java.util.function.Predicate; public class SecurityActionFilter implements ActionFilter { @@ -48,6 +50,7 @@ public class SecurityActionFilter implements ActionFilter { private final AuthenticationService authcService; private final AuthorizationService authzService; + private final AuditTrailService auditTrailService; private final SecurityActionMapper actionMapper = new SecurityActionMapper(); private final XPackLicenseState licenseState; private final ThreadContext threadContext; @@ -55,10 +58,11 @@ public class SecurityActionFilter implements ActionFilter { private final DestructiveOperations destructiveOperations; public SecurityActionFilter(AuthenticationService authcService, AuthorizationService authzService, - XPackLicenseState licenseState, ThreadPool threadPool, + AuditTrailService auditTrailService, XPackLicenseState licenseState, ThreadPool threadPool, SecurityContext securityContext, DestructiveOperations destructiveOperations) { this.authcService = authcService; this.authzService = authzService; + this.auditTrailService = auditTrailService; this.licenseState = licenseState; this.threadContext = threadPool.getThreadContext(); this.securityContext = securityContext; @@ -83,29 +87,19 @@ public void app if (licenseState.isSecurityEnabled()) { final ActionListener contextPreservingListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadContext); - ActionListener authenticatedListener = ActionListener.wrap( - (aVoid) -> chain.proceed(task, action, request, contextPreservingListener), contextPreservingListener::onFailure); final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action); try { if (useSystemUser) { securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> { - try { - applyInternal(action, request, authenticatedListener); - } catch (IOException e) { - listener.onFailure(e); - } + applyInternal(task, chain, action, request, contextPreservingListener); }, Version.CURRENT); } else if (AuthorizationUtils.shouldSetUserBasedOnActionOrigin(threadContext)) { AuthorizationUtils.switchUserBasedOnActionOriginAndExecute(threadContext, securityContext, (original) -> { - try { - applyInternal(action, request, authenticatedListener); - } catch (IOException e) { - listener.onFailure(e); - } + applyInternal(task, chain, action, request, contextPreservingListener); }); } else { try (ThreadContext.StoredContext ignore = threadContext.newStoredContext(true)) { - applyInternal(action, request, authenticatedListener); + applyInternal(task, chain, action, request, contextPreservingListener); } } } catch (Exception e) { @@ -130,13 +124,13 @@ public int order() { return Integer.MIN_VALUE; } - private void applyInternal(String action, Request request, - ActionListener listener) throws IOException { + private void applyInternal(Task task, + ActionFilterChain chain, String action, Request request, ActionListener listener) { if (CloseIndexAction.NAME.equals(action) || OpenIndexAction.NAME.equals(action) || DeleteIndexAction.NAME.equals(action)) { IndicesRequest indicesRequest = (IndicesRequest) request; try { destructiveOperations.failDestructive(indicesRequest.indices()); - } catch(IllegalArgumentException e) { + } catch (IllegalArgumentException e) { listener.onFailure(e); return; } @@ -156,7 +150,17 @@ it to the action without an associated user (not via REST or transport - this is authcService.authenticate(securityAction, request, SystemUser.INSTANCE, ActionListener.wrap((authc) -> { if (authc != null) { - authorizeRequest(authc, securityAction, request, listener); + final String requestId = AuditUtil.extractRequestId(threadContext); + assert Strings.hasText(requestId); + authorizeRequest(authc, securityAction, request, ActionListener.delegateFailure(listener, + (ignore, aVoid) -> { + chain.proceed(task, action, request, ActionListener.delegateFailure(listener, + (ignore2, response) -> { + auditTrailService.get().coordinatingActionResponse(requestId, authc, action, request, + response); + listener.onResponse(response); + })); + })); } else if (licenseState.isSecurityEnabled() == false) { listener.onResponse(null); } else { @@ -166,12 +170,11 @@ it to the action without an associated user (not via REST or transport - this is } private void authorizeRequest(Authentication authentication, String securityAction, Request request, - ActionListener listener) { + ActionListener listener) { if (authentication == null) { listener.onFailure(new IllegalArgumentException("authentication must be non null for authorization")); } else { - authzService.authorize(authentication, securityAction, request, ActionListener.wrap(ignore -> listener.onResponse(null), - listener::onFailure)); + authzService.authorize(authentication, securityAction, request, listener); } } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditTrail.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditTrail.java index 432b80578c76f..2aaa22725e708 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditTrail.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditTrail.java @@ -8,6 +8,7 @@ import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; import org.elasticsearch.xpack.core.security.authz.AuthorizationEngine.AuthorizationInfo; @@ -81,4 +82,9 @@ void runAsDenied(String requestId, Authentication authentication, RestRequest re void explicitIndexAccessEvent(String requestId, AuditLevel eventType, Authentication authentication, String action, String indices, String requestName, TransportAddress remoteAddress, AuthorizationInfo authorizationInfo); + // this is the only audit method that is called *after* the action executed, when the response is available + // it is however *only called for coordinating actions*, which are the actions that a client invokes as opposed to + // the actions that a node invokes in order to service a client request + void coordinatingActionResponse(String requestId, Authentication authentication, String action, TransportRequest transportRequest, + TransportResponse transportResponse); } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditTrailService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditTrailService.java index 8141e711d4e48..373d39eee7d4e 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditTrailService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/AuditTrailService.java @@ -12,6 +12,7 @@ import org.elasticsearch.license.XPackLicenseState.Feature; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; import org.elasticsearch.xpack.core.security.authz.AuthorizationEngine.AuthorizationInfo; @@ -147,6 +148,11 @@ public void runAsDenied(String requestId, Authentication authentication, RestReq public void explicitIndexAccessEvent(String requestId, AuditLevel eventType, Authentication authentication, String action, String indices, String requestName, TransportAddress remoteAddress, AuthorizationInfo authorizationInfo) {} + + @Override + public void coordinatingActionResponse(String requestId, Authentication authentication, String action, + TransportRequest transportRequest, + TransportResponse transportResponse) { } } private static class CompositeAuditTrail implements AuditTrail { @@ -254,6 +260,15 @@ public void accessDenied(String requestId, Authentication authentication, String } } + @Override + public void coordinatingActionResponse(String requestId, Authentication authentication, String action, + TransportRequest transportRequest, + TransportResponse transportResponse) { + for (AuditTrail auditTrail : auditTrails) { + auditTrail.coordinatingActionResponse(requestId, authentication, action, transportRequest, transportResponse); + } + } + @Override public void tamperedRequest(String requestId, RestRequest request) { for (AuditTrail auditTrail : auditTrails) { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java index 514c5b2c6c649..ebd0059627ae8 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java @@ -39,6 +39,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportRequest; +import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.xpack.core.security.action.CreateApiKeyAction; import org.elasticsearch.xpack.core.security.action.CreateApiKeyRequest; import org.elasticsearch.xpack.core.security.action.GrantApiKeyAction; @@ -815,6 +816,13 @@ public void runAsDenied(String requestId, Authentication authentication, RestReq } } + @Override + public void coordinatingActionResponse(String requestId, Authentication authentication, String action, + TransportRequest transportRequest, + TransportResponse transportResponse) { + // not implemented yet + } + private LogEntryBuilder securityChangeLogEntryBuilder(String requestId) { return new LogEntryBuilder(false) .with(EVENT_TYPE_FIELD_NAME, SECURITY_CHANGE_ORIGIN_FIELD_VALUE) diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/AuthorizationService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/AuthorizationService.java index 472d74744980b..ba110c185eeee 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/AuthorizationService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/AuthorizationService.java @@ -61,11 +61,8 @@ import org.elasticsearch.xpack.core.security.authz.privilege.ClusterPrivilegeResolver; import org.elasticsearch.xpack.core.security.authz.privilege.IndexPrivilege; import org.elasticsearch.xpack.core.security.user.AnonymousUser; -import org.elasticsearch.xpack.core.security.user.AsyncSearchUser; import org.elasticsearch.xpack.core.security.user.SystemUser; import org.elasticsearch.xpack.core.security.user.User; -import org.elasticsearch.xpack.core.security.user.XPackSecurityUser; -import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.security.audit.AuditLevel; import org.elasticsearch.xpack.security.audit.AuditTrail; import org.elasticsearch.xpack.security.audit.AuditTrailService; @@ -96,6 +93,7 @@ import static org.elasticsearch.xpack.core.security.authz.AuthorizationServiceField.INDICES_PERMISSIONS_KEY; import static org.elasticsearch.xpack.core.security.authz.AuthorizationServiceField.ORIGINATING_ACTION_KEY; import static org.elasticsearch.xpack.core.security.support.Exceptions.authorizationError; +import static org.elasticsearch.xpack.core.security.user.User.isInternal; import static org.elasticsearch.xpack.security.audit.logfile.LoggingAuditTrail.PRINCIPAL_ROLES_FIELD_NAME; public class AuthorizationService { @@ -191,7 +189,7 @@ public void authorize(final Authentication authentication, final String action, if (auditId == null) { // We would like to assert that there is an existing request-id, but if this is a system action, then that might not be // true because the request-id is generated during authentication - if (isInternalUser(authentication.getUser()) != false) { + if (isInternal(authentication.getUser())) { auditId = AuditUtil.getOrGenerateRequestId(threadContext); } else { auditTrailService.get().tamperedRequest(null, authentication, action, originalRequest); @@ -199,6 +197,7 @@ public void authorize(final Authentication authentication, final String action, + "] without an existing request-id"; assert false : message; listener.onFailure(new ElasticsearchSecurityException(message)); + return; } } @@ -398,7 +397,7 @@ AuthorizationEngine getAuthorizationEngine(final Authentication authentication) private AuthorizationEngine getAuthorizationEngineForUser(final User user) { if (rbacEngine != authorizationEngine && licenseState.isSecurityEnabled() && licenseState.checkFeature(Feature.SECURITY_AUTHORIZATION_ENGINE)) { - if (ClientReservedRealm.isReserved(user.principal(), settings) || isInternalUser(user)) { + if (ClientReservedRealm.isReserved(user.principal(), settings) || isInternal(user)) { return rbacEngine; } else { return authorizationEngine; @@ -449,10 +448,6 @@ private TransportRequest maybeUnwrapRequest(Authentication authentication, Trans return request; } - private boolean isInternalUser(User user) { - return SystemUser.is(user) || XPackUser.is(user) || XPackSecurityUser.is(user) || AsyncSearchUser.is(user); - } - private void authorizeRunAs(final RequestInfo requestInfo, final AuthorizationInfo authzInfo, final ActionListener listener) { final Authentication authentication = requestInfo.getAuthentication(); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilterTests.java index 8d341671be657..134101a1b2472 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/filter/SecurityActionFilterTests.java @@ -9,17 +9,18 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.MockIndicesRequest; import org.elasticsearch.action.admin.indices.close.CloseIndexAction; import org.elasticsearch.action.admin.indices.delete.DeleteIndexAction; import org.elasticsearch.action.admin.indices.open.OpenIndexAction; import org.elasticsearch.action.support.ActionFilterChain; -import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.action.support.DestructiveOperations; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -37,6 +38,9 @@ import org.elasticsearch.xpack.core.security.authz.accesscontrol.IndicesAccessControl; import org.elasticsearch.xpack.core.security.user.SystemUser; import org.elasticsearch.xpack.core.security.user.User; +import org.elasticsearch.xpack.security.audit.AuditTrail; +import org.elasticsearch.xpack.security.audit.AuditTrailService; +import org.elasticsearch.xpack.security.audit.AuditUtil; import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.authz.AuthorizationService; import org.junit.Before; @@ -45,6 +49,7 @@ import static org.elasticsearch.xpack.core.security.authz.AuthorizationServiceField.INDICES_PERMISSIONS_KEY; import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; @@ -60,6 +65,9 @@ public class SecurityActionFilterTests extends ESTestCase { private AuthenticationService authcService; private AuthorizationService authzService; + private AuditTrailService auditTrailService; + private AuditTrail auditTrail; + private ActionFilterChain chain; private XPackLicenseState licenseState; private SecurityActionFilter filter; private ThreadContext threadContext; @@ -69,6 +77,10 @@ public class SecurityActionFilterTests extends ESTestCase { public void init() throws Exception { authcService = mock(AuthenticationService.class); authzService = mock(AuthorizationService.class); + auditTrailService = mock(AuditTrailService.class); + auditTrail = mock(AuditTrail.class); + when(auditTrailService.get()).thenReturn(auditTrail); + chain = mock(ActionFilterChain.class); licenseState = mock(XPackLicenseState.class); when(licenseState.isSecurityEnabled()).thenReturn(true); when(licenseState.checkFeature(Feature.SECURITY_STATS_AND_HEALTH)).thenReturn(true); @@ -88,32 +100,37 @@ public void init() throws Exception { when(state.nodes()).thenReturn(nodes); SecurityContext securityContext = new SecurityContext(settings, threadContext); - filter = new SecurityActionFilter(authcService, authzService, licenseState, threadPool, securityContext, destructiveOperations); + filter = new SecurityActionFilter(authcService, authzService, auditTrailService, licenseState, threadPool, + securityContext, destructiveOperations); } public void testApply() throws Exception { ActionRequest request = mock(ActionRequest.class); ActionListener listener = mock(ActionListener.class); - ActionFilterChain chain = mock(ActionFilterChain.class); Task task = mock(Task.class); User user = new User("username", "r1", "r2"); Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null); - mockAuthentication(request, authentication); + String requestId = UUIDs.randomBase64UUID(); + mockAuthentication(request, authentication, requestId); mockAuthorize(); + ActionResponse actionResponse = mock(ActionResponse.class); + mockChain(task, "_action", request, actionResponse); filter.apply(task, "_action", request, listener, chain); verify(authzService).authorize(eq(authentication), eq("_action"), eq(request), any(ActionListener.class)); - verify(chain).proceed(eq(task), eq("_action"), eq(request), isA(ContextPreservingActionListener.class)); + verify(auditTrail).coordinatingActionResponse(eq(requestId), eq(authentication), eq("_action"), eq(request), eq(actionResponse)); } public void testApplyRestoresThreadContext() throws Exception { ActionRequest request = mock(ActionRequest.class); ActionListener listener = mock(ActionListener.class); - ActionFilterChain chain = mock(ActionFilterChain.class); Task task = mock(Task.class); User user = new User("username", "r1", "r2"); Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null); - mockAuthentication(request, authentication); + String requestId = UUIDs.randomBase64UUID(); + mockAuthentication(request, authentication, requestId); mockAuthorize(); + ActionResponse actionResponse = mock(ActionResponse.class); + mockChain(task, "_action", request, actionResponse); assertNull(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY)); assertNull(threadContext.getTransient(INDICES_PERMISSIONS_KEY)); @@ -122,7 +139,7 @@ public void testApplyRestoresThreadContext() throws Exception { assertNull(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY)); assertNull(threadContext.getTransient(INDICES_PERMISSIONS_KEY)); verify(authzService).authorize(eq(authentication), eq("_action"), eq(request), any(ActionListener.class)); - verify(chain).proceed(eq(task), eq("_action"), eq(request), isA(ContextPreservingActionListener.class)); + verify(auditTrail).coordinatingActionResponse(eq(requestId), eq(authentication), eq("_action"), eq(request), eq(actionResponse)); } public void testApplyAsSystemUser() throws Exception { @@ -132,15 +149,18 @@ public void testApplyAsSystemUser() throws Exception { Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null); SetOnce authenticationSetOnce = new SetOnce<>(); SetOnce accessControlSetOnce = new SetOnce<>(); + SetOnce requestIdOnActionHandler = new SetOnce<>(); ActionFilterChain chain = (task, action, request1, listener1) -> { authenticationSetOnce.set(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY)); accessControlSetOnce.set(threadContext.getTransient(INDICES_PERMISSIONS_KEY)); + requestIdOnActionHandler.set(AuditUtil.extractRequestId(threadContext)); }; Task task = mock(Task.class); final boolean hasExistingAuthentication = randomBoolean(); final boolean hasExistingAccessControl = randomBoolean(); final String action = "internal:foo"; if (hasExistingAuthentication) { + AuditUtil.generateRequestId(threadContext); threadContext.putTransient(AuthenticationField.AUTHENTICATION_KEY, authentication); threadContext.putHeader(AuthenticationField.AUTHENTICATION_KEY, "foo"); threadContext.putTransient(AuthorizationServiceField.ORIGINATING_ACTION_KEY, "indices:foo"); @@ -148,12 +168,15 @@ public void testApplyAsSystemUser() throws Exception { threadContext.putTransient(INDICES_PERMISSIONS_KEY, IndicesAccessControl.ALLOW_NO_INDICES); } } else { + assertNull(AuditUtil.extractRequestId(threadContext)); assertNull(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY)); } + SetOnce requestIdFromAuthn = new SetOnce<>(); doAnswer(i -> { final Object[] args = i.getArguments(); assertThat(args, arrayWithSize(4)); ActionListener callback = (ActionListener) args[args.length - 1]; + requestIdFromAuthn.set(AuditUtil.generateRequestId(threadContext)); callback.onResponse(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY)); return Void.TYPE; }).when(authcService).authenticate(eq(action), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class)); @@ -174,6 +197,7 @@ public void testApplyAsSystemUser() throws Exception { assertNotEquals(authentication, authenticationSetOnce.get()); assertEquals(SystemUser.INSTANCE, authenticationSetOnce.get().getUser()); assertThat(accessControlSetOnce.get(), sameInstance(authzAccessControl)); + assertThat(requestIdOnActionHandler.get(), is(requestIdFromAuthn.get())); } public void testApplyDestructiveOperations() throws Exception { @@ -182,14 +206,19 @@ public void testApplyDestructiveOperations() throws Exception { randomFrom("*", "_all", "test*")); String action = randomFrom(CloseIndexAction.NAME, OpenIndexAction.NAME, DeleteIndexAction.NAME); ActionListener listener = mock(ActionListener.class); - ActionFilterChain chain = mock(ActionFilterChain.class); Task task = mock(Task.class); User user = new User("username", "r1", "r2"); Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null); + ActionResponse actionResponse = mock(ActionResponse.class); + mockChain(task, action, request, actionResponse); + SetOnce requestIdFromAuthn = new SetOnce<>(); doAnswer(i -> { final Object[] args = i.getArguments(); assertThat(args, arrayWithSize(4)); ActionListener callback = (ActionListener) args[args.length - 1]; + requestIdFromAuthn.set(AuditUtil.generateRequestId(threadContext)); + threadContext.putTransient(AuthenticationField.AUTHENTICATION_KEY, authentication); + threadContext.putHeader(AuthenticationField.AUTHENTICATION_KEY, authentication.encode()); callback.onResponse(authentication); return Void.TYPE; }).when(authcService).authenticate(eq(action), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class)); @@ -202,10 +231,12 @@ public void testApplyDestructiveOperations() throws Exception { filter.apply(task, action, request, listener, chain); if (failDestructiveOperations) { verify(listener).onFailure(isA(IllegalArgumentException.class)); - verifyNoMoreInteractions(authzService, chain); + verifyNoMoreInteractions(authzService, chain, auditTrailService, auditTrail); } else { verify(authzService).authorize(eq(authentication), eq(action), eq(request), any(ActionListener.class)); - verify(chain).proceed(eq(task), eq(action), eq(request), isA(ContextPreservingActionListener.class)); + verify(chain).proceed(eq(task), eq(action), eq(request), any(ActionListener.class)); + verify(auditTrail).coordinatingActionResponse(eq(requestIdFromAuthn.get()), eq(authentication), eq(action), eq(request), + eq(actionResponse)); } } @@ -221,10 +252,21 @@ public void testActionProcessException() throws Exception { final Object[] args = i.getArguments(); assertThat(args, arrayWithSize(4)); ActionListener callback = (ActionListener) args[args.length - 1]; + assertNull(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY)); + AuditUtil.generateRequestId(threadContext); callback.onResponse(authentication); return Void.TYPE; }).when(authcService).authenticate(eq("_action"), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class)); - doThrow(exception).when(authzService).authorize(eq(authentication), eq("_action"), eq(request), any(ActionListener.class)); + if (randomBoolean()) { + doThrow(exception).when(authzService).authorize(eq(authentication), eq("_action"), eq(request), any(ActionListener.class)); + } else { + doAnswer((i) -> { + ActionListener callback = (ActionListener) i.getArguments()[3]; + callback.onFailure(exception); + return Void.TYPE; + }).when(authzService) + .authorize(eq(authentication), eq("_action"), eq(request), any(ActionListener.class)); + } filter.apply(task, "_action", request, listener, chain); verify(listener).onFailure(exception); verifyNoMoreInteractions(chain); @@ -242,13 +284,15 @@ public void testApplyUnlicensed() throws Exception { verify(chain).proceed(eq(task), eq("_action"), eq(request), eq(listener)); } - private void mockAuthentication(ActionRequest request, Authentication authentication) { + private void mockAuthentication(ActionRequest request, Authentication authentication, String requestId) { doAnswer(i -> { final Object[] args = i.getArguments(); assertThat(args, arrayWithSize(4)); ActionListener callback = (ActionListener) args[args.length - 1]; assertNull(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY)); threadContext.putTransient(AuthenticationField.AUTHENTICATION_KEY, authentication); + threadContext.putHeader(AuthenticationField.AUTHENTICATION_KEY, authentication.encode()); + threadContext.putHeader("_xpack_audit_request_id", requestId); callback.onResponse(authentication); return Void.TYPE; }).when(authcService).authenticate(eq("_action"), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class)); @@ -271,4 +315,13 @@ private void mockAuthorize(IndicesAccessControl indicesAccessControl) { .authorize(any(Authentication.class), any(String.class), any(TransportRequest.class), any(ActionListener.class)); } + private void mockChain(Task task, String action, ActionRequest request, ActionResponse actionResponse) { + doAnswer(i -> { + final Object[] args = i.getArguments(); + assertThat(args, arrayWithSize(4)); + ActionListener callback = (ActionListener) args[args.length - 1]; + callback.onResponse(actionResponse); + return Void.TYPE; + }).when(chain).proceed(eq(task), eq(action), eq(request), any(ActionListener.class)); + } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java index b75c26014ae06..93602c1a7b994 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java @@ -33,6 +33,7 @@ import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -73,6 +74,7 @@ import org.elasticsearch.xpack.core.security.authc.DefaultAuthenticationFailureHandler; import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.Realm.Factory; +import org.elasticsearch.xpack.core.security.authc.support.AuthenticationContextSerializer; import org.elasticsearch.xpack.core.security.authc.support.Hasher; import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; import org.elasticsearch.xpack.core.security.authz.AuthorizationEngine.EmptyAuthorizationInfo; @@ -301,17 +303,26 @@ public void testTokenMissing() throws Exception { Mockito.doReturn(List.of(secondRealm)).when(realms).getUnlicensedRealms(); Mockito.doReturn(List.of(firstRealm)).when(realms).asList(); - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } PlainActionFuture future = new PlainActionFuture<>(); Authenticator authenticator = service.createAuthenticator("_action", transportRequest, true, future); authenticator.extractToken((token) -> { assertThat(token, nullValue()); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } authenticator.handleNullToken(); }); ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> future.actionGet()); assertThat(e.getMessage(), containsString("missing authentication credentials")); - verify(auditTrail).anonymousAccessDenied(reqId, "_action", transportRequest); + verify(auditTrail).anonymousAccessDenied(reqId.get(), "_action", transportRequest); verifyNoMoreInteractions(auditTrail); mockAppender.assertAllExpectationsMatched(); } finally { @@ -331,10 +342,19 @@ public void testAuthenticateBothSupportSecondSucceeds() throws Exception { } else { when(secondRealm.token(threadContext)).thenReturn(token); } - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } final AtomicBoolean completed = new AtomicBoolean(false); service.authenticate("_action", transportRequest, true, ActionListener.wrap(result -> { + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } assertThat(result, notNullValue()); assertThat(result.getUser(), is(user)); assertThat(result.getLookedUpBy(), is(nullValue())); @@ -345,7 +365,7 @@ public void testAuthenticateBothSupportSecondSucceeds() throws Exception { verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); }, this::logAndFail)); assertTrue(completed.get()); - verify(auditTrail).authenticationFailed(reqId, firstRealm.name(), token, "_action", transportRequest); + verify(auditTrail).authenticationFailed(reqId.get(), firstRealm.name(), token, "_action", transportRequest); verify(realms).asList(); verifyNoMoreInteractions(realms); } @@ -357,11 +377,20 @@ public void testAuthenticateSmartRealmOrdering() { when(secondRealm.supports(token)).thenReturn(true); mockAuthenticate(secondRealm, token, user); when(secondRealm.token(threadContext)).thenReturn(token); - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } // Authenticate against the normal chain. 1st Realm will be checked (and not pass) then 2nd realm will successfully authc final AtomicBoolean completed = new AtomicBoolean(false); service.authenticate("_action", transportRequest, true, ActionListener.wrap(result -> { + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } assertThat(result, notNullValue()); assertThat(result.getUser(), is(user)); assertThat(result.getLookedUpBy(), is(nullValue())); @@ -369,7 +398,7 @@ public void testAuthenticateSmartRealmOrdering() { assertThat(result.getAuthenticatedBy().getName(), is(SECOND_REALM_NAME)); assertThat(result.getAuthenticatedBy().getType(), is(SECOND_REALM_TYPE)); assertThreadContextContainsAuthentication(result); - verify(auditTrail).authenticationSuccess(reqId, result, "_action", transportRequest); + verify(auditTrail).authenticationSuccess(reqId.get(), result, "_action", transportRequest); setCompletedToTrue(completed); verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); }, this::logAndFail)); @@ -381,6 +410,7 @@ public void testAuthenticateSmartRealmOrdering() { // "FirstRealm" will not be used Mockito.reset(operatorPrivilegesService); service.authenticate("_action", transportRequest, true, ActionListener.wrap(result -> { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); assertThat(result, notNullValue()); assertThat(result.getUser(), is(user)); assertThat(result.getLookedUpBy(), is(nullValue())); @@ -388,12 +418,12 @@ public void testAuthenticateSmartRealmOrdering() { assertThat(result.getAuthenticatedBy().getName(), is(SECOND_REALM_NAME)); assertThat(result.getAuthenticatedBy().getType(), is(SECOND_REALM_TYPE)); assertThreadContextContainsAuthentication(result); - verify(auditTrail, times(2)).authenticationSuccess(reqId, result, "_action", transportRequest); + verify(auditTrail, times(2)).authenticationSuccess(reqId.get(), result, "_action", transportRequest); setCompletedToTrue(completed); verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); }, this::logAndFail)); - verify(auditTrail).authenticationFailed(reqId, firstRealm.name(), token, "_action", transportRequest); + verify(auditTrail).authenticationFailed(reqId.get(), firstRealm.name(), token, "_action", transportRequest); verify(firstRealm, times(2)).name(); // used above one time verify(secondRealm, times(2)).name(); verify(secondRealm, times(2)).type(); // used to create realm ref @@ -414,6 +444,7 @@ public void testAuthenticateSmartRealmOrdering() { // "SecondRealm" will be at the top of the list but will no longer authenticate the user. // Then "FirstRealm" will be checked. service.authenticate("_action", transportRequest, true, ActionListener.wrap(result -> { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); assertThat(result, notNullValue()); assertThat(result.getUser(), is(user)); assertThat(result.getLookedUpBy(), is(nullValue())); @@ -421,12 +452,12 @@ public void testAuthenticateSmartRealmOrdering() { assertThat(result.getAuthenticatedBy().getName(), is(FIRST_REALM_NAME)); assertThat(result.getAuthenticatedBy().getType(), is(FIRST_REALM_TYPE)); assertThreadContextContainsAuthentication(result); - verify(auditTrail).authenticationSuccess(reqId, result, "_action", transportRequest); + verify(auditTrail).authenticationSuccess(reqId.get(), result, "_action", transportRequest); setCompletedToTrue(completed); verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); }, this::logAndFail)); - verify(auditTrail).authenticationFailed(reqId, SECOND_REALM_NAME, token, "_action", transportRequest); + verify(auditTrail).authenticationFailed(reqId.get(), SECOND_REALM_NAME, token, "_action", transportRequest); verify(secondRealm, times(3)).authenticate(eq(token), any(ActionListener.class)); // 2 from above + 1 more verify(firstRealm, times(2)).authenticate(eq(token), any(ActionListener.class)); // 1 from above + 1 more } @@ -480,16 +511,25 @@ public void testAuthenticateSmartRealmOrderingDisabled() { when(secondRealm.supports(token)).thenReturn(true); mockAuthenticate(secondRealm, token, user); when(secondRealm.token(threadContext)).thenReturn(token); - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } final AtomicBoolean completed = new AtomicBoolean(false); service.authenticate("_action", transportRequest, true, ActionListener.wrap(result -> { + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } assertThat(result, notNullValue()); assertThat(result.getUser(), is(user)); assertThat(result.getLookedUpBy(), is(nullValue())); assertThat(result.getAuthenticatedBy().getName(), is(SECOND_REALM_NAME)); // TODO implement equals assertThreadContextContainsAuthentication(result); - verify(auditTrail).authenticationSuccess(reqId, result, "_action", transportRequest); + verify(auditTrail).authenticationSuccess(reqId.get(), result, "_action", transportRequest); setCompletedToTrue(completed); verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); }, this::logAndFail)); @@ -498,16 +538,17 @@ public void testAuthenticateSmartRealmOrderingDisabled() { completed.set(false); Mockito.reset(operatorPrivilegesService); service.authenticate("_action", transportRequest, true, ActionListener.wrap(result -> { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); assertThat(result, notNullValue()); assertThat(result.getUser(), is(user)); assertThat(result.getLookedUpBy(), is(nullValue())); assertThat(result.getAuthenticatedBy().getName(), is(SECOND_REALM_NAME)); // TODO implement equals assertThreadContextContainsAuthentication(result); - verify(auditTrail, times(2)).authenticationSuccess(reqId, result, "_action", transportRequest); + verify(auditTrail, times(2)).authenticationSuccess(reqId.get(), result, "_action", transportRequest); setCompletedToTrue(completed); verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); }, this::logAndFail)); - verify(auditTrail, times(2)).authenticationFailed(reqId, firstRealm.name(), token, "_action", transportRequest); + verify(auditTrail, times(2)).authenticationFailed(reqId.get(), firstRealm.name(), token, "_action", transportRequest); verify(firstRealm, times(3)).name(); // used above one time verify(secondRealm, times(2)).name(); verify(secondRealm, times(2)).type(); // used to create realm ref @@ -526,17 +567,26 @@ public void testAuthenticateFirstNotSupportingSecondSucceeds() throws Exception when(secondRealm.supports(token)).thenReturn(true); mockAuthenticate(secondRealm, token, user); when(secondRealm.token(threadContext)).thenReturn(token); - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } final AtomicBoolean completed = new AtomicBoolean(false); service.authenticate("_action", transportRequest, true, ActionListener.wrap(result -> { + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } assertThat(result, notNullValue()); assertThat(result.getUser(), is(user)); assertThat(result.getAuthenticationType(), is(AuthenticationType.REALM)); assertThat(result.getAuthenticatedBy().getName(), is(secondRealm.name())); // TODO implement equals assertThat(result.getAuthenticationType(), is(AuthenticationType.REALM)); assertThreadContextContainsAuthentication(result); - verify(auditTrail).authenticationSuccess(reqId, result, "_action", transportRequest); + verify(auditTrail).authenticationSuccess(reqId.get(), result, "_action", transportRequest); setCompletedToTrue(completed); verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); }, this::logAndFail)); @@ -548,12 +598,21 @@ public void testAuthenticateFirstNotSupportingSecondSucceeds() throws Exception public void testAuthenticateCached() throws Exception { final Authentication authentication = new Authentication(new User("_username", "r1"), new RealmRef("test", "cached", "foo"), null); authentication.writeToContext(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } - Authentication result = authenticateBlocking("_action", transportRequest, null); + Tuple result = authenticateBlocking("_action", transportRequest, null); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } + assertThat(expectAuditRequestId(threadContext), is(result.v2())); assertThat(result, notNullValue()); - assertThat(result, is(authentication)); - assertThat(result.getAuthenticationType(), is(AuthenticationType.REALM)); + assertThat(result.v1(), is(authentication)); + assertThat(result.v1().getAuthenticationType(), is(AuthenticationType.REALM)); verifyZeroInteractions(auditTrail); verifyZeroInteractions(firstRealm); verifyZeroInteractions(secondRealm); @@ -567,6 +626,7 @@ public void testAuthenticateNonExistentRestRequestUserThrowsAuthenticationExcept authenticateBlocking(restRequest); fail("Authentication was successful but should not"); } catch (ElasticsearchSecurityException e) { + expectAuditRequestId(threadContext); assertAuthenticationException(e, containsString("unable to authenticate user [idonotexist] for REST request [/]")); verifyZeroInteractions(operatorPrivilegesService); } @@ -578,6 +638,7 @@ public void testTokenRestMissing() throws Exception { Authenticator authenticator = service.createAuthenticator(restRequest, true, mock(ActionListener.class)); authenticator.extractToken((token) -> { + expectAuditRequestId(threadContext); assertThat(token, nullValue()); }); } @@ -587,8 +648,16 @@ public void testAuthenticationInContextAndHeader() throws Exception { when(firstRealm.token(threadContext)).thenReturn(token); when(firstRealm.supports(token)).thenReturn(true); mockAuthenticate(firstRealm, token, user); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } service.authenticate("_action", transportRequest, true, ActionListener.wrap(result -> { + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } assertThat(result, notNullValue()); assertThat(result.getUser(), is(user)); assertThat(result.getAuthenticationType(), is(AuthenticationType.REALM)); @@ -604,7 +673,11 @@ public void testAuthenticationInContextAndHeader() throws Exception { public void testAuthenticateTransportAnonymous() throws Exception { when(firstRealm.token(threadContext)).thenReturn(null); when(secondRealm.token(threadContext)).thenReturn(null); - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } try { authenticateBlocking("_action", transportRequest, null); fail("expected an authentication exception when trying to authenticate an anonymous message"); @@ -613,7 +686,12 @@ public void testAuthenticateTransportAnonymous() throws Exception { assertAuthenticationException(e); verifyZeroInteractions(operatorPrivilegesService); } - verify(auditTrail).anonymousAccessDenied(reqId, "_action", transportRequest); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).anonymousAccessDenied(reqId.get(), "_action", transportRequest); } public void testAuthenticateRestAnonymous() throws Exception { @@ -627,25 +705,39 @@ public void testAuthenticateRestAnonymous() throws Exception { assertAuthenticationException(e); verifyZeroInteractions(operatorPrivilegesService); } - String reqId = expectAuditRequestId(); - verify(auditTrail).anonymousAccessDenied(reqId, restRequest); + verify(auditTrail).anonymousAccessDenied(expectAuditRequestId(threadContext), restRequest); } public void testAuthenticateTransportFallback() throws Exception { when(firstRealm.token(threadContext)).thenReturn(null); when(secondRealm.token(threadContext)).thenReturn(null); User user1 = new User("username", "r1", "r2"); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } - Authentication result = authenticateBlocking("_action", transportRequest, user1); + Tuple result = authenticateBlocking("_action", transportRequest, user1); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + assertThat(expectAuditRequestId(threadContext), is(result.v2())); assertThat(result, notNullValue()); - assertThat(result.getUser(), sameInstance(user1)); - assertThat(result.getAuthenticationType(), is(AuthenticationType.INTERNAL)); - assertThreadContextContainsAuthentication(result); - verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); + assertThat(result.v1().getUser(), sameInstance(user1)); + assertThat(result.v1().getAuthenticationType(), is(AuthenticationType.INTERNAL)); + assertThreadContextContainsAuthentication(result.v1()); + verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result.v1()), eq(threadContext)); } public void testAuthenticateTransportDisabledUser() throws Exception { - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } User user = new User("username", new String[] { "r1", "r2" }, null, null, Map.of(), false); User fallback = randomBoolean() ? SystemUser.INSTANCE : null; when(firstRealm.token(threadContext)).thenReturn(token); @@ -654,7 +746,12 @@ public void testAuthenticateTransportDisabledUser() throws Exception { ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> authenticateBlocking("_action", transportRequest, fallback)); - verify(auditTrail).authenticationFailed(reqId, token, "_action", transportRequest); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).authenticationFailed(reqId.get(), token, "_action", transportRequest); verifyNoMoreInteractions(auditTrail); assertAuthenticationException(e); verifyZeroInteractions(operatorPrivilegesService); @@ -668,7 +765,7 @@ public void testAuthenticateRestDisabledUser() throws Exception { ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> authenticateBlocking(restRequest)); - String reqId = expectAuditRequestId(); + String reqId = expectAuditRequestId(threadContext); verify(auditTrail).authenticationFailed(reqId, token, restRequest); verifyNoMoreInteractions(auditTrail); assertAuthenticationException(e); @@ -676,7 +773,11 @@ public void testAuthenticateRestDisabledUser() throws Exception { } public void testAuthenticateTransportSuccess() throws Exception { - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } final User user = new User("username", "r1", "r2"); final Consumer> authenticate; if (randomBoolean()) { @@ -690,12 +791,17 @@ public void testAuthenticateTransportSuccess() throws Exception { final AtomicBoolean completed = new AtomicBoolean(false); authenticate.accept(ActionListener.wrap(result -> { + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } assertThat(result, notNullValue()); assertThat(result.getUser(), sameInstance(user)); assertThat(result.getAuthenticationType(), is(AuthenticationType.REALM)); assertThat(result.getAuthenticatedBy().getName(), is(firstRealm.name())); // TODO implement equals assertThreadContextContainsAuthentication(result); - verify(auditTrail).authenticationSuccess(reqId, result, "_action", transportRequest); + verify(auditTrail).authenticationSuccess(reqId.get(), result, "_action", transportRequest); setCompletedToTrue(completed); verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); }, this::logAndFail)); @@ -717,7 +823,7 @@ public void testAuthenticateRestSuccess() throws Exception { assertThat(authentication.getAuthenticationType(), is(AuthenticationType.REALM)); assertThat(authentication.getAuthenticatedBy().getName(), is(firstRealm.name())); // TODO implement equals assertThreadContextContainsAuthentication(authentication); - String reqId = expectAuditRequestId(); + String reqId = expectAuditRequestId(threadContext); verify(auditTrail).authenticationSuccess(reqId, authentication, restRequest); setCompletedToTrue(completed); verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(authentication), eq(threadContext)); @@ -735,7 +841,17 @@ public void testAuthenticateTransportContextAndHeader() throws Exception { final SetOnce authRef = new SetOnce<>(); final SetOnce authHeaderRef = new SetOnce<>(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } service.authenticate("_action", transportRequest, SystemUser.INSTANCE, ActionListener.wrap(authentication -> { + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } assertThat(authentication, notNullValue()); assertThat(authentication.getUser(), sameInstance(user1)); assertThat(authentication.getAuthenticationType(), is(AuthenticationType.REALM)); @@ -758,10 +874,20 @@ public void testAuthenticateTransportContextAndHeader() throws Exception { service = new AuthenticationService(Settings.EMPTY, realms, auditTrailService, new DefaultAuthenticationFailureHandler(Collections.emptyMap()), threadPool1, new AnonymousUser(Settings.EMPTY), tokenService, apiKeyService, operatorPrivilegesService); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext1)); + } threadContext1.putTransient(AuthenticationField.AUTHENTICATION_KEY, authRef.get()); threadContext1.putHeader(AuthenticationField.AUTHENTICATION_KEY, authHeaderRef.get()); service.authenticate("_action", message1, SystemUser.INSTANCE, ActionListener.wrap(ctxAuth -> { + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext1), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext1)); + } assertThat(ctxAuth, sameInstance(authRef.get())); assertThat(threadContext1.getHeader(AuthenticationField.AUTHENTICATION_KEY), sameInstance(authHeaderRef.get())); setCompletedToTrue(completed); @@ -779,6 +905,11 @@ public void testAuthenticateTransportContextAndHeader() throws Exception { Mockito.reset(operatorPrivilegesService); try { ThreadContext threadContext2 = threadPool2.getThreadContext(); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext2)); + } final String header; try (ThreadContext.StoredContext ignore = threadContext2.stashContext()) { service = new AuthenticationService(Settings.EMPTY, realms, auditTrailService, @@ -799,6 +930,11 @@ public void testAuthenticateTransportContextAndHeader() throws Exception { new DefaultAuthenticationFailureHandler(Collections.emptyMap()), threadPool2, new AnonymousUser(Settings.EMPTY), tokenService, apiKeyService, operatorPrivilegesService); service.authenticate("_action", new InternalRequest(), SystemUser.INSTANCE, ActionListener.wrap(result -> { + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadPool2.getThreadContext()), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadPool2.getThreadContext())); + } assertThat(result, notNullValue()); assertThat(result.getUser(), equalTo(user1)); assertThat(result.getAuthenticationType(), is(AuthenticationType.REALM)); @@ -815,13 +951,21 @@ public void testAuthenticateTransportContextAndHeader() throws Exception { public void testAuthenticateTamperedUser() throws Exception { InternalRequest message = new InternalRequest(); threadContext.putHeader(AuthenticationField.AUTHENTICATION_KEY, "_signed_auth"); - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); - + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } try { authenticateBlocking("_action", message, randomBoolean() ? SystemUser.INSTANCE : null); } catch (Exception e) { //expected - verify(auditTrail).tamperedRequest(reqId, "_action", message); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).tamperedRequest(reqId.get(), "_action", message); verifyNoMoreInteractions(auditTrail); verifyZeroInteractions(operatorPrivilegesService); } @@ -841,11 +985,20 @@ public void testWrongTokenDoesNotFallbackToAnonymous() { tokenService, apiKeyService, operatorPrivilegesService); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } threadContext.putHeader("Authorization", "Bearer thisisaninvalidtoken"); ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> authenticateBlocking("_action", transportRequest, null)); - verify(auditTrail).anonymousAccessDenied(reqId, "_action", transportRequest); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).anonymousAccessDenied(reqId.get(), "_action", transportRequest); verifyNoMoreInteractions(auditTrail); assertAuthenticationException(e); verifyZeroInteractions(operatorPrivilegesService); @@ -873,11 +1026,20 @@ public void testWrongApiKeyDoesNotFallbackToAnonymous() { return Void.TYPE; }).when(client).get(any(GetRequest.class), any(ActionListener.class)); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } threadContext.putHeader("Authorization", "ApiKey dGhpc2lzYW5pbnZhbGlkaWQ6dGhpc2lzYW5pbnZhbGlkc2VjcmV0"); ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> authenticateBlocking("_action", transportRequest, null)); - verify(auditTrail).anonymousAccessDenied(reqId, "_action", transportRequest); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).anonymousAccessDenied(reqId.get(), "_action", transportRequest); verifyNoMoreInteractions(auditTrail); assertAuthenticationException(e); verifyZeroInteractions(operatorPrivilegesService); @@ -898,16 +1060,16 @@ public void testAnonymousUserRest() throws Exception { threadPool, anonymousUser, tokenService, apiKeyService, operatorPrivilegesService); RestRequest request = new FakeRestRequest(); - Authentication result = authenticateBlocking(request); + Tuple result = authenticateBlocking(request); assertThat(result, notNullValue()); - assertThat(result.getUser(), sameInstance((Object) anonymousUser)); - assertThat(result.getAuthenticationType(), is(AuthenticationType.ANONYMOUS)); - assertThreadContextContainsAuthentication(result); - String reqId = expectAuditRequestId(); - verify(auditTrail).authenticationSuccess(reqId, result, request); + assertThat(result.v1().getUser(), sameInstance((Object) anonymousUser)); + assertThat(result.v1().getAuthenticationType(), is(AuthenticationType.ANONYMOUS)); + assertThreadContextContainsAuthentication(result.v1()); + assertThat(expectAuditRequestId(threadContext), is(result.v2())); + verify(auditTrail).authenticationSuccess(result.v2(), result.v1(), request); verifyNoMoreInteractions(auditTrail); - verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); + verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result.v1()), eq(threadContext)); } public void testAuthenticateRestRequestDisallowAnonymous() throws Exception { @@ -933,7 +1095,7 @@ public void testAuthenticateRestRequestDisallowAnonymous() throws Exception { assertThat(ex, throwableWithMessage(containsString("missing authentication credentials for REST request"))); assertThat(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY), nullValue()); assertThat(threadContext.getHeader(AuthenticationField.AUTHENTICATION_KEY), nullValue()); - String reqId = expectAuditRequestId(); + String reqId = expectAuditRequestId(threadContext); verify(auditTrail).anonymousAccessDenied(reqId, request); verifyNoMoreInteractions(auditTrail); verifyZeroInteractions(operatorPrivilegesService); @@ -948,13 +1110,24 @@ public void testAnonymousUserTransportNoDefaultUser() throws Exception { new DefaultAuthenticationFailureHandler(Collections.emptyMap()), threadPool, anonymousUser, tokenService, apiKeyService, operatorPrivilegesService); InternalRequest message = new InternalRequest(); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } - Authentication result = authenticateBlocking("_action", message, null); + Tuple result = authenticateBlocking("_action", message, null); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } assertThat(result, notNullValue()); - assertThat(result.getUser(), sameInstance(anonymousUser)); - assertThat(result.getAuthenticationType(), is(AuthenticationType.ANONYMOUS)); - assertThreadContextContainsAuthentication(result); - verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); + assertThat(expectAuditRequestId(threadContext), is(result.v2())); + assertThat(result.v1().getUser(), sameInstance(anonymousUser)); + assertThat(result.v1().getAuthenticationType(), is(AuthenticationType.ANONYMOUS)); + assertThreadContextContainsAuthentication(result.v1()); + verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result.v1()), eq(threadContext)); } public void testAnonymousUserTransportWithDefaultUser() throws Exception { @@ -967,24 +1140,44 @@ public void testAnonymousUserTransportWithDefaultUser() throws Exception { threadPool, anonymousUser, tokenService, apiKeyService, operatorPrivilegesService); InternalRequest message = new InternalRequest(); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } - Authentication result = authenticateBlocking("_action", message, SystemUser.INSTANCE); + Tuple result = authenticateBlocking("_action", message, SystemUser.INSTANCE); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } assertThat(result, notNullValue()); - assertThat(result.getUser(), sameInstance(SystemUser.INSTANCE)); - assertThat(result.getAuthenticationType(), is(AuthenticationType.INTERNAL)); - assertThreadContextContainsAuthentication(result); - verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); + assertThat(expectAuditRequestId(threadContext), is(result.v2())); + assertThat(result.v1().getUser(), sameInstance(SystemUser.INSTANCE)); + assertThat(result.v1().getAuthenticationType(), is(AuthenticationType.INTERNAL)); + assertThreadContextContainsAuthentication(result.v1()); + verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result.v1()), eq(threadContext)); } public void testRealmTokenThrowingException() throws Exception { - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } when(firstRealm.token(threadContext)).thenThrow(authenticationError("realm doesn't like tokens")); try { authenticateBlocking("_action", transportRequest, null); fail("exception should bubble out"); } catch (ElasticsearchException e) { assertThat(e.getMessage(), is("realm doesn't like tokens")); - verify(auditTrail).authenticationFailed(reqId, "_action", transportRequest); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).authenticationFailed(reqId.get(), "_action", transportRequest); verifyZeroInteractions(operatorPrivilegesService); } } @@ -996,7 +1189,7 @@ public void testRealmTokenThrowingExceptionRest() throws Exception { fail("exception should bubble out"); } catch (ElasticsearchException e) { assertThat(e.getMessage(), is("realm doesn't like tokens")); - String reqId = expectAuditRequestId(); + String reqId = expectAuditRequestId(threadContext); verify(auditTrail).authenticationFailed(reqId, restRequest); verifyZeroInteractions(operatorPrivilegesService); } @@ -1028,14 +1221,18 @@ public void testRealmSupportsMethodThrowingExceptionRest() throws Exception { fail("exception should bubble out"); } catch (ElasticsearchException e) { assertThat(e.getMessage(), is("realm doesn't like supports")); - String reqId = expectAuditRequestId(); + String reqId = expectAuditRequestId(threadContext); verify(auditTrail).authenticationFailed(reqId, token, restRequest); verifyZeroInteractions(operatorPrivilegesService); } } public void testRealmAuthenticateTerminateAuthenticationProcessWithException() { - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } final AuthenticationToken token = mock(AuthenticationToken.class); final String principal = randomAlphaOfLength(5); when(token.principal()).thenReturn(principal); @@ -1067,14 +1264,23 @@ public void testRealmAuthenticateTerminateAuthenticationProcessWithException() { assertThat(e.getMessage(), is("error attempting to authenticate request")); assertThat(e.getHeader("WWW-Authenticate"), contains(basicScheme)); } - verify(auditTrail).authenticationFailed(reqId, secondRealm.name(), token, "_action", transportRequest); - verify(auditTrail).authenticationFailed(reqId, token, "_action", transportRequest); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).authenticationFailed(reqId.get(), secondRealm.name(), token, "_action", transportRequest); + verify(auditTrail).authenticationFailed(reqId.get(), token, "_action", transportRequest); verifyNoMoreInteractions(auditTrail); verifyZeroInteractions(operatorPrivilegesService); } public void testRealmAuthenticateGracefulTerminateAuthenticationProcess() { - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } final AuthenticationToken token = mock(AuthenticationToken.class); final String principal = randomAlphaOfLength(5); when(token.principal()).thenReturn(principal); @@ -1087,8 +1293,13 @@ public void testRealmAuthenticateGracefulTerminateAuthenticationProcess() { expectThrows(ElasticsearchSecurityException.class, () -> authenticateBlocking("_action", transportRequest, null)); assertThat(e.getMessage(), is("unable to authenticate user [" + principal + "] for action [_action]")); assertThat(e.getHeader("WWW-Authenticate"), contains(basicScheme)); - verify(auditTrail).authenticationFailed(reqId, firstRealm.name(), token, "_action", transportRequest); - verify(auditTrail).authenticationFailed(reqId, token, "_action", transportRequest); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).authenticationFailed(reqId.get(), firstRealm.name(), token, "_action", transportRequest); + verify(auditTrail).authenticationFailed(reqId.get(), token, "_action", transportRequest); verifyNoMoreInteractions(auditTrail); verifyZeroInteractions(operatorPrivilegesService); } @@ -1100,13 +1311,22 @@ public void testRealmAuthenticateThrowingException() throws Exception { when(secondRealm.supports(token)).thenReturn(true); doThrow(authenticationError("realm doesn't like authenticate")) .when(secondRealm).authenticate(eq(token), any(ActionListener.class)); - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } try { authenticateBlocking("_action", transportRequest, null); fail("exception should bubble out"); } catch (ElasticsearchException e) { assertThat(e.getMessage(), is("realm doesn't like authenticate")); - verify(auditTrail).authenticationFailed(reqId, token, "_action", transportRequest); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).authenticationFailed(reqId.get(), token, "_action", transportRequest); verifyZeroInteractions(operatorPrivilegesService); } } @@ -1123,7 +1343,7 @@ public void testRealmAuthenticateThrowingExceptionRest() throws Exception { fail("exception should bubble out"); } catch (ElasticsearchSecurityException e) { assertThat(e.getMessage(), is("realm doesn't like authenticate")); - String reqId = expectAuditRequestId(); + String reqId = expectAuditRequestId(threadContext); verify(auditTrail).authenticationFailed(reqId, token, restRequest); verifyZeroInteractions(operatorPrivilegesService); } @@ -1139,14 +1359,22 @@ public void testRealmLookupThrowingException() throws Exception { mockRealmLookupReturnsNull(firstRealm, "run_as"); doThrow(authenticationError("realm doesn't want to lookup")) .when(secondRealm).lookupUser(eq("run_as"), any(ActionListener.class)); - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); - + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } try { authenticateBlocking("_action", transportRequest, null); fail("exception should bubble out"); } catch (ElasticsearchException e) { assertThat(e.getMessage(), is("realm doesn't want to lookup")); - verify(auditTrail).authenticationFailed(reqId, token, "_action", transportRequest); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).authenticationFailed(reqId.get(), token, "_action", transportRequest); verifyZeroInteractions(operatorPrivilegesService); } } @@ -1166,13 +1394,19 @@ public void testRealmLookupThrowingExceptionRest() throws Exception { fail("exception should bubble out"); } catch (ElasticsearchException e) { assertThat(e.getMessage(), is("realm doesn't want to lookup")); - String reqId = expectAuditRequestId(); + String reqId = expectAuditRequestId(threadContext); verify(auditTrail).authenticationFailed(reqId, token, restRequest); verifyZeroInteractions(operatorPrivilegesService); } } public void testRunAsLookupSameRealm() throws Exception { + boolean testTransportRequest = randomBoolean(); + boolean requestIdAlreadyPresent = randomBoolean() && testTransportRequest; + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } AuthenticationToken token = mock(AuthenticationToken.class); when(token.principal()).thenReturn(randomAlphaOfLength(5)); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as"); @@ -1207,12 +1441,17 @@ public void testRunAsLookupSameRealm() throws Exception { assertEquals(user.email(), authUser.email()); assertEquals(user.enabled(), authUser.enabled()); assertEquals(user.fullName(), authUser.fullName()); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + expectAuditRequestId(threadContext); + } verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); setCompletedToTrue(completed); }, this::logAndFail); // we do not actually go async - if (randomBoolean()) { + if (testTransportRequest) { service.authenticate("_action", transportRequest, true, listener); } else { service.authenticate(restRequest, listener); @@ -1222,6 +1461,12 @@ public void testRunAsLookupSameRealm() throws Exception { @SuppressWarnings("unchecked") public void testRunAsLookupDifferentRealm() throws Exception { + boolean testTransportRequest = randomBoolean(); + boolean requestIdAlreadyPresent = randomBoolean() && testTransportRequest; + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } AuthenticationToken token = mock(AuthenticationToken.class); when(token.principal()).thenReturn(randomAlphaOfLength(5)); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as"); @@ -1247,12 +1492,17 @@ public void testRunAsLookupDifferentRealm() throws Exception { assertThat(authenticated.principal(), is("looked up user")); assertThat(authenticated.roles(), arrayContaining("some role")); assertThreadContextContainsAuthentication(result); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + expectAuditRequestId(threadContext); + } verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); setCompletedToTrue(completed); }, this::logAndFail); // call service asynchronously but it doesn't actually go async - if (randomBoolean()) { + if (testTransportRequest) { service.authenticate("_action", transportRequest, true, listener); } else { service.authenticate(restRequest, listener); @@ -1273,7 +1523,7 @@ public void testRunAsWithEmptyRunAsUsernameRest() throws Exception { authenticateBlocking(restRequest); fail("exception should be thrown"); } catch (ElasticsearchException e) { - String reqId = expectAuditRequestId(); + String reqId = expectAuditRequestId(threadContext); verify(auditTrail).runAsDenied(eq(reqId), any(Authentication.class), eq(restRequest), eq(EmptyAuthorizationInfo.INSTANCE)); verifyNoMoreInteractions(auditTrail); verifyZeroInteractions(operatorPrivilegesService); @@ -1285,7 +1535,11 @@ public void testRunAsWithEmptyRunAsUsername() throws Exception { when(token.principal()).thenReturn(randomAlphaOfLength(5)); User user = new User("lookup user", new String[]{"user"}); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, ""); - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.supports(token)).thenReturn(true); mockAuthenticate(secondRealm, token, user); @@ -1294,7 +1548,12 @@ public void testRunAsWithEmptyRunAsUsername() throws Exception { authenticateBlocking("_action", transportRequest, null); fail("exception should be thrown"); } catch (ElasticsearchException e) { - verify(auditTrail).runAsDenied(eq(reqId), any(Authentication.class), eq("_action"), eq(transportRequest), + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).runAsDenied(eq(reqId.get()), any(Authentication.class), eq("_action"), eq(transportRequest), eq(EmptyAuthorizationInfo.INSTANCE)); verifyNoMoreInteractions(auditTrail); verifyZeroInteractions(operatorPrivilegesService); @@ -1306,7 +1565,11 @@ public void testAuthenticateTransportDisabledRunAsUser() throws Exception { AuthenticationToken token = mock(AuthenticationToken.class); when(token.principal()).thenReturn(randomAlphaOfLength(5)); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as"); - final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.supports(token)).thenReturn(true); mockAuthenticate(secondRealm, token, new User("lookup user", new String[]{"user"})); @@ -1319,7 +1582,12 @@ public void testAuthenticateTransportDisabledRunAsUser() throws Exception { User fallback = randomBoolean() ? SystemUser.INSTANCE : null; ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> authenticateBlocking("_action", transportRequest, fallback)); - verify(auditTrail).authenticationFailed(reqId, token, "_action", transportRequest); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } + verify(auditTrail).authenticationFailed(reqId.get(), token, "_action", transportRequest); verifyNoMoreInteractions(auditTrail); assertAuthenticationException(e); verifyZeroInteractions(operatorPrivilegesService); @@ -1342,7 +1610,7 @@ public void testAuthenticateRestDisabledRunAsUser() throws Exception { ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> authenticateBlocking(restRequest)); - String reqId = expectAuditRequestId(); + String reqId = expectAuditRequestId(threadContext); verify(auditTrail).authenticationFailed(reqId, token, restRequest); verifyNoMoreInteractions(auditTrail); assertAuthenticationException(e); @@ -1368,6 +1636,11 @@ public void testAuthenticateWithToken() throws Exception { when(securityIndex.indexExists()).thenReturn(true); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { threadContext.putHeader("Authorization", "Bearer " + token); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } service.authenticate("_action", transportRequest, true, ActionListener.wrap(result -> { assertThat(result, notNullValue()); assertThat(result.getUser(), is(user)); @@ -1375,9 +1648,14 @@ public void testAuthenticateWithToken() throws Exception { assertThat(result.getAuthenticatedBy(), is(notNullValue())); assertThat(result.getAuthenticatedBy().getName(), is("realm")); // TODO implement equals assertThat(result.getAuthenticationType(), is(AuthenticationType.TOKEN)); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); setCompletedToTrue(completed); - verify(auditTrail).authenticationSuccess(anyString(), eq(result), eq("_action"), same(transportRequest)); + verify(auditTrail).authenticationSuccess(eq(reqId.get()), eq(result), eq("_action"), same(transportRequest)); }, this::logAndFail)); } assertTrue(completed.get()); @@ -1395,8 +1673,13 @@ public void testInvalidToken() throws Exception { final CountDownLatch latch = new CountDownLatch(1); final Authentication expected = new Authentication(user, new RealmRef(firstRealm.name(), firstRealm.type(), "authc_test"), null); AtomicBoolean success = new AtomicBoolean(false); + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { threadContext.putHeader("Authorization", "Bearer " + Base64.getEncoder().encodeToString(randomBytes)); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } service.authenticate("_action", transportRequest, true, ActionListener.wrap(result -> { assertThat(result, notNullValue()); assertThat(result.getUser(), is(user)); @@ -1404,11 +1687,21 @@ public void testInvalidToken() throws Exception { assertThat(result.getAuthenticatedBy(), is(notNullValue())); assertThreadContextContainsAuthentication(result); assertEquals(expected, result); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result), eq(threadContext)); success.set(true); latch.countDown(); }, e -> { verifyZeroInteractions(operatorPrivilegesService); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } else { + reqId.set(expectAuditRequestId(threadContext)); + } if (e instanceof IllegalStateException) { assertThat(e.getMessage(), containsString("array length must be <= to " + ArrayUtil.MAX_ARRAY_LENGTH + " but was: ")); latch.countDown(); @@ -1433,7 +1726,7 @@ public void testInvalidToken() throws Exception { latch.await(); if (success.get()) { final String realmName = firstRealm.name(); - verify(auditTrail).authenticationSuccess(anyString(), eq(expected), eq("_action"), same(transportRequest)); + verify(auditTrail).authenticationSuccess(eq(reqId.get()), eq(expected), eq("_action"), same(transportRequest)); } verifyNoMoreInteractions(auditTrail); } @@ -1459,9 +1752,17 @@ public void testExpiredToken() throws Exception { }).when(securityIndex).prepareIndexIfNeededThenExecute(any(Consumer.class), any(Runnable.class)); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } threadContext.putHeader("Authorization", "Bearer " + token); ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> authenticateBlocking("_action", transportRequest, null)); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } assertEquals(RestStatus.UNAUTHORIZED, e.status()); assertEquals("token expired", e.getMessage()); verifyZeroInteractions(operatorPrivilegesService); @@ -1470,10 +1771,18 @@ public void testExpiredToken() throws Exception { public void testApiKeyAuthInvalidHeader() { try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } final String invalidHeader = randomFrom("apikey", "apikey ", "apikey foo"); threadContext.putHeader("Authorization", invalidHeader); ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> authenticateBlocking("_action", transportRequest, null)); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } assertEquals(RestStatus.UNAUTHORIZED, e.status()); assertThat(e.getMessage(), containsString("missing authentication credentials")); verifyZeroInteractions(operatorPrivilegesService); @@ -1519,13 +1828,22 @@ public void testApiKeyAuth() { }).when(client).get(any(GetRequest.class), any(ActionListener.class)); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } threadContext.putHeader("Authorization", headerValue); - final Authentication authentication = authenticateBlocking("_action", transportRequest, null); - assertThat(authentication.getUser().principal(), is("johndoe")); - assertThat(authentication.getUser().fullName(), is("john doe")); - assertThat(authentication.getUser().email(), is("john@doe.com")); - assertThat(authentication.getAuthenticationType(), is(AuthenticationType.API_KEY)); - verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(authentication), eq(threadContext)); + Tuple result = authenticateBlocking("_action", transportRequest, null); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } + assertThat(expectAuditRequestId(threadContext), is(result.v2())); + assertThat(result.v1().getUser().principal(), is("johndoe")); + assertThat(result.v1().getUser().fullName(), is("john doe")); + assertThat(result.v1().getUser().email(), is("john@doe.com")); + assertThat(result.v1().getAuthenticationType(), is(AuthenticationType.API_KEY)); + verify(operatorPrivilegesService).maybeMarkOperatorUser(eq(result.v1()), eq(threadContext)); } } @@ -1562,9 +1880,17 @@ public void testExpiredApiKey() { }).when(client).get(any(GetRequest.class), any(ActionListener.class)); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + boolean requestIdAlreadyPresent = randomBoolean(); + SetOnce reqId = new SetOnce<>(); + if (requestIdAlreadyPresent) { + reqId.set(AuditUtil.getOrGenerateRequestId(threadContext)); + } threadContext.putHeader("Authorization", headerValue); ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> authenticateBlocking("_action", transportRequest, null)); + if (requestIdAlreadyPresent) { + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + } assertEquals(RestStatus.UNAUTHORIZED, e.status()); verifyZeroInteractions(operatorPrivilegesService); } @@ -1618,23 +1944,55 @@ private void mockAuthenticate(Realm realm, AuthenticationToken token, Exception }).when(realm).authenticate(eq(token), any(ActionListener.class)); } - private Authentication authenticateBlocking(RestRequest restRequest) { - PlainActionFuture future = new PlainActionFuture<>(); + private Tuple authenticateBlocking(RestRequest restRequest) { + SetOnce reqId = new SetOnce<>(); + PlainActionFuture future = new PlainActionFuture<>() { + @Override + public void onResponse(Authentication result) { + reqId.set(expectAuditRequestId(threadContext)); + assertThat(new AuthenticationContextSerializer().getAuthentication(threadContext), is(result)); + super.onResponse(result); + } + + @Override + public void onFailure(Exception e) { + reqId.set(expectAuditRequestId(threadContext)); + super.onFailure(e); + } + }; service.authenticate(restRequest, future); - return future.actionGet(); + Authentication authentication = future.actionGet(); + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + return new Tuple<>(authentication, reqId.get()); } - private Authentication authenticateBlocking(String action, TransportRequest transportRequest, User fallbackUser) { - PlainActionFuture future = new PlainActionFuture<>(); + private Tuple authenticateBlocking(String action, TransportRequest transportRequest, User fallbackUser) { + SetOnce reqId = new SetOnce<>(); + PlainActionFuture future = new PlainActionFuture<>() { + @Override + public void onResponse(Authentication result) { + reqId.set(expectAuditRequestId(threadContext)); + assertThat(new AuthenticationContextSerializer().getAuthentication(threadContext), is(result)); + super.onResponse(result); + } + + @Override + public void onFailure(Exception e) { + reqId.set(expectAuditRequestId(threadContext)); + super.onFailure(e); + } + }; if (fallbackUser == null) { service.authenticate(action, transportRequest, true, future); } else { service.authenticate(action, transportRequest, fallbackUser, future); } - return future.actionGet(); + Authentication authentication = future.actionGet(); + assertThat(expectAuditRequestId(threadContext), is(reqId.get())); + return new Tuple<>(authentication, reqId.get()); } - private String expectAuditRequestId() { + private static String expectAuditRequestId(ThreadContext threadContext) { String reqId = AuditUtil.extractRequestId(threadContext); assertThat(reqId, is(not(emptyOrNullString()))); return reqId;