Skip to content

Commit

Permalink
Add script_fields context to KNNAllowlist (opensearch-project#1917)
Browse files Browse the repository at this point in the history
Include script_fields context to existing
supported context for knn methods.
Added test cases for method and doc values.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Aug 19, 2024
1 parent 7ad7d2f commit f1f2f84
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 88 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)
* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844)
* Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936)
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.painless.spi.PainlessExtension;
import org.opensearch.painless.spi.Whitelist;
import org.opensearch.painless.spi.WhitelistLoader;
import org.opensearch.script.FieldScript;
import org.opensearch.script.ScoreScript;
import org.opensearch.script.ScriptContext;
import org.opensearch.script.ScriptedMetricAggContexts;
Expand All @@ -33,6 +34,8 @@ public Map<ScriptContext<?>, List<Whitelist>> getContextWhitelists() {
ScriptedMetricAggContexts.CombineScript.CONTEXT,
allowLists,
ScriptedMetricAggContexts.ReduceScript.CONTEXT,
allowLists,
FieldScript.CONTEXT,
allowLists
);
}
Expand Down
131 changes: 131 additions & 0 deletions src/test/java/org/opensearch/knn/integ/PainlessScriptFieldsIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.integ;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.integ.PainlessScriptHelper.MappingProperty;
import org.opensearch.script.Script;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.knn.integ.PainlessScriptHelper.createMapping;

// PainlesScriptScoreIT already tests every similarity methods with different field type. Hence,
// we don't have to recreate all tests for script_fields. From implementation point of view,
// it is clear if similarity method is supported by script_score, then same is applicable for script_fields
// provided script_fields context is supported. Hence, we test for one similarity method to verify that script_fields
// context is supported by this plugin.
public final class PainlessScriptFieldsIT extends KNNRestTestCase {

private static final String NUMERIC_INDEX_FIELD_NAME = "price";

private void buildTestIndex(final Map<String, Float[]> knnDocuments) throws Exception {
List<MappingProperty> properties = buildMappingProperties();
buildTestIndex(knnDocuments, properties);
}

private void buildTestIndex(final Map<String, Float[]> knnDocuments, final List<MappingProperty> properties) throws Exception {
createKnnIndex(INDEX_NAME, createMapping(properties));
for (Map.Entry<String, Float[]> data : knnDocuments.entrySet()) {
addKnnDoc(INDEX_NAME, data.getKey(), FIELD_NAME, data.getValue());
}
}

private Map<String, Float[]> getKnnVectorTestData() {
Map<String, Float[]> data = new HashMap<>();
data.put("1", new Float[] { 100.0f, 1.0f });
data.put("2", new Float[] { 99.0f, 2.0f });
data.put("3", new Float[] { 97.0f, 3.0f });
data.put("4", new Float[] { 98.0f, 4.0f });
return data;
}

private Map<String, Float[]> getCosineTestData() {
Map<String, Float[]> data = new HashMap<>();
data.put("0", new Float[] { 1.0f, -1.0f });
data.put("2", new Float[] { 1.0f, 1.0f });
data.put("1", new Float[] { 1.0f, 0.0f });
return data;
}

/*
The doc['field'] will throw an error if field is missing from the mappings.
*/
private List<MappingProperty> buildMappingProperties() {
List<MappingProperty> properties = new ArrayList<>();
properties.add(MappingProperty.builder().name(FIELD_NAME).type(KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").build());
properties.add(MappingProperty.builder().name(NUMERIC_INDEX_FIELD_NAME).type("integer").build());
return properties;
}

@SneakyThrows
public void testCosineSimilarity_whenUsedInScriptFields_thenExecutesScript() {
String source = String.format(Locale.ROOT, "1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME);
String scriptFieldName = "similarity";
Request request = buildPainlessScriptFieldsRequest(source, 3, getCosineTestData(), scriptFieldName);
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

List<KNNResult> results = parseSearchResponseScriptFields(EntityUtils.toString(response.getEntity()), scriptFieldName);
assertEquals(3, results.size());

String[] expectedDocIDs = { "0", "1", "2" };
for (int i = 0; i < results.size(); i++) {
assertEquals(expectedDocIDs[i], results.get(i).getDocId());
}
deleteKNNIndex(INDEX_NAME);
}

@SneakyThrows
public void testGetValue_whenUsedInScriptFields_thenReturnsDocValues() {
String source = String.format(Locale.ROOT, "doc['%s'].value[0]", FIELD_NAME);
String scriptFieldName = "doc_value_field";
Map<String, Float[]> testData = getKnnVectorTestData();
Request request = buildPainlessScriptFieldsRequest(source, testData.size(), testData, scriptFieldName);

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

List<KNNResult> results = parseSearchResponseScriptFields(EntityUtils.toString(response.getEntity()), scriptFieldName);
assertEquals(testData.size(), results.size());

String[] expectedDocIDs = { "1", "2", "3", "4" };
for (int i = 0; i < results.size(); i++) {
assertEquals(expectedDocIDs[i], results.get(i).getDocId());
}
deleteKNNIndex(INDEX_NAME);
}

private Request buildPainlessScriptFieldsRequest(
final String source,
final int size,
final Map<String, Float[]> documents,
final String scriptFieldName
) throws Exception {
buildTestIndex(documents);
return constructScriptFieldsContextSearchRequest(
INDEX_NAME,
scriptFieldName,
Collections.emptyMap(),
Script.DEFAULT_SCRIPT_LANG,
source,
size,
Collections.emptyMap()
);
}
}
58 changes: 58 additions & 0 deletions src/test/java/org/opensearch/knn/integ/PainlessScriptHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.integ;

import lombok.Builder;
import lombok.Getter;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.engine.KNNMethodContext;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public final class PainlessScriptHelper {
/**
* Utility to create a Index Mapping with multiple fields
*/
public static String createMapping(final List<MappingProperty> properties) throws IOException {
Objects.requireNonNull(properties);
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties");
for (MappingProperty property : properties) {
XContentBuilder builder = xContentBuilder.startObject(property.getName()).field("type", property.getType());
if (property.getDimension() != null) {
builder.field("dimension", property.getDimension());
}

if (property.getDocValues() != null) {
builder.field("doc_values", property.getDocValues());
}

if (property.getKnnMethodContext() != null) {
builder.startObject(KNNConstants.KNN_METHOD);
property.getKnnMethodContext().toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject();
}

builder.endObject();
}
xContentBuilder.endObject().endObject();
return xContentBuilder.toString();
}

@Getter
@Builder
final static class MappingProperty {
private final String name;
private final String type;
private String dimension;
private KNNMethodContext knnMethodContext;
private Boolean docValues;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@

import lombok.SneakyThrows;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
Expand All @@ -20,60 +18,30 @@
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.knn.integ.PainlessScriptHelper.MappingProperty;
import org.opensearch.script.Script;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.integ.PainlessScriptHelper.createMapping;

public class PainlessScriptIT extends KNNRestTestCase {
public final class PainlessScriptScoreIT extends KNNRestTestCase {

public static final int AGGREGATION_FIELD_NAME_MIN_LENGTH = 2;
public static final int AGGREGATION_FIELD_NAME_MAX_LENGTH = 5;
private static final String NUMERIC_INDEX_FIELD_NAME = "price";

/**
* Utility to create a Index Mapping with multiple fields
*/
protected String createMapping(List<MappingProperty> properties) throws IOException {
Objects.requireNonNull(properties);
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties");
for (MappingProperty property : properties) {
XContentBuilder builder = xContentBuilder.startObject(property.getName()).field("type", property.getType());
if (property.getDimension() != null) {
builder.field("dimension", property.getDimension());
}

if (property.getDocValues() != null) {
builder.field("doc_values", property.getDocValues());
}

if (property.getKnnMethodContext() != null) {
builder.startObject(KNNConstants.KNN_METHOD);
property.getKnnMethodContext().toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject();
}

builder.endObject();
}
xContentBuilder.endObject().endObject();
return xContentBuilder.toString();
}

/*
creates KnnIndex based on properties, we add single non-knn vector documents to verify whether actions
works on non-knn vector documents as well
Expand Down Expand Up @@ -148,8 +116,8 @@ private Map<String, Float[]> getCosineTestData() {
*/
private List<MappingProperty> buildMappingProperties() {
List<MappingProperty> properties = new ArrayList<>();
properties.add(new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2"));
properties.add(new MappingProperty(NUMERIC_INDEX_FIELD_NAME, "integer"));
properties.add(MappingProperty.builder().name(FIELD_NAME).type(KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").build());
properties.add(MappingProperty.builder().name(NUMERIC_INDEX_FIELD_NAME).type("integer").build());
return properties;
}

Expand Down Expand Up @@ -568,9 +536,13 @@ public void testL2ScriptingWithLuceneBackedIndex() throws Exception {
new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())
);
properties.add(
new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2")
MappingProperty.builder()
.name(FIELD_NAME)
.type(KNNVectorFieldMapper.CONTENT_TYPE)
.dimension("2")
.knnMethodContext(knnMethodContext)
.docValues(randomBoolean())
.build()
);

String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME);
Expand Down Expand Up @@ -671,54 +643,4 @@ private Response buildIndexAndRunPainlessScript(
deleteKNNIndex(INDEX_NAME);
}
}

static class MappingProperty {

private final String name;
private final String type;
private String dimension;

private KNNMethodContext knnMethodContext;
private Boolean docValues;

MappingProperty(String name, String type) {
this.name = name;
this.type = type;
}

MappingProperty dimension(String dimension) {
this.dimension = dimension;
return this;
}

MappingProperty knnMethodContext(KNNMethodContext knnMethodContext) {
this.knnMethodContext = knnMethodContext;
return this;
}

MappingProperty docValues(boolean docValues) {
this.docValues = docValues;
return this;
}

KNNMethodContext getKnnMethodContext() {
return knnMethodContext;
}

String getDimension() {
return dimension;
}

String getName() {
return name;
}

String getType() {
return type;
}

Boolean getDocValues() {
return docValues;
}
}
}
Loading

0 comments on commit f1f2f84

Please sign in to comment.