Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Fix issue where data hole exists for Preview API (#312)
Browse files Browse the repository at this point in the history
  • Loading branch information
yizheliu-amazon authored Nov 19, 2020
1 parent 28410da commit 5d30bb9
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,13 @@ void getSamplesInRangesForEntity(
ActionListener<Entry<List<Entry<Long, Long>>, double[][]>> listener
) throws IOException {
searchFeatureDao
.getColdStartSamplesForPeriods(detector, sampleRanges, entity.getValue(), getSamplesRangesListener(sampleRanges, listener));
.getColdStartSamplesForPeriods(
detector,
sampleRanges,
entity.getValue(),
true,
getSamplesRangesListener(sampleRanges, listener)
);
}

private ActionListener<List<Optional<double[]>>> getSamplesRangesListener(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ public void getColdStartSamplesForPeriods(
AnomalyDetector detector,
List<Entry<Long, Long>> ranges,
String entityName,
boolean includesEmptyBucket,
ActionListener<List<Optional<double[]>>> listener
) throws IOException {
SearchRequest request = createColdStartFeatureSearchRequest(detector, ranges, entityName);
Expand All @@ -712,6 +713,8 @@ public void getColdStartSamplesForPeriods(
return;
}

long docCountThreshold = includesEmptyBucket ? -1 : 0;

// Extract buckets and order by from_as_string. Currently by default it is ascending. Better not to assume it.
// Example responses from date range bucket aggregation:
// "aggregations":{"date_range":{"buckets":[{"key":"1598865166000-1598865226000","from":1.598865166E12,"
Expand All @@ -730,7 +733,7 @@ public void getColdStartSamplesForPeriods(
.filter(InternalDateRange.class::isInstance)
.flatMap(agg -> ((InternalDateRange) agg).getBuckets().stream())
.filter(bucket -> bucket.getFrom() != null)
.filter(bucket -> bucket.getDocCount() > 0)
.filter(bucket -> bucket.getDocCount() > docCountThreshold)
.sorted(Comparator.comparing((Bucket bucket) -> Long.valueOf(bucket.getFromAsString())))
.map(bucket -> parseBucket(bucket, detector.getEnabledFeatureIds()))
.collect(Collectors.toList())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ private void getEntityColdStartData(
detector,
sampleRanges,
entityName,
false,
new ThreadedActionListener<>(
logger,
threadPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.argThat;
Expand Down Expand Up @@ -517,10 +518,10 @@ public void getPreviewFeatureForEntity() throws IOException {
coldStartSamples.add(Optional.of(new double[] { 30.0 }));

doAnswer(invocation -> {
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(3);
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(4);
listener.onResponse(coldStartSamples);
return null;
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), any());
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any());

ActionListener<Features> listener = mock(ActionListener.class);

Expand All @@ -541,10 +542,10 @@ public void getPreviewFeatureForEntity_noDataToPreview() throws IOException {
Entity entity = new Entity("fieldName", "value");

doAnswer(invocation -> {
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(3);
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(4);
listener.onResponse(new ArrayList<>());
return null;
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), any());
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any());

ActionListener<Features> listener = mock(ActionListener.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.amazon.opendistroforelasticsearch.ad.ml;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
Expand Down Expand Up @@ -202,10 +203,10 @@ public void testColdStart() throws InterruptedException, IOException {
coldStartSamples.add(Optional.of(new double[] { 1.0 }));
coldStartSamples.add(Optional.of(new double[] { -19.0 }));
doAnswer(invocation -> {
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(3);
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(4);
listener.onResponse(coldStartSamples);
return null;
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), any());
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any());

entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState);

Expand Down Expand Up @@ -268,7 +269,7 @@ public void testMissMin() throws IOException {

entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState);

verify(searchFeatureDao, never()).getColdStartSamplesForPeriods(any(), any(), any(), any());
verify(searchFeatureDao, never()).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any());

RandomCutForest forest = model.getRcf();
assertTrue(forest == null);
Expand All @@ -294,10 +295,10 @@ public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOExc
coldStartSamples.add(Optional.empty());
coldStartSamples.add(Optional.of(new double[] { -17.0 }));
doAnswer(invocation -> {
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(3);
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(4);
listener.onResponse(coldStartSamples);
return null;
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), any());
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any());

entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState);

Expand Down Expand Up @@ -336,10 +337,10 @@ public void testTwoSegments() throws InterruptedException, IOException {
coldStartSamples.add(Optional.of(new double[] { -17.0 }));
coldStartSamples.add(Optional.of(new double[] { -38.0 }));
doAnswer(invocation -> {
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(3);
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(4);
listener.onResponse(coldStartSamples);
return null;
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), any());
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any());

entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState);

Expand Down Expand Up @@ -408,10 +409,10 @@ public void testNotEnoughSamples() throws InterruptedException, IOException {
coldStartSamples.add(Optional.of(new double[] { 57.0 }));
coldStartSamples.add(Optional.of(new double[] { 1.0 }));
doAnswer(invocation -> {
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(3);
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(4);
listener.onResponse(coldStartSamples);
return null;
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), any());
}).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any());

entityColdStarter.trainModel(samples, modelId, entityName, detectorId, modelState);

Expand Down

0 comments on commit 5d30bb9

Please sign in to comment.