From 40514cf3847506a3ebfafae37297456f7a4d7262 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 21 Jun 2023 13:30:33 -0700 Subject: [PATCH] [python] Adds rolling batch support (#828) * [python] Adds rolling batch support * change to fair lock --------- Co-authored-by: sindhuso --- engines/python/setup/djl_python/deepspeed.py | 7 +- .../java/ai/djl/python/engine/PyModel.java | 1 + .../ai/djl/python/engine/PyPredictor.java | 21 ++- .../ai/djl/python/engine/RollingBatch.java | 168 ++++++++++++++++++ .../ai/djl/python/engine/PyEngineTest.java | 35 ++++ .../src/test/resources/rolling_batch/model.py | 127 +++++++++++++ .../rolling_batch/serving.properties | 2 + 7 files changed, 355 insertions(+), 6 deletions(-) create mode 100644 engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java create mode 100644 engines/python/src/test/resources/rolling_batch/model.py create mode 100644 engines/python/src/test/resources/rolling_batch/serving.properties diff --git a/engines/python/setup/djl_python/deepspeed.py b/engines/python/setup/djl_python/deepspeed.py index cc57fc76d99..9dad7e7e3a6 100644 --- a/engines/python/setup/djl_python/deepspeed.py +++ b/engines/python/setup/djl_python/deepspeed.py @@ -329,9 +329,10 @@ def inference(self, inputs: Input): **model_kwargs)) return outputs if self.task == "text-generation": - tokenized_inputs = self.tokenizer( - input_data, padding=True, - return_tensors="pt").to(self.device) + tokenized_inputs = self.tokenizer(input_data, + padding=True, + return_tensors="pt").to( + self.device) with torch.no_grad(): output_tokens = self.model.generate( input_ids=tokenized_inputs.input_ids, diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java index f121be9b24f..80bb120d42c 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java @@ -85,6 +85,7 @@ public void load(Path modelPath, String prefix, Map options) throws I String value = (String) entry.getValue(); if (!"env".equals(key)) { pyEnv.addParameter(key, value); + properties.put(key, value); } logger.debug("{}={}", key, value); switch (key) { diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java index d927042ff21..152aafbb918 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java @@ -38,6 +38,8 @@ class PyPredictor extends Predictor { private PyProcess process; private int timeout; + private boolean isRollingBatch; + private RollingBatch rollingBatch; public PyPredictor( Model model, @@ -48,6 +50,12 @@ public PyPredictor( super(model, translator, device, false); this.process = process; this.timeout = timeout; + isRollingBatch = Boolean.parseBoolean(model.getProperty("rolling_batch", "false")); + if (isRollingBatch) { + int maxRollingBatchSize = + Integer.parseInt(model.getProperty("max_rolling_batch_size", "3")); + rollingBatch = new RollingBatch(process, maxRollingBatchSize, timeout); + } } /** {@inheritDoc} */ @@ -62,7 +70,12 @@ public List batchPredict(List inputs) throws TranslateException { if (first instanceof Input) { int size = inputs.size(); if (size == 1) { - Output output = process.predict((Input) first, timeout, false); + Output output; + if (isRollingBatch) { + output = rollingBatch.addInput((Input) first, timeout); + } else { + output = process.predict((Input) first, timeout, false); + } return Collections.singletonList((O) output); } @@ -120,8 +133,7 @@ public List batchPredict(List inputs) throws TranslateException { /** {@inheritDoc} */ @Override - protected NDList predictInternal(TranslatorContext ctx, NDList ndList) - throws TranslateException { + protected NDList predictInternal(TranslatorContext ctx, NDList ndList) { Input inputs = new Input(); inputs.addProperty("Content-Type", "tensor/ndlist"); inputs.add(ndList.encode()); @@ -135,5 +147,8 @@ protected NDList predictInternal(TranslatorContext ctx, NDList ndList) public void close() { super.close(); process.stopPythonProcess(); + if (rollingBatch != null) { + rollingBatch.shutdown(); + } } } diff --git a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java new file mode 100644 index 00000000000..cbab336f90e --- /dev/null +++ b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java @@ -0,0 +1,168 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.python.engine; + +import ai.djl.inference.streaming.ChunkedBytesSupplier; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.BytesSupplier; +import ai.djl.translate.TranslateException; +import ai.djl.util.JsonUtils; +import ai.djl.util.PairList; + +import com.google.gson.JsonObject; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +class RollingBatch implements Runnable { + + private static final Logger logger = LoggerFactory.getLogger(RollingBatch.class); + + private static ExecutorService threadPool = Executors.newCachedThreadPool(); + + private PyProcess process; + private int maxRollingBatchSize; + private int timeout; + private boolean stop; + private List list; + private Thread currentThread; + private ReentrantLock lock; + private Condition canAdd; + private Condition canRead; + + RollingBatch(PyProcess process, int maxRollingBatchSize, int timeout) { + this.process = process; + this.maxRollingBatchSize = maxRollingBatchSize; + this.timeout = timeout; + list = new ArrayList<>(3); + lock = new ReentrantLock(true); + canAdd = lock.newCondition(); + canRead = lock.newCondition(); + threadPool.submit(this); + } + + /** {@inheritDoc} */ + @Override + public void run() { + currentThread = Thread.currentThread(); + while (!stop) { + try { + lock.lock(); + if (list.isEmpty()) { + canRead.await(); + } + + Input batch = new Input(); + int size = list.size(); + batch.addProperty("batch_size", String.valueOf(size)); + for (int i = 0; i < size; ++i) { + Request req = list.get(i); + String prefix = "batch_" + i + ".data"; + batch.add(prefix, req.getRequest()); + } + + // TODO: Handler error case + + Output output = process.predict(batch, timeout, false); + PairList content = output.getContent(); + if (content.size() != size) { + throw new TranslateException( + "Batch output size mismatch, expected: " + + size + + ", actual: " + + content.size()); + } + for (int i = 0; i < size; ++i) { + Request status = list.get(i); + String json = content.get(i).getValue().getAsString(); + status.addResponse(json); + } + list.removeIf(status -> status.last); + if (list.size() < maxRollingBatchSize) { + canAdd.signal(); + } + } catch (InterruptedException e) { + break; + } catch (TranslateException e) { + logger.error("RollingBatch thread died, killing python process.", e); + process.stopPythonProcess(); + } finally { + lock.unlock(); + } + } + } + + public Output addInput(Input input, int timeout) throws TranslateException { + try { + lock.lock(); + if (list.size() >= maxRollingBatchSize) { + if (!canAdd.await(timeout, TimeUnit.SECONDS)) { + throw new TranslateException("Time out in: " + timeout); + } + } + Request req = new Request(input); + list.add(req); + canRead.signal(); + return req.output; + } catch (InterruptedException e) { + throw new TranslateException("Interrupted", e); + } finally { + lock.unlock(); + } + } + + public void shutdown() { + this.stop = true; + threadPool.shutdown(); + currentThread.interrupt(); + } + + private static final class Request { + + Input input; + ChunkedBytesSupplier data; + Output output; + String nextToken; + boolean last; + + Request(Input input) { + this.input = input; + data = new ChunkedBytesSupplier(); + output = new Output(); + output.add(data); + } + + BytesSupplier getRequest() { + if (nextToken != null) { + return BytesSupplier.wrap("{\"inputs\": [\"" + nextToken + "\"]}"); + } + return input.getData(); + } + + void addResponse(String json) { + JsonObject element = JsonUtils.GSON.fromJson(json, JsonObject.class); + last = element.get("last").getAsBoolean(); + nextToken = element.get("data").getAsString(); + data.appendContent(BytesSupplier.wrap(nextToken), last); + } + } +} diff --git a/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java b/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java index 57ed34e373a..d0ab428b7f1 100644 --- a/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java +++ b/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java @@ -16,6 +16,7 @@ import ai.djl.engine.Engine; import ai.djl.engine.EngineException; import ai.djl.inference.Predictor; +import ai.djl.inference.streaming.ChunkedBytesSupplier; import ai.djl.inference.streaming.PublisherBytesSupplier; import ai.djl.modality.Input; import ai.djl.modality.Output; @@ -30,6 +31,7 @@ import ai.djl.translate.NoopTranslator; import ai.djl.translate.TranslateException; import ai.djl.util.JsonUtils; +import ai.djl.util.RandomUtils; import com.google.gson.JsonElement; import com.google.gson.reflect.TypeToken; @@ -365,6 +367,39 @@ public void testHuggingfaceModel() throws TranslateException, IOException, Model } } + @Test + public void testRollingBatch() throws TranslateException, IOException, ModelException { + Criteria criteria = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optEngine("Python") + .optModelPath(Paths.get("src/test/resources/rolling_batch")) + .build(); + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + List list = new ArrayList<>(); + for (int i = 0; i < 5; ++i) { + Input input = new Input(); + input.add( + "{\"inputs\": \"request" + + i + + "\", \"parameters\": {\"max_length\": " + + (RandomUtils.nextInt(10) + 5) + + "}}"); + input.addProperty("Content-Type", "application/json"); + Output output = predictor.predict(input); + list.add(output); + } + + Output output = list.get(4); + ChunkedBytesSupplier cbs = (ChunkedBytesSupplier) output.getData(); + Assert.assertNull(cbs.pollChunk()); + String ret = cbs.getAsString(); + System.out.println(ret); + Assert.assertTrue(ret.startsWith(" token_request4_")); + } + } + @Test public void testModelException() throws TranslateException, IOException, ModelException { Criteria criteria = diff --git a/engines/python/src/test/resources/rolling_batch/model.py b/engines/python/src/test/resources/rolling_batch/model.py new file mode 100644 index 00000000000..715122d5dc7 --- /dev/null +++ b/engines/python/src/test/resources/rolling_batch/model.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +""" +PyTorch resnet18 model example. +""" + +import logging +import time + +from djl_python import Input +from djl_python import Output + + +class Request(object): + + def __init__(self, input_text: str, max_length, initial: bool = False): + self.input_text = input_text + self.max_length = max_length + self.token_sent = 0 + self.next_token = None + self.initial = initial + + def set_next_token(self, next_token: str): + self.next_token = next_token + self.initial = False + + def get_next_token(self) -> str: + self.token_sent += 1 + return f" token_{self.input_text}_{self.token_sent}" + + def is_last_token(self) -> bool: + return self.token_sent >= self.max_length + + +class RollingBatch(object): + """ + Resnet18 Model implementation. + """ + + def __init__(self): + self.pending_requests = [] + self.initialized = False + self.max_rolling_batch_size = None + + def initialize(self, properties: dict): + """ + Initialize model. + """ + self.max_rolling_batch_size = int( + properties.get("max_rolling_batch_size", "3")) + self.initialized = True + + def inference(self, inputs): + """ + Custom service entry point function. + + :param inputs: the Input object holds a list of numpy array + :return: the Output object to be send back + """ + outputs = Output() + try: + batch_size = inputs.get_batch_size() + if batch_size < len(self.pending_requests): + raise ValueError("mismatch rolling batch requests") + + batch = inputs.get_batches() + self._merge_request(batch) + time.sleep(0.1) + + for i in range(batch_size): + req = self.pending_requests[i] + res = { + "data": req.get_next_token(), + "last": req.is_last_token(), + } + outputs.add_as_json(res, batch_index=i) + + # remove input from pending_request if finished + for i in range(1, batch_size + 1): + if self.pending_requests[batch_size - i].is_last_token(): + self.pending_requests.pop(batch_size - i) + except Exception as e: + logging.exception("rolling batch inference failed") + # error handling + outputs = Output().error(str(e)) + + return outputs + + def _merge_request(self, batch): + for i, item in enumerate(batch): + input_map = item.get_as_json() + data = input_map.pop("inputs", input_map) + parameters = input_map.pop("parameters", {}) + if i >= len(self.pending_requests): + max_length = parameters.pop("max_length", 50) + self.pending_requests.append( + Request(data, max_length, initial=True)) + else: + self.pending_requests[i].set_next_token(data) + + +_service = RollingBatch() + + +def handle(inputs: Input): + """ + Default handler function + """ + if not _service.initialized: + # stateful model + _service.initialize(inputs.get_properties()) + + if inputs.is_empty(): + # initialization request + return None + + return _service.inference(inputs) diff --git a/engines/python/src/test/resources/rolling_batch/serving.properties b/engines/python/src/test/resources/rolling_batch/serving.properties new file mode 100644 index 00000000000..9c4c639af48 --- /dev/null +++ b/engines/python/src/test/resources/rolling_batch/serving.properties @@ -0,0 +1,2 @@ +option.rolling_batch=true +option.max_rolling_batch_size=3