Skip to content

Commit

Permalink
Allow customization of C* cluster and session creation.
Browse files Browse the repository at this point in the history
- Expose C* cluster and session instances directly through Guice
- Allow overriding cluster creation through module constructor methods
- Uses a factory pattern for the session to allow overriding how the session is exposed to Guice
  • Loading branch information
llinder committed Feb 2, 2017
1 parent e731a09 commit e554bc8
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public CassandraMigrationService(CassandraModule.Config config) {
@Override
public void onStart(StartEvent event) throws Exception {
if (config.autoMigrate) {
Session session = event.getRegistry().get(CassandraService.class).getSession();
Session session = event.getRegistry().get(Session.class);
logger.info("Auto Migrating Cassandra");
MigrationRunner migrationRunner = new MigrationRunner();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package smartthings.ratpack.cassandra

import com.datastax.driver.core.Cluster
import com.datastax.driver.core.Host
import com.datastax.driver.core.Session
import org.cassandraunit.CassandraCQLUnit
import org.cassandraunit.dataset.cql.ClassPathCQLDataSet
import org.junit.Rule
Expand All @@ -19,12 +21,15 @@ class MigrationIntegrationSpec extends Specification {
CassandraModule.Config config = new CassandraModule.Config(keyspace: 'test', migrationFile: 'changelog.txt', seeds: seeds, autoMigrate: true)
CassandraMigrationService service = new CassandraMigrationService(config)
StartEvent mockStartEvent = Mock()
CassandraModule cassandraModule = new CassandraModule()
Cluster cluster = cassandraModule.cluster(config)

and:
CassandraService cassandraService = new CassandraService(config)
CassandraService cassandraService = new CassandraService(cluster, cluster.connect())

and:
def registry = new SimpleMutableRegistry()
registry.add(Session, cassandraService.session)
registry.add(CassandraService, cassandraService)

and:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package smartthings.ratpack.cassandra;

import com.datastax.driver.core.Session;
import com.google.inject.Inject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -8,15 +9,15 @@
import ratpack.registry.Registry;

public class CassandraHealthCheck implements HealthCheck {
private final CassandraService cassandraService;
private final Session session;
private final String validationQuery;

Logger logger = LoggerFactory.getLogger(CassandraHealthCheck.class);

@Inject
public CassandraHealthCheck(CassandraModule.Config cassandraConfig, CassandraService cassandraService) {
public CassandraHealthCheck(CassandraModule.Config cassandraConfig, Session session) {
this.validationQuery = cassandraConfig.getValidationQuery();
this.cassandraService = cassandraService;
this.session = session;
}

@Override
Expand All @@ -28,7 +29,7 @@ public String getName() {
public Promise<Result> check(Registry registry) throws Exception {
return Promise.async(upstream -> {
try {
cassandraService.getSession().execute(validationQuery);
session.execute(validationQuery);
upstream.success(Result.healthy());
} catch (Exception ex) {
logger.error("Cassandra connection is unhealthy", ex);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,35 @@
package smartthings.ratpack.cassandra;

import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.JdkSSLOptions;
import com.datastax.driver.core.PerHostPercentileTracker;
import com.datastax.driver.core.Session;
import com.datastax.driver.core.SocketOptions;
import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy;
import com.datastax.driver.core.policies.EC2MultiRegionAddressTranslator;
import com.datastax.driver.core.policies.PercentileSpeculativeExecutionPolicy;
import com.datastax.driver.core.policies.TokenAwarePolicy;
import com.google.inject.Provides;
import com.google.inject.Scopes;
import ratpack.guice.ConfigurableModule;

import java.io.FileInputStream;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.List;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ratpack.guice.ConfigurableModule;

/**
* Supports Cassandra for Ratpack.
*/
public class CassandraModule extends ConfigurableModule<CassandraModule.Config> {

private static final Logger logger = LoggerFactory.getLogger(CassandraModule.class);

public static class Config {

public Config() {
Expand All @@ -29,6 +49,8 @@ public Config() {
String migrationFile = "/migrations/cql.changelog";
Boolean autoMigrate = false;

List<String> cipherSuites = Arrays.asList("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA");

List<String> seeds;

public JKSConfig getTruststore() {
Expand Down Expand Up @@ -139,8 +161,93 @@ public void setPassword(String password) {

@Override
protected void configure() {
bind(Session.class).toProvider(CassandraSessionFactory.class).in(Scopes.SINGLETON);
bind(CassandraService.class).in(Scopes.SINGLETON);
bind(CassandraHealthCheck.class).in(Scopes.SINGLETON);
}

@Provides
public Cluster cluster(final Config config) {
return createCluster(config);
}

/**
* Extension point for overriding {@link Cluster} creation.
*
* @param config
* @return
*/
protected Cluster createCluster(final Config config) {
//Set the highest tracking to just above the socket timeout for the read.
PerHostPercentileTracker
tracker = PerHostPercentileTracker.builder(SocketOptions.DEFAULT_READ_TIMEOUT_MILLIS + 500).build();

DCAwareRoundRobinPolicy
dcAwareRoundRobinPolicy = DCAwareRoundRobinPolicy.builder().withUsedHostsPerRemoteDc(1).build();

Cluster.Builder builder = Cluster.builder()
.withLoadBalancingPolicy(new TokenAwarePolicy(dcAwareRoundRobinPolicy))
.withSpeculativeExecutionPolicy(new PercentileSpeculativeExecutionPolicy(tracker, 0.99, 3));

if (config.getShareEventLoopGroup()) {
builder.withNettyOptions(new RatpackCassandraNettyOptions());
}

for (String seed : config.seeds) {
if (seed.contains(":")) {
String[] tokens = seed.split(":");
builder.addContactPoint(tokens[0]).withPort(Integer.parseInt(tokens[1]));
} else {
builder.addContactPoint(seed);
}
}

builder.withAddressTranslator(new EC2MultiRegionAddressTranslator());

if (config.truststore != null) {
try {
SSLContext sslContext = createSSLContext(config);
builder.withSSL(
JdkSSLOptions.builder()
.withSSLContext(sslContext)
.withCipherSuites(config.cipherSuites.toArray(new String[0]))
.build());
} catch (Exception e) {
logger.error("Couldn't add SSL to the cluster builder.", e);
}
}

if (config.user != null) {
builder.withCredentials(config.user, config.password);
}

return builder.build();
}

/**
* Extension point for overriding {@link SSLContext} creation.
*
* @param config
* @return
*/
protected SSLContext createSSLContext(final Config config) throws Exception {
FileInputStream tsf = new FileInputStream(config.truststore.path);
FileInputStream ksf = new FileInputStream(config.keystore.path);
SSLContext ctx = SSLContext.getInstance("SSL");

KeyStore ts = KeyStore.getInstance("JKS");
ts.load(tsf, config.truststore.password.toCharArray());
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(ts);

KeyStore ks = KeyStore.getInstance("JKS");
ks.load(ksf, config.keystore.password.toCharArray());
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());

kmf.init(ks, config.keystore.password.toCharArray());

ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
return ctx;
}

}
Original file line number Diff line number Diff line change
@@ -1,104 +1,24 @@
package smartthings.ratpack.cassandra;

import com.datastax.driver.core.*;
import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy;
import com.datastax.driver.core.policies.EC2MultiRegionAddressTranslator;
import com.datastax.driver.core.policies.PercentileSpeculativeExecutionPolicy;
import com.datastax.driver.core.policies.TokenAwarePolicy;
import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.ResultSet;
import com.datastax.driver.core.ResultSetFuture;
import com.datastax.driver.core.Session;
import com.datastax.driver.core.Statement;
import com.google.inject.Inject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ratpack.exec.Promise;
import ratpack.service.Service;
import ratpack.service.StartEvent;
import ratpack.service.StopEvent;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import java.io.FileInputStream;
import java.security.KeyStore;
import java.security.SecureRandom;

public class CassandraService implements Service {

private Cluster cluster;
private Session session;
private final String[] cipherSuites = new String[]{"TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA"};
private final CassandraModule.Config cassandraConfig;

private Logger logger = LoggerFactory.getLogger(CassandraService.class);
private Cluster cluster;

@Inject
public CassandraService(CassandraModule.Config cassandraConfig) {
this.cassandraConfig = cassandraConfig;
}

private void connect() {
//Set the highest tracking to just above the socket timeout for the read.
PerHostPercentileTracker tracker = PerHostPercentileTracker.builder(SocketOptions.DEFAULT_READ_TIMEOUT_MILLIS + 500).build();

DCAwareRoundRobinPolicy dcAwareRoundRobinPolicy = DCAwareRoundRobinPolicy.builder().withUsedHostsPerRemoteDc(1).build();

Cluster.Builder builder = Cluster.builder()
.withLoadBalancingPolicy(new TokenAwarePolicy(dcAwareRoundRobinPolicy))
.withSpeculativeExecutionPolicy(new PercentileSpeculativeExecutionPolicy(tracker, 0.99, 3));

if (cassandraConfig.getShareEventLoopGroup()) {
builder.withNettyOptions(new RatpackCassandraNettyOptions());
}

for (String seed : cassandraConfig.seeds) {
if (seed.contains(":")) {
String[] tokens = seed.split(":");
builder.addContactPoint(tokens[0]).withPort(Integer.parseInt(tokens[1]));
} else {
builder.addContactPoint(seed);
}
}

builder.withAddressTranslator(new EC2MultiRegionAddressTranslator());

if (cassandraConfig.truststore != null) {
try {
SSLContext sslContext = getSSLContext(cassandraConfig.truststore.path, cassandraConfig.truststore.password, cassandraConfig.keystore.path, cassandraConfig.keystore.password);
builder.withSSL(JdkSSLOptions.builder().withSSLContext(sslContext).withCipherSuites(cipherSuites).build());
} catch (Exception e) {
logger.error("Couldn't add SSL to the cluster builder.", e);
}
}

if (cassandraConfig.user != null) {
builder.withCredentials(cassandraConfig.user, cassandraConfig.password);
}

cluster = builder.build();

if (cassandraConfig.keyspace != null) {
session = cluster.connect(cassandraConfig.keyspace);
} else {
session = cluster.connect();
}
}

private static SSLContext getSSLContext(String truststorePath, String truststorePassword, String keystorePath, String keystorePassword) throws Exception {
FileInputStream tsf = new FileInputStream(truststorePath);
FileInputStream ksf = new FileInputStream(keystorePath);
SSLContext ctx = SSLContext.getInstance("SSL");

KeyStore ts = KeyStore.getInstance("JKS");
ts.load(tsf, truststorePassword.toCharArray());
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(ts);

KeyStore ks = KeyStore.getInstance("JKS");
ks.load(ksf, keystorePassword.toCharArray());
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());

kmf.init(ks, keystorePassword.toCharArray());

ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
return ctx;
public CassandraService(Cluster cluster, Session session) {
this.session = session;
this.cluster = cluster;
}

public Promise<ResultSet> execute(Statement statement) {
Expand All @@ -108,11 +28,6 @@ public Promise<ResultSet> execute(Statement statement) {
});
}

@Override
public void onStart(StartEvent event) throws Exception {
connect();
}

@Override
public void onStop(StopEvent event) throws Exception {
session.closeAsync();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package smartthings.ratpack.cassandra;

import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.Session;
import com.google.inject.Inject;
import com.google.inject.Provider;

public final class CassandraSessionFactory implements Provider<Session> {

private final Cluster cluster;
private final CassandraModule.Config config;

@Inject
CassandraSessionFactory(Cluster cluster, CassandraModule.Config config) {
this.cluster = cluster;
this.config = config;
}

@Override public Session get() {
return (config.keyspace != null && !"".equals(config.keyspace)) ?
cluster.connect(config.keyspace) :
cluster.connect();
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package smartthings.ratpack.cassandra

import com.datastax.driver.core.Cluster
import com.datastax.driver.core.exceptions.NoHostAvailableException
import org.cassandraunit.CassandraCQLUnit
import org.cassandraunit.dataset.CQLDataSet
Expand Down Expand Up @@ -41,9 +42,12 @@ class CassandraServiceSpec extends Specification {
cassConfig.setSeeds([TEST_SEED])
CassandraService service

CassandraModule module = new CassandraModule()
Cluster cluster = module.cluster(cassConfig)

when:
harness.run {
service = new CassandraService(cassConfig)
service = new CassandraService(cluster, cluster.connect())
service.onStart(new StartEvent() {
@Override
Registry getRegistry() {
Expand All @@ -68,9 +72,12 @@ class CassandraServiceSpec extends Specification {
cassConfig.setSeeds(["localhost:1111"])
CassandraService service

CassandraModule module = new CassandraModule()
Cluster cluster = module.cluster(cassConfig)

when:
harness.run {
service = new CassandraService(cassConfig)
service = new CassandraService(cluster, cluster.connect())
service.onStart(new StartEvent() {
@Override
Registry getRegistry() {
Expand Down

0 comments on commit e554bc8

Please sign in to comment.