Skip to content

Commit

Permalink
Introduce credentials provider (#3224)
Browse files Browse the repository at this point in the history
References:

1. #1602 and related PRs. Current PR is probably better than handling in JedisFactory 
2. redis/redis-py#2261 - main reason of this PR 
3. redis/lettuce#1774 
4. #632 

---

* Introduce credentials provider

* use volatile

* Test in Sentineled mode

* Support CharSequence in DefaultRedisCredentials

* Added doc for prepare() and cleanUp()

* Test the provider interface

* Added example

* Removed deprecations
  • Loading branch information
sazzad16 authored Feb 14, 2023
1 parent 1d898f1 commit d4644da
Show file tree
Hide file tree
Showing 15 changed files with 439 additions and 92 deletions.
7 changes: 3 additions & 4 deletions src/main/java/redis/clients/jedis/CommandArguments.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ public ProtocolCommand getCommand() {
}

public CommandArguments add(Object arg) {
if (arg instanceof Rawable) {
if (arg == null) {
throw new IllegalArgumentException("null is not a valid argument.");
} else if (arg instanceof Rawable) {
args.add((Rawable) arg);
} else if (arg instanceof byte[]) {
args.add(RawableFactory.from((byte[]) arg));
Expand All @@ -37,9 +39,6 @@ public CommandArguments add(Object arg) {
} else if (arg instanceof Boolean) {
args.add(RawableFactory.from(Integer.toString((Boolean) arg ? 1 : 0)));
} else {
if (arg == null) {
throw new IllegalArgumentException("null is not a valid argument.");
}
args.add(RawableFactory.from(String.valueOf(arg)));
}
return this;
Expand Down
51 changes: 35 additions & 16 deletions src/main/java/redis/clients/jedis/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import java.io.IOException;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;

import redis.clients.jedis.args.Rawable;
import redis.clients.jedis.commands.ProtocolCommand;
Expand Down Expand Up @@ -336,15 +340,16 @@ public List<Object> getMany(final int count) {
private void initializeFromClientConfig(JedisClientConfig config) {
try {
connect();
String password = config.getPassword();
if (password != null) {
String user = config.getUser();
if (user != null) {
auth(user, password);
} else {
auth(password);
}

Supplier<RedisCredentials> credentialsProvider = config.getCredentialsProvider();
if (credentialsProvider instanceof RedisCredentialsProvider) {
((RedisCredentialsProvider) credentialsProvider).prepare();
auth(credentialsProvider);
((RedisCredentialsProvider) credentialsProvider).cleanUp();
} else {
auth(credentialsProvider);
}

int dbIndex = config.getDatabase();
if (dbIndex > 0) {
select(dbIndex);
Expand All @@ -354,27 +359,41 @@ private void initializeFromClientConfig(JedisClientConfig config) {
// TODO: need to figure out something without encoding
clientSetname(clientName);
}

} catch (JedisException je) {
try {
if (isConnected()) {
quit();
}
disconnect();
} catch (Exception e) {
//
// the first exception 'je' will be thrown
}
throw je;
}
}

private String auth(final String password) {
sendCommand(Protocol.Command.AUTH, password);
return getStatusCodeReply();
}
private void auth(final Supplier<RedisCredentials> credentialsProvider) {
RedisCredentials credentials = credentialsProvider.get();
if (credentials == null || credentials.getPassword() == null) return;

private String auth(final String user, final String password) {
sendCommand(Protocol.Command.AUTH, user, password);
return getStatusCodeReply();
// Source: https://stackoverflow.com/a/9670279/4021802
ByteBuffer passBuf = Protocol.CHARSET.encode(CharBuffer.wrap(credentials.getPassword()));
byte[] rawPass = Arrays.copyOfRange(passBuf.array(), passBuf.position(), passBuf.limit());
Arrays.fill(passBuf.array(), (byte) 0); // clear sensitive data

if (credentials.getUser() != null) {
sendCommand(Protocol.Command.AUTH, SafeEncoder.encode(credentials.getUser()), rawPass);
} else {
sendCommand(Protocol.Command.AUTH, rawPass);
}

Arrays.fill(rawPass, (byte) 0); // clear sensitive data

// clearing 'char[] credentials.getPassword()' should be
// handled in RedisCredentialsProvider.cleanUp()

getStatusCodeReply(); // OK
}

public String select(final int index) {
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/redis/clients/jedis/ConnectionFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, final Jedi
this.jedisSocketFactory = jedisSocketFactory;
}

/**
* @deprecated Use {@link RedisCredentialsProvider} through
* {@link JedisClientConfig#getCredentialsProvider()}.
*/
@Deprecated
public void setPassword(final String password) {
this.clientConfig.updatePassword(password);
}
Expand Down
64 changes: 43 additions & 21 deletions src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package redis.clients.jedis;

import java.util.Objects;
import java.util.function.Supplier;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocketFactory;
Expand All @@ -11,8 +11,7 @@ public final class DefaultJedisClientConfig implements JedisClientConfig {
private final int socketTimeoutMillis;
private final int blockingSocketTimeoutMillis;

private final String user;
private volatile String password;
private volatile Supplier<RedisCredentials> credentialsProvider;
private final int database;
private final String clientName;

Expand All @@ -24,14 +23,13 @@ public final class DefaultJedisClientConfig implements JedisClientConfig {
private final HostAndPortMapper hostAndPortMapper;

private DefaultJedisClientConfig(int connectionTimeoutMillis, int soTimeoutMillis,
int blockingSocketTimeoutMillis, String user, String password, int database, String clientName,
boolean ssl, SSLSocketFactory sslSocketFactory, SSLParameters sslParameters,
int blockingSocketTimeoutMillis, Supplier<RedisCredentials> credentialsProvider, int database,
String clientName, boolean ssl, SSLSocketFactory sslSocketFactory, SSLParameters sslParameters,
HostnameVerifier hostnameVerifier, HostAndPortMapper hostAndPortMapper) {
this.connectionTimeoutMillis = connectionTimeoutMillis;
this.socketTimeoutMillis = soTimeoutMillis;
this.blockingSocketTimeoutMillis = blockingSocketTimeoutMillis;
this.user = user;
this.password = password;
this.credentialsProvider = credentialsProvider;
this.database = database;
this.clientName = clientName;
this.ssl = ssl;
Expand All @@ -58,19 +56,25 @@ public int getBlockingSocketTimeoutMillis() {

@Override
public String getUser() {
return user;
return credentialsProvider.get().getUser();
}

@Override
public String getPassword() {
return password;
char[] password = credentialsProvider.get().getPassword();
return password == null ? null : new String(password);
}

@Override
public Supplier<RedisCredentials> getCredentialsProvider() {
return credentialsProvider;
}

@Override
@Deprecated
public synchronized void updatePassword(String password) {
if (!Objects.equals(this.password, password)) {
this.password = password;
}
((DefaultRedisCredentialsProvider) this.credentialsProvider)
.setCredentials(new DefaultRedisCredentials(getUser(), password));
}

@Override
Expand Down Expand Up @@ -120,6 +124,7 @@ public static class Builder {

private String user = null;
private String password = null;
private Supplier<RedisCredentials> credentialsProvider;
private int database = Protocol.DEFAULT_DATABASE;
private String clientName = null;

Expand All @@ -134,9 +139,14 @@ private Builder() {
}

public DefaultJedisClientConfig build() {
if (credentialsProvider == null) {
credentialsProvider = new DefaultRedisCredentialsProvider(
new DefaultRedisCredentials(user, password));
}

return new DefaultJedisClientConfig(connectionTimeoutMillis, socketTimeoutMillis,
blockingSocketTimeoutMillis, user, password, database, clientName, ssl, sslSocketFactory,
sslParameters, hostnameVerifier, hostAndPortMapper);
blockingSocketTimeoutMillis, credentialsProvider, database, clientName, ssl,
sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper);
}

public Builder timeoutMillis(int timeoutMillis) {
Expand Down Expand Up @@ -170,6 +180,16 @@ public Builder password(String password) {
return this;
}

public Builder credentials(RedisCredentials credentials) {
this.credentialsProvider = new DefaultRedisCredentialsProvider(credentials);
return this;
}

public Builder credentialsProvider(Supplier<RedisCredentials> credentials) {
this.credentialsProvider = credentials;
return this;
}

public Builder database(int database) {
this.database = database;
return this;
Expand Down Expand Up @@ -210,16 +230,18 @@ public static DefaultJedisClientConfig create(int connectionTimeoutMillis, int s
int blockingSocketTimeoutMillis, String user, String password, int database, String clientName,
boolean ssl, SSLSocketFactory sslSocketFactory, SSLParameters sslParameters,
HostnameVerifier hostnameVerifier, HostAndPortMapper hostAndPortMapper) {
return new DefaultJedisClientConfig(connectionTimeoutMillis, soTimeoutMillis,
blockingSocketTimeoutMillis, user, password, database, clientName, ssl,
sslSocketFactory, sslParameters, hostnameVerifier, hostAndPortMapper);
return new DefaultJedisClientConfig(
connectionTimeoutMillis, soTimeoutMillis, blockingSocketTimeoutMillis,
new DefaultRedisCredentialsProvider(new DefaultRedisCredentials(user, password)),
database, clientName, ssl, sslSocketFactory, sslParameters,
hostnameVerifier, hostAndPortMapper);
}

public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) {
return new DefaultJedisClientConfig(copy.getConnectionTimeoutMillis(),
copy.getSocketTimeoutMillis(), copy.getBlockingSocketTimeoutMillis(), copy.getUser(),
copy.getPassword(), copy.getDatabase(), copy.getClientName(), copy.isSsl(),
copy.getSslSocketFactory(), copy.getSslParameters(), copy.getHostnameVerifier(),
copy.getHostAndPortMapper());
copy.getSocketTimeoutMillis(), copy.getBlockingSocketTimeoutMillis(),
copy.getCredentialsProvider(), copy.getDatabase(), copy.getClientName(),
copy.isSsl(), copy.getSslSocketFactory(), copy.getSslParameters(),
copy.getHostnameVerifier(), copy.getHostAndPortMapper());
}
}
38 changes: 38 additions & 0 deletions src/main/java/redis/clients/jedis/DefaultRedisCredentials.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package redis.clients.jedis;

public final class DefaultRedisCredentials implements RedisCredentials {

private final String user;
private final char[] password;

public DefaultRedisCredentials(String user, char[] password) {
this.user = user;
this.password = password;
}

public DefaultRedisCredentials(String user, CharSequence password) {
this.user = user;
this.password = password == null ? null
: password instanceof String ? ((String) password).toCharArray()
: toCharArray(password);
}

@Override
public String getUser() {
return user;
}

@Override
public char[] getPassword() {
return password;
}

private static char[] toCharArray(CharSequence seq) {
final int len = seq.length();
char[] arr = new char[len];
for (int i = 0; i < len; i++) {
arr[i] = seq.charAt(i);
}
return arr;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package redis.clients.jedis;

public final class DefaultRedisCredentialsProvider implements RedisCredentialsProvider {

private volatile RedisCredentials credentials;

public DefaultRedisCredentialsProvider(RedisCredentials credentials) {
this.credentials = credentials;
}

public void setCredentials(RedisCredentials credentials) {
this.credentials = credentials;
}

@Override
public RedisCredentials get() {
return this.credentials;
}
}
7 changes: 7 additions & 0 deletions src/main/java/redis/clients/jedis/JedisClientConfig.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package redis.clients.jedis;

import java.util.function.Supplier;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocketFactory;
Expand Down Expand Up @@ -39,9 +40,15 @@ default String getPassword() {
return null;
}

@Deprecated
default void updatePassword(String password) {
}

default Supplier<RedisCredentials> getCredentialsProvider() {
return new DefaultRedisCredentialsProvider(
new DefaultRedisCredentials(getUser(), getPassword()));
}

default int getDatabase() {
return Protocol.DEFAULT_DATABASE;
}
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/redis/clients/jedis/JedisFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ void setHostAndPort(final HostAndPort hostAndPort) {
((DefaultJedisSocketFactory) jedisSocketFactory).updateHostAndPort(hostAndPort);
}

/**
* @deprecated Use {@link RedisCredentialsProvider} through
* {@link JedisClientConfig#getCredentialsProvider()}.
*/
@Deprecated
public void setPassword(final String password) {
this.clientConfig.updatePassword(password);
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/redis/clients/jedis/Protocol.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public final class Protocol {
private static final String NOPERM_PREFIX = "NOPERM";

private Protocol() {
// this prevent the class from instantiation
throw new InstantiationError("Must not instantiate this class");
}

public static void sendCommand(final RedisOutputStream os, CommandArguments args) {
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/redis/clients/jedis/RedisCredentials.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package redis.clients.jedis;

public interface RedisCredentials {

/**
* @return Redis ACL user
*/
default String getUser() {
return null;
}

default char[] getPassword() {
return null;
}
}
Loading

0 comments on commit d4644da

Please sign in to comment.