diff --git a/CHANGELOG.md b/CHANGELOG.md index 3920ca210..689201aeb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) - Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635)) - Silently ignore content on APIs that don't require it ([#639](https://github.com/opensearch-project/flow-framework/pull/639)) - Hide user and credential field from search response ([#680](https://github.com/opensearch-project/flow-framework/pull/680)) +- Throw the correct error message in status API for WorkflowSteps ([#676](https://github.com/opensearch-project/flow-framework/pull/676)) ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/flowframework/exception/WorkflowStepException.java b/src/main/java/org/opensearch/flowframework/exception/WorkflowStepException.java index 3575034fc..8434dc848 100644 --- a/src/main/java/org/opensearch/flowframework/exception/WorkflowStepException.java +++ b/src/main/java/org/opensearch/flowframework/exception/WorkflowStepException.java @@ -8,6 +8,9 @@ */ package org.opensearch.flowframework.exception; +import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchParseException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; @@ -64,4 +67,19 @@ public RestStatus getRestStatus() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { return builder.startObject().field("error", this.getMessage()).endObject(); } + + /** + * Getter for safe exceptions + * @param ex exception + * @return exception if safe + */ + public static Exception getSafeException(Exception ex) { + if (ex instanceof IllegalArgumentException + || ex instanceof OpenSearchStatusException + || ex instanceof OpenSearchParseException + || (ex instanceof OpenSearchException && ex.getCause() instanceof OpenSearchParseException)) { + return ex; + } + return null; + } } diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index add83edfe..b1a4cc5da 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -266,7 +266,7 @@ private void executeWorkflow(List workflowSequence, String workflow status = ExceptionsHelper.status(ex); } logger.error("Provisioning failed for workflow {} during step {}.", workflowId, currentStepId, ex); - String errorMessage = (ex.getCause() == null ? ex.getClass().getName() : ex.getCause().getClass().getName()) + String errorMessage = (ex.getCause() == null ? ex.getMessage() : ex.getCause().getClass().getName()) + " during step " + currentStepId + ", restStatus: " diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java index ce9bca27e..e23d88b63 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java @@ -34,6 +34,7 @@ import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.PIPELINE_ID; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** * Step to create either a search or ingest pipeline @@ -137,8 +138,9 @@ public void onResponse(AcknowledgedResponse acknowledgedResponse) { } @Override - public void onFailure(Exception e) { - String errorMessage = "Failed step " + pipelineToBeCreated; + public void onFailure(Exception ex) { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed step " + pipelineToBeCreated : e.getMessage()); logger.error(errorMessage, e); createPipelineFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java index 51bab0a8f..fe4e54b6a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java @@ -49,6 +49,7 @@ import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** * Abstract local model registration step @@ -215,9 +216,10 @@ public PlainActionFuture execute( }, exception -> { registerLocalModelFuture.onFailure(exception); }) ); }, exception -> { - String errorMessage = "Failed to register local model in step " + currentNodeId; - logger.error(errorMessage, exception); - registerLocalModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(exception))); + Exception e = getSafeException(exception); + String errorMessage = (e == null ? "Failed to register local model in step " + currentNodeId : e.getMessage()); + logger.error(errorMessage, e); + registerLocalModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); })); } catch (IllegalArgumentException iae) { registerLocalModelFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST)); diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 02f5bf336..484807ce3 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -44,6 +44,7 @@ import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; /** @@ -121,8 +122,9 @@ public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { } @Override - public void onFailure(Exception e) { - String errorMessage = "Failed to create connector"; + public void onFailure(Exception ex) { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to create connector" : ex.getMessage()); logger.error(errorMessage, e); createConnectorFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index f28fa64d4..32ca9e9f6 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -38,6 +38,7 @@ import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** * Step to create an index @@ -136,10 +137,11 @@ public PlainActionFuture execute( logger.error(errorMessage, ex); createIndexFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(ex))); } - }, e -> { - String errorMessage = "Failed to create the index " + indexName; + }, ex -> { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to create the index " + indexName : e.getMessage()); logger.error(errorMessage, e); - createIndexFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + createIndexFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); })); } catch (Exception e) { createIndexFuture.onFailure(e); diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java index 0dba99f7a..b49be90b3 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java @@ -24,6 +24,7 @@ import java.util.Set; import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** * Step to delete a agent for a remote model @@ -82,8 +83,9 @@ public void onResponse(DeleteResponse deleteResponse) { } @Override - public void onFailure(Exception e) { - String errorMessage = "Failed to delete agent " + agentId; + public void onFailure(Exception ex) { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to delete agent " + agentId : e.getMessage()); logger.error(errorMessage, e); deleteAgentFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java index 11a2b2d62..6a1a5e0c7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java @@ -24,6 +24,7 @@ import java.util.Set; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** * Step to delete a connector for a remote model @@ -82,8 +83,9 @@ public void onResponse(DeleteResponse deleteResponse) { } @Override - public void onFailure(Exception e) { - String errorMessage = "Failed to delete connector " + connectorId; + public void onFailure(Exception ex) { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to delete connector " + connectorId : e.getMessage()); logger.error(errorMessage, e); deleteConnectorFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java index c8071f7cd..66a3e4ec0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java @@ -24,6 +24,7 @@ import java.util.Set; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** * Step to delete a model for a remote model @@ -83,8 +84,9 @@ public void onResponse(DeleteResponse deleteResponse) { } @Override - public void onFailure(Exception e) { - String errorMessage = "Failed to delete model " + modelId; + public void onFailure(Exception ex) { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to delete model " + modelId : e.getMessage()); logger.error(errorMessage, e); deleteModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 929f5f570..56c2a6181 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -29,6 +29,7 @@ import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** * Step to deploy a model @@ -115,8 +116,9 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { } @Override - public void onFailure(Exception e) { - String errorMessage = "Failed to deploy model " + modelId; + public void onFailure(Exception ex) { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to deploy model " + modelId : e.getMessage()); logger.error(errorMessage, e); deployModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 8042d5244..d485f6f5c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -48,6 +48,7 @@ import static org.opensearch.flowframework.common.CommonValue.TYPE; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; import static org.opensearch.flowframework.util.ParseUtils.getStringToStringMap; /** @@ -133,8 +134,9 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { } @Override - public void onFailure(Exception e) { - String errorMessage = "Failed to register the agent"; + public void onFailure(Exception ex) { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to register the agent" : e.getMessage()); logger.error(errorMessage, e); registerAgentModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java index 10824c2d5..1a2fcebe9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java @@ -38,6 +38,7 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** * Step to register a model group @@ -118,8 +119,9 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse } @Override - public void onFailure(Exception e) { - String errorMessage = "Failed to register model group"; + public void onFailure(Exception ex) { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to register model group" : e.getMessage()); logger.error(errorMessage, e); registerModelGroupFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index cfdc21cd9..cce5d6ee8 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -38,6 +38,7 @@ import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** * Step to register a remote model @@ -184,8 +185,9 @@ void completeRegisterFuture(UpdateResponse response, String resourceName, MLRegi } @Override - public void onFailure(Exception e) { - String errorMessage = "Failed to register remote model"; + public void onFailure(Exception ex) { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to register remote model" : e.getMessage()); logger.error(errorMessage, e); registerRemoteModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java index 00eec6d29..ab7bc1f16 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java @@ -29,6 +29,7 @@ import static org.opensearch.flowframework.common.CommonValue.SUCCESS; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; +import static org.opensearch.flowframework.exception.WorkflowStepException.getSafeException; /** * Step to undeploy model @@ -96,8 +97,9 @@ public void onResponse(MLUndeployModelsResponse mlUndeployModelsResponse) { } @Override - public void onFailure(Exception e) { - String errorMessage = "Failed to undeploy model " + modelId; + public void onFailure(Exception ex) { + Exception e = getSafeException(ex); + String errorMessage = (e == null ? "Failed to undeploy model " + modelId : e.getMessage()); logger.error(errorMessage, e); undeployModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index d5989cc9a..a03d15256 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -472,8 +472,6 @@ public void testDefaultSemanticSearchUseCaseWithFailureExpected() throws Excepti assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); getAndAssertWorkflowStatus(client(), workflowId, State.FAILED, ProvisioningProgress.FAILED); - String error = getAndWorkflowStatusError(client(), workflowId); - assertTrue(error.contains("org.opensearch.flowframework.exception.WorkflowStepException during step create_ingest_pipeline")); } public void testAllDefaultUseCasesCreation() throws Exception { diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 3b6af8ffa..538747b94 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.support.PlainActionFuture; @@ -24,6 +25,7 @@ import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.RemoteTransportException; import java.io.IOException; import java.util.Collections; @@ -139,4 +141,25 @@ public void testCreateIndexStepFailure() throws ExecutionException, InterruptedE assertTrue(ex.getCause() instanceof Exception); assertEquals("Failed to create the index demo", ex.getCause().getMessage()); } + + public void testCreateIndexStepUnsafeFailure() throws ExecutionException, InterruptedException, IOException { + @SuppressWarnings({ "unchecked" }) + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + PlainActionFuture future = createIndexStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + assertFalse(future.isDone()); + verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); + + actionListenerCaptor.getValue().onFailure(new RemoteTransportException("test", new ResourceNotFoundException("test"))); + + assertTrue(future.isDone()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof Exception); + assertEquals("Failed to create the index demo", ex.getCause().getMessage()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index 859b7bf0d..f8f2bce8f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -229,7 +229,7 @@ public void testRegisterLocalCustomModelFailure() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new IllegalArgumentException("test")); + actionListener.onFailure(new IllegalArgumentException("Failed to register local model in step test-node-id")); return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index 431827a1c..dcd2098d5 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -217,7 +217,7 @@ public void testRegisterLocalPretrainedModelFailure() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new IllegalArgumentException("test")); + actionListener.onFailure(new IllegalArgumentException("Failed to register local model in step test-node-id")); return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index e98b7d5d5..df6fd94e2 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -222,7 +222,7 @@ public void testRegisterLocalSparseEncodingModelFailure() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new IllegalArgumentException("test")); + actionListener.onFailure(new IllegalArgumentException("Failed to register local model in step test-node-id")); return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index 11eb6af05..1312d1638 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -10,6 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; @@ -23,6 +24,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.RemoteTransportException; import java.io.IOException; import java.util.Collections; @@ -185,7 +187,28 @@ public void testRegisterAndDeployRemoteModelSuccess() throws Exception { public void testRegisterRemoteModelFailure() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new IllegalArgumentException("test")); + actionListener.onFailure(new IllegalArgumentException("Failed to register remote model")); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + assertTrue(future.isDone()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to register remote model", ex.getCause().getMessage()); + + } + + public void testRegisterRemoteModelUnSafeFailure() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RemoteTransportException("test", new ResourceNotFoundException("test"))); return null; }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any());