diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index c14b329586..075019834c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -14,6 +14,7 @@ import static org.mockito.Mockito.when; import java.util.Arrays; +import java.util.Collections; import java.util.Map; import org.junit.Assert; @@ -23,6 +24,7 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.MockitoAnnotations; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.MLModel; @@ -30,14 +32,18 @@ import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorProtocols; import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.MLStaticMockBase; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import com.google.common.collect.ImmutableMap; -public class RemoteModelTest { +public class RemoteModelTest extends MLStaticMockBase { @Mock MLInput mlInput; @@ -45,6 +51,9 @@ public class RemoteModelTest { @Mock MLModel mlModel; + @Mock + RemoteConnectorExecutor remoteConnectorExecutor; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -73,7 +82,7 @@ public void test_predict_throw_IllegalStateException() { } @Test - public void predict_NullConnectorExecutor() { + public void asyncPredict_NullConnectorExecutor() { ActionListener actionListener = mock(ActionListener.class); remoteModel.asyncPredict(mlInput, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -86,7 +95,18 @@ public void predict_NullConnectorExecutor() { } @Test - public void predict_ModelDeployed_WrongInput() { + public void asyncPredict_ModelDeployed_WrongInput() { + asyncPredict_ModelDeployed_WrongInput("pre_process_function not defined in connector"); + } + + @Test + public void asyncPredict_With_RemoteInferenceInputDataSet() { + when(mlInput.getInputDataset()).thenReturn( + new RemoteInferenceInputDataSet(Collections.emptyMap(), ConnectorAction.ActionType.BATCH_PREDICT)); + asyncPredict_ModelDeployed_WrongInput("no BATCH_PREDICT action found"); + } + + private void asyncPredict_ModelDeployed_WrongInput(String expExceptionMessage) { Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); when(mlModel.getConnector()).thenReturn(connector); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); @@ -95,16 +115,71 @@ public void predict_ModelDeployed_WrongInput() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assert argumentCaptor.getValue() instanceof RuntimeException; - assertEquals("pre_process_function not defined in connector", argumentCaptor.getValue().getMessage()); + assertEquals(expExceptionMessage, argumentCaptor.getValue().getMessage()); } @Test - public void initModel_RuntimeException() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Tag mismatch!"); + public void asyncPredict_Failure_With_RuntimeException() { + asyncPredict_Failure_With_Throwable( + new RuntimeException("Remote Connection Exception!"), + RuntimeException.class, + "Remote Connection Exception!" + ); + } + + @Test + public void asyncPredict_Failure_With_Throwable() { + asyncPredict_Failure_With_Throwable( + new Error("Remote Connection Error!"), + MLException.class, + "java.lang.Error: Remote Connection Error!" + ); + } + + private void asyncPredict_Failure_With_Throwable( + Throwable actualException, + Class expExceptionClass, + String expExceptionMessage + ) { + ActionListener actionListener = mock(ActionListener.class); + doThrow(actualException) + .when(remoteConnectorExecutor) + .executeAction(ConnectorAction.ActionType.PREDICT.toString(), mlInput, actionListener); + try (MockedStatic loader = mockStatic(MLEngineClassLoader.class)) { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + loader + .when(() -> MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class)) + .thenReturn(remoteConnectorExecutor); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + remoteModel.asyncPredict(mlInput, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assert expExceptionClass.isInstance(argumentCaptor.getValue()); + assertEquals(expExceptionMessage, argumentCaptor.getValue().getMessage()); + } + } + + @Test + public void initModel_Failure_With_RuntimeException() { + initModel_Failure_With_Throwable(new IllegalArgumentException("Tag mismatch!"), IllegalArgumentException.class, "Tag mismatch!"); + } + + @Test + public void initModel_Failure_With_Throwable() { + initModel_Failure_With_Throwable(new Error("Decryption Error!"), MLException.class, "Decryption Error!"); + } + + private void initModel_Failure_With_Throwable( + Throwable actualException, + Class expExcepClass, + String expExceptionMessage + ) { + exceptionRule.expect(expExcepClass); + exceptionRule.expectMessage(expExceptionMessage); Connector connector = createConnector(null); when(mlModel.getConnector()).thenReturn(connector); - doThrow(new IllegalArgumentException("Tag mismatch!")).when(encryptor).decrypt(any()); + doThrow(actualException).when(encryptor).decrypt(any()); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); } @@ -129,7 +204,6 @@ public void initModel_WithHeader() { Assert.assertNotNull(executor.getConnector().getDecryptedHeaders()); assertEquals(1, executor.getConnector().getDecryptedHeaders().size()); assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization")); - remoteModel.close(); Assert.assertNull(remoteModel.getConnectorExecutor()); }