Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model access control dev3 #910

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public void setUp() throws Exception {
.functionName(functionName)
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("mockModelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public void setUp() throws Exception {
.functionName(functionName)
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class MLRegisterModelInputTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"url\":\"url\",\"model_format\":\"ONNX\"," +
private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"ONNX\"," +
"\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," +
"\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" +
"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";
Expand All @@ -51,6 +51,8 @@ public class MLRegisterModelInputTest {
private final String version = "version";
private final String url = "url";

private final String modelGroupId = "modelGroupId";

@Before
public void setUp() throws Exception {
config = TextEmbeddingModelConfig.builder()
Expand All @@ -64,6 +66,7 @@ public void setUp() throws Exception {
.functionName(functionName)
.modelName(modelName)
.version(version)
.modelGroupId(modelGroupId)
.url(url)
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
Expand All @@ -86,18 +89,19 @@ public void constructor_NullModelName() {
exceptionRule.expectMessage("model name is null");
MLRegisterModelInput.builder()
.functionName(functionName)
.modelGroupId(modelGroupId)
.modelName(null)
.build();
}

@Test
public void constructor_NullModelVersion() {
public void constructor_NullModelGroupId() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model version is null");
exceptionRule.expectMessage("model group id is null");
MLRegisterModelInput.builder()
.functionName(functionName)
.modelName(modelName)
.version(null)
.modelGroupId(null)
.build();
}

Expand All @@ -109,6 +113,7 @@ public void constructor_NullModelFormat() {
.functionName(functionName)
.modelName(modelName)
.version(version)
.modelGroupId(modelGroupId)
.modelFormat(null)
.url(url)
.build();
Expand All @@ -122,6 +127,7 @@ public void constructor_NullModelConfig() {
.functionName(functionName)
.modelName(modelName)
.version(version)
.modelGroupId(modelGroupId)
.modelFormat(MLModelFormat.ONNX)
.modelConfig(null)
.url(url)
Expand All @@ -133,6 +139,7 @@ public void constructor_SuccessWithMinimalSetup() {
MLRegisterModelInput input = MLRegisterModelInput.builder()
.modelName(modelName)
.version(version)
.modelGroupId(modelGroupId)
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
.url(url)
Expand All @@ -158,7 +165,7 @@ public void testToXContent() throws Exception {
public void testToXContent_Incomplete() throws Exception {
String expectedIncompleteInputStr =
"{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\"," +
"\"version\":\"version\",\"deploy_model\":true}";
"\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"deploy_model\":true}";
input.setUrl(null);
input.setModelConfig(null);
input.setModelFormat(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public void setUp(){
.functionName(FunctionName.KMEANS)
.modelName("modelName")
.version("version")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,18 @@ private void readInputStream(MLRegisterModelMetaInput input) throws IOException


@Test
public void testToXContent() throws IOException {
public void testToXContent() throws IOException {{
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
assertEquals(expected, mlModelContent);
}
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
assertEquals(expected, mlModelContent);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
DownloadUtils.download(url, modelPath, new ProgressBar());
verifyModelZipFile(modelFormat, modelPath, modelName);
String hash = calculateFileHash(modelZipFile);
if (modelContentHash == null || hash.equals(modelContentHash)) {
if (hash.equals(modelContentHash)) {
List<String> chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE);
Map<String, Object> result = new HashMap<>();
result.put(CHUNK_FILES, chunkFiles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ void registerModel(ActionListener<MLRegisterModelResponse> listener) throws Inte
.functionName(functionName)
.modelName(FunctionName.METRICS_CORRELATION.name())
.version(MCORR_ML_VERSION)
.modelGroupId(functionName.name())
.modelFormat(modelFormat)
.hashValue(MODEL_CONTENT_HASH)
.modelConfig(modelConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ public void testDownloadPrebuiltModelConfig_WrongModelName() {
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("test_model_name")
.version("1.0.1")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand All @@ -157,6 +158,7 @@ public void testDownloadPrebuiltModelConfig() {
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2")
.version("1.0.1")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand All @@ -176,6 +178,7 @@ public void testDownloadPrebuiltModelMetaList() throws PrivilegedActionException
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2")
.version("1.0.1")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand All @@ -190,6 +193,7 @@ public void testIsModelAllowed_true() throws PrivilegedActionException {
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2")
.version("1.0.1")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand All @@ -204,6 +208,7 @@ public void testIsModelAllowed_WrongModelName() throws PrivilegedActionException
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2-wrong")
.version("1.0.1")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand All @@ -218,6 +223,7 @@ public void testIsModelAllowed_WrongModelVersion() throws PrivilegedActionExcept
MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder()
.modelName("huggingface/sentence-transformers/all-mpnet-base-v2")
.version("000")
.modelGroupId("mockGroupId")
.modelFormat(modelFormat)
.deployModel(false)
.modelNodeIds(new String[]{"node_id1"})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> {
if (!access) {
actionListener.onFailure(new MLValidationException("User Doesn't have previlege to perform this operation"));
actionListener.onFailure(new MLValidationException("User Doesn't have privilege to perform this operation"));
} else {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,37 @@

package org.opensearch.ml.action.tasks;

import lombok.extern.log4j.Log4j2;
import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

@Log4j2
public class SearchTaskTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
private MLSearchHandler mlSearchHandler;
private Client client;

@Inject
public SearchTaskTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) {
public SearchTaskTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(MLTaskSearchAction.NAME, transportService, actionFilters, SearchRequest::new);
this.mlSearchHandler = mlSearchHandler;
this.client = client;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
mlSearchHandler.search(request, actionListener);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.search(request, actionListener);
} catch (Exception e) {
log.error(e.getMessage(), e);
actionListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ private MLRegisterModelInput prepareInput() {
.functionName(FunctionName.BATCH_RCF)
.deployModel(true)
.version("1.0")
.modelGroupId("model group id")
.modelName("Test Model")
.modelConfig(
new TextEmbeddingModelConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

package org.opensearch.ml.action.tasks;

import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import org.junit.Before;
import org.mockito.Mock;
Expand All @@ -16,12 +17,15 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;


public class SearchTaskTransportActionTests extends OpenSearchTestCase {
@Mock
Client client;
Expand All @@ -41,7 +45,6 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase {
@Mock
ActionListener<SearchResponse> actionListener;

MLSearchHandler mlSearchHandler;
SearchTaskTransportAction searchTaskTransportAction;

@Mock
Expand All @@ -50,12 +53,15 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper));
searchTaskTransportAction = new SearchTaskTransportAction(transportService, actionFilters, mlSearchHandler);
searchTaskTransportAction = new SearchTaskTransportAction(transportService, actionFilters, client);
ThreadPool threadPool = mock(ThreadPool.class);
when(client.threadPool()).thenReturn(threadPool);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void test_DoExecute() {
searchTaskTransportAction.doExecute(null, searchRequest, actionListener);
verify(mlSearchHandler).search(searchRequest, actionListener);
verify(client).search(searchRequest, actionListener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
Expand Down Expand Up @@ -50,6 +51,7 @@
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -64,7 +66,10 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -178,7 +183,7 @@ public void setup() throws URISyntaxException {
modelName = "model_name1";
modelId = randomAlphaOfLength(10);
modelContentHashValue = "c446f747520bcc6af053813cb1e8d34944a7c4686bbb405aeaa23883b5a806c8";
version = "1.0.0";
version = "1";
url = "http://testurl";
MLModelConfig modelConfig = TextEmbeddingModelConfig
.builder()
Expand All @@ -191,6 +196,7 @@ public void setup() throws URISyntaxException {
.builder()
.modelName(modelName)
.version(version)
.modelGroupId("modelGroupId")
.functionName(FunctionName.TEXT_EMBEDDING)
.modelFormat(modelFormat)
.modelConfig(modelConfig)
Expand Down Expand Up @@ -263,6 +269,23 @@ public void setup() throws URISyntaxException {
.build();
modelChunk0 = model.toBuilder().content(Base64.getEncoder().encodeToString("test chunk1".getBytes(StandardCharsets.UTF_8))).build();
modelChunk1 = model.toBuilder().content(Base64.getEncoder().encodeToString("test chunk2".getBytes(StandardCharsets.UTF_8))).build();

GetResponse getResponse = mock(GetResponse.class);
when(getResponse.isExists()).thenReturn(true);
Map<String, Object> sourceMap = new HashMap<>();
sourceMap.put("latest_version", 0);
when(getResponse.getSourceAsMap()).thenReturn(sourceMap);
doAnswer(invocation -> {
ActionListener<GetResponse> getResponseActionListener = invocation.getArgument(1);
getResponseActionListener.onResponse(getResponse);
return null;
}).when(client).get(any(GetRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<Void> updateActionListener = invocation.getArgument(1);
updateActionListener.onResponse(null);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));
}

public void testRegisterMLModel_ExceedMaxRunningTask() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private RestRequest getRestRequest_NullModelId() {
final Map<String, Object> modelConfig = Map
.of("model_type", "bert", "embedding_dimension", 384, "framework_type", "sentence_transformers", "all_config", "All Config");
final Map<String, Object> model = Map
.of("name", "test_model", "version", "2", "url", "testUrl", "model_format", "TORCH_SCRIPT", "model_config", modelConfig);
.of("name", "test_model", "version", "2", "model_group_id", "modelGroupId", "url", "testUrl", "model_format", "TORCH_SCRIPT", "model_config", modelConfig);
String requestContent = new Gson().toJson(model).toString();
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withMethod(method)
Expand Down
Loading