Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support nested query in neural sparse tool, vectorDB tool and RAG tool #350

Merged
merged 6 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ public class NeuralSparseSearchTool extends AbstractRetrieverTool {
public static final String TYPE = "NeuralSparseSearchTool";
public static final String MODEL_ID_FIELD = "model_id";
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String NESTED_PATH_FIELD = "nested_path";

private String name = TYPE;
private String modelId;
private String embeddingField;
private String nestedPath;

@Builder
public NeuralSparseSearchTool(
Expand All @@ -46,11 +48,13 @@ public NeuralSparseSearchTool(
String embeddingField,
String[] sourceFields,
Integer docSize,
String modelId
String modelId,
String nestedPath
) {
super(client, xContentRegistry, index, sourceFields, docSize);
this.modelId = modelId;
this.embeddingField = embeddingField;
this.nestedPath = nestedPath;
}

@Override
Expand All @@ -61,8 +65,29 @@ protected String getQueryBody(String queryText) {
);
}

Map<String, Object> queryBody = Map
.of("query", Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId))));
Map<String, Object> queryBody;
if (StringUtils.isBlank(nestedPath)) {
queryBody = Map
.of("query", Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId))));
} else {
queryBody = Map
.of(
"query",
Map
.of(
"nested",
Map
.of(
"path",
nestedPath,
"score_mode",
"max",
"query",
Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId)))
)
)
);
}

try {
return AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBody));
Expand Down Expand Up @@ -99,6 +124,7 @@ public NeuralSparseSearchTool create(Map<String, Object> params) {
String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class);
String modelId = (String) params.get(MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE;
String nestedPath = (String) params.get(NESTED_PATH_FIELD);
return NeuralSparseSearchTool
.builder()
.client(client)
Expand All @@ -108,6 +134,7 @@ public NeuralSparseSearchTool create(Map<String, Object> params) {
.sourceFields(sourceFields)
.modelId(modelId)
.docSize(docSize)
.nestedPath(nestedPath)
.build();
}

Expand Down
34 changes: 31 additions & 3 deletions src/main/java/org/opensearch/agent/tools/VectorDBTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ public class VectorDBTool extends AbstractRetrieverTool {
public static final String EMBEDDING_FIELD = "embedding_field";
public static final String K_FIELD = "k";
public static final Integer DEFAULT_K = 10;
public static final String NESTED_PATH_FIELD = "nested_path";

private String name = TYPE;
private String modelId;
private String embeddingField;
private Integer k;
private String nestedPath;

@Builder
public VectorDBTool(
Expand All @@ -53,12 +55,14 @@ public VectorDBTool(
String[] sourceFields,
Integer docSize,
String modelId,
Integer k
Integer k,
String nestedPath
) {
super(client, xContentRegistry, index, sourceFields, docSize);
this.modelId = modelId;
this.embeddingField = embeddingField;
this.k = k;
this.nestedPath = nestedPath;
}

@Override
Expand All @@ -69,8 +73,30 @@ protected String getQueryBody(String queryText) {
);
}

Map<String, Object> queryBody = Map
.of("query", Map.of("neural", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId, "k", k))));
Map<String, Object> queryBody;
if (StringUtils.isBlank(nestedPath)) {
queryBody = Map
.of("query", Map.of("neural", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId, "k", k))));

} else {
queryBody = Map
.of(
"query",
Map
.of(
"nested",
Map
.of(
"path",
nestedPath,
"score_mode",
"max",
"query",
Map.of("neural", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId, "k", k)))
)
)
);
}

try {
return AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBody));
Expand Down Expand Up @@ -108,6 +134,7 @@ public VectorDBTool create(Map<String, Object> params) {
String modelId = (String) params.get(MODEL_ID_FIELD);
Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE;
Integer k = params.containsKey(K_FIELD) ? Integer.parseInt((String) params.get(K_FIELD)) : DEFAULT_K;
String nestedPath = (String) params.get(NESTED_PATH_FIELD);
return VectorDBTool
.builder()
.client(client)
Expand All @@ -118,6 +145,7 @@ public VectorDBTool create(Map<String, Object> params) {
.modelId(modelId)
.docSize(docSize)
.k(k)
.nestedPath(nestedPath)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class NeuralSparseSearchToolTests {
public static final String TEST_QUERY_TEXT = "123fsd23134sdfouh";
public static final String TEST_EMBEDDING_FIELD = "test embedding";
public static final String TEST_MODEL_ID = "123fsd23134";
public static final String TEST_NESTED_PATH = "nested_path";
private Map<String, Object> params = new HashMap<>();

@Before
Expand Down Expand Up @@ -60,6 +61,22 @@ public void testGetQueryBody() {
assertEquals("123fsd23134", queryBody.get("query").get("neural_sparse").get("test embedding").get("model_id"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithNestedPath() {
params.put(NeuralSparseSearchTool.NESTED_PATH_FIELD, TEST_NESTED_PATH);
NeuralSparseSearchTool tool = NeuralSparseSearchTool.Factory.getInstance().create(params);
Map<String, Map<String, Map<String, Object>>> nestedQueryBody = gson.fromJson(tool.getQueryBody(TEST_QUERY_TEXT), Map.class);
assertEquals("nested_path", nestedQueryBody.get("query").get("nested").get("path"));
assertEquals("max", nestedQueryBody.get("query").get("nested").get("score_mode"));
Map<String, Map<String, Map<String, String>>> queryBody = (Map<String, Map<String, Map<String, String>>>) nestedQueryBody
.get("query")
.get("nested")
.get("query");
assertEquals("123fsd23134sdfouh", queryBody.get("neural_sparse").get("test embedding").get("query_text"));
assertEquals("123fsd23134", queryBody.get("neural_sparse").get("test embedding").get("model_id"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithJsonObjectString() {
Expand Down Expand Up @@ -110,6 +127,11 @@ public void testCreateToolsParseParams() {
() -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.MODEL_ID_FIELD, 123))
);

assertThrows(
ClassCastException.class,
() -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.NESTED_PATH_FIELD, 123))
);

assertThrows(
JsonSyntaxException.class,
() -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.SOURCE_FIELD, "123"))
Expand Down
10 changes: 9 additions & 1 deletion src/test/java/org/opensearch/agent/tools/RAGToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public class RAGToolTests {
public static final String TEST_INFERENCE_MODEL_ID = "1234";
public static final String TEST_NEURAL_QUERY_TYPE = "neural";
public static final String TEST_NEURAL_SPARSE_QUERY_TYPE = "neural_sparse";
public static final String TEST_NESTED_PATH = "nested_path";

static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY = getQueryNamedXContentRegistry();
private RAGTool ragTool;
Expand Down Expand Up @@ -422,6 +423,7 @@ public void testFactoryNeuralQuery() {
assertEquals(factoryMock.getDefaultVersion(), null);
assertNotNull(RAGTool.Factory.getInstance());

params.put(VectorDBTool.NESTED_PATH_FIELD, TEST_NESTED_PATH);
RAGTool rAGtool1 = factoryMock.create(params);
VectorDBTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY);
params.put(VectorDBTool.MODEL_ID_FIELD, TEST_EMBEDDING_MODEL_ID);
Expand All @@ -436,6 +438,7 @@ public void testFactoryNeuralQuery() {
assertEquals(rAGtool1.getQueryTool().getSourceFields(), rAGtool2.getQueryTool().getSourceFields());
assertEquals(rAGtool1.getXContentRegistry(), rAGtool2.getXContentRegistry());
assertEquals(rAGtool1.getQueryType(), rAGtool2.getQueryType());
assertEquals(((VectorDBTool) rAGtool1.getQueryTool()).getNestedPath(), ((VectorDBTool) rAGtool2.getQueryTool()).getNestedPath());
}

@Test
Expand All @@ -450,6 +453,8 @@ public void testFactoryNeuralSparseQuery() {
assertEquals(factoryMock.getDefaultType(), RAGTool.TYPE);
assertEquals(factoryMock.getDefaultVersion(), null);

params.put(NeuralSparseSearchTool.NESTED_PATH_FIELD, TEST_NESTED_PATH);
params.put("query_type", "neural_sparse");
RAGTool rAGtool1 = factoryMock.create(params);
NeuralSparseSearchTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY);
NeuralSparseSearchTool queryTool = NeuralSparseSearchTool.Factory.getInstance().create(params);
Expand All @@ -463,7 +468,10 @@ public void testFactoryNeuralSparseQuery() {
assertEquals(rAGtool1.getQueryTool().getSourceFields(), rAGtool2.getQueryTool().getSourceFields());
assertEquals(rAGtool1.getXContentRegistry(), rAGtool2.getXContentRegistry());
assertEquals(rAGtool1.getQueryType(), rAGtool2.getQueryType());

assertEquals(
((NeuralSparseSearchTool) rAGtool1.getQueryTool()).getNestedPath(),
((NeuralSparseSearchTool) rAGtool2.getQueryTool()).getNestedPath()
);
}

private static NamedXContentRegistry getQueryNamedXContentRegistry() {
Expand Down
22 changes: 22 additions & 0 deletions src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class VectorDBToolTests {
public static final String TEST_EMBEDDING_FIELD = "test embedding";
public static final String TEST_MODEL_ID = "123fsd23134";
public static final Integer TEST_K = 123;
public static final String TEST_NESTED_PATH = "nested_path";
private Map<String, Object> params = new HashMap<>();

@Before
Expand Down Expand Up @@ -61,6 +62,22 @@ public void testGetQueryBody() {
assertEquals(123.0, queryBody.get("query").get("neural").get("test embedding").get("k"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithNestedPath() {
params.put(VectorDBTool.NESTED_PATH_FIELD, TEST_NESTED_PATH);
VectorDBTool tool = VectorDBTool.Factory.getInstance().create(params);
Map<String, Map<String, Map<String, Object>>> nestedQueryBody = gson.fromJson(tool.getQueryBody(TEST_QUERY_TEXT), Map.class);
assertEquals("nested_path", nestedQueryBody.get("query").get("nested").get("path"));
assertEquals("max", nestedQueryBody.get("query").get("nested").get("score_mode"));
Map<String, Map<String, Map<String, String>>> queryBody = (Map<String, Map<String, Map<String, String>>>) nestedQueryBody
.get("query")
.get("nested")
.get("query");
assertEquals("123fsd23134sdfouh", queryBody.get("neural").get("test embedding").get("query_text"));
assertEquals("123fsd23134", queryBody.get("neural").get("test embedding").get("model_id"));
}

@Test
@SneakyThrows
public void testGetQueryBodyWithJsonObjectString() {
Expand Down Expand Up @@ -103,6 +120,11 @@ public void testCreateToolsParseParams() {

assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.MODEL_ID_FIELD, 123)));

assertThrows(
ClassCastException.class,
() -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.NESTED_PATH_FIELD, 123))
);

assertThrows(JsonSyntaxException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.SOURCE_FIELD, "123")));

// although it will be parsed as integer, but the parameters value should always be String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -22,6 +21,7 @@

public class NeuralSparseSearchToolIT extends BaseAgentToolsIT {
public static String TEST_INDEX_NAME = "test_index";
public static String TEST_NESTED_INDEX_NAME = "test_index_nested";

private String modelId;
private String registerAgentRequestBody;
Expand Down Expand Up @@ -64,12 +64,55 @@ private void prepareIndex() {
addDocToIndex(TEST_INDEX_NAME, "2", List.of("text", "embedding"), List.of("text doc 3", Map.of("test", 5, "a", 6)));
}

@SneakyThrows
private void prepareNestedIndex() {
createIndexWithConfiguration(
TEST_NESTED_INDEX_NAME,
"{\n"
+ " \"mappings\": {\n"
+ " \"properties\": {\n"
+ " \"text\": {\n"
+ " \"type\": \"text\"\n"
+ " },\n"
+ " \"embedding\": {\n"
+ " \"type\": \"nested\",\n"
+ " \"properties\":{\n"
+ " \"sparse\":{\n"
+ " \"type\":\"rank_features\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}"
);
addDocToIndex(
TEST_NESTED_INDEX_NAME,
"0",
List.of("text", "embedding"),
List.of("text doc 1", Map.of("sparse", List.of(Map.of("hello", 1, "world", 2))))
);
addDocToIndex(
TEST_NESTED_INDEX_NAME,
"1",
List.of("text", "embedding"),
List.of("text doc 2", Map.of("sparse", List.of(Map.of("a", 3, "b", 4))))
);
addDocToIndex(
TEST_NESTED_INDEX_NAME,
"2",
List.of("text", "embedding"),
List.of("text doc 3", Map.of("sparse", List.of(Map.of("test", 5, "a", 6))))
);
}

@Before
@SneakyThrows
public void setUp() {
super.setUp();
prepareModel();
prepareIndex();
prepareNestedIndex();
registerAgentRequestBody = Files
.readString(
Path
Expand Down Expand Up @@ -127,6 +170,23 @@ public void testNeuralSparseSearchToolInFlowAgent() {
);
}

public void testNeuralSparseSearchToolInFlowAgent_withNestedIndex() {
String registerAgentRequestBodyNested = registerAgentRequestBody;
registerAgentRequestBodyNested = registerAgentRequestBodyNested.replace("\"nested_path\": \"\"", "\"nested_path\": \"embedding\"");
registerAgentRequestBodyNested = registerAgentRequestBodyNested
.replace("\"embedding_field\": \"embedding\"", "\"embedding_field\": \"embedding.sparse\"");
registerAgentRequestBodyNested = registerAgentRequestBodyNested
.replace("\"index\": \"test_index\"", "\"index\": \"test_index_nested\"");
String agentId = createAgent(registerAgentRequestBodyNested);
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}");
assertEquals(
"The agent execute response not equal with expected.",
"{\"_index\":\"test_index_nested\",\"_source\":{\"text\":\"text doc 3\"},\"_id\":\"2\",\"_score\":2.4136734}\n"
+ "{\"_index\":\"test_index_nested\",\"_source\":{\"text\":\"text doc 2\"},\"_id\":\"1\",\"_score\":1.2068367}\n",
result
);
}

public void testNeuralSparseSearchToolInFlowAgent_withIllegalSourceField_thenGetEmptySource() {
String agentId = createAgent(registerAgentRequestBody.replace("text", "text2"));
String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}");
Expand Down
Loading
Loading