Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Rename TransportState to NodeState
Browse files Browse the repository at this point in the history
Previously, we have TransportState and TransportStateManager to track states used by transport layers.  Now the state is not only used by the transport layer. Methods like getDetector are used by ModelManager etc.  This PR renames the class to reflect the fact.

This PR also modifies how we track whether a cold start is running or not.  Previously, the caller had to manually set it on and off.  And we have a code everywhere.  Now, we return a Releasable object that can be called automatically.  The current way is more concise and easier to avoid bugs.

This PR also adds states to track last index throttling time as we face index rejections issues.

Testing done:
1. will add unit tests.
2. end-to-end testing pass.
  • Loading branch information
kaituo committed Oct 15, 2020
1 parent 2175fde commit 6d19b63
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.ad.transport;
package com.amazon.opendistroforelasticsearch.ad;

import java.time.Clock;
import java.time.Duration;
Expand All @@ -27,16 +27,16 @@
* Storing intermediate state during the execution of transport action
*
*/
public class TransportState {
public class NodeState implements ExpiringState {
private String detectorId;
// detector definition
private AnomalyDetector detectorDef;
// number of partitions
private int partitonNumber;
// checkpoint fetch time
private Instant lastAccessTime;
// last detection error. Used by DetectorStateHandler to check if the error for a
// detector has changed or not. If changed, trigger indexing.
// last detection error recorded in result index. Used by DetectorStateHandler
// to check if the error for a detector has changed or not. If changed, trigger indexing.
private Optional<String> lastDetectionError;
// last training error. Used to save cold start error by a concurrent cold start thread.
private Optional<AnomalyDetectionException> lastColdStartException;
Expand All @@ -47,7 +47,7 @@ public class TransportState {
// cold start running flag to prevent concurrent cold start
private boolean coldStartRunning;

public TransportState(String detectorId, Clock clock) {
public NodeState(String detectorId, Clock clock) {
this.detectorId = detectorId;
this.detectorDef = null;
this.partitonNumber = -1;
Expand Down Expand Up @@ -182,7 +182,8 @@ private void refreshLastUpdateTime() {
* @param stateTtl time to leave for the state
* @return whether the transport state is expired
*/
@Override
public boolean expired(Duration stateTtl) {
return lastAccessTime.plus(stateTtl).isBefore(clock.instant());
return expired(lastAccessTime, stateTtl, clock.instant());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.ad.transport;
package com.amazon.opendistroforelasticsearch.ad;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -29,57 +30,76 @@
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;

import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonName;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelPartitioner;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.transport.BackPressureRouting;
import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil;

/**
* ADStateManager is used by transport layer to manage AnomalyDetector object
* and the number of partitions for a detector id.
* NodeStateManager is used to manage states shared by transport and ml components
* like AnomalyDetector object
*
*/
public class TransportStateManager {
private static final Logger LOG = LogManager.getLogger(TransportStateManager.class);
private ConcurrentHashMap<String, TransportState> transportStates;
public class NodeStateManager implements MaintenanceState, CleanState {
private static final Logger LOG = LogManager.getLogger(NodeStateManager.class);
private ConcurrentHashMap<String, NodeState> states;
private Client client;
private ModelManager modelManager;
private ModelPartitioner modelPartitioner;
private NamedXContentRegistry xContentRegistry;
private ClientUtil clientUtil;
// map from ES node id to the node's backpressureMuter
private Map<String, BackPressureRouting> backpressureMuter;
private final Clock clock;
private final Settings settings;
private final Duration stateTtl;
// last time we are throttled due to too much index pressure
private Instant lastIndexThrottledTime;

public static final String NO_ERROR = "no_error";

public TransportStateManager(
/**
* Constructor
*
* @param client Client to make calls to ElasticSearch
* @param xContentRegistry ES named content registry
* @param settings ES settings
* @param clientUtil AD Client utility
* @param clock A UTC clock
* @param stateTtl Max time to keep state in memory
* @param modelPartitioner Used to partiton a RCF forest
*/
public NodeStateManager(
Client client,
NamedXContentRegistry xContentRegistry,
ModelManager modelManager,
Settings settings,
ClientUtil clientUtil,
Clock clock,
Duration stateTtl
Duration stateTtl,
ModelPartitioner modelPartitioner
) {
this.transportStates = new ConcurrentHashMap<>();
this.states = new ConcurrentHashMap<>();
this.client = client;
this.modelManager = modelManager;
this.modelPartitioner = modelPartitioner;
this.xContentRegistry = xContentRegistry;
this.clientUtil = clientUtil;
this.backpressureMuter = new ConcurrentHashMap<>();
this.clock = clock;
this.settings = settings;
this.stateTtl = stateTtl;
this.lastIndexThrottledTime = Instant.MIN;
}

/**
Expand All @@ -90,20 +110,30 @@ public TransportStateManager(
* @throws LimitExceededException when there is no sufficient resource available
*/
public int getPartitionNumber(String adID, AnomalyDetector detector) {
TransportState state = transportStates.get(adID);
NodeState state = states.get(adID);
if (state != null && state.getPartitonNumber() > 0) {
return state.getPartitonNumber();
}

int partitionNum = modelManager.getPartitionedForestSizes(detector).getKey();
state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
int partitionNum = modelPartitioner.getPartitionedForestSizes(detector).getKey();
state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setPartitonNumber(partitionNum);

return partitionNum;
}

/**
* Get Detector config object if present
* @param adID detector Id
* @return the Detecor config object or empty Optional
*/
public Optional<AnomalyDetector> getAnomalyDetectorIfPresent(String adID) {
NodeState state = states.get(adID);
return Optional.ofNullable(state).map(NodeState::getDetectorDef);
}

public void getAnomalyDetector(String adID, ActionListener<Optional<AnomalyDetector>> listener) {
TransportState state = transportStates.get(adID);
NodeState state = states.get(adID);
if (state != null && state.getDetectorDef() != null) {
listener.onResponse(Optional.of(state.getDetectorDef()));
} else {
Expand All @@ -127,7 +157,12 @@ private ActionListener<GetResponse> onGetDetectorResponse(String adID, ActionLis
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation);
AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId());
TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
// end execution if all features are disabled
if (detector.getEnabledFeatureIds().isEmpty()) {
listener.onFailure(new EndRunException(adID, CommonErrorMessages.ALL_FEATURES_DISABLED_ERR_MSG, true));
return;
}
NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setDetectorDef(detector);

listener.onResponse(Optional.of(detector));
Expand All @@ -145,13 +180,13 @@ private ActionListener<GetResponse> onGetDetectorResponse(String adID, ActionLis
* @param listener listener to handle get request
*/
public void getDetectorCheckpoint(String adID, ActionListener<Boolean> listener) {
TransportState state = transportStates.get(adID);
NodeState state = states.get(adID);
if (state != null && state.doesCheckpointExists()) {
listener.onResponse(Boolean.TRUE);
return;
}

GetRequest request = new GetRequest(CommonName.CHECKPOINT_INDEX_NAME, modelManager.getRcfModelId(adID, 0));
GetRequest request = new GetRequest(CommonName.CHECKPOINT_INDEX_NAME, modelPartitioner.getRcfModelId(adID, 0));

clientUtil.<GetRequest, GetResponse>asyncRequest(request, client::get, onGetCheckpointResponse(adID, listener));
}
Expand All @@ -161,7 +196,7 @@ private ActionListener<GetResponse> onGetCheckpointResponse(String adID, ActionL
if (response == null || !response.isExists()) {
listener.onResponse(Boolean.FALSE);
} else {
TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setCheckpointExists(true);
listener.onResponse(Boolean.TRUE);
}
Expand All @@ -173,8 +208,9 @@ private ActionListener<GetResponse> onGetCheckpointResponse(String adID, ActionL
*
* @param adID detector ID
*/
@Override
public void clear(String adID) {
transportStates.remove(adID);
states.remove(adID);
}

/**
Expand All @@ -183,18 +219,9 @@ public void clear(String adID) {
* java.util.ConcurrentModificationException.
*
*/
@Override
public void maintenance() {
transportStates.entrySet().stream().forEach(entry -> {
String detectorId = entry.getKey();
try {
TransportState state = entry.getValue();
if (state.expired(stateTtl)) {
transportStates.remove(detectorId);
}
} catch (Exception e) {
LOG.warn("Failed to finish maintenance for detector id " + detectorId, e);
}
});
maintenance(states, stateTtl);
}

public boolean isMuted(String nodeId) {
Expand Down Expand Up @@ -232,7 +259,7 @@ public boolean hasRunningQuery(AnomalyDetector detector) {
* @return last error for the detector
*/
public String getLastDetectionError(String adID) {
return Optional.ofNullable(transportStates.get(adID)).flatMap(state -> state.getLastDetectionError()).orElse(NO_ERROR);
return Optional.ofNullable(states.get(adID)).flatMap(state -> state.getLastDetectionError()).orElse(NO_ERROR);
}

/**
Expand All @@ -241,7 +268,7 @@ public String getLastDetectionError(String adID) {
* @param error error, can be null
*/
public void setLastDetectionError(String adID, String error) {
TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setLastDetectionError(error);
}

Expand All @@ -251,7 +278,7 @@ public void setLastDetectionError(String adID, String error) {
* @param exception exception, can be null
*/
public void setLastColdStartException(String adID, AnomalyDetectionException exception) {
TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setLastColdStartException(exception);
}

Expand All @@ -262,7 +289,7 @@ public void setLastColdStartException(String adID, AnomalyDetectionException exc
* @return last cold start exception for the detector
*/
public Optional<AnomalyDetectionException> fetchColdStartException(String adID) {
TransportState state = transportStates.get(adID);
NodeState state = states.get(adID);
if (state == null) {
return Optional.empty();
}
Expand All @@ -279,7 +306,7 @@ public Optional<AnomalyDetectionException> fetchColdStartException(String adID)
* @return running or not
*/
public boolean isColdStartRunning(String adID) {
TransportState state = transportStates.get(adID);
NodeState state = states.get(adID);
if (state != null) {
return state.isColdStartRunning();
}
Expand All @@ -290,10 +317,24 @@ public boolean isColdStartRunning(String adID) {
/**
* Mark the cold start status of the detector
* @param adID detector ID
* @param running whether it is running
* @return a callback when cold start is done
*/
public void setColdStartRunning(String adID, boolean running) {
TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
state.setColdStartRunning(running);
public Releasable markColdStartRunning(String adID) {
NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setColdStartRunning(true);
return () -> {
NodeState nodeState = states.get(adID);
if (nodeState != null) {
nodeState.setColdStartRunning(false);
}
};
}

public Instant getLastIndexThrottledTime() {
return lastIndexThrottledTime;
}

public void setLastIndexThrottledTime(Instant lastIndexThrottledTime) {
this.lastIndexThrottledTime = lastIndexThrottledTime;
}
}

0 comments on commit 6d19b63

Please sign in to comment.