diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index ac73b397e7..4c83b5060b 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -183,7 +183,12 @@ public void registerModelGroup( ActionListener listener ) { MLRegisterModelGroupRequest mlRegisterModelGroupRequest = new MLRegisterModelGroupRequest(mlRegisterModelGroupInput); - client.execute(MLRegisterModelGroupAction.INSTANCE, mlRegisterModelGroupRequest, listener); + client + .execute( + MLRegisterModelGroupAction.INSTANCE, + mlRegisterModelGroupRequest, + getMlRegisterModelGroupResponseActionListener(listener) + ); } /** @@ -236,18 +241,41 @@ public void register(MLRegisterModelInput mlInput, ActionListener listener) { MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false); - client - .execute( - MLDeployModelAction.INSTANCE, - deployModelRequest, - ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); }) - ); + client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, getMlDeployModelResponseActionListener(listener)); } @Override public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener listener) { MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput); - client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, listener); + client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener)); + } + + private ActionListener getMlDeployModelResponseActionListener(ActionListener listener) { + ActionListener actionListener = wrapActionListener(listener, response -> { + MLDeployModelResponse deployModelResponse = MLDeployModelResponse.fromActionResponse(response); + return deployModelResponse; + }); + return actionListener; + } + + private ActionListener getMlCreateConnectorResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, response -> { + MLCreateConnectorResponse createConnectorResponse = MLCreateConnectorResponse.fromActionResponse(response); + return createConnectorResponse; + }); + return actionListener; + } + + private ActionListener getMlRegisterModelGroupResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, response -> { + MLRegisterModelGroupResponse registerModelGroupResponse = MLRegisterModelGroupResponse.fromActionResponse(response); + return registerModelGroupResponse; + }); + return actionListener; } private ActionListener getMlPredictionTaskResponseActionListener(ActionListener listener) { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java index bf7b78e775..68ce877baa 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java @@ -7,12 +7,17 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLCreateConnectorResponse extends ActionResponse implements ToXContentObject { @@ -42,4 +47,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + public static MLCreateConnectorResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLCreateConnectorResponse) { + return (MLCreateConnectorResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateConnectorResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLCreateConnectorResponse", e); + } + + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java index ddf0104c9e..ca35af68f0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java @@ -7,6 +7,8 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; @@ -14,7 +16,10 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLTaskType; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLDeployModelResponse extends ActionResponse implements ToXContentObject { @@ -57,4 +62,21 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par builder.endObject(); return builder; } + + public static MLDeployModelResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLDeployModelResponse) { + return (MLDeployModelResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLDeployModelResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLDeployModelResponse", e); + } + + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java index 2b70ede72f..01c63d18de 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java @@ -7,12 +7,17 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLRegisterModelGroupResponse extends ActionResponse implements ToXContentObject { @@ -49,4 +54,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + public static MLRegisterModelGroupResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLRegisterModelGroupResponse) { + return (MLRegisterModelGroupResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterModelGroupResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterModelGroupResponse", e); + } + + } }