Skip to content

Commit

Permalink
add test for function name
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Nov 14, 2023
1 parent 64a6079 commit bdf89cf
Showing 1 changed file with 42 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,8 @@
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.security.PrivilegedActionException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;

Expand Down Expand Up @@ -168,6 +164,9 @@ public class MLModelManagerTests extends OpenSearchTestCase {
@Mock
private ScriptService scriptService;

@Mock
private MLTask pretrainedMLTask;

@Before
public void setup() throws URISyntaxException {
String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=";
Expand Down Expand Up @@ -373,6 +372,31 @@ public void testRegisterMLModel_DownloadModelFileFailure() {
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any(), any());
}

public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionException {
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null);
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo"));
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);
MLRegisterModelInput pretrainedInput = mockPretrainedInput();
doAnswer(
invocation -> {
ActionListener<MLRegisterModelInput> listener = (ActionListener<MLRegisterModelInput>) invocation.getArguments()[2];
listener.onResponse(pretrainedInput);
return null;
}
).when(modelHelper).downloadPrebuiltModelConfig(any(), any(), any());
MLTask pretrainedTask = MLTask
.builder()
.taskId("pretrained")
.modelId("pretrained")
.functionName(FunctionName.TEXT_EMBEDDING)
.build();
modelManager.registerMLModel(pretrainedInput, pretrainedTask);
assertEquals(pretrainedTask.getFunctionName(), FunctionName.SPARSE_ENCODING);

}

@Ignore
public void testRegisterMLModel_DownloadModelFile() throws IOException {
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
Expand Down Expand Up @@ -916,4 +940,16 @@ private MLRegisterModelMetaInput prepareRequest() {
.build();
return input;
}

private MLRegisterModelInput mockPretrainedInput()
{
return MLRegisterModelInput
.builder()
.modelName(modelName)
.version(version)
.modelGroupId("modelGroupId")
.modelFormat(modelFormat)
.functionName(FunctionName.SPARSE_ENCODING)
.build();
}
}

0 comments on commit bdf89cf

Please sign in to comment.