Skip to content

Commit

Permalink
SOLR-17574: Move host allow list cache to AllowListUrlChecker.
Browse files Browse the repository at this point in the history
  • Loading branch information
bruno-roustant committed Dec 2, 2024
1 parent cebdb2d commit 04c3c97
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 41 deletions.
2 changes: 2 additions & 0 deletions solr/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ Bug Fixes

* SOLR-17575: Fixed broken backwards compatibility with the legacy "langid.whitelist" config in Solr Langid. (Jan Høydahl, Alexander Zagniotov)

* SOLR-17574: Move host allow list cache to AllowListUrlChecker (Bruno Roustant, David Smiley)

Dependency Upgrades
---------------------
(No changes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import org.apache.solr.common.SolrException;
import org.apache.solr.common.cloud.ClusterState;
import org.apache.solr.core.NodeConfig;
Expand Down Expand Up @@ -84,6 +86,8 @@ public String toString() {

/** Allow list of hosts. Elements in the list will be host:port (no protocol or context). */
private final Set<String> hostAllowList;
private volatile Set<String> liveHostUrlsCache;
private volatile Set<String> liveNodesCache;

/**
* @param urlAllowList List of allowed URLs. URLs must be well-formed, missing protocol is
Expand Down Expand Up @@ -136,11 +140,10 @@ public void checkAllowList(List<String> urls) throws MalformedURLException {
*/
public void checkAllowList(List<String> urls, ClusterState clusterState)
throws MalformedURLException {
Set<String> clusterHostAllowList =
clusterState == null ? Collections.emptySet() : clusterState.getHostAllowList();
Set<String> liveHostUrls = getLiveHostUrls(clusterState);
for (String url : urls) {
String hostPort = parseHostPort(url);
if (clusterHostAllowList.stream().noneMatch(hostPort::equalsIgnoreCase)
if (liveHostUrls.stream().noneMatch(hostPort::equalsIgnoreCase)
&& hostAllowList.stream().noneMatch(hostPort::equalsIgnoreCase)) {
throw new SolrException(
SolrException.ErrorCode.FORBIDDEN,
Expand All @@ -154,6 +157,29 @@ public void checkAllowList(List<String> urls, ClusterState clusterState)
}
}

/**
* Gets the set of live hosts urls (host:port) built from the set of live nodes.
* The set is cached to be reused until the live nodes change.
*/
private Set<String> getLiveHostUrls(ClusterState clusterState) {
if (clusterState == null) {
return Set.of();
}
Set<String> liveNodes = clusterState.getLiveNodes();
if (liveHostUrlsCache == null || liveNodes != liveNodesCache) {
liveHostUrlsCache = buildLiveHostUrls(liveNodes);
liveNodesCache = liveNodes;
}
return liveHostUrlsCache;
}

@VisibleForTesting
Set<String> buildLiveHostUrls(Set<String> liveNodes) {
return liveNodes.stream()
.map((liveNode) -> liveNode.substring(0, liveNode.indexOf('_')))
.collect(Collectors.toSet());
}

/** Whether this checker has been created with a non-empty allow-list of URLs. */
public boolean hasExplicitAllowList() {
return !hostAllowList.isEmpty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,16 @@
package org.apache.solr.handler.component;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.hasItem;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;

import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.client.solrj.impl.LBSolrClient;
import org.apache.solr.client.solrj.request.QueryRequest;
import org.apache.solr.common.cloud.ClusterState;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.params.ShardParams;
Expand Down Expand Up @@ -155,18 +150,6 @@ public void getShardsAllowList() {
}
}

@Test
public void testLiveNodesToHostUrl() {
Set<String> liveNodes =
new HashSet<>(Arrays.asList("1.2.3.4:8983_solr", "1.2.3.4:9000_", "1.2.3.4:9001_solr-2"));
ClusterState cs = new ClusterState(liveNodes, new HashMap<>());
Set<String> hostSet = cs.getHostAllowList();
assertThat(hostSet.size(), is(3));
assertThat(hostSet, hasItem("1.2.3.4:8983"));
assertThat(hostSet, hasItem("1.2.3.4:9000"));
assertThat(hostSet, hasItem("1.2.3.4:9001"));
}

@Test
public void testXML() {
Path home = TEST_PATH();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@
import java.net.MalformedURLException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.cloud.ClusterState;
import org.junit.Test;

/** Tests {@link AllowListUrlChecker}. */
Expand Down Expand Up @@ -196,6 +200,46 @@ public void testHostParsingNoProtocol() throws Exception {
equalTo(AllowListUrlChecker.parseHostPorts(urls("https://abc-1.com:8983/solr"))));
}

@Test
public void testLiveNodesToHostUrlCache() throws Exception {
// Given some live nodes defined in the cluster state.
Set<String> liveNodes =
new HashSet<>(Arrays.asList("1.2.3.4:8983_solr", "1.2.3.4:9000_", "1.2.3.4:9001_solr-2"));
ClusterState clusterState1 = new ClusterState(liveNodes, new HashMap<>());

// When we call the AllowListUrlChecker.checkAllowList method on both valid and invalid urls.
AtomicInteger callCount = new AtomicInteger();
AllowListUrlChecker checker = new AllowListUrlChecker(List.of()) {
@Override
Set<String> buildLiveHostUrls(Set<String> liveNodes) {
callCount.incrementAndGet();
return super.buildLiveHostUrls(liveNodes);
}
};
for (int i = 0; i < 3; i++) {
checker.checkAllowList(List.of("1.2.3.4:8983", "1.2.3.4:9000", "1.2.3.4:9001"), clusterState1);
SolrException exception = expectThrows(
SolrException.class,
() -> checker.checkAllowList(List.of("1.1.3.4:8983"), clusterState1));
assertThat(exception.code(), equalTo(SolrException.ErrorCode.FORBIDDEN.code));
}
// Then we verify that the AllowListUrlChecker caches the live host urls and only builds them once.
assertThat(callCount.get(), equalTo(1));

// And when the ClusterState live nodes change.
liveNodes = new HashSet<>(Arrays.asList("2.3.4.5:8983_solr", "2.3.4.5:9000_", "2.3.4.5:9001_solr-2"));
ClusterState clusterState2 = new ClusterState(liveNodes, new HashMap<>());
for (int i = 0; i < 3; i++) {
checker.checkAllowList(List.of("2.3.4.5:8983", "2.3.4.5:9000", "2.3.4.5:9001"), clusterState2);
SolrException exception = expectThrows(
SolrException.class,
() -> checker.checkAllowList(List.of("1.1.3.4:8983"), clusterState2));
assertThat(exception.code(), equalTo(SolrException.ErrorCode.FORBIDDEN.code));
}
// Then the AllowListUrlChecker rebuilds the cache of live host urls.
assertThat(callCount.get(), equalTo(2));
}

private static List<String> urls(String... urls) {
return Arrays.asList(urls);
}
Expand Down
27 changes: 6 additions & 21 deletions solr/solrj/src/java/org/apache/solr/common/cloud/ClusterState.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.solr.common.MapWriter;
import org.apache.solr.common.SolrException;
Expand All @@ -52,6 +51,8 @@
/**
* Immutable state of the cloud. Normally you can get the state by using {@code
* ZkStateReader#getClusterState()}.
* <p>
* However, the {@link #setLiveNodes list of live nodes} is updated when nodes go up and down.
*
* @lucene.experimental
*/
Expand All @@ -63,8 +64,7 @@ public class ClusterState implements MapWriter {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private final Map<String, CollectionRef> collectionStates, immutableCollectionStates;
private Set<String> liveNodes;
private Set<String> hostAllowList;
private volatile Set<String> liveNodes;

/** Use this constr when ClusterState is meant for consumption. */
public ClusterState(Set<String> liveNodes, Map<String, DocCollection> collectionStates) {
Expand All @@ -85,8 +85,7 @@ private static Map<String, CollectionRef> getRefMap(Map<String, DocCollection> c
* loaded (parameter order different from constructor above to have different erasures)
*/
public ClusterState(Map<String, CollectionRef> collectionStates, Set<String> liveNodes) {
this.liveNodes = CollectionUtil.newHashSet(liveNodes.size());
this.liveNodes.addAll(liveNodes);
setLiveNodes(liveNodes);
this.collectionStates = new LinkedHashMap<>(collectionStates);
this.immutableCollectionStates = Collections.unmodifiableMap(this.collectionStates);
}
Expand Down Expand Up @@ -189,7 +188,7 @@ public Map<String, DocCollection> getCollectionsMap() {

/** Get names of the currently live nodes. */
public Set<String> getLiveNodes() {
return Collections.unmodifiableSet(liveNodes);
return liveNodes;
}

@Deprecated
Expand Down Expand Up @@ -387,7 +386,7 @@ public boolean equals(Object obj) {

/** Internal API used only by ZkStateReader */
void setLiveNodes(Set<String> liveNodes) {
this.liveNodes = liveNodes;
this.liveNodes = Set.of(liveNodes.toArray(new String[0]));
}

/**
Expand All @@ -401,20 +400,6 @@ public Map<String, CollectionRef> getCollectionStates() {
return immutableCollectionStates;
}

/**
* Gets the set of allowed hosts (host:port) built from the set of live nodes. The set is cached
* to be reused.
*/
public Set<String> getHostAllowList() {
if (hostAllowList == null) {
hostAllowList =
getLiveNodes().stream()
.map((liveNode) -> liveNode.substring(0, liveNode.indexOf('_')))
.collect(Collectors.toSet());
}
return hostAllowList;
}

/**
* Streams the resolved {@link DocCollection}s, which will often fetch from ZooKeeper for each one
* for a many-collection scenario. Use this sparingly; some users have thousands of collections!
Expand Down

0 comments on commit 04c3c97

Please sign in to comment.