diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 525bc95eda..cb807cbf66 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -49,12 +49,8 @@ import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.nio.file.Path; -import java.util.Arrays; -import java.util.Base64; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.security.PrivilegedActionException; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; @@ -168,6 +164,9 @@ public class MLModelManagerTests extends OpenSearchTestCase { @Mock private ScriptService scriptService; + @Mock + private MLTask pretrainedMLTask; + @Before public void setup() throws URISyntaxException { String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="; @@ -373,6 +372,31 @@ public void testRegisterMLModel_DownloadModelFileFailure() { verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any(), any()); } + public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionException { + doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); + when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null); + when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); + when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo")); + when(modelHelper.isModelAllowed(any(), any())).thenReturn(true); + MLRegisterModelInput pretrainedInput = mockPretrainedInput(); + doAnswer( + invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pretrainedInput); + return null; + } + ).when(modelHelper).downloadPrebuiltModelConfig(any(), any(), any()); + MLTask pretrainedTask = MLTask + .builder() + .taskId("pretrained") + .modelId("pretrained") + .functionName(FunctionName.TEXT_EMBEDDING) + .build(); + modelManager.registerMLModel(pretrainedInput, pretrainedTask); + assertEquals(pretrainedTask.getFunctionName(), FunctionName.SPARSE_ENCODING); + + } + @Ignore public void testRegisterMLModel_DownloadModelFile() throws IOException { doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); @@ -916,4 +940,16 @@ private MLRegisterModelMetaInput prepareRequest() { .build(); return input; } + + private MLRegisterModelInput mockPretrainedInput() + { + return MLRegisterModelInput + .builder() + .modelName(modelName) + .version(version) + .modelGroupId("modelGroupId") + .modelFormat(modelFormat) + .functionName(FunctionName.SPARSE_ENCODING) + .build(); + } }