Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reset realtime task state #201

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 48 additions & 49 deletions src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public class ADTaskCacheManager {
// This field is to record all batch tasks. Both single entity detector task
// and HC entity task will be cached in this field.
// Key: task id
private final Map<String, ADBatchTaskCache> taskCaches;
private final Map<String, ADBatchTaskCache> batchTaskCaches;

// We use this field to record all detector level tasks which running on the
// coordinating node to resolve race condition. We will check if
Expand All @@ -89,7 +89,7 @@ public class ADTaskCacheManager {
private Map<String, String> detectorTasks;

// Use this field to cache all HC tasks. Key is detector id
private Map<String, ADHCBatchTaskCache> hcTaskCaches;
private Map<String, ADHCBatchTaskCache> hcBatchTaskCaches;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the differences between hcBatchTaskCaches and batchTaskCaches? Do they have overlap contents? Is it that hcBatchTaskCaches have detector level task cache, while batchTaskCaches has entity level task cache (if it is a HC detector) and detector level task cache (if it is single-stream detector)? if yes, then I am more confused. Why cannot we combine these two maps?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your understanding is correct. hcBatchTaskCaches: cache HC detector level information on coordinating node, will add more comments. I think the comment for batchTaskCaches is enough, you can read and suggest more if the comment is still not clear.

They don't have overlap contents. Combine these two maps will mix the coordinating node and worker node cache together. I don't prefer that way now. If no big concern or cons , how about we keep current design now?

// cache deleted detector level tasks
private Queue<String> deletedDetectorTasks;
// cache deleted detectors
Expand All @@ -112,10 +112,10 @@ public ADTaskCacheManager(Settings settings, ClusterService clusterService, Memo
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BATCH_TASK_PER_NODE, it -> maxAdBatchTaskPerNode = it);
this.maxCachedDeletedTask = MAX_CACHED_DELETED_TASKS.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_CACHED_DELETED_TASKS, it -> maxCachedDeletedTask = it);
taskCaches = new ConcurrentHashMap<>();
batchTaskCaches = new ConcurrentHashMap<>();
this.memoryTracker = memoryTracker;
this.detectorTasks = new ConcurrentHashMap<>();
this.hcTaskCaches = new ConcurrentHashMap<>();
this.hcBatchTaskCaches = new ConcurrentHashMap<>();
this.realtimeTaskCaches = new ConcurrentHashMap<>();
this.deletedDetectorTasks = new ConcurrentLinkedQueue<>();
this.deletedDetectors = new ConcurrentLinkedQueue<>();
Expand Down Expand Up @@ -148,7 +148,7 @@ public synchronized void add(ADTask adTask) {
memoryTracker.consumeMemory(neededCacheSize, true, HISTORICAL_SINGLE_ENTITY_DETECTOR);
ADBatchTaskCache taskCache = new ADBatchTaskCache(adTask);
taskCache.getCacheMemorySize().set(neededCacheSize);
taskCaches.put(taskId, taskCache);
batchTaskCaches.put(taskId, taskCache);
}

/**
Expand All @@ -168,7 +168,7 @@ public synchronized void add(String detectorId, ADTask adTask) {
if (ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) {
ADHCBatchTaskCache adhcBatchTaskCache = new ADHCBatchTaskCache();
adhcBatchTaskCache.setIsCoordinatingNode(true);
this.hcTaskCaches.put(detectorId, adhcBatchTaskCache);
this.hcBatchTaskCaches.put(detectorId, adhcBatchTaskCache);
}
}

Expand Down Expand Up @@ -290,7 +290,7 @@ public Deque<Map.Entry<Long, Optional<double[]>>> getShingle(String taskId) {
* @return true if task exists in cache; otherwise, return false.
*/
public boolean contains(String taskId) {
return taskCaches.containsKey(taskId);
return batchTaskCaches.containsKey(taskId);
}

/**
Expand All @@ -300,7 +300,7 @@ public boolean contains(String taskId) {
* @return true if there is task in cache; otherwise return false
*/
public boolean containsTaskOfDetector(String detectorId) {
return taskCaches.values().stream().filter(v -> Objects.equals(detectorId, v.getDetectorId())).findAny().isPresent();
return batchTaskCaches.values().stream().filter(v -> Objects.equals(detectorId, v.getDetectorId())).findAny().isPresent();
}

/**
Expand All @@ -310,7 +310,7 @@ public boolean containsTaskOfDetector(String detectorId) {
* @return list of task id
*/
public List<String> getTasksOfDetector(String detectorId) {
return taskCaches
return batchTaskCaches
.values()
.stream()
.filter(v -> Objects.equals(detectorId, v.getDetectorId()))
Expand All @@ -332,11 +332,11 @@ private ADBatchTaskCache getBatchTaskCache(String taskId) {
if (!contains(taskId)) {
throw new IllegalArgumentException("AD task not in cache");
}
return taskCaches.get(taskId);
return batchTaskCaches.get(taskId);
}

private List<ADBatchTaskCache> getBatchTaskCacheByDetectorId(String detectorId) {
return taskCaches.values().stream().filter(v -> Objects.equals(detectorId, v.getDetectorId())).collect(Collectors.toList());
return batchTaskCaches.values().stream().filter(v -> Objects.equals(detectorId, v.getDetectorId())).collect(Collectors.toList());
}

/**
Expand Down Expand Up @@ -376,7 +376,7 @@ public void remove(String taskId) {
if (contains(taskId)) {
ADBatchTaskCache taskCache = getBatchTaskCache(taskId);
memoryTracker.releaseMemory(taskCache.getCacheMemorySize().get(), true, HISTORICAL_SINGLE_ENTITY_DETECTOR);
taskCaches.remove(taskId);
batchTaskCaches.remove(taskId);
// can't remove detector id from cache here as it's possible that some task running on
// other worker nodes
}
Expand All @@ -388,7 +388,7 @@ public void remove(String taskId) {
* @param detectorId detector id
*/
public void removeHistoricalTaskCache(String detectorId) {
ADHCBatchTaskCache taskCache = hcTaskCaches.get(detectorId);
ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId);
if (taskCache != null) {
// this will happen only on coordinating node. When worker nodes left,
// we will reset task state as STOPPED and clean up cache, add this warning
Expand All @@ -403,7 +403,7 @@ public void removeHistoricalTaskCache(String detectorId) {
);
}
taskCache.clear();
hcTaskCaches.remove(detectorId);
hcBatchTaskCaches.remove(detectorId);
}
List<String> tasksOfDetector = getTasksOfDetector(detectorId);
for (String taskId : tasksOfDetector) {
Expand Down Expand Up @@ -467,7 +467,7 @@ public ADTaskCancellationState cancelByDetectorId(String detectorId, String reas
cache.cancel(reason, userName);
}
}
ADHCBatchTaskCache hcTaskCache = hcTaskCaches.get(detectorId);
ADHCBatchTaskCache hcTaskCache = hcBatchTaskCaches.get(detectorId);
if (hcTaskCache != null) {
hcTaskCache.setHistoricalAnalysisCancelled(true);
hcTaskCache.clearPendingEntities();
Expand All @@ -487,7 +487,7 @@ public boolean isCancelled(String taskId) {
ADBatchTaskCache taskCache = getBatchTaskCache(taskId);
String detectorId = taskCache.getDetectorId();

ADHCBatchTaskCache hcTaskCache = hcTaskCaches.get(detectorId);
ADHCBatchTaskCache hcTaskCache = hcBatchTaskCaches.get(detectorId);
boolean hcDetectorStopped = false;
if (hcTaskCache != null) {
hcDetectorStopped = hcTaskCache.getHistoricalAnalysisCancelled();
Expand Down Expand Up @@ -524,14 +524,14 @@ public String getCancelledBy(String taskId) {
* @return task count
*/
public int size() {
return taskCaches.size();
return batchTaskCaches.size();
}

/**
* Clear all tasks.
*/
public void clear() {
taskCaches.clear();
batchTaskCaches.clear();
detectorTasks.clear();
}

Expand Down Expand Up @@ -570,7 +570,7 @@ public long shingleMemorySize(int shingleSize, int enabledFeatureSize) {
* @return true if top entity inited; otherwise return false
*/
public synchronized boolean topEntityInited(String detectorId) {
return hcTaskCaches.containsKey(detectorId) ? hcTaskCaches.get(detectorId).getTopEntitiesInited() : false;
return hcBatchTaskCaches.containsKey(detectorId) ? hcBatchTaskCaches.get(detectorId).getTopEntitiesInited() : false;
}

/**
Expand All @@ -589,7 +589,7 @@ public void setTopEntityInited(String detectorId) {
* @return entity count
*/
public int getPendingEntityCount(String detectorId) {
return hcTaskCaches.containsKey(detectorId) ? hcTaskCaches.get(detectorId).getPendingEntityCount() : 0;
return hcBatchTaskCaches.containsKey(detectorId) ? hcBatchTaskCaches.get(detectorId).getPendingEntityCount() : 0;
}

/**
Expand All @@ -599,15 +599,15 @@ public int getPendingEntityCount(String detectorId) {
* @return count of detector's running entity in cache
*/
public int getRunningEntityCount(String detectorId) {
ADHCBatchTaskCache taskCache = hcTaskCaches.get(detectorId);
ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId);
if (taskCache != null) {
return taskCache.getRunningEntityCount();
}
return 0;
}

public int getTempEntityCount(String detectorId) {
ADHCBatchTaskCache taskCache = hcTaskCaches.get(detectorId);
ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId);
if (taskCache != null) {
return taskCache.getTempEntityCount();
}
Expand All @@ -621,7 +621,7 @@ public int getTempEntityCount(String detectorId) {
* @return total top entity count
*/
public Integer getTopEntityCount(String detectorId) {
ADHCBatchTaskCache batchTaskCache = hcTaskCaches.get(detectorId);
ADHCBatchTaskCache batchTaskCache = hcBatchTaskCaches.get(detectorId);
if (batchTaskCache != null) {
return batchTaskCache.getTopEntityCount();
} else {
Expand All @@ -637,7 +637,7 @@ public Integer getTopEntityCount(String detectorId) {
* @return detector's running entities in cache
*/
public List<String> getRunningEntities(String detectorId) {
if (hcTaskCaches.containsKey(detectorId)) {
if (hcBatchTaskCaches.containsKey(detectorId)) {
ADHCBatchTaskCache hcTaskCache = getExistingHCTaskCache(detectorId);
return Arrays.asList(hcTaskCache.getRunningEntities());
}
Expand Down Expand Up @@ -699,7 +699,6 @@ public synchronized int scaleDownHCDetectorTaskSlots(String detectorId, int delt
if (adTaskSlotLimit != null && delta > 0) {
int newTaskSlots = taskSlots - delta;
if (newTaskSlots > 0) {
taskSlots = newTaskSlots;
logger.info("Scale down task slots of detector {} from {} to {}", detectorId, taskSlots, newTaskSlots);
adTaskSlotLimit.setDetectorTaskSlots(newTaskSlots);
return newTaskSlots;
Expand Down Expand Up @@ -734,7 +733,7 @@ public int getDetectorTaskSlots(String detectorId) {
}

public int getUnfinishedEntityCount(String detectorId) {
ADHCBatchTaskCache taskCache = hcTaskCaches.get(detectorId);
ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId);
if (taskCache != null) {
return taskCache.getUnfinishedEntityCount();
}
Expand All @@ -754,7 +753,7 @@ public int getTotalDetectorTaskSlots() {
}

public int getTotalBatchTaskCount() {
return taskCaches.size();
return batchTaskCaches.size();
}

/**
Expand All @@ -777,8 +776,8 @@ public int getAvailableNewEntityTaskLanes(String detectorId) {
}

private ADHCBatchTaskCache getExistingHCTaskCache(String detectorId) {
if (hcTaskCaches.containsKey(detectorId)) {
return hcTaskCaches.get(detectorId);
if (hcBatchTaskCaches.containsKey(detectorId)) {
return hcBatchTaskCaches.get(detectorId);
} else {
throw new IllegalArgumentException("Can't find HC detector in cache");
}
Expand All @@ -796,7 +795,7 @@ public void addPendingEntities(String detectorId, List<String> entities) {
}

private ADHCBatchTaskCache getOrCreateHCTaskCache(String detectorId) {
return hcTaskCaches.computeIfAbsent(detectorId, id -> new ADHCBatchTaskCache());
return hcBatchTaskCaches.computeIfAbsent(detectorId, id -> new ADHCBatchTaskCache());
}

/**
Expand All @@ -809,7 +808,7 @@ public boolean isHCTaskRunning(String detectorId) {
return true;
}
// Only running tasks will be in cache.
Optional<ADBatchTaskCache> entityTask = this.taskCaches
Optional<ADBatchTaskCache> entityTask = this.batchTaskCaches
.values()
.stream()
.filter(cache -> Objects.equals(detectorId, cache.getDetectorId()) && cache.getEntity() != null)
Expand All @@ -823,7 +822,7 @@ public boolean isHCTaskRunning(String detectorId) {
* @return true if find detector id in HC cache
*/
public boolean isHCTaskCoordinatingNode(String detectorId) {
return hcTaskCaches.containsKey(detectorId) && hcTaskCaches.get(detectorId).isCoordinatingNode();
return hcBatchTaskCaches.containsKey(detectorId) && hcBatchTaskCaches.get(detectorId).isCoordinatingNode();
}

/**
Expand All @@ -845,8 +844,8 @@ public void setTopEntityCount(String detectorId, Integer count) {
* @return one entity
*/
public synchronized String pollEntity(String detectorId) {
if (this.hcTaskCaches.containsKey(detectorId)) {
ADHCBatchTaskCache hcTaskCache = this.hcTaskCaches.get(detectorId);
if (this.hcBatchTaskCaches.containsKey(detectorId)) {
ADHCBatchTaskCache hcTaskCache = this.hcBatchTaskCaches.get(detectorId);
String entity = hcTaskCache.pollEntity();
return entity;
} else {
Expand All @@ -872,8 +871,8 @@ public void addPendingEntity(String detectorId, String entity) {
* @param entity entity value
*/
public synchronized void moveToRunningEntity(String detectorId, String entity) {
if (this.hcTaskCaches.containsKey(detectorId)) {
ADHCBatchTaskCache hcTaskCache = this.hcTaskCaches.get(detectorId);
if (this.hcBatchTaskCaches.containsKey(detectorId)) {
ADHCBatchTaskCache hcTaskCache = this.hcBatchTaskCaches.get(detectorId);
hcTaskCache.moveToRunningEntity(entity);
}
}
Expand Down Expand Up @@ -919,8 +918,8 @@ public int increaseEntityTaskRetry(String detectorId, String taskId) {
* @param entity entity value
*/
public void removeEntity(String detectorId, String entity) {
if (hcTaskCaches.containsKey(detectorId)) {
hcTaskCaches.get(detectorId).removeEntity(entity);
if (hcBatchTaskCaches.containsKey(detectorId)) {
hcBatchTaskCaches.get(detectorId).removeEntity(entity);
}
}

Expand All @@ -941,7 +940,7 @@ public Entity getEntity(String taskId) {
* @return true if detector still has entity in cache
*/
public synchronized boolean hasEntity(String detectorId) {
return hcTaskCaches.containsKey(detectorId) && hcTaskCaches.get(detectorId).hasEntity();
return hcBatchTaskCaches.containsKey(detectorId) && hcBatchTaskCaches.get(detectorId).hasEntity();
}

/**
Expand All @@ -953,8 +952,8 @@ public synchronized boolean hasEntity(String detectorId) {
*/
public boolean removeRunningEntity(String detectorId, String entity) {
logger.debug("Remove entity from running entities cache: {}", entity);
if (hcTaskCaches.containsKey(detectorId)) {
ADHCBatchTaskCache hcTaskCache = hcTaskCaches.get(detectorId);
if (hcBatchTaskCaches.containsKey(detectorId)) {
ADHCBatchTaskCache hcTaskCache = hcBatchTaskCaches.get(detectorId);
return hcTaskCache.removeRunningEntity(entity);
}
return false;
Expand All @@ -966,7 +965,7 @@ public boolean removeRunningEntity(String detectorId, String entity) {
* @return true if can get semaphore
*/
public boolean tryAcquireTaskUpdatingSemaphore(String detectorId) {
ADHCBatchTaskCache taskCache = hcTaskCaches.get(detectorId);
ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId);
if (taskCache != null) {
return taskCache.tryAcquireTaskUpdatingSemaphore();
}
Expand All @@ -978,7 +977,7 @@ public boolean tryAcquireTaskUpdatingSemaphore(String detectorId) {
* @param detectorId detector id
*/
public void releaseTaskUpdatingSemaphore(String detectorId) {
ADHCBatchTaskCache taskCache = hcTaskCaches.get(detectorId);
ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId);
if (taskCache != null) {
taskCache.releaseTaskUpdatingSemaphore();
}
Expand All @@ -990,8 +989,8 @@ public void releaseTaskUpdatingSemaphore(String detectorId) {
* @param detectorId detector id
*/
public void clearPendingEntities(String detectorId) {
if (hcTaskCaches.containsKey(detectorId)) {
hcTaskCaches.get(detectorId).clearPendingEntities();
if (hcBatchTaskCaches.containsKey(detectorId)) {
hcBatchTaskCaches.get(detectorId).clearPendingEntities();
}
}

Expand Down Expand Up @@ -1138,8 +1137,8 @@ public String pollDeletedDetector() {
* @return true if detector level task state changed
*/
public synchronized boolean isDetectorTaskStateChanged(String detectorId, String newState) {
if (hcTaskCaches.containsKey(detectorId)) {
return !Objects.equals(hcTaskCaches.get(detectorId).getDetectorTaskState(), newState);
if (hcBatchTaskCaches.containsKey(detectorId)) {
return !Objects.equals(hcBatchTaskCaches.get(detectorId).getDetectorTaskState(), newState);
}
return true;
}
Expand All @@ -1158,15 +1157,15 @@ public String getDetectorTaskId(String detectorId) {
}

public Instant getLastScaleEntityTaskLaneTime(String detectorId) {
ADHCBatchTaskCache taskCache = hcTaskCaches.get(detectorId);
ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId);
if (taskCache != null) {
return taskCache.getLastScaleEntityTaskSlotsTime();
}
return null;
}

public void refreshLastScaleEntityTaskLaneTime(String detectorId) {
ADHCBatchTaskCache taskCache = hcTaskCaches.get(detectorId);
ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId);
if (taskCache != null) {
taskCache.setLastScaleEntityTaskSlotsTime(Instant.now());
}
Expand Down