From e098d192d769b44e6571a9056e3595abf37d43e1 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Thu, 4 Nov 2021 16:06:43 -0700 Subject: [PATCH] Add WorkLoadManager.unregisterModel This replaces the use of scaling a worker to 0 to unregister it. The dedicated method should make the behavior more transparent. --- .../ai/djl/serving/models/ModelManager.java | 27 +++++++++++++++++-- .../ai/djl/serving/wlm/WorkLoadManager.java | 16 ++++++++--- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/serving/src/main/java/ai/djl/serving/models/ModelManager.java b/serving/src/main/java/ai/djl/serving/models/ModelManager.java index 35359f502..7468c85a9 100644 --- a/serving/src/main/java/ai/djl/serving/models/ModelManager.java +++ b/serving/src/main/java/ai/djl/serving/models/ModelManager.java @@ -36,6 +36,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -180,10 +181,11 @@ public boolean unregisterWorkflow(String workflowName, String version) { logger.warn("Model not found: " + workflowName); return false; } + Set candidateModelsToUnregister = new HashSet<>(); if (version == null) { // unregister all versions for (WorkflowInfo workflow : endpoint.getWorkflows()) { - scaleWorkers(workflow, null, 0, 0); + candidateModelsToUnregister.addAll(workflow.getWorkflow().getModels()); workflow.getWorkflow().close(); } startupModels.remove(workflowName); @@ -195,13 +197,20 @@ public boolean unregisterWorkflow(String workflowName, String version) { logger.warn("Workflow not found: " + workflowName + ':' + version); return false; } - scaleWorkers(workflow, null, 0, 0); + candidateModelsToUnregister.addAll(workflow.getWorkflow().getModels()); workflow.getWorkflow().close(); startupModels.remove(workflowName); } if (endpoint.getWorkflows().isEmpty()) { endpoints.remove(workflowName); } + + // Unregister candidate models if they are not used for a remaining endpoint + candidateModelsToUnregister.removeAll(getModels()); + for (ModelInfo model : candidateModelsToUnregister) { + wlm.unregisterModel(model); + } + return true; } @@ -250,6 +259,20 @@ public Map getEndpoints() { return endpoints; } + /** + * Returns all models in an endpoint. + * + * @return all models in an endpoint + */ + public Set getModels() { + return getEndpoints() + .values() + .stream() + .flatMap(e -> e.getWorkflows().stream()) + .flatMap(w -> w.getWorkflow().getModels().stream()) + .collect(Collectors.toSet()); + } + /** * Returns a version of workflow. * diff --git a/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java index beb08a0a9..777642e00 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java @@ -69,6 +69,17 @@ public List getWorkers(ModelInfo modelInfo) { return list; } + /** + * Removes a model from management. + * + * @param model the model to remove + */ + public void unregisterModel(ModelInfo model) { + WorkerPool pool = getWorkerPoolForModel(model); + pool.scaleWorkers(null, 0, 0); + workerPools.remove(model); + } + /** * Adds an inference job to the job queue of the next free worker. scales up worker if * necessary. @@ -236,9 +247,6 @@ public WorkerPool scaleWorkers(String deviceName, int newMinWorkers, int newMaxW cleanup(); List threads; - if (minWorkers == 0) { - workerPools.remove(model); - } threads = getWorkers(); List fixedPoolThread = @@ -274,7 +282,7 @@ private void addThreads(ModelInfo model, int count, boolean permanent) { WorkerThread thread = WorkerThread.builder() .setModel(model) - .setJobQueue(getWorkerPoolForModel(model).getJobQueue()) + .setJobQueue(jobQueue) .optFixPoolThread(permanent) .build();