Skip to content

Commit

Permalink
Make Authentication/Authorization Stacks Shallower/Simpler (elastic#7…
Browse files Browse the repository at this point in the history
…5662)

Same as elastic#75252 pretty much just continuing to make this logic a little
simpler for easier profiling and (very) maybe performance through
saving some allocations/indirection.
  • Loading branch information
original-brownbear authored and ywangd committed Jul 30, 2021
1 parent 1776cec commit 10d361e
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ private ClusterPermission(final Set<ClusterPrivilege> clusterPrivileges,
* @return {@code true} if the access is allowed else returns {@code false}
*/
public boolean check(final String action, final TransportRequest request, final Authentication authentication) {
return checks.stream().anyMatch(permission -> permission.check(action, request, authentication));
for (PermissionCheck permission : checks) {
if (permission.check(action, request, authentication)) {
return true;
}
}
return false;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authz.privilege.HealthAndStatsPrivilege;
import org.elasticsearch.xpack.core.security.support.Automatons;
import org.elasticsearch.xpack.core.security.user.SystemUser;
Expand Down Expand Up @@ -156,7 +155,7 @@ it to the action without an associated user (not via REST or transport - this is
if (authc != null) {
final String requestId = AuditUtil.extractRequestId(threadContext);
assert Strings.hasText(requestId);
authorizeRequest(authc, securityAction, request, listener.delegateFailure(
authzService.authorize(authc, securityAction, request, listener.delegateFailure(
(ll, aVoid) -> chain.proceed(task, action, request, ll.delegateFailure((l, response) -> {
auditTrailService.get().coordinatingActionResponse(requestId, authc, action, request,
response);
Expand All @@ -169,13 +168,4 @@ it to the action without an associated user (not via REST or transport - this is
}
}, listener::onFailure));
}

private <Request extends ActionRequest> void authorizeRequest(Authentication authentication, String securityAction, Request request,
ActionListener<Void> listener) {
if (authentication == null) {
listener.onFailure(new IllegalArgumentException("authentication must be non null for authorization"));
} else {
authzService.authorize(authentication, securityAction, request, listener);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public void authenticate(String action, TransportRequest transportRequest, boole
*/
public void authenticate(String action, TransportRequest transportRequest,
AuthenticationToken token, ActionListener<Authentication> listener) {
new Authenticator(action, transportRequest, shouldFallbackToAnonymous(true), listener).authenticateToken(token);
new Authenticator(action, transportRequest, shouldFallbackToAnonymous(true), listener).consumeToken(token);
}

public void expire(String principal) {
Expand Down Expand Up @@ -314,9 +314,9 @@ private Authenticator(AuditableRequest auditableRequest, User fallbackUser, bool
* these operations are:
*
* <ol>
* <li>look for existing authentication {@link #lookForExistingAuthentication(Consumer)}</li>
* <li>look for existing authentication {@link #lookForExistingAuthentication()}</li>
* <li>look for a user token</li>
* <li>token extraction {@link #extractToken(Consumer)}</li>
* <li>token extraction {@link #extractToken()}</li>
* <li>token authentication {@link #consumeToken(AuthenticationToken)}</li>
* <li>user lookup for run as if necessary {@link #consumeUser(User, Map)} and
* {@link #lookupRunAsUser(User, String, Consumer)}</li>
Expand All @@ -330,14 +330,19 @@ private void authenticateAsync() {
logger.debug("No realms available, failing authentication");
listener.onResponse(null);
} else {
lookForExistingAuthentication((authentication) -> {
if (authentication != null) {
logger.trace("Found existing authentication [{}] in request [{}]", authentication, request);
listener.onResponse(authentication);
} else {
checkForBearerToken();
}
});
final Authentication authentication;
try {
authentication = lookForExistingAuthentication();
} catch (Exception e) {
listener.onFailure(e);
return;
}
if (authentication != null) {
logger.trace("Found existing authentication [{}] in request [{}]", authentication, request);
listener.onResponse(authentication);
} else {
checkForBearerToken();
}
}
}

Expand Down Expand Up @@ -393,7 +398,14 @@ private void checkForApiKey() {
logger.warn("Authentication using apikey failed - {}", authResult.getMessage());
}
}
extractToken(this::consumeToken);
final AuthenticationToken token;
try {
token = extractToken();
} catch (Exception e) {
listener.onFailure(e);
return;
}
consumeToken(token);
}
},
e -> listener.onFailure(request.exceptionProcessingRequest(e, null))));
Expand All @@ -404,25 +416,20 @@ private void checkForApiKey() {
* consumer is called if no exception was thrown while trying to read the authentication and may be called with a {@code null}
* value
*/
private void lookForExistingAuthentication(Consumer<Authentication> authenticationConsumer) {
Runnable action;
private Authentication lookForExistingAuthentication() {
final Authentication authentication;
try {
final Authentication authentication = authenticationSerializer.readFromContext(threadContext);
if (authentication != null && request instanceof AuditableRestRequest) {
action = () -> listener.onFailure(request.tamperedRequest());
} else {
action = () -> authenticationConsumer.accept(authentication);
}
authentication = authenticationSerializer.readFromContext(threadContext);
} catch (Exception e) {
logger.error((Supplier<?>)
() -> new ParameterizedMessage("caught exception while trying to read authentication from request [{}]", request),
e);
action = () -> listener.onFailure(request.tamperedRequest());
throw request.tamperedRequest();
}

// While we could place this call in the try block, the issue is that we catch all exceptions and could catch exceptions that
// have nothing to do with a tampered request.
action.run();
if (authentication != null && request instanceof AuditableRestRequest) {
throw request.tamperedRequest();
}
return authentication;
}

/**
Expand All @@ -431,28 +438,26 @@ private void lookForExistingAuthentication(Consumer<Authentication> authenticati
* no exception was caught during the extraction process and may be called with a {@code null} token.
*/
// pkg-private accessor testing token extraction with a consumer
void extractToken(Consumer<AuthenticationToken> consumer) {
Runnable action = () -> consumer.accept(null);
AuthenticationToken extractToken() {
try {
if (authenticationToken != null) {
action = () -> consumer.accept(authenticationToken);
return authenticationToken;
} else {
for (Realm realm : defaultOrderedRealmList) {
final AuthenticationToken token = realm.token(threadContext);
if (token != null) {
logger.trace("Found authentication credentials [{}] for principal [{}] in request [{}]",
token.getClass().getName(), token.principal(), request);
action = () -> consumer.accept(token);
break;
return token;
}
}
}
} catch (Exception e) {
logger.warn("An exception occurred while attempting to find authentication credentials", e);
action = () -> listener.onFailure(request.exceptionProcessingRequest(e, null));
throw request.exceptionProcessingRequest(e, null);
}

action.run();
return null;
}

/**
Expand Down Expand Up @@ -606,19 +611,12 @@ void handleNullToken() {
authentication = null;
}

Runnable action;
if (authentication != null) {
action = () -> writeAuthToContext(authentication);
writeAuthToContext(authentication);
} else {
action = () -> {
logger.debug("No valid credentials found in request [{}], rejecting", request);
listener.onFailure(request.anonymousAccessDenied());
};
logger.debug("No valid credentials found in request [{}], rejecting", request);
listener.onFailure(request.anonymousAccessDenied());
}

// we assign the listener call to an action to avoid calling the listener within a try block and auditing the wrong thing when
// an exception bubbles up even after successful authentication
action.run();
}

/**
Expand Down Expand Up @@ -712,32 +710,23 @@ void finishAuthentication(User finalUser) {
* successful
*/
void writeAuthToContext(Authentication authentication) {
Runnable action = () -> {
logger.trace("Established authentication [{}] for request [{}]", authentication, request);
listener.onResponse(authentication);
};
try {
authenticationSerializer.writeToContext(authentication, threadContext);
request.authenticationSuccess(authentication);
// Header for operator privileges will only be written if authentication actually happens,
// i.e. not read from either header or transient header
operatorPrivilegesService.maybeMarkOperatorUser(authentication, threadContext);
} catch (Exception e) {
action = () -> {
logger.debug(
logger.debug(
new ParameterizedMessage("Failed to store authentication [{}] for request [{}]", authentication, request), e);
listener.onFailure(request.exceptionProcessingRequest(e, authenticationToken));
};
listener.onFailure(request.exceptionProcessingRequest(e, authenticationToken));
return;
}

// we assign the listener call to an action to avoid calling the listener within a try block and auditing the wrong thing
// when an exception bubbles up even after successful authentication
action.run();
logger.trace("Established authentication [{}] for request [{}]", authentication, request);
listener.onResponse(authentication);
}

private void authenticateToken(AuthenticationToken token) {
this.consumeToken(token);
}
}

abstract static class AuditableRequest {
Expand Down
Loading

0 comments on commit 10d361e

Please sign in to comment.