diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java index f8736b80fe..62ed498e4f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java @@ -5,38 +5,48 @@ package org.opensearch.ml.action.models; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.rules.ExpectedException; +import org.opensearch.OpenSearchTimeoutException; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ml.action.MLCommonsIntegTestCase; -import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.test.OpenSearchIntegTestCase; @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 2) public class GetModelITTests extends MLCommonsIntegTestCase { - private String irisIndexName; - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); + private static final int MAX_RETRIES = 3; - @Before - public void setUp() throws Exception { - super.setUp(); - irisIndexName = "iris_data_for_model_it"; - loadIrisData(irisIndexName); - } - - @Ignore public void testGetModel_IndexNotFound() { - exceptionRule.expect(MLResourceNotFoundException.class); - MLModel model = getModel("test_id"); + testGetModelExceptionsWithRetry(MLResourceNotFoundException.class, "test_id"); } public void testGetModel_NullModelIdException() { - exceptionRule.expect(ActionRequestValidationException.class); - MLModel model = getModel(null); + testGetModelExceptionsWithRetry(ActionRequestValidationException.class, null); + } + + private void testGetModelExceptionsWithRetry(Class expectedException, String modelId) { + assertThrows(expectedException, () -> { + for (int retryAttempt = 1; retryAttempt <= MAX_RETRIES; retryAttempt++) { + try { + getModel(modelId); + return; + } catch (OpenSearchTimeoutException e) { + logger.info("GetModelITTests attempt: {}", retryAttempt); + + if (retryAttempt == MAX_RETRIES) { + logger.error("Failed to execute test GetModelITTests after {} retries due to timeout", MAX_RETRIES); + throw e; + } + + // adding small delay between retries + try { + Thread.sleep(1000); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Thread was interrupted during retry", ie); + } + } + } + }); } }