diff --git a/CHANGELOG.md b/CHANGELOG.md index ccac8d7c36..aae0bf9db4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) * Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) * Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936) +* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNAllowlistExtension.java b/src/main/java/org/opensearch/knn/plugin/script/KNNAllowlistExtension.java index 959063d618..47b99af17e 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNAllowlistExtension.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNAllowlistExtension.java @@ -9,6 +9,7 @@ import org.opensearch.painless.spi.PainlessExtension; import org.opensearch.painless.spi.Whitelist; import org.opensearch.painless.spi.WhitelistLoader; +import org.opensearch.script.FieldScript; import org.opensearch.script.ScoreScript; import org.opensearch.script.ScriptContext; import org.opensearch.script.ScriptedMetricAggContexts; @@ -33,6 +34,8 @@ public Map, List> getContextWhitelists() { ScriptedMetricAggContexts.CombineScript.CONTEXT, allowLists, ScriptedMetricAggContexts.ReduceScript.CONTEXT, + allowLists, + FieldScript.CONTEXT, allowLists ); } diff --git a/src/test/java/org/opensearch/knn/integ/PainlessScriptFieldsIT.java b/src/test/java/org/opensearch/knn/integ/PainlessScriptFieldsIT.java new file mode 100644 index 0000000000..2ee662c225 --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/PainlessScriptFieldsIT.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import lombok.SneakyThrows; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.integ.PainlessScriptHelper.MappingProperty; +import org.opensearch.script.Script; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.opensearch.knn.integ.PainlessScriptHelper.createMapping; + +// PainlesScriptScoreIT already tests every similarity methods with different field type. Hence, +// we don't have to recreate all tests for script_fields. From implementation point of view, +// it is clear if similarity method is supported by script_score, then same is applicable for script_fields +// provided script_fields context is supported. Hence, we test for one similarity method to verify that script_fields +// context is supported by this plugin. +public final class PainlessScriptFieldsIT extends KNNRestTestCase { + + private static final String NUMERIC_INDEX_FIELD_NAME = "price"; + + private void buildTestIndex(final Map knnDocuments) throws Exception { + List properties = buildMappingProperties(); + buildTestIndex(knnDocuments, properties); + } + + private void buildTestIndex(final Map knnDocuments, final List properties) throws Exception { + createKnnIndex(INDEX_NAME, createMapping(properties)); + for (Map.Entry data : knnDocuments.entrySet()) { + addKnnDoc(INDEX_NAME, data.getKey(), FIELD_NAME, data.getValue()); + } + } + + private Map getKnnVectorTestData() { + Map data = new HashMap<>(); + data.put("1", new Float[] { 100.0f, 1.0f }); + data.put("2", new Float[] { 99.0f, 2.0f }); + data.put("3", new Float[] { 97.0f, 3.0f }); + data.put("4", new Float[] { 98.0f, 4.0f }); + return data; + } + + private Map getCosineTestData() { + Map data = new HashMap<>(); + data.put("0", new Float[] { 1.0f, -1.0f }); + data.put("2", new Float[] { 1.0f, 1.0f }); + data.put("1", new Float[] { 1.0f, 0.0f }); + return data; + } + + /* + The doc['field'] will throw an error if field is missing from the mappings. + */ + private List buildMappingProperties() { + List properties = new ArrayList<>(); + properties.add(MappingProperty.builder().name(FIELD_NAME).type(KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").build()); + properties.add(MappingProperty.builder().name(NUMERIC_INDEX_FIELD_NAME).type("integer").build()); + return properties; + } + + @SneakyThrows + public void testCosineSimilarity_whenUsedInScriptFields_thenExecutesScript() { + String source = String.format(Locale.ROOT, "1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME); + String scriptFieldName = "similarity"; + Request request = buildPainlessScriptFieldsRequest(source, 3, getCosineTestData(), scriptFieldName); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + List results = parseSearchResponseScriptFields(EntityUtils.toString(response.getEntity()), scriptFieldName); + assertEquals(3, results.size()); + + String[] expectedDocIDs = { "0", "1", "2" }; + for (int i = 0; i < results.size(); i++) { + assertEquals(expectedDocIDs[i], results.get(i).getDocId()); + } + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testGetValue_whenUsedInScriptFields_thenReturnsDocValues() { + String source = String.format(Locale.ROOT, "doc['%s'].value[0]", FIELD_NAME); + String scriptFieldName = "doc_value_field"; + Map testData = getKnnVectorTestData(); + Request request = buildPainlessScriptFieldsRequest(source, testData.size(), testData, scriptFieldName); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + List results = parseSearchResponseScriptFields(EntityUtils.toString(response.getEntity()), scriptFieldName); + assertEquals(testData.size(), results.size()); + + String[] expectedDocIDs = { "1", "2", "3", "4" }; + for (int i = 0; i < results.size(); i++) { + assertEquals(expectedDocIDs[i], results.get(i).getDocId()); + } + deleteKNNIndex(INDEX_NAME); + } + + private Request buildPainlessScriptFieldsRequest( + final String source, + final int size, + final Map documents, + final String scriptFieldName + ) throws Exception { + buildTestIndex(documents); + return constructScriptFieldsContextSearchRequest( + INDEX_NAME, + scriptFieldName, + Collections.emptyMap(), + Script.DEFAULT_SCRIPT_LANG, + source, + size, + Collections.emptyMap() + ); + } +} diff --git a/src/test/java/org/opensearch/knn/integ/PainlessScriptHelper.java b/src/test/java/org/opensearch/knn/integ/PainlessScriptHelper.java new file mode 100644 index 0000000000..e53fc41e45 --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/PainlessScriptHelper.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.engine.KNNMethodContext; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public final class PainlessScriptHelper { + /** + * Utility to create a Index Mapping with multiple fields + */ + public static String createMapping(final List properties) throws IOException { + Objects.requireNonNull(properties); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties"); + for (MappingProperty property : properties) { + XContentBuilder builder = xContentBuilder.startObject(property.getName()).field("type", property.getType()); + if (property.getDimension() != null) { + builder.field("dimension", property.getDimension()); + } + + if (property.getDocValues() != null) { + builder.field("doc_values", property.getDocValues()); + } + + if (property.getKnnMethodContext() != null) { + builder.startObject(KNNConstants.KNN_METHOD); + property.getKnnMethodContext().toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + } + + builder.endObject(); + } + xContentBuilder.endObject().endObject(); + return xContentBuilder.toString(); + } + + @Getter + @Builder + final static class MappingProperty { + private final String name; + private final String type; + private String dimension; + private KNNMethodContext knnMethodContext; + private Boolean docValues; + } +} diff --git a/src/test/java/org/opensearch/knn/integ/PainlessScriptIT.java b/src/test/java/org/opensearch/knn/integ/PainlessScriptScoreIT.java similarity index 91% rename from src/test/java/org/opensearch/knn/integ/PainlessScriptIT.java rename to src/test/java/org/opensearch/knn/integ/PainlessScriptScoreIT.java index 2fed9fc255..c8835f70f1 100644 --- a/src/test/java/org/opensearch/knn/integ/PainlessScriptIT.java +++ b/src/test/java/org/opensearch/knn/integ/PainlessScriptScoreIT.java @@ -7,10 +7,8 @@ import lombok.SneakyThrows; import org.opensearch.common.settings.Settings; -import org.opensearch.core.xcontent.ToXContent; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; -import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -20,15 +18,13 @@ import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.core.rest.RestStatus; +import org.opensearch.knn.integ.PainlessScriptHelper.MappingProperty; import org.opensearch.script.Script; -import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -36,44 +32,16 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Objects; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.integ.PainlessScriptHelper.createMapping; -public class PainlessScriptIT extends KNNRestTestCase { +public final class PainlessScriptScoreIT extends KNNRestTestCase { public static final int AGGREGATION_FIELD_NAME_MIN_LENGTH = 2; public static final int AGGREGATION_FIELD_NAME_MAX_LENGTH = 5; private static final String NUMERIC_INDEX_FIELD_NAME = "price"; - /** - * Utility to create a Index Mapping with multiple fields - */ - protected String createMapping(List properties) throws IOException { - Objects.requireNonNull(properties); - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties"); - for (MappingProperty property : properties) { - XContentBuilder builder = xContentBuilder.startObject(property.getName()).field("type", property.getType()); - if (property.getDimension() != null) { - builder.field("dimension", property.getDimension()); - } - - if (property.getDocValues() != null) { - builder.field("doc_values", property.getDocValues()); - } - - if (property.getKnnMethodContext() != null) { - builder.startObject(KNNConstants.KNN_METHOD); - property.getKnnMethodContext().toXContent(builder, ToXContent.EMPTY_PARAMS); - builder.endObject(); - } - - builder.endObject(); - } - xContentBuilder.endObject().endObject(); - return xContentBuilder.toString(); - } - /* creates KnnIndex based on properties, we add single non-knn vector documents to verify whether actions works on non-knn vector documents as well @@ -148,8 +116,8 @@ private Map getCosineTestData() { */ private List buildMappingProperties() { List properties = new ArrayList<>(); - properties.add(new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2")); - properties.add(new MappingProperty(NUMERIC_INDEX_FIELD_NAME, "integer")); + properties.add(MappingProperty.builder().name(FIELD_NAME).type(KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").build()); + properties.add(MappingProperty.builder().name(NUMERIC_INDEX_FIELD_NAME).type("integer").build()); return properties; } @@ -568,9 +536,13 @@ public void testL2ScriptingWithLuceneBackedIndex() throws Exception { new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) ); properties.add( - new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2") + MappingProperty.builder() + .name(FIELD_NAME) + .type(KNNVectorFieldMapper.CONTENT_TYPE) + .dimension("2") .knnMethodContext(knnMethodContext) .docValues(randomBoolean()) + .build() ); String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME); @@ -671,54 +643,4 @@ private Response buildIndexAndRunPainlessScript( deleteKNNIndex(INDEX_NAME); } } - - static class MappingProperty { - - private final String name; - private final String type; - private String dimension; - - private KNNMethodContext knnMethodContext; - private Boolean docValues; - - MappingProperty(String name, String type) { - this.name = name; - this.type = type; - } - - MappingProperty dimension(String dimension) { - this.dimension = dimension; - return this; - } - - MappingProperty knnMethodContext(KNNMethodContext knnMethodContext) { - this.knnMethodContext = knnMethodContext; - return this; - } - - MappingProperty docValues(boolean docValues) { - this.docValues = docValues; - return this; - } - - KNNMethodContext getKnnMethodContext() { - return knnMethodContext; - } - - String getDimension() { - return dimension; - } - - String getName() { - return name; - } - - String getType() { - return type; - } - - Boolean getDocValues() { - return docValues; - } - } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index a909358695..0c1dbb2ce2 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -303,6 +303,31 @@ protected List parseSearchResponseScore(String responseBody, String field return knnSearchResponses; } + protected List parseSearchResponseScriptFields(final String responseBody, final String scriptFieldName) throws IOException { + @SuppressWarnings("unchecked") + List hits = (List) ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + responseBody + ).map().get("hits")).get("hits"); + + @SuppressWarnings("unchecked") + List knnSearchResponses = hits.stream().map(hit -> { + @SuppressWarnings("unchecked") + final float[] vector = Floats.toArray( + Arrays.stream( + ((ArrayList) ((Map) ((Map) hit).get("fields")).get(scriptFieldName)).toArray() + ).map(Object::toString).map(Float::valueOf).collect(Collectors.toList()) + ); + return new KNNResult( + (String) ((Map) hit).get("_id"), + vector, + ((Double) ((Map) hit).get("_score")).floatValue() + ); + }).collect(Collectors.toList()); + + return knnSearchResponses; + } + /** * Parse the response of Aggregation to extract the value */ @@ -1002,6 +1027,37 @@ protected Request constructScriptedMetricAggregationSearchRequest( return request; } + protected Request constructScriptFieldsContextSearchRequest( + final String indexName, + final String fieldName, + final Map scriptParams, + final String language, + final String source, + final int size, + final Map searchParams + ) throws Exception { + Script script = buildScript(source, language, scriptParams); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("size", size).startObject("query"); + builder.startObject("match_all"); + builder.endObject(); + builder.endObject(); + builder.startObject("script_fields"); + builder.startObject(fieldName); + builder.field("script", script); + builder.endObject(); + builder.endObject(); + builder.endObject(); + URIBuilder uriBuilder = new URIBuilder("/" + indexName + "/_search"); + if (Objects.nonNull(searchParams)) { + for (Map.Entry entry : searchParams.entrySet()) { + uriBuilder.addParameter(entry.getKey(), entry.getValue().toString()); + } + } + Request request = new Request("POST", uriBuilder.toString()); + request.setJsonEntity(builder.toString()); + return request; + } + protected Request constructScriptScoreContextSearchRequest( String indexName, QueryBuilder qb,