diff --git a/atlasdb-cassandra-integration-tests/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolIntegrationTest.java b/atlasdb-cassandra-integration-tests/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolIntegrationTest.java index d64ad1c72ae..3f0f1181144 100644 --- a/atlasdb-cassandra-integration-tests/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolIntegrationTest.java +++ b/atlasdb-cassandra-integration-tests/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolIntegrationTest.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Range; import com.palantir.atlasdb.cassandra.CassandraKeyValueServiceConfig; import com.palantir.atlasdb.cassandra.CassandraKeyValueServiceRuntimeConfig; @@ -71,12 +72,12 @@ public void setUp() { @Test public void testTokenMapping() { - Map, List> mapOfRanges = + Map, ImmutableSet> mapOfRanges = clientPool.getTokenMap().asMapOfRanges(); assertThat(mapOfRanges).isNotEmpty(); - for (Map.Entry, List> entry : mapOfRanges.entrySet()) { + for (Map.Entry, ImmutableSet> entry : mapOfRanges.entrySet()) { Range tokenRange = entry.getKey(); - List hosts = entry.getValue(); + ImmutableSet hosts = entry.getValue(); clientPool.getRandomServerForKey("A".getBytes(StandardCharsets.UTF_8)); diff --git a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/cassandra/CassandraServersConfigs.java b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/cassandra/CassandraServersConfigs.java index cc6ec918706..bc7ea3f7d6a 100644 --- a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/cassandra/CassandraServersConfigs.java +++ b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/cassandra/CassandraServersConfigs.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import com.google.common.collect.ImmutableSet; import com.palantir.logsafe.Preconditions; import com.palantir.logsafe.SafeArg; import com.palantir.logsafe.exceptions.SafeIllegalStateException; @@ -51,16 +52,17 @@ public interface Visitor { T visit(CqlCapableConfig cqlCapableConfig); } - public static final class ThriftHostsExtractingVisitor implements Visitor> { + public enum ThriftHostsExtractingVisitor implements Visitor> { + INSTANCE; @Override - public Set visit(DefaultConfig defaultConfig) { - return defaultConfig.thriftHosts(); + public ImmutableSet visit(DefaultConfig defaultConfig) { + return ImmutableSet.copyOf(defaultConfig.thriftHosts()); } @Override - public Set visit(CqlCapableConfig cqlCapableConfig) { - return cqlCapableConfig.thriftHosts(); + public ImmutableSet visit(CqlCapableConfig cqlCapableConfig) { + return ImmutableSet.copyOf(cqlCapableConfig.thriftHosts()); } } diff --git a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/Blacklist.java b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/Blacklist.java index f5fe94279af..0414b37bb19 100644 --- a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/Blacklist.java +++ b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/Blacklist.java @@ -26,7 +26,6 @@ import com.palantir.logsafe.logger.SafeLoggerFactory; import com.palantir.refreshable.Refreshable; import java.time.Clock; -import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -108,8 +107,8 @@ private boolean isHostHealthy(CassandraClientPoolingContainer container) { } } - public Set filterBlacklistedHostsFrom(Collection potentialHosts) { - return Sets.difference(ImmutableSet.copyOf(potentialHosts), blacklist.keySet()); + public Set filterBlacklistedHostsFrom(ImmutableSet potentialHosts) { + return Sets.difference(potentialHosts, blacklist.keySet()); } boolean contains(CassandraServer cassandraServer) { diff --git a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraAbsentHostTracker.java b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraAbsentHostTracker.java index 6067dc47eff..84e1111564c 100644 --- a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraAbsentHostTracker.java +++ b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraAbsentHostTracker.java @@ -26,13 +26,15 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; +import javax.annotation.concurrent.GuardedBy; import org.immutables.value.Value; public final class CassandraAbsentHostTracker { private static final SafeLogger log = SafeLoggerFactory.get(CassandraAbsentHostTracker.class); private final int absenceLimit; + + @GuardedBy("this") private final Map absentCassandraServers; public CassandraAbsentHostTracker(int absenceLimit) { @@ -59,20 +61,20 @@ public synchronized void shutDown() { absentCassandraServers.clear(); } - private Set cleanupAbsentServer(Set absentServersSnapshot) { + private ImmutableSet cleanupAbsentServer(ImmutableSet absentServersSnapshot) { absentServersSnapshot.forEach(this::incrementAbsenceCountIfPresent); return absentServersSnapshot.stream() .map(this::removeIfAbsenceThresholdReached) .flatMap(Optional::stream) - .collect(Collectors.toSet()); + .collect(ImmutableSet.toImmutableSet()); } - private void incrementAbsenceCountIfPresent(CassandraServer cassandraServer) { + private synchronized void incrementAbsenceCountIfPresent(CassandraServer cassandraServer) { absentCassandraServers.computeIfPresent( cassandraServer, (_host, poolAndCount) -> poolAndCount.incrementCount()); } - private Optional removeIfAbsenceThresholdReached(CassandraServer cassandraServer) { + private synchronized Optional removeIfAbsenceThresholdReached(CassandraServer cassandraServer) { if (absentCassandraServers.get(cassandraServer).timesAbsent() <= absenceLimit) { return Optional.empty(); } else { diff --git a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolImpl.java b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolImpl.java index 1b144d8ed4d..84353470fa0 100644 --- a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolImpl.java +++ b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolImpl.java @@ -23,6 +23,7 @@ import com.google.common.collect.Multimap; import com.google.common.collect.RangeMap; import com.google.common.collect.Sets; +import com.google.common.collect.Sets.SetView; import com.palantir.async.initializer.AsyncInitializer; import com.palantir.atlasdb.AtlasDbConstants; import com.palantir.atlasdb.cassandra.CassandraKeyValueServiceConfig; @@ -36,7 +37,6 @@ import com.palantir.common.base.FunctionCheckedException; import com.palantir.common.concurrent.InitializeableScheduledExecutorServiceSupplier; import com.palantir.common.concurrent.NamedThreadFactory; -import com.palantir.common.streams.KeyedStream; import com.palantir.logsafe.SafeArg; import com.palantir.logsafe.UnsafeArg; import com.palantir.logsafe.exceptions.SafeIllegalStateException; @@ -44,7 +44,6 @@ import com.palantir.logsafe.logger.SafeLoggerFactory; import com.palantir.refreshable.Refreshable; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -302,7 +301,7 @@ public Map getCurrentPools() { } @VisibleForTesting - RangeMap> getTokenMap() { + RangeMap> getTokenMap() { return cassandra.getTokenMap(); } @@ -324,15 +323,16 @@ private synchronized void refreshPool() { } @VisibleForTesting - void setServersInPoolTo(Set desiredServers) { - Set cachedServers = getCachedServers(); - Set serversToAdd = ImmutableSet.copyOf(Sets.difference(desiredServers, cachedServers)); - Set absentServers = ImmutableSet.copyOf(Sets.difference(cachedServers, desiredServers)); + void setServersInPoolTo(ImmutableSet desiredServers) { + ImmutableSet cachedServers = getCachedServers(); + SetView serversToAdd = Sets.difference(desiredServers, cachedServers); + SetView absentServers = Sets.difference(cachedServers, desiredServers); serversToAdd.forEach(server -> cassandra.returnOrCreatePool(server, absentHostTracker.returnPool(server))); - Map containersForAbsentHosts = - KeyedStream.of(absentServers).map(cassandra::removePool).collectToMap(); - containersForAbsentHosts.forEach(absentHostTracker::trackAbsentCassandraServer); + absentServers.forEach(cassandraServer -> { + CassandraClientPoolingContainer container = cassandra.removePool(cassandraServer); + absentHostTracker.trackAbsentCassandraServer(cassandraServer, container); + }); Set serversToShutdown = absentHostTracker.incrementAbsenceAndRemove(); @@ -359,8 +359,8 @@ private static void logRefreshedHosts( } } - private Set getCachedServers() { - return cassandra.getPools().keySet(); + private ImmutableSet getCachedServers() { + return ImmutableSet.copyOf(cassandra.getPools().keySet()); } @Override diff --git a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraLogHelper.java b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraLogHelper.java index e3a59833f58..3afb5e00b96 100644 --- a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraLogHelper.java +++ b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraLogHelper.java @@ -49,7 +49,7 @@ static List tokenRangesToServer(Multimap, CassandraServe .collect(Collectors.toList()); } - public static List tokenMap(RangeMap> tokenMap) { + public static List tokenMap(RangeMap> tokenMap) { return tokenMap.asMapOfRanges().entrySet().stream() .map(rangeListToHostEntry -> String.format( "range from %s to %s is on host %s", diff --git a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraVerifier.java b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraVerifier.java index a6e58a6b46c..da12f6998aa 100644 --- a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraVerifier.java +++ b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraVerifier.java @@ -215,9 +215,8 @@ private static void createKeyspace(CassandraVerifierConfig verifierConfig) throw } private static boolean attemptToCreateKeyspace(CassandraVerifierConfig verifierConfig) { - Set thriftHosts = verifierConfig.servers().accept(new ThriftHostsExtractingVisitor()); - - return thriftHosts.stream().anyMatch(host -> attemptToCreateIfNotExists(host, verifierConfig)); + return verifierConfig.servers().accept(ThriftHostsExtractingVisitor.INSTANCE).stream() + .anyMatch(host -> attemptToCreateIfNotExists(host, verifierConfig)); } private static boolean attemptToCreateIfNotExists(InetSocketAddress host, CassandraVerifierConfig verifierConfig) { diff --git a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/pool/CassandraService.java b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/pool/CassandraService.java index ff10886bc4d..84577bbc518 100644 --- a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/pool/CassandraService.java +++ b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/pool/CassandraService.java @@ -27,6 +27,7 @@ import com.google.common.collect.Range; import com.google.common.collect.RangeMap; import com.google.common.collect.Sets; +import com.google.common.collect.Sets.SetView; import com.google.common.io.BaseEncoding; import com.palantir.atlasdb.cassandra.CassandraKeyValueServiceConfig; import com.palantir.atlasdb.cassandra.CassandraKeyValueServiceRuntimeConfig; @@ -51,14 +52,12 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; import java.util.Random; @@ -75,8 +74,8 @@ public class CassandraService implements AutoCloseable { private static final SafeLogger log = SafeLoggerFactory.get(CassandraService.class); - private static final Interner>> tokensInterner = - Interners.newWeakInterner(); + private static final Interner>> + tokensInterner = Interners.newWeakInterner(); private final MetricsManager metricsManager; private final CassandraKeyValueServiceConfig config; @@ -84,13 +83,14 @@ public class CassandraService implements AutoCloseable { private final CassandraClientPoolMetrics poolMetrics; private final Refreshable runtimeConfig; - private volatile RangeMap> tokenMap = ImmutableRangeMap.of(); + private volatile ImmutableRangeMap> tokenMap = + ImmutableRangeMap.of(); private final Map currentPools = new ConcurrentHashMap<>(); - private volatile Map hostToDatacenter = ImmutableMap.of(); + private volatile ImmutableMap hostToDatacenter = ImmutableMap.of(); private List cassandraHosts; - private volatile Set localHosts = ImmutableSet.of(); + private volatile ImmutableSet localHosts = ImmutableSet.of(); private final Supplier> myLocationSupplier; private final Supplier> hostnameByIpSupplier; @@ -131,12 +131,12 @@ public static CassandraService createInitialized( @Override public void close() {} - public Set refreshTokenRangesAndGetServers() { - Set servers = new HashSet<>(); - Map hostToDatacentersThisRefresh = new HashMap<>(); + public ImmutableSet refreshTokenRangesAndGetServers() { + ImmutableSet.Builder servers = ImmutableSet.builder(); + ImmutableMap.Builder hostToDatacentersThisRefresh = ImmutableMap.builder(); try { - ImmutableRangeMap.Builder> newTokenRing = + ImmutableRangeMap.Builder> newTokenRing = ImmutableRangeMap.builder(); // grab latest token ring view from a random node in the cluster and update local hosts @@ -148,7 +148,7 @@ public Set refreshTokenRangesAndGetServers() { EndpointDetails onlyEndpoint = Iterables.getOnlyElement( Iterables.getOnlyElement(tokenRanges).getEndpoint_details()); CassandraServer onlyHost = getAddressForHost(onlyEndpoint.getHost()); - newTokenRing.put(Range.all(), ImmutableList.of(onlyHost)); + newTokenRing.put(Range.all(), ImmutableSet.of(onlyHost)); servers.add(onlyHost); hostToDatacentersThisRefresh.put(onlyHost, onlyEndpoint.getDatacenter()); } else { // normal case, large cluster with many vnodes @@ -160,7 +160,8 @@ public Set refreshTokenRangesAndGetServers() { .map(EndpointDetails::getDatacenter) .collectToMap(); - List hosts = new ArrayList<>(hostToDatacentersOnThisTokenRange.keySet()); + ImmutableSet hosts = + ImmutableSet.copyOf(hostToDatacentersOnThisTokenRange.keySet()); servers.addAll(hosts); hostToDatacentersThisRefresh.putAll(hostToDatacentersOnThisTokenRange); @@ -178,9 +179,9 @@ public Set refreshTokenRangesAndGetServers() { } } tokenMap = tokensInterner.intern(newTokenRing.build()); - logHostToDatacenterMapping(hostToDatacentersThisRefresh); - hostToDatacenter = hostToDatacentersThisRefresh; - return servers; + hostToDatacenter = hostToDatacentersThisRefresh.build(); + logHostToDatacenterMapping(hostToDatacenter); + return servers.build(); } catch (Exception e) { log.info( "Couldn't grab new token ranges for token aware cassandra mapping. We will retry in {} seconds.", @@ -189,13 +190,13 @@ public Set refreshTokenRangesAndGetServers() { // Attempt to re-resolve addresses from the configuration; this is important owing to certain race // conditions where the entire pool becomes invalid between refreshes. - Set resolvedConfigAddresses = getCurrentServerListFromConfig(); + ImmutableSet resolvedConfigAddresses = getCurrentServerListFromConfig(); - Set lastKnownAddresses = tokenMap.asMapOfRanges().values().stream() + ImmutableSet lastKnownAddresses = tokenMap.asMapOfRanges().values().stream() .flatMap(Collection::stream) - .collect(Collectors.toSet()); + .collect(ImmutableSet.toImmutableSet()); - return Sets.union(resolvedConfigAddresses, lastKnownAddresses); + return Sets.union(resolvedConfigAddresses, lastKnownAddresses).immutableCopy(); } } @@ -203,15 +204,14 @@ public Set refreshTokenRangesAndGetServers() { * It is expected that config provides list of servers that are directly reachable and do not require special IP * resolution. * */ - public Set getCurrentServerListFromConfig() { - Set inetSocketAddresses = getServersSocketAddressesFromConfig(); - return inetSocketAddresses.stream() + public ImmutableSet getCurrentServerListFromConfig() { + return getServersSocketAddressesFromConfig().stream() .map(cassandraHost -> CassandraServer.of(cassandraHost.getHostString(), cassandraHost)) - .collect(Collectors.toSet()); + .collect(ImmutableSet.toImmutableSet()); } - private Set getServersSocketAddressesFromConfig() { - return runtimeConfig.get().servers().accept(new ThriftHostsExtractingVisitor()); + private ImmutableSet getServersSocketAddressesFromConfig() { + return runtimeConfig.get().servers().accept(ThriftHostsExtractingVisitor.INSTANCE); } private void logHostToDatacenterMapping(Map hostToDatacentersThisRefresh) { @@ -239,20 +239,20 @@ private String getSnitch() { } } - private Set refreshLocalHosts(List tokenRanges) { + private ImmutableSet refreshLocalHosts(List tokenRanges) { Optional myLocation = myLocationSupplier.get(); - if (!myLocation.isPresent()) { + if (myLocation.isEmpty()) { return ImmutableSet.of(); } - Set newLocalHosts = tokenRanges.stream() + ImmutableSet newLocalHosts = tokenRanges.stream() .map(TokenRange::getEndpoint_details) .flatMap(Collection::stream) .filter(details -> isHostLocal(details, myLocation.get())) .map(EndpointDetails::getHost) .map(this::getAddressForHostThrowUnchecked) - .collect(Collectors.toSet()); + .collect(ImmutableSet.toImmutableSet()); if (newLocalHosts.isEmpty()) { log.warn("No local hosts found"); @@ -269,7 +269,7 @@ private static boolean isHostLocal(EndpointDetails details, HostLocation myLocat } @VisibleForTesting - void setLocalHosts(Set localHosts) { + void setLocalHosts(ImmutableSet localHosts) { this.localHosts = localHosts; } @@ -292,20 +292,19 @@ CassandraServer getAddressForHost(String inputHost) throws UnknownHostException return CassandraServer.of(cassandraHostName, getReachableProxies(cassandraHostName)); } - private Set getReachableProxies(String inputHost) throws UnknownHostException { + private ImmutableSet getReachableProxies(String inputHost) throws UnknownHostException { InetAddress[] resolvedHosts = InetAddress.getAllByName(inputHost); int knownPort = getKnownPort(); // It is okay to have reachable proxies that do not have a hostname return Stream.of(resolvedHosts) .map(inetAddr -> new InetSocketAddress(inetAddr, knownPort)) - .collect(Collectors.toSet()); + .collect(ImmutableSet.toImmutableSet()); } private int getKnownPort() throws UnknownHostException { - Set allKnownHosts = getAllKnownHosts(); - Set allKnownPorts = - allKnownHosts.stream().map(InetSocketAddress::getPort).collect(Collectors.toSet()); + ImmutableSet allKnownPorts = + getAllKnownHosts().stream().map(InetSocketAddress::getPort).collect(ImmutableSet.toImmutableSet()); if (allKnownPorts.size() == 1) { // if everyone is on one port, try and use that return Iterables.getOnlyElement(allKnownPorts); @@ -314,15 +313,15 @@ private int getKnownPort() throws UnknownHostException { } } - private Set getAllKnownHosts() { - return ImmutableSet.copyOf(Sets.union(getProxiesFromCurrentPool(), getServersSocketAddressesFromConfig())); + private SetView getAllKnownHosts() { + return Sets.union(getProxiesFromCurrentPool(), getServersSocketAddressesFromConfig()); } - private Set getProxiesFromCurrentPool() { - return currentPools.keySet().stream().map(CassandraServer::proxy).collect(Collectors.toSet()); + private ImmutableSet getProxiesFromCurrentPool() { + return currentPools.keySet().stream().map(CassandraServer::proxy).collect(ImmutableSet.toImmutableSet()); } - private List getHostsFor(byte[] key) { + private ImmutableSet getHostsFor(byte[] key) { return tokenMap.get(new LightweightOppToken(key)); } @@ -333,12 +332,11 @@ public Optional getRandomGoodHostForPredicate( public Optional getRandomGoodHostForPredicate( Predicate predicate, Set triedNodes) { - Map pools = currentPools; - - Set hostsMatchingPredicate = - pools.keySet().stream().filter(predicate).collect(Collectors.toSet()); + ImmutableSet hostsMatchingPredicate = + currentPools.keySet().stream().filter(predicate).collect(ImmutableSet.toImmutableSet()); + ImmutableMap hostToDatacenterSnapshot = hostToDatacenter; // volatile read Map triedDatacenters = triedNodes.stream() - .map(hostToDatacenter::get) + .map(hostToDatacenterSnapshot::get) .filter(Objects::nonNull) .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); Optional maximumAttemptsPerDatacenter = @@ -351,13 +349,13 @@ public Optional getRandomGoodHostForPredicate( .keys() .collect(Collectors.toSet()); - Set hostsInPermittedDatacenters = hostsMatchingPredicate.stream() + ImmutableSet hostsInPermittedDatacenters = hostsMatchingPredicate.stream() .filter(pool -> { - String datacenter = hostToDatacenter.get(pool); + String datacenter = hostToDatacenterSnapshot.get(pool); return datacenter == null || !maximallyAttemptedDatacenters.contains(datacenter); }) - .collect(Collectors.toSet()); - Set filteredHosts = + .collect(ImmutableSet.toImmutableSet()); + ImmutableSet filteredHosts = hostsInPermittedDatacenters.isEmpty() ? hostsMatchingPredicate : hostsInPermittedDatacenters; if (filteredHosts.isEmpty()) { @@ -373,7 +371,7 @@ public Optional getRandomGoodHostForPredicate( } Optional randomLivingHost = getRandomHostByActiveConnections(livingHosts); - return randomLivingHost.map(pools::get); + return randomLivingHost.map(currentPools::get); } public List> getAllNonBlacklistedHosts() { @@ -393,7 +391,7 @@ private String getRingViewDescription() { return CassandraLogHelper.tokenMap(tokenMap).toString(); } - public RangeMap> getTokenMap() { + public RangeMap> getTokenMap() { return tokenMap; } @@ -413,18 +411,15 @@ Set maybeFilterLocalHosts(Set hosts) { return hosts; } - private Optional getRandomHostByActiveConnections(Set desiredHosts) { - + @VisibleForTesting + Optional getRandomHostByActiveConnections(Set desiredHosts) { Set localFilteredHosts = maybeFilterLocalHosts(desiredHosts); - - Map matchingPools = KeyedStream.stream( - ImmutableMap.copyOf(currentPools)) - .filterKeys(localFilteredHosts::contains) - .collectToMap(); + ImmutableMap matchingPools = currentPools.entrySet().stream() + .filter(e -> localFilteredHosts.contains(e.getKey())) + .collect(ImmutableMap.toImmutableMap(Entry::getKey, Entry::getValue)); if (matchingPools.isEmpty()) { return Optional.empty(); } - return Optional.of(WeightedServers.create(matchingPools).getRandomServer()); } @@ -449,7 +444,7 @@ public void debugLogStateOfPool() { } public CassandraServer getRandomCassandraNodeForKey(byte[] key) { - List hostsForKey = getHostsFor(key); + ImmutableSet hostsForKey = getHostsFor(key); if (hostsForKey == null) { if (config.autoRefreshNodes()) { @@ -517,7 +512,7 @@ public CassandraClientPoolingContainer removePool(CassandraServer removedServerA } public void cacheInitialCassandraHosts() { - Set thriftSocket = getCurrentServerListFromConfig(); + ImmutableSet thriftSocket = getCurrentServerListFromConfig(); cassandraHosts = thriftSocket.stream() .sorted(Comparator.comparing(CassandraServer::cassandraHostName)) @@ -530,7 +525,7 @@ public void clearInitialCassandraHosts() { } @VisibleForTesting - void overrideHostToDatacenterMapping(Map hostToDatacenterOverride) { + void overrideHostToDatacenterMapping(ImmutableMap hostToDatacenterOverride) { this.hostToDatacenter = hostToDatacenterOverride; } } diff --git a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/pool/WeightedServers.java b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/pool/WeightedServers.java index 05e79e4cf1a..94db0f8035d 100644 --- a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/pool/WeightedServers.java +++ b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/pool/WeightedServers.java @@ -16,6 +16,7 @@ package com.palantir.atlasdb.keyvalue.cassandra.pool; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.palantir.atlasdb.keyvalue.cassandra.CassandraClientPoolingContainer; import com.palantir.logsafe.Preconditions; @@ -37,7 +38,7 @@ private WeightedServers(NavigableMap hosts) { public static WeightedServers create(Map pools) { Preconditions.checkArgument(!pools.isEmpty(), "pools should be non-empty"); - return new WeightedServers(buildHostsWeightedByActiveConnections(pools)); + return new WeightedServers(buildHostsWeightedByActiveConnections(ImmutableMap.copyOf(pools))); } /** @@ -50,7 +51,7 @@ public static WeightedServers create(Map buildHostsWeightedByActiveConnections( - Map pools) { + ImmutableMap pools) { Map openRequestsByHost = Maps.newHashMapWithExpectedSize(pools.size()); int totalOpenRequests = 0; diff --git a/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolTest.java b/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolTest.java index fbc030f6e33..2d075e0e29c 100644 --- a/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolTest.java +++ b/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/CassandraClientPoolTest.java @@ -110,11 +110,9 @@ public void setup() { when(config.getKeyspaceOrThrow()).thenReturn("ks"); blacklist = new Blacklist(config, Refreshable.only(runtimeConfig.unresponsiveHostBackoffTimeSeconds())); - doAnswer(invocation -> { - Set inetSocketAddresses = - runtimeConfig.servers().accept(new ThriftHostsExtractingVisitor()); - return inetSocketAddresses.stream().map(CassandraServer::of).collect(Collectors.toSet()); - }) + doAnswer(invocation -> runtimeConfig.servers().accept(ThriftHostsExtractingVisitor.INSTANCE).stream() + .map(CassandraServer::of) + .collect(ImmutableSet.toImmutableSet())) .when(cassandra) .getCurrentServerListFromConfig(); doAnswer(invocation -> poolServers.add(getInvocationAddress(invocation))) @@ -429,7 +427,7 @@ private CassandraServer getInvocationAddress(InvocationOnMock invocation) { private void setCassandraServersTo(CassandraServer... servers) { when(cassandra.refreshTokenRangesAndGetServers()) - .thenReturn(Arrays.stream(servers).collect(Collectors.toSet())); + .thenReturn(Arrays.stream(servers).collect(ImmutableSet.toImmutableSet())); } private CassandraClientPoolImpl createClientPool() { diff --git a/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/pool/CassandraServiceTest.java b/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/pool/CassandraServiceTest.java index 7d4f468189d..29c9df246df 100644 --- a/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/pool/CassandraServiceTest.java +++ b/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/pool/CassandraServiceTest.java @@ -35,6 +35,7 @@ import java.net.UnknownHostException; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.junit.Test; @@ -62,8 +63,8 @@ public class CassandraServiceTest { @Test public void shouldOnlyReturnLocalHosts() { - Set hosts = ImmutableSet.of(SERVER_1, SERVER_2); - Set localHosts = ImmutableSet.of(SERVER_1); + ImmutableSet hosts = ImmutableSet.of(SERVER_1, SERVER_2); + ImmutableSet localHosts = ImmutableSet.of(SERVER_1); CassandraService cassandra = clientPoolWithServersAndParams(hosts, 1.0); @@ -245,6 +246,29 @@ public void selectsFromAllHostsIfDatacenterMappingNotAvailable() { assertThat(suggestedHosts).containsExactlyInAnyOrderElementsOf(allHosts); } + @Test + public void getRandomHostByActiveConnectionsReturnsDesiredHost() { + ImmutableSet servers = IntStream.range(0, 24) + .mapToObj(i1 -> CassandraServer.of(InetSocketAddress.createUnresolved("10.0.0." + i1, DEFAULT_PORT))) + .collect(ImmutableSet.toImmutableSet()); + try (CassandraService service = clientPoolWithParams(servers, servers, 1.0)) { + service.setLocalHosts(servers.stream().limit(8).collect(ImmutableSet.toImmutableSet())); + for (int i = 0; i < 500_000; i++) { + // select some random nodes + ImmutableSet desired = IntStream.generate( + () -> ThreadLocalRandom.current().nextInt(servers.size())) + .limit(3) + .mapToObj(i1 -> servers.asList().get(i1)) + .collect(ImmutableSet.toImmutableSet()); + assertThat(service.getRandomHostByActiveConnections(desired)) + .describedAs("Iteration %i - Expecting a node selected from desired: %s", i, desired) + .isPresent() + .get() + .satisfies(server -> assertThat(desired).contains(server)); + } + } + } + private Set getRecommendedHostsFromAThousandTrials( CassandraService cassandra, Set hosts) { return IntStream.range(0, 1_000) diff --git a/atlasdb-ete-tests/src/test/java/com/palantir/atlasdb/ete/CassandraRepairEteTest.java b/atlasdb-ete-tests/src/test/java/com/palantir/atlasdb/ete/CassandraRepairEteTest.java index 9c09c9e2e15..b4a93dc6e79 100644 --- a/atlasdb-ete-tests/src/test/java/com/palantir/atlasdb/ete/CassandraRepairEteTest.java +++ b/atlasdb-ete-tests/src/test/java/com/palantir/atlasdb/ete/CassandraRepairEteTest.java @@ -23,6 +23,7 @@ import com.datastax.driver.core.policies.DefaultRetryPolicy; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Range; import com.google.common.collect.RangeMap; import com.google.common.collect.RangeSet; @@ -278,7 +279,7 @@ private Map>> getFullTokenMap() @SuppressWarnings("UnstableApiUsage") private Map>> invert( - RangeMap> tokenMap) { + RangeMap> tokenMap) { Map>> invertedMap = new HashMap<>(); tokenMap.asMapOfRanges() .forEach((range, addresses) -> addresses.forEach(address -> { diff --git a/changelog/@unreleased/pr-6074.v2.yml b/changelog/@unreleased/pr-6074.v2.yml new file mode 100644 index 00000000000..5c3d259d6af --- /dev/null +++ b/changelog/@unreleased/pr-6074.v2.yml @@ -0,0 +1,9 @@ +type: improvement +improvement: + description: |- + Optimize CassandraService node selection. + + Explicitly use ImmutableSet and SetView to identify snapshots and views of cluster state + both for clarity and to avoid excessive intermediate collection copies on hot code paths. + links: + - https://github.com/palantir/atlasdb/pull/6074