diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 93d51dd1aa..4cf957c499 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -20,12 +20,12 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskResponse; @@ -178,8 +178,8 @@ public void onResponse(MLModel mlModel) { ); } else if (e instanceof MLResourceNotFoundException) { wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND)); - } else if (e instanceof MLLimitExceededException) { - wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.SERVICE_UNAVAILABLE)); + } else if (e instanceof CircuitBreakingException) { + wrappedListener.onFailure(e); } else { wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index ce24619226..6c608c95cd 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -59,6 +59,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentLinkedDeque; @@ -838,7 +839,9 @@ private ThreadedActionListener threadedActionListener(String threadPoolNa * @param runningTaskLimit limit */ public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) { - checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); + if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE) { + checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); + } mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit); } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index b5f8b46167..b341f4c9f5 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -143,13 +143,7 @@ public void dispatchTask( if (clusterService.localNode().getId().equals(node.getId())) { log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId()); request.setDispatchTask(false); - run( - // This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here - functionName, - request, - transportService, - listener - ); + checkCBAndExecute(functionName, request, listener); } else { log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId()); request.setDispatchTask(false); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index b2c71d6ed8..54195ab156 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -87,8 +87,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) { public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener listener) { if (!request.isDispatchTask()) { log.debug("Run ML request {} locally", request.getRequestID()); - checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); - executeTask(request, listener); + checkCBAndExecute(functionName, request, listener); return; } dispatchTask(functionName, request, transportService, listener); @@ -129,4 +128,11 @@ public void dispatchTask( protected abstract TransportResponseHandler getResponseHandler(ActionListener listener); protected abstract void executeTask(Request request, ActionListener listener); + + protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener listener) { + if (functionName != FunctionName.REMOTE) { + checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); + } + executeTask(request, listener); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java index 227518aabf..86fbfb1605 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java @@ -18,12 +18,13 @@ import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.breaker.ThresholdCircuitBreaker; -import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; @@ -92,7 +93,10 @@ public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBrea ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB(); if (openCircuitBreaker != null) { mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).increment(); - throw new MLLimitExceededException(openCircuitBreaker.getName() + " is open, please check your resources!"); + throw new CircuitBreakingException( + openCircuitBreaker.getName() + " is open, please check your resources!", + CircuitBreaker.Durability.TRANSIENT + ); } } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index 88d0262c4e..baaf2cec05 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -35,6 +35,8 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; @@ -43,7 +45,6 @@ import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; @@ -242,7 +243,7 @@ public void testPrediction_MLLimitExceededException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); - listener.onFailure(new MLLimitExceededException("Memory Circuit Breaker is open, please check your resources!")); + listener.onFailure(new CircuitBreakingException("Memory Circuit Breaker is open, please check your resources!", CircuitBreaker.Durability.TRANSIENT)); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); @@ -253,7 +254,7 @@ public void testPrediction_MLLimitExceededException() { transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(CircuitBreakingException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage()); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java index fb42b9d073..406286d4a5 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java @@ -58,7 +58,7 @@ public void testRunWithMemoryCircuitBreaker() throws IOException { exception.getMessage(), allOf( containsString("Memory Circuit Breaker is open, please check your resources!"), - containsString("m_l_limit_exceeded_exception") + containsString("circuit_breaking_exception") ) ); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index d301e2b381..69a11a66cc 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -199,7 +199,7 @@ public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, In Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModelWithTTL("openAI-GPT-3.5 completions", connectorId, 1); + response = registerRemoteModelWithTTLAndSkipHeapMemCheck("openAI-GPT-3.5 completions", connectorId, 1); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; @@ -814,11 +814,13 @@ public static Response registerRemoteModel(String name, String connectorId) thro .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); } - public static Response registerRemoteModelWithTTL(String name, String connectorId, int ttl) throws IOException { + public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name, String connectorId, int ttl) throws IOException { String registerModelGroupEntity = "{\n" + " \"name\": \"remote_model_group\",\n" + " \"description\": \"This is an example description\"\n" + "}"; + String updateJVMHeapThreshold = "{\"persistent\":{\"plugins.ml_commons.jvm_heap_memory_threshold\":0}}"; + TestHelper.makeRequest(client(), "PUT", "/_cluster/settings", null, TestHelper.toHttpEntity(updateJVMHeapThreshold), null); Response response = TestHelper .makeRequest( client(),