Skip to content

Commit

Permalink
enhance: make response_field customizable in MLModelTool (opensearch-…
Browse files Browse the repository at this point in the history
…project#2007)

* enhance: make response_field customizable in MLModelTool

Signed-off-by: zhichao-aws <[email protected]>

* ut for malformed response field

Signed-off-by: zhichao-aws <[email protected]>

* remove import * in MLModelToolTests

Signed-off-by: zhichao-aws <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws authored and austintlee committed Mar 18, 2024
1 parent fecb29a commit d63f092
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
@ToolAnnotation(MLModelTool.TYPE)
public class MLModelTool implements Tool {
public static final String TYPE = "MLModelTool";
public static final String RESPONSE_FIELD = "response_field";
public static final String MODEL_ID_FIELD = "model_id";
public static final String DEFAULT_RESPONSE_FIELD = "response";

@Setter
@Getter
Expand All @@ -52,14 +55,18 @@ public class MLModelTool implements Tool {
@Setter
@Getter
private Parser outputParser;
@Setter
@Getter
private String responseField;

public MLModelTool(Client client, String modelId) {
public MLModelTool(Client client, String modelId, String responseField) {
this.client = client;
this.modelId = modelId;
this.responseField = responseField;

outputParser = o -> {
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get(responseField);
};
}

Expand Down Expand Up @@ -132,7 +139,11 @@ public void init(Client client) {

@Override
public MLModelTool create(Map<String, Object> map) {
return new MLModelTool(client, (String) map.get("model_id"));
return new MLModelTool(
client,
(String) map.get(MODEL_ID_FIELD),
(String) map.getOrDefault(RESPONSE_FIELD, DEFAULT_RESPONSE_FIELD)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@

package org.opensearch.ml.engine.tools;

import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.engine.tools.MLModelTool.DEFAULT_DESCRIPTION;

import java.util.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -53,7 +62,73 @@ public void setup() {
}

@Test
public void testMLModelsWithOutputParser() {
public void testMLModelsWithDefaultOutputParserAndDefaultResponseField() throws ExecutionException, InterruptedException {
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
doAnswer(invocation -> {

ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);

actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());

Tool tool = MLModelTool.Factory.getInstance().create(Map.of("model_id", "modelId"));
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); });
tool.run(null, listener);

future.join();
assertEquals("response 1", future.get());
}

@Test
public void testMLModelsWithDefaultOutputParserAndCustomizedResponseField() throws ExecutionException, InterruptedException {
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
doAnswer(invocation -> {

ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);

actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());

Tool tool = MLModelTool.Factory.getInstance().create(Map.of("model_id", "modelId", "response_field", "action"));
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); });
tool.run(null, listener);

future.join();
assertEquals("action1", future.get());
}

@Test
public void testMLModelsWithDefaultOutputParserAndMalformedResponseField() throws ExecutionException, InterruptedException {
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
doAnswer(invocation -> {

ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);

actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());

Tool tool = MLModelTool.Factory.getInstance().create(Map.of("model_id", "modelId", "response_field", "malformed field"));
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); });
tool.run(null, listener);

future.join();
assertEquals(null, future.get());
}

@Test
public void testMLModelsWithCustomizedOutputParser() {
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
Expand Down

0 comments on commit d63f092

Please sign in to comment.