diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java index 7b316bd7..30b82f0a 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDao.java @@ -15,13 +15,16 @@ package com.amazon.opendistroforelasticsearch.ad.feature; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY; import static org.apache.commons.math3.linear.MatrixUtils.createRealMatrix; import java.io.IOException; import java.util.AbstractMap.SimpleEntry; +import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayDeque; import java.util.Arrays; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -35,20 +38,38 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.RangeQueryBuilder; +import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; import org.elasticsearch.search.aggregations.bucket.range.InternalDateRange; +import org.elasticsearch.search.aggregations.bucket.range.InternalDateRange.Bucket; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.InternalTDigestPercentiles; import org.elasticsearch.search.aggregations.metrics.Max; +import org.elasticsearch.search.aggregations.metrics.Min; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue; import org.elasticsearch.search.aggregations.metrics.Percentile; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.threadpool.ThreadPool; +import com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorPlugin; +import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; import com.amazon.opendistroforelasticsearch.ad.dataprocessor.Interpolator; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; +import com.amazon.opendistroforelasticsearch.ad.model.Feature; import com.amazon.opendistroforelasticsearch.ad.model.IntervalTimeConfiguration; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; import com.amazon.opendistroforelasticsearch.ad.util.ParseUtils; @@ -59,6 +80,8 @@ public class SearchFeatureDao { protected static final String AGG_NAME_MAX = "max_timefield"; + protected static final String AGG_NAME_MIN = "min_timefield"; + protected static final String AGG_NAME_TERM = "term_agg"; private static final Logger logger = LogManager.getLogger(SearchFeatureDao.class); @@ -67,6 +90,8 @@ public class SearchFeatureDao { private final NamedXContentRegistry xContent; private final Interpolator interpolator; private final ClientUtil clientUtil; + private ThreadPool threadPool; + private int maxEntitiesPerQuery; /** * Constructor injection. @@ -75,12 +100,26 @@ public class SearchFeatureDao { * @param xContent ES XContentRegistry * @param interpolator interpolator for missing values * @param clientUtil utility for ES client + * @param threadPool accessor to different threadpools + * @param settings ES settings + * @param clusterService ES ClusterService */ - public SearchFeatureDao(Client client, NamedXContentRegistry xContent, Interpolator interpolator, ClientUtil clientUtil) { + public SearchFeatureDao( + Client client, + NamedXContentRegistry xContent, + Interpolator interpolator, + ClientUtil clientUtil, + ThreadPool threadPool, + Settings settings, + ClusterService clusterService + ) { this.client = client; this.xContent = xContent; this.interpolator = interpolator; this.clientUtil = clientUtil; + this.threadPool = threadPool; + this.maxEntitiesPerQuery = MAX_ENTITIES_PER_QUERY.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ENTITIES_PER_QUERY, it -> maxEntitiesPerQuery = it); } /** @@ -129,6 +168,47 @@ private Optional getLatestDataTime(SearchResponse searchResponse) { .map(agg -> (long) agg.getValue()); } + /** + * Get the entity's earliest and latest timestamps + * @param detector detector config + * @param entityName entity's name + * @param listener listener to return back the requested timestamps + */ + public void getEntityMinMaxDataTime( + AnomalyDetector detector, + String entityName, + ActionListener, Optional>> listener + ) { + TermQueryBuilder term = new TermQueryBuilder(detector.getCategoryField().get(0), entityName); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(term); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(internalFilterQuery) + .aggregation(AggregationBuilders.max(AGG_NAME_MAX).field(detector.getTimeField())) + .aggregation(AggregationBuilders.min(AGG_NAME_MIN).field(detector.getTimeField())) + .trackTotalHits(false) + .size(0); + SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + client + .search( + searchRequest, + ActionListener.wrap(response -> { listener.onResponse(parseMinMaxDataTime(response)); }, listener::onFailure) + ); + } + + private Entry, Optional> parseMinMaxDataTime(SearchResponse searchResponse) { + Optional> mapOptional = Optional + .ofNullable(searchResponse) + .map(SearchResponse::getAggregations) + .map(aggs -> aggs.asMap()); + + Optional latest = mapOptional.map(map -> (Max) map.get(AGG_NAME_MAX)).map(agg -> (long) agg.getValue()); + + Optional earliest = mapOptional.map(map -> (Min) map.get(AGG_NAME_MIN)).map(agg -> (long) agg.getValue()); + + return new SimpleImmutableEntry<>(earliest, latest); + } + /** * Gets features for the given time period. * This function also adds given detector to negative cache before sending es request. @@ -569,4 +649,137 @@ private Optional parseAggregations(Optional aggregations ) .filter(result -> Arrays.stream(result).noneMatch(d -> Double.isNaN(d) || Double.isInfinite(d))); } + + public void getColdStartSamplesForPeriods( + AnomalyDetector detector, + List> ranges, + String entityName, + ActionListener>> listener + ) throws IOException { + SearchRequest request = createColdStartFeatureSearchRequest(detector, ranges, entityName); + + client.search(request, ActionListener.wrap(response -> { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + listener.onResponse(Collections.emptyList()); + return; + } + + // Extract buckets and order by from_as_string. Currently by default it is ascending. Better not to assume it. + // Example responses from date range bucket aggregation: + // "aggregations":{"date_range":{"buckets":[{"key":"1598865166000-1598865226000","from":1.598865166E12," + // from_as_string":"1598865166000","to":1.598865226E12,"to_as_string":"1598865226000","doc_count":3, + // "deny_max":{"value":154.0}},{"key":"1598869006000-1598869066000","from":1.598869006E12, + // "from_as_string":"1598869006000","to":1.598869066E12,"to_as_string":"1598869066000","doc_count":3, + // "deny_max":{"value":141.0}}, + listener + .onResponse( + aggs + .asList() + .stream() + .filter(InternalDateRange.class::isInstance) + .flatMap(agg -> ((InternalDateRange) agg).getBuckets().stream()) + .filter(bucket -> bucket.getFrom() != null) + .sorted(Comparator.comparing((Bucket bucket) -> Long.valueOf(bucket.getFromAsString()))) + .map(bucket -> parseBucket(bucket, detector.getEnabledFeatureIds())) + .collect(Collectors.toList()) + ); + }, listener::onFailure)); + } + + /** + * Get features by entities. An entity is one combination of particular + * categorical fields’ value. A categorical field in this setting refers to + * an Elasticsearch field of type keyword or ip. Specifically, an entity + * can be the IP address 182.3.4.5. + * @param detector Accessor to the detector object + * @param startMilli Start of time range to query + * @param endMilli End of time range to query + * @param listener Listener to return entities and their data points + */ + public void getFeaturesByEntities( + AnomalyDetector detector, + long startMilli, + long endMilli, + ActionListener> listener + ) { + try { + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(detector.getTimeField()) + .gte(startMilli) + .lt(endMilli) + .format("epoch_millis"); + + BoolQueryBuilder internalFilterQuery = new BoolQueryBuilder().filter(detector.getFilterQuery()).filter(rangeQuery); + + /* Terms aggregation implementation.*/ + // Support one category field + TermsAggregationBuilder termsAgg = AggregationBuilders + .terms(AGG_NAME_TERM) + .field(detector.getCategoryField().get(0)) + .size(maxEntitiesPerQuery); + for (Feature feature : detector.getFeatureAttributes()) { + AggregatorFactories.Builder internalAgg = ParseUtils + .parseAggregators(feature.getAggregation().toString(), xContent, feature.getId()); + termsAgg.subAggregation(internalAgg.getAggregatorFactories().iterator().next()); + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(internalFilterQuery) + .size(0) + .aggregation(termsAgg) + .trackTotalHits(false); + SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder); + + ActionListener termsListener = ActionListener.wrap(response -> { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + listener.onResponse(Collections.emptyMap()); + return; + } + + Map results = aggs + .asList() + .stream() + .filter(agg -> AGG_NAME_TERM.equals(agg.getName())) + .flatMap(agg -> ((Terms) agg).getBuckets().stream()) + .collect( + Collectors.toMap(Terms.Bucket::getKeyAsString, bucket -> parseBucket(bucket, detector.getEnabledFeatureIds()).get()) + ); + + listener.onResponse(results); + }, listener::onFailure); + + client + .search( + searchRequest, + new ThreadedActionListener<>(logger, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, termsListener, false) + ); + + } catch (IOException e) { + throw new EndRunException(detector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, true); + } + } + + private SearchRequest createColdStartFeatureSearchRequest(AnomalyDetector detector, List> ranges, String entityName) { + try { + SearchSourceBuilder searchSourceBuilder = ParseUtils.generateEntityColdStartQuery(detector, ranges, entityName, xContent); + return new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder); + } catch (IOException e) { + logger + .warn( + "Failed to create cold start feature search request for " + + detector.getDetectorId() + + " from " + + ranges.get(0).getKey() + + " to " + + ranges.get(ranges.size() - 1).getKey(), + e + ); + throw new IllegalStateException(e); + } + } + + private Optional parseBucket(MultiBucketsAggregation.Bucket bucket, List featureIds) { + return parseAggregations(Optional.ofNullable(bucket).map(b -> b.getAggregations()), featureIds); + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ParseUtils.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ParseUtils.java index 66b753f2..9b4a4bee 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ParseUtils.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/util/ParseUtils.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -35,6 +36,7 @@ import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.RangeQueryBuilder; +import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.BaseAggregationBuilder; @@ -44,6 +46,7 @@ import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.opendistroforelasticsearch.ad.model.Feature; +import com.amazon.opendistroforelasticsearch.ad.model.FeatureData; /** * Parsing utility functions. @@ -341,4 +344,50 @@ public static String generateInternalFeatureQueryTemplate(AnomalyDetector detect return internalSearchSourceBuilder.toString(); } + + public static SearchSourceBuilder generateEntityColdStartQuery( + AnomalyDetector detector, + List> ranges, + String entityName, + NamedXContentRegistry xContentRegistry + ) throws IOException { + + TermQueryBuilder term = new TermQueryBuilder(detector.getCategoryField().get(0), entityName); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(detector.getFilterQuery()).filter(term); + + DateRangeAggregationBuilder dateRangeBuilder = dateRange("date_range").field(detector.getTimeField()).format("epoch_millis"); + for (Entry range : ranges) { + dateRangeBuilder.addRange(range.getKey(), range.getValue()); + } + + if (detector.getFeatureAttributes() != null) { + for (Feature feature : detector.getFeatureAttributes()) { + AggregatorFactories.Builder internalAgg = parseAggregators( + feature.getAggregation().toString(), + xContentRegistry, + feature.getId() + ); + dateRangeBuilder.subAggregation(internalAgg.getAggregatorFactories().iterator().next()); + } + } + + return new SearchSourceBuilder().query(internalFilterQuery).size(0).aggregation(dateRangeBuilder); + } + + /** + * Map feature data to its Id and name + * @param currentFeature Feature data + * @param detector Detector Config object + * @return a list of feature data with Id and name + */ + public static List getFeatureData(double[] currentFeature, AnomalyDetector detector) { + List featureIds = detector.getEnabledFeatureIds(); + List featureNames = detector.getEnabledFeatureNames(); + int featureLen = featureIds.size(); + List featureData = new ArrayList<>(); + for (int i = 0; i < featureLen; i++) { + featureData.add(new FeatureData(featureIds.get(i), featureNames.get(i), currentFeature[i])); + } + return featureData; + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDaoTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDaoTests.java index c6976c2d..44a5b506 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDaoTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/SearchFeatureDaoTests.java @@ -16,10 +16,14 @@ package com.amazon.opendistroforelasticsearch.ad.feature; import static java.util.Arrays.asList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.AnyOf.anyOf; +import static org.hamcrest.core.IsInstanceOf.instanceOf; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.any; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyObject; import static org.mockito.Matchers.eq; @@ -30,22 +34,32 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.ZoneId; import java.time.temporal.ChronoUnit; import java.util.AbstractMap.SimpleEntry; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import java.util.concurrent.ExecutorService; import java.util.function.BiConsumer; import junitparams.JUnitParamsRunner; import junitparams.Parameters; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.MultiSearchRequest; @@ -53,24 +67,45 @@ import org.elasticsearch.action.search.MultiSearchResponse.Item; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchResponseSections; +import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.time.DateFormatter; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.mapper.DateFieldMapper; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.TemplateScript; import org.elasticsearch.script.TemplateScript.Factory; +import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.BucketOrder; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; +import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.InternalMax; +import org.elasticsearch.search.aggregations.metrics.InternalMin; import org.elasticsearch.search.aggregations.metrics.InternalTDigestPercentiles; import org.elasticsearch.search.aggregations.metrics.Max; +import org.elasticsearch.search.aggregations.metrics.MaxAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.search.aggregations.metrics.Percentile; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.threadpool.ThreadPool; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -84,20 +119,27 @@ import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunnerDelegate; +import com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorPlugin; +import com.amazon.opendistroforelasticsearch.ad.NodeStateManager; +import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; import com.amazon.opendistroforelasticsearch.ad.dataprocessor.Interpolator; import com.amazon.opendistroforelasticsearch.ad.dataprocessor.LinearUniformInterpolator; import com.amazon.opendistroforelasticsearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; +import com.amazon.opendistroforelasticsearch.ad.model.Feature; import com.amazon.opendistroforelasticsearch.ad.model.IntervalTimeConfiguration; -import com.amazon.opendistroforelasticsearch.ad.transport.TransportStateManager; +import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings; import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil; import com.amazon.opendistroforelasticsearch.ad.util.ParseUtils; +import com.google.common.collect.ImmutableList; @PowerMockIgnore("javax.management.*") @RunWith(PowerMockRunner.class) @PowerMockRunnerDelegate(JUnitParamsRunner.class) @PrepareForTest({ ParseUtils.class }) public class SearchFeatureDaoTests { + private final Logger LOG = LogManager.getLogger(SearchFeatureDaoTests.class); + private SearchFeatureDao searchFeatureDao; @Mock @@ -128,19 +170,26 @@ public class SearchFeatureDaoTests { @Mock private Max max; @Mock - private TransportStateManager stateManager; + private NodeStateManager stateManager; @Mock private AnomalyDetector detector; + @Mock + private ThreadPool threadPool; + + @Mock + private ClusterService clusterService; + private SearchSourceBuilder featureQuery = new SearchSourceBuilder(); - private Map searchRequestParams; + // private Map searchRequestParams; private SearchRequest searchRequest; private SearchSourceBuilder searchSourceBuilder; private MultiSearchRequest multiSearchRequest; private Map aggsMap; - private List aggsList; + // private List aggsList; private IntervalTimeConfiguration detectionInterval; + // private Settings settings; @Before public void setup() throws Exception { @@ -148,20 +197,38 @@ public void setup() throws Exception { PowerMockito.mockStatic(ParseUtils.class); Interpolator interpolator = new LinearUniformInterpolator(new SingleFeatureLinearUniformInterpolator()); - searchFeatureDao = spy(new SearchFeatureDao(client, xContent, interpolator, clientUtil)); + + ExecutorService executorService = mock(ExecutorService.class); + when(threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + + Settings settings = Settings.EMPTY; + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + searchFeatureDao = spy(new SearchFeatureDao(client, xContent, interpolator, clientUtil, threadPool, settings, clusterService)); detectionInterval = new IntervalTimeConfiguration(1, ChronoUnit.MINUTES); when(detector.getTimeField()).thenReturn("testTimeField"); when(detector.getIndices()).thenReturn(Arrays.asList("testIndices")); when(detector.generateFeatureQuery()).thenReturn(featureQuery); when(detector.getDetectionInterval()).thenReturn(detectionInterval); + when(detector.getFilterQuery()).thenReturn(QueryBuilders.matchAllQuery()); + when(detector.getCategoryField()).thenReturn(Collections.singletonList("a")); searchSourceBuilder = SearchSourceBuilder .fromXContent(XContentType.JSON.xContent().createParser(xContent, LoggingDeprecationHandler.INSTANCE, "{}")); - searchRequestParams = new HashMap<>(); + // searchRequestParams = new HashMap<>(); searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0])); aggsMap = new HashMap<>(); - aggsList = new ArrayList<>(); + // aggsList = new ArrayList<>(); when(max.getName()).thenReturn(SearchFeatureDao.AGG_NAME_MAX); List list = new ArrayList<>(); @@ -648,4 +715,213 @@ public void getFeaturesForSampledPeriods_throwToListener_whenSamplingFail() { private Entry pair(K key, V value) { return new SimpleEntry<>(key, value); } + + @Test + @SuppressWarnings("unchecked") + public void testNormalGetFeaturesByEntities() throws IOException { + SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); + + String aggregationId = "deny_max"; + String featureName = "deny max"; + AggregationBuilder builder = new MaxAggregationBuilder("deny_max").field("deny"); + AggregatorFactories.Builder aggBuilder = AggregatorFactories.builder(); + aggBuilder.addAggregator(builder); + when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList(aggregationId)); + when(detector.getFeatureAttributes()).thenReturn(Collections.singletonList(new Feature(aggregationId, featureName, true, builder))); + when(ParseUtils.parseAggregators(anyString(), any(), anyString())).thenReturn(aggBuilder); + + String app0Name = "app_0"; + double app0Max = 1976.0; + InternalAggregation app0Agg = new InternalMax(aggregationId, app0Max, DocValueFormat.RAW, Collections.emptyMap()); + StringTerms.Bucket app0Bucket = new StringTerms.Bucket( + new BytesRef(app0Name.getBytes(StandardCharsets.UTF_8), 0, app0Name.getBytes(StandardCharsets.UTF_8).length), + 3, + InternalAggregations.from(Collections.singletonList(app0Agg)), + false, + 0, + DocValueFormat.RAW + ); + + String app1Name = "app_1"; + double app1Max = 3604.0; + InternalAggregation app1Agg = new InternalMax(aggregationId, app1Max, DocValueFormat.RAW, Collections.emptyMap()); + StringTerms.Bucket app1Bucket = new StringTerms.Bucket( + new BytesRef(app1Name.getBytes(StandardCharsets.UTF_8), 0, app1Name.getBytes(StandardCharsets.UTF_8).length), + 3, + InternalAggregations.from(Collections.singletonList(app1Agg)), + false, + 0, + DocValueFormat.RAW + ); + + List stringBuckets = ImmutableList.of(app0Bucket, app1Bucket); + + StringTerms termsAgg = new StringTerms( + "term_agg", + BucketOrder.count(false), + 1, + 0, + Collections.emptyMap(), + DocValueFormat.RAW, + 1, + false, + 0, + stringBuckets, + 0 + ); + + InternalAggregations internalAggregations = InternalAggregations.from(Collections.singletonList(termsAgg)); + + SearchResponseSections searchSections = new SearchResponseSections(hits, internalAggregations, null, false, false, null, 1); + + // Simulate response: + // {"took":507,"timed_out":false,"_shards":{"total":1,"successful":1, + // "skipped":0,"failed":0},"hits":{"max_score":null,"hits":[]}, + // "aggregations":{"term_agg":{"doc_count_error_upper_bound":0, + // "sum_other_doc_count":0,"buckets":[{"key":"app_0","doc_count":3, + // "deny_max":{"value":1976.0}},{"key":"app_1","doc_count":3, + // "deny_max":{"value":3604.0}}]}}} + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 507, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + doAnswer(invocation -> { + SearchRequest request = invocation.getArgument(0); + assertEquals(1, request.indices().length); + assertTrue(detector.getIndices().contains(request.indices()[0])); + AggregatorFactories.Builder aggs = request.source().aggregations(); + assertEquals(1, aggs.count()); + Collection factory = aggs.getAggregatorFactories(); + assertTrue(!factory.isEmpty()); + assertThat(factory.iterator().next(), instanceOf(TermsAggregationBuilder.class)); + + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + searchFeatureDao.getFeaturesByEntities(detector, 10L, 20L, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + verify(listener).onResponse(captor.capture()); + Map result = captor.getValue(); + assertEquals(2, result.size()); + assertEquals(app0Max, result.get(app0Name)[0], 0.001); + assertEquals(app1Max, result.get(app1Name)[0], 0.001); + } + + @SuppressWarnings("unchecked") + @Test + public void testEmptyGetFeaturesByEntities() { + SearchResponseSections searchSections = new SearchResponseSections(null, null, null, false, false, null, 1); + + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 507, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + searchFeatureDao.getFeaturesByEntities(detector, 10L, 20L, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + verify(listener).onResponse(captor.capture()); + Map result = captor.getValue(); + assertEquals(0, result.size()); + } + + @SuppressWarnings("unchecked") + @Test(expected = EndRunException.class) + public void testParseIOException() throws Exception { + String aggregationId = "deny_max"; + String featureName = "deny max"; + AggregationBuilder builder = new MaxAggregationBuilder("deny_max").field("deny"); + AggregatorFactories.Builder aggBuilder = AggregatorFactories.builder(); + aggBuilder.addAggregator(builder); + when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList(aggregationId)); + when(detector.getFeatureAttributes()).thenReturn(Collections.singletonList(new Feature(aggregationId, featureName, true, builder))); + PowerMockito.doThrow(new IOException()).when(ParseUtils.class, "parseAggregators", anyString(), any(), anyString()); + + ActionListener> listener = mock(ActionListener.class); + searchFeatureDao.getFeaturesByEntities(detector, 10L, 20L, listener); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetEntityMinMaxDataTime() { + // simulate response {"took":11,"timed_out":false,"_shards":{"total":1, + // "successful":1,"skipped":0,"failed":0},"hits":{"max_score":null,"hits":[]}, + // "aggregations":{"min_timefield":{"value":1.602211285E12, + // "value_as_string":"2020-10-09T02:41:25.000Z"}, + // "max_timefield":{"value":1.602348325E12,"value_as_string":"2020-10-10T16:45:25.000Z"}}} + DocValueFormat dateFormat = new DocValueFormat.DateTime( + DateFormatter.forPattern("strict_date_optional_time||epoch_millis"), + ZoneId.of("UTC"), + DateFieldMapper.Resolution.MILLISECONDS + ); + double earliest = 1.602211285E12; + double latest = 1.602348325E12; + InternalMin minInternal = new InternalMin("min_timefield", earliest, dateFormat, new HashMap<>()); + InternalMax maxInternal = new InternalMax("max_timefield", latest, dateFormat, new HashMap<>()); + InternalAggregations internalAggregations = InternalAggregations.from(Arrays.asList(minInternal, maxInternal)); + SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, internalAggregations, null, false, false, null, 1); + + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + doAnswer(invocation -> { + SearchRequest request = invocation.getArgument(0); + assertEquals(1, request.indices().length); + assertTrue(detector.getIndices().contains(request.indices()[0])); + AggregatorFactories.Builder aggs = request.source().aggregations(); + assertEquals(2, aggs.count()); + Collection factory = aggs.getAggregatorFactories(); + assertTrue(!factory.isEmpty()); + Iterator iterator = factory.iterator(); + while (iterator.hasNext()) { + assertThat(iterator.next(), anyOf(instanceOf(MaxAggregationBuilder.class), instanceOf(MinAggregationBuilder.class))); + } + + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + + ActionListener, Optional>> listener = mock(ActionListener.class); + searchFeatureDao.getEntityMinMaxDataTime(detector, "app_1", listener); + + ArgumentCaptor, Optional>> captor = ArgumentCaptor.forClass(Entry.class); + verify(listener).onResponse(captor.capture()); + Entry, Optional> result = captor.getValue(); + assertEquals((long) earliest, result.getKey().get().longValue()); + assertEquals((long) latest, result.getValue().get().longValue()); + } }