From eb99064e326834e4f4d437d1a3bbc0aac2617608 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 5 Sep 2024 10:29:20 -0700 Subject: [PATCH] Add model version to model metadata and change model metadata reads to be from cluster metadata (#2005) * Add model version to model metadata Signed-off-by: Ryan Bogan * Add model version to model metadata and change model metadata reads to be from cluster metadata Signed-off-by: Ryan Bogan * Add changelog entry Signed-off-by: Ryan Bogan * Set version from config context Signed-off-by: Ryan Bogan * Fix spotless Signed-off-by: Ryan Bogan * Update model index mappings Signed-off-by: Ryan Bogan * Change field mapper to read model version Signed-off-by: Ryan Bogan * Fix tests Signed-off-by: Ryan Bogan * remove println Signed-off-by: John Mazanec --------- Signed-off-by: Ryan Bogan Signed-off-by: John Mazanec Co-authored-by: John Mazanec (cherry picked from commit 6814c8f60707ff8e3be835558ab35ae5a9ea0c1a) --- .../opensearch-knn.release-notes-2.17.0.0.md | 1 + .../opensearch/knn/common/KNNConstants.java | 1 + .../knn/index/mapper/ModelFieldMapper.java | 2 +- .../opensearch/knn/index/util/IndexUtil.java | 2 + .../org/opensearch/knn/indices/ModelDao.java | 1 + .../opensearch/knn/indices/ModelMetadata.java | 56 ++++- .../org/opensearch/knn/indices/ModelUtil.java | 4 +- .../opensearch/knn/training/TrainingJob.java | 3 +- .../knn/training/TrainingJobRunner.java | 7 +- src/main/resources/mappings/model-index.json | 3 + .../index/KNNCreateIndexFromModelTests.java | 4 +- .../KNN80DocValuesConsumerTests.java | 161 +++++++------- .../knn/index/codec/KNNCodecTestCase.java | 156 ++++++------- .../mapper/KNNVectorFieldMapperTests.java | 7 +- .../knn/indices/ModelCacheTests.java | 37 +++- .../opensearch/knn/indices/ModelDaoTests.java | 43 ++-- .../knn/indices/ModelMetadataTests.java | 206 +++++++++++------- .../opensearch/knn/indices/ModelTests.java | 50 +++-- .../knn/indices/ModelUtilTests.java | 21 +- .../transport/GetModelResponseTests.java | 7 +- ...oveModelFromCacheTransportActionTests.java | 4 +- .../transport/TrainingModelRequestTests.java | 4 +- ...ateModelGraveyardTransportActionTests.java | 4 +- .../UpdateModelMetadataRequestTests.java | 10 +- ...dateModelMetadataTransportActionTests.java | 4 +- .../knn/training/TrainingJobRunnerTests.java | 2 +- .../knn/training/TrainingJobTests.java | 3 +- 27 files changed, 484 insertions(+), 319 deletions(-) diff --git a/release-notes/opensearch-knn.release-notes-2.17.0.0.md b/release-notes/opensearch-knn.release-notes-2.17.0.0.md index 9876b7b38..8b4aa8e95 100644 --- a/release-notes/opensearch-knn.release-notes-2.17.0.0.md +++ b/release-notes/opensearch-knn.release-notes-2.17.0.0.md @@ -11,6 +11,7 @@ Compatible with OpenSearch 2.17.0 * Add spaceType as a top level optional parameter while creating vector field. [#2044](https://github.com/opensearch-project/k-NN/pull/2044) ### Enhancements * Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950) +* Add model version to model metadata and change model metadata reads to be from cluster metadata [#2005](https://github.com/opensearch-project/k-NN/pull/2005) ### Bug Fixes * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) * Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 11024076f..ed21d3005 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -77,6 +77,7 @@ public class KNNConstants { public static final String TOP_LEVEL_SPACE_TYPE_FEATURE = "top_level_space_type_feature"; public static final String RADIAL_SEARCH_KEY = "radial_search"; + public static final String MODEL_VERSION = "model_version"; public static final String QUANTIZATION_STATE_FILE_SUFFIX = "osknnqstate"; // Lucene specific constants diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index b7bbc5a0d..42a0c10c4 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -290,7 +290,7 @@ private static KNNMethodConfigContext getKNNMethodConfigContextFromModelMetadata return KNNMethodConfigContext.builder() .vectorDataType(modelMetadata.getVectorDataType()) .dimension(modelMetadata.getDimension()) - .versionCreated(Version.V_2_14_0) + .versionCreated(modelMetadata.getModelVersion()) .mode(modelMetadata.getMode()) .compressionLevel(modelMetadata.getCompressionLevel()) .build(); diff --git a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java index 4a0118f58..02aa1e954 100644 --- a/src/main/java/org/opensearch/knn/index/util/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/util/IndexUtil.java @@ -54,6 +54,7 @@ public class IndexUtil { private static final Version MINIMAL_RESCORE_FEATURE = Version.V_2_17_0; private static final Version MINIMAL_MODE_AND_COMPRESSION_FEATURE = Version.V_2_17_0; private static final Version MINIMAL_TOP_LEVEL_SPACE_TYPE_FEATURE = Version.V_2_17_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VERSION = Version.V_2_17_0; // public so neural search can access it public static final Map minimalRequiredVersionMap = initializeMinimalRequiredVersionMap(); public static final Set VECTOR_DATA_TYPES_NOT_SUPPORTING_ENCODERS = Set.of(VectorDataType.BINARY, VectorDataType.BYTE); @@ -394,6 +395,7 @@ private static Map initializeMinimalRequiredVersionMap() { put(RESCORE_PARAMETER, MINIMAL_RESCORE_FEATURE); put(KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE, MINIMAL_MODE_AND_COMPRESSION_FEATURE); put(KNNConstants.TOP_LEVEL_SPACE_TYPE_FEATURE, MINIMAL_TOP_LEVEL_SPACE_TYPE_FEATURE); + put(KNNConstants.MODEL_VERSION, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VERSION); } }; diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 326d595a4..d0abe8612 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -301,6 +301,7 @@ private void putInternal(Model model, ActionListener listener, Do if (CompressionLevel.isConfigured(modelMetadata.getCompressionLevel())) { put(KNNConstants.COMPRESSION_LEVEL_PARAMETER, modelMetadata.getCompressionLevel().getName()); } + put(KNNConstants.MODEL_VERSION, modelMetadata.getModelVersion()); MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); if (!methodComponentContext.getName().isEmpty()) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index 620e520ba..17eed833e 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -15,6 +15,7 @@ import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; +import org.opensearch.Version; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -59,6 +60,7 @@ public class ModelMetadata implements Writeable, ToXContentObject { private String error; @Getter private final CompressionLevel compressionLevel; + private final Version version; /** * Constructor @@ -66,7 +68,6 @@ public class ModelMetadata implements Writeable, ToXContentObject { * @param in Stream input */ public ModelMetadata(StreamInput in) throws IOException { - String tempTrainingNodeAssignment; this.knnEngine = KNNEngine.getEngine(in.readString()); this.spaceType = SpaceType.getSpace(in.readString()); this.dimension = in.readInt(); @@ -96,7 +97,6 @@ public ModelMetadata(StreamInput in) throws IOException { } else { this.vectorDataType = VectorDataType.DEFAULT; } - if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE)) { this.mode = Mode.fromName(in.readOptionalString()); this.compressionLevel = CompressionLevel.fromName(in.readOptionalString()); @@ -105,6 +105,11 @@ public ModelMetadata(StreamInput in) throws IOException { this.compressionLevel = CompressionLevel.NOT_CONFIGURED; } + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MODEL_VERSION)) { + this.version = Version.fromString(in.readString()); + } else { + this.version = Version.V_EMPTY; + } } /** @@ -133,7 +138,8 @@ public ModelMetadata( MethodComponentContext methodComponentContext, VectorDataType vectorDataType, Mode mode, - CompressionLevel compressionLevel + CompressionLevel compressionLevel, + Version version ) { this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null"); this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null"); @@ -159,6 +165,7 @@ public ModelMetadata( this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null"); this.mode = Objects.requireNonNull(mode, "Mode must not be null"); this.compressionLevel = Objects.requireNonNull(compressionLevel, "Compression level must not be null"); + this.version = Objects.requireNonNull(version, "model version must not be null"); } /** @@ -246,6 +253,14 @@ public VectorDataType getVectorDataType() { return vectorDataType; } + /** + * Getter for the model version + * @return version + */ + public Version getModelVersion() { + return version; + } + /** * setter for model's state * @@ -279,7 +294,8 @@ public String toString() { methodComponentContext.toClusterStateString(), vectorDataType.getValue(), mode.getName(), - compressionLevel.getName() + compressionLevel.getName(), + version.toString() ); } @@ -317,6 +333,7 @@ public int hashCode() { .append(getVectorDataType()) .append(getMode()) .append(getCompressionLevel()) + .append(getModelVersion()) .toHashCode(); } @@ -329,15 +346,15 @@ public int hashCode() { public static ModelMetadata fromString(String modelMetadataString) { String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1); int length = modelMetadataArray.length; - - if (length < 7 || length > 12) { + if (length < 7 || length > 13) { throw new IllegalArgumentException( "Illegal format for model metadata. Must be of the form " + "\",,,,,,\" or " + "\",,,,,,,\" or " + "\",,,,,,,,\" or " + "\",,,,,,,,,\". or " - + "\",,,,,,,,,,,\"." + + "\",,,,,,,,,,,\" or " + + "\",,,,,,,,,,,,\"." ); } @@ -357,6 +374,7 @@ public static ModelMetadata fromString(String modelMetadataString) { CompressionLevel compressionLevel = length > 11 ? CompressionLevel.fromName(modelMetadataArray[11]) : CompressionLevel.NOT_CONFIGURED; + Version version = length > 12 ? Version.fromString(modelMetadataArray[12]) : Version.V_EMPTY; log.debug(getLogMessage(length)); @@ -372,7 +390,8 @@ public static ModelMetadata fromString(String modelMetadataString) { methodComponentContext, vectorDataType, mode, - compressionLevel + compressionLevel, + version ); } @@ -386,9 +405,10 @@ private static String getLogMessage(int length) { return "Model metadata contains training node assignment and method context."; case 10: return "Model metadata contains training node assignment, method context and vector data type."; - case 11: case 12: return "Model metadata contains mode and compression level"; + case 13: + return "Model metadata contains training node assignment, method context, vector data type, and version"; default: throw new IllegalArgumentException("Unexpected metadata array length: " + length); } @@ -423,6 +443,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD); Object mode = modelSourceMap.get(KNNConstants.MODE_PARAMETER); Object compressionLevel = modelSourceMap.get(KNNConstants.COMPRESSION_LEVEL_PARAMETER); + Object version = modelSourceMap.get(KNNConstants.MODEL_VERSION); if (trainingNodeAssignment == null) { trainingNodeAssignment = ""; @@ -447,6 +468,10 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m vectorDataType = VectorDataType.DEFAULT.getValue(); } + if (version == null) { + version = Version.V_EMPTY; + } + ModelMetadata modelMetadata = new ModelMetadata( KNNEngine.getEngine(objectToString(engine)), SpaceType.getSpace(objectToString(space)), @@ -459,7 +484,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m (MethodComponentContext) methodComponentContext, VectorDataType.get(objectToString(vectorDataType)), Mode.fromName(objectToString(mode)), - CompressionLevel.fromName(objectToString(compressionLevel)) + CompressionLevel.fromName(objectToString(compressionLevel)), + Version.fromString(version.toString()) ); return modelMetadata; } @@ -486,6 +512,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(mode.getName()); out.writeOptionalString(compressionLevel.getName()); } + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MODEL_VERSION)) { + out.writeString(version.toString()); + } } @Override @@ -517,6 +546,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(KNNConstants.COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName()); } } + if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.MODEL_VERSION)) { + String versionString = "unknown"; + if (version != Version.V_EMPTY) { + versionString = version.toString(); + } + builder.field(KNNConstants.MODEL_VERSION, versionString); + } return builder; } } diff --git a/src/main/java/org/opensearch/knn/indices/ModelUtil.java b/src/main/java/org/opensearch/knn/indices/ModelUtil.java index ac0e4fb79..d63f02b2b 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelUtil.java +++ b/src/main/java/org/opensearch/knn/indices/ModelUtil.java @@ -48,8 +48,8 @@ public static ModelMetadata getModelMetadata(final String modelId) { if (StringUtils.isEmpty(modelId)) { return null; } - final Model model = ModelCache.getInstance().get(modelId); - final ModelMetadata modelMetadata = model.getModelMetadata(); + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + final ModelMetadata modelMetadata = modelDao.getMetadata(modelId); if (isModelCreated(modelMetadata) == false) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId)); } diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index 90b2762c2..b479192e8 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -99,7 +99,8 @@ public TrainingJob( knnMethodContext.getMethodComponentContext(), knnMethodConfigContext.getVectorDataType(), mode, - compressionLevel + compressionLevel, + knnMethodConfigContext.getVersionCreated() ), null, this.modelId diff --git a/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java b/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java index 8884f8102..5b2bb26b6 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java @@ -16,7 +16,6 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.action.index.IndexResponse; import org.opensearch.common.ValidationException; -import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -166,11 +165,11 @@ private void train(TrainingJob trainingJob) { private void serializeModel(TrainingJob trainingJob, ActionListener listener, boolean update) throws IOException, ExecutionException, InterruptedException { if (update) { - Model model = modelDao.get(trainingJob.getModelId()); - if (model.getModelMetadata().getState().equals(ModelState.TRAINING)) { + ModelMetadata modelMetadata = modelDao.getMetadata(trainingJob.getModelId()); + if (modelMetadata.getState().equals(ModelState.TRAINING)) { modelDao.update(trainingJob.getModel(), listener); } else { - logger.info("Model state is {}. Skipping serialization of trained data", model.getModelMetadata().getState()); + logger.info("Model state is {}. Skipping serialization of trained data", modelMetadata.getState()); } } else { modelDao.put(trainingJob.getModel(), listener); diff --git a/src/main/resources/mappings/model-index.json b/src/main/resources/mappings/model-index.json index e7879cced..8d16b98ab 100644 --- a/src/main/resources/mappings/model-index.json +++ b/src/main/resources/mappings/model-index.json @@ -38,6 +38,9 @@ }, "compression_level": { "type": "keyword" + }, + "model_version": { + "type": "keyword" } } } diff --git a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java index 28ef41e04..d6ee2d7fd 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java @@ -12,6 +12,7 @@ package org.opensearch.knn.index; import com.google.common.collect.ImmutableMap; +import org.opensearch.Version; import org.opensearch.core.action.ActionListener; import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder; import org.opensearch.common.settings.Settings; @@ -69,7 +70,8 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException MethodComponentContext.EMPTY, VectorDataType.FLOAT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ); Model model = new Model(modelMetadata, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index 786061af8..e1f16006d 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -18,6 +18,8 @@ import org.apache.lucene.store.IOContext; import org.junit.AfterClass; import org.junit.BeforeClass; +import org.mockito.MockedStatic; +import org.mockito.Mockito; import org.opensearch.Version; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -460,85 +462,92 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio ); byte[] modelBytes = JNIService.trainIndex(parameters, dimension, trainingPtr, knnEngine); - Model model = new Model( - new ModelMetadata( - knnEngine, - spaceType, - dimension, - ModelState.CREATED, - "timestamp", - "Empty description", - "", - "", - MethodComponentContext.EMPTY, - VectorDataType.FLOAT, - Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED - ), - modelBytes, - modelId + ModelMetadata modelMetadata = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + "timestamp", + "Empty description", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ); + Model model = new Model(modelMetadata, modelBytes, modelId); JNICommons.freeVectorData(trainingPtr); - // Setup the model cache to return the correct model - ModelDao modelDao = mock(ModelDao.class); - when(modelDao.get(modelId)).thenReturn(model); - ClusterService clusterService = mock(ClusterService.class); - when(clusterService.getSettings()).thenReturn(Settings.EMPTY); - - ClusterSettings clusterSettings = new ClusterSettings( - Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10kb").build(), - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING) - ); - - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - ModelCache.initialize(modelDao, clusterService); - - // Build the segment and field info - String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); - int docsInSegment = 100; - String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); - - SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() - .directory(directory) - .segmentName(segmentName) - .docsInSegment(docsInSegment) - .codec(codec) - .build(); - - FieldInfo[] fieldInfoArray = new FieldInfo[] { - KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) - .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") - .addAttribute(MODEL_ID, modelId) - .build() }; - - FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); - SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - - long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); - - // Add documents to the field - KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); - TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( - docsInSegment, - dimension - ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); - - // The document should be created in the correct location - String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); - assertFileInCorrectLocation(state, expectedFile); - - // The footer should be valid - assertValidFooter(state.directory, expectedFile); - - // The document should be readable by faiss - assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension); - - // The graph creation statistics should be updated - assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + try (MockedStatic modelDaoMockedStatic = Mockito.mockStatic(ModelDao.OpenSearchKNNModelDao.class)) { + // Setup the model cache to return the correct model + ModelDao.OpenSearchKNNModelDao modelDao = mock(ModelDao.OpenSearchKNNModelDao.class); + when(modelDao.get(modelId)).thenReturn(model); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + + modelDaoMockedStatic.when(ModelDao.OpenSearchKNNModelDao::getInstance).thenReturn(modelDao); + + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getSettings()).thenReturn(Settings.EMPTY); + + ClusterSettings clusterSettings = new ClusterSettings( + Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10kb").build(), + ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING) + ); + + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ModelCache.initialize(modelDao, clusterService); + + // Build the segment and field info + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() + .directory(directory) + .segmentName(segmentName) + .docsInSegment(docsInSegment) + .codec(codec) + .build(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(MODEL_ID, modelId) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = + new TestVectorValues.RandomVectorDocValuesProducer(docsInSegment, dimension); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName( + segmentName, + knnEngine.getVersion(), + fieldName, + knnEngine.getExtension() + ); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by faiss + assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension); + + // The graph creation statistics should be updated + assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 174441df8..3d9969a1e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -16,6 +16,8 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; +import org.mockito.MockedStatic; +import org.opensearch.Version; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.xcontent.XContentFactory; @@ -232,82 +234,86 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio ); // Setup model cache - ModelDao modelDao = mock(ModelDao.class); - - // Set model state to created - ModelMetadata modelMetadata1 = new ModelMetadata( - knnEngine, - spaceType, - dimension, - ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), - "", - "", - "", - MethodComponentContext.EMPTY, - VectorDataType.FLOAT, - Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED - ); - - Model mockModel = new Model(modelMetadata1, modelBlob, modelId); - when(modelDao.get(modelId)).thenReturn(mockModel); - when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata1); - - Settings settings = settings(CURRENT).put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10%").build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); - - ClusterService clusterService = mock(ClusterService.class); - when(clusterService.getSettings()).thenReturn(settings); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - - ModelCache.initialize(modelDao, clusterService); - ModelCache.getInstance().removeAll(); - - // Setup Lucene - setUpMockClusterService(); - Directory dir = newFSDirectory(createTempDir()); - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setMergeScheduler(new SerialMergeScheduler()); - iwc.setCodec(codec); - - FieldType fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - fieldType.setDocValuesType(DocValuesType.BINARY); - fieldType.putAttribute(KNNConstants.MODEL_ID, modelId); - fieldType.freeze(); - - // Add the documents to the index - float[][] arrays = { { 1.0f, 3.0f, 4.0f }, { 2.0f, 5.0f, 8.0f }, { 3.0f, 6.0f, 9.0f }, { 4.0f, 7.0f, 10.0f } }; - - RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); - String fieldName = "test_vector"; - for (float[] array : arrays) { - VectorField vectorField = new VectorField(fieldName, array, fieldType); - Document doc = new Document(); - doc.add(vectorField); - writer.addDocument(doc); + try (MockedStatic modelDaoMockedStatic = Mockito.mockStatic(ModelDao.OpenSearchKNNModelDao.class)) { + ModelDao.OpenSearchKNNModelDao modelDao = mock(ModelDao.OpenSearchKNNModelDao.class); + modelDaoMockedStatic.when(ModelDao.OpenSearchKNNModelDao::getInstance).thenReturn(modelDao); + + // Set model state to created + ModelMetadata modelMetadata1 = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY + ); + + Model mockModel = new Model(modelMetadata1, modelBlob, modelId); + when(modelDao.get(modelId)).thenReturn(mockModel); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata1); + + Settings settings = settings(CURRENT).put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10%").build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + ModelCache.initialize(modelDao, clusterService); + ModelCache.getInstance().removeAll(); + + // Setup Lucene + setUpMockClusterService(); + Directory dir = newFSDirectory(createTempDir()); + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setMergeScheduler(new SerialMergeScheduler()); + iwc.setCodec(codec); + + FieldType fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + fieldType.setDocValuesType(DocValuesType.BINARY); + fieldType.putAttribute(KNNConstants.MODEL_ID, modelId); + fieldType.freeze(); + + // Add the documents to the index + float[][] arrays = { { 1.0f, 3.0f, 4.0f }, { 2.0f, 5.0f, 8.0f }, { 3.0f, 6.0f, 9.0f }, { 4.0f, 7.0f, 10.0f } }; + + RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); + String fieldName = "test_vector"; + for (float[] array : arrays) { + VectorField vectorField = new VectorField(fieldName, array, fieldType); + Document doc = new Document(); + doc.add(vectorField); + writer.addDocument(doc); + } + + IndexReader reader = writer.getReader(); + writer.close(); + + // Make sure that search returns the correct results + KNNWeight.initialize(modelDao); + ResourceWatcherService resourceWatcherService = createDisabledResourceWatcherService(); + NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); + float[] query = { 10.0f, 10.0f, 10.0f }; + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy", (BitSetProducer) null), 10); + + assertEquals(3, topDocs.scoreDocs[0].doc); + assertEquals(2, topDocs.scoreDocs[1].doc); + assertEquals(1, topDocs.scoreDocs[2].doc); + assertEquals(0, topDocs.scoreDocs[3].doc); + + reader.close(); + dir.close(); + resourceWatcherService.close(); + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); } - - IndexReader reader = writer.getReader(); - writer.close(); - - // Make sure that search returns the correct results - KNNWeight.initialize(modelDao); - ResourceWatcherService resourceWatcherService = createDisabledResourceWatcherService(); - NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); - float[] query = { 10.0f, 10.0f, 10.0f }; - IndexSearcher searcher = new IndexSearcher(reader); - TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy", (BitSetProducer) null), 10); - - assertEquals(3, topDocs.scoreDocs[0].doc); - assertEquals(2, topDocs.scoreDocs[1].doc); - assertEquals(1, topDocs.scoreDocs[2].doc); - assertEquals(0, topDocs.scoreDocs[3].doc); - - reader.close(); - dir.close(); - resourceWatcherService.close(); - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); } public void testWriteByOldCodec(Codec codec) throws IOException { diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index abc4f563e..e07bf3aed 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -14,6 +14,7 @@ import org.apache.lucene.util.BytesRef; import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.opensearch.Version; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.Explicit; import org.opensearch.common.ValidationException; @@ -223,7 +224,8 @@ public void testBuilder_build_fromModel() { MethodComponentContext.EMPTY, VectorDataType.FLOAT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ); builder.modelId.setValue(modelId); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); @@ -815,7 +817,8 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { MethodComponentContext.EMPTY, VectorDataType.FLOAT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ); when(mockModelDao.getMetadata(modelId)).thenReturn(mockModelMetadata); diff --git a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java index 91bb7d3d9..a31fe4a7d 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java @@ -13,6 +13,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.UncheckedExecutionException; +import org.opensearch.Version; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -51,7 +52,8 @@ public void testGet_normal() throws ExecutionException, InterruptedException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), "hello".getBytes(), modelId @@ -91,7 +93,8 @@ public void testGet_modelDoesNotFitInCache() throws ExecutionException, Interrup MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), new byte[BYTES_PER_KILOBYTES + 1], modelId @@ -152,7 +155,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), new byte[size1], modelId1 @@ -171,7 +175,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), new byte[size2], modelId2 @@ -218,7 +223,8 @@ public void testRemove_normal() throws ExecutionException, InterruptedException MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), new byte[size1], modelId1 @@ -237,7 +243,8 @@ public void testRemove_normal() throws ExecutionException, InterruptedException MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), new byte[size2], modelId2 @@ -289,7 +296,8 @@ public void testRebuild_normal() throws ExecutionException, InterruptedException MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), "hello".getBytes(), modelId @@ -338,7 +346,8 @@ public void testRebuild_afterSettingUpdate() throws ExecutionException, Interrup MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), new byte[modelSize], modelId @@ -410,7 +419,8 @@ public void testContains() throws ExecutionException, InterruptedException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), new byte[modelSize1], modelId1 @@ -455,7 +465,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), new byte[modelSize1], modelId1 @@ -476,7 +487,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), new byte[modelSize2], modelId2 @@ -525,7 +537,8 @@ public void testModelCacheEvictionDueToSize() throws ExecutionException, Interru MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ), new byte[BYTES_PER_KILOBYTES * 2], modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 560ea59b2..1edb5cff2 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -18,6 +18,7 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.ResourceNotFoundException; +import org.opensearch.Version; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.core.action.ActionListener; import org.opensearch.action.DocWriteResponse; @@ -145,7 +146,8 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, modelId @@ -168,7 +170,8 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, modelId @@ -199,7 +202,8 @@ public void testPut_withId() throws InterruptedException, IOException { new MethodComponentContext("test", Collections.emptyMap()), VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, modelId @@ -263,7 +267,8 @@ public void testPut_withoutModel() throws InterruptedException, IOException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, modelId @@ -328,7 +333,8 @@ public void testPut_invalid_badState() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, "any-id" @@ -368,7 +374,8 @@ public void testUpdate() throws IOException, InterruptedException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), null, modelId @@ -410,7 +417,8 @@ public void testUpdate() throws IOException, InterruptedException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, modelId @@ -464,7 +472,8 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, modelId @@ -486,7 +495,8 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), null, modelId @@ -526,7 +536,8 @@ public void testGetMetadata() throws IOException, InterruptedException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); Model model = new Model(modelMetadata, modelBlob, modelId); @@ -606,7 +617,8 @@ public void testDelete() throws IOException, InterruptedException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, modelId @@ -643,7 +655,8 @@ public void testDelete() throws IOException, InterruptedException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, modelId1 @@ -714,7 +727,8 @@ public void testDeleteModelInTrainingWithStepListeners() throws IOException, Exe MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, modelId @@ -759,7 +773,8 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java index 6f0b49285..79340d331 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java @@ -11,6 +11,7 @@ package org.opensearch.knn.indices; +import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.ToXContent; @@ -51,7 +52,8 @@ public void testStreams() throws IOException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); BytesStreamOutput streamOutput = new BytesStreamOutput(); @@ -71,7 +73,8 @@ public void testStreams() throws IOException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.ON_DISK, - CompressionLevel.x16 + CompressionLevel.x16, + Version.CURRENT ); streamOutput = new BytesStreamOutput(); modelMetadata.writeTo(streamOutput); @@ -93,7 +96,8 @@ public void testGetKnnEngine() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(knnEngine, modelMetadata.getKnnEngine()); @@ -113,7 +117,8 @@ public void testGetSpaceType() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(spaceType, modelMetadata.getSpaceType()); @@ -133,7 +138,8 @@ public void testGetDimension() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(dimension, modelMetadata.getDimension()); @@ -153,7 +159,8 @@ public void testGetState() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(modelState, modelMetadata.getState()); @@ -173,7 +180,8 @@ public void testGetTimestamp() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(timeValue, modelMetadata.getTimestamp()); @@ -193,7 +201,8 @@ public void testDescription() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(description, modelMetadata.getDescription()); @@ -213,7 +222,8 @@ public void testGetError() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(error, modelMetadata.getError()); @@ -233,12 +243,34 @@ public void testGetVectorDataType() { MethodComponentContext.EMPTY, vectorDataType, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(vectorDataType, modelMetadata.getVectorDataType()); } + public void testGetModelVersion() { + Version version = Version.CURRENT; + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 12, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.DEFAULT, + Mode.NOT_CONFIGURED, + CompressionLevel.NOT_CONFIGURED, + version + ); + + assertEquals(version, modelMetadata.getModelVersion()); + } + public void testSetState() { ModelState modelState = ModelState.FAILED; ModelMetadata modelMetadata = new ModelMetadata( @@ -253,7 +285,8 @@ public void testSetState() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(modelState, modelMetadata.getState()); @@ -277,7 +310,8 @@ public void testSetError() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(error, modelMetadata.getError()); @@ -297,47 +331,9 @@ public void testToString() { String error = "test-error"; String nodeAssignment = ""; MethodComponentContext methodComponentContext = MethodComponentContext.EMPTY; + Version version = Version.CURRENT; String expected = knnEngine.getName() - + "," - + spaceType.getValue() - + "," - + dimension - + "," - + modelState.getName() - + "," - + timestamp - + "," - + description - + "," - + error - + "," - + nodeAssignment - + "," - + methodComponentContext.toClusterStateString() - + "," - + VectorDataType.DEFAULT.getValue() - + "," - + ","; - - ModelMetadata modelMetadata = new ModelMetadata( - knnEngine, - spaceType, - dimension, - modelState, - timestamp, - description, - error, - nodeAssignment, - MethodComponentContext.EMPTY, - VectorDataType.DEFAULT, - Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED - ); - - assertEquals(expected, modelMetadata.toString()); - - expected = knnEngine.getName() + "," + spaceType.getValue() + "," @@ -359,9 +355,11 @@ public void testToString() { + "," + Mode.ON_DISK.getName() + "," - + CompressionLevel.x32.getName(); + + CompressionLevel.x32.getName() + + "," + + version; - modelMetadata = new ModelMetadata( + ModelMetadata modelMetadata = new ModelMetadata( knnEngine, spaceType, dimension, @@ -373,7 +371,8 @@ public void testToString() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.ON_DISK, - CompressionLevel.x32 + CompressionLevel.x32, + Version.CURRENT ); assertEquals(expected, modelMetadata.toString()); @@ -396,7 +395,8 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata2 = new ModelMetadata( KNNEngine.FAISS, @@ -410,7 +410,8 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata3 = new ModelMetadata( @@ -425,7 +426,8 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata4 = new ModelMetadata( KNNEngine.FAISS, @@ -439,7 +441,8 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata5 = new ModelMetadata( KNNEngine.FAISS, @@ -453,7 +456,8 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata6 = new ModelMetadata( KNNEngine.FAISS, @@ -467,7 +471,8 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata7 = new ModelMetadata( KNNEngine.FAISS, @@ -481,7 +486,8 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, @@ -495,7 +501,8 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -509,7 +516,8 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata10 = new ModelMetadata( @@ -524,7 +532,8 @@ public void testEquals() { new MethodComponentContext("test", Collections.emptyMap()), VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(modelMetadata1, modelMetadata1); @@ -557,7 +566,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata2 = new ModelMetadata( KNNEngine.FAISS, @@ -571,7 +581,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata3 = new ModelMetadata( @@ -586,7 +597,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata4 = new ModelMetadata( KNNEngine.FAISS, @@ -600,7 +612,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata5 = new ModelMetadata( KNNEngine.FAISS, @@ -614,7 +627,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata6 = new ModelMetadata( KNNEngine.FAISS, @@ -628,7 +642,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata7 = new ModelMetadata( KNNEngine.FAISS, @@ -642,7 +657,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, @@ -656,7 +672,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -670,7 +687,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata modelMetadata10 = new ModelMetadata( @@ -685,7 +703,8 @@ public void testHashCode() { new MethodComponentContext("test", Collections.emptyMap()), VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); assertEquals(modelMetadata1.hashCode(), modelMetadata1.hashCode()); @@ -711,6 +730,7 @@ public void testFromString() { String error = "test-error"; String nodeAssignment = "test-node"; MethodComponentContext methodComponentContext = MethodComponentContext.EMPTY; + Version version = Version.CURRENT; String stringRep1 = knnEngine.getName() + "," @@ -730,7 +750,11 @@ public void testFromString() { + "," + methodComponentContext.toClusterStateString() + "," - + VectorDataType.DEFAULT.getValue(); + + VectorDataType.DEFAULT.getValue() + + "," + + "," + + "," + + version.toString(); String stringRep2 = knnEngine.getName() + "," @@ -768,7 +792,9 @@ public void testFromString() { + "," + VectorDataType.DEFAULT.getValue() + "," - + ","; + + "," + + "," + + version; String stringRep4 = knnEngine.getName() + "," @@ -806,7 +832,8 @@ public void testFromString() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ); ModelMetadata expected2 = new ModelMetadata( @@ -821,7 +848,8 @@ public void testFromString() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.V_EMPTY ); ModelMetadata expected3 = new ModelMetadata( @@ -836,7 +864,8 @@ public void testFromString() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.ON_DISK, - CompressionLevel.x32 + CompressionLevel.x32, + Version.CURRENT ); ModelMetadata fromString1 = ModelMetadata.fromString(stringRep1); @@ -863,6 +892,8 @@ public void testFromResponseMap() throws IOException { String nodeAssignment = "test-node"; MethodComponentContext methodComponentContext = getMethodComponentContext(); MethodComponentContext emptyMethodComponentContext = MethodComponentContext.EMPTY; + Version version = Version.CURRENT; + Version emptyVersion = Version.V_EMPTY; ModelMetadata expected = new ModelMetadata( knnEngine, @@ -876,7 +907,8 @@ public void testFromResponseMap() throws IOException { methodComponentContext, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + version ); ModelMetadata expected2 = new ModelMetadata( knnEngine, @@ -890,7 +922,8 @@ public void testFromResponseMap() throws IOException { emptyMethodComponentContext, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + emptyVersion ); ModelMetadata expected3 = new ModelMetadata( @@ -905,7 +938,8 @@ public void testFromResponseMap() throws IOException { emptyMethodComponentContext, VectorDataType.DEFAULT, Mode.ON_DISK, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); ModelMetadata expected4 = new ModelMetadata( @@ -920,7 +954,8 @@ public void testFromResponseMap() throws IOException { emptyMethodComponentContext, VectorDataType.DEFAULT, Mode.ON_DISK, - CompressionLevel.x16 + CompressionLevel.x16, + Version.CURRENT ); Map metadataAsMap = new HashMap<>(); metadataAsMap.put(KNNConstants.KNN_ENGINE, knnEngine.getName()); @@ -936,12 +971,14 @@ public void testFromResponseMap() throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); builder = methodComponentContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject(); metadataAsMap.put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, builder.toString()); + metadataAsMap.put(KNNConstants.MODEL_VERSION, version.toString()); ModelMetadata fromMap = ModelMetadata.getMetadataFromSourceMap(metadataAsMap); assertEquals(expected, fromMap); metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, null); metadataAsMap.put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, null); + metadataAsMap.put(KNNConstants.MODEL_VERSION, emptyVersion); assertEquals(expected2, fromMap); metadataAsMap.put(KNNConstants.MODE_PARAMETER, Mode.ON_DISK.getName()); @@ -978,7 +1015,8 @@ public void testBlockCommasInDescription() { methodComponentContext, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ) ); assertEquals("Model description cannot contain any commas: ','", e.getMessage()); diff --git a/src/test/java/org/opensearch/knn/indices/ModelTests.java b/src/test/java/org/opensearch/knn/indices/ModelTests.java index 4e666872f..59ecd66ec 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelTests.java @@ -11,6 +11,7 @@ package org.opensearch.knn.indices; +import org.opensearch.Version; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.engine.MethodComponentContext; @@ -47,7 +48,8 @@ public void testInvalidConstructor() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), null, "test-model" @@ -71,7 +73,8 @@ public void testInvalidDimension() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), new byte[16], "test-model" @@ -92,7 +95,8 @@ public void testInvalidDimension() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), new byte[16], "test-model" @@ -113,7 +117,8 @@ public void testInvalidDimension() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), new byte[16], "test-model" @@ -135,7 +140,8 @@ public void testGetModelMetadata() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); Model model = new Model(modelMetadata, new byte[16], "test-model"); assertEquals(modelMetadata, model.getModelMetadata()); @@ -156,7 +162,8 @@ public void testGetModelBlob() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, "test-model" @@ -179,7 +186,8 @@ public void testGetLength() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), new byte[size], "test-model" @@ -199,7 +207,8 @@ public void testGetLength() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), null, "test-model" @@ -222,7 +231,8 @@ public void testSetModelBlob() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), blob1, "test-model" @@ -251,7 +261,8 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), new byte[16], "test-model-1" @@ -269,7 +280,8 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), new byte[16], "test-model-1" @@ -287,13 +299,13 @@ public void testEquals() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), new byte[16], "test-model-2" ); - assertEquals(model1, model1); assertEquals(model1, model2); assertNotEquals(model1, model3); } @@ -315,7 +327,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), new byte[16], "test-model-1" @@ -333,7 +346,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), new byte[16], "test-model-1" @@ -351,7 +365,8 @@ public void testHashCode() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), new byte[16], "test-model-2" @@ -385,7 +400,8 @@ public void testModelFromSourceMap() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); Map modelAsMap = new HashMap<>(); modelAsMap.put(KNNConstants.MODEL_ID, modelID); diff --git a/src/test/java/org/opensearch/knn/indices/ModelUtilTests.java b/src/test/java/org/opensearch/knn/indices/ModelUtilTests.java index edefd10ee..45597b4c9 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelUtilTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelUtilTests.java @@ -23,14 +23,19 @@ public void testGetModelMetadata_whenVariousInputs_thenSuccess() { MockedStatic modelCacheMockedStatic = Mockito.mockStatic(ModelCache.class); modelCacheMockedStatic.when(ModelCache::getInstance).thenReturn(modelCache); - - Mockito.when(modelCache.get(MODEL_ID)).thenReturn(model); - Mockito.when(model.getModelMetadata()).thenReturn(null); - Assert.assertThrows(IllegalArgumentException.class, () -> ModelUtil.getModelMetadata(MODEL_ID)); - - Mockito.when(model.getModelMetadata()).thenReturn(modelMetadata); - Mockito.when(modelMetadata.getState()).thenReturn(ModelState.CREATED); - Assert.assertNotNull(ModelUtil.getModelMetadata(MODEL_ID)); + try (MockedStatic modelDaoMockedStatic = Mockito.mockStatic(ModelDao.OpenSearchKNNModelDao.class)) { + ModelDao.OpenSearchKNNModelDao modelDao = Mockito.mock(ModelDao.OpenSearchKNNModelDao.class); + Mockito.when(modelDao.getMetadata(MODEL_ID)).thenReturn(modelMetadata); + Mockito.when(modelMetadata.getState()).thenReturn(ModelState.FAILED); + modelDaoMockedStatic.when(ModelDao.OpenSearchKNNModelDao::getInstance).thenReturn(modelDao); + + Mockito.when(modelCache.get(MODEL_ID)).thenReturn(model); + Mockito.when(model.getModelMetadata()).thenReturn(null); + Assert.assertThrows(IllegalArgumentException.class, () -> ModelUtil.getModelMetadata(MODEL_ID)); + + Mockito.when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + Assert.assertNotNull(ModelUtil.getModelMetadata(MODEL_ID)); + } modelCacheMockedStatic.close(); } } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java index 7010dbf43..649bb6c76 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java @@ -50,7 +50,8 @@ private ModelMetadata getModelMetadata(ModelState state) { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); } @@ -75,7 +76,7 @@ public void testXContent() throws IOException { Model model = new Model(getModelMetadata(ModelState.CREATED), testModelBlob, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}},\"data_type\":\"float\"}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}},\"data_type\":\"float\",\"model_version\":\"3.0.0\"}"; XContentBuilder xContentBuilder = MediaTypeRegistry.contentBuilder(XContentType.JSON); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); @@ -91,7 +92,7 @@ public void testXContentWithNoModelBlob() throws IOException { Model model = new Model(getModelMetadata(ModelState.FAILED), null, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}},\"data_type\":\"float\"}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}},\"data_type\":\"float\",\"model_version\":\"3.0.0\"}"; XContentBuilder xContentBuilder = MediaTypeRegistry.contentBuilder(XContentType.JSON); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java index 6252a29ac..8f4cad112 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java @@ -13,6 +13,7 @@ import com.google.common.collect.ImmutableSet; import org.junit.Ignore; +import org.opensearch.Version; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -84,7 +85,8 @@ public void testNodeOperation_modelInCache() throws ExecutionException, Interrup MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), new byte[128], modelId diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index a03084c63..79292fb53 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -12,6 +12,7 @@ package org.opensearch.knn.plugin.transport; import com.google.common.collect.ImmutableMap; +import org.opensearch.Version; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; @@ -222,7 +223,8 @@ public void testValidation_invalid_modelIdAlreadyExists() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java index 45203dae6..cac5c1b9c 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.transport; +import org.opensearch.Version; import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; @@ -216,7 +217,8 @@ public void testClusterManagerOperation_GetIndicesUsingModel() throws IOExceptio MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), modelBlob, modelId diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java index d0a83ccc5..577762a42 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java @@ -11,6 +11,7 @@ package org.opensearch.knn.plugin.transport; +import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.engine.MethodComponentContext; @@ -48,7 +49,8 @@ public void testStreams() throws IOException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest(modelId, isRemoveRequest, modelMetadata); @@ -76,7 +78,8 @@ public void testValidate() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); UpdateModelMetadataRequest updateModelMetadataRequest1 = new UpdateModelMetadataRequest("test", true, null); @@ -119,7 +122,8 @@ public void testGetModelMetadata() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest("test", true, modelMetadata); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java index d317fa893..48b93653f 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java @@ -11,6 +11,7 @@ package org.opensearch.knn.plugin.transport; +import org.opensearch.Version; import org.opensearch.core.action.ActionListener; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.cluster.ClusterState; @@ -74,7 +75,8 @@ public void testClusterManagerOperation() throws InterruptedException { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ); // Get update transport action diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java index 4876b1562..9671b6b5a 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java @@ -65,7 +65,7 @@ public void testExecute_success() throws IOException, InterruptedException, Exec // After put finishes, it should call the onResponse function that will call responseListener and then kickoff // training. ModelDao modelDao = mock(ModelDao.class); - when(modelDao.get(modelId)).thenReturn(model); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); doAnswer(invocationOnMock -> { assertEquals(1, trainingJobRunner.getJobCount()); // Make sure job count is correct IndexResponse indexResponse = new IndexResponse(new ShardId(MODEL_INDEX_NAME, "uuid", 0), modelId, 0, 0, 0, true); diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index 32794a33b..14308b915 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -124,7 +124,8 @@ public void testGetModel() { MethodComponentContext.EMPTY, VectorDataType.DEFAULT, Mode.NOT_CONFIGURED, - CompressionLevel.NOT_CONFIGURED + CompressionLevel.NOT_CONFIGURED, + Version.CURRENT ), null, modelID