Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix flaky test of PredictionITTests and RestConnectorToolIT #2437

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,23 @@
import org.opensearch.ml.utils.TestData;
import org.opensearch.plugins.Plugin;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.test.OpenSearchIntegTestCase;
import org.opensearch.test.ParameterizedStaticSettingsOpenSearchIntegTestCase;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.gson.Gson;

public class MLCommonsIntegTestCase extends OpenSearchIntegTestCase {
public class MLCommonsIntegTestCase extends ParameterizedStaticSettingsOpenSearchIntegTestCase {
private Gson gson = new Gson();

public MLCommonsIntegTestCase() {
super(Settings.EMPTY);
}

public MLCommonsIntegTestCase(Settings nodeSettings) {
super(nodeSettings);
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singletonList(MachineLearningPlugin.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.action.prediction;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS;
import static org.opensearch.ml.utils.TestData.IRIS_DATA_SIZE;
import static org.opensearch.ml.utils.TestData.TIME_FIELD;

Expand All @@ -16,6 +17,7 @@
import org.junit.rules.ExpectedException;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.action.MLCommonsIntegTestCase;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
Expand Down Expand Up @@ -49,6 +51,14 @@ public class PredictionITTests extends MLCommonsIntegTestCase {
private String logisticRegressionModelId;
private int batchRcfDataSize = 100;

/**
* set ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS to 0 to disable ML_COMMONS_SYNC_UP_JOB
* the cluster will be pre-created with the settings at startup
*/
public PredictionITTests() {
super(Settings.builder().put(ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS.getKey(), 0).build());
}

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,10 +969,18 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt
}

public String registerConnector(String createConnectorInput) throws IOException, InterruptedException {
Response response = RestMLRemoteInferenceIT.createConnector(createConnectorInput);
Response response;
try {
response = RestMLRemoteInferenceIT.createConnector(createConnectorInput);
} catch (Throwable throwable) {
// Add retry for `The ML encryption master key has not been initialized yet. Please retry after waiting for 10 seconds.`
TimeUnit.SECONDS.sleep(10);
response = RestMLRemoteInferenceIT.createConnector(createConnectorInput);
}
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
return connectorId;

}

public String registerRemoteModel(String createConnectorInput, String modelName, boolean deploy) throws IOException,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ public void tearDown() throws Exception {
}

public void testConnectorToolInFlowAgent_WrongAction() throws IOException, ParseException {
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
return;
}
String registerAgentRequestBody = "{\n"
+ " \"name\": \"Test agent with connector tool\",\n"
+ " \"type\": \"flow\",\n"
Expand All @@ -111,6 +114,9 @@ public void testConnectorToolInFlowAgent_WrongAction() throws IOException, Parse
}

public void testConnectorToolInFlowAgent() throws IOException, ParseException {
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
return;
}
String registerAgentRequestBody = "{\n"
+ " \"name\": \"Test agent with connector tool\",\n"
+ " \"type\": \"flow\",\n"
Expand Down
Loading