From 24e4585184bff51c5e88f7fece05547890a5d0cf Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 23 Apr 2024 10:26:46 -0700 Subject: [PATCH] Serialize all models into cluster metadata (#1499) * Remove transport calls in TrainingJobRunner and TrainingJobClusterStateListener Signed-off-by: Ryan Bogan * Fix tests Signed-off-by: Ryan Bogan * Add changelog Signed-off-by: Ryan Bogan * Fix CMake Faiss bug Signed-off-by: Ryan Bogan * Add state checks for existing cluster metadata calls Signed-off-by: Ryan Bogan * Remove CMake bug fix Signed-off-by: Ryan Bogan * Fix changelog Signed-off-by: Ryan Bogan * Fix failing tests Signed-off-by: Ryan Bogan * Refactor and add two more created state checks Signed-off-by: Ryan Bogan * Rebase and fix new tests Signed-off-by: Ryan Bogan * Refactor created checks and modify error messages Signed-off-by: Ryan Bogan * Refactor cluster state listener transport calls Signed-off-by: Ryan Bogan --------- Signed-off-by: Ryan Bogan (cherry picked from commit dc8eb6b919ff6025b7d1e2aeccba28d32793bb96) --- CHANGELOG.md | 1 + .../org/opensearch/knn/index/IndexUtil.java | 6 ++-- .../knn/index/mapper/ModelFieldMapper.java | 5 +-- .../knn/index/query/KNNQueryBuilder.java | 5 +-- .../opensearch/knn/index/query/KNNWeight.java | 5 +-- .../org/opensearch/knn/indices/ModelDao.java | 8 +---- .../org/opensearch/knn/indices/ModelUtil.java | 30 ++++++++++++++++++ .../TrainingJobClusterStateListener.java | 31 +++++++++++++------ .../knn/index/query/KNNQueryBuilderTests.java | 4 +++ .../knn/index/query/KNNWeightTests.java | 4 ++- .../transport/TrainingModelRequestTests.java | 1 + .../TrainingJobClusterStateListenerTests.java | 2 ++ 12 files changed, 77 insertions(+), 25 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/indices/ModelUtil.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b37975ba..7999e7bff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) * Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices [#1604](https://github.com/opensearch-project/k-NN/pull/1604) * Remove unnecessary toString conversion of vector field and added some minor optimization in KNNCodec [1613](https://github.com/opensearch-project/k-NN/pull/1613) +* Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499) ### Bug Fixes * Add stored fields for knn_vector type [#1630](https://github.com/opensearch-project/k-NN/pull/1630) ### Infrastructure diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 92b94c2e2..e4523bb5e 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -23,6 +23,7 @@ import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.jni.JNIService; import java.io.File; @@ -199,8 +200,8 @@ public static ValidationException validateKnnField( } ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (modelMetadata == null) { - exception.addValidationError(String.format("Model \"%s\" for field \"%s\" does not exist.", modelId, field)); + if (!ModelUtil.isModelCreated(modelMetadata)) { + exception.addValidationError(String.format("Model \"%s\" for field \"%s\" is not created.", modelId, field)); return exception; } @@ -286,4 +287,5 @@ public static boolean isSharedIndexStateRequired(KNNEngine knnEngine, String mod } return JNIService.isSharedIndexStateRequired(indexAddr, knnEngine); } + } diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index ce92d2967..554871279 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -11,6 +11,7 @@ import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; import java.io.IOException; @@ -50,10 +51,10 @@ protected void parseCreateField(ParseContext context) throws IOException { // model when ingestion starts. ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId); - if (modelMetadata == null) { + if (!ModelUtil.isModelCreated(modelMetadata)) { throw new IllegalStateException( String.format( - "Model \"%s\" from %s's mapping does not exist. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.", + "Model \"%s\" from %s's mapping is not created. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.", modelId, context.mapperService().index().getName(), MODEL_ID diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 760a91815..37114f3cb 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -31,6 +31,7 @@ import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryShardContext; @@ -548,8 +549,8 @@ private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFie } ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (modelMetadata == null) { - throw new IllegalArgumentException(String.format("Model ID '%s' does not exist.", modelId)); + if (!ModelUtil.isModelCreated(modelMetadata)) { + throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); } return modelMetadata; } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 6b323e124..8939a569e 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -40,6 +40,7 @@ import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.stats.KNNCounter; @@ -213,8 +214,8 @@ private Map doANNSearch(final LeafReaderContext context, final B String modelId = fieldInfo.getAttribute(MODEL_ID); if (modelId != null) { ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (modelMetadata == null) { - throw new RuntimeException("Model \"" + modelId + "\" does not exist."); + if (!ModelUtil.isModelCreated(modelMetadata)) { + throw new RuntimeException("Model \"" + modelId + "\" is not created."); } knnEngine = modelMetadata.getKnnEngine(); diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 0c0f08545..6940fcd39 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -341,13 +341,7 @@ private void putInternal(Model model, ActionListener listener, Do ); }, listener::onFailure); - // After the model is indexed, update metadata only if the model is in CREATED state - ActionListener onIndexListener; - if (ModelState.CREATED.equals(model.getModelMetadata().getState())) { - onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener); - } else { - onIndexListener = onMetaListener; - } + ActionListener onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener); // Create the model index if it does not already exist Runnable indexModelRunnable = () -> indexRequestBuilder.execute(onIndexListener); diff --git a/src/main/java/org/opensearch/knn/indices/ModelUtil.java b/src/main/java/org/opensearch/knn/indices/ModelUtil.java new file mode 100644 index 000000000..3daaed138 --- /dev/null +++ b/src/main/java/org/opensearch/knn/indices/ModelUtil.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.indices; + +/** + * A utility class for models. + */ +public class ModelUtil { + + public static boolean isModelPresent(ModelMetadata modelMetadata) { + return modelMetadata != null; + } + + public static boolean isModelCreated(ModelMetadata modelMetadata) { + if (!isModelPresent(modelMetadata)) { + return false; + } + return modelMetadata.getState().equals(ModelState.CREATED); + } + +} diff --git a/src/main/java/org/opensearch/knn/training/TrainingJobClusterStateListener.java b/src/main/java/org/opensearch/knn/training/TrainingJobClusterStateListener.java index 45d2197e8..7e39ff7b3 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJobClusterStateListener.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJobClusterStateListener.java @@ -109,10 +109,9 @@ protected void updateModelsNewCluster() throws IOException, InterruptedException if (modelDao.isCreated()) { List modelIds = searchModelIds(); for (String modelId : modelIds) { - Model model = modelDao.get(modelId); - ModelMetadata modelMetadata = model.getModelMetadata(); + ModelMetadata modelMetadata = getModelMetadata(modelId); if (modelMetadata.getState().equals(ModelState.TRAINING)) { - updateModelStateAsFailed(model, "Training failed to complete as cluster crashed"); + updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as cluster crashed"); } } } @@ -123,11 +122,10 @@ protected void updateModelsNodesRemoved(List removedNodes) throws List modelIds = searchModelIds(); for (DiscoveryNode removedNode : removedNodes) { for (String modelId : modelIds) { - Model model = modelDao.get(modelId); - ModelMetadata modelMetadata = model.getModelMetadata(); + ModelMetadata modelMetadata = getModelMetadata(modelId); if (modelMetadata.getNodeAssignment().equals(removedNode.getEphemeralId()) && modelMetadata.getState().equals(ModelState.TRAINING)) { - updateModelStateAsFailed(model, "Training failed to complete as node dropped"); + updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as node dropped"); } } } @@ -158,9 +156,11 @@ public void onFailure(Exception e) { return modelIds; } - private void updateModelStateAsFailed(Model model, String msg) throws IOException { - model.getModelMetadata().setState(ModelState.FAILED); - model.getModelMetadata().setError(msg); + private void updateModelStateAsFailed(String modelId, ModelMetadata modelMetadata, String msg) throws IOException, ExecutionException, + InterruptedException { + modelMetadata.setState(ModelState.FAILED); + modelMetadata.setError(msg); + Model model = new Model(modelMetadata, null, modelId); modelDao.update(model, new ActionListener() { @Override public void onResponse(IndexResponse indexResponse) { @@ -173,4 +173,17 @@ public void onFailure(Exception e) { } }); } + + private ModelMetadata getModelMetadata(String modelId) throws ExecutionException, InterruptedException { + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + // On versions prior to 2.14, only models in created state are present in model metadata. + if (modelMetadata == null) { + log.info( + "Model metadata is null in cluster metadata. This can happen for models training on nodes prior to OpenSearch version 2.14.0. Fetching model information from system index." + ); + Model model = modelDao.get(modelId); + return model.getModelMetadata(); + } + return modelMetadata; + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index e47f583e7..ddc961093 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -39,6 +39,7 @@ import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; import org.opensearch.plugins.SearchPlugin; import java.io.IOException; @@ -683,6 +684,7 @@ public void testDoToQuery_FromModel() { when(modelMetadata.getDimension()).thenReturn(4); when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + when(modelMetadata.getState()).thenReturn(ModelState.CREATED); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -712,6 +714,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold when(modelMetadata.getDimension()).thenReturn(4); when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); + when(modelMetadata.getState()).thenReturn(ModelState.CREATED); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -744,6 +747,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th when(modelMetadata.getDimension()).thenReturn(4); when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); + when(modelMetadata.getState()).thenReturn(ModelState.CREATED); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 795635b68..e80190ff8 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -45,6 +45,7 @@ import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.jni.JNIService; import java.io.IOException; @@ -167,6 +168,7 @@ public void testQueryScoreForFaissWithModel() { ModelMetadata modelMetadata = mock(ModelMetadata.class); when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(spaceType); + when(modelMetadata.getState()).thenReturn(ModelState.CREATED); when(modelDao.getMetadata(eq("modelId"))).thenReturn(modelMetadata); KNNWeight.initialize(modelDao); @@ -254,7 +256,7 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException { when(fieldInfo.getAttribute(eq(MODEL_ID))).thenReturn(modelId); RuntimeException ex = expectThrows(RuntimeException.class, () -> knnWeight.scorer(leafReaderContext)); - assertEquals(String.format("Model \"%s\" does not exist.", modelId), ex.getMessage()); + assertEquals(String.format("Model \"%s\" is not created.", modelId), ex.getMessage()); } @SneakyThrows diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index bdae54cad..b39c48635 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -661,6 +661,7 @@ public void testValidation_valid_trainingIndexBuiltFromModel() { // Mock the model dao to return metadata for modelId to recognize it is a duplicate ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class); when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); + when(trainingFieldModelMetadata.getState()).thenReturn(ModelState.CREATED); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(null); diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobClusterStateListenerTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobClusterStateListenerTests.java index 7994e73d2..672a54110 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobClusterStateListenerTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobClusterStateListenerTests.java @@ -101,6 +101,7 @@ public void testUpdateModelsNewCluster() throws IOException, InterruptedExceptio ModelDao modelDao = mock(ModelDao.class); when(modelDao.isCreated()).thenReturn(true); when(modelDao.get(modelId)).thenReturn(model); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); doAnswer(invocationOnMock -> { SearchResponse searchResponse = mock(SearchResponse.class); SearchHits searchHits = mock(SearchHits.class); @@ -144,6 +145,7 @@ public void testUpdateModelsNodesRemoved() throws IOException, InterruptedExcept ModelDao modelDao = mock(ModelDao.class); when(modelDao.isCreated()).thenReturn(true); when(modelDao.get(modelId)).thenReturn(model); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); DiscoveryNode node1 = mock(DiscoveryNode.class); when(node1.getEphemeralId()).thenReturn("test-node-model-match"); DiscoveryNode node2 = mock(DiscoveryNode.class);