Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get model group API #1670

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import org.opensearch.action.ActionType;

public class MLModelGroupGetAction extends ActionType<MLModelGroupGetResponse> {
public static final MLModelGroupGetAction INSTANCE = new MLModelGroupGetAction();

Check warning on line 11 in common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java#L11

Added line #L11 was not covered by tests
public static final String NAME = "cluster:admin/opensearch/ml/model_groups/get";

private MLModelGroupGetAction() { super(NAME, MLModelGroupGetResponse::new);}

Check warning on line 14 in common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java#L14

Added line #L14 was not covered by tests
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLModelGroupGetRequest extends ActionRequest {

String modelGroupId;

@Builder
public MLModelGroupGetRequest(String modelGroupId) {
this.modelGroupId = modelGroupId;
}

public MLModelGroupGetRequest(StreamInput in) throws IOException {
super(in);
this.modelGroupId = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.modelGroupId);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.modelGroupId == null) {
exception = addValidationError("Model group id can't be null", exception);
}

return exception;
}

public static MLModelGroupGetRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLModelGroupGetRequest) {
return (MLModelGroupGetRequest)actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLModelGroupGetRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLModelGroupGetRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLModelGroup;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
@ToString
public class MLModelGroupGetResponse extends ActionResponse implements ToXContentObject {

MLModelGroup mlModelGroup;

@Builder
public MLModelGroupGetResponse(MLModelGroup mlModelGroup) {
this.mlModelGroup = mlModelGroup;
}


public MLModelGroupGetResponse(StreamInput in) throws IOException {
super(in);
mlModelGroup = mlModelGroup.fromStream(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException{
mlModelGroup.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
return mlModelGroup.toXContent(xContentBuilder, params);
}

public static MLModelGroupGetResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLModelGroupGetResponse) {
return (MLModelGroupGetResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLModelGroupGetResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLModelGroupGetResponse", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamOutput;

import java.io.IOException;
import java.io.UncheckedIOException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;

public class MLModelGroupGetRequestTest {
private String modelGroupId;

@Before
public void setUp() {
modelGroupId = "test_id";
}

@Test
public void writeTo_Success() throws IOException {
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder()
.modelGroupId(modelGroupId).build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
mlModelGroupGetRequest.writeTo(bytesStreamOutput);
MLModelGroupGetRequest parsedModel = new MLModelGroupGetRequest(bytesStreamOutput.bytes().streamInput());
assertEquals(parsedModel.getModelGroupId(), modelGroupId);
}

@Test
public void validate_Exception_NullmodelGroupId() {
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder().build();

ActionRequestValidationException exception = mlModelGroupGetRequest.validate();
assertEquals("Validation Failed: 1: Model group id can't be null;", exception.getMessage());
}

@Test
public void fromActionRequest_Success() {
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder()
.modelGroupId(modelGroupId).build();
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
mlModelGroupGetRequest.writeTo(out);
}
};
MLModelGroupGetRequest result = MLModelGroupGetRequest.fromActionRequest(actionRequest);
assertNotSame(result, mlModelGroupGetRequest);
assertEquals(result.getModelGroupId(), mlModelGroupGetRequest.getModelGroupId());
}

@Test(expected = UncheckedIOException.class)
public void fromActionRequest_IOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException("test");
}
};
MLModelGroupGetRequest.fromActionRequest(actionRequest);
}

@Test
public void validate_Success() {
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder().modelGroupId(modelGroupId).build();
ActionRequestValidationException actionRequestValidationException = mlModelGroupGetRequest.validate();
assertNull(actionRequestValidationException);
}

@Test
public void fromActionRequestWithMLModelGroupGetRequest_Success() {
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder().modelGroupId(modelGroupId).build();
MLModelGroupGetRequest mlModelGroupGetRequestFromActionRequest = MLModelGroupGetRequest.fromActionRequest(mlModelGroupGetRequest);
assertSame(mlModelGroupGetRequest, mlModelGroupGetRequestFromActionRequest);
assertEquals(mlModelGroupGetRequest.getModelGroupId(), mlModelGroupGetRequestFromActionRequest.getModelGroupId());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLModelGroup;

import java.io.IOException;
import java.io.UncheckedIOException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;

public class MLModelGroupGetResponseTest {

MLModelGroup mlModelGroup;

@Before
public void setUp() {
mlModelGroup = MLModelGroup.builder()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add more complex example with adding all the fields like backend roles.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To test such complex cases, we might need to add them in IT security tests. In this class we can only test two cases where user has access or no access to model group. Will add them in security tests

.name("modelGroup1")
.latestVersion(1)
.description("This is an example model group")
.access("public")
.build();
}

@Test
public void writeTo_Success() throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
MLModelGroupGetResponse response = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
response.writeTo(bytesStreamOutput);
MLModelGroupGetResponse parsedResponse = new MLModelGroupGetResponse(bytesStreamOutput.bytes().streamInput());
assertNotEquals(response.mlModelGroup, parsedResponse.mlModelGroup);
assertEquals(response.mlModelGroup.getName(), parsedResponse.mlModelGroup.getName());
assertEquals(response.mlModelGroup.getDescription(), parsedResponse.mlModelGroup.getDescription());
assertEquals(response.mlModelGroup.getLatestVersion(), parsedResponse.mlModelGroup.getLatestVersion());
}

@Test
public void toXContentTest() throws IOException {
MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
mlModelGroupGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
assertEquals("{\"name\":\"modelGroup1\"," +
"\"latest_version\":1," +
"\"description\":\"This is an example model group\"," +
"\"access\":\"public\"}",
jsonStr);
}

@Test
public void fromActionResponseWithMLModelGroupGetResponse_Success() {
MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
MLModelGroupGetResponse mlModelGroupGetResponseFromActionResponse = MLModelGroupGetResponse.fromActionResponse(mlModelGroupGetResponse);
assertSame(mlModelGroupGetResponse, mlModelGroupGetResponseFromActionResponse);
assertEquals(mlModelGroupGetResponse.mlModelGroup, mlModelGroupGetResponseFromActionResponse.mlModelGroup);
Comment on lines +72 to +73
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are checking references here, is that the goal?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

}

@Test
public void fromActionResponse_Success() {
MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
mlModelGroupGetResponse.writeTo(out);
}
};
MLModelGroupGetResponse mlModelGroupGetResponseFromActionResponse = MLModelGroupGetResponse.fromActionResponse(actionResponse);
assertNotSame(mlModelGroupGetResponse, mlModelGroupGetResponseFromActionResponse);
assertNotEquals(mlModelGroupGetResponse.mlModelGroup, mlModelGroupGetResponseFromActionResponse.mlModelGroup);
}

@Test(expected = UncheckedIOException.class)
public void fromActionResponse_IOException() {
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
MLModelGroupGetResponse.fromActionResponse(actionResponse);
}
}
Loading
Loading