diff --git a/server/src/internalClusterTest/java/org/opensearch/indices/IndicesRequestCacheIT.java b/server/src/internalClusterTest/java/org/opensearch/indices/IndicesRequestCacheIT.java index b23aac08702df..44e4d02620805 100644 --- a/server/src/internalClusterTest/java/org/opensearch/indices/IndicesRequestCacheIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/indices/IndicesRequestCacheIT.java @@ -34,6 +34,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; import org.opensearch.action.admin.cluster.node.stats.NodeStats; import org.opensearch.action.admin.cluster.node.stats.NodesStatsResponse; import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; @@ -43,11 +44,17 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchType; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.routing.allocation.command.MoveAllocationCommand; +import org.opensearch.cluster.routing.allocation.decider.EnableAllocationDecider; import org.opensearch.common.settings.Settings; import org.opensearch.common.time.DateFormatter; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.FeatureFlags; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.env.NodeEnvironment; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.cache.request.RequestCacheStats; import org.opensearch.index.query.QueryBuilders; @@ -59,6 +66,8 @@ import org.opensearch.test.ParameterizedStaticSettingsOpenSearchIntegTestCase; import org.opensearch.test.hamcrest.OpenSearchAssertions; +import java.nio.file.Files; +import java.nio.file.Path; import java.time.ZoneId; import java.time.ZoneOffset; import java.time.ZonedDateTime; @@ -70,6 +79,7 @@ import static org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; import static org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; +import static org.opensearch.cluster.routing.allocation.decider.EnableAllocationDecider.CLUSTER_ROUTING_ALLOCATION_ENABLE_SETTING; import static org.opensearch.indices.IndicesRequestCache.INDICES_REQUEST_CACHE_STALENESS_THRESHOLD_SETTING; import static org.opensearch.indices.IndicesService.INDICES_CACHE_CLEANUP_INTERVAL_SETTING_KEY; import static org.opensearch.search.SearchService.CLUSTER_CONCURRENT_SEGMENT_SEARCH_SETTING; @@ -1240,6 +1250,101 @@ public void testStaleKeysCleanupWithMultipleIndices() throws Exception { }, cacheCleanIntervalInMillis * 2, TimeUnit.MILLISECONDS); } + public void testDeleteAndCreateSameIndexShardOnSameNode() throws Exception { + String node_1 = internalCluster().startNode(Settings.builder().build()); + Client client = client(node_1); + + logger.info("Starting a node"); + + assertThat(cluster().size(), equalTo(1)); + ClusterHealthResponse healthResponse = client().admin().cluster().prepareHealth().setWaitForNodes("1").execute().actionGet(); + assertThat(healthResponse.isTimedOut(), equalTo(false)); + + String indexName = "test"; + + logger.info("Creating an index: {} with 2 shards", indexName); + createIndex( + indexName, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ); + + ensureGreen(indexName); + + logger.info("Writing few docs and searching which will cache items"); + indexRandom(true, client.prepareIndex(indexName).setSource("k", "hello")); + indexRandom(true, client.prepareIndex(indexName).setSource("y", "hello again")); + SearchResponse resp = client.prepareSearch(indexName).setRequestCache(true).setQuery(QueryBuilders.termQuery("k", "hello")).get(); + assertSearchResponse(resp); + resp = client.prepareSearch(indexName).setRequestCache(true).setQuery(QueryBuilders.termQuery("y", "hello")).get(); + + RequestCacheStats stats = getNodeCacheStats(client); + assertTrue(stats.getMemorySizeInBytes() > 0); + + logger.info("Disabling allocation"); + Settings newSettings = Settings.builder() + .put(CLUSTER_ROUTING_ALLOCATION_ENABLE_SETTING.getKey(), EnableAllocationDecider.Allocation.NONE.name()) + .build(); + client().admin().cluster().prepareUpdateSettings().setTransientSettings(newSettings).execute().actionGet(); + + logger.info("Starting a second node"); + String node_2 = internalCluster().startDataOnlyNode(Settings.builder().build()); + assertThat(cluster().size(), equalTo(2)); + healthResponse = client().admin().cluster().prepareHealth().setWaitForNodes("2").execute().actionGet(); + assertThat(healthResponse.isTimedOut(), equalTo(false)); + + logger.info("Moving the shard:{} from node:{} to node:{}", indexName + "#0", node_1, node_2); + MoveAllocationCommand cmd = new MoveAllocationCommand(indexName, 0, node_1, node_2); + internalCluster().client().admin().cluster().prepareReroute().add(cmd).get(); + ClusterHealthResponse clusterHealth = client().admin() + .cluster() + .prepareHealth() + .setWaitForNoRelocatingShards(true) + .setWaitForNoInitializingShards(true) + .get(); + assertThat(clusterHealth.isTimedOut(), equalTo(false)); + + ClusterState state = client().admin().cluster().prepareState().get().getState(); + final Index index = state.metadata().index(indexName).getIndex(); + + assertBusy(() -> { + assertThat(Files.exists(shardDirectory(node_1, index, 0)), equalTo(false)); + assertThat(Files.exists(shardDirectory(node_2, index, 0)), equalTo(true)); + }); + + logger.info("Moving the shard: {} again from node:{} to node:{}", indexName + "#0", node_2, node_1); + cmd = new MoveAllocationCommand(indexName, 0, node_2, node_1); + internalCluster().client().admin().cluster().prepareReroute().add(cmd).get(); + clusterHealth = client().admin() + .cluster() + .prepareHealth() + .setWaitForNoRelocatingShards(true) + .setWaitForNoInitializingShards(true) + .get(); + assertThat(clusterHealth.isTimedOut(), equalTo(false)); + assertThat(Files.exists(shardDirectory(node_1, index, 0)), equalTo(true)); + + assertBusy(() -> { + assertThat(Files.exists(shardDirectory(node_1, index, 0)), equalTo(true)); + assertThat(Files.exists(shardDirectory(node_2, index, 0)), equalTo(false)); + }); + + logger.info("Clearing the cache for index:{}. And verify the request stats doesn't go negative", indexName); + ClearIndicesCacheRequest clearIndicesCacheRequest = new ClearIndicesCacheRequest(indexName); + client.admin().indices().clearCache(clearIndicesCacheRequest).actionGet(); + + stats = getNodeCacheStats(client(node_1)); + assertTrue(stats.getMemorySizeInBytes() == 0); + stats = getNodeCacheStats(client(node_2)); + assertTrue(stats.getMemorySizeInBytes() == 0); + } + + private Path shardDirectory(String server, Index index, int shard) { + NodeEnvironment env = internalCluster().getInstance(NodeEnvironment.class, server); + final Path[] paths = env.availableShardPaths(new ShardId(index, shard)); + assert paths.length == 1; + return paths[0]; + } + private void setupIndex(Client client, String index) throws Exception { assertAcked( client.admin() diff --git a/server/src/main/java/org/opensearch/index/cache/request/ShardRequestCache.java b/server/src/main/java/org/opensearch/index/cache/request/ShardRequestCache.java index 502eae55df83e..906c1c5f86ef3 100644 --- a/server/src/main/java/org/opensearch/index/cache/request/ShardRequestCache.java +++ b/server/src/main/java/org/opensearch/index/cache/request/ShardRequestCache.java @@ -32,10 +32,13 @@ package org.opensearch.index.cache.request; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.util.Accountable; import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.metrics.CounterMetric; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.indices.IndicesRequestCache; /** * Tracks the portion of the request cache in use for a particular shard. @@ -45,6 +48,7 @@ @PublicApi(since = "1.0.0") public final class ShardRequestCache { + private static final Logger logger = LogManager.getLogger(ShardRequestCache.class); final CounterMetric evictionsMetric = new CounterMetric(); final CounterMetric totalMetric = new CounterMetric(); final CounterMetric hitCount = new CounterMetric(); @@ -75,7 +79,12 @@ public void onRemoval(long keyRamBytesUsed, BytesReference value, boolean evicte if (value != null) { dec += value.ramBytesUsed(); } - totalMetric.dec(dec); + if ((totalMetric.count() - dec) < 0) { + logger.warn("Ignoring the operation to deduct memory: {} from RequestStats memory_size metric as it will " + + "go negative. Current memory: {}. This is a bug.", dec, totalMetric.count()); + } else { + totalMetric.dec(dec); + } } // Old functions which increment size by passing in an Accountable. Functional but no longer used. @@ -94,5 +103,11 @@ public void onRemoval(Accountable key, BytesReference value, boolean evicted) { if (value != null) { dec += value.ramBytesUsed(); } + if ((totalMetric.count() - dec) < 0) { + logger.warn("Ignoring the operation to deduct memory: {} from RequestStats memory_size metric as it will " + + "go negative. Current memory: {}. This is a bug.", dec, totalMetric.count()); + } else { + totalMetric.dec(dec); + } } } diff --git a/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java b/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java index 35826d45f969f..f6010282f3d9b 100644 --- a/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java +++ b/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java @@ -205,6 +205,11 @@ public final class IndicesRequestCache implements RemovalListener, BytesReference> notifi // shards as part of request cache. // Pass a new removal notification containing Key rather than ICacheKey to the CacheEntity for backwards compatibility. Key key = notification.getKey().key; - cacheEntityLookup.apply(key.shardId).ifPresent(entity -> entity.onRemoval(notification)); - CleanupKey cleanupKey = new CleanupKey(cacheEntityLookup.apply(key.shardId).orElse(null), key.readerCacheKeyId); + IndicesService.IndexShardCacheEntity indexShardCacheEntity = (IndicesService.IndexShardCacheEntity) cacheEntityLookup.apply( + key.shardId + ).orElse(null); + if (indexShardCacheEntity != null) { + // Here we match the hashcode to avoid scenario where we deduct stats of older IndexShard(with same + // shardId) from current IndexShard. + if (key.indexShardHashCode == indexShardCacheEntity.getCacheIdentity().hashCode()) { + indexShardCacheEntity.onRemoval(notification); + } + } + CleanupKey cleanupKey = new CleanupKey(indexShardCacheEntity, key.readerCacheKeyId); cacheCleanupManager.updateStaleCountOnEntryRemoval(cleanupKey, notification); } @@ -266,7 +280,8 @@ BytesReference getOrCompute( .getReaderCacheHelper(); String readerCacheKeyId = delegatingCacheHelper.getDelegatingCacheKey().getId(); assert readerCacheKeyId != null; - final Key key = new Key(((IndexShard) cacheEntity.getCacheIdentity()).shardId(), cacheKey, readerCacheKeyId); + IndexShard indexShard = ((IndexShard) cacheEntity.getCacheIdentity()); + final Key key = new Key(indexShard.shardId(), cacheKey, readerCacheKeyId, indexShard.hashCode()); Loader cacheLoader = new Loader(cacheEntity, loader); BytesReference value = cache.computeIfAbsent(getICacheKey(key), cacheLoader); if (cacheLoader.isLoaded()) { @@ -299,7 +314,8 @@ void invalidate(IndicesService.IndexShardCacheEntity cacheEntity, DirectoryReade IndexReader.CacheHelper cacheHelper = ((OpenSearchDirectoryReader) reader).getDelegatingCacheHelper(); readerCacheKeyId = ((OpenSearchDirectoryReader.DelegatingCacheHelper) cacheHelper).getDelegatingCacheKey().getId(); } - cache.invalidate(getICacheKey(new Key(((IndexShard) cacheEntity.getCacheIdentity()).shardId(), cacheKey, readerCacheKeyId))); + IndexShard indexShard = (IndexShard) cacheEntity.getCacheIdentity(); + cache.invalidate(getICacheKey(new Key(indexShard.shardId(), cacheKey, readerCacheKeyId, indexShard.hashCode()))); } /** @@ -377,19 +393,24 @@ interface CacheEntity extends Accountable { */ static class Key implements Accountable, Writeable { public final ShardId shardId; // use as identity equality + public final int indexShardHashCode; // While ShardId is usually sufficient to uniquely identify an + // indexShard but in case where the same indexShard is deleted and reallocated on same node, we need the + // hashcode(default) to identify the older indexShard but with same shardId. public final String readerCacheKeyId; public final BytesReference value; - Key(ShardId shardId, BytesReference value, String readerCacheKeyId) { + Key(ShardId shardId, BytesReference value, String readerCacheKeyId, int indexShardHashCode) { this.shardId = shardId; this.value = value; this.readerCacheKeyId = Objects.requireNonNull(readerCacheKeyId); + this.indexShardHashCode = indexShardHashCode; } Key(StreamInput in) throws IOException { this.shardId = in.readOptionalWriteable(ShardId::new); this.readerCacheKeyId = in.readOptionalString(); this.value = in.readBytesReference(); + this.indexShardHashCode = in.readInt(); } @Override @@ -411,6 +432,7 @@ public boolean equals(Object o) { if (!Objects.equals(readerCacheKeyId, key.readerCacheKeyId)) return false; if (!shardId.equals(key.shardId)) return false; if (!value.equals(key.value)) return false; + if (indexShardHashCode != key.indexShardHashCode) return false; return true; } @@ -419,6 +441,7 @@ public int hashCode() { int result = shardId.hashCode(); result = 31 * result + readerCacheKeyId.hashCode(); result = 31 * result + value.hashCode(); + result = 31 * result + indexShardHashCode; return result; } @@ -427,6 +450,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalWriteable(shardId); out.writeOptionalString(readerCacheKeyId); out.writeBytesReference(value); + out.writeInt(indexShardHashCode); } } diff --git a/server/src/test/java/org/opensearch/indices/IRCKeyWriteableSerializerTests.java b/server/src/test/java/org/opensearch/indices/IRCKeyWriteableSerializerTests.java index af657dadd7a1a..a5014675ce0ed 100644 --- a/server/src/test/java/org/opensearch/indices/IRCKeyWriteableSerializerTests.java +++ b/server/src/test/java/org/opensearch/indices/IRCKeyWriteableSerializerTests.java @@ -45,6 +45,7 @@ private IndicesRequestCache.Key getRandomIRCKey(int valueLength, Random random, value[i] = (byte) (random.nextInt(126 - 32) + 32); } BytesReference keyValue = new BytesArray(value); - return new IndicesRequestCache.Key(shard, keyValue, UUID.randomUUID().toString()); // same UUID source as used in real key + return new IndicesRequestCache.Key(shard, keyValue, UUID.randomUUID().toString(), shard.hashCode()); // same UUID + // source as used in real key } } diff --git a/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java b/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java index bbf2867a0087c..d95bb4e659ac7 100644 --- a/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java +++ b/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java @@ -82,14 +82,22 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Random; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicInteger; import static org.opensearch.indices.IndicesRequestCache.INDEX_DIMENSION_NAME; +import static org.opensearch.indices.IndicesRequestCache.INDICES_CACHE_QUERY_SIZE; import static org.opensearch.indices.IndicesRequestCache.INDICES_REQUEST_CACHE_STALENESS_THRESHOLD_SETTING; import static org.opensearch.indices.IndicesRequestCache.SHARD_ID_DIMENSION_NAME; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; @@ -460,7 +468,12 @@ public void testStaleCount_OnRemovalNotificationOfStaleKey_DecrementsStaleCount( // cache count should not be affected assertEquals(2, cache.count()); - IndicesRequestCache.Key key = new IndicesRequestCache.Key(indexShard.shardId(), getTermBytes(), getReaderCacheKeyId(reader)); + IndicesRequestCache.Key key = new IndicesRequestCache.Key( + indexShard.shardId(), + getTermBytes(), + getReaderCacheKeyId(reader), + indexShard.hashCode() + ); // test the mapping ConcurrentMap> cleanupKeyToCountMap = cache.cacheCleanupManager.getCleanupKeyToCountMap(); // shard id should exist @@ -517,7 +530,12 @@ public void testStaleCount_OnRemovalNotificationOfNonStaleKey_DoesNotDecrementsS assertEquals(2, cache.count()); // evict entry from second reader (this reader is not closed) - IndicesRequestCache.Key key = new IndicesRequestCache.Key(indexShard.shardId(), getTermBytes(), getReaderCacheKeyId(secondReader)); + IndicesRequestCache.Key key = new IndicesRequestCache.Key( + indexShard.shardId(), + getTermBytes(), + getReaderCacheKeyId(secondReader), + indexShard.hashCode() + ); // test the mapping ConcurrentMap> cleanupKeyToCountMap = cache.cacheCleanupManager.getCleanupKeyToCountMap(); @@ -567,7 +585,12 @@ public void testStaleCount_WithoutReaderClosing_DecrementsStaleCount() throws Ex // no keys are stale assertEquals(0, cache.cacheCleanupManager.getStaleKeysCount().get()); // create notification for removal of non-stale entry - IndicesRequestCache.Key key = new IndicesRequestCache.Key(indexShard.shardId(), getTermBytes(), getReaderCacheKeyId(reader)); + IndicesRequestCache.Key key = new IndicesRequestCache.Key( + indexShard.shardId(), + getTermBytes(), + getReaderCacheKeyId(reader), + indexShard.hashCode() + ); cache.onRemoval( new RemovalNotification, BytesReference>( new ICacheKey<>(key), @@ -610,11 +633,8 @@ public void testStaleCount_OnRemovalNotifications() throws Exception { assertEquals(totalKeys, cache.cacheCleanupManager.getStaleKeysCount().get()); String readerCacheKeyId = getReaderCacheKeyId(reader); - IndicesRequestCache.Key key = new IndicesRequestCache.Key( - ((IndexShard) entity.getCacheIdentity()).shardId(), - termBytes, - readerCacheKeyId - ); + IndexShard indexShard = (IndexShard) entity.getCacheIdentity(); + IndicesRequestCache.Key key = new IndicesRequestCache.Key(indexShard.shardId(), termBytes, readerCacheKeyId, indexShard.hashCode()); int staleCount = cache.cacheCleanupManager.getStaleKeysCount().get(); // Notification for Replaced should not deduct the staleCount @@ -709,7 +729,12 @@ public void testCleanupKeyToCountMapAreSetAppropriately() throws Exception { // second reader's mapping should not be affected assertEquals(2, (int) cleanupKeyToCountMap.get(shardId).get(getReaderCacheKeyId(secondReader))); // send removal notification for first reader - IndicesRequestCache.Key key = new IndicesRequestCache.Key(indexShard.shardId(), getTermBytes(), getReaderCacheKeyId(reader)); + IndicesRequestCache.Key key = new IndicesRequestCache.Key( + indexShard.shardId(), + getTermBytes(), + getReaderCacheKeyId(reader), + indexShard.hashCode() + ); cache.onRemoval( new RemovalNotification, BytesReference>( new ICacheKey<>(key), @@ -725,7 +750,7 @@ public void testCleanupKeyToCountMapAreSetAppropriately() throws Exception { assertEquals(2, (int) cleanupKeyToCountMap.get(shardId).get(getReaderCacheKeyId(secondReader))); // Without closing the secondReader send removal notification of one of its key - key = new IndicesRequestCache.Key(indexShard.shardId(), getTermBytes(), getReaderCacheKeyId(secondReader)); + key = new IndicesRequestCache.Key(indexShard.shardId(), getTermBytes(), getReaderCacheKeyId(secondReader), indexShard.hashCode()); cache.onRemoval( new RemovalNotification, BytesReference>( new ICacheKey<>(key), @@ -738,7 +763,7 @@ public void testCleanupKeyToCountMapAreSetAppropriately() throws Exception { // secondReader's readerCacheKeyId count should be decremented by 1 assertEquals(1, (int) cleanupKeyToCountMap.get(shardId).get(getReaderCacheKeyId(secondReader))); // Without closing the secondReader send removal notification of its last key - key = new IndicesRequestCache.Key(indexShard.shardId(), getTermBytes(), getReaderCacheKeyId(secondReader)); + key = new IndicesRequestCache.Key(indexShard.shardId(), getTermBytes(), getReaderCacheKeyId(secondReader), indexShard.hashCode()); cache.onRemoval( new RemovalNotification, BytesReference>( new ICacheKey<>(key), @@ -1152,11 +1177,11 @@ public void testEqualsKey() throws IOException { IOUtils.close(reader1, reader2, writer, dir); IndexShard indexShard = mock(IndexShard.class); when(indexShard.state()).thenReturn(IndexShardState.STARTED); - IndicesRequestCache.Key key1 = new IndicesRequestCache.Key(shardId, new TestBytesReference(1), rKey1); - IndicesRequestCache.Key key2 = new IndicesRequestCache.Key(shardId, new TestBytesReference(1), rKey1); - IndicesRequestCache.Key key3 = new IndicesRequestCache.Key(shardId1, new TestBytesReference(1), rKey1); - IndicesRequestCache.Key key4 = new IndicesRequestCache.Key(shardId, new TestBytesReference(1), rKey2); - IndicesRequestCache.Key key5 = new IndicesRequestCache.Key(shardId, new TestBytesReference(2), rKey2); + IndicesRequestCache.Key key1 = new IndicesRequestCache.Key(shardId, new TestBytesReference(1), rKey1, shardId.hashCode()); + IndicesRequestCache.Key key2 = new IndicesRequestCache.Key(shardId, new TestBytesReference(1), rKey1, shardId.hashCode()); + IndicesRequestCache.Key key3 = new IndicesRequestCache.Key(shardId1, new TestBytesReference(1), rKey1, shardId1.hashCode()); + IndicesRequestCache.Key key4 = new IndicesRequestCache.Key(shardId, new TestBytesReference(1), rKey2, shardId.hashCode()); + IndicesRequestCache.Key key5 = new IndicesRequestCache.Key(shardId, new TestBytesReference(2), rKey2, shardId.hashCode()); String s = "Some other random object"; assertEquals(key1, key1); assertEquals(key1, key2); @@ -1170,7 +1195,12 @@ public void testEqualsKey() throws IOException { public void testSerializationDeserializationOfCacheKey() throws Exception { IndicesService.IndexShardCacheEntity shardCacheEntity = new IndicesService.IndexShardCacheEntity(indexShard); String readerCacheKeyId = UUID.randomUUID().toString(); - IndicesRequestCache.Key key1 = new IndicesRequestCache.Key(indexShard.shardId(), getTermBytes(), readerCacheKeyId); + IndicesRequestCache.Key key1 = new IndicesRequestCache.Key( + indexShard.shardId(), + getTermBytes(), + readerCacheKeyId, + indexShard.hashCode() + ); BytesReference bytesReference = null; try (BytesStreamOutput out = new BytesStreamOutput()) { key1.writeTo(out); @@ -1185,6 +1215,99 @@ public void testSerializationDeserializationOfCacheKey() throws Exception { assertEquals(getTermBytes(), key2.value); } + public void testGetOrComputeConcurrentlyWithMultipleIndices() throws Exception { + threadPool = getThreadPool(); + int numberOfIndices = randomIntBetween(2, 5); + List indicesList = new ArrayList<>(); + List indexShardList = Collections.synchronizedList(new ArrayList<>()); + for (int i = 0; i < numberOfIndices; i++) { + String indexName = "test" + i; + indicesList.add(indexName); + IndexShard indexShard = createIndex( + indexName, + Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).build() + ).getShard(0); + indexShardList.add(indexShard); + } + // Create a cache with 2kb to cause evictions and test that flow as well. + IndicesRequestCache cache = getIndicesRequestCache(Settings.builder().put(INDICES_CACHE_QUERY_SIZE.getKey(), "2kb").build()); + Map readerMap = new ConcurrentHashMap<>(); + Map entityMap = new ConcurrentHashMap<>(); + Map writerMap = new ConcurrentHashMap<>(); + int numberOfItems = randomIntBetween(200, 400); + for (int i = 0; i < numberOfIndices; i++) { + IndexShard indexShard = indexShardList.get(i); + entityMap.put(indexShard, new IndicesService.IndexShardCacheEntity(indexShard)); + Directory dir = newDirectory(); + IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); + for (int j = 0; j < numberOfItems; j++) { + writer.addDocument(newDoc(j, generateString(randomIntBetween(4, 50)))); + } + writerMap.put(indexShard, writer); + DirectoryReader reader = OpenSearchDirectoryReader.wrap(DirectoryReader.open(writer), indexShard.shardId()); + readerMap.put(indexShard, reader); + } + + CountDownLatch latch = new CountDownLatch(numberOfItems); + ExecutorService executorService = Executors.newFixedThreadPool(5); + for (int i = 0; i < numberOfItems; i++) { + int finalI = i; + executorService.submit(() -> { + int randomIndexPosition = randomIntBetween(0, numberOfIndices - 1); + IndexShard indexShard = indexShardList.get(randomIndexPosition); + TermQueryBuilder termQuery = new TermQueryBuilder("id", generateString(randomIntBetween(4, 50))); + BytesReference termBytes = null; + try { + termBytes = XContentHelper.toXContent(termQuery, MediaTypeRegistry.JSON, false); + } catch (IOException e) { + throw new RuntimeException(e); + } + Loader loader = new Loader(readerMap.get(indexShard), finalI); + try { + cache.getOrCompute(entityMap.get(indexShard), loader, readerMap.get(indexShard), termBytes); + } catch (Exception e) { + throw new RuntimeException(e); + } + latch.countDown(); + }); + } + for (int i = 0; i < numberOfIndices; i++) { + IndexShard indexShard = indexShardList.get(i); + IndicesService.IndexShardCacheEntity entity = entityMap.get(indexShard); + RequestCacheStats stats = entity.stats().stats(); + assertTrue(stats.getMemorySizeInBytes() >= 0); + assertTrue(stats.getMissCount() >= 0); + assertTrue(stats.getEvictions() >= 0); + } + cache.invalidateAll(); + for (int i = 0; i < numberOfIndices; i++) { + IndexShard indexShard = indexShardList.get(i); + IndicesService.IndexShardCacheEntity entity = entityMap.get(indexShard); + RequestCacheStats stats = entity.stats().stats(); + assertEquals(0, stats.getMemorySizeInBytes()); + } + + for (int i = 0; i < numberOfIndices; i++) { + IndexShard indexShard = indexShardList.get(i); + readerMap.get(indexShard).close(); + writerMap.get(indexShard).close(); + writerMap.get(indexShard).getDirectory().close(); + } + IOUtils.close(cache); + executorService.shutdownNow(); + } + + public static String generateString(int length) { + String characters = "abcdefghijklmnopqrstuvwxyz"; + StringBuilder sb = new StringBuilder(length); + Random random = new Random(); + for (int i = 0; i < length; i++) { + int index = random.nextInt(characters.length()); + sb.append(characters.charAt(index)); + } + return sb.toString(); + } + private class TestBytesReference extends AbstractBytesReference { int dummyValue;