Skip to content

Commit

Permalink
Enable tests with mockStatic in MLEngineTest (opensearch-project#2582)
Browse files Browse the repository at this point in the history
* Enable test cases with mockStatic in MLEngineTest

Signed-off-by: Liyun Xiu <[email protected]>

* Create a class to provide mockStatic

Signed-off-by: Liyun Xiu <[email protected]>

* Add TODO and comment about static best practices

Signed-off-by: Liyun Xiu <[email protected]>

---------

Signed-off-by: Liyun Xiu <[email protected]>
  • Loading branch information
chishui authored Jul 3, 2024
1 parent c4cf1b2 commit 4a22eb8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@

import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
Expand All @@ -50,7 +48,8 @@
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;

public class MLEngineTest {
// TODO: refactor MLEngineClassLoader's static functions to avoid mockStatic
public class MLEngineTest extends MLStaticMockBase {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

Expand Down Expand Up @@ -140,68 +139,58 @@ public void trainLinearRegression() {
assertNotNull(model.getContent());
}

// TODO: fix mockito error
@Ignore
@Test
public void train_NullInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Input should not be null");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
try (MockedStatic<MLEngineClassLoader> loader = mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
mlEngine.train(null);
}
}

// TODO: fix mockito error
@Ignore
@Test
public void train_NullInputDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Input data set should not be null");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
try (MockedStatic<MLEngineClassLoader> loader = mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
mlEngine.train(MLInput.builder().algorithm(algoName).build());
}
}

// TODO: fix mockito error
@Ignore
@Test
public void train_NullDataFrame() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Input data frame should not be null or empty");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
try (MockedStatic<MLEngineClassLoader> loader = mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
DataFrame dataFrame = new DefaultDataFrame(new ColumnMeta[0]);
mlEngine.train(MLInput.builder().inputDataset(new DataFrameInputDataset(dataFrame)).algorithm(algoName).build());
}
}

// TODO: fix mockito error
@Ignore
@Test
public void train_EmptyDataFrame() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Input data frame should not be null or empty");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
try (MockedStatic<MLEngineClassLoader> loader = mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(0)).build();
mlEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
}
}

// TODO: fix mockito error
@Ignore
@Test
public void train_UnsupportedAlgorithm() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
try (MockedStatic<MLEngineClassLoader> loader = mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructTestDataFrame(10)).build();
mlEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
Expand Down Expand Up @@ -233,14 +222,12 @@ public void predictWithoutModel() {
mlEngine.predict(mlInput, null);
}

// TODO: fix mockito error
@Ignore
@Test
public void predictUnsupportedAlgorithm() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
try (MockedStatic<MLEngineClassLoader> loader = mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build();
Input mlInput = MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.opensearch.ml.engine;

import org.mockito.MockSettings;
import org.mockito.MockedStatic;
import org.mockito.internal.creation.MockSettingsImpl;

/**
* This class provides a way to use Mockito's MockedStatic with inline mock maker enabled.
* It can be used as a base class for other test classes that require mocking static methods.
*
* Note: before using this class to mock static function, think twice if your function really has to be
* static as static functions are tightly coupled with coding using them and have bad testability.
*
* Example usage:
*
* public class MyClassTest extends MLStaticMockBase {
* // Test methods go here
* }
*
* It's to overcome the issue described in https://github.com/opensearch-project/OpenSearch/issues/14420
*/
public class MLStaticMockBase {
private static final String inlineMockMaker = "org.mockito.internal.creation.bytebuddy.InlineByteBuddyMockMaker";

private MockSettings mockSettingsWithInlineMockMaker = new MockSettingsImpl().mockMaker(inlineMockMaker);

protected <T> MockedStatic<T> mockStatic(Class<T> classToMock) {
return org.mockito.Mockito.mockStatic(classToMock, mockSettingsWithInlineMockMaker);
}

protected <T> MockedStatic<T> mockStatic(Class<T> classToMock, MockSettings settings) {
MockSettings newSettings = settings.mockMaker(inlineMockMaker);
return org.mockito.Mockito.mockStatic(classToMock, newSettings);
}
}

0 comments on commit 4a22eb8

Please sign in to comment.