diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/Hasher.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/Hasher.java index 492622b2c519c..28f263748135f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/Hasher.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/Hasher.java @@ -351,6 +351,24 @@ public boolean verify(SecureString text, char[] hash) { return CharArrays.constantTimeEquals(computedHash, new String(saltAndHash, 12, saltAndHash.length - 12)); } }, + /* + * Unsalted SHA-256 , not suited for password storage. + */ + SHA256() { + @Override + public char[] hash(SecureString text) { + MessageDigest md = MessageDigests.sha256(); + md.update(CharArrays.toUtf8Bytes(text.getChars())); + return Base64.getEncoder().encodeToString(md.digest()).toCharArray(); + } + + @Override + public boolean verify(SecureString text, char[] hash) { + MessageDigest md = MessageDigests.sha256(); + md.update(CharArrays.toUtf8Bytes(text.getChars())); + return CharArrays.constantTimeEquals(Base64.getEncoder().encodeToString(md.digest()).toCharArray(), hash); + } + }, NOOP() { @Override diff --git a/x-pack/plugin/core/src/main/resources/security-index-template-7.json b/x-pack/plugin/core/src/main/resources/security-index-template-7.json index ebf6d073cd8a6..dae6462b7a6f0 100644 --- a/x-pack/plugin/core/src/main/resources/security-index-template-7.json +++ b/x-pack/plugin/core/src/main/resources/security-index-template-7.json @@ -213,8 +213,19 @@ "type": "date", "format": "epoch_millis" }, - "superseded_by": { - "type": "keyword" + "superseding": { + "type": "object", + "properties": { + "encrypted_tokens": { + "type": "binary" + }, + "encryption_iv": { + "type": "binary" + }, + "encryption_salt": { + "type": "binary" + } + } }, "invalidated" : { "type" : "boolean" diff --git a/x-pack/plugin/core/src/main/resources/security-tokens-index-template-7.json b/x-pack/plugin/core/src/main/resources/security-tokens-index-template-7.json index e7450d0be9c28..312d9ff9e3f58 100644 --- a/x-pack/plugin/core/src/main/resources/security-tokens-index-template-7.json +++ b/x-pack/plugin/core/src/main/resources/security-tokens-index-template-7.json @@ -35,8 +35,19 @@ "type": "date", "format": "epoch_millis" }, - "superseded_by": { - "type": "keyword" + "superseding": { + "type": "object", + "properties": { + "encrypted_tokens": { + "type": "binary" + }, + "encryption_iv": { + "type": "binary" + }, + "encryption_salt": { + "type": "binary" + } + } }, "invalidated" : { "type" : "boolean" diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectAuthenticateAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectAuthenticateAction.java index 1b4aff064a0c3..4bab16cf92115 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectAuthenticateAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectAuthenticateAction.java @@ -7,6 +7,8 @@ import com.nimbusds.oauth2.sdk.id.State; import com.nimbusds.openid.connect.sdk.Nonce; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; @@ -36,6 +38,7 @@ public class TransportOpenIdConnectAuthenticateAction private final ThreadPool threadPool; private final AuthenticationService authenticationService; private final TokenService tokenService; + private static final Logger logger = LogManager.getLogger(TransportOpenIdConnectAuthenticateAction.class); @Inject public TransportOpenIdConnectAuthenticateAction(ThreadPool threadPool, TransportService transportService, @@ -67,9 +70,8 @@ protected void doExecute(Task task, OpenIdConnectAuthenticateRequest request, .get(OpenIdConnectRealm.CONTEXT_TOKEN_DATA); tokenService.createOAuth2Tokens(authentication, originatingAuthentication, tokenMetadata, true, ActionListener.wrap(tuple -> { - final String tokenString = tokenService.getAccessTokenAsString(tuple.v1()); final TimeValue expiresIn = tokenService.getExpirationDelay(); - listener.onResponse(new OpenIdConnectAuthenticateResponse(authentication.getUser().principal(), tokenString, + listener.onResponse(new OpenIdConnectAuthenticateResponse(authentication.getUser().principal(), tuple.v1(), tuple.v2(), expiresIn)); }, listener::onFailure)); }, e -> { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java index 6b61742eed262..96eec7e8fd6c7 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java @@ -63,10 +63,9 @@ protected void doExecute(Task task, SamlAuthenticateRequest request, ActionListe final Map tokenMeta = (Map) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA); tokenService.createOAuth2Tokens(authentication, originatingAuthentication, tokenMeta, true, ActionListener.wrap(tuple -> { - final String tokenString = tokenService.getAccessTokenAsString(tuple.v1()); final TimeValue expiresIn = tokenService.getExpirationDelay(); listener.onResponse( - new SamlAuthenticateResponse(authentication.getUser().principal(), tokenString, tuple.v2(), expiresIn)); + new SamlAuthenticateResponse(authentication.getUser().principal(), tuple.v1(), tuple.v2(), expiresIn)); }, listener::onFailure)); }, e -> { logger.debug(() -> new ParameterizedMessage("SamlToken [{}] could not be authenticated", saml), e); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java index 4b648d5ed4bc0..65456ccd2af51 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java @@ -88,9 +88,8 @@ private void createToken(CreateTokenRequest request, Authentication authenticati boolean includeRefreshToken, ActionListener listener) { tokenService.createOAuth2Tokens(authentication, originatingAuth, Collections.emptyMap(), includeRefreshToken, ActionListener.wrap(tuple -> { - final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1()); final String scope = getResponseScopeValue(request.getScope()); - final CreateTokenResponse response = new CreateTokenResponse(tokenStr, tokenService.getExpirationDelay(), scope, + final CreateTokenResponse response = new CreateTokenResponse(tuple.v1(), tokenService.getExpirationDelay(), scope, tuple.v2()); listener.onResponse(response); }, listener::onFailure)); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java index 71aeb64bc4276..5c161d889cfb1 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java @@ -31,11 +31,9 @@ public TransportRefreshTokenAction(TransportService transportService, ActionFilt @Override protected void doExecute(Task task, CreateTokenRequest request, ActionListener listener) { tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> { - final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1()); final String scope = getResponseScopeValue(request.getScope()); - final CreateTokenResponse response = - new CreateTokenResponse(tokenStr, tokenService.getExpirationDelay(), scope, tuple.v2()); + new CreateTokenResponse(tuple.v1(), tokenService.getExpirationDelay(), scope, tuple.v2()); listener.onResponse(response); }, listener::onFailure)); } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java index 6f96c9bf7dd88..ec5086201c68e 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java @@ -86,6 +86,7 @@ import org.elasticsearch.xpack.core.security.authc.Authentication.AuthenticationType; import org.elasticsearch.xpack.core.security.authc.KeyAndTimestamp; import org.elasticsearch.xpack.core.security.authc.TokenMetaData; +import org.elasticsearch.xpack.core.security.authc.support.Hasher; import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult; import org.elasticsearch.xpack.security.support.SecurityIndexManager; @@ -157,11 +158,12 @@ public final class TokenService { * Cheat Sheet and the * NIST Digital Identity Guidelines */ - private static final int ITERATIONS = 100000; + static final int TOKEN_SERVICE_KEY_ITERATIONS = 100000; + static final int TOKENS_ENCRYPTION_KEY_ITERATIONS = 1024; private static final String KDF_ALGORITHM = "PBKDF2withHMACSHA512"; - private static final int SALT_BYTES = 32; + static final int SALT_BYTES = 32; private static final int KEY_BYTES = 64; - private static final int IV_BYTES = 12; + static final int IV_BYTES = 12; private static final int VERSION_BYTES = 4; private static final String ENCRYPTION_CIPHER = "AES/GCM/NoPadding"; private static final String EXPIRED_TOKEN_WWW_AUTH_VALUE = "Bearer realm=\"" + XPackField.SECURITY + @@ -179,14 +181,18 @@ public final class TokenService { TimeValue.MINUS_ONE, Property.NodeScope); static final String TOKEN_DOC_TYPE = "token"; + private static final int HASHED_TOKEN_LENGTH = 44; + // UUIDs are 16 bytes encoded base64 without padding, therefore the length is (16 / 3) * 4 + ((16 % 3) * 8 + 5) / 6 chars + private static final int TOKEN_LENGTH = 22; private static final String TOKEN_DOC_ID_PREFIX = TOKEN_DOC_TYPE + "_"; - static final int MINIMUM_BYTES = VERSION_BYTES + SALT_BYTES + IV_BYTES + 1; + static final int LEGACY_MINIMUM_BYTES = VERSION_BYTES + SALT_BYTES + IV_BYTES + 1; + static final int MINIMUM_BYTES = VERSION_BYTES + TOKEN_LENGTH + 1; + static final int LEGACY_MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * LEGACY_MINIMUM_BYTES) / 3)).intValue(); static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue(); + static final Version VERSION_HASHED_TOKENS = Version.V_8_0_0; static final Version VERSION_TOKENS_INDEX_INTRODUCED = Version.V_7_2_0; static final Version VERSION_ACCESS_TOKENS_AS_UUIDS = Version.V_7_2_0; static final Version VERSION_MULTIPLE_CONCURRENT_REFRESHES = Version.V_7_2_0; - // UUIDs are 16 bytes encoded base64 without padding, therefore the length is (16 / 3) * 4 + ((16 % 3) * 8 + 5) / 6 chars - private static final int TOKEN_ID_LENGTH = 22; private static final Logger logger = LogManager.getLogger(TokenService.class); private final SecureRandom secureRandom = new SecureRandom(); @@ -235,31 +241,71 @@ public TokenService(Settings settings, Clock clock, Client client, XPackLicenseS } /** - * Creates an access token and optionally a refresh token as well, based on the provided authentication and metadata with an - * auto-generated token document id. The created tokens are stored in the security index. + * Creates an access token and optionally a refresh token as well, based on the provided authentication and metadata with + * auto-generated values. The created tokens are stored in the security index for versions up to + * {@link #VERSION_TOKENS_INDEX_INTRODUCED} and to a specific security tokens index for later versions. */ - public void createOAuth2Tokens(Authentication authentication, Authentication originatingClientAuth, - Map metadata, boolean includeRefreshToken, - ActionListener> listener) { + public void createOAuth2Tokens(Authentication authentication, Authentication originatingClientAuth, Map metadata, + boolean includeRefreshToken, ActionListener> listener) { // the created token is compatible with the oldest node version in the cluster final Version tokenVersion = getTokenVersionCompatibility(); // tokens moved to a separate index in newer versions final SecurityIndexManager tokensIndex = getTokensIndexForVersion(tokenVersion); // the id of the created tokens ought be unguessable - final String userTokenId = UUIDs.randomBase64UUID(); - createOAuth2Tokens(userTokenId, tokenVersion, tokensIndex, authentication, originatingClientAuth, metadata, includeRefreshToken, - listener); + final String accessToken = UUIDs.randomBase64UUID(); + final String refreshToken = includeRefreshToken ? UUIDs.randomBase64UUID() : null; + createOAuth2Tokens(accessToken, refreshToken, tokenVersion, tokensIndex, authentication, originatingClientAuth, metadata, listener); } /** - * Create an access token and optionally a refresh token as well, based on the provided authentication and metadata, with the given - * token document id. The created tokens are be stored in the security index. + * Creates an access token and optionally a refresh token as well from predefined values, based on the provided authentication and + * metadata. The created tokens are stored in the security index for versions up to {@link #VERSION_TOKENS_INDEX_INTRODUCED} and to a + * specific security tokens index for later versions. + */ + //public for testing + public void createOAuth2Tokens(String accessToken, String refreshToken, Authentication authentication, + Authentication originatingClientAuth, + Map metadata, ActionListener> listener) { + // the created token is compatible with the oldest node version in the cluster + final Version tokenVersion = getTokenVersionCompatibility(); + // tokens moved to a separate index in newer versions + final SecurityIndexManager tokensIndex = getTokensIndexForVersion(tokenVersion); + createOAuth2Tokens(accessToken, refreshToken, tokenVersion, tokensIndex, authentication, originatingClientAuth, metadata, listener); + } + + /** + * Create an access token and optionally a refresh token as well from predefined values, based on the provided authentication and + * metadata. + * + * @param accessToken The predefined seed value for the access token. This will then be + *
    + *
  • Encrypted before stored for versions before {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • + *
  • Hashed before stored for versions after {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • + *
  • Stored in the security index for versions up to {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • + *
  • Stored in a specific security tokens index for versions after {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • + *
  • Prepended with a version ID and encoded with Base64 before returned to the caller of the APIs
  • + *
+ * @param refreshToken The predefined seed value for the access token. This will then be + *
    + *
  • Hashed before stored for versions after {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • + *
  • Stored in the security index for versions up to {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • + *
  • Stored in a specific security tokens index for versions after {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • + *
  • Prepended with a version ID and encoded with Base64 before returned to the caller of the APIs for + * versions after {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • + *
+ * @param tokenVersion The version of the nodes with which these tokens will be compatible. + * @param tokensIndex The security tokens index + * @param authentication The authentication object representing the user for which the tokens are created + * @param originatingClientAuth The authentication object representing the client that called the related API + * @param metadata A map with metadata to be stored in the token document + * @param listener The listener to call upon completion with a {@link Tuple} containing the + * serialized access token and serialized refresh token as these will be returned to the client */ - private void createOAuth2Tokens(String userTokenId, Version tokenVersion, SecurityIndexManager tokensIndex, + private void createOAuth2Tokens(String accessToken, String refreshToken, Version tokenVersion, SecurityIndexManager tokensIndex, Authentication authentication, Authentication originatingClientAuth, Map metadata, - boolean includeRefreshToken, ActionListener> listener) { - assert userTokenId.length() == TOKEN_ID_LENGTH : "We assume token ids have a fixed length for nodes of a certain version." - + " When changing the token length, be careful that the inferences about its length still hold."; + ActionListener> listener) { + assert accessToken.length() == TOKEN_LENGTH : "We assume token ids have a fixed length for nodes of a certain version." + + " When changing the token length, be careful that the inferences about its length still hold."; ensureEnabled(); if (authentication == null) { listener.onFailure(traceLog("create token", new IllegalArgumentException("authentication must be provided"))); @@ -269,10 +315,19 @@ private void createOAuth2Tokens(String userTokenId, Version tokenVersion, Securi } else { final Authentication tokenAuth = new Authentication(authentication.getUser(), authentication.getAuthenticatedBy(), authentication.getLookedUpBy(), tokenVersion, AuthenticationType.TOKEN, authentication.getMetadata()); - final UserToken userToken = new UserToken(userTokenId, tokenVersion, tokenAuth, getExpirationTime(), metadata); - final String plainRefreshToken = includeRefreshToken ? UUIDs.randomBase64UUID() : null; - final BytesReference tokenDocument = createTokenDocument(userToken, plainRefreshToken, originatingClientAuth); - final String documentId = getTokenDocumentId(userToken); + final String storedAccessToken; + final String storedRefreshToken; + if (tokenVersion.onOrAfter(VERSION_HASHED_TOKENS)) { + storedAccessToken = hashTokenString(accessToken); + storedRefreshToken = (null == refreshToken) ? null : hashTokenString(refreshToken); + } else { + storedAccessToken = accessToken; + storedRefreshToken = refreshToken; + } + final UserToken userToken = new UserToken(storedAccessToken, tokenVersion, tokenAuth, getExpirationTime(), metadata); + final BytesReference tokenDocument = createTokenDocument(userToken, storedRefreshToken, originatingClientAuth); + final String documentId = getTokenDocumentId(storedAccessToken); + final IndexRequest indexTokenRequest = client.prepareIndex(tokensIndex.aliasName(), SINGLE_MAPPING_NAME, documentId) .setOpType(OpType.CREATE) .setSource(tokenDocument, XContentType.JSON) @@ -283,15 +338,17 @@ private void createOAuth2Tokens(String userTokenId, Version tokenVersion, Securi () -> executeAsyncWithOrigin(client, SECURITY_ORIGIN, IndexAction.INSTANCE, indexTokenRequest, ActionListener.wrap(indexResponse -> { if (indexResponse.getResult() == Result.CREATED) { + final String versionedAccessToken = prependVersionAndEncodeAccessToken(tokenVersion, accessToken); if (tokenVersion.onOrAfter(VERSION_TOKENS_INDEX_INTRODUCED)) { - final String versionedRefreshToken = plainRefreshToken != null - ? prependVersionAndEncode(tokenVersion, plainRefreshToken) - : null; - listener.onResponse(new Tuple<>(userToken, versionedRefreshToken)); + final String versionedRefreshToken = refreshToken != null + ? prependVersionAndEncodeRefreshToken(tokenVersion, refreshToken) + : null; + listener.onResponse(new Tuple<>(versionedAccessToken, versionedRefreshToken)); } else { - // prior versions are not version-prepended, as nodes on those versions don't expect it. + // prior versions of the refresh token are not version-prepended, as nodes on those + // versions don't expect it. // Such nodes might exist in a mixed cluster during a rolling upgrade. - listener.onResponse(new Tuple<>(userToken, plainRefreshToken)); + listener.onResponse(new Tuple<>(versionedAccessToken, refreshToken)); } } else { listener.onFailure(traceLog("create token", @@ -301,6 +358,15 @@ private void createOAuth2Tokens(String userTokenId, Version tokenVersion, Securi } } + /** + * Hashes an access or refresh token String so that it can safely be persisted in the index. We don't salt + * the values as these are v4 UUIDs that have enough entropy by themselves. + */ + // public for testing + public static String hashTokenString(String accessTokenString) { + return new String(Hasher.SHA256.hash(new SecureString(accessTokenString.toCharArray()))); + } + /** * Looks in the context to see if the request provided a header with a user token and if so the * token is validated, which might include authenticated decryption and verification that the token @@ -406,13 +472,24 @@ void decodeToken(String token, ActionListener listener) { final Version version = Version.readVersion(in); in.setVersion(version); if (version.onOrAfter(VERSION_ACCESS_TOKENS_AS_UUIDS)) { - // The token was created in a > VERSION_ACCESS_TOKENS_UUIDS cluster so it contains the tokenId as a String - String usedTokenId = in.readString(); - getUserTokenFromId(usedTokenId, version, listener); + // The token was created in a > VERSION_ACCESS_TOKENS_UUIDS cluster + if (in.available() < MINIMUM_BYTES) { + logger.debug("invalid token, smaller than [{}] bytes", MINIMUM_BYTES); + listener.onResponse(null); + return; + } + final String accessToken = in.readString(); + // TODO Remove this conditional after backporting to 7.x + if (version.onOrAfter(VERSION_HASHED_TOKENS)) { + final String userTokenId = hashTokenString(accessToken); + getUserTokenFromId(userTokenId, version, listener); + } else { + getUserTokenFromId(accessToken, version, listener); + } } else { // The token was created in a < VERSION_ACCESS_TOKENS_UUIDS cluster so we need to decrypt it to get the tokenId - if (in.available() < MINIMUM_BASE64_BYTES) { - logger.debug("invalid token, smaller than [{}] bytes", MINIMUM_BASE64_BYTES); + if (in.available() < LEGACY_MINIMUM_BYTES) { + logger.debug("invalid token, smaller than [{}] bytes", LEGACY_MINIMUM_BYTES); listener.onResponse(null); return; } @@ -709,8 +786,12 @@ private void indexInvalidation(Collection tokenIds, SecurityIndexManager /** * Called by the transport action in order to start the process of refreshing a token. + * + * @param refreshToken The refresh token as provided by the client + * @param listener The listener to call upon completion with a {@link Tuple} containing the + * serialized access token and serialized refresh token as these will be returned to the client */ - public void refreshToken(String refreshToken, ActionListener> listener) { + public void refreshToken(String refreshToken, ActionListener> listener) { ensureEnabled(); final Instant refreshRequested = clock.instant(); final Iterator backoff = DEFAULT_BACKOFF.iterator(); @@ -718,36 +799,49 @@ public void refreshToken(String refreshToken, ActionListener { final Authentication clientAuth = Authentication.readFromContext(client.threadPool().getThreadContext()); - innerRefresh(tokenDocHit.getId(), tokenDocHit.getSourceAsMap(), tokenDocHit.getSeqNo(), tokenDocHit.getPrimaryTerm(), - clientAuth, backoff, refreshRequested, listener); + innerRefresh(refreshToken, tokenDocHit.getId(), tokenDocHit.getSourceAsMap(), tokenDocHit.getSeqNo(), + tokenDocHit.getPrimaryTerm(), + clientAuth, backoff, refreshRequested, listener); }, listener::onFailure)); } /** - * Inferes the format and version of the passed in {@code refreshToken}. Delegates the actual search of the token document to + * Infers the format and version of the passed in {@code refreshToken}. Delegates the actual search of the token document to * {@code #findTokenFromRefreshToken(String, SecurityIndexManager, Iterator, ActionListener)} . */ private void findTokenFromRefreshToken(String refreshToken, Iterator backoff, ActionListener listener) { - if (refreshToken.length() == TOKEN_ID_LENGTH) { + if (refreshToken.length() == TOKEN_LENGTH) { // first check if token has the old format before the new version-prepended one logger.debug("Assuming an unversioned refresh token [{}], generated for node versions" - + " prior to the introduction of the version-header format.", refreshToken); + + " prior to the introduction of the version-header format.", refreshToken); findTokenFromRefreshToken(refreshToken, securityMainIndex, backoff, listener); } else { - try { - final Tuple versionAndRefreshTokenTuple = unpackVersionAndPayload(refreshToken); - final Version refreshTokenVersion = versionAndRefreshTokenTuple.v1(); - final String unencodedRefreshToken = versionAndRefreshTokenTuple.v2(); - if (false == refreshTokenVersion.onOrAfter(VERSION_TOKENS_INDEX_INTRODUCED) - || unencodedRefreshToken.length() != TOKEN_ID_LENGTH) { - logger.debug("Decoded refresh token [{}] with version [{}] is invalid.", unencodedRefreshToken, refreshTokenVersion); + if (refreshToken.length() == HASHED_TOKEN_LENGTH) { + logger.debug("Assuming a hashed refresh token [{}] retrieved from the tokens index", refreshToken); + findTokenFromRefreshToken(refreshToken, securityTokensIndex, backoff, listener); + } else { + logger.debug("Assuming a refresh token [{}] provided from a client", refreshToken); + try { + final Tuple versionAndRefreshTokenTuple = unpackVersionAndPayload(refreshToken); + final Version refreshTokenVersion = versionAndRefreshTokenTuple.v1(); + final String unencodedRefreshToken = versionAndRefreshTokenTuple.v2(); + if (refreshTokenVersion.before(VERSION_TOKENS_INDEX_INTRODUCED) || unencodedRefreshToken.length() != TOKEN_LENGTH) { + logger.debug("Decoded refresh token [{}] with version [{}] is invalid.", unencodedRefreshToken, + refreshTokenVersion); + listener.onFailure(malformedTokenException()); + } else { + // TODO Remove this conditional after backporting to 7.x + if (refreshTokenVersion.onOrAfter(VERSION_HASHED_TOKENS)) { + final String hashedRefreshToken = hashTokenString(unencodedRefreshToken); + findTokenFromRefreshToken(hashedRefreshToken, securityTokensIndex, backoff, listener); + } else { + findTokenFromRefreshToken(unencodedRefreshToken, securityTokensIndex, backoff, listener); + } + } + } catch (IOException e) { + logger.debug(() -> new ParameterizedMessage("Could not decode refresh token [{}].", refreshToken), e); listener.onFailure(malformedTokenException()); - } else { - findTokenFromRefreshToken(unencodedRefreshToken, securityTokensIndex, backoff, listener); } - } catch (IOException e) { - logger.debug("Could not decode refresh token [" + refreshToken + "].", e); - listener.onFailure(malformedTokenException()); } } } @@ -763,7 +857,7 @@ private void findTokenFromRefreshToken(String refreshToken, SecurityIndexManager final Consumer maybeRetryOnFailure = ex -> { if (backoff.hasNext()) { final TimeValue backofTimeValue = backoff.next(); - logger.debug("retrying after [" + backofTimeValue + "] back off"); + logger.debug("retrying after [{}] back off", backofTimeValue); final Runnable retryWithContextRunnable = client.threadPool().getThreadContext() .preserveContext(() -> findTokenFromRefreshToken(refreshToken, tokensIndexManager, backoff, listener)); client.threadPool().schedule(retryWithContextRunnable, backofTimeValue, GENERIC); @@ -821,13 +915,14 @@ private void findTokenFromRefreshToken(String refreshToken, SecurityIndexManager * supersedes this one. The new document that contains the new access token and refresh token is created and finally the new access * token and refresh token are returned to the listener. */ - private void innerRefresh(String tokenDocId, Map source, long seqNo, long primaryTerm, Authentication clientAuth, - Iterator backoff, Instant refreshRequested, ActionListener> listener) { + private void innerRefresh(String refreshToken, String tokenDocId, Map source, long seqNo, long primaryTerm, + Authentication clientAuth, Iterator backoff, Instant refreshRequested, + ActionListener> listener) { logger.debug("Attempting to refresh token stored in token document [{}]", tokenDocId); final Consumer onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex)); final Tuple> checkRefreshResult; try { - checkRefreshResult = checkTokenDocumentForRefresh(clock.instant(), clientAuth, source); + checkRefreshResult = checkTokenDocumentForRefresh(refreshRequested, clientAuth, source); } catch (DateTimeException | IllegalStateException e) { onFailure.accept(new ElasticsearchSecurityException("invalid token document", e)); return; @@ -838,23 +933,29 @@ private void innerRefresh(String tokenDocId, Map source, long se } final RefreshTokenStatus refreshTokenStatus = checkRefreshResult.v1(); if (refreshTokenStatus.isRefreshed()) { - logger.debug("Token document [{}] was recently refreshed, when a new token document [{}] was generated. Reusing that result.", - tokenDocId, refreshTokenStatus.getSupersededBy()); - getSupersedingTokenDocAsyncWithRetry(refreshTokenStatus, backoff, listener); + logger.debug("Token document [{}] was recently refreshed, when a new token document was generated. Reusing that result.", + tokenDocId); + decryptAndReturnSupersedingTokens(refreshToken, refreshTokenStatus, listener); } else { - final String newUserTokenId = UUIDs.randomBase64UUID(); + final String newAccessTokenString = UUIDs.randomBase64UUID(); + final String newRefreshTokenString = UUIDs.randomBase64UUID(); final Version newTokenVersion = getTokenVersionCompatibility(); final Map updateMap = new HashMap<>(); updateMap.put("refreshed", true); - updateMap.put("refresh_time", clock.instant().toEpochMilli()); - if (newTokenVersion.onOrAfter(VERSION_TOKENS_INDEX_INTRODUCED)) { - // the superseding token document reference is formated as "|"; - // for now, only the ".security-tokens|" is a valid reference format - updateMap.put("superseded_by", securityTokensIndex.aliasName() + "|" + getTokenDocumentId(newUserTokenId)); - } else { - // preservers the format of the reference (without the alias prefix) - // so that old nodes in a mixed cluster can still understand it - updateMap.put("superseded_by", getTokenDocumentId(newUserTokenId)); + if (newTokenVersion.onOrAfter(VERSION_MULTIPLE_CONCURRENT_REFRESHES)) { + updateMap.put("refresh_time", clock.instant().toEpochMilli()); + try { + final byte[] iv = getRandomBytes(IV_BYTES); + final byte[] salt = getRandomBytes(SALT_BYTES); + String encryptedAccessAndRefreshToken = encryptSupersedingTokens(newAccessTokenString, + newRefreshTokenString, refreshToken, iv, salt); + updateMap.put("superseding.encrypted_tokens", encryptedAccessAndRefreshToken); + updateMap.put("superseding.encryption_iv", Base64.getEncoder().encodeToString(iv)); + updateMap.put("superseding.encryption_salt", Base64.getEncoder().encodeToString(salt)); + } catch (GeneralSecurityException e) { + logger.warn("could not encrypt access token and refresh token string", e); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } } assert seqNo != SequenceNumbers.UNASSIGNED_SEQ_NO : "expected an assigned sequence number"; assert primaryTerm != SequenceNumbers.UNASSIGNED_PRIMARY_TERM : "expected an assigned primary term"; @@ -875,14 +976,15 @@ private void innerRefresh(String tokenDocId, Map source, long se updateResponse.getGetResult().sourceAsMap())); final Tuple parsedTokens = parseTokensFromDocument(source, null); final UserToken toRefreshUserToken = parsedTokens.v1(); - createOAuth2Tokens(newUserTokenId, newTokenVersion, getTokensIndexForVersion(newTokenVersion), - toRefreshUserToken.getAuthentication(), clientAuth, toRefreshUserToken.getMetadata(), true, listener); + createOAuth2Tokens(newAccessTokenString, newRefreshTokenString, newTokenVersion, + getTokensIndexForVersion(newTokenVersion), toRefreshUserToken.getAuthentication(), clientAuth, + toRefreshUserToken.getMetadata(), listener); } else if (backoff.hasNext()) { logger.info("failed to update the original token document [{}], the update result was [{}]. Retrying", tokenDocId, updateResponse.getResult()); final Runnable retryWithContextRunnable = client.threadPool().getThreadContext() - .preserveContext(() -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, backoff, - refreshRequested, listener)); + .preserveContext(() -> innerRefresh(refreshToken, tokenDocId, source, seqNo, primaryTerm, clientAuth, + backoff, refreshRequested, listener)); client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC); } else { logger.info("failed to update the original token document [{}] after all retries, the update result was [{}]. ", @@ -898,8 +1000,8 @@ private void innerRefresh(String tokenDocId, Map source, long se @Override public void onResponse(GetResponse response) { if (response.isExists()) { - innerRefresh(tokenDocId, response.getSource(), response.getSeqNo(), response.getPrimaryTerm(), - clientAuth, backoff, refreshRequested, listener); + innerRefresh(refreshToken, tokenDocId, response.getSource(), response.getSeqNo(), + response.getPrimaryTerm(), clientAuth, backoff, refreshRequested, listener); } else { logger.warn("could not find token document [{}] for refresh", tokenDocId); onFailure.accept(invalidGrantException("could not refresh the requested token")); @@ -927,8 +1029,8 @@ public void onFailure(Exception e) { if (backoff.hasNext()) { logger.debug("failed to update the original token document [{}], retrying", tokenDocId); final Runnable retryWithContextRunnable = client.threadPool().getThreadContext() - .preserveContext(() -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, backoff, - refreshRequested, listener)); + .preserveContext(() -> innerRefresh(refreshToken, tokenDocId, source, seqNo, primaryTerm, + clientAuth, backoff, refreshRequested, listener)); client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC); } else { logger.warn("failed to update the original token document [{}], after all retries", tokenDocId); @@ -941,72 +1043,47 @@ public void onFailure(Exception e) { } } - private void getSupersedingTokenDocAsyncWithRetry(RefreshTokenStatus refreshTokenStatus, Iterator backoff, - ActionListener> listener) { - final Consumer onFailure = ex -> listener - .onFailure(traceLog("get superseding token", refreshTokenStatus.getSupersededBy(), ex)); - getSupersedingTokenDocAsync(refreshTokenStatus, new ActionListener() { - private final Consumer maybeRetryOnFailure = ex -> { - if (backoff.hasNext()) { - final TimeValue backofTimeValue = backoff.next(); - logger.debug("retrying after [" + backofTimeValue + "] back off"); - final Runnable retryWithContextRunnable = client.threadPool().getThreadContext() - .preserveContext(() -> getSupersedingTokenDocAsync(refreshTokenStatus, this)); - client.threadPool().schedule(retryWithContextRunnable, backofTimeValue, GENERIC); - } else { - logger.warn("back off retries exhausted"); - onFailure.accept(ex); - } - }; - - @Override - public void onResponse(GetResponse response) { - if (response.isExists()) { - logger.debug("found superseding token document [{}] in index [{}] by following the [{}] reference", response.getId(), - response.getIndex(), refreshTokenStatus.getSupersededBy()); - final Tuple parsedTokens; - try { - parsedTokens = parseTokensFromDocument(response.getSource(), null); - } catch (IllegalStateException | DateTimeException e) { - logger.error("unable to decode existing user token", e); - listener.onFailure(new ElasticsearchSecurityException("could not refresh the requested token", e)); - return; - } - listener.onResponse(parsedTokens); - } else { - // We retry this since the creation of the superseding token document might already be in flight but not - // yet completed, triggered by a refresh request that came a few milliseconds ago - logger.info("could not find superseding token document from [{}] reference, retrying", - refreshTokenStatus.getSupersededBy()); - maybeRetryOnFailure.accept(invalidGrantException("could not refresh the requested token")); - } - } - - @Override - public void onFailure(Exception e) { - if (isShardNotAvailableException(e)) { - logger.info("could not find superseding token document from reference [{}], retrying", - refreshTokenStatus.getSupersededBy()); - maybeRetryOnFailure.accept(invalidGrantException("could not refresh the requested token")); - } else { - logger.warn("could not find superseding token document from reference [{}]", refreshTokenStatus.getSupersededBy()); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } + /** + * Decrypts the values of the superseding access token and the refresh token, using a key derived from the superseded refresh token. It + * encodes the version and serializes the tokens before calling the listener, in the same manner as {@link #createOAuth2Tokens } does. + * + * @param refreshToken The refresh token that the user sent in the request, used to derive the decryption key + * @param refreshTokenStatus The {@link RefreshTokenStatus} containing information about the superseding tokens as retrieved from the + * index + * @param listener The listener to call upon completion with a {@link Tuple} containing the + * serialized access token and serialized refresh token as these will be returned to the client + */ + void decryptAndReturnSupersedingTokens(String refreshToken, RefreshTokenStatus refreshTokenStatus, + ActionListener> listener) { + final byte[] iv = Base64.getDecoder().decode(refreshTokenStatus.getIv()); + final byte[] salt = Base64.getDecoder().decode(refreshTokenStatus.getSalt()); + final byte[] encryptedSupersedingTokens = Base64.getDecoder().decode(refreshTokenStatus.getSupersedingTokens()); + try { + Cipher cipher = getDecryptionCipher(iv, refreshToken, salt); + final String supersedingTokens = new String(cipher.doFinal(encryptedSupersedingTokens), StandardCharsets.UTF_8); + final String[] decryptedTokens = supersedingTokens.split("\\|"); + if (decryptedTokens.length != 2) { + logger.warn("Decrypted tokens string is not correctly formatted"); + listener.onFailure(invalidGrantException("could not refresh the requested token")); } - }); + listener.onResponse(new Tuple<>(prependVersionAndEncodeAccessToken(refreshTokenStatus.getVersion(), decryptedTokens[0]), + prependVersionAndEncodeRefreshToken(refreshTokenStatus.getVersion(), decryptedTokens[1]))); + } catch (GeneralSecurityException | IOException e) { + logger.warn("Could not get stored superseding token values", e); + listener.onFailure(invalidGrantException("could not refresh the requested token")); + } } - private void getSupersedingTokenDocAsync(RefreshTokenStatus refreshTokenStatus, ActionListener listener) { - final String supersedingDocReference = refreshTokenStatus.getSupersededBy(); - if (supersedingDocReference.startsWith(securityTokensIndex.aliasName() + "|")) { - // superseding token doc is stored on the new tokens index, irrespective of where the superseded token doc resides - final String supersedingDocId = supersedingDocReference.substring(securityTokensIndex.aliasName().length() + 1); - getTokenDocAsync(supersedingDocId, securityTokensIndex, listener); - } else { - assert false == supersedingDocReference - .contains("|") : "The superseding doc reference appears to contain an alias name but should not"; - getTokenDocAsync(supersedingDocReference, securityMainIndex, listener); - } + /* + * Encrypts the values of the superseding access token and the refresh token, using a key derived from the superseded refresh token. + * The tokens are concatenated to a string separated with `|` before encryption so that we only perform one encryption operation + * and that we only need to store one field + */ + String encryptSupersedingTokens(String supersedingAccessToken, String supersedingRefreshToken, + String refreshToken, byte[] iv, byte[] salt) throws GeneralSecurityException { + Cipher cipher = getEncryptionCipher(iv, refreshToken, salt); + final String supersedingTokens = supersedingAccessToken + "|" + supersedingRefreshToken; + return Base64.getEncoder().encodeToString(cipher.doFinal(supersedingTokens.getBytes(StandardCharsets.UTF_8))); } private void getTokenDocAsync(String tokenDocId, SecurityIndexManager tokensIndex, ActionListener listener) { @@ -1016,7 +1093,7 @@ private void getTokenDocAsync(String tokenDocId, SecurityIndexManager tokensInde () -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, listener, client::get)); } - private Version getTokenVersionCompatibility() { + Version getTokenVersionCompatibility() { // newly minted tokens are compatible with the min node version in the cluster return clusterService.state().nodes().getMinNodeVersion(); } @@ -1029,13 +1106,13 @@ public static Boolean isTokenServiceEnabled(Settings settings) { * A refresh token has a fixed maximum lifetime of {@code ExpiredTokenRemover#MAXIMUM_TOKEN_LIFETIME_HOURS} hours. This checks if the * token document represents a valid token wrt this time interval. */ - private static Optional checkTokenDocumentExpired(Instant now, Map source) { - final Long creationEpochMilli = (Long) source.get("creation_time"); + private static Optional checkTokenDocumentExpired(Instant refreshRequested, Map src) { + final Long creationEpochMilli = (Long) src.get("creation_time"); if (creationEpochMilli == null) { throw new IllegalStateException("token document is missing creation time value"); } else { final Instant creationTime = Instant.ofEpochMilli(creationEpochMilli); - if (now.isAfter(creationTime.plus(ExpiredTokenRemover.MAXIMUM_TOKEN_LIFETIME_HOURS, ChronoUnit.HOURS))) { + if (refreshRequested.isAfter(creationTime.plus(ExpiredTokenRemover.MAXIMUM_TOKEN_LIFETIME_HOURS, ChronoUnit.HOURS))) { return Optional.of(invalidGrantException("token document has expired")); } else { return Optional.empty(); @@ -1048,17 +1125,17 @@ private static Optional checkTokenDocumentExpire * parsed {@code RefreshTokenStatus} together with an {@code Optional} validation exception that encapsulates the various logic about * when and by who a token can be refreshed. */ - private static Tuple> checkTokenDocumentForRefresh(Instant now, - Authentication clientAuth, Map source) throws IllegalStateException, DateTimeException { + private static Tuple> checkTokenDocumentForRefresh( + Instant refreshRequested, Authentication clientAuth, Map source) throws IllegalStateException, DateTimeException { final RefreshTokenStatus refreshTokenStatus = RefreshTokenStatus.fromSourceMap(getRefreshTokenSourceMap(source)); final UserToken userToken = UserToken.fromSourceMap(getUserTokenSourceMap(source)); refreshTokenStatus.setVersion(userToken.getVersion()); - final ElasticsearchSecurityException validationException = checkTokenDocumentExpired(now, source).orElseGet(() -> { + final ElasticsearchSecurityException validationException = checkTokenDocumentExpired(refreshRequested, source).orElseGet(() -> { if (refreshTokenStatus.isInvalidated()) { return invalidGrantException("token has been invalidated"); } else { return checkClientCanRefresh(refreshTokenStatus, clientAuth) - .orElse(checkMultipleRefreshes(now, refreshTokenStatus).orElse(null)); + .orElse(checkMultipleRefreshes(refreshRequested, refreshTokenStatus).orElse(null)); } }); return new Tuple<>(refreshTokenStatus, Optional.ofNullable(validationException)); @@ -1111,13 +1188,14 @@ private static Map getUserTokenSourceMap(Map sou * @return An {@code Optional} containing the exception in case this refresh token cannot be reused, or an empty Optional if * refreshing is allowed. */ - private static Optional checkMultipleRefreshes(Instant now, RefreshTokenStatus refreshTokenStatus) { + private static Optional checkMultipleRefreshes(Instant refreshRequested, + RefreshTokenStatus refreshTokenStatus) { if (refreshTokenStatus.isRefreshed()) { if (refreshTokenStatus.getVersion().onOrAfter(VERSION_MULTIPLE_CONCURRENT_REFRESHES)) { - if (now.isAfter(refreshTokenStatus.getRefreshInstant().plus(30L, ChronoUnit.SECONDS))) { + if (refreshRequested.isAfter(refreshTokenStatus.getRefreshInstant().plus(30L, ChronoUnit.SECONDS))) { return Optional.of(invalidGrantException("token has already been refreshed more than 30 seconds in the past")); } - if (now.isBefore(refreshTokenStatus.getRefreshInstant().minus(30L, ChronoUnit.SECONDS))) { + if (refreshRequested.isBefore(refreshTokenStatus.getRefreshInstant().minus(30L, ChronoUnit.SECONDS))) { return Optional .of(invalidGrantException("token has been refreshed more than 30 seconds in the future, clock skew too great")); } @@ -1269,7 +1347,7 @@ private void sourceIndicesWithTokensAndRun(ActionListener> listener private BytesReference createTokenDocument(UserToken userToken, @Nullable String refreshToken, @Nullable Authentication originatingClientAuth) { assert refreshToken == null || originatingClientAuth != null : "non-null refresh token " + refreshToken - + " requires non-null client authn " + originatingClientAuth; + + " requires non-null client authn " + originatingClientAuth; try (XContentBuilder builder = XContentFactory.jsonBuilder()) { builder.startObject(); builder.field("doc_type", TOKEN_DOC_TYPE); @@ -1332,21 +1410,14 @@ private Tuple filterAndParseHit(SearchHit hit, @Nullable Pred */ private Tuple parseTokensFromDocument(Map source, @Nullable Predicate> filter) throws IllegalStateException, DateTimeException { - final String plainRefreshToken = (String) ((Map) source.get("refresh_token")).get("token"); + final String hashedRefreshToken = (String) ((Map) source.get("refresh_token")).get("token"); final Map userTokenSource = (Map) ((Map) source.get("access_token")).get("user_token"); if (null != filter && filter.test(userTokenSource) == false) { return null; } final UserToken userToken = UserToken.fromSourceMap(userTokenSource); - if (userToken.getVersion().onOrAfter(VERSION_TOKENS_INDEX_INTRODUCED)) { - final String versionedRefreshToken = plainRefreshToken != null ? - prependVersionAndEncode(userToken.getVersion(), plainRefreshToken) : null; - return new Tuple<>(userToken, versionedRefreshToken); - } else { - // do not prepend version to refresh token as the audience node version cannot deal with it - return new Tuple<>(userToken, plainRefreshToken); - } + return new Tuple<>(userToken, hashedRefreshToken); } private static String getTokenDocumentId(UserToken userToken) { @@ -1450,7 +1521,7 @@ public TimeValue getExpirationDelay() { return expirationDelay; } - private Instant getExpirationTime() { + Instant getExpirationTime() { return clock.instant().plusSeconds(expirationDelay.getSeconds()); } @@ -1478,38 +1549,34 @@ private String getFromHeader(ThreadContext threadContext) { return null; } - /** - * Serializes a token to a String containing the minimum compatible node version for decoding it back and either an encrypted - * representation of the token id for versions earlier to {@code #VERSION_ACCESS_TOKENS_UUIDS} or the token itself for versions after - * {@code #VERSION_ACCESS_TOKENS_UUIDS} - */ - public String getAccessTokenAsString(UserToken userToken) throws IOException, GeneralSecurityException { - if (userToken.getVersion().onOrAfter(VERSION_ACCESS_TOKENS_AS_UUIDS)) { + String prependVersionAndEncodeAccessToken(Version version, String accessToken) throws IOException, GeneralSecurityException { + if (version.onOrAfter(VERSION_ACCESS_TOKENS_AS_UUIDS)) { try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); OutputStream base64 = Base64.getEncoder().wrap(os); StreamOutput out = new OutputStreamStreamOutput(base64)) { - out.setVersion(userToken.getVersion()); - Version.writeVersion(userToken.getVersion(), out); - out.writeString(userToken.getId()); + out.setVersion(version); + Version.writeVersion(version, out); + out.writeString(accessToken); return new String(os.toByteArray(), StandardCharsets.UTF_8); } } else { // we know that the minimum length is larger than the default of the ByteArrayOutputStream so set the size to this explicitly - try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); + try (ByteArrayOutputStream os = new ByteArrayOutputStream(LEGACY_MINIMUM_BASE64_BYTES); OutputStream base64 = Base64.getEncoder().wrap(os); StreamOutput out = new OutputStreamStreamOutput(base64)) { - out.setVersion(userToken.getVersion()); + out.setVersion(version); KeyAndCache keyAndCache = keyCache.activeKeyCache; - Version.writeVersion(userToken.getVersion(), out); + Version.writeVersion(version, out); out.writeByteArray(keyAndCache.getSalt().bytes); out.writeByteArray(keyAndCache.getKeyHash().bytes); - final byte[] initializationVector = getNewInitializationVector(); + final byte[] initializationVector = getRandomBytes(IV_BYTES); out.writeByteArray(initializationVector); try (CipherOutputStream encryptedOutput = - new CipherOutputStream(out, getEncryptionCipher(initializationVector, keyAndCache, userToken.getVersion())); + new CipherOutputStream(out, getEncryptionCipher(initializationVector, keyAndCache, version)); StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { - encryptedStreamOutput.setVersion(userToken.getVersion()); - encryptedStreamOutput.writeString(userToken.getId()); + encryptedStreamOutput.setVersion(version); + encryptedStreamOutput.writeString(accessToken); + // StreamOutput needs to be closed explicitly because it wraps CipherOutputStream encryptedStreamOutput.close(); return new String(os.toByteArray(), StandardCharsets.UTF_8); } @@ -1517,7 +1584,7 @@ public String getAccessTokenAsString(UserToken userToken) throws IOException, Ge } } - private static String prependVersionAndEncode(Version version, String payload) { + static String prependVersionAndEncodeRefreshToken(Version version, String payload) { try (ByteArrayOutputStream os = new ByteArrayOutputStream(); OutputStream base64 = Base64.getEncoder().wrap(os); StreamOutput out = new OutputStreamStreamOutput(base64)) { @@ -1563,6 +1630,17 @@ Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) return cipher; } + /** + * Initialize the encryption cipher using the provided password to derive the encryption key. + */ + Cipher getEncryptionCipher(byte[] iv, String password, byte[] salt) throws GeneralSecurityException { + SecretKey key = computeSecretKey(password.toCharArray(), salt, TOKENS_ENCRYPTION_KEY_ITERATIONS); + Cipher cipher = Cipher.getInstance(ENCRYPTION_CIPHER); + cipher.init(Cipher.ENCRYPT_MODE, key, new GCMParameterSpec(128, iv), secureRandom); + cipher.updateAAD(salt); + return cipher; + } + private void getKeyAsync(BytesKey decodedSalt, KeyAndCache keyAndCache, ActionListener listener) { final SecretKey decodeKey = keyAndCache.getKey(decodedSalt); if (decodeKey != null) { @@ -1595,21 +1673,31 @@ private Cipher getDecryptionCipher(byte[] iv, SecretKey key, Version version, By return cipher; } - // Package private for testing - byte[] getNewInitializationVector() { - final byte[] initializationVector = new byte[IV_BYTES]; - secureRandom.nextBytes(initializationVector); - return initializationVector; + /** + * Initialize the decryption cipher using the provided password to derive the decryption key. + */ + private Cipher getDecryptionCipher(byte[] iv, String password, byte[] salt) throws GeneralSecurityException { + SecretKey key = computeSecretKey(password.toCharArray(), salt, TOKENS_ENCRYPTION_KEY_ITERATIONS); + Cipher cipher = Cipher.getInstance(ENCRYPTION_CIPHER); + cipher.init(Cipher.DECRYPT_MODE, key, new GCMParameterSpec(128, iv), secureRandom); + cipher.updateAAD(salt); + return cipher; + } + + byte[] getRandomBytes(int length) { + final byte[] bytes = new byte[length]; + secureRandom.nextBytes(bytes); + return bytes; } /** * Generates a secret key based off of the provided password and salt. - * This method is computationally expensive. + * This method can be computationally expensive. */ - static SecretKey computeSecretKey(char[] rawPassword, byte[] salt) + static SecretKey computeSecretKey(char[] rawPassword, byte[] salt, int iterations) throws NoSuchAlgorithmException, InvalidKeySpecException { SecretKeyFactory secretKeyFactory = SecretKeyFactory.getInstance(KDF_ALGORITHM); - PBEKeySpec keySpec = new PBEKeySpec(rawPassword, salt, ITERATIONS, 128); + PBEKeySpec keySpec = new PBEKeySpec(rawPassword, salt, iterations, 128); SecretKey tmp = secretKeyFactory.generateSecret(keySpec); return new SecretKeySpec(tmp.getEncoded(), "AES"); } @@ -2003,7 +2091,7 @@ private KeyAndCache(KeyAndTimestamp keyAndTimestamp, BytesKey salt) { .setMaximumWeight(500L) .build(); try { - SecretKey secretKey = computeSecretKey(keyAndTimestamp.getKey().getChars(), salt.bytes); + SecretKey secretKey = computeSecretKey(keyAndTimestamp.getKey().getChars(), salt.bytes, TOKEN_SERVICE_KEY_ITERATIONS); keyCache.put(salt, secretKey); } catch (Exception e) { throw new IllegalStateException(e); @@ -2019,7 +2107,7 @@ private SecretKey getKey(BytesKey salt) { public SecretKey getOrComputeKey(BytesKey decodedSalt) throws ExecutionException { return keyCache.computeIfAbsent(decodedSalt, (salt) -> { try (SecureString closeableChars = keyAndTimestamp.getKey().clone()) { - return computeSecretKey(closeableChars.getChars(), salt.bytes); + return computeSecretKey(closeableChars.getChars(), salt.bytes, TOKEN_SERVICE_KEY_ITERATIONS); } }); } @@ -2074,24 +2162,32 @@ KeyAndCache get(BytesKey passphraseHash) { /** * Contains metadata associated with the refresh token that is used for validity checks, but does not contain the proper token string. */ - private static final class RefreshTokenStatus { + static final class RefreshTokenStatus { private final boolean invalidated; private final String associatedUser; private final String associatedRealm; private final boolean refreshed; @Nullable private final Instant refreshInstant; - @Nullable private final String supersededBy; + @Nullable + private final String supersedingTokens; + @Nullable + private final String iv; + @Nullable + private final String salt; private Version version; - private RefreshTokenStatus(boolean invalidated, String associatedUser, String associatedRealm, boolean refreshed, - Instant refreshInstant, String supersededBy) { + // pkg-private for testing + RefreshTokenStatus(boolean invalidated, String associatedUser, String associatedRealm, boolean refreshed, Instant refreshInstant, + String supersedingTokens, String iv, String salt) { this.invalidated = invalidated; this.associatedUser = associatedUser; this.associatedRealm = associatedRealm; this.refreshed = refreshed; this.refreshInstant = refreshInstant; - this.supersededBy = supersededBy; + this.supersedingTokens = supersedingTokens; + this.iv = iv; + this.salt = salt; } boolean isInvalidated() { @@ -2114,8 +2210,19 @@ boolean isRefreshed() { return refreshInstant; } - @Nullable String getSupersededBy() { - return supersededBy; + @Nullable + String getSupersedingTokens() { + return supersedingTokens; + } + + @Nullable + String getIv() { + return iv; + } + + @Nullable + String getSalt() { + return salt; } Version getVersion() { @@ -2149,8 +2256,11 @@ static RefreshTokenStatus fromSourceMap(Map refreshTokenSource) } final Long refreshEpochMilli = (Long) refreshTokenSource.get("refresh_time"); final Instant refreshInstant = refreshEpochMilli == null ? null : Instant.ofEpochMilli(refreshEpochMilli); - final String supersededBy = (String) refreshTokenSource.get("superseded_by"); - return new RefreshTokenStatus(invalidated, associatedUser, associatedRealm, refreshed, refreshInstant, supersededBy); + final String supersedingTokens = (String) refreshTokenSource.get("superseding.encrypted_tokens"); + final String iv = (String) refreshTokenSource.get("superseding.encryption_iv"); + final String salt = (String) refreshTokenSource.get("superseding.encryption_salt"); + return new RefreshTokenStatus(invalidated, associatedUser, associatedRealm, refreshed, refreshInstant, supersedingTokens, + iv, salt); } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/UserToken.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/UserToken.java index 2bcf0849084bc..f46aa42a24450 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/UserToken.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/UserToken.java @@ -50,7 +50,7 @@ public final class UserToken implements Writeable, ToXContentObject { /** * Create a new token with an autogenerated id */ - UserToken(Version version, Authentication authentication, Instant expirationTime, Map metadata) { + private UserToken(Version version, Authentication authentication, Instant expirationTime, Map metadata) { this(UUIDs.randomBase64UUID(), version, authentication, expirationTime, metadata); } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java index 69cedf6389f7f..0ab3c96167c2c 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/oidc/TransportOpenIdConnectLogoutActionTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -47,7 +48,6 @@ import org.elasticsearch.xpack.core.ssl.SSLService; import org.elasticsearch.xpack.security.authc.Realms; import org.elasticsearch.xpack.security.authc.TokenService; -import org.elasticsearch.xpack.security.authc.UserToken; import org.elasticsearch.xpack.security.authc.oidc.OpenIdConnectRealm; import org.elasticsearch.xpack.security.authc.oidc.OpenIdConnectTestCase; import org.elasticsearch.xpack.security.authc.support.UserRoleMapper; @@ -195,20 +195,21 @@ public void testLogoutInvalidatesTokens() throws Exception { final JWT signedIdToken = generateIdToken(subject, randomAlphaOfLength(8), randomAlphaOfLength(8)); final User user = new User("oidc-user", new String[]{"superuser"}, null, null, null, true); final Authentication.RealmRef realmRef = new Authentication.RealmRef(oidcRealm.name(), OpenIdConnectRealmSettings.TYPE, "node01"); - final Authentication authentication = new Authentication(user, realmRef, null); - final Map tokenMetadata = new HashMap<>(); tokenMetadata.put("id_token_hint", signedIdToken.serialize()); tokenMetadata.put("oidc_realm", REALM_NAME); + final Authentication authentication = new Authentication(user, realmRef, null, null, Authentication.AuthenticationType.REALM, + tokenMetadata); - final PlainActionFuture> future = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, tokenMetadata, true, future); - final UserToken userToken = future.actionGet().v1(); - mockGetTokenFromId(userToken, false, client); - final String tokenString = tokenService.getAccessTokenAsString(userToken); + final PlainActionFuture> future = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, tokenMetadata, future); + final String accessToken = future.actionGet().v1(); + mockGetTokenFromId(tokenService, userTokenId, authentication, false, client); final OpenIdConnectLogoutRequest request = new OpenIdConnectLogoutRequest(); - request.setToken(tokenString); + request.setToken(accessToken); final PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(mock(Task.class), request, listener); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java index 3f4ac8942089c..6a9c487bf2013 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java @@ -32,6 +32,7 @@ import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.PathUtils; @@ -66,7 +67,6 @@ import org.elasticsearch.xpack.core.security.user.User; import org.elasticsearch.xpack.security.authc.Realms; import org.elasticsearch.xpack.security.authc.TokenService; -import org.elasticsearch.xpack.security.authc.UserToken; import org.elasticsearch.xpack.security.authc.saml.SamlLogoutRequestHandler; import org.elasticsearch.xpack.security.authc.saml.SamlNameId; import org.elasticsearch.xpack.security.authc.saml.SamlRealm; @@ -252,9 +252,14 @@ public void cleanup() { } public void testInvalidateCorrectTokensFromLogoutRequest() throws Exception { + final String userTokenId1 = UUIDs.randomBase64UUID(); + final String refreshToken1 = UUIDs.randomBase64UUID(); + final String userTokenId2 = UUIDs.randomBase64UUID(); + final String refreshToken2 = UUIDs.randomBase64UUID(); storeToken(logoutRequest.getNameId(), randomAlphaOfLength(10)); - final Tuple tokenToInvalidate1 = storeToken(logoutRequest.getNameId(), logoutRequest.getSession()); - final Tuple tokenToInvalidate2 = storeToken(logoutRequest.getNameId(), logoutRequest.getSession()); + final Tuple tokenToInvalidate1 = storeToken(userTokenId1, refreshToken1, logoutRequest.getNameId(), + logoutRequest.getSession()); + storeToken(userTokenId2, refreshToken2, logoutRequest.getNameId(), logoutRequest.getSession()); storeToken(new SamlNameId(NameID.PERSISTENT, randomAlphaOfLength(16), null, null, null), logoutRequest.getSession()); assertThat(indexRequests.size(), equalTo(4)); @@ -316,27 +321,27 @@ public void testInvalidateCorrectTokensFromLogoutRequest() throws Exception { assertThat(filter1.get(1), instanceOf(TermQueryBuilder.class)); assertThat(((TermQueryBuilder) filter1.get(1)).fieldName(), equalTo("refresh_token.token")); assertThat(((TermQueryBuilder) filter1.get(1)).value(), - equalTo(TokenService.unpackVersionAndPayload(tokenToInvalidate1.v2()).v2())); + equalTo(TokenService.hashTokenString(TokenService.unpackVersionAndPayload(tokenToInvalidate1.v2()).v2()))); assertThat(bulkRequests.size(), equalTo(4)); // 4 updates (refresh-token + access-token) // Invalidate refresh token 1 assertThat(bulkRequests.get(0).requests().get(0), instanceOf(UpdateRequest.class)); - assertThat(bulkRequests.get(0).requests().get(0).id(), equalTo("token_" + tokenToInvalidate1.v1().getId())); + assertThat(bulkRequests.get(0).requests().get(0).id(), equalTo("token_" + TokenService.hashTokenString(userTokenId1))); UpdateRequest updateRequest1 = (UpdateRequest) bulkRequests.get(0).requests().get(0); assertThat(updateRequest1.toString().contains("refresh_token"), equalTo(true)); // Invalidate access token 1 assertThat(bulkRequests.get(1).requests().get(0), instanceOf(UpdateRequest.class)); - assertThat(bulkRequests.get(1).requests().get(0).id(), equalTo("token_" + tokenToInvalidate1.v1().getId())); + assertThat(bulkRequests.get(1).requests().get(0).id(), equalTo("token_" + TokenService.hashTokenString(userTokenId1))); UpdateRequest updateRequest2 = (UpdateRequest) bulkRequests.get(1).requests().get(0); assertThat(updateRequest2.toString().contains("access_token"), equalTo(true)); // Invalidate refresh token 2 assertThat(bulkRequests.get(2).requests().get(0), instanceOf(UpdateRequest.class)); - assertThat(bulkRequests.get(2).requests().get(0).id(), equalTo("token_" + tokenToInvalidate2.v1().getId())); + assertThat(bulkRequests.get(2).requests().get(0).id(), equalTo("token_" + TokenService.hashTokenString(userTokenId2))); UpdateRequest updateRequest3 = (UpdateRequest) bulkRequests.get(2).requests().get(0); assertThat(updateRequest3.toString().contains("refresh_token"), equalTo(true)); // Invalidate access token 2 assertThat(bulkRequests.get(3).requests().get(0), instanceOf(UpdateRequest.class)); - assertThat(bulkRequests.get(3).requests().get(0).id(), equalTo("token_" + tokenToInvalidate2.v1().getId())); + assertThat(bulkRequests.get(3).requests().get(0).id(), equalTo("token_" + TokenService.hashTokenString(userTokenId2))); UpdateRequest updateRequest4 = (UpdateRequest) bulkRequests.get(3).requests().get(0); assertThat(updateRequest4.toString().contains("access_token"), equalTo(true)); } @@ -359,13 +364,19 @@ private Function findTokenByRefreshToken(SearchHit[] }; } - private Tuple storeToken(SamlNameId nameId, String session) throws IOException { + private Tuple storeToken(String userTokenId, String refreshToken, SamlNameId nameId, String session) { Authentication authentication = new Authentication(new User("bob"), new RealmRef("native", NativeRealmSettings.TYPE, "node01"), null); final Map metadata = samlRealm.createTokenMetadata(nameId, session); - final PlainActionFuture> future = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, metadata, true, future); + final PlainActionFuture> future = new PlainActionFuture<>(); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, metadata, future); return future.actionGet(); } + private Tuple storeToken(SamlNameId nameId, String session) { + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); + return storeToken(userTokenId, refreshToken, nameId, session); + } + } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java index 1652122bf6e80..9b9dc79a29cd4 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.collect.MapBuilder; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.PathUtils; @@ -55,7 +56,6 @@ import org.elasticsearch.xpack.core.ssl.SSLService; import org.elasticsearch.xpack.security.authc.Realms; import org.elasticsearch.xpack.security.authc.TokenService; -import org.elasticsearch.xpack.security.authc.UserToken; import org.elasticsearch.xpack.security.authc.saml.SamlNameId; import org.elasticsearch.xpack.security.authc.saml.SamlRealm; import org.elasticsearch.xpack.security.authc.saml.SamlRealmTests; @@ -236,19 +236,21 @@ public void testLogoutInvalidatesToken() throws Exception { .map(); final User user = new User("punisher", new String[]{"superuser"}, null, null, userMetaData, true); final Authentication.RealmRef realmRef = new Authentication.RealmRef(samlRealm.name(), SamlRealmSettings.TYPE, "node01"); - final Authentication authentication = new Authentication(user, realmRef, null); - final Map tokenMetaData = samlRealm.createTokenMetadata( - new SamlNameId(NameID.TRANSIENT, nameId, null, null, null), session); + new SamlNameId(NameID.TRANSIENT, nameId, null, null, null), session); + final Authentication authentication = new Authentication(user, realmRef, null, null, Authentication.AuthenticationType.REALM, + tokenMetaData); + - final PlainActionFuture> future = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, tokenMetaData, true, future); - final UserToken userToken = future.actionGet().v1(); - mockGetTokenFromId(userToken, false, client); - final String tokenString = tokenService.getAccessTokenAsString(userToken); + final PlainActionFuture> future = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, tokenMetaData, future); + final String accessToken = future.actionGet().v1(); + mockGetTokenFromId(tokenService, userTokenId, authentication, false, client); final SamlLogoutRequest request = new SamlLogoutRequest(); - request.setToken(tokenString); + request.setToken(accessToken); final PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(mock(Task.class), request, listener); final SamlLogoutResponse response = listener.get(); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java index c7994888a2631..67ce5ce2b27af 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java @@ -1108,14 +1108,16 @@ public void testAuthenticateWithToken() throws Exception { User user = new User("_username", "r1"); final AtomicBoolean completed = new AtomicBoolean(false); final Authentication expected = new Authentication(user, new RealmRef("realm", "custom", "node"), null); - PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); - tokenService.createOAuth2Tokens(expected, originatingAuth, Collections.emptyMap(), true, tokenFuture); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, expected, originatingAuth, Collections.emptyMap(), tokenFuture); } - String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1()); + String token = tokenFuture.get().v1(); when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE)); - mockGetTokenFromId(tokenFuture.get().v1(), false, client); + mockGetTokenFromId(tokenService, userTokenId, expected, false, client); when(securityIndex.isAvailable()).thenReturn(true); when(securityIndex.indexExists()).thenReturn(true); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { @@ -1191,13 +1193,15 @@ public void testExpiredToken() throws Exception { when(securityIndex.indexExists()).thenReturn(true); User user = new User("_username", "r1"); final Authentication expected = new Authentication(user, new RealmRef("realm", "custom", "node"), null); - PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); - tokenService.createOAuth2Tokens(expected, originatingAuth, Collections.emptyMap(), true, tokenFuture); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, expected, originatingAuth, Collections.emptyMap(), tokenFuture); } - String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1()); - mockGetTokenFromId(tokenFuture.get().v1(), true, client); + String token = tokenFuture.get().v1(); + mockGetTokenFromId(tokenService, userTokenId, expected, true, client); doAnswer(invocationOnMock -> { ((Runnable) invocationOnMock.getArguments()[1]).run(); return null; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java index 7f09444784c6d..42101b1f4ec97 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java @@ -28,8 +28,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.collect.Tuple; -import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -62,10 +60,7 @@ import org.junit.Before; import org.junit.BeforeClass; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.OutputStream; -import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; import java.time.Clock; import java.time.Instant; @@ -75,7 +70,6 @@ import java.util.HashMap; import java.util.Map; -import javax.crypto.CipherOutputStream; import javax.crypto.SecretKey; import static java.time.Clock.systemUTC; @@ -169,15 +163,16 @@ public static void shutdownThreadpool() throws InterruptedException { public void testAttachAndGetToken() throws Exception { TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - PlainActionFuture> tokenFuture = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture); - final UserToken token = tokenFuture.get().v1(); - assertNotNull(token); - mockGetTokenFromId(token, false); - authentication = token.getAuthentication(); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture); + final String accessToken = tokenFuture.get().v1(); + assertNotNull(accessToken); + mockGetTokenFromId(tokenService, userTokenId, authentication, false); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getAccessTokenAsString(token)); + requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + accessToken); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -214,16 +209,21 @@ public void testInvalidAuthorizationHeader() throws Exception { public void testRotateKey() throws Exception { TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); + // This test only makes sense in mixed clusters with pre v7.2.0 nodes where the Key is actually used + if (null == oldNode) { + oldNode = addAnotherDataNodeWithVersion(this.clusterService, randomFrom(Version.V_6_7_0, Version.V_7_1_0)); + } Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - PlainActionFuture> tokenFuture = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture); - final UserToken token = tokenFuture.get().v1(); - assertNotNull(token); - mockGetTokenFromId(token, false); - authentication = token.getAuthentication(); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture); + final String accessToken = tokenFuture.get().v1(); + assertNotNull(accessToken); + mockGetTokenFromId(tokenService, userTokenId, authentication, false); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(requestContext, getDeprecatedAccessTokenString(tokenService, token)); + storeTokenHeader(requestContext, accessToken); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -240,15 +240,18 @@ public void testRotateKey() throws Exception { assertAuthentication(authentication, serialized.getAuthentication()); } - PlainActionFuture> newTokenFuture = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, newTokenFuture); - final UserToken newToken = newTokenFuture.get().v1(); - assertNotNull(newToken); - assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token)); + PlainActionFuture> newTokenFuture = new PlainActionFuture<>(); + final String newUserTokenId = UUIDs.randomBase64UUID(); + final String newRefreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(newUserTokenId, newRefreshToken, authentication, authentication, Collections.emptyMap(), + newTokenFuture); + final String newAccessToken = newTokenFuture.get().v1(); + assertNotNull(newAccessToken); + assertNotEquals(newAccessToken, accessToken); requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(requestContext, getDeprecatedAccessTokenString(tokenService, newToken)); - mockGetTokenFromId(newToken, false); + storeTokenHeader(requestContext, newAccessToken); + mockGetTokenFromId(tokenService, newUserTokenId, authentication, false); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -267,6 +270,10 @@ private void rotateKeys(TokenService tokenService) { public void testKeyExchange() throws Exception { TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); + // This test only makes sense in mixed clusters with pre v7.2.0 nodes where the Key is actually used + if (null == oldNode) { + oldNode = addAnotherDataNodeWithVersion(this.clusterService, randomFrom(Version.V_6_7_0, Version.V_7_1_0)); + } int numRotations = randomIntBetween(1, 5); for (int i = 0; i < numRotations; i++) { rotateKeys(tokenService); @@ -274,20 +281,21 @@ public void testKeyExchange() throws Exception { TokenService otherTokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); otherTokenService.refreshMetaData(tokenService.getTokenMetaData()); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - PlainActionFuture> tokenFuture = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture); - final UserToken token = tokenFuture.get().v1(); - assertNotNull(token); - mockGetTokenFromId(token, false); - authentication = token.getAuthentication(); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture); + final String accessToken = tokenFuture.get().v1(); + assertNotNull(accessToken); + mockGetTokenFromId(tokenService, userTokenId, authentication, false); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(requestContext, getDeprecatedAccessTokenString(tokenService, token)); + storeTokenHeader(requestContext, accessToken); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); otherTokenService.getAndValidateToken(requestContext, future); UserToken serialized = future.get(); - assertEquals(authentication, serialized.getAuthentication()); + assertAuthentication(serialized.getAuthentication(), authentication); } rotateKeys(tokenService); @@ -298,22 +306,27 @@ public void testKeyExchange() throws Exception { PlainActionFuture future = new PlainActionFuture<>(); otherTokenService.getAndValidateToken(requestContext, future); UserToken serialized = future.get(); - assertEquals(authentication, serialized.getAuthentication()); + assertAuthentication(serialized.getAuthentication(), authentication); } } public void testPruneKeys() throws Exception { TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); + // This test only makes sense in mixed clusters with pre v7.2.0 nodes where the Key is actually used + if (null == oldNode) { + oldNode = addAnotherDataNodeWithVersion(this.clusterService, randomFrom(Version.V_6_7_0, Version.V_7_1_0)); + } Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - PlainActionFuture> tokenFuture = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture); - final UserToken token = tokenFuture.get().v1(); - assertNotNull(token); - mockGetTokenFromId(token, false); - authentication = token.getAuthentication(); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture); + final String accessToken = tokenFuture.get().v1(); + assertNotNull(accessToken); + mockGetTokenFromId(tokenService, userTokenId, authentication, false); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(requestContext, getDeprecatedAccessTokenString(tokenService, token)); + storeTokenHeader(requestContext, accessToken); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -336,11 +349,14 @@ public void testPruneKeys() throws Exception { assertAuthentication(authentication, serialized.getAuthentication()); } - PlainActionFuture> newTokenFuture = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, newTokenFuture); - final UserToken newToken = newTokenFuture.get().v1(); - assertNotNull(newToken); - assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token)); + PlainActionFuture> newTokenFuture = new PlainActionFuture<>(); + final String newUserTokenId = UUIDs.randomBase64UUID(); + final String newRefreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(newUserTokenId, newRefreshToken, authentication, authentication, Collections.emptyMap(), + newTokenFuture); + final String newAccessToken = newTokenFuture.get().v1(); + assertNotNull(newAccessToken); + assertNotEquals(newAccessToken, accessToken); metaData = tokenService.pruneKeys(1); tokenService.refreshMetaData(metaData); @@ -353,8 +369,8 @@ public void testPruneKeys() throws Exception { } requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(requestContext, getDeprecatedAccessTokenString(tokenService, newToken)); - mockGetTokenFromId(newToken, false); + storeTokenHeader(requestContext, newAccessToken); + mockGetTokenFromId(tokenService, newUserTokenId, authentication, false); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); tokenService.getAndValidateToken(requestContext, future); @@ -366,16 +382,21 @@ public void testPruneKeys() throws Exception { public void testPassphraseWorks() throws Exception { TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); + // This test only makes sense in mixed clusters with pre v7.1.0 nodes where the Key is actually used + if (null == oldNode) { + oldNode = addAnotherDataNodeWithVersion(this.clusterService, randomFrom(Version.V_6_7_0, Version.V_7_1_0)); + } Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - PlainActionFuture> tokenFuture = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture); - final UserToken token = tokenFuture.get().v1(); - assertNotNull(token); - mockGetTokenFromId(token, false); - authentication = token.getAuthentication(); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture); + final String accessToken = tokenFuture.get().v1(); + assertNotNull(accessToken); + mockGetTokenFromId(tokenService, userTokenId, authentication, false); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(requestContext, getDeprecatedAccessTokenString(tokenService, token)); + storeTokenHeader(requestContext, accessToken); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -395,29 +416,40 @@ public void testPassphraseWorks() throws Exception { public void testGetTokenWhenKeyCacheHasExpired() throws Exception { TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); + // This test only makes sense in mixed clusters with pre v7.1.0 nodes where the Key is actually used + if (null == oldNode) { + oldNode = addAnotherDataNodeWithVersion(this.clusterService, randomFrom(Version.V_6_7_0, Version.V_7_1_0)); + } Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - PlainActionFuture> tokenFuture = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture); - UserToken token = tokenFuture.get().v1(); - assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue()); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture); + String accessToken = tokenFuture.get().v1(); + assertThat(accessToken, notNullValue()); tokenService.clearActiveKeyCache(); - assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue()); + + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture); + accessToken = tokenFuture.get().v1(); + assertThat(accessToken, notNullValue()); } public void testInvalidatedToken() throws Exception { when(securityMainIndex.indexExists()).thenReturn(true); TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - PlainActionFuture> tokenFuture = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture); - final UserToken token = tokenFuture.get().v1(); - assertNotNull(token); - mockGetTokenFromId(token, true); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture); + final String accessToken = tokenFuture.get().v1(); + assertNotNull(accessToken); + mockGetTokenFromId(tokenService, userTokenId, authentication, true); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(requestContext, tokenService.getAccessTokenAsString(token)); + storeTokenHeader(requestContext, accessToken); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -436,8 +468,10 @@ private void storeTokenHeader(ThreadContext requestContext, String tokenString) public void testComputeSecretKeyIsConsistent() throws Exception { byte[] saltArr = new byte[32]; random().nextBytes(saltArr); - SecretKey key = TokenService.computeSecretKey("some random passphrase".toCharArray(), saltArr); - SecretKey key2 = TokenService.computeSecretKey("some random passphrase".toCharArray(), saltArr); + SecretKey key = + TokenService.computeSecretKey("some random passphrase".toCharArray(), saltArr, TokenService.TOKEN_SERVICE_KEY_ITERATIONS); + SecretKey key2 = + TokenService.computeSecretKey("some random passphrase".toCharArray(), saltArr, TokenService.TOKEN_SERVICE_KEY_ITERATIONS); assertArrayEquals(key.getEncoded(), key2.getEncoded()); } @@ -468,14 +502,15 @@ public void testTokenExpiry() throws Exception { ClockMock clock = ClockMock.frozen(); TokenService tokenService = createTokenService(tokenServiceEnabledSettings, clock); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - PlainActionFuture> tokenFuture = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture); - final UserToken token = tokenFuture.get().v1(); - mockGetTokenFromId(token, false); - authentication = token.getAuthentication(); + final String userTokenId = UUIDs.randomBase64UUID(); + UserToken userToken = new UserToken(userTokenId, tokenService.getTokenVersionCompatibility(), authentication, + tokenService.getExpirationTime(), Collections.emptyMap()); + mockGetTokenFromId(userToken, false); + final String accessToken = tokenService.prependVersionAndEncodeAccessToken(tokenService.getTokenVersionCompatibility(), userTokenId + ); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(requestContext, tokenService.getAccessTokenAsString(token)); + storeTokenHeader(requestContext, accessToken); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { // the clock is still frozen, so the cookie should be valid @@ -519,7 +554,7 @@ public void testTokenServiceDisabled() throws Exception { TokenService tokenService = new TokenService(Settings.builder() .put(XPackSettings.TOKEN_SERVICE_ENABLED_SETTING.getKey(), false) .build(), - Clock.systemUTC(), client, licenseState, securityMainIndex, securityTokensIndex, clusterService); + Clock.systemUTC(), client, licenseState, securityMainIndex, securityTokensIndex, clusterService); IllegalStateException e = expectThrows(IllegalStateException.class, () -> tokenService.createOAuth2Tokens(null, null, null, true, null)); assertEquals("security tokens are not enabled", e.getMessage()); @@ -577,14 +612,15 @@ public void testMalformedToken() throws Exception { public void testIndexNotAvailable() throws Exception { TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - PlainActionFuture> tokenFuture = new PlainActionFuture<>(); - tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture); - final UserToken token = tokenFuture.get().v1(); - assertNotNull(token); - //mockGetTokenFromId(token, false); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String refreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture); + final String accessToken = tokenFuture.get().v1(); + assertNotNull(accessToken); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(requestContext, tokenService.getAccessTokenAsString(token)); + storeTokenHeader(requestContext, accessToken); doAnswer(invocationOnMock -> { ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; @@ -620,34 +656,64 @@ public void testIndexNotAvailable() throws Exception { when(tokensIndex.isAvailable()).thenReturn(true); when(tokensIndex.indexExists()).thenReturn(true); - mockGetTokenFromId(token, false); + mockGetTokenFromId(tokenService, userTokenId, authentication, false); future = new PlainActionFuture<>(); tokenService.getAndValidateToken(requestContext, future); - assertEquals(future.get().getAuthentication(), token.getAuthentication()); + assertAuthentication(future.get().getAuthentication(), authentication); } } public void testGetAuthenticationWorksWithExpiredUserToken() throws Exception { TokenService tokenService = createTokenService(tokenServiceEnabledSettings, Clock.systemUTC()); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - UserToken expired = new UserToken(authentication, Instant.now().minus(3L, ChronoUnit.DAYS)); + final String userTokenId = UUIDs.randomBase64UUID(); + UserToken expired = new UserToken(userTokenId, tokenService.getTokenVersionCompatibility(), authentication, + Instant.now().minus(3L, ChronoUnit.DAYS), Collections.emptyMap()); mockGetTokenFromId(expired, false); - String userTokenString = tokenService.getAccessTokenAsString(expired); + final String accessToken = tokenService.prependVersionAndEncodeAccessToken(tokenService.getTokenVersionCompatibility(), userTokenId + ); PlainActionFuture>> authFuture = new PlainActionFuture<>(); - tokenService.getAuthenticationAndMetaData(userTokenString, authFuture); + tokenService.getAuthenticationAndMetaData(accessToken, authFuture); Authentication retrievedAuth = authFuture.actionGet().v1(); - assertEquals(authentication, retrievedAuth); + assertAuthentication(authentication, retrievedAuth); + } + + public void testSupercedingTokenEncryption() throws Exception { + TokenService tokenService = createTokenService(tokenServiceEnabledSettings, Clock.systemUTC()); + Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String refrehToken = UUIDs.randomBase64UUID(); + final String newAccessToken = UUIDs.randomBase64UUID(); + final String newRefreshToken = UUIDs.randomBase64UUID(); + final byte[] iv = tokenService.getRandomBytes(TokenService.IV_BYTES); + final byte[] salt = tokenService.getRandomBytes(TokenService.SALT_BYTES); + final Version version = tokenService.getTokenVersionCompatibility(); + String encryptedTokens = tokenService.encryptSupersedingTokens(newAccessToken, newRefreshToken, refrehToken, iv, + salt); + TokenService.RefreshTokenStatus refreshTokenStatus = new TokenService.RefreshTokenStatus(false, + authentication.getUser().principal(), authentication.getAuthenticatedBy().getName(), true, Instant.now().minusSeconds(5L), + encryptedTokens, Base64.getEncoder().encodeToString(iv), Base64.getEncoder().encodeToString(salt)); + refreshTokenStatus.setVersion(version); + tokenService.decryptAndReturnSupersedingTokens(refrehToken, refreshTokenStatus, tokenFuture); + if (version.onOrAfter(TokenService.VERSION_ACCESS_TOKENS_AS_UUIDS)) { + // previous versions serialized the access token encrypted and the cipher text was different each time (due to different IVs) + assertThat(tokenService.prependVersionAndEncodeAccessToken(version, newAccessToken), equalTo(tokenFuture.get().v1())); + } + assertThat(TokenService.prependVersionAndEncodeRefreshToken(version, newRefreshToken), equalTo(tokenFuture.get().v2())); } public void testCannotValidateTokenIfLicenseDoesNotAllowTokens() throws Exception { when(licenseState.isTokenServiceAllowed()).thenReturn(true); TokenService tokenService = createTokenService(tokenServiceEnabledSettings, Clock.systemUTC()); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - UserToken token = new UserToken(authentication, Instant.now().plusSeconds(180)); + final String userTokenId = UUIDs.randomBase64UUID(); + UserToken token = new UserToken(userTokenId, tokenService.getTokenVersionCompatibility(), authentication, + Instant.now().plusSeconds(180), Collections.emptyMap()); mockGetTokenFromId(token, false); - + final String accessToken = tokenService.prependVersionAndEncodeAccessToken(tokenService.getTokenVersionCompatibility(), userTokenId + ); final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(threadContext, tokenService.getAccessTokenAsString(token)); + storeTokenHeader(threadContext, tokenService.prependVersionAndEncodeAccessToken(token.getVersion(), accessToken)); PlainActionFuture authFuture = new PlainActionFuture<>(); when(licenseState.isTokenServiceAllowed()).thenReturn(false); @@ -660,18 +726,30 @@ private TokenService createTokenService(Settings settings, Clock clock) throws G return new TokenService(settings, clock, client, licenseState, securityMainIndex, securityTokensIndex, clusterService); } - private void mockGetTokenFromId(UserToken userToken, boolean isExpired) { - mockGetTokenFromId(userToken, isExpired, client); + private void mockGetTokenFromId(TokenService tokenService, String accessToken, Authentication authentication, boolean isExpired) { + mockGetTokenFromId(tokenService, accessToken, authentication, isExpired, client); } - public static void mockGetTokenFromId(UserToken userToken, boolean isExpired, Client client) { + public static void mockGetTokenFromId(TokenService tokenService, String userTokenId, Authentication authentication, boolean isExpired, + Client client) { doAnswer(invocationOnMock -> { GetRequest request = (GetRequest) invocationOnMock.getArguments()[0]; ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; GetResponse response = mock(GetResponse.class); - if (userToken.getId().equals(request.id().replace("token_", ""))) { + Version tokenVersion = tokenService.getTokenVersionCompatibility(); + final String possiblyHashedUserTokenId; + if (tokenVersion.onOrAfter(TokenService.VERSION_ACCESS_TOKENS_AS_UUIDS)) { + possiblyHashedUserTokenId = TokenService.hashTokenString(userTokenId); + } else { + possiblyHashedUserTokenId = userTokenId; + } + if (possiblyHashedUserTokenId.equals(request.id().replace("token_", ""))) { when(response.isExists()).thenReturn(true); Map sourceMap = new HashMap<>(); + final Authentication tokenAuth = new Authentication(authentication.getUser(), authentication.getAuthenticatedBy(), + authentication.getLookedUpBy(), tokenVersion, AuthenticationType.TOKEN, authentication.getMetadata()); + final UserToken userToken = new UserToken(possiblyHashedUserTokenId, tokenVersion, tokenAuth, + tokenService.getExpirationTime(), authentication.getMetadata()); try (XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent())) { userToken.toXContent(builder, ToXContent.EMPTY_PARAMS); Map accessTokenMap = new HashMap<>(); @@ -687,35 +765,42 @@ public static void mockGetTokenFromId(UserToken userToken, boolean isExpired, Cl }).when(client).get(any(GetRequest.class), any(ActionListener.class)); } + private void mockGetTokenFromId(UserToken userToken, boolean isExpired) { + doAnswer(invocationOnMock -> { + GetRequest request = (GetRequest) invocationOnMock.getArguments()[0]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + GetResponse response = mock(GetResponse.class); + final String possiblyHashedUserTokenId; + if (userToken.getVersion().onOrAfter(TokenService.VERSION_ACCESS_TOKENS_AS_UUIDS)) { + possiblyHashedUserTokenId = TokenService.hashTokenString(userToken.getId()); + } else { + possiblyHashedUserTokenId = userToken.getId(); + } + if (possiblyHashedUserTokenId.equals(request.id().replace("token_", ""))) { + when(response.isExists()).thenReturn(true); + Map sourceMap = new HashMap<>(); + try (XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent())) { + userToken.toXContent(builder, ToXContent.EMPTY_PARAMS); + Map accessTokenMap = new HashMap<>(); + Map userTokenMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), + Strings.toString(builder), false); + userTokenMap.put("id", possiblyHashedUserTokenId); + accessTokenMap.put("user_token", userTokenMap); + accessTokenMap.put("invalidated", isExpired); + sourceMap.put("access_token", accessTokenMap); + } + when(response.getSource()).thenReturn(sourceMap); + } + listener.onResponse(response); + return Void.TYPE; + }).when(client).get(any(GetRequest.class), any(ActionListener.class)); + } + public static void assertAuthentication(Authentication result, Authentication expected) { assertEquals(expected.getUser(), result.getUser()); assertEquals(expected.getAuthenticatedBy(), result.getAuthenticatedBy()); assertEquals(expected.getLookedUpBy(), result.getLookedUpBy()); assertEquals(expected.getMetadata(), result.getMetadata()); - assertEquals(AuthenticationType.TOKEN, result.getAuthenticationType()); - } - - protected String getDeprecatedAccessTokenString(TokenService tokenService, UserToken userToken) throws IOException, - GeneralSecurityException { - try (ByteArrayOutputStream os = new ByteArrayOutputStream(TokenService.MINIMUM_BASE64_BYTES); - OutputStream base64 = Base64.getEncoder().wrap(os); - StreamOutput out = new OutputStreamStreamOutput(base64)) { - out.setVersion(Version.V_7_0_0); - TokenService.KeyAndCache keyAndCache = tokenService.getActiveKeyCache(); - Version.writeVersion(Version.V_7_0_0, out); - out.writeByteArray(keyAndCache.getSalt().bytes); - out.writeByteArray(keyAndCache.getKeyHash().bytes); - final byte[] initializationVector = tokenService.getNewInitializationVector(); - out.writeByteArray(initializationVector); - try (CipherOutputStream encryptedOutput = - new CipherOutputStream(out, tokenService.getEncryptionCipher(initializationVector, keyAndCache, Version.V_7_0_0)); - StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { - encryptedStreamOutput.setVersion(Version.V_7_0_0); - encryptedStreamOutput.writeString(userToken.getId()); - encryptedStreamOutput.close(); - return new String(os.toByteArray(), StandardCharsets.UTF_8); - } - } } private DiscoveryNode addAnotherDataNodeWithVersion(ClusterService clusterService, Version version) { diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/HasherTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/HasherTests.java index 6086dc642d22f..e51945cd90418 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/HasherTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/HasherTests.java @@ -50,6 +50,10 @@ public void testSSHA256SelfGenerated() throws Exception { testHasherSelfGenerated(Hasher.SSHA256); } + public void testSHA256SelfGenerated() throws Exception { + testHasherSelfGenerated(Hasher.SHA256); + } + public void testNoopSelfGenerated() throws Exception { testHasherSelfGenerated(Hasher.NOOP); }