Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Add data structure for entity model
Browse files Browse the repository at this point in the history
This PR adds a new kind of model state called EntityModel.  An entity state contains RCF and threshold models, priority, and recent sets of values (128 sets at most) arising at AD job run time and their start/end timestamps.

Testing done:
1. These data structures are used throughout our code.  They are covered by their callers' unit tests.
  • Loading branch information
kaituo committed Oct 15, 2020
1 parent 2175fde commit 15c2fc0
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.ad.ml;

import java.util.Queue;

import com.amazon.randomcutforest.RandomCutForest;

public class EntityModel {
private String modelId;
// TODO: sample should record timestamp
private Queue<double[]> samples;
private RandomCutForest rcf;
private ThresholdingModel threshold;

public EntityModel(String modelId, Queue<double[]> samples, RandomCutForest rcf, ThresholdingModel threshold) {
this.modelId = modelId;
this.samples = samples;
this.rcf = rcf;
this.threshold = threshold;
}

public String getModelId() {
return this.modelId;
}

public Queue<double[]> getSamples() {
return this.samples;
}

public void addSample(double[] sample) {
this.samples.add(sample);
}

public RandomCutForest getRcf() {
return this.rcf;
}

public ThresholdingModel getThreshold() {
return this.threshold;
}

public void setRcf(RandomCutForest rcf) {
this.rcf = rcf;
}

public void setThreshold(ThresholdingModel threshold) {
this.threshold = threshold;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,35 @@

package com.amazon.opendistroforelasticsearch.ad.ml;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;

import com.amazon.opendistroforelasticsearch.ad.ExpiringState;

/**
* A ML model and states such as usage.
*/
public class ModelState<T> {
public class ModelState<T> implements ExpiringState {

public static String MODEL_ID_KEY = "model_id";
public static String DETECTOR_ID_KEY = "detector_id";
public static String MODEL_TYPE_KEY = "model_type";
public static String LAST_USED_TIME_KEY = "last_used_time";
public static String LAST_CHECKPOINT_TIME_KEY = "last_checkpoint_time";
public static String PRIORITY = "priority";

private T model;
private String modelId;
private String detectorId;
private String modelType;
// time when the ML model was used last time
private Instant lastUsedTime;
private Instant lastCheckpointTime;
private Clock clock;
private float priority;

/**
* Constructor.
Expand All @@ -44,15 +52,41 @@ public class ModelState<T> {
* @param modelId Id of model partition
* @param detectorId Id of detector this model partition is used for
* @param modelType type of model
* @param lastUsedTime time when the ML model was used last time
* @param clock UTC clock
* @param priority Priority of the model state. Used in multi-entity detectors' cache.
*/
public ModelState(T model, String modelId, String detectorId, String modelType, Instant lastUsedTime) {
public ModelState(T model, String modelId, String detectorId, String modelType, Clock clock, float priority) {
this.model = model;
this.modelId = modelId;
this.detectorId = detectorId;
this.modelType = modelType;
this.lastUsedTime = lastUsedTime;
this.lastUsedTime = clock.instant();
// this is inaccurate until we find the last checkpoint time from disk
this.lastCheckpointTime = Instant.MIN;
this.clock = clock;
this.priority = priority;
}

/**
* Create state with zero priority. Used in single-entity detector.
*
* @param <T> Model object's type
* @param model The actual model object
* @param modelId Model Id
* @param detectorId Detector Id
* @param modelType Model type like RCF model
* @param clock UTC clock
*
* @return the created model state
*/
public static <T> ModelState<T> createSingleEntityModelState(
T model,
String modelId,
String detectorId,
String modelType,
Clock clock
) {
return new ModelState<>(model, modelId, detectorId, modelType, clock, 0f);
}

/**
Expand All @@ -64,6 +98,10 @@ public T getModel() {
return this.model;
}

public void setModel(T model) {
this.model = model;
}

/**
* Gets the model ID
*
Expand Down Expand Up @@ -127,6 +165,18 @@ public void setLastCheckpointTime(Instant lastCheckpointTime) {
this.lastCheckpointTime = lastCheckpointTime;
}

/**
* Returns priority of the ModelState
* @return the priority
*/
public float getPriority() {
return priority;
}

public void setPriority(float priority) {
this.priority = priority;
}

/**
* Gets the Model State as a map
*
Expand All @@ -140,7 +190,13 @@ public Map<String, Object> getModelStateAsMap() {
put(MODEL_TYPE_KEY, modelType);
put(LAST_USED_TIME_KEY, lastUsedTime);
put(LAST_CHECKPOINT_TIME_KEY, lastCheckpointTime);
put(PRIORITY, priority);
}
};
}

@Override
public boolean expired(Duration stateTtl) {
return expired(lastUsedTime, stateTtl, clock.instant());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,10 @@ public void setup() {
expectedResults = new ArrayList<>(
Arrays
.asList(
new ModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock.instant()),
new ModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock.instant()),
new ModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock.instant()),
new ModelState<>(
thresholdingModel,
"thr-model-2",
"detector-2",
ModelManager.ModelType.THRESHOLD.getName(),
clock.instant()
)
new ModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f),
new ModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f),
new ModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f),
new ModelState<>(thresholdingModel, "thr-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f)
)
);

Expand Down

0 comments on commit 15c2fc0

Please sign in to comment.