diff --git a/ratpack-cassandra-migrate/src/main/java/smartthings/ratpack/cassandra/CassandraMigrationService.java b/ratpack-cassandra-migrate/src/main/java/smartthings/ratpack/cassandra/CassandraMigrationService.java index 0ac9bdd..2783744 100644 --- a/ratpack-cassandra-migrate/src/main/java/smartthings/ratpack/cassandra/CassandraMigrationService.java +++ b/ratpack-cassandra-migrate/src/main/java/smartthings/ratpack/cassandra/CassandraMigrationService.java @@ -25,7 +25,7 @@ public CassandraMigrationService(CassandraModule.Config config) { @Override public void onStart(StartEvent event) throws Exception { if (config.autoMigrate) { - Session session = event.getRegistry().get(CassandraService.class).getSession(); + Session session = event.getRegistry().get(Session.class); logger.info("Auto Migrating Cassandra"); MigrationRunner migrationRunner = new MigrationRunner(); diff --git a/ratpack-cassandra-migrate/src/test/groovy/smartthings/ratpack/cassandra/MigrationIntegrationSpec.groovy b/ratpack-cassandra-migrate/src/test/groovy/smartthings/ratpack/cassandra/MigrationIntegrationSpec.groovy index 56d1dd0..bc7e736 100644 --- a/ratpack-cassandra-migrate/src/test/groovy/smartthings/ratpack/cassandra/MigrationIntegrationSpec.groovy +++ b/ratpack-cassandra-migrate/src/test/groovy/smartthings/ratpack/cassandra/MigrationIntegrationSpec.groovy @@ -1,6 +1,8 @@ package smartthings.ratpack.cassandra +import com.datastax.driver.core.Cluster import com.datastax.driver.core.Host +import com.datastax.driver.core.Session import org.cassandraunit.CassandraCQLUnit import org.cassandraunit.dataset.cql.ClassPathCQLDataSet import org.junit.Rule @@ -19,12 +21,15 @@ class MigrationIntegrationSpec extends Specification { CassandraModule.Config config = new CassandraModule.Config(keyspace: 'test', migrationFile: 'changelog.txt', seeds: seeds, autoMigrate: true) CassandraMigrationService service = new CassandraMigrationService(config) StartEvent mockStartEvent = Mock() + CassandraModule cassandraModule = new CassandraModule() + Cluster cluster = cassandraModule.cluster(config) and: - CassandraService cassandraService = new CassandraService(config) + CassandraService cassandraService = new CassandraService(cluster, cluster.connect()) and: def registry = new SimpleMutableRegistry() + registry.add(Session, cassandraService.session) registry.add(CassandraService, cassandraService) and: diff --git a/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraHealthCheck.java b/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraHealthCheck.java index a377b08..7cae482 100644 --- a/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraHealthCheck.java +++ b/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraHealthCheck.java @@ -1,5 +1,6 @@ package smartthings.ratpack.cassandra; +import com.datastax.driver.core.Session; import com.google.inject.Inject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -8,15 +9,15 @@ import ratpack.registry.Registry; public class CassandraHealthCheck implements HealthCheck { - private final CassandraService cassandraService; + private final Session session; private final String validationQuery; Logger logger = LoggerFactory.getLogger(CassandraHealthCheck.class); @Inject - public CassandraHealthCheck(CassandraModule.Config cassandraConfig, CassandraService cassandraService) { + public CassandraHealthCheck(CassandraModule.Config cassandraConfig, Session session) { this.validationQuery = cassandraConfig.getValidationQuery(); - this.cassandraService = cassandraService; + this.session = session; } @Override @@ -28,7 +29,7 @@ public String getName() { public Promise check(Registry registry) throws Exception { return Promise.async(upstream -> { try { - cassandraService.getSession().execute(validationQuery); + session.execute(validationQuery); upstream.success(Result.healthy()); } catch (Exception ex) { logger.error("Cassandra connection is unhealthy", ex); diff --git a/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraModule.java b/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraModule.java index 71e945f..7ab78bd 100644 --- a/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraModule.java +++ b/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraModule.java @@ -1,15 +1,35 @@ package smartthings.ratpack.cassandra; +import com.datastax.driver.core.Cluster; +import com.datastax.driver.core.JdkSSLOptions; +import com.datastax.driver.core.PerHostPercentileTracker; +import com.datastax.driver.core.Session; +import com.datastax.driver.core.SocketOptions; +import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy; +import com.datastax.driver.core.policies.EC2MultiRegionAddressTranslator; +import com.datastax.driver.core.policies.PercentileSpeculativeExecutionPolicy; +import com.datastax.driver.core.policies.TokenAwarePolicy; +import com.google.inject.Provides; import com.google.inject.Scopes; -import ratpack.guice.ConfigurableModule; - +import java.io.FileInputStream; +import java.security.KeyStore; +import java.security.SecureRandom; +import java.util.Arrays; import java.util.List; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import ratpack.guice.ConfigurableModule; /** * Supports Cassandra for Ratpack. */ public class CassandraModule extends ConfigurableModule { + private static final Logger logger = LoggerFactory.getLogger(CassandraModule.class); + public static class Config { public Config() { @@ -29,6 +49,8 @@ public Config() { String migrationFile = "/migrations/cql.changelog"; Boolean autoMigrate = false; + List cipherSuites = Arrays.asList("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA"); + List seeds; public JKSConfig getTruststore() { @@ -139,8 +161,93 @@ public void setPassword(String password) { @Override protected void configure() { + bind(Session.class).toProvider(CassandraSessionFactory.class).in(Scopes.SINGLETON); bind(CassandraService.class).in(Scopes.SINGLETON); bind(CassandraHealthCheck.class).in(Scopes.SINGLETON); } + @Provides + public Cluster cluster(final Config config) { + return createCluster(config); + } + + /** + * Extension point for overriding {@link Cluster} creation. + * + * @param config + * @return + */ + protected Cluster createCluster(final Config config) { + //Set the highest tracking to just above the socket timeout for the read. + PerHostPercentileTracker + tracker = PerHostPercentileTracker.builder(SocketOptions.DEFAULT_READ_TIMEOUT_MILLIS + 500).build(); + + DCAwareRoundRobinPolicy + dcAwareRoundRobinPolicy = DCAwareRoundRobinPolicy.builder().withUsedHostsPerRemoteDc(1).build(); + + Cluster.Builder builder = Cluster.builder() + .withLoadBalancingPolicy(new TokenAwarePolicy(dcAwareRoundRobinPolicy)) + .withSpeculativeExecutionPolicy(new PercentileSpeculativeExecutionPolicy(tracker, 0.99, 3)); + + if (config.getShareEventLoopGroup()) { + builder.withNettyOptions(new RatpackCassandraNettyOptions()); + } + + for (String seed : config.seeds) { + if (seed.contains(":")) { + String[] tokens = seed.split(":"); + builder.addContactPoint(tokens[0]).withPort(Integer.parseInt(tokens[1])); + } else { + builder.addContactPoint(seed); + } + } + + builder.withAddressTranslator(new EC2MultiRegionAddressTranslator()); + + if (config.truststore != null) { + try { + SSLContext sslContext = createSSLContext(config); + builder.withSSL( + JdkSSLOptions.builder() + .withSSLContext(sslContext) + .withCipherSuites(config.cipherSuites.toArray(new String[0])) + .build()); + } catch (Exception e) { + logger.error("Couldn't add SSL to the cluster builder.", e); + } + } + + if (config.user != null) { + builder.withCredentials(config.user, config.password); + } + + return builder.build(); + } + + /** + * Extension point for overriding {@link SSLContext} creation. + * + * @param config + * @return + */ + protected SSLContext createSSLContext(final Config config) throws Exception { + FileInputStream tsf = new FileInputStream(config.truststore.path); + FileInputStream ksf = new FileInputStream(config.keystore.path); + SSLContext ctx = SSLContext.getInstance("SSL"); + + KeyStore ts = KeyStore.getInstance("JKS"); + ts.load(tsf, config.truststore.password.toCharArray()); + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(ts); + + KeyStore ks = KeyStore.getInstance("JKS"); + ks.load(ksf, config.keystore.password.toCharArray()); + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + + kmf.init(ks, config.keystore.password.toCharArray()); + + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom()); + return ctx; + } + } diff --git a/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraService.java b/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraService.java index 99e6b17..919b449 100644 --- a/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraService.java +++ b/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraService.java @@ -1,104 +1,24 @@ package smartthings.ratpack.cassandra; -import com.datastax.driver.core.*; -import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy; -import com.datastax.driver.core.policies.EC2MultiRegionAddressTranslator; -import com.datastax.driver.core.policies.PercentileSpeculativeExecutionPolicy; -import com.datastax.driver.core.policies.TokenAwarePolicy; +import com.datastax.driver.core.Cluster; +import com.datastax.driver.core.ResultSet; +import com.datastax.driver.core.ResultSetFuture; +import com.datastax.driver.core.Session; +import com.datastax.driver.core.Statement; import com.google.inject.Inject; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import ratpack.exec.Promise; import ratpack.service.Service; -import ratpack.service.StartEvent; import ratpack.service.StopEvent; -import javax.net.ssl.KeyManagerFactory; -import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManagerFactory; -import java.io.FileInputStream; -import java.security.KeyStore; -import java.security.SecureRandom; - public class CassandraService implements Service { - private Cluster cluster; private Session session; - private final String[] cipherSuites = new String[]{"TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA"}; - private final CassandraModule.Config cassandraConfig; - - private Logger logger = LoggerFactory.getLogger(CassandraService.class); + private Cluster cluster; @Inject - public CassandraService(CassandraModule.Config cassandraConfig) { - this.cassandraConfig = cassandraConfig; - } - - private void connect() { - //Set the highest tracking to just above the socket timeout for the read. - PerHostPercentileTracker tracker = PerHostPercentileTracker.builder(SocketOptions.DEFAULT_READ_TIMEOUT_MILLIS + 500).build(); - - DCAwareRoundRobinPolicy dcAwareRoundRobinPolicy = DCAwareRoundRobinPolicy.builder().withUsedHostsPerRemoteDc(1).build(); - - Cluster.Builder builder = Cluster.builder() - .withLoadBalancingPolicy(new TokenAwarePolicy(dcAwareRoundRobinPolicy)) - .withSpeculativeExecutionPolicy(new PercentileSpeculativeExecutionPolicy(tracker, 0.99, 3)); - - if (cassandraConfig.getShareEventLoopGroup()) { - builder.withNettyOptions(new RatpackCassandraNettyOptions()); - } - - for (String seed : cassandraConfig.seeds) { - if (seed.contains(":")) { - String[] tokens = seed.split(":"); - builder.addContactPoint(tokens[0]).withPort(Integer.parseInt(tokens[1])); - } else { - builder.addContactPoint(seed); - } - } - - builder.withAddressTranslator(new EC2MultiRegionAddressTranslator()); - - if (cassandraConfig.truststore != null) { - try { - SSLContext sslContext = getSSLContext(cassandraConfig.truststore.path, cassandraConfig.truststore.password, cassandraConfig.keystore.path, cassandraConfig.keystore.password); - builder.withSSL(JdkSSLOptions.builder().withSSLContext(sslContext).withCipherSuites(cipherSuites).build()); - } catch (Exception e) { - logger.error("Couldn't add SSL to the cluster builder.", e); - } - } - - if (cassandraConfig.user != null) { - builder.withCredentials(cassandraConfig.user, cassandraConfig.password); - } - - cluster = builder.build(); - - if (cassandraConfig.keyspace != null) { - session = cluster.connect(cassandraConfig.keyspace); - } else { - session = cluster.connect(); - } - } - - private static SSLContext getSSLContext(String truststorePath, String truststorePassword, String keystorePath, String keystorePassword) throws Exception { - FileInputStream tsf = new FileInputStream(truststorePath); - FileInputStream ksf = new FileInputStream(keystorePath); - SSLContext ctx = SSLContext.getInstance("SSL"); - - KeyStore ts = KeyStore.getInstance("JKS"); - ts.load(tsf, truststorePassword.toCharArray()); - TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); - tmf.init(ts); - - KeyStore ks = KeyStore.getInstance("JKS"); - ks.load(ksf, keystorePassword.toCharArray()); - KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); - - kmf.init(ks, keystorePassword.toCharArray()); - - ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom()); - return ctx; + public CassandraService(Cluster cluster, Session session) { + this.session = session; + this.cluster = cluster; } public Promise execute(Statement statement) { @@ -108,11 +28,6 @@ public Promise execute(Statement statement) { }); } - @Override - public void onStart(StartEvent event) throws Exception { - connect(); - } - @Override public void onStop(StopEvent event) throws Exception { session.closeAsync(); diff --git a/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraSessionFactory.java b/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraSessionFactory.java new file mode 100644 index 0000000..30bfe2b --- /dev/null +++ b/ratpack-cassandra/src/main/java/smartthings/ratpack/cassandra/CassandraSessionFactory.java @@ -0,0 +1,24 @@ +package smartthings.ratpack.cassandra; + +import com.datastax.driver.core.Cluster; +import com.datastax.driver.core.Session; +import com.google.inject.Inject; +import com.google.inject.Provider; + +public final class CassandraSessionFactory implements Provider { + + private final Cluster cluster; + private final CassandraModule.Config config; + + @Inject + CassandraSessionFactory(Cluster cluster, CassandraModule.Config config) { + this.cluster = cluster; + this.config = config; + } + + @Override public Session get() { + return (config.keyspace != null && !"".equals(config.keyspace)) ? + cluster.connect(config.keyspace) : + cluster.connect(); + } +} diff --git a/ratpack-cassandra/src/test/groovy/smartthings/ratpack/cassandra/CassandraServiceSpec.groovy b/ratpack-cassandra/src/test/groovy/smartthings/ratpack/cassandra/CassandraServiceSpec.groovy index 32eae01..9505fa5 100644 --- a/ratpack-cassandra/src/test/groovy/smartthings/ratpack/cassandra/CassandraServiceSpec.groovy +++ b/ratpack-cassandra/src/test/groovy/smartthings/ratpack/cassandra/CassandraServiceSpec.groovy @@ -1,5 +1,6 @@ package smartthings.ratpack.cassandra +import com.datastax.driver.core.Cluster import com.datastax.driver.core.exceptions.NoHostAvailableException import org.cassandraunit.CassandraCQLUnit import org.cassandraunit.dataset.CQLDataSet @@ -41,9 +42,12 @@ class CassandraServiceSpec extends Specification { cassConfig.setSeeds([TEST_SEED]) CassandraService service + CassandraModule module = new CassandraModule() + Cluster cluster = module.cluster(cassConfig) + when: harness.run { - service = new CassandraService(cassConfig) + service = new CassandraService(cluster, cluster.connect()) service.onStart(new StartEvent() { @Override Registry getRegistry() { @@ -68,9 +72,12 @@ class CassandraServiceSpec extends Specification { cassConfig.setSeeds(["localhost:1111"]) CassandraService service + CassandraModule module = new CassandraModule() + Cluster cluster = module.cluster(cassConfig) + when: harness.run { - service = new CassandraService(cassConfig) + service = new CassandraService(cluster, cluster.connect()) service.onStart(new StartEvent() { @Override Registry getRegistry() {