Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Add transport action for model inference #249

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.ad.transport;

import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedResponse;

public class EntityResultAction extends ActionType<AcknowledgedResponse> {
public static final EntityResultAction INSTANCE = new EntityResultAction();
public static final String NAME = "cluster:admin/opendistro/ad/entity/result";

private EntityResultAction() {
super(NAME, AcknowledgedResponse::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.ad.transport;

import static org.elasticsearch.action.ValidateActions.addValidationError;

import java.io.IOException;
import java.util.Locale;
import java.util.Map;

import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;

import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonMessageAttributes;

public class EntityResultRequest extends ActionRequest implements ToXContentObject {

private String detectorId;
private Map<String, double[]> entities;
private long start;
private long end;

public EntityResultRequest(StreamInput in) throws IOException {
super(in);
this.detectorId = in.readString();
this.entities = in.readMap(StreamInput::readString, StreamInput::readDoubleArray);
this.start = in.readLong();
this.end = in.readLong();
}

public EntityResultRequest(String detectorId, Map<String, double[]> entities, long start, long end) {
super();
this.detectorId = detectorId;
this.entities = entities;
this.start = start;
this.end = end;
}

public String getDetectorId() {
return this.detectorId;
}

public Map<String, double[]> getEntities() {
return this.entities;
}

public long getStart() {
return this.start;
}

public long getEnd() {
return this.end;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.detectorId);
out.writeMap(this.entities, StreamOutput::writeString, StreamOutput::writeDoubleArray);
out.writeLong(this.start);
out.writeLong(this.end);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = null;
if (Strings.isEmpty(detectorId)) {
validationException = addValidationError(CommonErrorMessages.AD_ID_MISSING_MSG, validationException);
}
if (start <= 0 || end <= 0 || start > end) {
validationException = addValidationError(
String.format(Locale.ROOT, "%s: start %d, end %d", CommonErrorMessages.INVALID_TIMESTAMP_ERR_MSG, start, end),
validationException
);
}
return validationException;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CommonMessageAttributes.ID_JSON_KEY, detectorId);
builder.field(CommonMessageAttributes.START_JSON_KEY, start);
builder.field(CommonMessageAttributes.END_JSON_KEY, end);
for (String entity : entities.keySet()) {
builder.field(entity, entities.get(entity));
}
builder.endObject();
return builder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.ad.transport;

import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Map.Entry;
import java.util.Optional;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;

import com.amazon.opendistroforelasticsearch.ad.NodeStateManager;
import com.amazon.opendistroforelasticsearch.ad.breaker.ADCircuitBreakerService;
import com.amazon.opendistroforelasticsearch.ad.caching.CacheProvider;
import com.amazon.opendistroforelasticsearch.ad.caching.EntityCache;
import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages;
import com.amazon.opendistroforelasticsearch.ad.ml.CheckpointDao;
import com.amazon.opendistroforelasticsearch.ad.ml.EntityModel;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelState;
import com.amazon.opendistroforelasticsearch.ad.ml.ThresholdingResult;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyResult;
import com.amazon.opendistroforelasticsearch.ad.model.Entity;
import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings;
import com.amazon.opendistroforelasticsearch.ad.transport.handler.MultitiEntityResultHandler;
import com.amazon.opendistroforelasticsearch.ad.util.ParseUtils;

public class EntityResultTransportAction extends HandledTransportAction<EntityResultRequest, AcknowledgedResponse> {

private static final Logger LOG = LogManager.getLogger(EntityResultTransportAction.class);
private ModelManager manager;
private ADCircuitBreakerService adCircuitBreakerService;
private MultitiEntityResultHandler anomalyResultHandler;
private CheckpointDao checkpointDao;
private EntityCache cache;
private final NodeStateManager stateManager;
private final int coolDownMinutes;
private final Clock clock;

@Inject
public EntityResultTransportAction(
ActionFilters actionFilters,
TransportService transportService,
ModelManager manager,
ADCircuitBreakerService adCircuitBreakerService,
MultitiEntityResultHandler anomalyResultHandler,
CheckpointDao checkpointDao,
CacheProvider entityCache,
NodeStateManager stateManager,
Settings settings,
Clock clock
) {
super(EntityResultAction.NAME, transportService, actionFilters, EntityResultRequest::new);
this.manager = manager;
this.adCircuitBreakerService = adCircuitBreakerService;
this.anomalyResultHandler = anomalyResultHandler;
this.checkpointDao = checkpointDao;
this.cache = entityCache;
this.stateManager = stateManager;
this.coolDownMinutes = (int) (COOLDOWN_MINUTES.get(settings).getMinutes());
this.clock = clock;
}

@Override
protected void doExecute(Task task, EntityResultRequest request, ActionListener<AcknowledgedResponse> listener) {
if (adCircuitBreakerService.isOpen()) {
listener.onFailure(new LimitExceededException(request.getDetectorId(), CommonErrorMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG));
return;
}

try {
String detectorId = request.getDetectorId();
stateManager.getAnomalyDetector(detectorId, onGetDetector(listener, detectorId, request));
} catch (Exception exception) {
LOG.error("fail to get entity's anomaly grade", exception);
listener.onFailure(exception);
}

}

private ActionListener<Optional<AnomalyDetector>> onGetDetector(
ActionListener<AcknowledgedResponse> listener,
String detectorId,
EntityResultRequest request
) {
return ActionListener.wrap(detectorOptional -> {
if (!detectorOptional.isPresent()) {
listener.onFailure(new EndRunException(detectorId, "AnomalyDetector is not available.", true));
return;
}

AnomalyDetector detector = detectorOptional.get();
// we only support 1 categorical field now
String categoricalField = detector.getCategoryField().get(0);

ADResultBulkRequest currentBulkRequest = new ADResultBulkRequest();
// index pressure is high. Only save anomalies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, how did we determine the index pressure is high?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We get exceptions from ES if index pressure is high.

boolean onlySaveAnomalies = stateManager
.getLastIndexThrottledTime()
.plus(Duration.ofMinutes(coolDownMinutes))
.isAfter(clock.instant());

Instant executionStartTime = Instant.now();
for (Entry<String, double[]> entity : request.getEntities().entrySet()) {
String entityName = entity.getKey();
// For ES, the limit of the document ID is 512 bytes.
// skip an entity if the entity's name is more than 256 characters
// since we are using it as part of document id.
if (entityName.length() > AnomalyDetectorSettings.MAX_ENTITY_LENGTH) {
continue;
}

double[] datapoint = entity.getValue();
String modelId = manager.getEntityModelId(detectorId, entityName);
ModelState<EntityModel> entityModel = cache.get(modelId, detector, datapoint, entityName);
if (entityModel == null) {
// cache miss
continue;
}
ThresholdingResult result = manager.getAnomalyResultForEntity(detectorId, datapoint, entityName, entityModel, modelId);
// result.getRcfScore() = 0 means the model is not initialized
// result.getGrade() = 0 means it is not an anomaly
// So many EsRejectedExecutionException if we write no matter what
if (result.getRcfScore() > 0 && (!onlySaveAnomalies || result.getGrade() > 0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If model not initialized (rcfScore == 0) for a long time such as not enough data/some error, how does user know what's going on if we don't write AD result ? The init progress bar is only for most active entity.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the entity is in cache, profile can call getTotalUpdates(String detectorId, String entityId).

If not, profile API has to go to a checkpoint, load it to memory, and check its total updates.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cache/checkpoint will be cleared periodically. Have some concern for Ops, if user want to know what happens, we can only rely on service log. We can address this in next phase, it's ok if we have enough logs for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed.

this.anomalyResultHandler
.write(
new AnomalyResult(
detectorId,
result.getRcfScore(),
result.getGrade(),
result.getConfidence(),
ParseUtils.getFeatureData(datapoint, detector),
Instant.ofEpochMilli(request.getStart()),
Instant.ofEpochMilli(request.getEnd()),
executionStartTime,
Instant.now(),
null,
Arrays.asList(new Entity(categoricalField, entityName))
),
currentBulkRequest
);
}
}
this.anomalyResultHandler.flush(currentBulkRequest, detectorId);
// bulk all accumulated checkpoint requests
this.checkpointDao.flush();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about add listener as param of flush function ? If flush successfully, execute listener.onResponse(new AcknowledgedResponse(true)); , if fail, execute listener.onFailure. Otherwise user never know the flush succeed or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't do it because:
First, flush is asynchronous. We don't know when it is gonna finish. In our performance testing, we find our queue has 70 k checkpoints. Hold the job too long may fail the following jobs.
Second, I don't know what's the action item for customers if they know their checkpoints fail. Mostly they are agnostic to checkpoints.

Copy link
Contributor

@ylwu-amzn ylwu-amzn Oct 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't know when it is gonna finish. In our performance testing, we find our queue has 70 k checkpoints. Hold the job too long may fail the following jobs.

Make sense. Do we catch exception in checkpointDao and write in AD result ?

what's the action item for customers if they know their checkpoints fail

User know the system is under pressure, they can
1.Scale up/out their cluster
2.Stop some testing/low priority detector
3.Tune detector configuration like use less features, tune detector interval etc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't in checkpointDao. We write to logs. Write to state index might be an option. We can have a field called checkpointError. How's that?

The action item you mentioned great.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's ok to track in state index.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do it after release.


listener.onResponse(new AcknowledgedResponse(true));
}, exception -> {
LOG.error(
new ParameterizedMessage(
"fail to get entity's anomaly grade for detector [{}]: start: [{}], end: [{}]",
detectorId,
request.getStart(),
request.getEnd()
),
exception
);
listener.onFailure(exception);
});
}
}
Loading