Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new audit handler for action responses (elastic#63708)
Browse files Browse the repository at this point in the history
This adds a new method to the AuditTrail that intercepts the
responses of transport-level actions. This new method is unlike all
the other existing audit methods because it is called after the
action has been run (so that it has access to the response).
After careful deliberation, the new method is called for the
responses of actions that are intercepted by the
`SecurityActionFilter` only, and not by the transport filter.

In order to facilitate the "linking" of the new audit event with the
other existing events, the audit method receives the requestId
as well as the authentication as arguments (in addition to the
request itself and the response).

This is labeled non-issue because it is only the foundation
upon which later features that actually print out (some) responses
can be built upon.

Related elastic#63221
albertzaharovits committed Dec 17, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent d79ae0e commit 3732ea1
Showing 9 changed files with 585 additions and 147 deletions.
Original file line number Diff line number Diff line change
@@ -97,7 +97,7 @@ public void testApply_ActionAllowedInUpgradeMode() {

public void testOrder_UpgradeFilterIsExecutedAfterSecurityFilter() {
MlUpgradeModeActionFilter upgradeModeFilter = new MlUpgradeModeActionFilter(clusterService);
SecurityActionFilter securityFilter = new SecurityActionFilter(null, null, null, mock(ThreadPool.class), null, null);
SecurityActionFilter securityFilter = new SecurityActionFilter(null, null, null, null, mock(ThreadPool.class), null, null);

ActionFilter[] actionFiltersInOrderOfExecution = new ActionFilters(Sets.newHashSet(upgradeModeFilter, securityFilter)).filters();
assertThat(actionFiltersInOrderOfExecution, is(arrayContaining(securityFilter, upgradeModeFilter)));
Original file line number Diff line number Diff line change
@@ -560,7 +560,7 @@ auditTrailService, failureHandler, threadPool, anonymousUser, getAuthorizationEn
securityInterceptor.set(new SecurityServerTransportInterceptor(settings, threadPool, authcService.get(),
authzService, getLicenseState(), getSslService(), securityContext.get(), destructiveOperations, clusterService));

securityActionFilter.set(new SecurityActionFilter(authcService.get(), authzService, getLicenseState(),
securityActionFilter.set(new SecurityActionFilter(authcService.get(), authzService, auditTrailService, getLicenseState(),
threadPool, securityContext.get(), destructiveOperations));

return components;
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
import org.elasticsearch.action.support.ActionFilterChain;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.DestructiveOperations;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
@@ -33,11 +34,12 @@
import org.elasticsearch.xpack.core.security.support.Automatons;
import org.elasticsearch.xpack.core.security.user.SystemUser;
import org.elasticsearch.xpack.security.action.SecurityActionMapper;
import org.elasticsearch.xpack.security.audit.AuditTrailService;
import org.elasticsearch.xpack.security.audit.AuditUtil;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authz.AuthorizationService;
import org.elasticsearch.xpack.security.authz.AuthorizationUtils;

import java.io.IOException;
import java.util.function.Predicate;

public class SecurityActionFilter implements ActionFilter {
@@ -48,17 +50,19 @@ public class SecurityActionFilter implements ActionFilter {

private final AuthenticationService authcService;
private final AuthorizationService authzService;
private final AuditTrailService auditTrailService;
private final SecurityActionMapper actionMapper = new SecurityActionMapper();
private final XPackLicenseState licenseState;
private final ThreadContext threadContext;
private final SecurityContext securityContext;
private final DestructiveOperations destructiveOperations;

public SecurityActionFilter(AuthenticationService authcService, AuthorizationService authzService,
XPackLicenseState licenseState, ThreadPool threadPool,
AuditTrailService auditTrailService, XPackLicenseState licenseState, ThreadPool threadPool,
SecurityContext securityContext, DestructiveOperations destructiveOperations) {
this.authcService = authcService;
this.authzService = authzService;
this.auditTrailService = auditTrailService;
this.licenseState = licenseState;
this.threadContext = threadPool.getThreadContext();
this.securityContext = securityContext;
@@ -83,29 +87,19 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
if (licenseState.isSecurityEnabled()) {
final ActionListener<Response> contextPreservingListener =
ContextPreservingActionListener.wrapPreservingContext(listener, threadContext);
ActionListener<Void> authenticatedListener = ActionListener.wrap(
(aVoid) -> chain.proceed(task, action, request, contextPreservingListener), contextPreservingListener::onFailure);
final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action);
try {
if (useSystemUser) {
securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> {
try {
applyInternal(action, request, authenticatedListener);
} catch (IOException e) {
listener.onFailure(e);
}
applyInternal(task, chain, action, request, contextPreservingListener);
}, Version.CURRENT);
} else if (AuthorizationUtils.shouldSetUserBasedOnActionOrigin(threadContext)) {
AuthorizationUtils.switchUserBasedOnActionOriginAndExecute(threadContext, securityContext, (original) -> {
try {
applyInternal(action, request, authenticatedListener);
} catch (IOException e) {
listener.onFailure(e);
}
applyInternal(task, chain, action, request, contextPreservingListener);
});
} else {
try (ThreadContext.StoredContext ignore = threadContext.newStoredContext(true)) {
applyInternal(action, request, authenticatedListener);
applyInternal(task, chain, action, request, contextPreservingListener);
}
}
} catch (Exception e) {
@@ -130,13 +124,13 @@ public int order() {
return Integer.MIN_VALUE;
}

private <Request extends ActionRequest> void applyInternal(String action, Request request,
ActionListener<Void> listener) throws IOException {
private <Request extends ActionRequest, Response extends ActionResponse> void applyInternal(Task task,
ActionFilterChain<Request, Response> chain, String action, Request request, ActionListener<Response> listener) {
if (CloseIndexAction.NAME.equals(action) || OpenIndexAction.NAME.equals(action) || DeleteIndexAction.NAME.equals(action)) {
IndicesRequest indicesRequest = (IndicesRequest) request;
try {
destructiveOperations.failDestructive(indicesRequest.indices());
} catch(IllegalArgumentException e) {
} catch (IllegalArgumentException e) {
listener.onFailure(e);
return;
}
@@ -156,7 +150,17 @@ it to the action without an associated user (not via REST or transport - this is
authcService.authenticate(securityAction, request, SystemUser.INSTANCE,
ActionListener.wrap((authc) -> {
if (authc != null) {
authorizeRequest(authc, securityAction, request, listener);
final String requestId = AuditUtil.extractRequestId(threadContext);
assert Strings.hasText(requestId);
authorizeRequest(authc, securityAction, request, ActionListener.delegateFailure(listener,
(ignore, aVoid) -> {
chain.proceed(task, action, request, ActionListener.delegateFailure(listener,
(ignore2, response) -> {
auditTrailService.get().coordinatingActionResponse(requestId, authc, action, request,
response);
listener.onResponse(response);
}));
}));
} else if (licenseState.isSecurityEnabled() == false) {
listener.onResponse(null);
} else {
@@ -166,12 +170,11 @@ it to the action without an associated user (not via REST or transport - this is
}

private <Request extends ActionRequest> void authorizeRequest(Authentication authentication, String securityAction, Request request,
ActionListener<Void> listener) {
ActionListener<Void> listener) {
if (authentication == null) {
listener.onFailure(new IllegalArgumentException("authentication must be non null for authorization"));
} else {
authzService.authorize(authentication, securityAction, request, ActionListener.wrap(ignore -> listener.onResponse(null),
listener::onFailure));
authzService.authorize(authentication, securityAction, request, listener);
}
}
}
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.AuthenticationToken;
import org.elasticsearch.xpack.core.security.authz.AuthorizationEngine.AuthorizationInfo;
@@ -81,4 +82,9 @@ void runAsDenied(String requestId, Authentication authentication, RestRequest re
void explicitIndexAccessEvent(String requestId, AuditLevel eventType, Authentication authentication, String action, String indices,
String requestName, TransportAddress remoteAddress, AuthorizationInfo authorizationInfo);

// this is the only audit method that is called *after* the action executed, when the response is available
// it is however *only called for coordinating actions*, which are the actions that a client invokes as opposed to
// the actions that a node invokes in order to service a client request
void coordinatingActionResponse(String requestId, Authentication authentication, String action, TransportRequest transportRequest,
TransportResponse transportResponse);
}
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
import org.elasticsearch.license.XPackLicenseState.Feature;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.AuthenticationToken;
import org.elasticsearch.xpack.core.security.authz.AuthorizationEngine.AuthorizationInfo;
@@ -147,6 +148,11 @@ public void runAsDenied(String requestId, Authentication authentication, RestReq
public void explicitIndexAccessEvent(String requestId, AuditLevel eventType, Authentication authentication,
String action, String indices, String requestName, TransportAddress remoteAddress,
AuthorizationInfo authorizationInfo) {}

@Override
public void coordinatingActionResponse(String requestId, Authentication authentication, String action,
TransportRequest transportRequest,
TransportResponse transportResponse) { }
}

private static class CompositeAuditTrail implements AuditTrail {
@@ -254,6 +260,15 @@ public void accessDenied(String requestId, Authentication authentication, String
}
}

@Override
public void coordinatingActionResponse(String requestId, Authentication authentication, String action,
TransportRequest transportRequest,
TransportResponse transportResponse) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.coordinatingActionResponse(requestId, authentication, action, transportRequest, transportResponse);
}
}

@Override
public void tamperedRequest(String requestId, RestRequest request) {
for (AuditTrail auditTrail : auditTrails) {
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.xpack.core.security.action.CreateApiKeyAction;
import org.elasticsearch.xpack.core.security.action.CreateApiKeyRequest;
import org.elasticsearch.xpack.core.security.action.GrantApiKeyAction;
@@ -815,6 +816,13 @@ public void runAsDenied(String requestId, Authentication authentication, RestReq
}
}

@Override
public void coordinatingActionResponse(String requestId, Authentication authentication, String action,
TransportRequest transportRequest,
TransportResponse transportResponse) {
// not implemented yet
}

private LogEntryBuilder securityChangeLogEntryBuilder(String requestId) {
return new LogEntryBuilder(false)
.with(EVENT_TYPE_FIELD_NAME, SECURITY_CHANGE_ORIGIN_FIELD_VALUE)
Original file line number Diff line number Diff line change
@@ -61,11 +61,8 @@
import org.elasticsearch.xpack.core.security.authz.privilege.ClusterPrivilegeResolver;
import org.elasticsearch.xpack.core.security.authz.privilege.IndexPrivilege;
import org.elasticsearch.xpack.core.security.user.AnonymousUser;
import org.elasticsearch.xpack.core.security.user.AsyncSearchUser;
import org.elasticsearch.xpack.core.security.user.SystemUser;
import org.elasticsearch.xpack.core.security.user.User;
import org.elasticsearch.xpack.core.security.user.XPackSecurityUser;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.security.audit.AuditLevel;
import org.elasticsearch.xpack.security.audit.AuditTrail;
import org.elasticsearch.xpack.security.audit.AuditTrailService;
@@ -95,6 +92,7 @@
import static org.elasticsearch.xpack.core.security.authz.AuthorizationServiceField.INDICES_PERMISSIONS_KEY;
import static org.elasticsearch.xpack.core.security.authz.AuthorizationServiceField.ORIGINATING_ACTION_KEY;
import static org.elasticsearch.xpack.core.security.support.Exceptions.authorizationError;
import static org.elasticsearch.xpack.core.security.user.User.isInternal;
import static org.elasticsearch.xpack.security.audit.logfile.LoggingAuditTrail.PRINCIPAL_ROLES_FIELD_NAME;

public class AuthorizationService {
@@ -190,14 +188,15 @@ public void authorize(final Authentication authentication, final String action,
if (auditId == null) {
// We would like to assert that there is an existing request-id, but if this is a system action, then that might not be
// true because the request-id is generated during authentication
if (isInternalUser(authentication.getUser()) != false) {
if (isInternal(authentication.getUser())) {
auditId = AuditUtil.getOrGenerateRequestId(threadContext);
} else {
auditTrailService.get().tamperedRequest(null, authentication, action, originalRequest);
final String message = "Attempt to authorize action [" + action + "] for [" + authentication.getUser().principal()
+ "] without an existing request-id";
assert false : message;
listener.onFailure(new ElasticsearchSecurityException(message));
return;
}
}

@@ -397,7 +396,7 @@ AuthorizationEngine getAuthorizationEngine(final Authentication authentication)
private AuthorizationEngine getAuthorizationEngineForUser(final User user) {
if (rbacEngine != authorizationEngine && licenseState.isSecurityEnabled() &&
licenseState.checkFeature(Feature.SECURITY_AUTHORIZATION_ENGINE)) {
if (ClientReservedRealm.isReserved(user.principal(), settings) || isInternalUser(user)) {
if (ClientReservedRealm.isReserved(user.principal(), settings) || isInternal(user)) {
return rbacEngine;
} else {
return authorizationEngine;
@@ -448,10 +447,6 @@ private TransportRequest maybeUnwrapRequest(Authentication authentication, Trans
return request;
}

private boolean isInternalUser(User user) {
return SystemUser.is(user) || XPackUser.is(user) || XPackSecurityUser.is(user) || AsyncSearchUser.is(user);
}

private void authorizeRunAs(final RequestInfo requestInfo, final AuthorizationInfo authzInfo,
final ActionListener<AuthorizationResult> listener) {
final Authentication authentication = requestInfo.getAuthentication();
Original file line number Diff line number Diff line change
@@ -9,17 +9,18 @@
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.MockIndicesRequest;
import org.elasticsearch.action.admin.indices.close.CloseIndexAction;
import org.elasticsearch.action.admin.indices.delete.DeleteIndexAction;
import org.elasticsearch.action.admin.indices.open.OpenIndexAction;
import org.elasticsearch.action.support.ActionFilterChain;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.DestructiveOperations;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
@@ -37,6 +38,9 @@
import org.elasticsearch.xpack.core.security.authz.accesscontrol.IndicesAccessControl;
import org.elasticsearch.xpack.core.security.user.SystemUser;
import org.elasticsearch.xpack.core.security.user.User;
import org.elasticsearch.xpack.security.audit.AuditTrail;
import org.elasticsearch.xpack.security.audit.AuditTrailService;
import org.elasticsearch.xpack.security.audit.AuditUtil;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authz.AuthorizationService;
import org.junit.Before;
@@ -45,6 +49,7 @@

import static org.elasticsearch.xpack.core.security.authz.AuthorizationServiceField.INDICES_PERMISSIONS_KEY;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
@@ -60,6 +65,9 @@
public class SecurityActionFilterTests extends ESTestCase {
private AuthenticationService authcService;
private AuthorizationService authzService;
private AuditTrailService auditTrailService;
private AuditTrail auditTrail;
private ActionFilterChain chain;
private XPackLicenseState licenseState;
private SecurityActionFilter filter;
private ThreadContext threadContext;
@@ -69,6 +77,10 @@ public class SecurityActionFilterTests extends ESTestCase {
public void init() throws Exception {
authcService = mock(AuthenticationService.class);
authzService = mock(AuthorizationService.class);
auditTrailService = mock(AuditTrailService.class);
auditTrail = mock(AuditTrail.class);
when(auditTrailService.get()).thenReturn(auditTrail);
chain = mock(ActionFilterChain.class);
licenseState = mock(XPackLicenseState.class);
when(licenseState.isSecurityEnabled()).thenReturn(true);
when(licenseState.checkFeature(Feature.SECURITY_STATS_AND_HEALTH)).thenReturn(true);
@@ -88,32 +100,37 @@ public void init() throws Exception {
when(state.nodes()).thenReturn(nodes);

SecurityContext securityContext = new SecurityContext(settings, threadContext);
filter = new SecurityActionFilter(authcService, authzService, licenseState, threadPool, securityContext, destructiveOperations);
filter = new SecurityActionFilter(authcService, authzService, auditTrailService, licenseState, threadPool,
securityContext, destructiveOperations);
}

public void testApply() throws Exception {
ActionRequest request = mock(ActionRequest.class);
ActionListener listener = mock(ActionListener.class);
ActionFilterChain chain = mock(ActionFilterChain.class);
Task task = mock(Task.class);
User user = new User("username", "r1", "r2");
Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null);
mockAuthentication(request, authentication);
String requestId = UUIDs.randomBase64UUID();
mockAuthentication(request, authentication, requestId);
mockAuthorize();
ActionResponse actionResponse = mock(ActionResponse.class);
mockChain(task, "_action", request, actionResponse);
filter.apply(task, "_action", request, listener, chain);
verify(authzService).authorize(eq(authentication), eq("_action"), eq(request), any(ActionListener.class));
verify(chain).proceed(eq(task), eq("_action"), eq(request), isA(ContextPreservingActionListener.class));
verify(auditTrail).coordinatingActionResponse(eq(requestId), eq(authentication), eq("_action"), eq(request), eq(actionResponse));
}

public void testApplyRestoresThreadContext() throws Exception {
ActionRequest request = mock(ActionRequest.class);
ActionListener listener = mock(ActionListener.class);
ActionFilterChain chain = mock(ActionFilterChain.class);
Task task = mock(Task.class);
User user = new User("username", "r1", "r2");
Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null);
mockAuthentication(request, authentication);
String requestId = UUIDs.randomBase64UUID();
mockAuthentication(request, authentication, requestId);
mockAuthorize();
ActionResponse actionResponse = mock(ActionResponse.class);
mockChain(task, "_action", request, actionResponse);
assertNull(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY));
assertNull(threadContext.getTransient(INDICES_PERMISSIONS_KEY));

@@ -122,7 +139,7 @@ public void testApplyRestoresThreadContext() throws Exception {
assertNull(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY));
assertNull(threadContext.getTransient(INDICES_PERMISSIONS_KEY));
verify(authzService).authorize(eq(authentication), eq("_action"), eq(request), any(ActionListener.class));
verify(chain).proceed(eq(task), eq("_action"), eq(request), isA(ContextPreservingActionListener.class));
verify(auditTrail).coordinatingActionResponse(eq(requestId), eq(authentication), eq("_action"), eq(request), eq(actionResponse));
}

public void testApplyAsSystemUser() throws Exception {
@@ -132,28 +149,34 @@ public void testApplyAsSystemUser() throws Exception {
Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null);
SetOnce<Authentication> authenticationSetOnce = new SetOnce<>();
SetOnce<IndicesAccessControl> accessControlSetOnce = new SetOnce<>();
SetOnce<String> requestIdOnActionHandler = new SetOnce<>();
ActionFilterChain chain = (task, action, request1, listener1) -> {
authenticationSetOnce.set(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY));
accessControlSetOnce.set(threadContext.getTransient(INDICES_PERMISSIONS_KEY));
requestIdOnActionHandler.set(AuditUtil.extractRequestId(threadContext));
};
Task task = mock(Task.class);
final boolean hasExistingAuthentication = randomBoolean();
final boolean hasExistingAccessControl = randomBoolean();
final String action = "internal:foo";
if (hasExistingAuthentication) {
AuditUtil.generateRequestId(threadContext);
threadContext.putTransient(AuthenticationField.AUTHENTICATION_KEY, authentication);
threadContext.putHeader(AuthenticationField.AUTHENTICATION_KEY, "foo");
threadContext.putTransient(AuthorizationServiceField.ORIGINATING_ACTION_KEY, "indices:foo");
if (hasExistingAccessControl) {
threadContext.putTransient(INDICES_PERMISSIONS_KEY, IndicesAccessControl.ALLOW_NO_INDICES);
}
} else {
assertNull(AuditUtil.extractRequestId(threadContext));
assertNull(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY));
}
SetOnce<String> requestIdFromAuthn = new SetOnce<>();
doAnswer(i -> {
final Object[] args = i.getArguments();
assertThat(args, arrayWithSize(4));
ActionListener callback = (ActionListener) args[args.length - 1];
requestIdFromAuthn.set(AuditUtil.generateRequestId(threadContext));
callback.onResponse(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY));
return Void.TYPE;
}).when(authcService).authenticate(eq(action), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class));
@@ -174,6 +197,7 @@ public void testApplyAsSystemUser() throws Exception {
assertNotEquals(authentication, authenticationSetOnce.get());
assertEquals(SystemUser.INSTANCE, authenticationSetOnce.get().getUser());
assertThat(accessControlSetOnce.get(), sameInstance(authzAccessControl));
assertThat(requestIdOnActionHandler.get(), is(requestIdFromAuthn.get()));
}

public void testApplyDestructiveOperations() throws Exception {
@@ -182,14 +206,19 @@ public void testApplyDestructiveOperations() throws Exception {
randomFrom("*", "_all", "test*"));
String action = randomFrom(CloseIndexAction.NAME, OpenIndexAction.NAME, DeleteIndexAction.NAME);
ActionListener listener = mock(ActionListener.class);
ActionFilterChain chain = mock(ActionFilterChain.class);
Task task = mock(Task.class);
User user = new User("username", "r1", "r2");
Authentication authentication = new Authentication(user, new RealmRef("test", "test", "foo"), null);
ActionResponse actionResponse = mock(ActionResponse.class);
mockChain(task, action, request, actionResponse);
SetOnce<String> requestIdFromAuthn = new SetOnce<>();
doAnswer(i -> {
final Object[] args = i.getArguments();
assertThat(args, arrayWithSize(4));
ActionListener callback = (ActionListener) args[args.length - 1];
requestIdFromAuthn.set(AuditUtil.generateRequestId(threadContext));
threadContext.putTransient(AuthenticationField.AUTHENTICATION_KEY, authentication);
threadContext.putHeader(AuthenticationField.AUTHENTICATION_KEY, authentication.encode());
callback.onResponse(authentication);
return Void.TYPE;
}).when(authcService).authenticate(eq(action), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class));
@@ -202,10 +231,12 @@ public void testApplyDestructiveOperations() throws Exception {
filter.apply(task, action, request, listener, chain);
if (failDestructiveOperations) {
verify(listener).onFailure(isA(IllegalArgumentException.class));
verifyNoMoreInteractions(authzService, chain);
verifyNoMoreInteractions(authzService, chain, auditTrailService, auditTrail);
} else {
verify(authzService).authorize(eq(authentication), eq(action), eq(request), any(ActionListener.class));
verify(chain).proceed(eq(task), eq(action), eq(request), isA(ContextPreservingActionListener.class));
verify(chain).proceed(eq(task), eq(action), eq(request), any(ActionListener.class));
verify(auditTrail).coordinatingActionResponse(eq(requestIdFromAuthn.get()), eq(authentication), eq(action), eq(request),
eq(actionResponse));
}
}

@@ -221,10 +252,21 @@ public void testActionProcessException() throws Exception {
final Object[] args = i.getArguments();
assertThat(args, arrayWithSize(4));
ActionListener callback = (ActionListener) args[args.length - 1];
assertNull(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY));
AuditUtil.generateRequestId(threadContext);
callback.onResponse(authentication);
return Void.TYPE;
}).when(authcService).authenticate(eq("_action"), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class));
doThrow(exception).when(authzService).authorize(eq(authentication), eq("_action"), eq(request), any(ActionListener.class));
if (randomBoolean()) {
doThrow(exception).when(authzService).authorize(eq(authentication), eq("_action"), eq(request), any(ActionListener.class));
} else {
doAnswer((i) -> {
ActionListener<Void> callback = (ActionListener<Void>) i.getArguments()[3];
callback.onFailure(exception);
return Void.TYPE;
}).when(authzService)
.authorize(eq(authentication), eq("_action"), eq(request), any(ActionListener.class));
}
filter.apply(task, "_action", request, listener, chain);
verify(listener).onFailure(exception);
verifyNoMoreInteractions(chain);
@@ -242,13 +284,15 @@ public void testApplyUnlicensed() throws Exception {
verify(chain).proceed(eq(task), eq("_action"), eq(request), eq(listener));
}

private void mockAuthentication(ActionRequest request, Authentication authentication) {
private void mockAuthentication(ActionRequest request, Authentication authentication, String requestId) {
doAnswer(i -> {
final Object[] args = i.getArguments();
assertThat(args, arrayWithSize(4));
ActionListener callback = (ActionListener) args[args.length - 1];
assertNull(threadContext.getTransient(AuthenticationField.AUTHENTICATION_KEY));
threadContext.putTransient(AuthenticationField.AUTHENTICATION_KEY, authentication);
threadContext.putHeader(AuthenticationField.AUTHENTICATION_KEY, authentication.encode());
threadContext.putHeader("_xpack_audit_request_id", requestId);
callback.onResponse(authentication);
return Void.TYPE;
}).when(authcService).authenticate(eq("_action"), eq(request), eq(SystemUser.INSTANCE), any(ActionListener.class));
@@ -271,4 +315,13 @@ private void mockAuthorize(IndicesAccessControl indicesAccessControl) {
.authorize(any(Authentication.class), any(String.class), any(TransportRequest.class), any(ActionListener.class));
}

private void mockChain(Task task, String action, ActionRequest request, ActionResponse actionResponse) {
doAnswer(i -> {
final Object[] args = i.getArguments();
assertThat(args, arrayWithSize(4));
ActionListener callback = (ActionListener) args[args.length - 1];
callback.onResponse(actionResponse);
return Void.TYPE;
}).when(chain).proceed(eq(task), eq(action), eq(request), any(ActionListener.class));
}
}

Large diffs are not rendered by default.

0 comments on commit 3732ea1

Please sign in to comment.