From 45ff4f59ac40d1b3f725effc4d3e79ff4f39d24f Mon Sep 17 00:00:00 2001 From: Muneer Kolarkunnu <33829651+akolarkunnu@users.noreply.github.com> Date: Fri, 15 Nov 2024 01:23:41 +0530 Subject: [PATCH] [FEATURE]Improve test coverage for RemoteModel.java (#3205) * [FEATURE]Improve test coverage for RemoteModel.java Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict(). Also renamed some tests to match with testing methods. Resolves #1382 Signed-off-by: Abdul Muneer Kolarkunnu * [FEATURE]Improve test coverage for RemoteModel.java Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict(). Also renamed some tests to match with testing methods. Resolves #1382 Signed-off-by: Abdul Muneer Kolarkunnu * [FEATURE]Improve test coverage for RemoteModel.java Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict(). Also renamed some tests to match with testing methods. Resolves #1382 Signed-off-by: Abdul Muneer Kolarkunnu * [FEATURE]Improve test coverage for RemoteModel.java Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict(). Also renamed some tests to match with testing methods. Resolves #1382 Signed-off-by: Abdul Muneer Kolarkunnu --------- Signed-off-by: Abdul Muneer Kolarkunnu --- .../algorithms/remote/RemoteModelTest.java | 92 +++++++++++++++++-- 1 file changed, 83 insertions(+), 9 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index c14b329586..075019834c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -14,6 +14,7 @@ import static org.mockito.Mockito.when; import java.util.Arrays; +import java.util.Collections; import java.util.Map; import org.junit.Assert; @@ -23,6 +24,7 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.MockitoAnnotations; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.MLModel; @@ -30,14 +32,18 @@ import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorProtocols; import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.MLStaticMockBase; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import com.google.common.collect.ImmutableMap; -public class RemoteModelTest { +public class RemoteModelTest extends MLStaticMockBase { @Mock MLInput mlInput; @@ -45,6 +51,9 @@ public class RemoteModelTest { @Mock MLModel mlModel; + @Mock + RemoteConnectorExecutor remoteConnectorExecutor; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -73,7 +82,7 @@ public void test_predict_throw_IllegalStateException() { } @Test - public void predict_NullConnectorExecutor() { + public void asyncPredict_NullConnectorExecutor() { ActionListener actionListener = mock(ActionListener.class); remoteModel.asyncPredict(mlInput, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -86,7 +95,18 @@ public void predict_NullConnectorExecutor() { } @Test - public void predict_ModelDeployed_WrongInput() { + public void asyncPredict_ModelDeployed_WrongInput() { + asyncPredict_ModelDeployed_WrongInput("pre_process_function not defined in connector"); + } + + @Test + public void asyncPredict_With_RemoteInferenceInputDataSet() { + when(mlInput.getInputDataset()).thenReturn( + new RemoteInferenceInputDataSet(Collections.emptyMap(), ConnectorAction.ActionType.BATCH_PREDICT)); + asyncPredict_ModelDeployed_WrongInput("no BATCH_PREDICT action found"); + } + + private void asyncPredict_ModelDeployed_WrongInput(String expExceptionMessage) { Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); when(mlModel.getConnector()).thenReturn(connector); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); @@ -95,16 +115,71 @@ public void predict_ModelDeployed_WrongInput() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assert argumentCaptor.getValue() instanceof RuntimeException; - assertEquals("pre_process_function not defined in connector", argumentCaptor.getValue().getMessage()); + assertEquals(expExceptionMessage, argumentCaptor.getValue().getMessage()); } @Test - public void initModel_RuntimeException() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Tag mismatch!"); + public void asyncPredict_Failure_With_RuntimeException() { + asyncPredict_Failure_With_Throwable( + new RuntimeException("Remote Connection Exception!"), + RuntimeException.class, + "Remote Connection Exception!" + ); + } + + @Test + public void asyncPredict_Failure_With_Throwable() { + asyncPredict_Failure_With_Throwable( + new Error("Remote Connection Error!"), + MLException.class, + "java.lang.Error: Remote Connection Error!" + ); + } + + private void asyncPredict_Failure_With_Throwable( + Throwable actualException, + Class expExceptionClass, + String expExceptionMessage + ) { + ActionListener actionListener = mock(ActionListener.class); + doThrow(actualException) + .when(remoteConnectorExecutor) + .executeAction(ConnectorAction.ActionType.PREDICT.toString(), mlInput, actionListener); + try (MockedStatic loader = mockStatic(MLEngineClassLoader.class)) { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + loader + .when(() -> MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class)) + .thenReturn(remoteConnectorExecutor); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + remoteModel.asyncPredict(mlInput, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assert expExceptionClass.isInstance(argumentCaptor.getValue()); + assertEquals(expExceptionMessage, argumentCaptor.getValue().getMessage()); + } + } + + @Test + public void initModel_Failure_With_RuntimeException() { + initModel_Failure_With_Throwable(new IllegalArgumentException("Tag mismatch!"), IllegalArgumentException.class, "Tag mismatch!"); + } + + @Test + public void initModel_Failure_With_Throwable() { + initModel_Failure_With_Throwable(new Error("Decryption Error!"), MLException.class, "Decryption Error!"); + } + + private void initModel_Failure_With_Throwable( + Throwable actualException, + Class expExcepClass, + String expExceptionMessage + ) { + exceptionRule.expect(expExcepClass); + exceptionRule.expectMessage(expExceptionMessage); Connector connector = createConnector(null); when(mlModel.getConnector()).thenReturn(connector); - doThrow(new IllegalArgumentException("Tag mismatch!")).when(encryptor).decrypt(any()); + doThrow(actualException).when(encryptor).decrypt(any()); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); } @@ -129,7 +204,6 @@ public void initModel_WithHeader() { Assert.assertNotNull(executor.getConnector().getDecryptedHeaders()); assertEquals(1, executor.getConnector().getDecryptedHeaders().size()); assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization")); - remoteModel.close(); Assert.assertNull(remoteModel.getConnectorExecutor()); }