Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
change to exhausive search for training data (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
wnbts authored Jul 10, 2020
1 parent 4bf1818 commit cd3bdae
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ public Collection<Object> createComponents(
clock,
AnomalyDetectorSettings.MAX_TRAIN_SAMPLE,
AnomalyDetectorSettings.MAX_SAMPLE_STRIDE,
AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS,
AnomalyDetectorSettings.MIN_TRAIN_SAMPLES,
AnomalyDetectorSettings.SHINGLE_SIZE,
AnomalyDetectorSettings.MAX_MISSING_POINTS,
AnomalyDetectorSettings.MAX_NEIGHBOR_DISTANCE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ public class CommonErrorMessages {
public static final String FEATURE_NOT_AVAILABLE_ERR_MSG = "No Feature in current detection window.";
public static final String MEMORY_CIRCUIT_BROKEN_ERR_MSG = "AD memory circuit is broken.";
public static final String DISABLED_ERR_MSG = "AD plugin is disabled. To enable update opendistro.anomaly_detection.enabled to true";
public static final String INVALID_SEARCH_QUERY_MSG = "Invalid search query.";
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand All @@ -41,6 +42,8 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;

import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages;
import com.amazon.opendistroforelasticsearch.ad.dataprocessor.Interpolator;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.model.IntervalTimeConfiguration;
Expand All @@ -61,6 +64,8 @@ public class FeatureManager {

private final int maxTrainSamples;
private final int maxSampleStride;
private final int trainSampleTimeRangeInHours;
private final int minTrainSamples;
private final int shingleSize;
private final int maxMissingPoints;
private final int maxNeighborDistance;
Expand All @@ -76,6 +81,8 @@ public class FeatureManager {
* @param clock clock for system time
* @param maxTrainSamples max number of samples from search
* @param maxSampleStride max stride between uninterpolated train samples
* @param trainSampleTimeRangeInHours time range in hours for collect train samples
* @param minTrainSamples min number of train samples
* @param shingleSize size of feature shingles
* @param maxMissingPoints max number of missing points allowed to generate a shingle
* @param maxNeighborDistance max distance (number of intervals) between a missing point and a replacement neighbor
Expand All @@ -89,6 +96,8 @@ public FeatureManager(
Clock clock,
int maxTrainSamples,
int maxSampleStride,
int trainSampleTimeRangeInHours,
int minTrainSamples,
int shingleSize,
int maxMissingPoints,
int maxNeighborDistance,
Expand All @@ -101,6 +110,8 @@ public FeatureManager(
this.clock = clock;
this.maxTrainSamples = maxTrainSamples;
this.maxSampleStride = maxSampleStride;
this.trainSampleTimeRangeInHours = trainSampleTimeRangeInHours;
this.minTrainSamples = minTrainSamples;
this.shingleSize = shingleSize;
this.maxMissingPoints = maxMissingPoints;
this.maxNeighborDistance = maxNeighborDistance;
Expand Down Expand Up @@ -225,11 +236,11 @@ public Optional<double[][]> getColdStartData(AnomalyDetector detector) {
* Returns to listener data for cold-start training.
*
* Training data starts with getting samples from (costly) search.
* Samples are increased in size via interpolation and then
* in dimension via shingling.
* Samples are increased in dimension via shingling.
*
* @param detector contains data info (indices, documents, etc)
* @param listener onResponse is called with data for cold-start training, or empty if unavailable
* onFailure is called with EndRunException on feature query creation errors
*/
public void getColdStartData(AnomalyDetector detector, ActionListener<Optional<double[][]>> listener) {
searchFeatureDao
Expand All @@ -241,30 +252,75 @@ public void getColdStartData(AnomalyDetector detector, ActionListener<Optional<d

private void getColdStartSamples(Optional<Long> latest, AnomalyDetector detector, ActionListener<Optional<double[][]>> listener) {
if (latest.isPresent()) {
searchFeatureDao
.getFeaturesForSampledPeriods(
detector,
maxTrainSamples,
maxSampleStride,
latest.get(),
ActionListener.wrap(samples -> processColdStartSamples(samples, listener), listener::onFailure)
);
List<Entry<Long, Long>> sampleRanges = getColdStartSampleRanges(detector, latest.get());
try {
searchFeatureDao
.getFeatureSamplesForPeriods(
detector,
sampleRanges,
ActionListener.wrap(samples -> processColdStartSamples(samples, listener), listener::onFailure)
);
} catch (IOException e) {
listener.onFailure(new EndRunException(detector.getDetectorId(), CommonErrorMessages.INVALID_SEARCH_QUERY_MSG, e, true));
}
} else {
listener.onResponse(Optional.empty());
}
}

private void processColdStartSamples(Optional<Entry<double[][], Integer>> samples, ActionListener<Optional<double[][]>> listener) {
listener
.onResponse(
samples
.map(
results -> transpose(
interpolator.interpolate(transpose(results.getKey()), results.getValue() * (results.getKey().length - 1) + 1)
)
)
.map(points -> batchShingle(points, shingleSize))
private void processColdStartSamples(List<Optional<double[]>> samples, ActionListener<Optional<double[][]>> listener) {
List<double[]> shingles = new ArrayList<>();
LinkedList<Optional<double[]>> currentShingle = new LinkedList<>();
for (Optional<double[]> sample : samples) {
currentShingle.addLast(sample);
if (currentShingle.size() == this.shingleSize) {
sample.ifPresent(s -> fillAndShingle(currentShingle, this.shingleSize).ifPresent(shingles::add));
currentShingle.remove();
}
}
listener.onResponse(Optional.of(shingles.toArray(new double[0][0])).filter(results -> results.length > 0));
}

private Optional<double[]> fillAndShingle(LinkedList<Optional<double[]>> shingle, int shingleSize) {
Optional<double[]> result = null;
if (shingle.stream().filter(s -> s.isPresent()).count() >= shingleSize - this.maxMissingPoints) {
TreeMap<Integer, double[]> search = new TreeMap<>(
IntStream
.range(0, shingleSize)
.filter(i -> shingle.get(i).isPresent())
.boxed()
.collect(Collectors.toMap(i -> i, i -> shingle.get(i).get()))
);
result = Optional.of(IntStream.range(0, shingleSize).mapToObj(i -> {
Optional<Entry<Integer, double[]>> after = Optional.ofNullable(search.ceilingEntry(i));
Optional<Entry<Integer, double[]>> before = Optional.ofNullable(search.floorEntry(i));
return after
.filter(a -> Math.abs(i - a.getKey()) <= before.map(b -> Math.abs(i - b.getKey())).orElse(Integer.MAX_VALUE))
.map(Optional::of)
.orElse(before)
.filter(e -> Math.abs(i - e.getKey()) <= maxNeighborDistance)
.map(Entry::getValue)
.orElse(null);
}).filter(d -> d != null).toArray(double[][]::new))
.filter(d -> d.length == shingleSize)
.map(d -> batchShingle(d, shingleSize)[0]);
} else {
result = Optional.empty();
}
return result;
}

private List<Entry<Long, Long>> getColdStartSampleRanges(AnomalyDetector detector, long endMillis) {
long interval = getDetectionIntervalInMillis(detector);
int numSamples = Math.max((int) (Duration.ofHours(this.trainSampleTimeRangeInHours).toMillis() / interval), this.minTrainSamples);
return IntStream
.rangeClosed(1, numSamples)
.mapToObj(i -> new SimpleImmutableEntry<>(endMillis - (numSamples - i + 1) * interval, endMillis - (numSamples - i) * interval))
.collect(Collectors.toList());
}

private long getDetectionIntervalInMillis(AnomalyDetector detector) {
return ((IntervalTimeConfiguration) detector.getDetectionInterval()).toDuration().toMillis();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ private AnomalyDetectorSettings() {}

public static final int MAX_SAMPLE_STRIDE = 64;

public static final int TRAIN_SAMPLE_TIME_RANGE_IN_HOURS = 24;

public static final int MIN_TRAIN_SAMPLES = 512;

public static final int SHINGLE_SIZE = 8;

public static final int MAX_MISSING_POINTS = Math.min(2, SHINGLE_SIZE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.amazon.opendistroforelasticsearch.ad.feature;

import static java.util.Arrays.asList;
import static java.util.Optional.empty;
import static java.util.Optional.ofNullable;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
Expand All @@ -24,6 +25,7 @@
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -51,6 +53,7 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException;
import com.amazon.opendistroforelasticsearch.ad.dataprocessor.Interpolator;
import com.amazon.opendistroforelasticsearch.ad.dataprocessor.LinearUniformInterpolator;
import com.amazon.opendistroforelasticsearch.ad.dataprocessor.SingleFeatureLinearUniformInterpolator;
Expand All @@ -66,6 +69,8 @@ public class FeatureManagerTests {
// configuration
private int maxTrainSamples;
private int maxSampleStride;
private int trainSampleTimeRangeInHours;
private int minTrainSamples;
private int shingleSize;
private int maxMissingPoints;
private int maxNeighborDistance;
Expand Down Expand Up @@ -96,6 +101,8 @@ public void setup() {

maxTrainSamples = 24;
maxSampleStride = 100;
trainSampleTimeRangeInHours = 1;
minTrainSamples = 4;
shingleSize = 3;
maxMissingPoints = 2;
maxNeighborDistance = 2;
Expand All @@ -114,6 +121,8 @@ public void setup() {
clock,
maxTrainSamples,
maxSampleStride,
trainSampleTimeRangeInHours,
minTrainSamples,
shingleSize,
maxMissingPoints,
maxNeighborDistance,
Expand Down Expand Up @@ -151,41 +160,68 @@ public void getColdStartData_returnExpected(Long latestTime, Entry<double[][], I
assertTrue(Arrays.deepEquals(expected, results.orElse(null)));
}

private Object[] getTrainDataTestData() {
List<Entry<Long, Long>> ranges = asList(
entry(0L, 900_000L),
entry(900_000L, 1_800_000L),
entry(1_800_000L, 2_700_000L),
entry(2_700_000L, 3_600_000L)
);
return new Object[] {
new Object[] { 3_600_000L, ranges, asList(ar(1), ar(2), ar(3), ar(4)), new double[][] { { 1, 2, 3, 4 } } },
new Object[] { 3_600_000L, ranges, asList(ar(), ar(2), ar(3), ar(4)), new double[][] { { 2, 2, 3, 4 } } },
new Object[] { 3_600_000L, ranges, asList(ar(1), ar(), ar(3), ar(4)), new double[][] { { 1, 3, 3, 4 } } },
new Object[] { 3_600_000L, ranges, asList(ar(1), ar(2), ar(), ar(4)), new double[][] { { 1, 2, 4, 4 } } },
new Object[] { 3_600_000L, ranges, asList(ar(1), ar(), ar(), ar(4)), new double[][] { { 1, 1, 4, 4 } } },
new Object[] { 3_600_000L, ranges, asList(ar(), ar(2), ar(), ar(4)), new double[][] { { 2, 2, 4, 4 } } },
new Object[] { 3_600_000L, ranges, asList(ar(), ar(), ar(3), ar(4)), null },
new Object[] { 3_600_000L, ranges, asList(ar(1), empty(), empty(), empty()), null },
new Object[] { 3_600_000L, ranges, asList(empty(), empty(), empty(), ar(4)), null },
new Object[] { 3_600_000L, ranges, asList(empty(), empty(), empty(), empty()), null },
new Object[] { null, null, null, null } };
}

@Test
@SuppressWarnings("unchecked")
@Parameters(method = "getColdStartDataTestData")
@Parameters(method = "getTrainDataTestData")
public void getColdStartData_returnExpectedToListener(
Long latestTime,
Entry<double[][], Integer> data,
int interpolants,
List<Entry<Long, Long>> sampleRanges,
List<Optional<double[]>> samples,
double[][] expected
) {
) throws Exception {
when(detector.getDetectionInterval()).thenReturn(new IntervalTimeConfiguration(15, ChronoUnit.MINUTES));
doAnswer(invocation -> {
ActionListener<Optional<Long>> listener = invocation.getArgument(1);
listener.onResponse(Optional.ofNullable(latestTime));
return null;
}).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class));
if (latestTime != null) {
doAnswer(invocation -> {
ActionListener<Optional<Entry<double[][], Integer>>> listener = invocation.getArgument(4);
listener.onResponse(ofNullable(data));
ActionListener<List<Optional<double[]>>> listener = invocation.getArgument(2);
listener.onResponse(samples);
return null;
})
.when(searchFeatureDao)
.getFeaturesForSampledPeriods(
eq(detector),
eq(maxTrainSamples),
eq(maxSampleStride),
eq(latestTime),
any(ActionListener.class)
);
}
if (data != null) {
when(interpolator.interpolate(argThat(new ArrayEqMatcher<>(data.getKey())), eq(interpolants))).thenReturn(data.getKey());
doReturn(data.getKey()).when(featureManager).batchShingle(argThat(new ArrayEqMatcher<>(data.getKey())), eq(shingleSize));
}).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), eq(sampleRanges), any(ActionListener.class));
}

ActionListener<Optional<double[][]>> listener = mock(ActionListener.class);
featureManager = spy(
new FeatureManager(
searchFeatureDao,
interpolator,
clock,
maxTrainSamples,
maxSampleStride,
trainSampleTimeRangeInHours,
minTrainSamples,
4, /*shingleSize*/
2, /*maxMissingPoints*/
1, /*maxNeighborDistance*/
previewSampleRate,
maxPreviewSamples,
featureBufferTtl
)
);
featureManager.getColdStartData(detector, listener);

ArgumentCaptor<Optional<double[][]>> captor = ArgumentCaptor.forClass(Optional.class);
Expand All @@ -209,6 +245,22 @@ public void getColdStartData_throwToListener_whenSearchFail() {
verify(listener).onFailure(any(Exception.class));
}

@Test
@SuppressWarnings("unchecked")
public void getColdStartData_throwToListener_onQueryCreationError() throws Exception {
doAnswer(invocation -> {
ActionListener<Optional<Long>> listener = invocation.getArgument(1);
listener.onResponse(Optional.ofNullable(0L));
return null;
}).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class));
doThrow(IOException.class).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), any(), any(ActionListener.class));

ActionListener<Optional<double[][]>> listener = mock(ActionListener.class);
featureManager.getColdStartData(detector, listener);

verify(listener).onFailure(any(EndRunException.class));
}

private Object[] batchShingleData() {
return new Object[] {
new Object[] { new double[][] { { 1.0 } }, 1, new double[][] { { 1.0 } } },
Expand Down Expand Up @@ -412,4 +464,16 @@ public void getPreviewFeatures_returnExceptionToListener_whenNoDataToPreview() t
public void getPreviewFeatures_returnExceptionToListener_whenQueryFail() throws IOException {
getPreviewFeaturesTemplate(asList(Optional.of(new double[] { 1 }), Optional.of(new double[] { 3 })), false, false);
}

private <K, V> Entry<K, V> entry(K key, V value) {
return new SimpleEntry<>(key, value);
}

private Optional<double[]> ar(double... values) {
if (values.length == 0) {
return Optional.empty();
} else {
return Optional.of(values);
}
}
}

0 comments on commit cd3bdae

Please sign in to comment.