diff --git a/driver-core/src/main/com/mongodb/ConnectionString.java b/driver-core/src/main/com/mongodb/ConnectionString.java index c5197b8b7d4..8bb802e9e70 100644 --- a/driver-core/src/main/com/mongodb/ConnectionString.java +++ b/driver-core/src/main/com/mongodb/ConnectionString.java @@ -47,7 +47,10 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateCreateOidcCredential; import static java.lang.String.format; import static java.util.Arrays.asList; @@ -282,6 +285,9 @@ public class ConnectionString { private static final Set ALLOWED_OPTIONS_IN_TXT_RECORD = new HashSet<>(asList("authsource", "replicaset", "loadbalanced")); private static final Logger LOGGER = Loggers.getLogger("uri"); + private static final List MECHANISM_KEYS_DISALLOWED_IN_CONNECTION_STRING = Stream.of(ALLOWED_HOSTS_KEY) + .map(k -> k.toLowerCase()) + .collect(Collectors.toList()); private final MongoCredential credential; private final boolean isSrvProtocol; @@ -917,6 +923,11 @@ private MongoCredential createCredentials(final Map> option } String key = mechanismPropertyKeyValue[0].trim().toLowerCase(); String value = mechanismPropertyKeyValue[1].trim(); + if (MECHANISM_KEYS_DISALLOWED_IN_CONNECTION_STRING.contains(key)) { + throw new IllegalArgumentException(format("The connection string contains disallowed mechanism properties. " + + "'%s' must be set on the credential programmatically.", key)); + } + if (key.equals("canonicalize_host_name")) { credential = credential.withMechanismProperty(key, Boolean.valueOf(value)); } else { diff --git a/driver-core/src/main/com/mongodb/MongoCredential.java b/driver-core/src/main/com/mongodb/MongoCredential.java index 4c10e1f640c..295803e55a4 100644 --- a/driver-core/src/main/com/mongodb/MongoCredential.java +++ b/driver-core/src/main/com/mongodb/MongoCredential.java @@ -25,6 +25,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -187,7 +188,8 @@ public final class MongoCredential { * The provider name. The value must be a string. *

* If this is provided, - * {@link MongoCredential#OIDC_CALLBACK_KEY} + * {@link MongoCredential#OIDC_CALLBACK_KEY} and + * {@link MongoCredential#OIDC_HUMAN_CALLBACK_KEY} * must not be provided. * * @see #createOidcCredential(String) @@ -197,10 +199,13 @@ public final class MongoCredential { /** * This callback is invoked when the OIDC-based authenticator requests - * tokens from the identity provider. The type of the value must be - * {@link OidcRequestCallback}. + * a token. The type of the value must be {@link OidcCallback}. + * {@link IdpInfo} will not be supplied to the callback, + * and a {@linkplain OidcCallbackResult#getRefreshToken() refresh token} + * must not be returned by the callback. *

* If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * and {@link MongoCredential#OIDC_HUMAN_CALLBACK_KEY} * must not be provided. * * @see #createOidcCredential(String) @@ -208,6 +213,46 @@ public final class MongoCredential { */ public static final String OIDC_CALLBACK_KEY = "OIDC_CALLBACK"; + /** + * This callback is invoked when the OIDC-based authenticator requests + * a token from the identity provider (IDP) using the IDP information + * from the MongoDB server. The type of the value must be + * {@link OidcCallback}. + *

+ * If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * and {@link MongoCredential#OIDC_CALLBACK_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String OIDC_HUMAN_CALLBACK_KEY = "OIDC_HUMAN_CALLBACK"; + + + /** + * Mechanism key for a list of allowed hostnames or ip-addresses for MongoDB connections. Ports must be excluded. + * The hostnames may include a leading "*." wildcard, which allows for matching (potentially nested) subdomains. + * When MONGODB-OIDC authentication is attempted against a hostname that does not match any of list of allowed hosts + * the driver will raise an error. The type of the value must be {@code List}. + * + * @see MongoCredential#DEFAULT_ALLOWED_HOSTS + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String ALLOWED_HOSTS_KEY = "ALLOWED_HOSTS"; + + /** + * The list of allowed hosts that will be used if no + * {@link MongoCredential#ALLOWED_HOSTS_KEY} value is supplied. + * The default allowed hosts are: + * {@code "*.mongodb.net", "*.mongodb-qa.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"} + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final List DEFAULT_ALLOWED_HOSTS = Collections.unmodifiableList(Arrays.asList( + "*.mongodb.net", "*.mongodb-qa.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1")); + /** * Creates a MongoCredential instance with an unspecified mechanism. The client will negotiate the best mechanism based on the * version of the server that the client is authenticating to. @@ -365,6 +410,8 @@ public static MongoCredential createAwsCredential(@Nullable final String userNam * @see #withMechanismProperty(String, Object) * @see #PROVIDER_NAME_KEY * @see #OIDC_CALLBACK_KEY + * @see #OIDC_HUMAN_CALLBACK_KEY + * @see #ALLOWED_HOSTS_KEY * @mongodb.server.release 7.0 */ public static MongoCredential createOidcCredential(@Nullable final String userName) { @@ -593,10 +640,15 @@ public String toString() { } /** - * The context for the {@link OidcRequestCallback#onRequest(OidcRequestContext) OIDC request callback}. + * The context for the {@link OidcCallback#onRequest(OidcCallbackContext) OIDC request callback}. */ @Evolving - public interface OidcRequestContext { + public interface OidcCallbackContext { + /** + * @return The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + */ + @Nullable + IdpInfo getIdpInfo(); /** * @return The timeout that this callback must complete within. @@ -607,6 +659,12 @@ public interface OidcRequestContext { * @return The OIDC callback API version. Currently, version 1. */ int getVersion(); + + /** + * @return The OIDC Refresh token supplied by a prior callback invocation. + */ + @Nullable + String getRefreshToken(); } /** @@ -616,27 +674,76 @@ public interface OidcRequestContext { * It does not have to be thread-safe, unless it is provided to multiple * MongoClients. */ - public interface OidcRequestCallback { + public interface OidcCallback { /** * @param context The context. * @return The response produced by an OIDC Identity Provider */ - RequestCallbackResult onRequest(OidcRequestContext context); + OidcCallbackResult onRequest(OidcCallbackContext context); + } + + /** + * The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + */ + @Evolving + public interface IdpInfo { + /** + * @return URL which describes the Authorization Server. This identifier is the + * iss of provided access tokens, and is viable for RFC8414 metadata + * discovery and RFC9207 identification. + */ + String getIssuer(); + + /** + * @return Unique client ID for this OIDC client. + */ + String getClientId(); + + /** + * @return Additional scopes to request from Identity Provider. Immutable. + */ + List getRequestScopes(); } /** * The response produced by an OIDC Identity Provider. */ - public static final class RequestCallbackResult { + public static final class OidcCallbackResult { private final String accessToken; + private final Duration expiresIn; + + @Nullable + private final String refreshToken; + + /** + * @param accessToken The OIDC access token. + * @param expiresIn Time until the access token expires. + * A {@linkplain Duration#isZero() zero-length} duration + * means that the access token does not expire. + */ + public OidcCallbackResult(final String accessToken, final Duration expiresIn) { + this(accessToken, expiresIn, null); + } + /** - * @param accessToken The OIDC access token + * @param accessToken The OIDC access token. + * @param expiresIn Time until the access token expires. + * A {@linkplain Duration#isZero() zero-length} duration + * means that the access token does not expire. + * @param refreshToken The refresh token. If null, refresh will not be attempted. */ - public RequestCallbackResult(final String accessToken) { + public OidcCallbackResult(final String accessToken, final Duration expiresIn, + @Nullable final String refreshToken) { notNull("accessToken", accessToken); + notNull("expiresIn", expiresIn); + if (expiresIn.isNegative()) { + throw new IllegalArgumentException("expiresIn must not be a negative value"); + } this.accessToken = accessToken; + this.expiresIn = expiresIn; + this.refreshToken = refreshToken; } /** @@ -645,5 +752,13 @@ public RequestCallbackResult(final String accessToken) { public String getAccessToken() { return accessToken; } + + /** + * @return The OIDC refresh token. If null, refresh will not be attempted. + */ + @Nullable + public String getRefreshToken() { + return refreshToken; + } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 70f9682476c..6b2362cbc1f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -21,7 +21,7 @@ import com.mongodb.MongoCommandException; import com.mongodb.MongoConfigurationException; import com.mongodb.MongoCredential; -import com.mongodb.MongoCredential.RequestCallbackResult; +import com.mongodb.MongoCredential.OidcCallbackResult; import com.mongodb.MongoException; import com.mongodb.MongoSecurityException; import com.mongodb.ServerAddress; @@ -34,6 +34,7 @@ import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; +import org.bson.RawBsonDocument; import javax.security.sasl.SaslClient; import java.io.IOException; @@ -42,12 +43,18 @@ import java.nio.file.Paths; import java.time.Duration; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; -import static com.mongodb.MongoCredential.OidcRequestCallback; -import static com.mongodb.MongoCredential.OidcRequestContext; +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.DEFAULT_ALLOWED_HOSTS; +import static com.mongodb.MongoCredential.IdpInfo; +import static com.mongodb.MongoCredential.OIDC_HUMAN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OidcCallback; +import static com.mongodb.MongoCredential.OidcCallbackContext; import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; import static com.mongodb.assertions.Assertions.assertFalse; @@ -69,6 +76,9 @@ public final class OidcAuthenticator extends SaslAuthenticator { public static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; private static final int CALLBACK_API_VERSION_NUMBER = 1; + @Nullable + private ServerAddress serverAddress; + @Nullable private String connectionLastAccessToken; @@ -94,6 +104,7 @@ public String getMechanismName() { @Override protected SaslClient createSaslClient(final ServerAddress serverAddress) { + this.serverAddress = assertNotNull(serverAddress); MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); return new OidcSaslClient(mongoCredentialWithCache); } @@ -141,11 +152,25 @@ public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument resp speculativeAuthenticateResponse = response; } + private boolean isAutomaticAuthentication() { + return getOidcCallbackMechanismProperty(PROVIDER_NAME_KEY) == null; + } + + private boolean isHumanCallback() { + return getOidcCallbackMechanismProperty(OIDC_HUMAN_CALLBACK_KEY) != null; + } + @Nullable - private OidcRequestCallback getRequestCallback() { + private OidcCallback getOidcCallbackMechanismProperty(final String key) { return getMongoCredentialWithCache() .getCredential() - .getMechanismProperty(OIDC_CALLBACK_KEY, null); + .getMechanismProperty(key, null); + } + + @Nullable + private OidcCallback getRequestCallback() { + OidcCallback machine = getOidcCallbackMechanismProperty(OIDC_CALLBACK_KEY); + return machine != null ? machine : getOidcCallbackMechanismProperty(OIDC_HUMAN_CALLBACK_KEY); } @Override @@ -195,7 +220,7 @@ private void authenticationLoop(final InternalConnection connection, final Conne try { super.authenticate(connection, description); break; - } catch (MongoSecurityException e) { + } catch (Exception e) { if (triggersRetry(e) && shouldRetryHandler()) { continue; } @@ -219,17 +244,66 @@ private byte[] evaluate(final byte[] challenge) { } byte[][] jwt = new byte[1][]; Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { + OidcCacheEntry oidcCacheEntry = getMongoCredentialWithCache().getOidcCacheEntry(); + String cachedRefreshToken = oidcCacheEntry.getRefreshToken(); + IdpInfo cachedIdpInfo = oidcCacheEntry.getIdpInfo(); String cachedAccessToken = validatedCachedAccessToken(); + OidcCallback requestCallback = assertNotNull(getRequestCallback()); + boolean isHuman = isHumanCallback(); if (cachedAccessToken != null) { - jwt[0] = prepareTokenAsJwt(cachedAccessToken); fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; + jwt[0] = prepareTokenAsJwt(cachedAccessToken); + } else if (cachedRefreshToken != null) { + // cached refresh token is only set when isHuman + // original IDP info will be present, if refresh token present + assertNotNull(cachedIdpInfo); + // Invoke Callback using cached Refresh Token + fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; + OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( + CALLBACK_TIMEOUT, cachedIdpInfo, cachedRefreshToken)); + jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result); } else { // cache is empty - OidcRequestCallback requestCallback = assertNotNull(getRequestCallback()); - RequestCallbackResult result = requestCallback.onRequest(new OidcRequestContextImpl(CALLBACK_TIMEOUT)); - jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(result); - fallbackState = FallbackState.PHASE_2_CALLBACK_TOKEN; + + if (!isHuman) { + // no principal request + fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN; + OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( + CALLBACK_TIMEOUT)); + jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(null, result); + if (result.getRefreshToken() != null) { + throw new MongoConfigurationException( + "Refresh token must only be provided in human workflow"); + } + } else { + /* + A check for present idp info short-circuits phase-3a. + If a challenge is present, it can only be a response to a + "principal-request", so the challenge must be the resulting + idp info. Such a request is made during speculative auth, + though the source is unimportant, as long as we detect and + use it here. + */ + boolean idpInfoNotPresent = challenge.length == 0; + /* + Checking that the fallback state is not phase-3a ensures that + this does not loop infinitely in the case of a bug. + */ + boolean alreadyTriedPrincipal = fallbackState == FallbackState.PHASE_3A_PRINCIPAL; + if (!alreadyTriedPrincipal && idpInfoNotPresent) { + // request for idp info, only in the human workflow + fallbackState = FallbackState.PHASE_3A_PRINCIPAL; + jwt[0] = prepareUsername(getMongoCredentialWithCache().getCredential().getUserName()); + } else { + IdpInfo idpInfo = toIdpInfo(challenge); + // there is no cached refresh token + fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN; + OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( + CALLBACK_TIMEOUT, idpInfo, null)); + jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result); + } + } } }); return jwt[0]; @@ -255,19 +329,35 @@ private String validatedCachedAccessToken() { return cachedAccessToken; } - private boolean isAutomaticAuthentication() { - return getRequestCallback() == null; - } - private boolean clientIsComplete() { - return true; // all possibilities are 1-step + return fallbackState != FallbackState.PHASE_3A_PRINCIPAL; } private boolean shouldRetryHandler() { + boolean[] result = new boolean[1]; Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { - validatedCachedAccessToken(); + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) { + // a cached access token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken()); + result[0] = true; + } else if (fallbackState == FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN) { + // a refresh token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + result[0] = true; + } else { + // a clean-restart failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + result[0] = false; + } }); - return fallbackState == FallbackState.PHASE_1_CACHED_TOKEN; + return result[0]; } @Nullable @@ -280,24 +370,29 @@ private String getCachedAccessToken() { static final class OidcCacheEntry { @Nullable private final String accessToken; + @Nullable + private final String refreshToken; + @Nullable + private final IdpInfo idpInfo; @Override public String toString() { return "OidcCacheEntry{" + "\n accessToken=[omitted]" + + ",\n refreshToken=[omitted]" + + ",\n idpInfo=" + idpInfo + '}'; } - OidcCacheEntry(final RequestCallbackResult requestCallbackResult) { - this.accessToken = requestCallbackResult.getAccessToken(); - } - OidcCacheEntry() { - this((String) null); + this(null, null, null); } - private OidcCacheEntry(@Nullable final String accessToken) { + private OidcCacheEntry(@Nullable final String accessToken, + @Nullable final String refreshToken, @Nullable final IdpInfo idpInfo) { this.accessToken = accessToken; + this.refreshToken = refreshToken; + this.idpInfo = idpInfo; } @Nullable @@ -305,8 +400,28 @@ String getCachedAccessToken() { return accessToken; } + @Nullable + String getRefreshToken() { + return refreshToken; + } + + @Nullable + IdpInfo getIdpInfo() { + return idpInfo; + } + OidcCacheEntry clearAccessToken() { - return new OidcCacheEntry((String) null); + return new OidcCacheEntry( + null, + this.refreshToken, + this.idpInfo); + } + + OidcCacheEntry clearRefreshToken() { + return new OidcCacheEntry( + this.accessToken, + null, + null); } } @@ -343,13 +458,70 @@ private static String readAwsTokenFromFile() { } } - private byte[] populateCacheWithCallbackResultAndPrepareJwt(@Nullable final RequestCallbackResult requestCallbackResult) { - if (requestCallbackResult == null) { + private byte[] populateCacheWithCallbackResultAndPrepareJwt( + @Nullable final IdpInfo serverInfo, + @Nullable final OidcCallbackResult oidcCallbackResult) { + if (oidcCallbackResult == null) { throw new MongoConfigurationException("Result of callback must not be null"); } - OidcCacheEntry newEntry = new OidcCacheEntry(requestCallbackResult); + OidcCacheEntry newEntry = new OidcCacheEntry(oidcCallbackResult.getAccessToken(), + oidcCallbackResult.getRefreshToken(), serverInfo); getMongoCredentialWithCache().setOidcCacheEntry(newEntry); - return prepareTokenAsJwt(requestCallbackResult.getAccessToken()); + return prepareTokenAsJwt(oidcCallbackResult.getAccessToken()); + } + + private static byte[] prepareUsername(@Nullable final String username) { + BsonDocument document = new BsonDocument(); + if (username != null) { + document = document.append("n", new BsonString(username)); + } + return toBson(document); + } + + private IdpInfo toIdpInfo(final byte[] challenge) { + // validate here to prevent creating IdpInfo for unauthorized hosts + validateAllowedHosts(getMongoCredential()); + BsonDocument c = new RawBsonDocument(challenge); + String issuer = c.getString("issuer").getValue(); + String clientId = c.getString("clientId").getValue(); + return new IdpInfoImpl( + issuer, + clientId, + getStringArray(c, "requestScopes")); + } + + + @Nullable + private static List getStringArray(final BsonDocument document, final String key) { + if (!document.isArray(key)) { + return null; + } + return document.getArray(key).stream() + // ignore non-string values from server, rather than error + .filter(v -> v.isString()) + .map(v -> v.asString().getValue()) + .collect(Collectors.toList()); + } + + private void validateAllowedHosts(final MongoCredential credential) { + List allowedHosts = assertNotNull(credential.getMechanismProperty(ALLOWED_HOSTS_KEY, DEFAULT_ALLOWED_HOSTS)); + String host = assertNotNull(serverAddress).getHost(); + boolean permitted = allowedHosts.stream().anyMatch(allowedHost -> { + if (allowedHost.startsWith("*.")) { + String ending = allowedHost.substring(1); + return host.endsWith(ending); + } else if (allowedHost.contains("*")) { + throw new IllegalArgumentException( + "Allowed host " + allowedHost + " contains invalid wildcard"); + } else { + return host.equals(allowedHost); + } + }); + if (!permitted) { + throw new MongoSecurityException( + credential, "Host " + host + " not permitted by " + ALLOWED_HOSTS_KEY + + ", values: " + allowedHosts); + } } private byte[] prepareTokenAsJwt(final String accessToken) { @@ -400,32 +572,59 @@ public static void validateCreateOidcCredential(@Nullable final char[] password) public static void validateBeforeUse(final MongoCredential credential) { String userName = credential.getUserName(); Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null); - Object requestCallback = credential.getMechanismProperty(OIDC_CALLBACK_KEY, null); + Object machineCallback = credential.getMechanismProperty(OIDC_CALLBACK_KEY, null); + Object humanCallback = credential.getMechanismProperty(OIDC_HUMAN_CALLBACK_KEY, null); if (providerName == null) { // callback - if (requestCallback == null) { - throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY + " or " - + OIDC_CALLBACK_KEY + " must be specified"); + if (machineCallback == null && humanCallback == null) { + throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY + + " or " + OIDC_CALLBACK_KEY + + " or " + OIDC_HUMAN_CALLBACK_KEY + + " must be specified"); + } + if (machineCallback != null && humanCallback != null) { + throw new IllegalArgumentException("Both " + OIDC_CALLBACK_KEY + + " and " + OIDC_HUMAN_CALLBACK_KEY + + " must not be specified"); } } else { - // automatic if (userName != null) { throw new IllegalArgumentException("user name must not be specified when " + PROVIDER_NAME_KEY + " is specified"); } - if (requestCallback != null) { + if (machineCallback != null) { throw new IllegalArgumentException(OIDC_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); } + if (humanCallback != null) { + throw new IllegalArgumentException(OIDC_HUMAN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } } } } @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) - static class OidcRequestContextImpl implements OidcRequestContext { + static class OidcCallbackContextImpl implements OidcCallbackContext { private final Duration timeout; + @Nullable + private final IdpInfo idpInfo; + @Nullable + private final String refreshToken; - OidcRequestContextImpl(final Duration timeout) { + OidcCallbackContextImpl(final Duration timeout) { this.timeout = assertNotNull(timeout); + this.idpInfo = null; + this.refreshToken = null; + } + + OidcCallbackContextImpl(final Duration timeout, final IdpInfo idpInfo, @Nullable final String refreshToken) { + this.timeout = assertNotNull(timeout); + this.idpInfo = assertNotNull(idpInfo); + this.refreshToken = refreshToken; + } + + @Override + public IdpInfo getIdpInfo() { + return idpInfo; } @Override @@ -437,6 +636,41 @@ public Duration getTimeout() { public int getVersion() { return CALLBACK_API_VERSION_NUMBER; } + + @Override + public String getRefreshToken() { + return refreshToken; + } + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static final class IdpInfoImpl implements IdpInfo { + private final String issuer; + private final String clientId; + private final List requestScopes; + + IdpInfoImpl(final String issuer, final String clientId, @Nullable final List requestScopes) { + this.issuer = assertNotNull(issuer); + this.clientId = assertNotNull(clientId); + this.requestScopes = requestScopes == null + ? Collections.emptyList() + : Collections.unmodifiableList(requestScopes); + } + + @Override + public String getIssuer() { + return issuer; + } + + @Override + public String getClientId() { + return clientId; + } + + @Override + public List getRequestScopes() { + return requestScopes; + } } /** @@ -445,6 +679,8 @@ public int getVersion() { private enum FallbackState { INITIAL, PHASE_1_CACHED_TOKEN, - PHASE_2_CALLBACK_TOKEN + PHASE_2_REFRESH_CALLBACK_TOKEN, + PHASE_3A_PRINCIPAL, + PHASE_3B_CALLBACK_TOKEN } } diff --git a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java index 4da83dc7d4f..cab5b0e0365 100644 --- a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java +++ b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java @@ -119,7 +119,7 @@ private MongoCredential getMongoCredential() { if ("oidcRequest".equals(string)) { credential = credential.withMechanismProperty( OIDC_CALLBACK_KEY, - (MongoCredential.OidcRequestCallback) (context) -> null); + (MongoCredential.OidcCallback) (context) -> null); } else { fail("Unsupported callback: " + string); } @@ -176,7 +176,7 @@ private void assertMechanismProperties(final MongoCredential credential) { } else if ((document.get(key).isBoolean())) { boolean expectedValue = document.getBoolean(key).getValue(); if (OIDC_CALLBACK_KEY.equals(key)) { - assertTrue(actualMechanismProperty instanceof MongoCredential.OidcRequestCallback); + assertTrue(actualMechanismProperty instanceof MongoCredential.OidcCallback); return; } assertNotNull(actualMechanismProperty); diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java index 5f42066aada..7e83f802279 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java @@ -73,7 +73,6 @@ import org.bson.BsonDouble; import org.bson.BsonInt32; import org.bson.BsonInt64; -import org.bson.BsonNumber; import org.bson.BsonString; import org.bson.BsonValue; @@ -82,6 +81,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -542,7 +542,7 @@ private void initClient(final BsonDocument entity, final String id, if (isOidc && hasPlaceholder) { clientSettingsBuilder.credential(credential.withMechanismProperty( MongoCredential.OIDC_CALLBACK_KEY, - (MongoCredential.OidcRequestCallback) context -> { + (MongoCredential.OidcCallback) context -> { Path path = Paths.get(getenv(OidcAuthenticator.AWS_WEB_IDENTITY_TOKEN_FILE)); String accessToken; try { @@ -550,7 +550,7 @@ private void initClient(final BsonDocument entity, final String id, } catch (IOException e) { throw new RuntimeException(e); } - return new MongoCredential.RequestCallbackResult(accessToken); + return new MongoCredential.OidcCallbackResult(accessToken, Duration.ZERO); })); break; } diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java index 66b6a305297..b5a87a51cef 100644 --- a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -16,13 +16,13 @@ package com.mongodb.internal.connection; -import com.mongodb.ClusterFixture; import com.mongodb.ConnectionString; import com.mongodb.MongoClientSettings; import com.mongodb.MongoCommandException; import com.mongodb.MongoConfigurationException; import com.mongodb.MongoCredential; -import com.mongodb.MongoCredential.RequestCallbackResult; +import com.mongodb.MongoSecurityException; +import com.mongodb.MongoSocketException; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; import com.mongodb.client.TestListener; @@ -49,21 +49,27 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Random; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import java.util.stream.Collectors; -import static com.mongodb.MongoCredential.OidcRequestCallback; -import static com.mongodb.MongoCredential.OidcRequestContext; +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.OIDC_HUMAN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OidcCallbackResult; +import static com.mongodb.MongoCredential.OidcCallback; +import static com.mongodb.MongoCredential.OidcCallbackContext; import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; import static java.lang.System.getenv; import static java.util.Arrays.asList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; import static util.ThreadTestHelpers.executeAll; @@ -81,15 +87,37 @@ public static boolean oidcTestsEnabled() { private String appName; protected static String getOidcUri() { - ConnectionString cs = ClusterFixture.getConnectionString(); - // remove username and password + ConnectionString cs = new ConnectionString(getenv("OIDC_ATLAS_URI_SINGLE")); + // remove any username and password return "mongodb+srv://" + cs.getHosts().get(0) + "/?authMechanism=MONGODB-OIDC"; } + protected static String getOidcUri(final String username) { + ConnectionString cs = new ConnectionString(getenv("OIDC_ATLAS_URI_SINGLE")); + // set username + return "mongodb+srv://" + username + "@" + cs.getHosts().get(0) + "/?authMechanism=MONGODB-OIDC"; + } + + protected static String getOidcUriMulti(@Nullable final String username) { + ConnectionString cs = new ConnectionString(getenv("OIDC_ATLAS_URI_MULTI")); + // set username + String userPart = username == null ? "" : username + "@"; + return "mongodb+srv://" + userPart + cs.getHosts().get(0) + "/?authMechanism=MONGODB-OIDC"; + } + private static String getAwsOidcUri() { return getOidcUri() + "&authMechanismProperties=PROVIDER_NAME:aws"; } + @NotNull + private static String oidcTokenDirectory() { + return getenv("OIDC_TOKEN_DIR"); + } + + private static String getAwsTokenFilePath() { + return getenv(OidcAuthenticator.AWS_WEB_IDENTITY_TOKEN_FILE); + } + protected MongoClient createMongoClient(final MongoClientSettings settings) { return MongoClients.create(settings); } @@ -147,7 +175,7 @@ public void test2p1ValidCallbackInputs() { TestCallback onRequest = createCallback(); // #. Verify that the request callback was called with the appropriate // inputs, including the timeout parameter if possible. - OidcRequestCallback onRequest2 = (context) -> { + OidcCallback onRequest2 = (context) -> { assertEquals(expectedSeconds, context.getTimeout()); return onRequest.onRequest(context); }; @@ -162,7 +190,7 @@ public void test2p1ValidCallbackInputs() { @Test public void test2p2RequestCallbackReturnsNull() { //noinspection ConstantConditions - OidcRequestCallback onRequest = (context) -> null; + OidcCallback onRequest = (context) -> null; MongoClientSettings settings = this.createSettings(getOidcUri(), onRequest, null); performFind(settings, MongoConfigurationException.class, "Result of callback must not be null"); } @@ -171,9 +199,9 @@ public void test2p2RequestCallbackReturnsNull() { public void test2p3CallbackReturnsMissingData() { // #. Create a client with a request callback that returns data not // conforming to the OIDCRequestTokenResult with missing field(s). - OidcRequestCallback onRequest = (context) -> { + OidcCallback onRequest = (context) -> { //noinspection ConstantConditions - return new RequestCallbackResult(null); + return new OidcCallbackResult(null, Duration.ZERO); }; // we ensure that the error is propagated MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest, null); @@ -205,8 +233,8 @@ public void test2p4InvalidClientConfigurationWithCallback() { public void test3p1AuthFailsWithCachedToken() throws ExecutionException, InterruptedException, NoSuchFieldException, IllegalAccessException { TestCallback onRequestWrapped = createCallback(); CompletableFuture poisonToken = new CompletableFuture<>(); - OidcRequestCallback onRequest = (context) -> { - RequestCallbackResult result = onRequestWrapped.onRequest(context); + OidcCallback onRequest = (context) -> { + OidcCallbackResult result = onRequestWrapped.onRequest(context); String accessToken = result.getAccessToken(); if (!poisonToken.isDone()) { poisonToken.complete(accessToken); @@ -240,7 +268,7 @@ public void test3p1AuthFailsWithCachedToken() throws ExecutionException, Interru @Test public void test3p2AuthFailsWithoutCachedToken() { MongoClientSettings clientSettings = createSettings(getOidcUri(), - (x) -> new RequestCallbackResult("invalid_token"), null); + (x) -> new OidcCallbackResult("invalid_token", Duration.ZERO), null); try (MongoClient mongoClient = createMongoClient(clientSettings)) { try { performFind(mongoClient); @@ -252,7 +280,6 @@ public void test3p2AuthFailsWithoutCachedToken() { } } - @Test public void test4p1Reauthentication() { TestCallback onRequest = createCallback(); @@ -265,19 +292,328 @@ public void test4p1Reauthentication() { assertEquals(2, onRequest.invocations.get()); } + // Tests for human authentication ("testh", to preserve ordering) + + @Test + public void testh1p1SinglePrincipalImplicitUsername() { + // #. Create default OIDC client with authMechanism=MONGODB-OIDC. + String oidcUri = getOidcUri(); + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + // #. Perform a find operation that succeeds + performFind(clientSettings); + assertEquals(1, callback.invocations.get()); + } + + @Test + public void testh1p2SinglePrincipalExplicitUsername() { + // #. Create a client with MONGODB_URI_SINGLE, a username of test_user1, + // authMechanism=MONGODB-OIDC, and the OIDC human callback. + String oidcUri = getOidcUri("test_user1"); + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + // #. Perform a find operation that succeeds + performFind(clientSettings); + } + + @Test + public void testh1p3MultiplePrincipalUser1() { + // #. Create a client with MONGODB_URI_MULTI, a username of test_user1, + // authMechanism=MONGODB-OIDC, and the OIDC human callback. + String oidcUri = getOidcUriMulti("test_user1"); + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + // #. Perform a find operation that succeeds + performFind(clientSettings); + } + + @Test + public void testh1p4MultiplePrincipalUser2() { + //- Create a human callback that reads in the generated ``test_user2`` token file. + //- Create a client with ``MONGODB_URI_MULTI``, a username of ``test_user2``, + // ``authMechanism=MONGODB-OIDC``, and the OIDC human callback. + String oidcUri = getOidcUriMulti("test_user2"); + TestCallback callback = createHumanCallback() + .setPathSupplier(() -> tokenQueue("test_user2").remove()); + MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + // #. Perform a find operation that succeeds + performFind(clientSettings); + } + + @Test + public void testh1p5MultiplePrincipalNoUser() { + //- Create a client with ``MONGODB_URI_MULTI``, no username, + // ``authMechanism=MONGODB-OIDC``, and the OIDC human callback. + String oidcUri = getOidcUriMulti(null); + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + // #. Perform a find operation that succeeds + performFind(clientSettings, MongoCommandException.class, "Authentication failed"); + } + + @Test + public void testh1p6AllowedHostsBlocked() { + //- Create a default OIDC client, with an ``ALLOWED_HOSTS`` that is an empty list. + //- Assert that a ``find`` operation fails with a client-side error. + MongoClientSettings settings1 = createSettings( + getOidcUri(), + createHumanCallback(), null, OIDC_HUMAN_CALLBACK_KEY, Collections.emptyList()); + performFind(settings1, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS"); + + //- Create a client that uses the URL + // ``mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com``, a + // human callback, and an ``ALLOWED_HOSTS`` that contains ``["example.com"]``. + //- Assert that a ``find`` operation fails with a client-side error. + MongoClientSettings settings2 = createSettings( + getOidcUri() + "&ignored=example.com", + createHumanCallback(), null, OIDC_HUMAN_CALLBACK_KEY, Arrays.asList("example.com")); + performFind(settings2, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS"); + } + + // Not a prose test + @Test + public void testAllowedHostsDisallowedInConnectionString() { + String string = "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:localhost"; + assertCause(IllegalArgumentException.class, + "connection string contains disallowed mechanism properties", + () -> new ConnectionString(string)); + } + + @Test + public void testh2p1ValidCallbackInputs() { + TestCallback onRequest = createHumanCallback(); + OidcCallback onRequest2 = (context) -> { + assertTrue(context.getIdpInfo().getClientId().startsWith("0oad")); + assertTrue(context.getIdpInfo().getIssuer().endsWith("mock-identity-config-oidc")); + assertEquals(Arrays.asList("fizz", "buzz"), context.getIdpInfo().getRequestScopes()); + assertEquals(Duration.ofMinutes(5), context.getTimeout()); + return onRequest.onRequest(context); + }; + MongoClientSettings clientSettings = createHumanSettings(getOidcUri(), onRequest2, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + // Ensure that callback was called + assertEquals(1, onRequest.getInvocations()); + } + } + + @Test + public void testh2p2HumanCallbackReturnsMissingData() { + //noinspection ConstantConditions + OidcCallback onRequestNull = (context) -> null; + performFind(createHumanSettings(getOidcUri(), onRequestNull, null), + MongoConfigurationException.class, + "Result of callback must not be null"); + + //noinspection ConstantConditions + OidcCallback onRequest = (context) -> new OidcCallbackResult(null, Duration.ZERO); + performFind(createHumanSettings(getOidcUri(), onRequest, null), + IllegalArgumentException.class, + "accessToken can not be null"); + + // additionally, check validation for refresh in machine workflow: + OidcCallback onRequestMachineRefresh = (context) -> new OidcCallbackResult("access", Duration.ZERO, "exists"); + performFind(createSettings(getOidcUri(), onRequestMachineRefresh, null), + MongoConfigurationException.class, + "Refresh token must only be provided in human workflow"); + } + + @Test + public void testh3p1UsesSpecAuthIfCachedToken() { + failCommandAndCloseConnection("find", 1); + MongoClientSettings settings = createHumanSettings(getOidcUri(), createHumanCallback(), null); + + try (MongoClient mongoClient = createMongoClient(settings)) { + assertCause(MongoSocketException.class, + "Prematurely reached end of stream", + () -> performFind(mongoClient)); + failCommand(20, 99, "saslStart"); + + performFind(mongoClient); + } + } + + @Test + public void testh3p2NoSpecAuthIfNoCachedToken() { + failCommand(20, 99, "saslStart"); + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + performFind(createHumanSettings(getOidcUri(), createHumanCallback(), commandListener), + MongoCommandException.class, + "Command failed with error 20"); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "saslStart started", + "saslStart failed" + ), listener.getEventStrings()); + listener.clear(); + } + + @Test + public void testh4p1Succeeds() { + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + TestCallback callback = createHumanCallback() + .setEventListener(listener); + MongoClientSettings settings = createHumanSettings(getOidcUri(), callback, commandListener); + try (MongoClient mongoClient = createMongoClient(settings)) { + performFind(mongoClient); + listener.clear(); + assertEquals(1, callback.getInvocations()); + + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(Arrays.asList( + // first find fails: + "find started", + "find failed", + "onRequest invoked (Refresh Token: present - IdpInfo: present)", + "read access token: test_user1", + "saslStart started", + "saslStart succeeded", + // second find succeeds: + "find started", + "find succeeded" + ), listener.getEventStrings()); + assertEquals(2, callback.getInvocations()); + } + } + + @Test + public void testh4p2SucceedsNoRefresh() { + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + TestCallback callback = createHumanCallback().setEventListener(listener); + MongoClientSettings settings = createHumanSettings(getOidcUri(), callback, commandListener); + try (MongoClient mongoClient = createMongoClient(settings)) { + + performFind(mongoClient); + listener.clear(); + assertEquals(1, callback.getInvocations()); + + failCommand(391, 1, "find"); + performFind(mongoClient); + } + } + + + // TODO-OIDC awaiting spec updates, add 4.3 and 4.4 + + // Not a prose test + @Test + public void testErrorClearsCache() { + // #. Create a new client with a valid request callback that + // gives credentials that expire within 5 minutes and + // a refresh callback that gives invalid credentials. + TestListener listener = new TestListener(); + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_expires", + "test_user1_expires", + "test_user1_1"); + TestCallback onRequest = createHumanCallback() + .setRefreshToken("refresh") + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + + TestCommandListener commandListener = new TestCommandListener(listener); + + MongoClientSettings clientSettings = createHumanSettings(getOidcUri(), onRequest, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Ensure that a find operation adds a new entry to the cache. + performFind(mongoClient); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + // no speculative auth. Send principal request: + "saslStart started", + "saslStart succeeded", + "onRequest invoked (Refresh Token: none - IdpInfo: present)", + "read access token: test_user1", + // the refresh token from the callback is cached here + // send jwt: + "saslContinue started", + "saslContinue succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that a subsequent find operation results in a 391 error. + failCommand(391, 1, "find"); + // ensure that the operation entirely fails, after attempting both potential fallback callbacks + assertThrows(MongoSecurityException.class, () -> performFind(mongoClient)); + assertEquals(Arrays.asList( + "find started", + "find failed", // reauth 391; current access token is invalid + // fall back to refresh token, from prior find + "onRequest invoked (Refresh Token: present - IdpInfo: present)", + "read access token: test_user1_expires", + "saslStart started", + "saslStart failed", // it is expired, fails immediately + // fall back to principal request, and non-refresh callback: + "saslStart started", + "saslStart succeeded", + "onRequest invoked (Refresh Token: none - IdpInfo: present)", + "read access token: test_user1_expires", + "saslContinue started", + "saslContinue failed" // also fails due to 391 + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that the cache value cleared. + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(Arrays.asList( + "find started", + "find failed", + // falling back to principal request, onRequest callback. + // this implies that the cache has been cleared during the + // preceding find operation. + "saslStart started", + "saslStart succeeded", + "onRequest invoked (Refresh Token: none - IdpInfo: present)", + "read access token: test_user1_1", + "saslContinue started", + "saslContinue succeeded", + // auth has finished + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + } + } + public MongoClientSettings createSettings( final String connectionString, - @Nullable final OidcRequestCallback onRequest) { + @Nullable final OidcCallback onRequest) { return createSettings(connectionString, onRequest, null); } private MongoClientSettings createSettings( final String connectionString, - @Nullable final OidcRequestCallback onRequest, + @Nullable final OidcCallback callback, @Nullable final CommandListener commandListener) { + return createSettings(connectionString, callback, commandListener, OIDC_CALLBACK_KEY); + } + + private MongoClientSettings createHumanSettings( + final String connectionString, + @Nullable final OidcCallback callback, + @Nullable final CommandListener commandListener) { + return createSettings(connectionString, callback, commandListener, OIDC_HUMAN_CALLBACK_KEY); + } + + @NotNull + private MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcCallback onRequest, + @Nullable final CommandListener commandListener, + final String oidcCallbackKey) { ConnectionString cs = new ConnectionString(connectionString); MongoCredential credential = cs.getCredential() - .withMechanismProperty(OIDC_CALLBACK_KEY, onRequest); + .withMechanismProperty(oidcCallbackKey, onRequest); MongoClientSettings.Builder builder = MongoClientSettings.builder() .applicationName(appName) .applyConnectionString(cs) @@ -289,6 +625,26 @@ private MongoClientSettings createSettings( return builder.build(); } + private MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcCallback onRequest, + @Nullable final CommandListener commandListener, + final String oidcCallbackKey, + @Nullable final List allowedHosts) { + ConnectionString cs = new ConnectionString(connectionString); + MongoCredential credential = cs.getCredential() + .withMechanismProperty(oidcCallbackKey, onRequest) + .withMechanismProperty(ALLOWED_HOSTS_KEY, allowedHosts); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applicationName(appName) + .applyConnectionString(cs) + .credential(credential); + if (commandListener != null) { + builder.addCommandListener(commandListener); + } + return builder.build(); + } + private void performFind(final MongoClientSettings settings) { try (MongoClient mongoClient = createMongoClient(settings)) { performFind(mongoClient); @@ -333,7 +689,8 @@ private static void assertCause( } protected void delayNextFind() { - try (MongoClient client = createMongoClient(createSettings(getAwsOidcUri(), null, null))) { + try (MongoClient client = createMongoClient(createSettings( + getAwsOidcUri(), null, null))) { BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) .append("mode", new BsonDocument("times", new BsonInt32(1))) .append("data", new BsonDocument() @@ -359,11 +716,27 @@ protected void failCommand(final int code, final int times, final String... comm } } - public static class TestCallback implements OidcRequestCallback { + private void failCommandAndCloseConnection(final String command, final int times) { + try (MongoClient mongoClient = createMongoClient(createSettings( + getAwsOidcUri(), null, null))) { + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(times))) + .append("data", new BsonDocument() + .append("appName", new BsonString(appName)) + .append("closeConnection", new BsonBoolean(true)) + .append("failCommands", new BsonArray(Arrays.asList(new BsonString(command)))) + ); + mongoClient.getDatabase("admin").runCommand(failPointDocument); + } + } + + public static class TestCallback implements OidcCallback { private final AtomicInteger invocations = new AtomicInteger(); @Nullable private final Integer delayInMilliseconds; @Nullable + private final String refreshToken; + @Nullable private final AtomicInteger concurrentTracker; @Nullable private final TestListener testListener; @@ -371,14 +744,16 @@ public static class TestCallback implements OidcRequestCallback { private final Supplier pathSupplier; public TestCallback() { - this(null, new AtomicInteger(), null, null); + this(null, null, new AtomicInteger(), null, null); } public TestCallback( + @Nullable final String refreshToken, @Nullable final Integer delayInMilliseconds, @Nullable final AtomicInteger concurrentTracker, @Nullable final TestListener testListener, @Nullable final Supplier pathSupplier) { + this.refreshToken = refreshToken; this.delayInMilliseconds = delayInMilliseconds; this.concurrentTracker = concurrentTracker; this.testListener = testListener; @@ -390,15 +765,18 @@ public int getInvocations() { } @Override - public RequestCallbackResult onRequest(final OidcRequestContext context) { + public OidcCallbackResult onRequest(final OidcCallbackContext context) { if (testListener != null) { - testListener.add("onRequest invoked"); + testListener.add("onRequest invoked (" + + "Refresh Token: " + (context.getRefreshToken() == null ? "none" : "present") + + " - IdpInfo: " + (context.getIdpInfo() == null ? "none" : "present") + + ")"); } return callback(); } @NotNull - private RequestCallbackResult callback() { + private OidcCallbackResult callback() { if (concurrentTracker != null) { if (concurrentTracker.get() > 0) { throw new RuntimeException("Callbacks should not be invoked by multiple threads."); @@ -408,7 +786,7 @@ private RequestCallbackResult callback() { try { invocations.incrementAndGet(); Path path = Paths.get(pathSupplier == null - ? getenv(OidcAuthenticator.AWS_WEB_IDENTITY_TOKEN_FILE) + ? getAwsTokenFilePath() : pathSupplier.get()); String accessToken; try { @@ -420,7 +798,7 @@ private RequestCallbackResult callback() { if (testListener != null) { testListener.add("read access token: " + path.getFileName()); } - return new RequestCallbackResult(accessToken); + return new OidcCallbackResult(accessToken, Duration.ZERO, refreshToken); } finally { if (concurrentTracker != null) { concurrentTracker.decrementAndGet(); @@ -436,6 +814,7 @@ private void simulateDelay() throws InterruptedException { public TestCallback setDelayMs(final int milliseconds) { return new TestCallback( + this.refreshToken, milliseconds, this.concurrentTracker, this.testListener, @@ -444,6 +823,7 @@ public TestCallback setDelayMs(final int milliseconds) { public TestCallback setConcurrentTracker(final AtomicInteger c) { return new TestCallback( + this.refreshToken, this.delayInMilliseconds, c, this.testListener, @@ -452,6 +832,7 @@ public TestCallback setConcurrentTracker(final AtomicInteger c) { public TestCallback setEventListener(final TestListener testListener) { return new TestCallback( + this.refreshToken, this.delayInMilliseconds, this.concurrentTracker, testListener, @@ -460,14 +841,38 @@ public TestCallback setEventListener(final TestListener testListener) { public TestCallback setPathSupplier(final Supplier pathSupplier) { return new TestCallback( + this.refreshToken, this.delayInMilliseconds, this.concurrentTracker, this.testListener, pathSupplier); } + public TestCallback setRefreshToken(final String token) { + return new TestCallback( + token, + this.delayInMilliseconds, + this.concurrentTracker, + this.testListener, + this.pathSupplier); + } + } + + @NotNull + private ConcurrentLinkedQueue tokenQueue(final String... queue) { + String tokenPath = oidcTokenDirectory(); + return java.util.stream.Stream + .of(queue) + .map(v -> tokenPath + v) + .collect(Collectors.toCollection(ConcurrentLinkedQueue::new)); } public TestCallback createCallback() { return new TestCallback(); } + + public TestCallback createHumanCallback() { + return new TestCallback() + .setPathSupplier(() -> oidcTokenDirectory() + "test_user1") + .setRefreshToken("refreshToken"); + } }