Skip to content

Commit

Permalink
Fixing MachineLearningNodeClient create connector, deploy model, regi…
Browse files Browse the repository at this point in the history
…ster model group actions (opensearch-project#1580) (opensearch-project#1584)

Signed-off-by: Joshua Palis <[email protected]>
(cherry picked from commit 882246c)

Co-authored-by: Joshua Palis <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and joshpalis authored Nov 2, 2023
1 parent 3aa10c1 commit d8e8ff2
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,12 @@ public void registerModelGroup(
ActionListener<MLRegisterModelGroupResponse> listener
) {
MLRegisterModelGroupRequest mlRegisterModelGroupRequest = new MLRegisterModelGroupRequest(mlRegisterModelGroupInput);
client.execute(MLRegisterModelGroupAction.INSTANCE, mlRegisterModelGroupRequest, listener);
client
.execute(
MLRegisterModelGroupAction.INSTANCE,
mlRegisterModelGroupRequest,
getMlRegisterModelGroupResponseActionListener(listener)
);
}

/**
Expand Down Expand Up @@ -236,18 +241,41 @@ public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterMode
@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false);
client
.execute(
MLDeployModelAction.INSTANCE,
deployModelRequest,
ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); })
);
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, getMlDeployModelResponseActionListener(listener));
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput);
client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, listener);
client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener));
}

private ActionListener<MLDeployModelResponse> getMlDeployModelResponseActionListener(ActionListener<MLDeployModelResponse> listener) {
ActionListener<MLDeployModelResponse> actionListener = wrapActionListener(listener, response -> {
MLDeployModelResponse deployModelResponse = MLDeployModelResponse.fromActionResponse(response);
return deployModelResponse;
});
return actionListener;
}

private ActionListener<MLCreateConnectorResponse> getMlCreateConnectorResponseActionListener(
ActionListener<MLCreateConnectorResponse> listener
) {
ActionListener<MLCreateConnectorResponse> actionListener = wrapActionListener(listener, response -> {
MLCreateConnectorResponse createConnectorResponse = MLCreateConnectorResponse.fromActionResponse(response);
return createConnectorResponse;
});
return actionListener;
}

private ActionListener<MLRegisterModelGroupResponse> getMlRegisterModelGroupResponseActionListener(
ActionListener<MLRegisterModelGroupResponse> listener
) {
ActionListener<MLRegisterModelGroupResponse> actionListener = wrapActionListener(listener, response -> {
MLRegisterModelGroupResponse registerModelGroupResponse = MLRegisterModelGroupResponse.fromActionResponse(response);
return registerModelGroupResponse;
});
return actionListener;
}

private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLCreateConnectorResponse extends ActionResponse implements ToXContentObject {
Expand Down Expand Up @@ -42,4 +47,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
return builder;
}

public static MLCreateConnectorResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLCreateConnectorResponse) {
return (MLCreateConnectorResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLCreateConnectorResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLCreateConnectorResponse", e);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLTaskType;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLDeployModelResponse extends ActionResponse implements ToXContentObject {
Expand Down Expand Up @@ -57,4 +62,21 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
builder.endObject();
return builder;
}

public static MLDeployModelResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLDeployModelResponse) {
return (MLDeployModelResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLDeployModelResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLDeployModelResponse", e);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLRegisterModelGroupResponse extends ActionResponse implements ToXContentObject {
Expand Down Expand Up @@ -49,4 +54,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
return builder;
}

public static MLRegisterModelGroupResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLRegisterModelGroupResponse) {
return (MLRegisterModelGroupResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterModelGroupResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterModelGroupResponse", e);
}

}
}

0 comments on commit d8e8ff2

Please sign in to comment.