Skip to content

Commit

Permalink
add acknowledge check for index creation in missing places
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Jul 23, 2024
1 parent 9b413a7 commit 4321361
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,58 +114,65 @@ private void initMasterKey() {

CountDownLatch latch = new CountDownLatch(1);
mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(getResponse -> {
if (getResponse == null || !getResponse.isExists()) {
IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
final String generatedMasterKey = generateMasterKey();
indexRequest
.source(ImmutableMap.of(MASTER_KEY, generatedMasterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
indexRequest.opType(DocWriteRequest.OpType.CREATE);
client.index(indexRequest, ActionListener.wrap(indexResponse -> {
this.masterKey = generatedMasterKey;
log.info("ML encryption master key initialized successfully");
latch.countDown();
}, e -> {

if (ExceptionUtils.getRootCause(e) instanceof VersionConflictEngineException) {
GetRequest getMasterKeyRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
client.get(getMasterKeyRequest, ActionListener.wrap(getMasterKeyResponse -> {
if (getMasterKeyResponse != null && getMasterKeyResponse.isExists()) {
final String masterKey = (String) getMasterKeyResponse.getSourceAsMap().get(MASTER_KEY);
this.masterKey = masterKey;
log.info("ML encryption master key already initialized, no action needed");
latch.countDown();
} else {
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
if (!r) {
exceptionRef.set(new RuntimeException("No response to create ML Config index"));
latch.countDown();
} else {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(getResponse -> {
if (getResponse == null || !getResponse.isExists()) {
IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
final String generatedMasterKey = generateMasterKey();
indexRequest
.source(ImmutableMap.of(MASTER_KEY, generatedMasterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
indexRequest.opType(DocWriteRequest.OpType.CREATE);
client.index(indexRequest, ActionListener.wrap(indexResponse -> {
this.masterKey = generatedMasterKey;
log.info("ML encryption master key initialized successfully");
latch.countDown();
}, e -> {

if (ExceptionUtils.getRootCause(e) instanceof VersionConflictEngineException) {
GetRequest getMasterKeyRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
try (
ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()
) {
client.get(getMasterKeyRequest, ActionListener.wrap(getMasterKeyResponse -> {
if (getMasterKeyResponse != null && getMasterKeyResponse.isExists()) {
final String masterKey = (String) getMasterKeyResponse.getSourceAsMap().get(MASTER_KEY);
this.masterKey = masterKey;
log.info("ML encryption master key already initialized, no action needed");
latch.countDown();
} else {
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
latch.countDown();
}
}, error -> {
log.debug("Failed to get ML encryption master key", e);
exceptionRef.set(error);
latch.countDown();
}
}, error -> {
log.debug("Failed to get ML encryption master key", e);
exceptionRef.set(error);
latch.countDown();
}));
}));
}
} else {
log.debug("Failed to index ML encryption master key", e);
exceptionRef.set(e);
latch.countDown();
}
} else {
log.debug("Failed to index ML encryption master key", e);
exceptionRef.set(e);
latch.countDown();
}
}));
} else {
final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY);
this.masterKey = masterKey;
log.info("ML encryption master key already initialized, no action needed");
}));
} else {
final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY);
this.masterKey = masterKey;
log.info("ML encryption master key already initialized, no action needed");
latch.countDown();
}
}, e -> {
log.debug("Failed to get ML encryption master key from config index", e);
exceptionRef.set(e);
latch.countDown();
}
}, e -> {
log.debug("Failed to get ML encryption master key from config index", e);
exceptionRef.set(e);
latch.countDown();
}));
}));
}
}
}, e -> {
log.debug("Failed to init ML config index", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,22 @@ public void decrypt_NullMasterKey_GetMasterKey_Exception() {
encryptor.decrypt("test");
}

@Test
public void decrypt_NoResponseToInitConfigIndex() {
exceptionRule.expect(RuntimeException.class);
exceptionRule.expectMessage("No response to create ML Config index");

doAnswer(invocation -> {
ActionListener<Boolean> actionListener = (ActionListener) invocation.getArgument(0);
actionListener.onResponse(false);
return null;
}).when(mlIndicesHandler).initMLConfigIndex(any());

Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler);
Assert.assertNull(encryptor.getMasterKey());
encryptor.decrypt("test");
}

@Test
public void decrypt_MLConfigIndexNotFound() {
exceptionRule.expect(ResourceNotFoundException.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ public void uploadModelChunk(MLUploadModelChunkInput uploadModelChunkInput, Acti
throw new Exception("Chunk size exceeds 10MB");
}
mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> {
if (!res) {
wrappedListener.onFailure(new RuntimeException("No response to create ML Model index"));
return;
}
int chunkNum = uploadModelChunkInput.getChunkNumber();
MLModel mlModel = MLModel
.builder()
Expand Down
26 changes: 18 additions & 8 deletions plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,13 @@ public void run() {
return;
}
// refresh model status
mlIndicesHandler
.initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> {
log.error("Failed to init model index", e);
}));
mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> {
if (!res) {
log.error("No response to create ML model index");
return;
}
refreshModelState(modelWorkerNodes, deployingModels);
}, e -> { log.error("Failed to init model index", e); }));
}, ex -> { log.error("Failed to sync model routing", ex); }));
}, e -> { log.error("Failed to sync model routing", e); }));
}
Expand All @@ -211,10 +214,13 @@ private void undeployExpiredModels(
log.debug("Received failures in undeploying expired models", mlUndeployModelNodesResponse.failures());
}

mlIndicesHandler
.initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> {
log.error("Failed to init model index", e);
}));
mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> {
if (!res) {
log.error("No response to create ML model index");
return;
}
refreshModelState(modelWorkerNodes, deployingModels);
}, e -> { log.error("Failed to init model index", e); }));
}, e -> { log.error("Failed to undeploy models {}", expiredModels, e); }));
}

Expand All @@ -224,6 +230,10 @@ void initMLConfig() {
return;
}
mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> {
if (!r) {
log.debug("Failed to initialize or update ML Config index");
return;
}
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(getResponse -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str
}

mlIndicesHandler.initModelGroupIndexIfAbsent(ActionListener.wrap(res -> {
if (!res) {
wrappedListener.onFailure(new RuntimeException("No response to create ML Model Group index"));
return;
}
IndexRequest indexRequest = new IndexRequest(ML_MODEL_GROUP_INDEX);
indexRequest
.source(
Expand Down
16 changes: 16 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput
ActionListener<String> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
String modelName = mlRegisterModelMetaInput.getName();
mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> {
if (!res) {
wrappedListener.onFailure(new RuntimeException("No response to create ML Model index"));
return;
}
Instant now = Instant.now();
MLModel mlModelMeta = MLModel
.builder()
Expand Down Expand Up @@ -528,6 +532,10 @@ private void indexRemoteModel(
}

mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(boolResponse -> {
if (!boolResponse) {
listener.onFailure(new RuntimeException("No response to create ML Model index"));
return;
}
MLModel mlModelMeta = MLModel
.builder()
.name(modelName)
Expand Down Expand Up @@ -596,6 +604,10 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St
registerModelInput.getConnector().encrypt(mlEngine::encrypt);
}
mlIndicesHandler.initModelIndexIfAbsent(ActionListener.runBefore(ActionListener.wrap(res -> {
if (!res) {
handleException(functionName, taskId, new RuntimeException("No response to create ML Model index"));
return;
}
MLModel mlModelMeta = MLModel
.builder()
.name(modelName)
Expand Down Expand Up @@ -666,6 +678,10 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas
String modelGroupId = registerModelInput.getModelGroupId();
Instant now = Instant.now();
mlIndicesHandler.initModelIndexIfAbsent(ActionListener.runBefore(ActionListener.wrap(res -> {
if (!res) {
handleException(functionName, taskId, new RuntimeException("No response to create ML Model index"));
return;
}
MLModel mlModelMeta = MLModel
.builder()
.name(modelName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ public void testUploadModelChunk() {
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void test_NoResponseInitModelIndex() {
doAnswer(invocation -> {
ActionListener<Boolean> actionListener = invocation.getArgument(0);
actionListener.onResponse(false);
return null;
}).when(mlIndicesHandler).initModelIndexIfAbsent(any());

MLUploadModelChunkInput uploadModelChunkInput = prepareRequest();
mlModelChunkUploader.uploadModelChunk(uploadModelChunkInput, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("No response to create ML Model index", argumentCaptor.getValue().getMessage());
}

private MLUploadModelChunkInput prepareRequest() {
final byte[] content = new byte[] { 1, 2, 3, 4 };
MLUploadModelChunkInput input = MLUploadModelChunkInput.builder().chunkNumber(0).modelId("someModelId").content(content).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,23 @@ public void test_NotFoundGetModelGroup() throws IOException {
assertEquals("Failed to find model group with ID: testModelGroupID", argumentCaptor.getValue().getMessage());
}

public void test_NoResponseoInitModelGroup() throws IOException {
doAnswer(invocation -> {
ActionListener<Boolean> actionListener = invocation.getArgument(0);
actionListener.onResponse(false);
return null;
}).when(mlIndicesHandler).initModelGroupIndexIfAbsent(any());

when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false);

MLRegisterModelGroupInput mlRegisterModelGroupInput = prepareRequest(null, null, null);
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, actionListener);

ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("No response to create ML Model Group index", argumentCaptor.getValue().getMessage());
}

private MLRegisterModelGroupInput prepareRequest(List<String> backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) {
return MLRegisterModelGroupInput
.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,20 @@ public void testRegisterModelMeta_FailedToInitIndexIfPresent() {
verify(actionListener).onFailure(argumentCaptor.capture());
}

public void testRegisterModelMeta_NoResponseToInitIndex() {
setupForModelMeta();
doAnswer(invocation -> {
ActionListener<Boolean> actionListener = invocation.getArgument(0);
actionListener.onResponse(false);
return null;
}).when(mlIndicesHandler).initModelIndexIfAbsent(any());
MLRegisterModelMetaInput mlUploadInput = prepareRequest();
modelManager.registerModelMeta(mlUploadInput, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("No response to create ML Model index", argumentCaptor.getValue().getMessage());
}

public void test_trackPredictDuration_sync() {
Supplier<String> mockResult = () -> {
try {
Expand Down

0 comments on commit 4321361

Please sign in to comment.