Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Rename TransportState to NodeState #259

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.ad.transport;
package com.amazon.opendistroforelasticsearch.ad;

import java.time.Clock;
import java.time.Duration;
Expand All @@ -27,16 +27,16 @@
* Storing intermediate state during the execution of transport action
*
*/
public class TransportState {
public class NodeState implements ExpiringState {
private String detectorId;
// detector definition
private AnomalyDetector detectorDef;
// number of partitions
private int partitonNumber;
// checkpoint fetch time
private Instant lastAccessTime;
// last detection error. Used by DetectorStateHandler to check if the error for a
// detector has changed or not. If changed, trigger indexing.
// last detection error recorded in result index. Used by DetectorStateHandler
// to check if the error for a detector has changed or not. If changed, trigger indexing.
private Optional<String> lastDetectionError;
// last training error. Used to save cold start error by a concurrent cold start thread.
private Optional<AnomalyDetectionException> lastColdStartException;
Expand All @@ -47,7 +47,7 @@ public class TransportState {
// cold start running flag to prevent concurrent cold start
private boolean coldStartRunning;

public TransportState(String detectorId, Clock clock) {
public NodeState(String detectorId, Clock clock) {
this.detectorId = detectorId;
this.detectorDef = null;
this.partitonNumber = -1;
Expand Down Expand Up @@ -182,7 +182,8 @@ private void refreshLastUpdateTime() {
* @param stateTtl time to leave for the state
* @return whether the transport state is expired
*/
@Override
public boolean expired(Duration stateTtl) {
return lastAccessTime.plus(stateTtl).isBefore(clock.instant());
return expired(lastAccessTime, stateTtl, clock.instant());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.ad.transport;
package com.amazon.opendistroforelasticsearch.ad;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -29,57 +30,76 @@
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;

import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonName;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelPartitioner;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.transport.BackPressureRouting;
import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil;

/**
* ADStateManager is used by transport layer to manage AnomalyDetector object
* and the number of partitions for a detector id.
* NodeStateManager is used to manage states shared by transport and ml components
* like AnomalyDetector object
*
*/
public class TransportStateManager {
private static final Logger LOG = LogManager.getLogger(TransportStateManager.class);
private ConcurrentHashMap<String, TransportState> transportStates;
public class NodeStateManager implements MaintenanceState, CleanState {
private static final Logger LOG = LogManager.getLogger(NodeStateManager.class);
private ConcurrentHashMap<String, NodeState> states;
private Client client;
private ModelManager modelManager;
private ModelPartitioner modelPartitioner;
private NamedXContentRegistry xContentRegistry;
private ClientUtil clientUtil;
// map from ES node id to the node's backpressureMuter
private Map<String, BackPressureRouting> backpressureMuter;
private final Clock clock;
private final Settings settings;
private final Duration stateTtl;
// last time we are throttled due to too much index pressure
private Instant lastIndexThrottledTime;

public static final String NO_ERROR = "no_error";

public TransportStateManager(
/**
* Constructor
*
* @param client Client to make calls to ElasticSearch
* @param xContentRegistry ES named content registry
* @param settings ES settings
* @param clientUtil AD Client utility
* @param clock A UTC clock
* @param stateTtl Max time to keep state in memory
* @param modelPartitioner Used to partiton a RCF forest

*/
public NodeStateManager(
Client client,
NamedXContentRegistry xContentRegistry,
ModelManager modelManager,
Settings settings,
ClientUtil clientUtil,
Clock clock,
Duration stateTtl
Duration stateTtl,
ModelPartitioner modelPartitioner
) {
this.transportStates = new ConcurrentHashMap<>();
this.states = new ConcurrentHashMap<>();
this.client = client;
this.modelManager = modelManager;
this.modelPartitioner = modelPartitioner;
this.xContentRegistry = xContentRegistry;
this.clientUtil = clientUtil;
this.backpressureMuter = new ConcurrentHashMap<>();
this.clock = clock;
this.settings = settings;
this.stateTtl = stateTtl;
this.lastIndexThrottledTime = Instant.MIN;
}

/**
Expand All @@ -90,20 +110,30 @@ public TransportStateManager(
* @throws LimitExceededException when there is no sufficient resource available
*/
public int getPartitionNumber(String adID, AnomalyDetector detector) {
TransportState state = transportStates.get(adID);
NodeState state = states.get(adID);
if (state != null && state.getPartitonNumber() > 0) {
return state.getPartitonNumber();
}

int partitionNum = modelManager.getPartitionedForestSizes(detector).getKey();
state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
int partitionNum = modelPartitioner.getPartitionedForestSizes(detector).getKey();
state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setPartitonNumber(partitionNum);

return partitionNum;
}

/**
* Get Detector config object if present
* @param adID detector Id
* @return the Detecor config object or empty Optional
*/
public Optional<AnomalyDetector> getAnomalyDetectorIfPresent(String adID) {
NodeState state = states.get(adID);
return Optional.ofNullable(state).map(NodeState::getDetectorDef);
}

public void getAnomalyDetector(String adID, ActionListener<Optional<AnomalyDetector>> listener) {
TransportState state = transportStates.get(adID);
NodeState state = states.get(adID);
if (state != null && state.getDetectorDef() != null) {
listener.onResponse(Optional.of(state.getDetectorDef()));
} else {
Expand All @@ -127,7 +157,12 @@ private ActionListener<GetResponse> onGetDetectorResponse(String adID, ActionLis
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser::getTokenLocation);
AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId());
TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
// end execution if all features are disabled
if (detector.getEnabledFeatureIds().isEmpty()) {
listener.onFailure(new EndRunException(adID, CommonErrorMessages.ALL_FEATURES_DISABLED_ERR_MSG, true));
return;
}
ohltyler marked this conversation as resolved.
Show resolved Hide resolved
NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setDetectorDef(detector);

listener.onResponse(Optional.of(detector));
Expand All @@ -145,13 +180,13 @@ private ActionListener<GetResponse> onGetDetectorResponse(String adID, ActionLis
* @param listener listener to handle get request
*/
public void getDetectorCheckpoint(String adID, ActionListener<Boolean> listener) {
TransportState state = transportStates.get(adID);
NodeState state = states.get(adID);
if (state != null && state.doesCheckpointExists()) {
listener.onResponse(Boolean.TRUE);
return;
}

GetRequest request = new GetRequest(CommonName.CHECKPOINT_INDEX_NAME, modelManager.getRcfModelId(adID, 0));
GetRequest request = new GetRequest(CommonName.CHECKPOINT_INDEX_NAME, modelPartitioner.getRcfModelId(adID, 0));

clientUtil.<GetRequest, GetResponse>asyncRequest(request, client::get, onGetCheckpointResponse(adID, listener));
}
Expand All @@ -161,7 +196,7 @@ private ActionListener<GetResponse> onGetCheckpointResponse(String adID, ActionL
if (response == null || !response.isExists()) {
listener.onResponse(Boolean.FALSE);
} else {
TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setCheckpointExists(true);
listener.onResponse(Boolean.TRUE);
}
Expand All @@ -173,8 +208,9 @@ private ActionListener<GetResponse> onGetCheckpointResponse(String adID, ActionL
*
* @param adID detector ID
*/
@Override
public void clear(String adID) {
transportStates.remove(adID);
states.remove(adID);
}

/**
Expand All @@ -183,18 +219,9 @@ public void clear(String adID) {
* java.util.ConcurrentModificationException.
*
*/
@Override
public void maintenance() {
transportStates.entrySet().stream().forEach(entry -> {
String detectorId = entry.getKey();
try {
TransportState state = entry.getValue();
if (state.expired(stateTtl)) {
transportStates.remove(detectorId);
}
} catch (Exception e) {
LOG.warn("Failed to finish maintenance for detector id " + detectorId, e);
}
});
maintenance(states, stateTtl);
}

public boolean isMuted(String nodeId) {
Expand Down Expand Up @@ -232,7 +259,7 @@ public boolean hasRunningQuery(AnomalyDetector detector) {
* @return last error for the detector
*/
public String getLastDetectionError(String adID) {
return Optional.ofNullable(transportStates.get(adID)).flatMap(state -> state.getLastDetectionError()).orElse(NO_ERROR);
return Optional.ofNullable(states.get(adID)).flatMap(state -> state.getLastDetectionError()).orElse(NO_ERROR);
}

/**
Expand All @@ -241,7 +268,7 @@ public String getLastDetectionError(String adID) {
* @param error error, can be null
*/
public void setLastDetectionError(String adID, String error) {
TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setLastDetectionError(error);
}

Expand All @@ -251,7 +278,7 @@ public void setLastDetectionError(String adID, String error) {
* @param exception exception, can be null
*/
public void setLastColdStartException(String adID, AnomalyDetectionException exception) {
TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setLastColdStartException(exception);
}

Expand All @@ -262,7 +289,7 @@ public void setLastColdStartException(String adID, AnomalyDetectionException exc
* @return last cold start exception for the detector
*/
public Optional<AnomalyDetectionException> fetchColdStartException(String adID) {
TransportState state = transportStates.get(adID);
NodeState state = states.get(adID);
if (state == null) {
return Optional.empty();
}
Expand All @@ -279,7 +306,7 @@ public Optional<AnomalyDetectionException> fetchColdStartException(String adID)
* @return running or not
*/
public boolean isColdStartRunning(String adID) {
TransportState state = transportStates.get(adID);
NodeState state = states.get(adID);
if (state != null) {
return state.isColdStartRunning();
}
Expand All @@ -290,10 +317,24 @@ public boolean isColdStartRunning(String adID) {
/**
* Mark the cold start status of the detector
* @param adID detector ID
* @param running whether it is running
* @return a callback when cold start is done
*/
public void setColdStartRunning(String adID, boolean running) {
TransportState state = transportStates.computeIfAbsent(adID, id -> new TransportState(id, clock));
state.setColdStartRunning(running);
public Releasable markColdStartRunning(String adID) {
NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock));
state.setColdStartRunning(true);
return () -> {
NodeState nodeState = states.get(adID);
if (nodeState != null) {
nodeState.setColdStartRunning(false);
}
};
}

public Instant getLastIndexThrottledTime() {
return lastIndexThrottledTime;
}

public void setLastIndexThrottledTime(Instant lastIndexThrottledTime) {
this.lastIndexThrottledTime = lastIndexThrottledTime;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyResult;
import com.amazon.opendistroforelasticsearch.ad.model.IntervalTimeConfiguration;
import com.amazon.opendistroforelasticsearch.ad.transport.TransportStateManager;
import com.amazon.opendistroforelasticsearch.ad.transport.handler.AnomalyIndexHandler;
import com.amazon.opendistroforelasticsearch.ad.transport.handler.DetectionStateHandler;
import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil;
Expand Down Expand Up @@ -152,7 +151,7 @@ public void setup() throws Exception {
AnomalyDetectionIndices anomalyDetectionIndices = mock(AnomalyDetectionIndices.class);
IndexNameExpressionResolver indexNameResolver = mock(IndexNameExpressionResolver.class);
IndexUtils indexUtils = new IndexUtils(client, clientUtil, clusterService, indexNameResolver);
TransportStateManager stateManager = mock(TransportStateManager.class);
NodeStateManager stateManager = mock(NodeStateManager.class);
detectorStateHandler = new DetectionStateHandler(
client,
settings,
Expand Down
Loading