Skip to content

Commit

Permalink
Fix concurrent refresh of tokens (elastic#55114) (elastic#55733)
Browse files Browse the repository at this point in the history
Our handling for concurrent refresh of access tokens suffered from
a race condition where:

1. Thread A has just finished with updating the existing token
document, but hasn't stored the new tokens in a new document
yet
2. Thread B attempts to refresh the same token and since the
original token document is marked as refreshed, it decrypts and
gets the new access token and refresh token and returns that to
the caller of the API.
3. The caller attempts to use the newly refreshed access token
immediately and gets an authentication error since thread A still
hasn't finished writing the document.

This commit changes the behavior so that Thread B, would first try
to do a Get request for the token document where it expects that
the access token it decrypted is stored(with exponential backoff )
and will not respond until it can verify that it reads it in the
tokens index. That ensures that we only ever return tokens in a
response if they are already valid and can be used immediately

It also adjusts TokenAuthIntegTests
to test authenticating with the tokens each thread receives,
which would fail without the fix.

Resolves: elastic#54289
  • Loading branch information
jkakavas committed Apr 27, 2020
1 parent f2f1296 commit 3a19671
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ private void invalidateAllTokens(Collection<UserToken> userTokens, ActionListene
}

/**
* Invalidates access and/or refresh tokens associated to a user token (coexisting in the same token document)
* Invalidates access and/or refresh tokens associated to a user token (coexisting in the same token document)
*/
private void indexInvalidation(Collection<UserToken> userTokens, Iterator<TimeValue> backoff, String srcPrefix,
@Nullable TokensInvalidationResult previousResult, ActionListener<TokensInvalidationResult> listener) {
Expand Down Expand Up @@ -969,10 +969,11 @@ private void innerRefresh(String refreshToken, String tokenDocId, Map<String, Ob
return;
}
final RefreshTokenStatus refreshTokenStatus = checkRefreshResult.v1();
final SecurityIndexManager refreshedTokenIndex = getTokensIndexForVersion(refreshTokenStatus.getVersion());
if (refreshTokenStatus.isRefreshed()) {
logger.debug("Token document [{}] was recently refreshed, when a new token document was generated. Reusing that result.",
tokenDocId);
decryptAndReturnSupersedingTokens(refreshToken, refreshTokenStatus, listener);
decryptAndReturnSupersedingTokens(refreshToken, refreshTokenStatus, refreshedTokenIndex, listener);
} else {
final String newAccessTokenString = UUIDs.randomBase64UUID();
final String newRefreshTokenString = UUIDs.randomBase64UUID();
Expand All @@ -996,7 +997,6 @@ private void innerRefresh(String refreshToken, String tokenDocId, Map<String, Ob
}
assert seqNo != SequenceNumbers.UNASSIGNED_SEQ_NO : "expected an assigned sequence number";
assert primaryTerm != SequenceNumbers.UNASSIGNED_PRIMARY_TERM : "expected an assigned primary term";
final SecurityIndexManager refreshedTokenIndex = getTokensIndexForVersion(refreshTokenStatus.getVersion());
final UpdateRequestBuilder updateRequest = client
.prepareUpdate(refreshedTokenIndex.aliasName(), SINGLE_MAPPING_NAME, tokenDocId)
.setDoc("refresh_token", updateMap)
Expand Down Expand Up @@ -1033,7 +1033,7 @@ private void innerRefresh(String refreshToken, String tokenDocId, Map<String, Ob
if (cause instanceof VersionConflictEngineException) {
// The document has been updated by another thread, get it again.
logger.debug("version conflict while updating document [{}], attempting to get it again", tokenDocId);
getTokenDocAsync(tokenDocId, refreshedTokenIndex, new ActionListener<GetResponse>() {
getTokenDocAsync(tokenDocId, refreshedTokenIndex, true, new ActionListener<GetResponse>() {
@Override
public void onResponse(GetResponse response) {
if (response.isExists()) {
Expand All @@ -1051,7 +1051,7 @@ public void onFailure(Exception e) {
if (backoff.hasNext()) {
logger.info("could not get token document [{}] for refresh, retrying", tokenDocId);
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
.preserveContext(() -> getTokenDocAsync(tokenDocId, refreshedTokenIndex, this));
.preserveContext(() -> getTokenDocAsync(tokenDocId, refreshedTokenIndex, true, this));
client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
} else {
logger.warn("could not get token document [{}] for refresh after all retries", tokenDocId);
Expand Down Expand Up @@ -1081,17 +1081,20 @@ public void onFailure(Exception e) {
}

/**
* Decrypts the values of the superseding access token and the refresh token, using a key derived from the superseded refresh token. It
* Decrypts the values of the superseding access token and the refresh token, using a key derived from the superseded refresh token.
* It verifies that the token document for the access token it decrypted exists first, before calling the listener. It
* encodes the version and serializes the tokens before calling the listener, in the same manner as {@link #createOAuth2Tokens } does.
*
* @param refreshToken The refresh token that the user sent in the request, used to derive the decryption key
* @param refreshToken The refresh token that the user sent in the request, used to derive the decryption key
* @param refreshTokenStatus The {@link RefreshTokenStatus} containing information about the superseding tokens as retrieved from the
* index
* @param listener The listener to call upon completion with a {@link Tuple} containing the
* serialized access token and serialized refresh token as these will be returned to the client
* index
* @param tokensIndex the manager for the index where the tokens are stored
* @param listener The listener to call upon completion with a {@link Tuple} containing the
* serialized access token and serialized refresh token as these will be returned to the client
*/
void decryptAndReturnSupersedingTokens(String refreshToken, RefreshTokenStatus refreshTokenStatus,
void decryptAndReturnSupersedingTokens(String refreshToken, RefreshTokenStatus refreshTokenStatus, SecurityIndexManager tokensIndex,
ActionListener<Tuple<String, String>> listener) {

final byte[] iv = Base64.getDecoder().decode(refreshTokenStatus.getIv());
final byte[] salt = Base64.getDecoder().decode(refreshTokenStatus.getSalt());
final byte[] encryptedSupersedingTokens = Base64.getDecoder().decode(refreshTokenStatus.getSupersedingTokens());
Expand All @@ -1103,10 +1106,51 @@ void decryptAndReturnSupersedingTokens(String refreshToken, RefreshTokenStatus r
logger.warn("Decrypted tokens string is not correctly formatted");
listener.onFailure(invalidGrantException("could not refresh the requested token"));
} else {
listener.onResponse(new Tuple<>(prependVersionAndEncodeAccessToken(refreshTokenStatus.getVersion(), decryptedTokens[0]),
prependVersionAndEncodeRefreshToken(refreshTokenStatus.getVersion(), decryptedTokens[1])));
// We expect this to protect against race conditions that manifest within few ms
final Iterator<TimeValue> backoff = BackoffPolicy.exponentialBackoff(TimeValue.timeValueMillis(10), 8).iterator();
final String tokenDocId = getTokenDocumentId(hashTokenString(decryptedTokens[0]));
final Consumer<Exception> onFailure = ex ->
listener.onFailure(traceLog("decrypt and get superseding token", tokenDocId, ex));
final Consumer<ActionListener<GetResponse>> maybeRetryGet = actionListener -> {
if (backoff.hasNext()) {
logger.info("could not get token document [{}] that should have been created, retrying", tokenDocId);
client.threadPool().schedule(
() -> getTokenDocAsync(tokenDocId, tokensIndex, false, actionListener),
backoff.next(), GENERIC);
} else {
logger.warn("could not get token document [{}] that should have been created after all retries",
tokenDocId);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
};
getTokenDocAsync(tokenDocId, tokensIndex, false, new ActionListener<GetResponse>() {
@Override
public void onResponse(GetResponse response) {
if (response.isExists()) {
try {
listener.onResponse(
new Tuple<>(prependVersionAndEncodeAccessToken(refreshTokenStatus.getVersion(), decryptedTokens[0]),
prependVersionAndEncodeRefreshToken(refreshTokenStatus.getVersion(), decryptedTokens[1])));
} catch (GeneralSecurityException | IOException e) {
logger.warn("Could not format stored superseding token values", e);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
} else {
maybeRetryGet.accept(this);
}
}

@Override
public void onFailure(Exception e) {
if (isShardNotAvailableException(e)) {
maybeRetryGet.accept(this);
} else {
onFailure.accept(e);
}
}
});
}
} catch (GeneralSecurityException | IOException e) {
} catch (GeneralSecurityException e) {
logger.warn("Could not get stored superseding token values", e);
listener.onFailure(invalidGrantException("could not refresh the requested token"));
}
Expand All @@ -1124,11 +1168,14 @@ String encryptSupersedingTokens(String supersedingAccessToken, String supersedin
return Base64.getEncoder().encodeToString(cipher.doFinal(supersedingTokens.getBytes(StandardCharsets.UTF_8)));
}

private void getTokenDocAsync(String tokenDocId, SecurityIndexManager tokensIndex, ActionListener<GetResponse> listener) {
final GetRequest getRequest = client.prepareGet(tokensIndex.aliasName(), SINGLE_MAPPING_NAME, tokenDocId).request();
private void getTokenDocAsync(String tokenDocId, SecurityIndexManager tokensIndex,
boolean fetchSource, ActionListener<GetResponse> listener) {
final GetRequest getRequest = client.prepareGet(tokensIndex.aliasName(), SINGLE_MAPPING_NAME, tokenDocId)
.setFetchSource(fetchSource)
.request();
tokensIndex.checkIndexVersionThenExecute(
ex -> listener.onFailure(traceLog("prepare tokens index [" + tokensIndex.aliasName() + "]", tokenDocId, ex)),
() -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, listener, client::get));
ex -> listener.onFailure(traceLog("prepare tokens index [" + tokensIndex.aliasName() + "]", tokenDocId, ex)),
() -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, listener, client::get));
}

Version getTokenVersionCompatibility() {
Expand Down Expand Up @@ -1222,7 +1269,7 @@ private static Map<String, Object> getUserTokenSourceMap(Map<String, Object> sou
/**
* Checks if the token can be refreshed once more. If a token has previously been refreshed, it can only by refreshed again inside a
* short span of time (30 s).
*
*
* @return An {@code Optional} containing the exception in case this refresh token cannot be reused, or an empty <b>Optional</b> if
* refreshing is allowed.
*/
Expand Down Expand Up @@ -1370,7 +1417,7 @@ private void sourceIndicesWithTokensAndRun(ActionListener<List<String>> listener
}
final SecurityIndexManager frozenMainIndex = securityMainIndex.freeze();
if (frozenMainIndex.indexExists()) {
// main security index _might_ contain tokens if the tokens index has been created recently
// main security index _might_ contain tokens if the tokens index has been created recently
if (false == frozenTokensIndex.indexExists() || frozenTokensIndex.getCreationTime()
.isAfter(clock.instant().minus(ExpiredTokenRemover.MAXIMUM_TOKEN_LIFETIME_HOURS, ChronoUnit.HOURS))) {
if (false == frozenMainIndex.isAvailable()) {
Expand Down
Loading

0 comments on commit 3a19671

Please sign in to comment.