Skip to content

Commit

Permalink
Add Search Model Group Rest Action unit tests and minor fixes (#929)
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
rbhavna authored and b4sjoo committed May 30, 2023
1 parent 6210255 commit 2115f54
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ public class CommonValue {
" \""+MLModelGroup.MODEL_GROUP_ID_FIELD+"\": {\n" +
" \"type\": \"keyword\"\n" +
" },\n" +
" \""+MLModelGroup.BACKEND_ROLES_FIELD+"\": {\n" +
" \"type\": \"text\",\n" +
" \"fields\": {\n" +
" \"keyword\": {\n" +
" \"type\": \"keyword\",\n" +
" \"ignore_above\": 256\n" +
" }\n" +
" }\n" +
" },\n" +
" \""+MLModelGroup.ACCESS+"\": {\n" +
" \"type\": \"keyword\"\n" +
" },\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
@Getter
public class MLModelGroup implements ToXContentObject {
public static final String MODEL_GROUP_NAME_FIELD = "name"; //name of the model group
// We use int type for version in first release 1.3. In 2.4, we changed to
// use String type for version. Keep this old version field for old models.
public static final String DESCRIPTION_FIELD = "description"; //description of the model group
public static final String LATEST_VERSION_FIELD = "latest_version"; //latest model version added to the model group
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //back_end roles as specified by the owner/admin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@

package org.opensearch.ml.rest;

import static org.opensearch.commons.ConfigConstants.*;
import static org.opensearch.ml.common.MLTask.*;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH;
import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD;
import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD;
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT;
import static org.opensearch.ml.stats.MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT;
import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL;
Expand All @@ -16,7 +23,14 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
Expand Down Expand Up @@ -51,6 +65,7 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.ModelAccessMode;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
Expand All @@ -60,6 +75,7 @@
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLActionLevelStat;
Expand Down Expand Up @@ -619,7 +635,7 @@ private void verifyResponse(Consumer<Map<String, Object>> verificationConsumer,
}
}

public MLRegisterModelInput createRegisterModelInput() {
public MLRegisterModelInput createRegisterModelInput(String modelGroupID) {
MLModelConfig modelConfig = TextEmbeddingModelConfig
.builder()
.modelType("bert")
Expand All @@ -630,6 +646,7 @@ public MLRegisterModelInput createRegisterModelInput() {
.builder()
.modelName("test_model_name")
.version("1.0.0")
.modelGroupId(modelGroupID)
.functionName(FunctionName.TEXT_EMBEDDING)
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.modelConfig(modelConfig)
Expand All @@ -639,6 +656,42 @@ public MLRegisterModelInput createRegisterModelInput() {
.build();
}

public MLRegisterModelGroupInput createRegisterModelGroupInput(
List<String> backendRoles,
ModelAccessMode modelAccessMode,
Boolean isAddAllBackendRoles
) {
return MLRegisterModelGroupInput
.builder()
.name("modelGroupName")
.description("This is a test model group")
.backendRoles(backendRoles)
.modelAccessMode(modelAccessMode)
.isAddAllBackendRoles(isAddAllBackendRoles)
.build();
}

public void registerModelGroup(RestClient client, String input, Consumer<Map<String, Object>> function) throws IOException {
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/model_groups/_register", null, input, null);
verifyResponse(function, response);
}

public void updateModelGroup(RestClient client, String modelGroupId, Consumer<Map<String, Object>> function) throws IOException {
Response response = TestHelper
.makeRequest(client, "POST", "/_plugins/_ml/model_groups/" + modelGroupId + "/_update", null, "", null);
verifyResponse(function, response);
}

public void deleteModelGroup(RestClient client, String modelGroupId, Consumer<Map<String, Object>> function) throws IOException {
Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/model_groups/" + modelGroupId, null, "", null);
verifyResponse(function, response);
}

public void searchModelGroups(RestClient client, String query, Consumer<Map<String, Object>> function) throws IOException {
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/model_groups/_search", null, query, null);
verifyResponse(function, response);
}

public void registerModel(RestClient client, String input, Consumer<Map<String, Object>> function) throws IOException {
Response response = TestHelper.makeRequest(client, "POST", "/_plugins/_ml/models/_register", null, input, null);
verifyResponse(function, response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class RestMLCustomModelActionIT extends MLCommonsRestTestCase {

@Before
public void setup() {
registerModelInput = createRegisterModelInput();
registerModelInput = createRegisterModelInput("testModelGroupID");
}

@Ignore
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.utils.TestHelper.getSearchAllRestRequest;

import java.io.IOException;
import java.util.List;

import org.apache.lucene.search.TotalHits;
import org.hamcrest.Matchers;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.Strings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;

public class RestMLSearchModelGroupActionTests extends OpenSearchTestCase {

private RestMLSearchModelGroupAction restMLSearchModelGroupAction;

NodeClient client;
private ThreadPool threadPool;
@Mock
RestChannel channel;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
restMLSearchModelGroupAction = new RestMLSearchModelGroupAction();
threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
client = spy(new NodeClient(Settings.EMPTY, threadPool));

XContentBuilder builder = XContentFactory.jsonBuilder();

doReturn(builder).when(channel).newBuilder();

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(2);

String modelGroupContent = "{\"name\":\"modelName\",\"description\":\"description\",\"model_access_mode\":\"public\"}";
SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelGroupContent));
SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN);
SearchResponseSections searchSections = new SearchResponseSections(
hits,
InternalAggregations.EMPTY,
null,
false,
false,
null,
1
);
SearchResponse searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);
actionListener.onResponse(searchResponse);
return null;
}).when(client).execute(eq(MLModelGroupSearchAction.INSTANCE), any(), any());
}

@Override
public void tearDown() throws Exception {
super.tearDown();
threadPool.shutdown();
client.close();
}

public void testConstructor() {
RestMLSearchModelGroupAction mlSearchModelGroupAction = new RestMLSearchModelGroupAction();
assertNotNull(mlSearchModelGroupAction);
}

public void testGetName() {
String actionName = restMLSearchModelGroupAction.getName();
assertFalse(Strings.isNullOrEmpty(actionName));
assertEquals("ml_search_model_group_action", actionName);
}

public void testRoutes() {
List<RestHandler.Route> routes = restMLSearchModelGroupAction.routes();
assertNotNull(routes);
assertFalse(routes.isEmpty());
RestHandler.Route postRoute = routes.get(0);
assertEquals(RestRequest.Method.POST, postRoute.getMethod());
assertThat(postRoute.getMethod(), Matchers.either(Matchers.is(RestRequest.Method.POST)).or(Matchers.is(RestRequest.Method.GET)));
assertEquals("/_plugins/_ml/model_groups/_search", postRoute.getPath());
}

public void testPrepareRequest() throws Exception {
RestRequest request = getSearchAllRestRequest();
restMLSearchModelGroupAction.handleRequest(request, channel, client);

ArgumentCaptor<SearchRequest> argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class);
ArgumentCaptor<RestResponse> responseCaptor = ArgumentCaptor.forClass(RestResponse.class);
verify(client, times(1)).execute(eq(MLModelGroupSearchAction.INSTANCE), argumentCaptor.capture(), any());
verify(channel, times(1)).sendResponse(responseCaptor.capture());
SearchRequest searchRequest = argumentCaptor.getValue();
String[] indices = searchRequest.indices();
assertArrayEquals(new String[] { ML_MODEL_GROUP_INDEX }, indices);
assertEquals(
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}",
searchRequest.source().toString()
);
RestResponse restResponse = responseCaptor.getValue();
assertNotEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status());
}

public void testPrepareRequest_timeout() throws Exception {
doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(2);

SearchHits hits = new SearchHits(new SearchHit[0], null, Float.NaN);
SearchResponseSections searchSections = new SearchResponseSections(
hits,
InternalAggregations.EMPTY,
null,
true,
false,
null,
1
);
SearchResponse searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);
actionListener.onResponse(searchResponse);
return null;
}).when(client).execute(eq(MLModelGroupSearchAction.INSTANCE), any(), any());

RestRequest request = getSearchAllRestRequest();
restMLSearchModelGroupAction.handleRequest(request, channel, client);

ArgumentCaptor<SearchRequest> argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class);
ArgumentCaptor<RestResponse> responseCaptor = ArgumentCaptor.forClass(RestResponse.class);
verify(client, times(1)).execute(eq(MLModelGroupSearchAction.INSTANCE), argumentCaptor.capture(), any());
verify(channel, times(1)).sendResponse(responseCaptor.capture());
SearchRequest searchRequest = argumentCaptor.getValue();
String[] indices = searchRequest.indices();
assertArrayEquals(new String[] { ML_MODEL_GROUP_INDEX }, indices);
assertEquals(
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}",
searchRequest.source().toString()
);
RestResponse restResponse = responseCaptor.getValue();
assertEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public void setup() throws IOException {
searchSourceBuilder.size(1000);
searchSourceBuilder.fetchSource(new String[] { "petal_length_in_cm", "petal_width_in_cm" }, null);

mlRegisterModelInput = createRegisterModelInput();
mlRegisterModelInput = createRegisterModelInput("testModelGroupID");
}

@After
Expand Down

0 comments on commit 2115f54

Please sign in to comment.