From cd3bdaea465cabb0d72a76df14c5ea093a0d5080 Mon Sep 17 00:00:00 2001 From: Lai <57818076+wnbts@users.noreply.github.com> Date: Thu, 9 Jul 2020 17:26:55 -0700 Subject: [PATCH] change to exhausive search for training data (#184) --- .../ad/AnomalyDetectorPlugin.java | 2 + .../ad/constant/CommonErrorMessages.java | 1 + .../ad/feature/FeatureManager.java | 96 +++++++++++++---- .../ad/settings/AnomalyDetectorSettings.java | 4 + .../ad/feature/FeatureManagerTests.java | 102 ++++++++++++++---- 5 files changed, 166 insertions(+), 39 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java index 12c8759e..0acff42e 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java @@ -317,6 +317,8 @@ public Collection createComponents( clock, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.SHINGLE_SIZE, AnomalyDetectorSettings.MAX_MISSING_POINTS, AnomalyDetectorSettings.MAX_NEIGHBOR_DISTANCE, diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonErrorMessages.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonErrorMessages.java index bc9c040e..44361d17 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonErrorMessages.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonErrorMessages.java @@ -25,4 +25,5 @@ public class CommonErrorMessages { public static final String FEATURE_NOT_AVAILABLE_ERR_MSG = "No Feature in current detection window."; public static final String MEMORY_CIRCUIT_BROKEN_ERR_MSG = "AD memory circuit is broken."; public static final String DISABLED_ERR_MSG = "AD plugin is disabled. To enable update opendistro.anomaly_detection.enabled to true"; + public static final String INVALID_SEARCH_QUERY_MSG = "Invalid search query."; } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java index fc53ad5a..6aef9e4b 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java @@ -27,6 +27,7 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Deque; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -41,6 +42,8 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; +import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages; import com.amazon.opendistroforelasticsearch.ad.dataprocessor.Interpolator; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.opendistroforelasticsearch.ad.model.IntervalTimeConfiguration; @@ -61,6 +64,8 @@ public class FeatureManager { private final int maxTrainSamples; private final int maxSampleStride; + private final int trainSampleTimeRangeInHours; + private final int minTrainSamples; private final int shingleSize; private final int maxMissingPoints; private final int maxNeighborDistance; @@ -76,6 +81,8 @@ public class FeatureManager { * @param clock clock for system time * @param maxTrainSamples max number of samples from search * @param maxSampleStride max stride between uninterpolated train samples + * @param trainSampleTimeRangeInHours time range in hours for collect train samples + * @param minTrainSamples min number of train samples * @param shingleSize size of feature shingles * @param maxMissingPoints max number of missing points allowed to generate a shingle * @param maxNeighborDistance max distance (number of intervals) between a missing point and a replacement neighbor @@ -89,6 +96,8 @@ public FeatureManager( Clock clock, int maxTrainSamples, int maxSampleStride, + int trainSampleTimeRangeInHours, + int minTrainSamples, int shingleSize, int maxMissingPoints, int maxNeighborDistance, @@ -101,6 +110,8 @@ public FeatureManager( this.clock = clock; this.maxTrainSamples = maxTrainSamples; this.maxSampleStride = maxSampleStride; + this.trainSampleTimeRangeInHours = trainSampleTimeRangeInHours; + this.minTrainSamples = minTrainSamples; this.shingleSize = shingleSize; this.maxMissingPoints = maxMissingPoints; this.maxNeighborDistance = maxNeighborDistance; @@ -225,11 +236,11 @@ public Optional getColdStartData(AnomalyDetector detector) { * Returns to listener data for cold-start training. * * Training data starts with getting samples from (costly) search. - * Samples are increased in size via interpolation and then - * in dimension via shingling. + * Samples are increased in dimension via shingling. * * @param detector contains data info (indices, documents, etc) * @param listener onResponse is called with data for cold-start training, or empty if unavailable + * onFailure is called with EndRunException on feature query creation errors */ public void getColdStartData(AnomalyDetector detector, ActionListener> listener) { searchFeatureDao @@ -241,30 +252,75 @@ public void getColdStartData(AnomalyDetector detector, ActionListener latest, AnomalyDetector detector, ActionListener> listener) { if (latest.isPresent()) { - searchFeatureDao - .getFeaturesForSampledPeriods( - detector, - maxTrainSamples, - maxSampleStride, - latest.get(), - ActionListener.wrap(samples -> processColdStartSamples(samples, listener), listener::onFailure) - ); + List> sampleRanges = getColdStartSampleRanges(detector, latest.get()); + try { + searchFeatureDao + .getFeatureSamplesForPeriods( + detector, + sampleRanges, + ActionListener.wrap(samples -> processColdStartSamples(samples, listener), listener::onFailure) + ); + } catch (IOException e) { + listener.onFailure(new EndRunException(detector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, true)); + } } else { listener.onResponse(Optional.empty()); } } - private void processColdStartSamples(Optional> samples, ActionListener> listener) { - listener - .onResponse( - samples - .map( - results -> transpose( - interpolator.interpolate(transpose(results.getKey()), results.getValue() * (results.getKey().length - 1) + 1) - ) - ) - .map(points -> batchShingle(points, shingleSize)) + private void processColdStartSamples(List> samples, ActionListener> listener) { + List shingles = new ArrayList<>(); + LinkedList> currentShingle = new LinkedList<>(); + for (Optional sample : samples) { + currentShingle.addLast(sample); + if (currentShingle.size() == this.shingleSize) { + sample.ifPresent(s -> fillAndShingle(currentShingle, this.shingleSize).ifPresent(shingles::add)); + currentShingle.remove(); + } + } + listener.onResponse(Optional.of(shingles.toArray(new double[0][0])).filter(results -> results.length > 0)); + } + + private Optional fillAndShingle(LinkedList> shingle, int shingleSize) { + Optional result = null; + if (shingle.stream().filter(s -> s.isPresent()).count() >= shingleSize - this.maxMissingPoints) { + TreeMap search = new TreeMap<>( + IntStream + .range(0, shingleSize) + .filter(i -> shingle.get(i).isPresent()) + .boxed() + .collect(Collectors.toMap(i -> i, i -> shingle.get(i).get())) ); + result = Optional.of(IntStream.range(0, shingleSize).mapToObj(i -> { + Optional> after = Optional.ofNullable(search.ceilingEntry(i)); + Optional> before = Optional.ofNullable(search.floorEntry(i)); + return after + .filter(a -> Math.abs(i - a.getKey()) <= before.map(b -> Math.abs(i - b.getKey())).orElse(Integer.MAX_VALUE)) + .map(Optional::of) + .orElse(before) + .filter(e -> Math.abs(i - e.getKey()) <= maxNeighborDistance) + .map(Entry::getValue) + .orElse(null); + }).filter(d -> d != null).toArray(double[][]::new)) + .filter(d -> d.length == shingleSize) + .map(d -> batchShingle(d, shingleSize)[0]); + } else { + result = Optional.empty(); + } + return result; + } + + private List> getColdStartSampleRanges(AnomalyDetector detector, long endMillis) { + long interval = getDetectionIntervalInMillis(detector); + int numSamples = Math.max((int) (Duration.ofHours(this.trainSampleTimeRangeInHours).toMillis() / interval), this.minTrainSamples); + return IntStream + .rangeClosed(1, numSamples) + .mapToObj(i -> new SimpleImmutableEntry<>(endMillis - (numSamples - i + 1) * interval, endMillis - (numSamples - i) * interval)) + .collect(Collectors.toList()); + } + + private long getDetectionIntervalInMillis(AnomalyDetector detector) { + return ((IntervalTimeConfiguration) detector.getDetectionInterval()).toDuration().toMillis(); } /** diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java index fd84fe50..a0067e93 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java @@ -192,6 +192,10 @@ private AnomalyDetectorSettings() {} public static final int MAX_SAMPLE_STRIDE = 64; + public static final int TRAIN_SAMPLE_TIME_RANGE_IN_HOURS = 24; + + public static final int MIN_TRAIN_SAMPLES = 512; + public static final int SHINGLE_SIZE = 8; public static final int MAX_MISSING_POINTS = Math.min(2, SHINGLE_SIZE); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManagerTests.java index b6ed0e31..0d14105d 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManagerTests.java @@ -16,6 +16,7 @@ package com.amazon.opendistroforelasticsearch.ad.feature; import static java.util.Arrays.asList; +import static java.util.Optional.empty; import static java.util.Optional.ofNullable; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -24,6 +25,7 @@ import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -51,6 +53,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException; import com.amazon.opendistroforelasticsearch.ad.dataprocessor.Interpolator; import com.amazon.opendistroforelasticsearch.ad.dataprocessor.LinearUniformInterpolator; import com.amazon.opendistroforelasticsearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator; @@ -66,6 +69,8 @@ public class FeatureManagerTests { // configuration private int maxTrainSamples; private int maxSampleStride; + private int trainSampleTimeRangeInHours; + private int minTrainSamples; private int shingleSize; private int maxMissingPoints; private int maxNeighborDistance; @@ -96,6 +101,8 @@ public void setup() { maxTrainSamples = 24; maxSampleStride = 100; + trainSampleTimeRangeInHours = 1; + minTrainSamples = 4; shingleSize = 3; maxMissingPoints = 2; maxNeighborDistance = 2; @@ -114,6 +121,8 @@ public void setup() { clock, maxTrainSamples, maxSampleStride, + trainSampleTimeRangeInHours, + minTrainSamples, shingleSize, maxMissingPoints, maxNeighborDistance, @@ -151,15 +160,37 @@ public void getColdStartData_returnExpected(Long latestTime, Entry> ranges = asList( + entry(0L, 900_000L), + entry(900_000L, 1_800_000L), + entry(1_800_000L, 2_700_000L), + entry(2_700_000L, 3_600_000L) + ); + return new Object[] { + new Object[] { 3_600_000L, ranges, asList(ar(1), ar(2), ar(3), ar(4)), new double[][] { { 1, 2, 3, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(), ar(2), ar(3), ar(4)), new double[][] { { 2, 2, 3, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(1), ar(), ar(3), ar(4)), new double[][] { { 1, 3, 3, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(1), ar(2), ar(), ar(4)), new double[][] { { 1, 2, 4, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(1), ar(), ar(), ar(4)), new double[][] { { 1, 1, 4, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(), ar(2), ar(), ar(4)), new double[][] { { 2, 2, 4, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(), ar(), ar(3), ar(4)), null }, + new Object[] { 3_600_000L, ranges, asList(ar(1), empty(), empty(), empty()), null }, + new Object[] { 3_600_000L, ranges, asList(empty(), empty(), empty(), ar(4)), null }, + new Object[] { 3_600_000L, ranges, asList(empty(), empty(), empty(), empty()), null }, + new Object[] { null, null, null, null } }; + } + @Test @SuppressWarnings("unchecked") - @Parameters(method = "getColdStartDataTestData") + @Parameters(method = "getTrainDataTestData") public void getColdStartData_returnExpectedToListener( Long latestTime, - Entry data, - int interpolants, + List> sampleRanges, + List> samples, double[][] expected - ) { + ) throws Exception { + when(detector.getDetectionInterval()).thenReturn(new IntervalTimeConfiguration(15, ChronoUnit.MINUTES)); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.ofNullable(latestTime)); @@ -167,25 +198,30 @@ public void getColdStartData_returnExpectedToListener( }).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class)); if (latestTime != null) { doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(4); - listener.onResponse(ofNullable(data)); + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(samples); return null; - }) - .when(searchFeatureDao) - .getFeaturesForSampledPeriods( - eq(detector), - eq(maxTrainSamples), - eq(maxSampleStride), - eq(latestTime), - any(ActionListener.class) - ); - } - if (data != null) { - when(interpolator.interpolate(argThat(new ArrayEqMatcher<>(data.getKey())), eq(interpolants))).thenReturn(data.getKey()); - doReturn(data.getKey()).when(featureManager).batchShingle(argThat(new ArrayEqMatcher<>(data.getKey())), eq(shingleSize)); + }).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), eq(sampleRanges), any(ActionListener.class)); } ActionListener> listener = mock(ActionListener.class); + featureManager = spy( + new FeatureManager( + searchFeatureDao, + interpolator, + clock, + maxTrainSamples, + maxSampleStride, + trainSampleTimeRangeInHours, + minTrainSamples, + 4, /*shingleSize*/ + 2, /*maxMissingPoints*/ + 1, /*maxNeighborDistance*/ + previewSampleRate, + maxPreviewSamples, + featureBufferTtl + ) + ); featureManager.getColdStartData(detector, listener); ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); @@ -209,6 +245,22 @@ public void getColdStartData_throwToListener_whenSearchFail() { verify(listener).onFailure(any(Exception.class)); } + @Test + @SuppressWarnings("unchecked") + public void getColdStartData_throwToListener_onQueryCreationError() throws Exception { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.ofNullable(0L)); + return null; + }).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class)); + doThrow(IOException.class).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), any(), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + featureManager.getColdStartData(detector, listener); + + verify(listener).onFailure(any(EndRunException.class)); + } + private Object[] batchShingleData() { return new Object[] { new Object[] { new double[][] { { 1.0 } }, 1, new double[][] { { 1.0 } } }, @@ -412,4 +464,16 @@ public void getPreviewFeatures_returnExceptionToListener_whenNoDataToPreview() t public void getPreviewFeatures_returnExceptionToListener_whenQueryFail() throws IOException { getPreviewFeaturesTemplate(asList(Optional.of(new double[] { 1 }), Optional.of(new double[] { 3 })), false, false); } + + private Entry entry(K key, V value) { + return new SimpleEntry<>(key, value); + } + + private Optional ar(double... values) { + if (values.length == 0) { + return Optional.empty(); + } else { + return Optional.of(values); + } + } }