Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize all models into cluster metadata #1499

Merged
merged 16 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402)
* Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499)
### Bug Fixes
* Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518)
### Infrastructure
Expand Down
8 changes: 1 addition & 7 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
);
}, listener::onFailure);

// After the model is indexed, update metadata only if the model is in CREATED state
ActionListener<IndexResponse> onIndexListener;
if (ModelState.CREATED.equals(model.getModelMetadata().getState())) {
onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener);
} else {
onIndexListener = onMetaListener;
}
ActionListener<IndexResponse> onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to discuss a little bit here some thoughts:

So before, the logic was that the model index is the only source of truth until the model is actually created. Then the model-metadata is a source of truth as well.

Now, we are saying that the model metadata will always lag behind the model system index (i.e. on any call to change model state, we will first persist in system index and then update in cluster state). This is going to open up to the possibility that model index and cluster state fall out of sync on failure. We may need to think about how we handle different scenarios.

I think a model's state can be updated in one of the following ways:

  1. Training process (this will be in sync - single thread, synchronous process)
  2. Model deleted (we block the model delete if the model is not in the created or error state - we should confirm that we check this from cluster metadata before blocking)
  3. Node drop (if a node drops, the offending node will know not to change anything in the state correct?)

Are there other cases you can think of that might be of concern?

Copy link
Member Author

@ryanbogan ryanbogan Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmazanec15 The goal of this was to treat the cluster metadata as the main source of truth in case of node drops I believe. This would decrease the chance of nodes being out of sync with the cluster since they could always access the model metadata from the cluster metadata.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote a temporary integration test trying to create a race condition between a model finishing training and entering the created state in ModelDao and then deleting the model immediately. All behavior was as expected and the model was deleted from the cluster metadata as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ryanbogan can you push to another branch and link to the test? I want to take a quick look.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmazanec15 I already deleted it off the local, do you want me to recreate it?


// Create the model index if it does not already exist
Runnable indexModelRunnable = () -> indexRequestBuilder.execute(onIndexListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ protected void updateModelsNewCluster() throws IOException, InterruptedException
if (modelDao.isCreated()) {
List<String> modelIds = searchModelIds();
for (String modelId : modelIds) {
Model model = modelDao.get(modelId);
ModelMetadata modelMetadata = model.getModelMetadata();
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to be backwards compatible?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should probably add a version check, since they won't be in the metadata on older clusters if it's in created state

Copy link
Member Author

@ryanbogan ryanbogan Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm not sure if there is a way to make it backwards compatible, since the old models wouldn't be in the cluster metadata. I think I have to revert to the get call

if (modelMetadata.getState().equals(ModelState.TRAINING)) {
updateModelStateAsFailed(model, "Training failed to complete as cluster crashed");
updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as cluster crashed");
}
}
}
Expand All @@ -123,11 +122,10 @@ protected void updateModelsNodesRemoved(List<DiscoveryNode> removedNodes) throws
List<String> modelIds = searchModelIds();
for (DiscoveryNode removedNode : removedNodes) {
for (String modelId : modelIds) {
Model model = modelDao.get(modelId);
ModelMetadata modelMetadata = model.getModelMetadata();
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata.getNodeAssignment().equals(removedNode.getEphemeralId())
&& modelMetadata.getState().equals(ModelState.TRAINING)) {
updateModelStateAsFailed(model, "Training failed to complete as node dropped");
updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as node dropped");
}
}
}
Expand Down Expand Up @@ -158,9 +156,11 @@ public void onFailure(Exception e) {
return modelIds;
}

private void updateModelStateAsFailed(Model model, String msg) throws IOException {
model.getModelMetadata().setState(ModelState.FAILED);
model.getModelMetadata().setError(msg);
private void updateModelStateAsFailed(String modelId, ModelMetadata modelMetadata, String msg) throws IOException, ExecutionException,
InterruptedException {
modelMetadata.setState(ModelState.FAILED);
modelMetadata.setError(msg);
Model model = new Model(modelMetadata, null, modelId);
modelDao.update(model, new ActionListener<IndexResponse>() {
@Override
public void onResponse(IndexResponse indexResponse) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public void testUpdateModelsNewCluster() throws IOException, InterruptedExceptio
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.isCreated()).thenReturn(true);
when(modelDao.get(modelId)).thenReturn(model);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
doAnswer(invocationOnMock -> {
SearchResponse searchResponse = mock(SearchResponse.class);
SearchHits searchHits = mock(SearchHits.class);
Expand Down Expand Up @@ -144,6 +145,7 @@ public void testUpdateModelsNodesRemoved() throws IOException, InterruptedExcept
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.isCreated()).thenReturn(true);
when(modelDao.get(modelId)).thenReturn(model);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
DiscoveryNode node1 = mock(DiscoveryNode.class);
when(node1.getEphemeralId()).thenReturn("test-node-model-match");
DiscoveryNode node2 = mock(DiscoveryNode.class);
Expand Down
Loading