Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validations from appsec #562

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524))
- Fix Flaky test reported in #433 ([#533](https://github.com/opensearch-project/neural-search/pull/533))
- Enable support for default model id on HybridQueryBuilder ([#541](https://github.com/opensearch-project/neural-search/pull/541))
- Add vaalidations for reranker requests per #555 ([#562](https://github.com/opensearch-project/neural-search/pull/562))
### Infrastructure
- BWC tests for Neural Search ([#515](https://github.com/opensearch-project/neural-search/pull/515))
- Github action to run integ tests in secure opensearch cluster ([#535](https://github.com/opensearch-project/neural-search/pull/535))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.plugin;

import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -145,7 +146,7 @@

@Override
public List<Setting<?>> getSettings() {
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED);
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED, RERANKER_MAX_DOC_FIELDS);

Check warning on line 149 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L149

Added line #L149 was not covered by tests
}

@Override
Expand All @@ -159,7 +160,10 @@
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchResponseProcessor>> getResponseProcessors(
Parameters parameters
) {
return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor));
return Map.of(

Check warning on line 163 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L163

Added line #L163 was not covered by tests
RerankProcessor.TYPE,
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService())

Check warning on line 165 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L165

Added line #L165 was not covered by tests
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Set;
import java.util.StringJoiner;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
Expand All @@ -37,6 +38,7 @@ public class RerankProcessorFactory implements Processor.Factory<SearchResponseP
public static final String CONTEXT_CONFIG_FIELD = "context";

private final MLCommonsClientAccessor clientAccessor;
private final ClusterService clusterService;

@Override
public SearchResponseProcessor create(
Expand All @@ -49,7 +51,12 @@ public SearchResponseProcessor create(
) {
RerankType type = findRerankType(config);
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(
config,
includeQueryContextFetcher,
tag,
clusterService
);
switch (type) {
case ML_OPENSEARCH:
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());
Expand Down Expand Up @@ -109,22 +116,23 @@ public static boolean shouldIncludeQueryContextFetcher(RerankType type) {
public static List<ContextSourceFetcher> createFetchers(
Map<String, Object> config,
boolean includeQueryContextFetcher,
String tag
String tag,
final ClusterService clusterService
) {
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD);
List<ContextSourceFetcher> fetchers = new ArrayList<>();
for (String key : contextConfig.keySet()) {
Object cfg = contextConfig.get(key);
switch (key) {
case DocumentContextSourceFetcher.NAME:
fetchers.add(DocumentContextSourceFetcher.create(cfg));
fetchers.add(DocumentContextSourceFetcher.create(cfg, clusterService));
break;
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key));
}
}
if (includeQueryContextFetcher) {
fetchers.add(new QueryContextSourceFetcher());
fetchers.add(new QueryContextSourceFetcher(clusterService));
}
return fetchers;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.ObjectPath;
import org.opensearch.search.SearchHit;

import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

Expand Down Expand Up @@ -87,14 +90,25 @@ public String getName() {
* @param config configuration object grabbed from parsed API request. Should be a list of strings
* @return a new DocumentContextSourceFetcher or throws IllegalArgumentException if config is malformed
*/
public static DocumentContextSourceFetcher create(Object config) {
public static DocumentContextSourceFetcher create(Object config, ClusterService clusterService) {
if (!(config instanceof List)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a list of field names", NAME));
}
List<?> fields = (List<?>) config;
if (fields.size() == 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", NAME));
}
if (fields.size() > RERANKER_MAX_DOC_FIELDS.get(clusterService.getSettings())) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"%s must not contain more than %d fields. Configure by setting %s",
NAME,
RERANKER_MAX_DOC_FIELDS.get(clusterService.getSettings()),
RERANKER_MAX_DOC_FIELDS.getKey()
)
);
}
List<String> fieldsAsStrings = fields.stream().map(field -> (String) field).collect(Collectors.toList());
return new DocumentContextSourceFetcher(fieldsAsStrings);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,36 @@

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ObjectPath;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.search.SearchExtBuilder;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

/**
* Context Source Fetcher that gets context from the rerank query ext.
*/
@Log4j2
@AllArgsConstructor
public class QueryContextSourceFetcher implements ContextSourceFetcher {

public static final String NAME = "query_context";
public static final String QUERY_TEXT_FIELD = "query_text";
public static final String QUERY_TEXT_PATH_FIELD = "query_text_path";

public static final Integer MAX_QUERY_PATH_STRLEN = 1000;

private final ClusterService clusterService;

@Override
public void fetchContext(
final SearchRequest searchRequest,
Expand Down Expand Up @@ -65,6 +76,7 @@
} else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) {
// Case "query_text_path": ser/de the query into a map and then find the text at the path specified
String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD);
validatePath(path);
Map<String, Object> map = requestToMap(searchRequest);
// Get the text at the path
Object queryText = ObjectPath.eval(path, map);
Expand Down Expand Up @@ -107,4 +119,32 @@
Map<String, Object> map = parser.map();
return map;
}

private void validatePath(final String path) throws IllegalArgumentException {
if (path == null || path.isEmpty()) {
return;

Check warning on line 125 in src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java#L125

Added line #L125 was not covered by tests
}
if (path.length() > MAX_QUERY_PATH_STRLEN) {
log.error(String.format(Locale.ROOT, "invalid %s due to too many characters: %s", QUERY_TEXT_PATH_FIELD, path));
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d characters",
QUERY_TEXT_PATH_FIELD,
MAX_QUERY_PATH_STRLEN
)
);
}
if (path.split("\\.").length > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(clusterService.getSettings())) {
log.error(String.format(Locale.ROOT, "invalid %s due to too many nested fields: %s", QUERY_TEXT_PATH_FIELD, path));
throw new IllegalArgumentException(
String.format(

Check warning on line 141 in src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java#L139-L141

Added lines #L139 - L141 were not covered by tests
Locale.ROOT,
"%s exceeded the maximum path length of %d nested fields",
QUERY_TEXT_PATH_FIELD,
MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(clusterService.getSettings())

Check warning on line 145 in src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java#L145

Added line #L145 was not covered by tests
)
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,13 @@ public final class NeuralSearchSettings {
false,
Setting.Property.NodeScope
);

/**
* Limits the number of document fields that can be passed to the reranker.
*/
public static final Setting<Integer> RERANKER_MAX_DOC_FIELDS = Setting.intSetting(
"plugins.neural_search.reranker_max_document_fields",
50,
Setting.Property.NodeScope
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
*/
package org.opensearch.neuralsearch.processor.factory;

import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
Expand All @@ -15,6 +18,8 @@
import org.junit.Before;
import org.mockito.Mock;
import org.opensearch.OpenSearchParseException;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
Expand All @@ -37,11 +42,16 @@ public class RerankProcessorFactoryTests extends OpenSearchTestCase {
@Mock
private PipelineContext pipelineContext;

@Mock
private ClusterService clusterService;

@Before
public void setup() {
clusterService = mock(ClusterService.class);
pipelineContext = mock(PipelineContext.class);
clientAccessor = mock(MLCommonsClientAccessor.class);
factory = new RerankProcessorFactory(clientAccessor);
factory = new RerankProcessorFactory(clientAccessor, clusterService);
doReturn(Settings.EMPTY).when(clusterService).getSettings();
}

public void testRerankProcessorFactory_whenEmptyConfig_thenFail() {
Expand Down Expand Up @@ -187,4 +197,26 @@ public void testCrossEncoder_whenEmptyContextDocField_thenFail() {
);
}

public void testCrossEncoder_whenTooManyDocFields_thenFail() {
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")),
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, Collections.nCopies(75, "field")))
)
);
assertThrows(
String.format(
Locale.ROOT,
"%s must not contain more than %d fields. Configure by setting %s",
DocumentContextSourceFetcher.NAME,
RERANKER_MAX_DOC_FIELDS.get(clusterService.getSettings()),
RERANKER_MAX_DOC_FIELDS.getKey()
),
IllegalArgumentException.class,
() -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext)
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponse.Clusters;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.common.document.DocumentField;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
Expand All @@ -49,6 +52,8 @@
import org.opensearch.search.pipeline.Processor.PipelineContext;
import org.opensearch.test.OpenSearchTestCase;

import lombok.SneakyThrows;

public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase {

@Mock
Expand All @@ -65,14 +70,18 @@ public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase {
@Mock
private PipelineProcessingContext ppctx;

@Mock
private ClusterService clusterService;

private RerankProcessorFactory factory;

private MLOpenSearchRerankProcessor processor;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
factory = new RerankProcessorFactory(mlCommonsClientAccessor);
doReturn(Settings.EMPTY).when(clusterService).getSettings();
factory = new RerankProcessorFactory(mlCommonsClientAccessor, clusterService);
Map<String, Object> config = new HashMap<>(
Map.of(
RerankType.ML_OPENSEARCH.getLabel(),
Expand Down Expand Up @@ -223,6 +232,51 @@ public void testRerankContext_whenQueryTextPathIsBadPointer_thenFail() throws IO
.equals(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD + " must point to a string field"));
}

@SneakyThrows
public void testRerankContext_whenQueryTextPathIsExceeedinglyManyCharacters_thenFail() {
// "eighteencharacters" * 60 = 1080 character string > max len of 1024
setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "eighteencharacters".repeat(60)));
setupSearchResults();
@SuppressWarnings("unchecked")
ActionListener<Map<String, Object>> listener = mock(ActionListener.class);
processor.generateRerankingContext(request, response, listener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue() instanceof IllegalArgumentException);
assert (argCaptor.getValue()
.getMessage()
.equals(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d characters",
QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD,
QueryContextSourceFetcher.MAX_QUERY_PATH_STRLEN
)
));
}

@SneakyThrows
public void textRerankContext_whenQueryTextPathIsExceeedinglyDeeplyNested_thenFail() {
setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "a.b.c.d.e.f.g.h.i.j.k.l.m.n.o.p.q.r.s.t.w.x.y.z"));
setupSearchResults();
@SuppressWarnings("unchecked")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this annotation? I don't see ActionListener<Map<String, Object>> listener = mock(ActionListener.class); is causing any warning in my local.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

¯_(ツ)_/¯. I get a warning when I remove it. Makes sense since it's a typecast from generic action listener to action listener for that map thingy.

ActionListener<Map<String, Object>> listener = mock(ActionListener.class);
processor.generateRerankingContext(request, response, listener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue() instanceof IllegalArgumentException);
assert (argCaptor.getValue()
.getMessage()
.equals(
String.format(
Locale.ROOT,
"%s exceeded the maximum path length of %d nested fields",
QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD,
MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(clusterService.getSettings())
)
));
}

public void testRescoreSearchResponse_HappyPath() throws IOException {
setupSimilarityRescoring();
setupSearchResults();
Expand Down
Loading