diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e5479e4e..ad8f13179 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,3 +38,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939) * Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889) * Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957) +* Add quantization state cache [#1960](https://github.com/opensearch-project/k-NN/pull/1960) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 4ced38b38..73f43d3d1 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -24,6 +24,7 @@ import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryCacheManagerDto; import org.opensearch.knn.index.util.IndexHyperParametersUtil; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; import org.opensearch.monitor.jvm.JvmInfo; import org.opensearch.monitor.os.OsProbe; @@ -88,6 +89,8 @@ public class KNNSettings { * for native engines. */ public static final String KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED = "knn.use.format.enabled"; + public static final String QUANTIZATION_STATE_CACHE_SIZE_LIMIT = "knn.quantization.cache.size.limit"; + public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes"; /** * Default setting values @@ -106,6 +109,11 @@ public class KNNSettings { public static final String KNN_DEFAULT_VECTOR_STREAMING_MEMORY_LIMIT_PCT = "1%"; public static final Integer ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE = -1; + public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 5; // By default, set aside 5% of the JVM for + // the limit + public static final Integer KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // Quantization state cache limit cannot exceed + // 10% of the JVM heap + public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60; /** * Settings Definition @@ -272,6 +280,44 @@ public class KNNSettings { NodeScope ); + /* + * Quantization state cache settings + */ + public static final Setting QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING = new Setting( + QUANTIZATION_STATE_CACHE_SIZE_LIMIT, + percentageAsString(KNN_DEFAULT_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE), + (s) -> { + ByteSizeValue userDefinedLimit = parseBytesSizeValueOrHeapRatio(s, QUANTIZATION_STATE_CACHE_SIZE_LIMIT); + + // parseBytesSizeValueOrHeapRatio will make sure that the value entered falls between 0 and 100% of the + // JVM heap. However, we want the maximum percentage of the heap to be much smaller. So, we add + // some additional validation here before returning + ByteSizeValue jvmHeapSize = JvmInfo.jvmInfo().getMem().getHeapMax(); + if ((userDefinedLimit.getKbFrac() / jvmHeapSize.getKbFrac()) > percentageAsFraction( + KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE + )) { + throw new OpenSearchParseException( + "{} ({} KB) cannot exceed {}% of the heap ({} KB).", + QUANTIZATION_STATE_CACHE_SIZE_LIMIT, + userDefinedLimit.getKb(), + KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE, + jvmHeapSize.getKb() + ); + } + + return userDefinedLimit; + }, + NodeScope, + Dynamic + ); + + public static final Setting QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING = Setting.positiveTimeSetting( + QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES, + TimeValue.timeValueMinutes(KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES), + NodeScope, + Dynamic + ); + /** * Dynamic settings */ @@ -349,6 +395,13 @@ private void setSettingsUpdateConsumers() { NativeMemoryCacheManager.getInstance().rebuildCache(builder.build()); }, Stream.concat(dynamicCacheSettings.values().stream(), FEATURE_FLAGS.values().stream()).collect(Collectors.toUnmodifiableList())); + clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, it -> { + QuantizationStateCache.getInstance().setMaxCacheSizeInKB(it.getKb()); + QuantizationStateCache.getInstance().rebuildCache(); + }); + clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, it -> { + QuantizationStateCache.getInstance().rebuildCache(); + }); } /** @@ -400,6 +453,14 @@ private Setting getSetting(String key) { return KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING; } + if (QUANTIZATION_STATE_CACHE_SIZE_LIMIT.equals(key)) { + return QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING; + } + + if (QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES.equals(key)) { + return QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -419,7 +480,9 @@ public List> getSettings() { ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING, KNN_FAISS_AVX2_DISABLED_SETTING, KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING, - KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING + KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING, + QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, + QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java new file mode 100644 index 000000000..ba26d517d --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java @@ -0,0 +1,125 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.RemovalCause; +import com.google.common.cache.RemovalNotification; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.index.KNNSettings; + +import java.io.IOException; +import java.time.Instant; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES; +import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_SIZE_LIMIT; + +/** + * A thread-safe singleton cache that contains quantization states. + */ +@Log4j2 +public class QuantizationStateCache { + + private static volatile QuantizationStateCache instance; + private Cache cache; + @Getter + @Setter + private long maxCacheSizeInKB; + @Getter + private Instant evictedDueToSizeAt; + + @VisibleForTesting + QuantizationStateCache() { + maxCacheSizeInKB = ((ByteSizeValue) KNNSettings.state().getSettingValue(QUANTIZATION_STATE_CACHE_SIZE_LIMIT)).getKb(); + buildCache(); + } + + /** + * Gets the singleton instance of the cache. + * @return QuantizationStateCache + */ + public static QuantizationStateCache getInstance() { + if (instance == null) { + synchronized (QuantizationStateCache.class) { + if (instance == null) { + instance = new QuantizationStateCache(); + } + } + } + return instance; + } + + private void buildCache() { + this.cache = CacheBuilder.newBuilder().concurrencyLevel(1).maximumWeight(maxCacheSizeInKB).weigher((k, v) -> { + try { + return ((QuantizationState) v).toByteArray().length; + } catch (IOException e) { + throw new RuntimeException(e); + } + }) + .expireAfterAccess( + ((TimeValue) KNNSettings.state().getSettingValue(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES)).getMinutes(), + TimeUnit.MINUTES + ) + .removalListener(this::onRemoval) + .build(); + } + + public synchronized void rebuildCache() { + clear(); + buildCache(); + } + + /** + * Retrieves the quantization state associated with a given field name. + * @param fieldName The name of the field. + * @return The associated QuantizationState, or null if not present. + */ + public QuantizationState getQuantizationState(String fieldName) { + return cache.getIfPresent(fieldName); + } + + /** + * Adds or updates a quantization state in the cache. + * @param fieldName The name of the field. + * @param quantizationState The quantization state to store. + */ + public void addQuantizationState(String fieldName, QuantizationState quantizationState) { + cache.put(fieldName, quantizationState); + } + + /** + * Removes the quantization state associated with a given field name. + * @param fieldName The name of the field. + */ + public void evict(String fieldName) { + cache.invalidate(fieldName); + } + + private void onRemoval(RemovalNotification removalNotification) { + if (RemovalCause.SIZE == removalNotification.getCause()) { + updateEvictedDueToSizeAt(); + } + } + + private void updateEvictedDueToSizeAt() { + evictedDueToSizeAt = Instant.now(); + } + + /** + * Clears all entries from the cache. + */ + public void clear() { + cache.invalidateAll(); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java new file mode 100644 index 000000000..e5381aec7 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java @@ -0,0 +1,450 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import com.google.common.collect.ImmutableSet; +import lombok.SneakyThrows; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING; +import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING; +import static org.opensearch.knn.quantization.enums.ScalarQuantizationType.ONE_BIT; + +public class QuantizationStateCacheTests extends KNNTestCase { + + @SneakyThrows + public void testSingleThreadedAddAndRetrieve() { + String fieldName = "singleThreadField"; + QuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f } + ); + + String cacheSize = "10%"; + TimeValue expiry = TimeValue.timeValueMinutes(30); + + Settings settings = Settings.builder() + .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) + .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) + .build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) + ); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + QuantizationStateCache cache = QuantizationStateCache.getInstance(); + clusterService.getClusterSettings().applySettings(settings); + + // Add state + cache.addQuantizationState(fieldName, state); + + QuantizationState retrievedState = cache.getQuantizationState(fieldName); + assertNotNull("State should be retrieved successfully", retrievedState); + assertSame("Retrieved state should be the same instance as the one added", state, retrievedState); + } + + @SneakyThrows + public void testMultiThreadedAddAndRetrieve() { + int threadCount = 10; + ExecutorService executorService = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + String fieldName = "multiThreadField"; + QuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f } + ); + String cacheSize = "10%"; + TimeValue expiry = TimeValue.timeValueMinutes(30); + + Settings settings = Settings.builder() + .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) + .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) + .build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) + ); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + QuantizationStateCache cache = QuantizationStateCache.getInstance(); + clusterService.getClusterSettings().applySettings(settings); + + // Add state from multiple threads + for (int i = 0; i < threadCount; i++) { + executorService.submit(() -> { + try { + cache.addQuantizationState(fieldName, state); + } finally { + latch.countDown(); + } + }); + } + + // Wait for all threads to finish + latch.await(); + executorService.shutdown(); + + QuantizationState retrievedState = cache.getQuantizationState(fieldName); + assertNotNull("State should be retrieved successfully", retrievedState); + assertSame("Retrieved state should be the same instance as the one added", state, retrievedState); + } + + @SneakyThrows + public void testMultiThreadedEvict() { + int threadCount = 10; + ExecutorService executorService = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + String fieldName = "multiThreadEvictField"; + QuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f } + ); + String cacheSize = "10%"; + TimeValue expiry = TimeValue.timeValueMinutes(30); + + Settings settings = Settings.builder() + .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) + .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) + .build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) + ); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + QuantizationStateCache cache = QuantizationStateCache.getInstance(); + + clusterService.getClusterSettings().applySettings(settings); + + cache.addQuantizationState(fieldName, state); + + // Evict state from multiple threads + for (int i = 0; i < threadCount; i++) { + executorService.submit(() -> { + try { + cache.evict(fieldName); + } finally { + latch.countDown(); + } + }); + } + + // Wait for all threads to finish + latch.await(); + executorService.shutdown(); + + QuantizationState retrievedState = cache.getQuantizationState(fieldName); + assertNull("State should be null", retrievedState); + } + + @SneakyThrows + public void testConcurrentAddAndEvict() { + int threadCount = 10; + ExecutorService executorService = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + String fieldName = "concurrentAddEvictField"; + QuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f } + ); + String cacheSize = "10%"; + TimeValue expiry = TimeValue.timeValueMinutes(30); + + Settings settings = Settings.builder() + .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) + .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) + .build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) + ); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + QuantizationStateCache cache = QuantizationStateCache.getInstance(); + clusterService.getClusterSettings().applySettings(settings); + + // Concurrently add and evict state from multiple threads + for (int i = 0; i < threadCount; i++) { + if (i % 2 == 0) { + executorService.submit(() -> { + try { + cache.addQuantizationState(fieldName, state); + } finally { + latch.countDown(); + } + }); + } else { + executorService.submit(() -> { + try { + cache.evict(fieldName); + } finally { + latch.countDown(); + } + }); + } + + } + + // Wait for all threads to finish + latch.await(); + executorService.shutdown(); + + // Since operations are concurrent, we can't be sure of the final state, but we can assert that the cache handles it gracefully + QuantizationState retrievedState = cache.getQuantizationState(fieldName); + assertTrue("Final state should be either null or the added state", retrievedState == null || retrievedState == state); + } + + @SneakyThrows + public void testMultipleThreadedCacheClear() { + int threadCount = 10; + ExecutorService executorService = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + String fieldName = "multiThreadField"; + QuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f } + ); + String cacheSize = "10%"; + TimeValue expiry = TimeValue.timeValueMinutes(30); + + Settings settings = Settings.builder() + .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) + .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) + .build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) + ); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + QuantizationStateCache cache = QuantizationStateCache.getInstance(); + clusterService.getClusterSettings().applySettings(settings); + cache.addQuantizationState(fieldName, state); + + // Clear cache from multiple threads + for (int i = 0; i < threadCount; i++) { + executorService.submit(() -> { + try { + cache.clear(); + } finally { + latch.countDown(); + } + }); + } + + // Wait for all threads to finish + latch.await(); + executorService.shutdown(); + + QuantizationState retrievedState = cache.getQuantizationState(fieldName); + assertNull("State should be null", retrievedState); + } + + @SneakyThrows + public void testRebuild() { + int threadCount = 10; + ExecutorService executorService = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + String fieldName = "rebuildField"; + QuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f } + ); + String cacheSize = "10%"; + TimeValue expiry = TimeValue.timeValueMinutes(30); + + Settings settings = Settings.builder() + .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) + .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) + .build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) + ); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + QuantizationStateCache cache = QuantizationStateCache.getInstance(); + cache.addQuantizationState(fieldName, state); + + // Rebuild cache from multiple threads + for (int i = 0; i < threadCount; i++) { + executorService.submit(() -> { + try { + cache.rebuildCache(); + } finally { + latch.countDown(); + } + }); + } + + // Wait for all threads to finish + latch.await(); + executorService.shutdown(); + + QuantizationState retrievedState = cache.getQuantizationState(fieldName); + assertNull("State should be null", retrievedState); + } + + @SneakyThrows + public void testRebuildOnCacheSizeSettingsChange() { + int threadCount = 10; + ExecutorService executorService = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + String fieldName = "rebuildField"; + QuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f } + ); + + Settings settings = Settings.builder().build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) + ); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + Client client = mock(Client.class); + + KNNSettings.state().initialize(client, clusterService); + + QuantizationStateCache cache = QuantizationStateCache.getInstance(); + cache.rebuildCache(); + long maxCacheSizeInKB = cache.getMaxCacheSizeInKB(); + cache.addQuantizationState(fieldName, state); + + String newCacheSize = "10%"; + + Settings newSettings = Settings.builder().put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), newCacheSize).build(); + + // Rebuild cache from multiple threads + for (int i = 0; i < threadCount; i++) { + executorService.submit(() -> { + try { + clusterService.getClusterSettings().applySettings(newSettings); + } finally { + latch.countDown(); + } + }); + } + + // Wait for all threads to finish + latch.await(); + executorService.shutdown(); + + QuantizationState retrievedState = cache.getQuantizationState(fieldName); + assertNull("State should be null", retrievedState); + assertNotEquals(maxCacheSizeInKB, cache.getMaxCacheSizeInKB()); + } + + @SneakyThrows + public void testRebuildOnTimeExpirySettingsChange() { + int threadCount = 10; + ExecutorService executorService = Executors.newFixedThreadPool(threadCount); + CountDownLatch latch = new CountDownLatch(threadCount); + String fieldName = "rebuildField"; + QuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f } + ); + + Settings settings = Settings.builder().build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) + ); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + Client client = mock(Client.class); + + KNNSettings.state().initialize(client, clusterService); + + QuantizationStateCache cache = QuantizationStateCache.getInstance(); + cache.addQuantizationState(fieldName, state); + + TimeValue newExpiry = TimeValue.timeValueMinutes(30); + + Settings newSettings = Settings.builder().put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), newExpiry).build(); + + // Rebuild cache from multiple threads + for (int i = 0; i < threadCount; i++) { + executorService.submit(() -> { + try { + clusterService.getClusterSettings().applySettings(newSettings); + } finally { + latch.countDown(); + } + }); + } + + // Wait for all threads to finish + latch.await(); + executorService.shutdown(); + + QuantizationState retrievedState = cache.getQuantizationState(fieldName); + assertNull("State should be null", retrievedState); + } + + public void testCacheEvictionDueToSize() { + String fieldName = "evictionField"; + // States have size of slightly over 500 bytes so that adding two will reach the max size of 1 kb for the cache + int arrayLength = 112; + float[] arr = new float[arrayLength]; + float[] arr2 = new float[arrayLength]; + for (int i = 0; i < arrayLength; i++) { + arr[i] = i; + arr[i] = i + 1; + } + QuantizationState state = new OneBitScalarQuantizationState(new ScalarQuantizationParams(ONE_BIT), arr); + QuantizationState state2 = new OneBitScalarQuantizationState(new ScalarQuantizationParams(ONE_BIT), arr2); + long cacheSize = 1; + Settings settings = Settings.builder().build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) + ); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + QuantizationStateCache cache = new QuantizationStateCache(); + cache.setMaxCacheSizeInKB(cacheSize); + cache.rebuildCache(); + cache.addQuantizationState(fieldName, state); + cache.addQuantizationState(fieldName, state2); + cache.clear(); + assertNotNull(cache.getEvictedDueToSizeAt()); + } +}