Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport multiple PRs to main from 2.x #1663

Merged
merged 24 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
003f898
fix parameter name in preprocess function; fix remote model function …
ylwu-amzn Sep 26, 2023
96b0f94
throw exception when model group not found during update request (#1447)
rbhavna Oct 6, 2023
56d5802
add status code to model tensor (#1443) (#1453)
ylwu-amzn Oct 6, 2023
111f3dd
register new versions to a model group based on the name provided (#1…
rbhavna Oct 6, 2023
bcbcdb8
fixing metrics correlation algorithm (#1448)
dhrubo-os Oct 6, 2023
3dd9496
if model version fails to register, update model group accordingly (#…
rbhavna Oct 6, 2023
9858730
Update Model API (#1350)
b4sjoo Oct 6, 2023
6c6cb18
Add a setting to control the update connector API (#1465)
b4sjoo Oct 9, 2023
b7ffef7
fix update connector API (#1484)
ylwu-amzn Oct 11, 2023
3eddabc
Performance enhacement for predict action by caching model info (#147…
opensearch-trigger-bot[bot] Oct 12, 2023
1ad1091
fix failed ut from PR 1472 (#1479) (#1510)
opensearch-trigger-bot[bot] Oct 12, 2023
3f354e3
[Backport to 2.11] throw exception if remote model doesn't return 2xx…
opensearch-trigger-bot[bot] Oct 12, 2023
3848c19
fix no worker node exception for remote embedding model (#1482) (#1511)
opensearch-trigger-bot[bot] Oct 12, 2023
43b083d
fix for delete model group API throwing incorrect error when model in…
opensearch-trigger-bot[bot] Oct 12, 2023
4e1301d
fix no worker node error on multi-node cluster (#1487) (#1513)
opensearch-trigger-bot[bot] Oct 12, 2023
ccad060
add prefix to show the error is from remote service (#1499) (#1515)
opensearch-trigger-bot[bot] Oct 12, 2023
5049a7f
fix multiple docs support (#1516)
ylwu-amzn Oct 13, 2023
a340b7f
adding another fix issue to the release note (#1498) (#1514)
opensearch-trigger-bot[bot] Oct 12, 2023
b84b076
add bedrockURL to trusted connector regex list (#1461)
rbhavna Oct 6, 2023
20f366c
return parsing exception 400 for parsing errors
Zhangxunmt Nov 4, 2023
ede60e7
add more ut in restupdateconnector
Zhangxunmt Nov 4, 2023
e2f667a
fix format violations
rbhavna Nov 16, 2023
440c8c3
Fix model/connector update API to address security concern (#1595)
b4sjoo Nov 7, 2023
b1840b8
change XContentFactory to MediaTypeRegistry builder in MLRegisterMode…
rbhavna Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
public class CommonValue {

public static Integer NO_SCHEMA_VERSION = 0;
public static final String REMOTE_SERVICE_ERROR = "Error from remote service: ";
public static final String USER = "user";
public static final String META = "_meta";
public static final String SCHEMA_VERSION_FIELD = "schema_version";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ public MLModel(StreamInput input) throws IOException{
modelContentSizeInBytes = input.readOptionalLong();
modelContentHash = input.readOptionalString();
if (input.readBoolean()) {
modelConfig = new TextEmbeddingModelConfig(input);
if (algorithm.equals(FunctionName.METRICS_CORRELATION)) {
modelConfig = new MetricsCorrelationModelConfig(input);
} else {
modelConfig = new TextEmbeddingModelConfig(input);
}
}
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.gson;
Expand Down Expand Up @@ -69,6 +70,7 @@ public interface Connector extends ToXContentObject, Writeable {

void writeTo(StreamOutput out) throws IOException;

void update(MLCreateConnectorInput updateContent, Function<String, String> function);

<T> void parseResponse(T orElse, List<ModelTensor> modelTensors, boolean b) throws IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;

@Log4j2
@NoArgsConstructor
Expand Down Expand Up @@ -248,6 +249,38 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

@Override
public void update(MLCreateConnectorInput updateContent, Function<String, String> function) {
if (updateContent.getName() != null) {
this.name = updateContent.getName();
}
if (updateContent.getDescription() != null) {
this.description = updateContent.getDescription();
}
if (updateContent.getVersion() != null) {
this.version = updateContent.getVersion();
}
if (updateContent.getProtocol() != null) {
this.protocol = updateContent.getProtocol();
}
if (updateContent.getParameters() != null && updateContent.getParameters().size() > 0) {
this.parameters = updateContent.getParameters();
}
if (updateContent.getCredential() != null && updateContent.getCredential().size() > 0) {
this.credential = updateContent.getCredential();
encrypt(function);
}
if (updateContent.getActions() != null) {
this.actions = updateContent.getActions();
}
if (updateContent.getBackendRoles() != null) {
this.backendRoles = updateContent.getBackendRoles();
}
if (updateContent.getAccess() != null) {
this.access = updateContent.getAccess();
}
}

@Override
public <T> T createPredictPayload(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
Expand All @@ -28,6 +29,10 @@ public MetricsCorrelationModelConfig(String modelType, String allConfig) {
super(modelType, allConfig);
}

public MetricsCorrelationModelConfig(StreamInput in) throws IOException{
super(in);
}

@Override
public String getWriteableName() {
return PARSE_FIELD_NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -24,7 +25,10 @@
@Getter
public class ModelTensors implements Writeable, ToXContentObject {
public static final String OUTPUT_FIELD = "output";
public static final String STATUS_CODE_FIELD = "status_code";
private List<ModelTensor> mlModelTensors;
@Setter
private Integer statusCode;

@Builder
public ModelTensors(List<ModelTensor> mlModelTensors) {
Expand All @@ -41,6 +45,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
builder.endArray();
}
if (statusCode != null) {
builder.field(STATUS_CODE_FIELD, statusCode);
}
builder.endObject();
return builder;
}
Expand All @@ -53,6 +60,7 @@ public ModelTensors(StreamInput in) throws IOException {
mlModelTensors.add(new ModelTensor(in));
}
}
statusCode = in.readOptionalInt();
}

@Override
Expand All @@ -66,6 +74,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalInt(statusCode);
}

public void filter(ModelResultFilter resultFilter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
private Boolean addAllBackendRoles;
private AccessMode access;
private boolean dryRun = false;
private boolean updateConnector = false;

@Builder(toBuilder = true)
public MLCreateConnectorInput(String name,
Expand All @@ -68,9 +69,10 @@ public MLCreateConnectorInput(String name,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode access,
boolean dryRun
boolean dryRun,
boolean updateConnector
) {
if (!dryRun) {
if (!dryRun && !updateConnector) {
if (name == null) {
throw new IllegalArgumentException("Connector name is null");
}
Expand All @@ -92,9 +94,14 @@ public MLCreateConnectorInput(String name,
this.addAllBackendRoles = addAllBackendRoles;
this.access = access;
this.dryRun = dryRun;
this.updateConnector = updateConnector;
}

public static MLCreateConnectorInput parse(XContentParser parser) throws IOException {
return parse(parser, false);
}

public static MLCreateConnectorInput parse(XContentParser parser, boolean updateConnector) throws IOException {
String name = null;
String description = null;
String version = null;
Expand Down Expand Up @@ -159,7 +166,7 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep
break;
}
}
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun);
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun, updateConnector);
}

@Override
Expand Down Expand Up @@ -201,10 +208,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

@Override
public void writeTo(StreamOutput output) throws IOException {
output.writeString(name);
output.writeOptionalString(name);
output.writeOptionalString(description);
output.writeString(version);
output.writeString(protocol);
output.writeOptionalString(version);
output.writeOptionalString(protocol);
if (parameters != null) {
output.writeBoolean(true);
output.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString);
Expand Down Expand Up @@ -240,13 +247,14 @@ public void writeTo(StreamOutput output) throws IOException {
output.writeBoolean(false);
}
output.writeBoolean(dryRun);
output.writeBoolean(updateConnector);
}

public MLCreateConnectorInput(StreamInput input) throws IOException {
name = input.readString();
name = input.readOptionalString();
description = input.readOptionalString();
version = input.readString();
protocol = input.readString();
version = input.readOptionalString();
protocol = input.readOptionalString();
if (input.readBoolean()) {
parameters = input.readMap(s -> s.readString(), s -> s.readString());
}
Expand All @@ -268,5 +276,6 @@ public MLCreateConnectorInput(StreamInput input) throws IOException {
this.access = input.readEnum(AccessMode.class);
}
dryRun = input.readBoolean();
updateConnector = input.readBoolean();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,31 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
public class MLUpdateConnectorRequest extends ActionRequest {
String connectorId;
Map<String, Object> updateContent;
MLCreateConnectorInput updateContent;

@Builder
public MLUpdateConnectorRequest(String connectorId, Map<String, Object> updateContent) {
public MLUpdateConnectorRequest(String connectorId, MLCreateConnectorInput updateContent) {
this.connectorId = connectorId;
this.updateContent = updateContent;
}

public MLUpdateConnectorRequest(StreamInput in) throws IOException {
super(in);
this.connectorId = in.readString();
this.updateContent = in.readMap();
this.updateContent = new MLCreateConnectorInput(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.connectorId);
out.writeMap(this.getUpdateContent());
this.updateContent.writeTo(out);
}

@Override
Expand All @@ -55,14 +54,17 @@ public ActionRequestValidationException validate() {
exception = addValidationError("ML connector id can't be null", exception);
}

if (updateContent == null) {
exception = addValidationError("Update connector content can't be null", exception);
}

return exception;
}

public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException {
Map<String, Object> dataAsMap = null;
dataAsMap = parser.map();
MLCreateConnectorInput updateContent = MLCreateConnectorInput.parse(parser, true);

return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(dataAsMap).build();
return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(updateContent).build();
}

public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionRequest) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model;

import org.opensearch.action.ActionType;
import org.opensearch.action.update.UpdateResponse;

public class MLUpdateModelAction extends ActionType<UpdateResponse> {
public static MLUpdateModelAction INSTANCE = new MLUpdateModelAction();
public static final String NAME = "cluster:admin/opensearch/ml/models/update";

private MLUpdateModelAction() {
super(NAME, UpdateResponse::new);
}
}
Loading
Loading