Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize all models into cluster metadata #1499

Merged
merged 16 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
* Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices [#1604](https://github.com/opensearch-project/k-NN/pull/1604)
* Remove unnecessary toString conversion of vector field and added some minor optimization in KNNCodec [1613](https://github.com/opensearch-project/k-NN/pull/1613)
* Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499)
### Bug Fixes
### Infrastructure
* Add micro-benchmark module in k-NN plugin for benchmark streaming vectors to JNI layer functionality. [#1583](https://github.com/opensearch-project/k-NN/pull/1583)
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;

import java.io.File;
Expand Down Expand Up @@ -197,7 +198,7 @@ public static ValidationException validateKnnField(
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
if (!ModelUtil.isModelPresent(modelMetadata) || !ModelUtil.isModelCreated(modelMetadata)) {
exception.addValidationError(String.format("Model \"%s\" for field \"%s\" does not exist.", modelId, field));
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
return exception;
}
Expand Down Expand Up @@ -284,4 +285,5 @@ public static boolean isSharedIndexStateRequired(KNNEngine knnEngine, String mod
}
return JNIService.isSharedIndexStateRequired(indexAddr, knnEngine);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

import java.io.IOException;

Expand Down Expand Up @@ -50,7 +51,7 @@ protected void parseCreateField(ParseContext context) throws IOException {
// model when ingestion starts.
ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId);

if (modelMetadata == null) {
if (!ModelUtil.isModelPresent(modelMetadata) || !ModelUtil.isModelCreated(modelMetadata)) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalStateException(
String.format(
"Model \"%s\" from %s's mapping does not exist. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
Expand Down Expand Up @@ -517,7 +518,7 @@ private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFie
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
if (!ModelUtil.isModelPresent(modelMetadata) || !ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalArgumentException(String.format("Model ID '%s' does not exist.", modelId));
}
return modelMetadata;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;

Expand Down Expand Up @@ -213,7 +214,7 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
String modelId = fieldInfo.getAttribute(MODEL_ID);
if (modelId != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
if (!ModelUtil.isModelPresent(modelMetadata) || !ModelUtil.isModelCreated(modelMetadata)) {
throw new RuntimeException("Model \"" + modelId + "\" does not exist.");
}

Expand Down
8 changes: 1 addition & 7 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -341,13 +341,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
);
}, listener::onFailure);

// After the model is indexed, update metadata only if the model is in CREATED state
ActionListener<IndexResponse> onIndexListener;
if (ModelState.CREATED.equals(model.getModelMetadata().getState())) {
onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener);
} else {
onIndexListener = onMetaListener;
}
ActionListener<IndexResponse> onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to discuss a little bit here some thoughts:

So before, the logic was that the model index is the only source of truth until the model is actually created. Then the model-metadata is a source of truth as well.

Now, we are saying that the model metadata will always lag behind the model system index (i.e. on any call to change model state, we will first persist in system index and then update in cluster state). This is going to open up to the possibility that model index and cluster state fall out of sync on failure. We may need to think about how we handle different scenarios.

I think a model's state can be updated in one of the following ways:

  1. Training process (this will be in sync - single thread, synchronous process)
  2. Model deleted (we block the model delete if the model is not in the created or error state - we should confirm that we check this from cluster metadata before blocking)
  3. Node drop (if a node drops, the offending node will know not to change anything in the state correct?)

Are there other cases you can think of that might be of concern?

Copy link
Member Author

@ryanbogan ryanbogan Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmazanec15 The goal of this was to treat the cluster metadata as the main source of truth in case of node drops I believe. This would decrease the chance of nodes being out of sync with the cluster since they could always access the model metadata from the cluster metadata.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote a temporary integration test trying to create a race condition between a model finishing training and entering the created state in ModelDao and then deleting the model immediately. All behavior was as expected and the model was deleted from the cluster metadata as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ryanbogan can you push to another branch and link to the test? I want to take a quick look.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmazanec15 I already deleted it off the local, do you want me to recreate it?


// Create the model index if it does not already exist
Runnable indexModelRunnable = () -> indexRequestBuilder.execute(onIndexListener);
Expand Down
24 changes: 24 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.indices;

public class ModelUtil {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved

public static boolean isModelPresent(ModelMetadata modelMetadata) {
return modelMetadata != null;
}

public static boolean isModelCreated(ModelMetadata modelMetadata) {
return modelMetadata.getState().equals(ModelState.CREATED);
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ protected void updateModelsNewCluster() throws IOException, InterruptedException
Model model = modelDao.get(modelId);
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
ModelMetadata modelMetadata = model.getModelMetadata();
if (modelMetadata.getState().equals(ModelState.TRAINING)) {
updateModelStateAsFailed(model, "Training failed to complete as cluster crashed");
updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as cluster crashed");
}
}
}
Expand All @@ -123,11 +123,10 @@ protected void updateModelsNodesRemoved(List<DiscoveryNode> removedNodes) throws
List<String> modelIds = searchModelIds();
for (DiscoveryNode removedNode : removedNodes) {
for (String modelId : modelIds) {
Model model = modelDao.get(modelId);
ModelMetadata modelMetadata = model.getModelMetadata();
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata.getNodeAssignment().equals(removedNode.getEphemeralId())
&& modelMetadata.getState().equals(ModelState.TRAINING)) {
updateModelStateAsFailed(model, "Training failed to complete as node dropped");
updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as node dropped");
}
}
}
Expand Down Expand Up @@ -158,9 +157,11 @@ public void onFailure(Exception e) {
return modelIds;
}

private void updateModelStateAsFailed(Model model, String msg) throws IOException {
model.getModelMetadata().setState(ModelState.FAILED);
model.getModelMetadata().setError(msg);
private void updateModelStateAsFailed(String modelId, ModelMetadata modelMetadata, String msg) throws IOException, ExecutionException,
InterruptedException {
modelMetadata.setState(ModelState.FAILED);
modelMetadata.setError(msg);
Model model = new Model(modelMetadata, null, modelId);
modelDao.update(model, new ActionListener<IndexResponse>() {
@Override
public void onResponse(IndexResponse indexResponse) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.plugins.SearchPlugin;

import java.io.IOException;
Expand Down Expand Up @@ -661,6 +662,7 @@ public void testDoToQuery_FromModel() {
when(modelMetadata.getDimension()).thenReturn(4);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
KNNQueryBuilder.initialize(modelDao);
Expand Down Expand Up @@ -690,6 +692,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold
when(modelMetadata.getDimension()).thenReturn(4);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
KNNQueryBuilder.initialize(modelDao);
Expand Down Expand Up @@ -722,6 +725,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th
when(modelMetadata.getDimension()).thenReturn(4);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
KNNQueryBuilder.initialize(modelDao);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.jni.JNIService;

import java.io.IOException;
Expand Down Expand Up @@ -167,6 +168,7 @@ public void testQueryScoreForFaissWithModel() {
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(spaceType);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelDao.getMetadata(eq("modelId"))).thenReturn(modelMetadata);

KNNWeight.initialize(modelDao);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ public void testValidation_valid_trainingIndexBuiltFromModel() {
// Mock the model dao to return metadata for modelId to recognize it is a duplicate
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(trainingFieldModelMetadata.getState()).thenReturn(ModelState.CREATED);

ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public void testUpdateModelsNewCluster() throws IOException, InterruptedExceptio
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.isCreated()).thenReturn(true);
when(modelDao.get(modelId)).thenReturn(model);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
doAnswer(invocationOnMock -> {
SearchResponse searchResponse = mock(SearchResponse.class);
SearchHits searchHits = mock(SearchHits.class);
Expand Down Expand Up @@ -144,6 +145,7 @@ public void testUpdateModelsNodesRemoved() throws IOException, InterruptedExcept
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.isCreated()).thenReturn(true);
when(modelDao.get(modelId)).thenReturn(model);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
DiscoveryNode node1 = mock(DiscoveryNode.class);
when(node1.getEphemeralId()).thenReturn("test-node-model-match");
DiscoveryNode node2 = mock(DiscoveryNode.class);
Expand Down
Loading