Skip to content

Commit

Permalink
add validations from appsec
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Jan 26, 2024
1 parent 55e840f commit ed8e581
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.plugin;

import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -145,7 +146,7 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseR

@Override
public List<Setting<?>> getSettings() {
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED);
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED, RERANKER_MAX_DOC_FIELDS);
}

@Override
Expand All @@ -159,7 +160,7 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchReques
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchResponseProcessor>> getResponseProcessors(
Parameters parameters
) {
return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor));
return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor, parameters.env));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Set;
import java.util.StringJoiner;

import org.opensearch.env.Environment;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
Expand All @@ -37,6 +38,7 @@ public class RerankProcessorFactory implements Processor.Factory<SearchResponseP
public static final String CONTEXT_CONFIG_FIELD = "context";

private final MLCommonsClientAccessor clientAccessor;
private final Environment environment;

@Override
public SearchResponseProcessor create(
Expand All @@ -49,7 +51,12 @@ public SearchResponseProcessor create(
) {
RerankType type = findRerankType(config);
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(
config,
includeQueryContextFetcher,
tag,
environment
);
switch (type) {
case ML_OPENSEARCH:
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());
Expand Down Expand Up @@ -109,22 +116,23 @@ public static boolean shouldIncludeQueryContextFetcher(RerankType type) {
public static List<ContextSourceFetcher> createFetchers(
Map<String, Object> config,
boolean includeQueryContextFetcher,
String tag
String tag,
final Environment environment
) {
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD);
List<ContextSourceFetcher> fetchers = new ArrayList<>();
for (String key : contextConfig.keySet()) {
Object cfg = contextConfig.get(key);
switch (key) {
case DocumentContextSourceFetcher.NAME:
fetchers.add(DocumentContextSourceFetcher.create(cfg));
fetchers.add(DocumentContextSourceFetcher.create(cfg, environment));
break;
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key));
}
}
if (includeQueryContextFetcher) {
fetchers.add(new QueryContextSourceFetcher());
fetchers.add(new QueryContextSourceFetcher(environment));
}
return fetchers;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.ObjectPath;
import org.opensearch.env.Environment;
import org.opensearch.search.SearchHit;

import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

Expand All @@ -29,8 +32,10 @@ public class DocumentContextSourceFetcher implements ContextSourceFetcher {

public static final String NAME = "document_fields";
public static final String DOCUMENT_CONTEXT_LIST_FIELD = "document_context_list";
public static final int MAX_DOCUMENT_FIELDS = 50;

private final List<String> contextFields;
private final Environment environment;

/**
* Fetch the information needed in order to rerank.
Expand Down Expand Up @@ -87,15 +92,26 @@ public String getName() {
* @param config configuration object grabbed from parsed API request. Should be a list of strings
* @return a new DocumentContextSourceFetcher or throws IllegalArgumentException if config is malformed
*/
public static DocumentContextSourceFetcher create(Object config) {
public static DocumentContextSourceFetcher create(Object config, Environment environment) {
if (!(config instanceof List)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a list of field names", NAME));
}
List<?> fields = (List<?>) config;
if (fields.size() == 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", NAME));
}
if (fields.size() > RERANKER_MAX_DOC_FIELDS.get(environment.settings())) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"%s must not contain more than %d fields. Configure by setting %s",
NAME,
RERANKER_MAX_DOC_FIELDS.get(environment.settings()),
RERANKER_MAX_DOC_FIELDS.getKey()
)
);
}
List<String> fieldsAsStrings = fields.stream().map(field -> (String) field).collect(Collectors.toList());
return new DocumentContextSourceFetcher(fieldsAsStrings);
return new DocumentContextSourceFetcher(fieldsAsStrings, environment);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,25 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.search.SearchExtBuilder;

import lombok.AllArgsConstructor;

/**
* Context Source Fetcher that gets context from the rerank query ext.
*/
@AllArgsConstructor
public class QueryContextSourceFetcher implements ContextSourceFetcher {

public static final String NAME = "query_context";
public static final String QUERY_TEXT_FIELD = "query_text";
public static final String QUERY_TEXT_PATH_FIELD = "query_text_path";

private final Environment environment;

@Override
public void fetchContext(
final SearchRequest searchRequest,
Expand Down Expand Up @@ -65,6 +72,16 @@ public void fetchContext(
} else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) {
// Case "query_text_path": ser/de the query into a map and then find the text at the path specified
String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD);
if (!validatePath(path)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d",
QUERY_TEXT_PATH_FIELD,
MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())
)
);
}
Map<String, Object> map = requestToMap(searchRequest);
// Get the text at the path
Object queryText = ObjectPath.eval(path, map);
Expand Down Expand Up @@ -107,4 +124,11 @@ private static Map<String, Object> requestToMap(final SearchRequest request) thr
Map<String, Object> map = parser.map();
return map;
}

private boolean validatePath(final String path) {
if (path == null || path.isEmpty()) {
return true;
}
return path.split("\\.").length <= MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,13 @@ public final class NeuralSearchSettings {
false,
Setting.Property.NodeScope
);

/**
* Limits the number of document fields that can be passed to the reranker.
*/
public static final Setting<Integer> RERANKER_MAX_DOC_FIELDS = Setting.intSetting(
"plugins.neural_search.reranker_max_document_fields",
50,
Setting.Property.NodeScope
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
*/
package org.opensearch.neuralsearch.processor.factory;

import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
Expand All @@ -15,6 +18,8 @@
import org.junit.Before;
import org.mockito.Mock;
import org.opensearch.OpenSearchParseException;
import org.opensearch.common.settings.Settings;
import org.opensearch.env.Environment;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
Expand All @@ -37,11 +42,16 @@ public class RerankProcessorFactoryTests extends OpenSearchTestCase {
@Mock
private PipelineContext pipelineContext;

@Mock
private Environment environment;

@Before
public void setup() {
environment = mock(Environment.class);
pipelineContext = mock(PipelineContext.class);
clientAccessor = mock(MLCommonsClientAccessor.class);
factory = new RerankProcessorFactory(clientAccessor);
factory = new RerankProcessorFactory(clientAccessor, environment);
doReturn(Settings.EMPTY).when(environment).settings();
}

public void testRerankProcessorFactory_whenEmptyConfig_thenFail() {
Expand Down Expand Up @@ -187,4 +197,26 @@ public void testCrossEncoder_whenEmptyContextDocField_thenFail() {
);
}

public void testCrossEncoder_whenTooManyDocFields_ThenFail() {
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")),
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, Collections.nCopies(75, "field")))
)
);
assertThrows(
String.format(
Locale.ROOT,
"%s must not contain more than %d fields. Configure by setting %s",
DocumentContextSourceFetcher.NAME,
RERANKER_MAX_DOC_FIELDS.get(environment.settings()),
RERANKER_MAX_DOC_FIELDS.getKey()
),
IllegalArgumentException.class,
() -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext)
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.common.document.DocumentField;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
Expand Down Expand Up @@ -65,14 +68,18 @@ public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase {
@Mock
private PipelineProcessingContext ppctx;

@Mock
private Environment environment;

private RerankProcessorFactory factory;

private MLOpenSearchRerankProcessor processor;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
factory = new RerankProcessorFactory(mlCommonsClientAccessor);
doReturn(Settings.EMPTY).when(environment).settings();
factory = new RerankProcessorFactory(mlCommonsClientAccessor, environment);
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
Expand Down Expand Up @@ -223,6 +230,27 @@ public void testRerankContext_whenQueryTextPathIsBadPointer_thenFail() throws IO
.equals(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD + " must point to a string field"));
}

public void testRerankContext_whenQueryTextPathIsExceeedinglyLong_thenFail() throws IOException {
setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "a.b.c.d.e.f.g.h.i.j.k.l.m.n.o.p.q.r.s.t.w.x.y.z"));
setupSearchResults();
@SuppressWarnings("unchecked")
ActionListener<Map<String, Object>> listener = mock(ActionListener.class);
processor.generateRerankingContext(request, response, listener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue() instanceof IllegalArgumentException);
assert (argCaptor.getValue()
.getMessage()
.equals(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d",
QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD,
MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())
)
));
}

public void testRescoreSearchResponse_HappyPath() throws IOException {
setupSimilarityRescoring();
setupSearchResults();
Expand Down

0 comments on commit ed8e581

Please sign in to comment.