Skip to content

Commit

Permalink
Updating state index after register agent (#250)
Browse files Browse the repository at this point in the history
Adding state index update on agent

Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz authored Dec 4, 2023
1 parent 53daf61 commit b4e6b44
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ public enum WorkflowResources {
/** official workflow step name for creating an ingest-pipeline and associated created resource */
CREATE_INGEST_PIPELINE("create_ingest_pipeline", "pipeline_id"),
/** official workflow step name for creating an index and associated created resource */
CREATE_INDEX("create_index", "index_name");
CREATE_INDEX("create_index", "index_name"),
/** official workflow step name for register an agent and the associated created resource */
REGISTER_AGENT("register_agent", "agent_id");

private final String workflowStep;
private final String resourceCreated;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.WorkflowResources;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.agent.LLMSpec;
Expand Down Expand Up @@ -54,8 +56,9 @@ public class RegisterAgentStep implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(RegisterAgentStep.class);

private MachineLearningNodeClient mlClient;
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;

static final String NAME = "register_agent";
static final String NAME = WorkflowResources.REGISTER_AGENT.getWorkflowStep();

private static final String LLM_MODEL_ID = "llm.model_id";
private static final String LLM_PARAMETERS = "llm.parameters";
Expand All @@ -65,10 +68,12 @@ public class RegisterAgentStep implements WorkflowStep {
/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
* @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices
*/
public RegisterAgentStep(MachineLearningNodeClient mlClient) {
public RegisterAgentStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) {
this.mlClient = mlClient;
this.mlToolSpecList = new ArrayList<>();
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
}

@Override
Expand All @@ -92,6 +97,36 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) {
currentNodeInputs.getNodeId()
)
);

try {
String resourceName = WorkflowResources.getResourceByWorkflowStep(getName());
logger.info("Created connector successfully");
flowFrameworkIndicesHandler.updateResourceInStateIndex(
currentNodeInputs.getWorkflowId(),
currentNodeId,
getName(),
mlRegisterAgentResponse.getAgentId(),
ActionListener.wrap(response -> {
logger.info("successfully updated resources created in state index: {}", response.getIndex());
registerAgentModelFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(resourceName, mlRegisterAgentResponse.getAgentId())),
currentNodeInputs.getWorkflowId(),
currentNodeId
)
);
}, exception -> {
logger.error("Failed to update new created resource", exception);
registerAgentModelFuture.completeExceptionally(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
})
);

} catch (Exception e) {
logger.error("Failed to parse and update new created resource", e);
registerAgentModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public WorkflowStepFactory(
stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient));
stepMap.put(ModelGroupStep.NAME, () -> new ModelGroupStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(ToolStep.NAME, ToolStep::new);
stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient));
stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.action.update.UpdateResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
Expand All @@ -29,8 +32,12 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

public class RegisterAgentTests extends OpenSearchTestCase {
Expand All @@ -39,10 +46,12 @@ public class RegisterAgentTests extends OpenSearchTestCase {
@Mock
MachineLearningNodeClient machineLearningNodeClient;

private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;

@Override
public void setUp() throws Exception {
super.setUp();

this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
MockitoAnnotations.openMocks(this);

MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false);
Expand Down Expand Up @@ -76,7 +85,7 @@ public void setUp() throws Exception {

public void testRegisterAgent() throws IOException, ExecutionException, InterruptedException {
String agentId = "agent_id";
RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient);
RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient, flowFrameworkIndicesHandler);

@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<MLRegisterAgentResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
Expand All @@ -88,6 +97,12 @@ public void testRegisterAgent() throws IOException, ExecutionException, Interrup
return null;
}).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture());

doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

CompletableFuture<WorkflowData> future = registerAgentStep.execute(
inputData.getNodeId(),
inputData,
Expand All @@ -103,7 +118,7 @@ public void testRegisterAgent() throws IOException, ExecutionException, Interrup

public void testRegisterAgentFailure() throws IOException {
String agentId = "agent_id";
RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient);
RegisterAgentStep registerAgentStep = new RegisterAgentStep(machineLearningNodeClient, flowFrameworkIndicesHandler);

@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<MLRegisterAgentResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
Expand All @@ -114,6 +129,12 @@ public void testRegisterAgentFailure() throws IOException {
return null;
}).when(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture());

doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

CompletableFuture<WorkflowData> future = registerAgentStep.execute(
inputData.getNodeId(),
inputData,
Expand Down

0 comments on commit b4e6b44

Please sign in to comment.