Skip to content

Commit

Permalink
register new versions to a model group based on the name provided (#1452
Browse files Browse the repository at this point in the history
)

Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna authored and ylwu-amzn committed Nov 20, 2023
1 parent 3495e8d commit 3cf8b32
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.util.List;
import java.util.regex.Pattern;

import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -136,17 +137,76 @@ public TransportRegisterModelAction(

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterModelResponse> listener) {
User user = RestActionUtils.getUserContext(client);
MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request);
MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput();
if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) {
mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId();
registerModelInput.setModelGroupId(modelGroupIdOfTheNameProvided);
checkUserAccess(registerModelInput, listener, true);
} else {
doRegister(registerModelInput, listener);
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}));
} else {
checkUserAccess(registerModelInput, listener, false);
}
}

private void checkUserAccess(
MLRegisterModelInput registerModelInput,
ActionListener<MLRegisterModelResponse> listener,
Boolean isModelNameAlreadyExisting
) {
User user = RestActionUtils.getUserContext(client);
modelAccessControlHelper
.validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
log.error("You don't have permissions to perform this operation on this model.");
listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
} else {
if (access) {
doRegister(registerModelInput, listener);
return;
}
// if the user does not have access, we need to check three more conditions before throwing exception.
// if we are checking the access based on the name provided in the input, we let user know the name is already used by a
// model group they do not have access to.
if (isModelNameAlreadyExisting) {
// This case handles when user is using the same pre-trained model already registered by another user on the cluster.
// The only way here is for the user to first create model group and use its ID in the request
if (registerModelInput.getUrl() == null
&& registerModelInput.getFunctionName() != FunctionName.REMOTE
&& registerModelInput.getConnectorId() == null) {
listener
.onFailure(
new IllegalArgumentException(
"Without a model group ID, the system will use the model name {"
+ registerModelInput.getModelName()
+ "} to create a new model group. However, this name is taken by another group with id {"
+ registerModelInput.getModelGroupId()
+ "} you can't access. To register this pre-trained model, create a new model group and use its ID in your request."
)
);
} else {
listener
.onFailure(
new IllegalArgumentException(
"The name {"
+ registerModelInput.getModelName()
+ "} you provided is unavailable because it is used by another model group with id {"
+ registerModelInput.getModelGroupId()
+ "} to which you do not have access. Please provide a different name."
)
);
}
return;
}
// if user does not have access to the model group ID provided in the input, we let user know they do not have access to the
// specified model group
listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
}, listener::onFailure));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,74 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegi
MLRegisterModelMetaRequest registerModelMetaRequest = MLRegisterModelMetaRequest.fromActionRequest(request);
MLRegisterModelMetaInput mlUploadInput = registerModelMetaRequest.getMlRegisterModelMetaInput();

User user = RestActionUtils.getUserContext(client);
if (StringUtils.isEmpty(mlUploadInput.getModelGroupId())) {
mlModelGroupManager.validateUniqueModelGroupName(mlUploadInput.getName(), ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId();
mlUploadInput.setModelGroupId(modelGroupIdOfTheNameProvided);
checkUserAccess(mlUploadInput, listener, true);
} else {
createModelGroup(mlUploadInput, listener);
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}));
} else {
checkUserAccess(mlUploadInput, listener, false);
}
}

private void checkUserAccess(
MLRegisterModelMetaInput mlUploadInput,
ActionListener<MLRegisterModelMetaResponse> listener,
Boolean isModelNameAlreadyExisting
) {

User user = RestActionUtils.getUserContext(client);
modelAccessControlHelper.validateModelGroupAccess(user, mlUploadInput.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
if (access) {
createModelGroup(mlUploadInput, listener);
return;
}
if (isModelNameAlreadyExisting) {
listener
.onFailure(
new IllegalArgumentException(
"The name {"
+ mlUploadInput.getName()
+ "} you provided is unavailable because it is used by another model group with id {"
+ mlUploadInput.getModelGroupId()
+ "} to which you do not have access. Please provide a different name."
)
);
} else {
log.error("You don't have permissions to perform this operation on this model.");
listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
} else {
if (StringUtils.isEmpty(mlUploadInput.getModelGroupId())) {
MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput);
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> {
mlUploadInput.setModelGroupId(modelGroupId);
registerModelMeta(mlUploadInput, listener);
}, e -> {
logException("Failed to create Model Group", e, log);
listener.onFailure(e);
}));
} else {
registerModelMeta(mlUploadInput, listener);
}
}
}, e -> {
logException("Failed to validate model access", e, log);
listener.onFailure(e);
}));
}

private void createModelGroup(MLRegisterModelMetaInput mlUploadInput, ActionListener<MLRegisterModelMetaResponse> listener) {
if (StringUtils.isEmpty(mlUploadInput.getModelGroupId())) {
MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput);
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> {
mlUploadInput.setModelGroupId(modelGroupId);
registerModelMeta(mlUploadInput, listener);
}, e -> {
logException("Failed to create Model Group", e, log);
listener.onFailure(e);
}));
} else {
registerModelMeta(mlUploadInput, listener);
}
}

private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterModelMetaInput mlUploadInput) {
return MLRegisterModelGroupInput
.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
Expand All @@ -30,6 +32,7 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
Expand Down Expand Up @@ -61,6 +64,9 @@
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -144,7 +150,7 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase {
private ConnectorAccessControlHelper connectorAccessControlHelper;

@Before
public void setup() {
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
settings = Settings
.builder()
Expand Down Expand Up @@ -199,6 +205,13 @@ public void setup() {
return null;
}).when(mlTaskDispatcher).dispatch(any(), any());

SearchResponse searchResponse = createModelGroupSearchResponse(0);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

when(clusterService.localNode()).thenReturn(node2);
when(node2.getId()).thenReturn("node2Id");

Expand Down Expand Up @@ -461,6 +474,97 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi
);
}

public void test_ModelNameAlreadyExists() throws IOException {
when(node1.getId()).thenReturn("NodeId1");
when(node2.getId()).thenReturn("NodeId2");
MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class);
doAnswer(invocation -> {
ActionListenerResponseHandler<MLForwardResponse> handler = invocation.getArgument(3);
handler.handleResponse(forwardResponse);
return null;
}).when(transportService).sendRequest(any(), any(), any(), any());
SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", null), actionListener);
ArgumentCaptor<MLRegisterModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class);
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException {
SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
listener.onResponse(false);
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());

MLRegisterModelInput registerModelInput = MLRegisterModelInput
.builder()
.modelName("huggingface/sentence-transformers/all-MiniLM-L12-v2")
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.version("1")
.build();

transportRegisterModelAction.doExecute(task, new MLRegisterModelRequest(registerModelInput), actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Without a model group ID, the system will use the model name {huggingface/sentence-transformers/all-MiniLM-L12-v2} to create a new model group. However, this name is taken by another group with id {model_group_ID} you can't access. To register this pre-trained model, create a new model group and use its ID in your request.",
argumentCaptor.getValue().getMessage()

);
}

public void test_FailureWhenSearchingModelGroupName() throws IOException {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("Runtime exception"));
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener);

ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Runtime exception", argumentCaptor.getValue().getMessage());
}

public void test_NoAccessWhenModelNameAlreadyExists() throws IOException {

SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
listener.onResponse(false);
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());

transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener);

ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"The name {Test Model} you provided is unavailable because it is used by another model group with id {model_group_ID} to which you do not have access. Please provide a different name.",
argumentCaptor.getValue().getMessage()
);
}

private MLRegisterModelRequest prepareRequest(String url, String modelGroupID) {
MLRegisterModelInput registerModelInput = MLRegisterModelInput
.builder()
Expand All @@ -485,4 +589,22 @@ private MLRegisterModelRequest prepareRequest(String url, String modelGroupID) {
return new MLRegisterModelRequest(registerModelInput);
}

private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException {

SearchResponse searchResponse = mock(SearchResponse.class);
String modelContent = "{\n"
+ " \"created_time\": 1684981986069,\n"
+ " \"access\": \"public\",\n"
+ " \"latest_version\": 0,\n"
+ " \"last_updated_time\": 1684981986069,\n"
+ " \"_id\": \"model_group_ID\",\n"
+ " \"name\": \"Test Model\",\n"
+ " \"description\": \"This is an example description\"\n"
+ " }";
SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent));
SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN);
when(searchResponse.getHits()).thenReturn(hits);
return searchResponse;
}

}
Loading

0 comments on commit 3cf8b32

Please sign in to comment.