diff --git a/src/main/java/io/appform/dropwizard/sharding/DBShardingBundleBase.java b/src/main/java/io/appform/dropwizard/sharding/DBShardingBundleBase.java index f4387484..c303889e 100644 --- a/src/main/java/io/appform/dropwizard/sharding/DBShardingBundleBase.java +++ b/src/main/java/io/appform/dropwizard/sharding/DBShardingBundleBase.java @@ -45,14 +45,12 @@ import lombok.Getter; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.hibernate.Interceptor; import org.hibernate.SessionFactory; import org.reflections.Reflections; import javax.persistence.Entity; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -86,7 +84,17 @@ protected DBShardingBundleBase( Class... entities) { this.dbNamespace = dbNamespace; val inEntities = ImmutableList.>builder().add(entity).add(entities).build(); - init(inEntities); + init(inEntities, null); + } + + protected DBShardingBundleBase( + String dbNamespace, + SessionFactoryFactory sessionFactoryFactory, + Class entity, + Class... entities) { + this.dbNamespace = dbNamespace; + val inEntities = ImmutableList.>builder().add(entity).add(entities).build(); + init(inEntities, sessionFactoryFactory); } protected DBShardingBundleBase(String dbNamespace, List classPathPrefixList) { @@ -94,11 +102,11 @@ protected DBShardingBundleBase(String dbNamespace, List classPathPrefixL Set> entities = new Reflections(classPathPrefixList).getTypesAnnotatedWith(Entity.class); Preconditions.checkArgument(!entities.isEmpty(), String.format("No entity class found at %s", String.join(",", classPathPrefixList))); val inEntities = ImmutableList.>builder().addAll(entities).build(); - init(inEntities); + init(inEntities, null); } protected DBShardingBundleBase(Class entity, Class... entities) { - this(DEFAULT_NAMESPACE, entity, entities); + this(DEFAULT_NAMESPACE, null, entity, entities); } protected DBShardingBundleBase(String... classPathPrefixes) { @@ -107,17 +115,17 @@ protected DBShardingBundleBase(String... classPathPrefixes) { protected abstract ShardManager createShardManager(int numShards, ShardBlacklistingStore blacklistingStore); - private void init(final ImmutableList> inEntities) { + private void init(final ImmutableList> inEntities, SessionFactoryFactory sessionFactoryFactoryy) { String numShardsEnv = System.getProperty(String.join(".", dbNamespace, DEFAULT_NAMESPACE), System.getProperty(SHARD_ENV, DEFAULT_SHARDS)); - + SessionFactoryFactory sessionFactoryFactory = sessionFactoryFactoryy != null ? sessionFactoryFactoryy : new SessionFactoryFactory(); this.numShards = Integer.parseInt(numShardsEnv); val blacklistingStore = getBlacklistingStore(); this.shardManager = createShardManager(numShards, blacklistingStore); this.shardInfoProvider = new ShardInfoProvider(dbNamespace); this.healthCheckManager = new HealthCheckManager(dbNamespace, shardInfoProvider, blacklistingStore, shardManager); IntStream.range(0, numShards).forEach( - shard -> shardBundles.add(new HibernateBundle(inEntities, new SessionFactoryFactory()) { + shard -> shardBundles.add(new HibernateBundle(inEntities, sessionFactoryFactory) { @Override protected String name() { return shardInfoProvider.shardName(shard); @@ -233,7 +241,6 @@ CacheableRelationalDao createRelatedObjectDao(Class claz return new CacheableRelationalDao<>(this.sessionFactories, clazz, new ShardCalculator<>(this.shardManager, bucketIdExtractor), cacheManager); } - public , T extends Configuration> WrapperDao createWrapperDao(Class daoTypeClass) { return new WrapperDao<>(this.sessionFactories,