From d8e8ff2ccee3648806c1eed293ec869d58f1f577 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Thu, 2 Nov 2023 16:53:42 -0700 Subject: [PATCH] Fixing MachineLearningNodeClient create connector, deploy model, register model group actions (#1580) (#1584) Signed-off-by: Joshua Palis (cherry picked from commit 882246c04daacc98eb9c72fbdcabcce497856423) Co-authored-by: Joshua Palis --- .../ml/client/MachineLearningNodeClient.java | 44 +++++++++++++++---- .../connector/MLCreateConnectorResponse.java | 22 ++++++++++ .../deploy/MLDeployModelResponse.java | 22 ++++++++++ .../MLRegisterModelGroupResponse.java | 22 ++++++++++ 4 files changed, 102 insertions(+), 8 deletions(-) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index ac73b397e7..4c83b5060b 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -183,7 +183,12 @@ public void registerModelGroup( ActionListener listener ) { MLRegisterModelGroupRequest mlRegisterModelGroupRequest = new MLRegisterModelGroupRequest(mlRegisterModelGroupInput); - client.execute(MLRegisterModelGroupAction.INSTANCE, mlRegisterModelGroupRequest, listener); + client + .execute( + MLRegisterModelGroupAction.INSTANCE, + mlRegisterModelGroupRequest, + getMlRegisterModelGroupResponseActionListener(listener) + ); } /** @@ -236,18 +241,41 @@ public void register(MLRegisterModelInput mlInput, ActionListener listener) { MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false); - client - .execute( - MLDeployModelAction.INSTANCE, - deployModelRequest, - ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); }) - ); + client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, getMlDeployModelResponseActionListener(listener)); } @Override public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener listener) { MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput); - client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, listener); + client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener)); + } + + private ActionListener getMlDeployModelResponseActionListener(ActionListener listener) { + ActionListener actionListener = wrapActionListener(listener, response -> { + MLDeployModelResponse deployModelResponse = MLDeployModelResponse.fromActionResponse(response); + return deployModelResponse; + }); + return actionListener; + } + + private ActionListener getMlCreateConnectorResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, response -> { + MLCreateConnectorResponse createConnectorResponse = MLCreateConnectorResponse.fromActionResponse(response); + return createConnectorResponse; + }); + return actionListener; + } + + private ActionListener getMlRegisterModelGroupResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, response -> { + MLRegisterModelGroupResponse registerModelGroupResponse = MLRegisterModelGroupResponse.fromActionResponse(response); + return registerModelGroupResponse; + }); + return actionListener; } private ActionListener getMlPredictionTaskResponseActionListener(ActionListener listener) { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java index bf7b78e775..68ce877baa 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java @@ -7,12 +7,17 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLCreateConnectorResponse extends ActionResponse implements ToXContentObject { @@ -42,4 +47,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + public static MLCreateConnectorResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLCreateConnectorResponse) { + return (MLCreateConnectorResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateConnectorResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLCreateConnectorResponse", e); + } + + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java index ddf0104c9e..ca35af68f0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java @@ -7,6 +7,8 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; @@ -14,7 +16,10 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLTaskType; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLDeployModelResponse extends ActionResponse implements ToXContentObject { @@ -57,4 +62,21 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par builder.endObject(); return builder; } + + public static MLDeployModelResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLDeployModelResponse) { + return (MLDeployModelResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLDeployModelResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLDeployModelResponse", e); + } + + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java index 2b70ede72f..01c63d18de 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java @@ -7,12 +7,17 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLRegisterModelGroupResponse extends ActionResponse implements ToXContentObject { @@ -49,4 +54,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + public static MLRegisterModelGroupResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLRegisterModelGroupResponse) { + return (MLRegisterModelGroupResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterModelGroupResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterModelGroupResponse", e); + } + + } }