diff --git a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java index 6c489171b65ce8..19af6fbffb1ef3 100644 --- a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java +++ b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java @@ -3,12 +3,17 @@ import com.codahale.metrics.Timer; import com.datahub.util.exception.ESQueryException; import com.fasterxml.jackson.core.type.TypeReference; +import com.linkedin.data.template.StringArray; import com.linkedin.metadata.config.search.SearchConfiguration; import com.linkedin.metadata.config.search.custom.CustomSearchConfiguration; import com.linkedin.metadata.models.EntitySpec; import com.linkedin.metadata.models.registry.EntityRegistry; import com.linkedin.metadata.query.AutoCompleteResult; import com.linkedin.metadata.query.SearchFlags; +import com.linkedin.metadata.query.filter.ConjunctiveCriterion; +import com.linkedin.metadata.query.filter.ConjunctiveCriterionArray; +import com.linkedin.metadata.query.filter.Criterion; +import com.linkedin.metadata.query.filter.CriterionArray; import com.linkedin.metadata.query.filter.Filter; import com.linkedin.metadata.query.filter.SortCriterion; import com.linkedin.metadata.search.ScrollResult; @@ -132,14 +137,15 @@ public SearchResult search(@Nonnull String entityName, @Nonnull String input, @N final String finalInput = input.isEmpty() ? "*" : input; Timer.Context searchRequestTimer = MetricUtils.timer(this.getClass(), "searchRequest").time(); EntitySpec entitySpec = entityRegistry.getEntitySpec(entityName); + Filter transformedFilters = transformFilterForEntities(postFilters, indexConvention); // Step 1: construct the query final SearchRequest searchRequest = SearchRequestHandler .getBuilder(entitySpec, searchConfiguration, customSearchConfiguration) - .getSearchRequest(finalInput, postFilters, sortCriterion, from, size, searchFlags); + .getSearchRequest(finalInput, transformedFilters, sortCriterion, from, size, searchFlags); searchRequest.indices(indexConvention.getIndexName(entitySpec)); searchRequestTimer.stop(); // Step 2: execute the query and extract results, validated against document model as well - return executeAndExtract(entitySpec, searchRequest, postFilters, from, size); + return executeAndExtract(entitySpec, searchRequest, transformedFilters, from, size); } /** @@ -155,12 +161,14 @@ public SearchResult search(@Nonnull String entityName, @Nonnull String input, @N public SearchResult filter(@Nonnull String entityName, @Nullable Filter filters, @Nullable SortCriterion sortCriterion, int from, int size) { EntitySpec entitySpec = entityRegistry.getEntitySpec(entityName); + Filter transformedFilters = transformFilterForEntities(filters, indexConvention); final SearchRequest searchRequest = SearchRequestHandler .getBuilder(entitySpec, searchConfiguration, customSearchConfiguration) - .getFilterRequest(filters, sortCriterion, from, size); + .getFilterRequest(transformedFilters, sortCriterion, from, size); + searchRequest.indices(indexConvention.getIndexName(entitySpec)); - return executeAndExtract(entitySpec, searchRequest, filters, from, size); + return executeAndExtract(entitySpec, searchRequest, transformedFilters, from, size); } /** @@ -180,7 +188,7 @@ public AutoCompleteResult autoComplete(@Nonnull String entityName, @Nonnull Stri try { EntitySpec entitySpec = entityRegistry.getEntitySpec(entityName); AutocompleteRequestHandler builder = AutocompleteRequestHandler.getBuilder(entitySpec); - SearchRequest req = builder.getSearchRequest(query, field, requestParams, limit); + SearchRequest req = builder.getSearchRequest(query, field, transformFilterForEntities(requestParams, indexConvention), limit); req.indices(indexConvention.getIndexName(entitySpec)); SearchResponse searchResponse = client.search(req, RequestOptions.DEFAULT); return builder.extractResult(searchResponse, query); @@ -202,7 +210,7 @@ public AutoCompleteResult autoComplete(@Nonnull String entityName, @Nonnull Stri @Nonnull public Map aggregateByValue(@Nullable String entityName, @Nonnull String field, @Nullable Filter requestParams, int limit) { - final SearchRequest searchRequest = SearchRequestHandler.getAggregationRequest(field, requestParams, limit); + final SearchRequest searchRequest = SearchRequestHandler.getAggregationRequest(field, transformFilterForEntities(requestParams, indexConvention), limit); String indexName; if (entityName == null) { indexName = indexConvention.getAllEntityIndicesPattern(); @@ -261,10 +269,11 @@ public ScrollResult scroll(@Nonnull List entities, @Nonnull String input pitId = createPointInTime(indexArray, keepAlive); } + Filter transformedFilters = transformFilterForEntities(postFilters, indexConvention); // Step 1: construct the query final SearchRequest searchRequest = SearchRequestHandler .getBuilder(entitySpecs, searchConfiguration, customSearchConfiguration) - .getSearchRequest(finalInput, postFilters, sortCriterion, sort, pitId, keepAlive, size, searchFlags); + .getSearchRequest(finalInput, transformedFilters, sortCriterion, sort, pitId, keepAlive, size, searchFlags); // PIT specifies indices in creation so it doesn't support specifying indices on the request, so we only specify if not using PIT if (!supportsPointInTime()) { @@ -273,7 +282,46 @@ public ScrollResult scroll(@Nonnull List entities, @Nonnull String input scrollRequestTimer.stop(); // Step 2: execute the query and extract results, validated against document model as well - return executeAndExtract(entitySpecs, searchRequest, postFilters, scrollId, keepAlive, size); + return executeAndExtract(entitySpecs, searchRequest, transformedFilters, scrollId, keepAlive, size); + } + + private static Criterion transformEntityTypeCriterion(Criterion criterion, IndexConvention indexConvention) { + return criterion.setField("_index").setValues( + new StringArray(criterion.getValues().stream().map( + indexConvention::getEntityIndexName).collect( + Collectors.toList()))).setValue(indexConvention.getEntityIndexName(criterion.getValue())); + } + + private static ConjunctiveCriterion transformConjunctiveCriterion(ConjunctiveCriterion conjunctiveCriterion, + IndexConvention indexConvention) { + return new ConjunctiveCriterion().setAnd( + conjunctiveCriterion.getAnd().stream().map( + criterion -> criterion.getField().equals("_entityType") + ? transformEntityTypeCriterion(criterion, indexConvention) + : criterion) + .collect(Collectors.toCollection(CriterionArray::new))); + } + + private static ConjunctiveCriterionArray transformConjunctiveCriterionArray(ConjunctiveCriterionArray criterionArray, + IndexConvention indexConvention) { + return new ConjunctiveCriterionArray( + criterionArray.stream().map( + conjunctiveCriterion -> transformConjunctiveCriterion(conjunctiveCriterion, indexConvention)) + .collect(Collectors.toList())); + } + + /** + * Allows filtering on entities which are stored as different indices under the hood by transforming the tag + * _entityType to _index and updating the type to the index name. + * @param filter The filter to parse and transform if needed + * @param indexConvention The index convention used to generate the index name for an entity + * @return A filter, with the changes if necessary + */ + static Filter transformFilterForEntities(Filter filter, @Nonnull IndexConvention indexConvention) { + if (filter != null && filter.getOr() != null) { + return new Filter().setOr(transformConjunctiveCriterionArray(filter.getOr(), indexConvention)); + } + return filter; } private boolean supportsPointInTime() { diff --git a/metadata-io/src/test/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAOTest.java b/metadata-io/src/test/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAOTest.java new file mode 100644 index 00000000000000..2e744469e9aa13 --- /dev/null +++ b/metadata-io/src/test/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAOTest.java @@ -0,0 +1,111 @@ +package com.linkedin.metadata.search.elasticsearch.query; + +import com.google.common.collect.ImmutableList; +import com.linkedin.data.template.StringArray; +import com.linkedin.metadata.ESSampleDataFixture; +import com.linkedin.metadata.query.filter.Condition; +import com.linkedin.metadata.query.filter.ConjunctiveCriterion; +import com.linkedin.metadata.query.filter.ConjunctiveCriterionArray; +import com.linkedin.metadata.query.filter.CriterionArray; +import com.linkedin.metadata.query.filter.Filter; +import com.linkedin.metadata.utils.elasticsearch.IndexConvention; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Import; +import org.springframework.test.context.testng.AbstractTestNGSpringContextTests; +import org.testng.annotations.Test; + +import com.linkedin.metadata.query.filter.Criterion; +import org.springframework.beans.factory.annotation.Qualifier; + +import static org.testng.Assert.*; + + +@Import(ESSampleDataFixture.class) +public class ESSearchDAOTest extends AbstractTestNGSpringContextTests { + @Autowired + @Qualifier("sampleDataIndexConvention") + IndexConvention indexConvention; + + @Test + public void testTransformFilterForEntitiesNoChange() { + Criterion c = new Criterion().setValue("urn:li:tag:abc").setValues( + new StringArray(ImmutableList.of("urn:li:tag:abc", "urn:li:tag:def")) + ).setNegated(false).setCondition(Condition.EQUAL).setField("tags.keyword"); + + Filter f = new Filter().setOr( + new ConjunctiveCriterionArray(new ConjunctiveCriterion().setAnd(new CriterionArray(c)))); + + Filter transformedFilter = ESSearchDAO.transformFilterForEntities(f, indexConvention); + assertEquals(f, transformedFilter); + } + + @Test + public void testTransformFilterForEntitiesNullFilter() { + Filter transformedFilter = ESSearchDAO.transformFilterForEntities(null, indexConvention); + assertNotNull(indexConvention); + assertEquals(null, transformedFilter); + } + + @Test + public void testTransformFilterForEntitiesWithChanges() { + + Criterion c = new Criterion().setValue("dataset").setValues( + new StringArray(ImmutableList.of("dataset")) + ).setNegated(false).setCondition(Condition.EQUAL).setField("_entityType"); + + Filter f = new Filter().setOr( + new ConjunctiveCriterionArray(new ConjunctiveCriterion().setAnd(new CriterionArray(c)))); + Filter originalF = null; + try { + originalF = f.copy(); + } catch (CloneNotSupportedException e) { + fail(e.getMessage()); + } + assertEquals(f, originalF); + + Filter transformedFilter = ESSearchDAO.transformFilterForEntities(f, indexConvention); + assertNotEquals(originalF, transformedFilter); + + Criterion expectedNewCriterion = new Criterion().setValue("smpldat_datasetindex_v2").setValues( + new StringArray(ImmutableList.of("smpldat_datasetindex_v2")) + ).setNegated(false).setCondition(Condition.EQUAL).setField("_index"); + + Filter expectedNewFilter = new Filter().setOr( + new ConjunctiveCriterionArray(new ConjunctiveCriterion().setAnd(new CriterionArray(expectedNewCriterion)))); + + assertEquals(expectedNewFilter, transformedFilter); + } + + @Test + public void testTransformFilterForEntitiesWithSomeChanges() { + + Criterion criterionChanged = new Criterion().setValue("dataset").setValues( + new StringArray(ImmutableList.of("dataset")) + ).setNegated(false).setCondition(Condition.EQUAL).setField("_entityType"); + Criterion criterionUnchanged = new Criterion().setValue("urn:li:tag:abc").setValues( + new StringArray(ImmutableList.of("urn:li:tag:abc", "urn:li:tag:def")) + ).setNegated(false).setCondition(Condition.EQUAL).setField("tags.keyword"); + + Filter f = new Filter().setOr( + new ConjunctiveCriterionArray(new ConjunctiveCriterion().setAnd(new CriterionArray(criterionChanged, criterionUnchanged)))); + Filter originalF = null; + try { + originalF = f.copy(); + } catch (CloneNotSupportedException e) { + fail(e.getMessage()); + } + assertEquals(f, originalF); + + Filter transformedFilter = ESSearchDAO.transformFilterForEntities(f, indexConvention); + assertNotEquals(originalF, transformedFilter); + + Criterion expectedNewCriterion = new Criterion().setValue("smpldat_datasetindex_v2").setValues( + new StringArray(ImmutableList.of("smpldat_datasetindex_v2")) + ).setNegated(false).setCondition(Condition.EQUAL).setField("_index"); + + Filter expectedNewFilter = new Filter().setOr( + new ConjunctiveCriterionArray(new ConjunctiveCriterion().setAnd(new CriterionArray(expectedNewCriterion, criterionUnchanged)))); + + assertEquals(expectedNewFilter, transformedFilter); + } +}