diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ExecutionContext.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ExecutionContext.java index 7ff76a9830..66c828bead 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ExecutionContext.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ExecutionContext.java @@ -13,6 +13,14 @@ import lombok.AllArgsConstructor; import lombok.Data; +/** + * This class encapsulates several parameters that are used in a split-batch request case. + * A batch request is that in neural-search side multiple fields are send in one request to ml-commons, + * but the remote model doesn't accept list of string inputs so in ml-commons the request needs split. + * sequence is used to identify the index of the split request. + * countDownLatch is used to wait for all the split requests to finish. + * exceptionHolder is used to hold any exception thrown in a split-batch request. + */ @Data @AllArgsConstructor public class ExecutionContext { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java index acbaea1087..80306c2003 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -60,8 +60,6 @@ public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandle private final MLGuard mlGuard; - private final static Gson GSON = StringUtils.gson; - public MLSdkAsyncHttpResponseHandler( ExecutionContext executionContext, ActionListener> actionListener, @@ -108,17 +106,19 @@ private void processResponse( ) { if (Strings.isBlank(body)) { log.error("Remote model response body is empty!"); - if (executionContext.getExceptionHolder().get() == null) + if (executionContext.getExceptionHolder().get() == null) { executionContext .getExceptionHolder() .compareAndSet(null, new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST)); + } } else { if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) { log.error("Remote server returned error code: {}", statusCode); - if (executionContext.getExceptionHolder().get() == null) + if (executionContext.getExceptionHolder().get() == null) { executionContext .getExceptionHolder() .compareAndSet(null, new OpenSearchStatusException(REMOTE_SERVICE_ERROR + body, RestStatus.fromCode(statusCode))); + } } else { try { ModelTensors tensors = processOutput(body, connector, scriptService, parameters, mlGuard); @@ -126,10 +126,11 @@ private void processResponse( tensorOutputs.put(executionContext.getSequence(), tensors); } catch (Exception e) { log.error("Failed to process response body: {}", body, e); - if (executionContext.getExceptionHolder().get() == null) + if (executionContext.getExceptionHolder().get() == null) { executionContext .getExceptionHolder() .compareAndSet(null, new MLException("Fail to execute predict in aws connector", e)); + } } } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/LoggerUtil.java b/plugin/src/main/java/org/opensearch/ml/utils/LoggerUtil.java new file mode 100644 index 0000000000..1f1a7adc23 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/utils/LoggerUtil.java @@ -0,0 +1,2 @@ +package org.opensearch.ml.utils;public class LoggerUtil { +}