Skip to content

Commit

Permalink
change UT for dictionary returns
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Nov 28, 2023
1 parent 7c5ba06 commit 2b13cb3
Showing 1 changed file with 10 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.utils.StringUtils.gson;

@Log4j2
public class PPLToolTests {
Expand Down Expand Up @@ -136,11 +137,13 @@ public void setup() {

@Test
public void testTool() {
Tool tool = PPLTool.Factory.getInstance().create(Collections.emptyMap());
Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt"));
assertEquals(PPLTool.TYPE, tool.getName());

tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.<String>wrap(pplResult ->{
assertEquals(pplResult, "ppl result");
tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.<String>wrap(executePPLResult ->{
Map<String, String> returnResults = gson.fromJson(executePPLResult, Map.class);
assertEquals("ppl result", returnResults.get("executionResult"));
assertEquals("source=demo | head 1", returnResults.get("ppl"));
}, e -> {
log.info(e);
}));
Expand All @@ -149,7 +152,7 @@ public void testTool() {

@Test
public void testTool_getMappingFailure(){
Tool tool = PPLTool.Factory.getInstance().create(Collections.emptyMap());
Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt"));
assertEquals(PPLTool.TYPE, tool.getName());
Exception exception = new Exception("get mapping error");
doAnswer(invocation -> {
Expand All @@ -167,7 +170,7 @@ public void testTool_getMappingFailure(){

@Test
public void testTool_predictModelFailure(){
Tool tool = PPLTool.Factory.getInstance().create(Collections.emptyMap());
Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt"));
assertEquals(PPLTool.TYPE, tool.getName());
Exception exception = new Exception("predict model error");
doAnswer(invocation -> {
Expand All @@ -185,7 +188,7 @@ public void testTool_predictModelFailure(){

@Test
public void testTool_searchFailure(){
Tool tool = PPLTool.Factory.getInstance().create(Collections.emptyMap());
Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt"));
assertEquals(PPLTool.TYPE, tool.getName());
Exception exception = new Exception("search error");
doAnswer(invocation -> {
Expand All @@ -203,7 +206,7 @@ public void testTool_searchFailure(){

@Test
public void testTool_executePPLFailure(){
Tool tool = PPLTool.Factory.getInstance().create(Collections.emptyMap());
Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt"));
assertEquals(PPLTool.TYPE, tool.getName());
Exception exception = new Exception("execute ppl error");
doAnswer(invocation -> {
Expand Down

0 comments on commit 2b13cb3

Please sign in to comment.