From fa8bb7dd0d31518335dd204f6c4df014bbb79ea0 Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Thu, 1 Aug 2024 12:01:24 -0700 Subject: [PATCH] Fix code review comments Signed-off-by: Vijayan Balasubramanian --- .../knn/integ/PainlessScriptFieldsIT.java | 131 +++++++++++++++ .../knn/integ/PainlessScriptHelper.java | 97 +++++++++++ ...riptIT.java => PainlessScriptScoreIT.java} | 152 ++---------------- .../org/opensearch/knn/KNNRestTestCase.java | 16 +- 4 files changed, 246 insertions(+), 150 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/integ/PainlessScriptFieldsIT.java create mode 100644 src/test/java/org/opensearch/knn/integ/PainlessScriptHelper.java rename src/test/java/org/opensearch/knn/integ/{PainlessScriptIT.java => PainlessScriptScoreIT.java} (85%) diff --git a/src/test/java/org/opensearch/knn/integ/PainlessScriptFieldsIT.java b/src/test/java/org/opensearch/knn/integ/PainlessScriptFieldsIT.java new file mode 100644 index 000000000..c68ca2a6f --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/PainlessScriptFieldsIT.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import lombok.SneakyThrows; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.integ.PainlessScriptHelper.MappingProperty; +import org.opensearch.script.Script; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.opensearch.knn.integ.PainlessScriptHelper.createMapping; + +// PainlesScriptScoreIT already tests every similarity methods with different field type. Hence, +// we don't have to recreate all tests for script_fields. From implementation point of view, +// it is clear if similarity method is supported by script_score, then same is applicable for script_fields +// provided script_fields context is supported. Hence, we test for one similarity method to verify that script_fields +// context is supported by this plugin. +public class PainlessScriptFieldsIT extends KNNRestTestCase { + + private static final String NUMERIC_INDEX_FIELD_NAME = "price"; + + private void buildTestIndex(Map knnDocuments) throws Exception { + List properties = buildMappingProperties(); + buildTestIndex(knnDocuments, properties); + } + + private void buildTestIndex(Map knnDocuments, List properties) throws Exception { + createKnnIndex(INDEX_NAME, createMapping(properties)); + for (Map.Entry data : knnDocuments.entrySet()) { + addKnnDoc(INDEX_NAME, data.getKey(), FIELD_NAME, data.getValue()); + } + } + + private Map getKnnVectorTestData() { + Map data = new HashMap<>(); + data.put("1", new Float[] { 100.0f, 1.0f }); + data.put("2", new Float[] { 99.0f, 2.0f }); + data.put("3", new Float[] { 97.0f, 3.0f }); + data.put("4", new Float[] { 98.0f, 4.0f }); + return data; + } + + private Map getCosineTestData() { + Map data = new HashMap<>(); + data.put("0", new Float[] { 1.0f, -1.0f }); + data.put("2", new Float[] { 1.0f, 1.0f }); + data.put("1", new Float[] { 1.0f, 0.0f }); + return data; + } + + /* + The doc['field'] will throw an error if field is missing from the mappings. + */ + private List buildMappingProperties() { + List properties = new ArrayList<>(); + properties.add(new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2")); + properties.add(new MappingProperty(NUMERIC_INDEX_FIELD_NAME, "integer")); + return properties; + } + + @SneakyThrows + public void testCosineSimilarity_whenUsedInScriptFields_thenExecutesScript() { + String source = String.format(Locale.ROOT, "1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME); + String scriptFieldName = "similarity"; + Request request = buildPainlessScriptFieldsRequest(source, 3, getCosineTestData(), scriptFieldName); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + List results = parseSearchResponseScriptFields(EntityUtils.toString(response.getEntity()), scriptFieldName); + assertEquals(3, results.size()); + + String[] expectedDocIDs = { "0", "1", "2" }; + for (int i = 0; i < results.size(); i++) { + assertEquals(expectedDocIDs[i], results.get(i).getDocId()); + } + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testGetValue_whenUsedInScriptFields_thenReturnsDocValues() { + String source = String.format(Locale.ROOT, "doc['%s'].value[0]", FIELD_NAME); + String scriptFieldName = "doc_value_field"; + Map testData = getKnnVectorTestData(); + Request request = buildPainlessScriptFieldsRequest(source, testData.size(), testData, scriptFieldName); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + List results = parseSearchResponseScriptFields(EntityUtils.toString(response.getEntity()), scriptFieldName); + assertEquals(testData.size(), results.size()); + + String[] expectedDocIDs = { "1", "2", "3", "4" }; + for (int i = 0; i < results.size(); i++) { + assertEquals(expectedDocIDs[i], results.get(i).getDocId()); + } + deleteKNNIndex(INDEX_NAME); + } + + private Request buildPainlessScriptFieldsRequest( + final String source, + final int size, + final Map documents, + final String scriptFieldName + ) throws Exception { + buildTestIndex(documents); + return constructScriptFieldsContextSearchRequest( + INDEX_NAME, + scriptFieldName, + Collections.emptyMap(), + Script.DEFAULT_SCRIPT_LANG, + source, + size, + Collections.emptyMap() + ); + } +} diff --git a/src/test/java/org/opensearch/knn/integ/PainlessScriptHelper.java b/src/test/java/org/opensearch/knn/integ/PainlessScriptHelper.java new file mode 100644 index 000000000..800f1a1ec --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/PainlessScriptHelper.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.engine.KNNMethodContext; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class PainlessScriptHelper { + + /** + * Utility to create a Index Mapping with multiple fields + */ + public static String createMapping(List properties) throws IOException { + Objects.requireNonNull(properties); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties"); + for (MappingProperty property : properties) { + XContentBuilder builder = xContentBuilder.startObject(property.getName()).field("type", property.getType()); + if (property.getDimension() != null) { + builder.field("dimension", property.getDimension()); + } + + if (property.getDocValues() != null) { + builder.field("doc_values", property.getDocValues()); + } + + if (property.getKnnMethodContext() != null) { + builder.startObject(KNNConstants.KNN_METHOD); + property.getKnnMethodContext().toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + } + + builder.endObject(); + } + xContentBuilder.endObject().endObject(); + return xContentBuilder.toString(); + } + + static class MappingProperty { + + private final String name; + private final String type; + private String dimension; + + private KNNMethodContext knnMethodContext; + private Boolean docValues; + + MappingProperty(String name, String type) { + this.name = name; + this.type = type; + } + + MappingProperty dimension(String dimension) { + this.dimension = dimension; + return this; + } + + MappingProperty knnMethodContext(KNNMethodContext knnMethodContext) { + this.knnMethodContext = knnMethodContext; + return this; + } + + MappingProperty docValues(boolean docValues) { + this.docValues = docValues; + return this; + } + + KNNMethodContext getKnnMethodContext() { + return knnMethodContext; + } + + String getDimension() { + return dimension; + } + + String getName() { + return name; + } + + String getType() { + return type; + } + + Boolean getDocValues() { + return docValues; + } + } +} diff --git a/src/test/java/org/opensearch/knn/integ/PainlessScriptIT.java b/src/test/java/org/opensearch/knn/integ/PainlessScriptScoreIT.java similarity index 85% rename from src/test/java/org/opensearch/knn/integ/PainlessScriptIT.java rename to src/test/java/org/opensearch/knn/integ/PainlessScriptScoreIT.java index 144339244..03e8f68a2 100644 --- a/src/test/java/org/opensearch/knn/integ/PainlessScriptIT.java +++ b/src/test/java/org/opensearch/knn/integ/PainlessScriptScoreIT.java @@ -6,29 +6,25 @@ package org.opensearch.knn.integ; import lombok.SneakyThrows; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.opensearch.client.Request; -import org.opensearch.client.Response; -import org.opensearch.client.ResponseException; import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.index.query.MatchAllQueryBuilder; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; -import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.knn.integ.PainlessScriptHelper.MappingProperty; import org.opensearch.script.Script; -import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -36,44 +32,16 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Objects; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.integ.PainlessScriptHelper.createMapping; -public class PainlessScriptIT extends KNNRestTestCase { +public class PainlessScriptScoreIT extends KNNRestTestCase { public static final int AGGREGATION_FIELD_NAME_MIN_LENGTH = 2; public static final int AGGREGATION_FIELD_NAME_MAX_LENGTH = 5; private static final String NUMERIC_INDEX_FIELD_NAME = "price"; - /** - * Utility to create a Index Mapping with multiple fields - */ - protected String createMapping(List properties) throws IOException { - Objects.requireNonNull(properties); - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties"); - for (MappingProperty property : properties) { - XContentBuilder builder = xContentBuilder.startObject(property.getName()).field("type", property.getType()); - if (property.getDimension() != null) { - builder.field("dimension", property.getDimension()); - } - - if (property.getDocValues() != null) { - builder.field("doc_values", property.getDocValues()); - } - - if (property.getKnnMethodContext() != null) { - builder.startObject(KNNConstants.KNN_METHOD); - property.getKnnMethodContext().toXContent(builder, ToXContent.EMPTY_PARAMS); - builder.endObject(); - } - - builder.endObject(); - } - xContentBuilder.endObject().endObject(); - return xContentBuilder.toString(); - } - /* creates KnnIndex based on properties, we add single non-knn vector documents to verify whether actions works on non-knn vector documents as well @@ -161,42 +129,6 @@ public void testL2ScriptScoreFails() throws Exception { deleteKNNIndex(INDEX_NAME); } - public void testCosineSimilarityScriptFields() throws Exception { - String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME); - String scriptFieldName = "similarity"; - Request request = buildPainlessScriptFieldsRequest(source, 3, getCosineTestData(), scriptFieldName); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponseScriptFields(EntityUtils.toString(response.getEntity()), scriptFieldName); - assertEquals(3, results.size()); - - String[] expectedDocIDs = { "0", "1", "2" }; - for (int i = 0; i < results.size(); i++) { - assertEquals(expectedDocIDs[i], results.get(i).getDocId()); - } - deleteKNNIndex(INDEX_NAME); - } - - public void testScriptFieldsGetValueReturnsDocValues() throws Exception { - String source = String.format("doc['%s'].value[0]", FIELD_NAME); - String scriptFieldName = "doc_value_field"; - Map testData = getKnnVectorTestData(); - Request request = buildPainlessScriptFieldsRequest(source, testData.size(), testData, scriptFieldName); - - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponseScriptFields(EntityUtils.toString(response.getEntity()), scriptFieldName); - assertEquals(testData.size(), results.size()); - - String[] expectedDocIDs = { "1", "2", "3", "4" }; - for (int i = 0; i < results.size(); i++) { - assertEquals(expectedDocIDs[i], results.get(i).getDocId()); - } - deleteKNNIndex(INDEX_NAME); - } - private Request buildPainlessScoreScriptRequest(String source, int size, Map documents) throws Exception { buildTestIndex(documents); QueryBuilder qb = new MatchAllQueryBuilder(); @@ -211,20 +143,6 @@ private Request buildPainlessScoreScriptRequest(String source, int size, Map documents, String scriptFieldName) - throws Exception { - buildTestIndex(documents); - return constructScriptFieldsContextSearchRequest( - INDEX_NAME, - scriptFieldName, - Collections.emptyMap(), - Script.DEFAULT_SCRIPT_LANG, - source, - size, - Collections.emptyMap() - ); - } - private Request buildPainlessScoreScriptRequest( String source, int size, @@ -721,54 +639,4 @@ private Response buildIndexAndRunPainlessScript( deleteKNNIndex(INDEX_NAME); } } - - static class MappingProperty { - - private final String name; - private final String type; - private String dimension; - - private KNNMethodContext knnMethodContext; - private Boolean docValues; - - MappingProperty(String name, String type) { - this.name = name; - this.type = type; - } - - MappingProperty dimension(String dimension) { - this.dimension = dimension; - return this; - } - - MappingProperty knnMethodContext(KNNMethodContext knnMethodContext) { - this.knnMethodContext = knnMethodContext; - return this; - } - - MappingProperty docValues(boolean docValues) { - this.docValues = docValues; - return this; - } - - KNNMethodContext getKnnMethodContext() { - return knnMethodContext; - } - - String getDimension() { - return dimension; - } - - String getName() { - return name; - } - - String getType() { - return type; - } - - Boolean getDocValues() { - return docValues; - } - } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 184050056..3f4a37d91 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -303,7 +303,7 @@ protected List parseSearchResponseScore(String responseBody, String field return knnSearchResponses; } - protected List parseSearchResponseScriptFields(String responseBody, String scriptFieldName) throws IOException { + protected List parseSearchResponseScriptFields(final String responseBody, final String scriptFieldName) throws IOException { @SuppressWarnings("unchecked") List hits = (List) ((Map) createParser( MediaTypeRegistry.getDefaultMediaType().xContent(), @@ -1028,13 +1028,13 @@ protected Request constructScriptedMetricAggregationSearchRequest( } protected Request constructScriptFieldsContextSearchRequest( - String indexName, - String fieldName, - Map scriptParams, - String language, - String source, - int size, - Map searchParams + final String indexName, + final String fieldName, + final Map scriptParams, + final String language, + final String source, + final int size, + final Map searchParams ) throws Exception { Script script = buildScript(source, language, scriptParams); XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("size", size).startObject("query");