diff --git a/pom.xml b/pom.xml
index e90411a01..06954ae37 100644
--- a/pom.xml
+++ b/pom.xml
@@ -197,6 +197,12 @@
0.1.0-SNAPSHOT
test
+
+ io.github.cdimascio
+ dotenv-java
+ 2.2.0
+ test
+
@@ -210,6 +216,11 @@
redis-authx-entraid
test
+
+ io.github.cdimascio
+ dotenv-java
+ test
+
diff --git a/src/main/java/io/lettuce/authx/TokenBasedRedisCredentialsProvider.java b/src/main/java/io/lettuce/authx/TokenBasedRedisCredentialsProvider.java
index ca86891c1..4753010ca 100644
--- a/src/main/java/io/lettuce/authx/TokenBasedRedisCredentialsProvider.java
+++ b/src/main/java/io/lettuce/authx/TokenBasedRedisCredentialsProvider.java
@@ -10,7 +10,7 @@
import redis.clients.authentication.core.TokenListener;
import redis.clients.authentication.core.TokenManager;
-public class TokenBasedRedisCredentialsProvider implements StreamingCredentialsProvider {
+public class TokenBasedRedisCredentialsProvider implements StreamingCredentialsProvider, AutoCloseable {
private final TokenManager tokenManager;
@@ -94,7 +94,8 @@ public Flux credentials() {
* This method stops the TokenManager and completes the credentials sink, ensuring that all resources are properly released.
* It should be called when the credentials provider is no longer needed.
*/
- public void shutdown() {
+ @Override
+ public void close() {
credentialsSink.tryEmitComplete();
tokenManager.stop();
}
diff --git a/src/test/java/io/lettuce/authx/EntraIdIntegrationTests.java b/src/test/java/io/lettuce/authx/EntraIdIntegrationTests.java
new file mode 100644
index 000000000..84100f577
--- /dev/null
+++ b/src/test/java/io/lettuce/authx/EntraIdIntegrationTests.java
@@ -0,0 +1,177 @@
+package io.lettuce.authx;
+
+import io.lettuce.core.ClientOptions;
+import io.lettuce.core.RedisClient;
+import io.lettuce.core.RedisFuture;
+import io.lettuce.core.RedisURI;
+import io.lettuce.core.SocketOptions;
+import io.lettuce.core.TimeoutOptions;
+import io.lettuce.core.TransactionResult;
+import io.lettuce.core.api.StatefulRedisConnection;
+import io.lettuce.core.api.async.RedisAsyncCommands;
+import io.lettuce.core.cluster.ClusterClientOptions;
+import io.lettuce.core.cluster.RedisClusterClient;
+import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
+import org.junit.jupiter.api.Assumptions;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import redis.clients.authentication.core.TokenAuthConfig;
+import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder;
+
+import java.time.Duration;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.fail;
+
+public class EntraIdIntegrationTests {
+
+ private static EntraIdTestContext testCtx = EntraIdTestContext.DEFAULT;;
+
+ @BeforeAll
+ public static void setup() {
+ Assumptions.assumeTrue(testCtx.host() != null && !testCtx.host().isEmpty(),
+ "Skipping EntraID tests. Redis host with enabled EntraId not provided!");
+ }
+
+ // T.1.1
+ // Verify authentication using Azure AD with service principals using Redis Standalone client
+ @Test
+ public void standaloneWithSecret_azureServicePrincipalIntegrationTest() throws ExecutionException, InterruptedException {
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder().clientId(testCtx.getClientId())
+ .secret(testCtx.getClientSecret()).authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build();
+
+ // Configure timeout options to assure fast test failover
+ ClientOptions clientOptions = ClientOptions.builder()
+ .socketOptions(SocketOptions.builder().connectTimeout(Duration.ofSeconds(1)).build())
+ .timeoutOptions(TimeoutOptions.enabled(Duration.ofSeconds(1)))
+ .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build();
+
+ try (TokenBasedRedisCredentialsProvider credentialsProvider = new TokenBasedRedisCredentialsProvider(tokenAuthConfig)) {
+ RedisURI uri = RedisURI.builder().withHost(testCtx.host()).withPort(testCtx.port())
+ .withAuthentication(credentialsProvider).build();
+
+ try (RedisClient client = RedisClient.create(uri)) {
+ client.setOptions(clientOptions);
+
+ try (StatefulRedisConnection connection = client.connect()) {
+ assertThat(connection.sync().aclWhoami()).isEqualTo(testCtx.getSpOID());
+ assertThat(connection.async().aclWhoami().get()).isEqualTo(testCtx.getSpOID());
+ assertThat(connection.reactive().aclWhoami().block()).isEqualTo(testCtx.getSpOID());
+ }
+ }
+ }
+ }
+
+ // T.1.1
+ // Verify authentication using Azure AD with service principals using Redis Cluster Client
+ @Test
+ public void clusterWithSecret_azureServicePrincipalIntegrationTest() throws ExecutionException, InterruptedException {
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder().clientId(testCtx.getClientId())
+ .secret(testCtx.getClientSecret()).authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build();
+
+ // Configure timeout options to assure fast test failover
+ ClusterClientOptions clientOptions = ClusterClientOptions.builder()
+ .socketOptions(SocketOptions.builder().connectTimeout(Duration.ofSeconds(1)).build())
+ .timeoutOptions(TimeoutOptions.enabled(Duration.ofSeconds(1)))
+ .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build();
+
+ try (TokenBasedRedisCredentialsProvider credentialsProvider = new TokenBasedRedisCredentialsProvider(tokenAuthConfig)) {
+ RedisURI uri = RedisURI.builder().withHost(testCtx.clusterHost().get(0)).withPort(testCtx.clusterPort())
+ .withAuthentication(credentialsProvider).build();
+
+ try (RedisClusterClient client = RedisClusterClient.create(uri)) {
+ client.setOptions(clientOptions);
+
+ try (StatefulRedisClusterConnection connection = client.connect()) {
+ assertThat(connection.sync().aclWhoami()).isEqualTo(testCtx.getSpOID());
+ assertThat(connection.async().aclWhoami().get()).isEqualTo(testCtx.getSpOID());
+ assertThat(connection.reactive().aclWhoami().block()).isEqualTo(testCtx.getSpOID());
+
+ connection.getPartitions().forEach((partition) -> {
+ try (StatefulRedisConnection, ?> nodeConnection = connection.getConnection(partition.getNodeId())) {
+ assertThat(nodeConnection.sync().aclWhoami()).isEqualTo(testCtx.getSpOID());
+ }
+ });
+ }
+ }
+ }
+ }
+
+ // T.2.2
+ // Test that the Redis client is not blocked/interrupted during token renewal.
+ @Test
+ public void renewalDuringOperationsTest() throws InterruptedException, ExecutionException {
+ TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder().clientId(testCtx.getClientId())
+ .secret(testCtx.getClientSecret()).authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes())
+ .expirationRefreshRatio(0.000001F).build();
+
+ // Configure timeout options to assure fast test failover
+ ClientOptions clientOptions = ClientOptions.builder()
+ .socketOptions(SocketOptions.builder().connectTimeout(Duration.ofSeconds(1)).build())
+ .timeoutOptions(TimeoutOptions.enabled(Duration.ofSeconds(1)))
+ .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build();
+
+ try (TokenBasedRedisCredentialsProvider credentialsProvider = new TokenBasedRedisCredentialsProvider(tokenAuthConfig)) {
+ RedisURI uri = RedisURI.builder().withHost(testCtx.host()).withPort(testCtx.port())
+ .withAuthentication(credentialsProvider).build();
+
+ try (RedisClient client = RedisClient.create(uri)) {
+ client.setOptions(clientOptions);
+
+ try (StatefulRedisConnection connection = client.connect()) {
+
+ // Counter to track the number of command cycles
+ AtomicInteger commandCycleCount = new AtomicInteger(0);
+
+ // Start a thread to continuously send Redis commands
+ Thread commandThread = new Thread(() -> {
+ try {
+ RedisAsyncCommands async = client.connect().async();
+ for (int i = 1; i <= 10; i++) {
+ // Start a transaction with SET and INCRBY commands
+ RedisFuture multi = async.multi();
+ RedisFuture set = async.set("key", "1");
+ RedisFuture incrby = async.incrby("key", 1);
+ RedisFuture exec = async.exec();
+ TransactionResult results = exec.get(1, TimeUnit.SECONDS);
+
+ // Increment the command cycle count after each execution
+ commandCycleCount.incrementAndGet();
+
+ // Verify the results from EXEC
+ assertThat(results).hasSize(2); // We expect 2 responses: SET and INCRBY
+
+ // Check the response from each command in the transaction
+ assertThat((String) results.get(0)).isEqualTo("OK"); // SET "key" = "1"
+ assertThat((Long) results.get(1)).isEqualTo(2L); // INCRBY "key" by 1, expected result is 2
+ }
+ } catch (Exception e) {
+ fail("Command execution failed during token refresh", e);
+ }
+ });
+
+ commandThread.start();
+
+ // Count token renewals directly within the main thread
+ AtomicInteger renewalCount = new AtomicInteger(0);
+ CountDownLatch latch = new CountDownLatch(10); // Wait for at least 10 token renewals
+
+ credentialsProvider.credentials().subscribe(cred -> {
+ latch.countDown(); // Signal each renewal as it's received
+ });
+
+ latch.await(1, TimeUnit.SECONDS); // Wait to reach 10 renewals
+ commandThread.join(); // Wait for the command thread to finish
+
+ // Verify that at least 10 command cycles were executed during the test
+ assertThat(commandCycleCount.get()).isGreaterThanOrEqualTo(10);
+ }
+ }
+ }
+ }
+
+}
diff --git a/src/test/java/io/lettuce/authx/EntraIdTestContext.java b/src/test/java/io/lettuce/authx/EntraIdTestContext.java
new file mode 100644
index 000000000..7abfac0fe
--- /dev/null
+++ b/src/test/java/io/lettuce/authx/EntraIdTestContext.java
@@ -0,0 +1,111 @@
+package io.lettuce.authx;
+
+import io.github.cdimascio.dotenv.Dotenv;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+public class EntraIdTestContext {
+
+ private static final String AZURE_CLIENT_ID = "AZURE_CLIENT_ID";
+
+ private static final String AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET";
+
+ private static final String AZURE_SP_OID = "AZURE_SP_OID";
+
+ private static final String AZURE_AUTHORITY = "AZURE_AUTHORITY";
+
+ private static final String AZURE_REDIS_SCOPES = "AZURE_REDIS_SCOPES";
+
+ private static final String REDIS_AZURE_HOST = "REDIS_AZURE_HOST";
+
+ private static final String REDIS_AZURE_PORT = "REDIS_AZURE_PORT";
+
+ private static final String REDIS_AZURE_CLUSTER_HOST = "REDIS_AZURE_CLUSTER_HOST";
+
+ private static final String REDIS_AZURE_CLUSTER_PORT = "REDIS_AZURE_CLUSTER_PORT";
+
+ private static final String REDIS_AZURE_DB = "REDIS_AZURE_DB";
+
+ private final String clientId;
+
+ private final String authority;
+
+ private final String clientSecret;
+
+ private final String spOID;
+
+ private final Set redisScopes;
+
+ private final String redisHost;
+
+ private final int redisPort;
+
+ private final List redisClusterHost;
+
+ private final int redisClusterPort;
+
+ private static Dotenv dotenv;
+ static {
+ dotenv = Dotenv.configure().directory("src/test/resources").filename(".env.entraid").load();
+ }
+
+ public static final EntraIdTestContext DEFAULT = new EntraIdTestContext();
+
+ private EntraIdTestContext() {
+ // Using Dotenv directly here
+ clientId = dotenv.get(AZURE_CLIENT_ID, "");
+ clientSecret = dotenv.get(AZURE_CLIENT_SECRET, "");
+ spOID = dotenv.get(AZURE_SP_OID, "");
+ authority = dotenv.get(AZURE_AUTHORITY, "https://login.microsoftonline.com/your-tenant-id");
+ redisHost = dotenv.get(REDIS_AZURE_HOST);
+ redisPort = Integer.parseInt(dotenv.get(REDIS_AZURE_PORT, "6379"));
+ redisClusterHost = Arrays.asList(dotenv.get(REDIS_AZURE_CLUSTER_HOST, "").split(","));
+ redisClusterPort = Integer.parseInt(dotenv.get(REDIS_AZURE_CLUSTER_PORT, "6379"));
+ String redisScopesEnv = dotenv.get(AZURE_REDIS_SCOPES, "https://redis.azure.com/.default");
+ if (redisScopesEnv != null && !redisScopesEnv.isEmpty()) {
+ this.redisScopes = new HashSet<>(Arrays.asList(redisScopesEnv.split(";")));
+ } else {
+ this.redisScopes = new HashSet<>();
+ }
+ }
+
+ public String host() {
+ return redisHost;
+ }
+
+ public int port() {
+ return redisPort;
+ }
+
+ public List clusterHost() {
+ return redisClusterHost;
+ }
+
+ public int clusterPort() {
+ return redisClusterPort;
+ }
+
+ public String getClientId() {
+ return clientId;
+ }
+
+ public String getSpOID() {
+ return spOID;
+ }
+
+ public String getAuthority() {
+ return authority;
+ }
+
+ public String getClientSecret() {
+ return clientSecret;
+ }
+
+ public Set getRedisScopes() {
+ return redisScopes;
+ }
+
+}
diff --git a/src/test/java/io/lettuce/authx/TokenBasedRedisCredentialsProviderTest.java b/src/test/java/io/lettuce/authx/TokenBasedRedisCredentialsProviderTest.java
index d6ff84648..da78dcfdd 100644
--- a/src/test/java/io/lettuce/authx/TokenBasedRedisCredentialsProviderTest.java
+++ b/src/test/java/io/lettuce/authx/TokenBasedRedisCredentialsProviderTest.java
@@ -102,7 +102,7 @@ public void shouldCompleteAllSubscribersOnStop() {
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
- credentialsProvider.shutdown();
+ credentialsProvider.close();
}).start();
StepVerifier.create(credentialsFlux1)
diff --git a/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java b/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java
index 618bb1a14..97e88218c 100644
--- a/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java
+++ b/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java
@@ -155,7 +155,7 @@ void tokenBasedCredentialProvider(RedisClient client) {
// verify that the connection is re-authenticated with the new user credentials
assertThat(connection.sync().aclWhoami()).isEqualTo("steave");
- credentialsProvider.shutdown();
+ credentialsProvider.close();
connection.close();
client.removeListener(listener);
client.setOptions(
diff --git a/src/test/java/io/lettuce/examples/TokenBasedAuthExample.java b/src/test/java/io/lettuce/examples/TokenBasedAuthExample.java
index ef1b51419..1c317b97a 100644
--- a/src/test/java/io/lettuce/examples/TokenBasedAuthExample.java
+++ b/src/test/java/io/lettuce/examples/TokenBasedAuthExample.java
@@ -129,9 +129,8 @@ public static void main(String[] args) throws Exception {
// Shutdown Redis client and close connections
redisClusterClient.shutdown();
} finally {
- credentialsUser1.shutdown();
- credentialsUser2.shutdown();
-
+ credentialsUser1.close();
+ credentialsUser2.close();
}
}
diff --git a/src/test/resources/.env.entraid b/src/test/resources/.env.entraid
new file mode 100644
index 000000000..016449e92
--- /dev/null
+++ b/src/test/resources/.env.entraid
@@ -0,0 +1,11 @@
+AZURE_SP_OID=
+AZURE_CLIENT_ID=
+AZURE_CLIENT_SECRET=
+AZURE_REDIS_SCOPES=https://redis.azure.com/.default
+AZURE_AUTHORITY=https://login.microsoftonline.com/
+# Redis standalone db with Azure enabled authentication
+REDIS_AZURE_HOST=
+REDIS_AZURE_PORT=6379
+# Redis cluster db with Azure enabled authentication & osscluster API enabled
+REDIS_AZURE_CLUSTER_HOST=
+REDIS_AZURE_CLUSTER_PORT=6379