Skip to content

Commit

Permalink
1. revert model content hash change; 2.fix task search with permissio…
Browse files Browse the repository at this point in the history
…n issue; 3.fix failure UTs

Signed-off-by: Zan Niu <[email protected]>
  • Loading branch information
zane-neo committed May 24, 2023
1 parent 4345954 commit 3150e4e
Show file tree
Hide file tree
Showing 16 changed files with 91 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public void setUp() throws Exception {
.functionName(functionName)
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("mockModelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public void setUp() throws Exception {
.functionName(functionName)
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class MLRegisterModelInputTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"url\":\"url\",\"model_format\":\"ONNX\"," +
private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"ONNX\"," +
"\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," +
"\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" +
"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";
Expand All @@ -51,6 +51,8 @@ public class MLRegisterModelInputTest {
private final String version = "version";
private final String url = "url";

private final String modelGroupId = "modelGroupId";

@Before
public void setUp() throws Exception {
config = TextEmbeddingModelConfig.builder()
Expand All @@ -64,6 +66,7 @@ public void setUp() throws Exception {
.functionName(functionName)
.modelName(modelName)
.version(version)
.modelGroupId(modelGroupId)
.url(url)
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
Expand All @@ -86,18 +89,19 @@ public void constructor_NullModelName() {
exceptionRule.expectMessage("model name is null");
MLRegisterModelInput.builder()
.functionName(functionName)
.modelGroupId(modelGroupId)
.modelName(null)
.build();
}

@Test
public void constructor_NullModelVersion() {
public void constructor_NullModelGroupId() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model version is null");
exceptionRule.expectMessage("model group id is null");
MLRegisterModelInput.builder()
.functionName(functionName)
.modelName(modelName)
.version(null)
.modelGroupId(null)
.build();
}

Expand All @@ -109,6 +113,7 @@ public void constructor_NullModelFormat() {
.functionName(functionName)
.modelName(modelName)
.version(version)
.modelGroupId(modelGroupId)
.modelFormat(null)
.url(url)
.build();
Expand All @@ -122,6 +127,7 @@ public void constructor_NullModelConfig() {
.functionName(functionName)
.modelName(modelName)
.version(version)
.modelGroupId(modelGroupId)
.modelFormat(MLModelFormat.ONNX)
.modelConfig(null)
.url(url)
Expand All @@ -133,6 +139,7 @@ public void constructor_SuccessWithMinimalSetup() {
MLRegisterModelInput input = MLRegisterModelInput.builder()
.modelName(modelName)
.version(version)
.modelGroupId(modelGroupId)
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
.url(url)
Expand All @@ -158,7 +165,7 @@ public void testToXContent() throws Exception {
public void testToXContent_Incomplete() throws Exception {
String expectedIncompleteInputStr =
"{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\"," +
"\"version\":\"version\",\"deploy_model\":true}";
"\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"deploy_model\":true}";
input.setUrl(null);
input.setModelConfig(null);
input.setModelFormat(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public void setUp(){
.functionName(FunctionName.KMEANS)
.modelName("modelName")
.version("version")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,18 @@ private void readInputStream(MLRegisterModelMetaInput input) throws IOException


@Test
public void testToXContent() throws IOException {
public void testToXContent() throws IOException {{
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
assertEquals(expected, mlModelContent);
}
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
assertEquals(expected, mlModelContent);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
DownloadUtils.download(url, modelPath, new ProgressBar());
verifyModelZipFile(modelFormat, modelPath, modelName);
String hash = calculateFileHash(modelZipFile);
if (modelContentHash == null || hash.equals(modelContentHash)) {
if (hash.equals(modelContentHash)) {
List<String> chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE);
Map<String, Object> result = new HashMap<>();
result.put(CHUNK_FILES, chunkFiles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ void registerModel(ActionListener<MLRegisterModelResponse> listener) throws Inte
.functionName(functionName)
.modelName(FunctionName.METRICS_CORRELATION.name())
.version(MCORR_ML_VERSION)
.modelGroupId(functionName.name())
.modelFormat(modelFormat)
.hashValue(MODEL_CONTENT_HASH)
.modelConfig(modelConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ public void testDownloadPrebuiltModelConfig_WrongModelName() {
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("test_model_name")
.version("1.0.1")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand All @@ -157,6 +158,7 @@ public void testDownloadPrebuiltModelConfig() {
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2")
.version("1.0.1")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand All @@ -176,6 +178,7 @@ public void testDownloadPrebuiltModelMetaList() throws PrivilegedActionException
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2")
.version("1.0.1")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand All @@ -190,6 +193,7 @@ public void testIsModelAllowed_true() throws PrivilegedActionException {
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2")
.version("1.0.1")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand All @@ -204,6 +208,7 @@ public void testIsModelAllowed_WrongModelName() throws PrivilegedActionException
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2-wrong")
.version("1.0.1")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand All @@ -218,6 +223,7 @@ public void testIsModelAllowed_WrongModelVersion() throws PrivilegedActionExcept
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2")
.version("000")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> {
if (!access) {
actionListener.onFailure(new MLValidationException("User Doesn't have previlege to perform this operation"));
actionListener.onFailure(new MLValidationException("User Doesn't have privilege to perform this operation"));
} else {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,37 @@

package org.opensearch.ml.action.tasks;

import lombok.extern.log4j.Log4j2;
import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

@Log4j2
public class SearchTaskTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
private MLSearchHandler mlSearchHandler;
private Client client;

@Inject
public SearchTaskTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) {
public SearchTaskTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(MLTaskSearchAction.NAME, transportService, actionFilters, SearchRequest::new);
this.mlSearchHandler = mlSearchHandler;
this.client = client;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
mlSearchHandler.search(request, actionListener);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.search(request, actionListener);
} catch (Exception e) {
log.error(e.getMessage(), e);
actionListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ private MLRegisterModelInput prepareInput() {
.functionName(FunctionName.BATCH_RCF)
.deployModel(true)
.version("1.0")
.modelGroupId("model group id")
.modelName("Test Model")
.modelConfig(
new TextEmbeddingModelConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

package org.opensearch.ml.action.tasks;

import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import org.junit.Before;
import org.mockito.Mock;
Expand All @@ -16,12 +17,15 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;


public class SearchTaskTransportActionTests extends OpenSearchTestCase {
@Mock
Client client;
Expand All @@ -41,7 +45,6 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase {
@Mock
ActionListener<SearchResponse> actionListener;

MLSearchHandler mlSearchHandler;
SearchTaskTransportAction searchTaskTransportAction;

@Mock
Expand All @@ -50,12 +53,15 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper));
searchTaskTransportAction = new SearchTaskTransportAction(transportService, actionFilters, mlSearchHandler);
searchTaskTransportAction = new SearchTaskTransportAction(transportService, actionFilters, client);
ThreadPool threadPool = mock(ThreadPool.class);
when(client.threadPool()).thenReturn(threadPool);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void test_DoExecute() {
searchTaskTransportAction.doExecute(null, searchRequest, actionListener);
verify(mlSearchHandler).search(searchRequest, actionListener);
verify(client).search(searchRequest, actionListener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
Expand Down Expand Up @@ -50,6 +51,7 @@
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -64,7 +66,10 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -178,7 +183,7 @@ public void setup() throws URISyntaxException {
modelName = "model_name1";
modelId = randomAlphaOfLength(10);
modelContentHashValue = "c446f747520bcc6af053813cb1e8d34944a7c4686bbb405aeaa23883b5a806c8";
version = "1.0.0";
version = "1";
url = "http://testurl";
MLModelConfig modelConfig = TextEmbeddingModelConfig
.builder()
Expand All @@ -191,6 +196,7 @@ public void setup() throws URISyntaxException {
.builder()
.modelName(modelName)
.version(version)
.modelGroupId("modelGroupId")
.functionName(FunctionName.TEXT_EMBEDDING)
.modelFormat(modelFormat)
.modelConfig(modelConfig)
Expand Down Expand Up @@ -263,6 +269,23 @@ public void setup() throws URISyntaxException {
.build();
modelChunk0 = model.toBuilder().content(Base64.getEncoder().encodeToString("test chunk1".getBytes(StandardCharsets.UTF_8))).build();
modelChunk1 = model.toBuilder().content(Base64.getEncoder().encodeToString("test chunk2".getBytes(StandardCharsets.UTF_8))).build();

GetResponse getResponse = mock(GetResponse.class);
when(getResponse.isExists()).thenReturn(true);
Map<String, Object> sourceMap = new HashMap<>();
sourceMap.put("latest_version", 0);
when(getResponse.getSourceAsMap()).thenReturn(sourceMap);
doAnswer(invocation -> {
ActionListener<GetResponse> getResponseActionListener = invocation.getArgument(1);
getResponseActionListener.onResponse(getResponse);
return null;
}).when(client).get(any(GetRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<Void> updateActionListener = invocation.getArgument(1);
updateActionListener.onResponse(null);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));
}

public void testRegisterMLModel_ExceedMaxRunningTask() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private RestRequest getRestRequest_NullModelId() {
final Map<String, Object> modelConfig = Map
.of("model_type", "bert", "embedding_dimension", 384, "framework_type", "sentence_transformers", "all_config", "All Config");
final Map<String, Object> model = Map
.of("name", "test_model", "version", "2", "url", "testUrl", "model_format", "TORCH_SCRIPT", "model_config", modelConfig);
.of("name", "test_model", "version", "2", "model_group_id", "modelGroupId", "url", "testUrl", "model_format", "TORCH_SCRIPT", "model_config", modelConfig);
String requestContent = new Gson().toJson(model).toString();
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withMethod(method)
Expand Down
Loading

0 comments on commit 3150e4e

Please sign in to comment.