Skip to content

Commit

Permalink
[python] Adds rolling batch support (deepjavalibrary#828)
Browse files Browse the repository at this point in the history
* [python] Adds rolling batch support

* change to fair lock

---------

Co-authored-by: sindhuso <[email protected]>
  • Loading branch information
frankfliu and sindhuvahinis authored Jun 21, 2023
1 parent be62009 commit 40514cf
Show file tree
Hide file tree
Showing 7 changed files with 355 additions and 6 deletions.
7 changes: 4 additions & 3 deletions engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,10 @@ def inference(self, inputs: Input):
**model_kwargs))
return outputs
if self.task == "text-generation":
tokenized_inputs = self.tokenizer(
input_data, padding=True,
return_tensors="pt").to(self.device)
tokenized_inputs = self.tokenizer(input_data,
padding=True,
return_tensors="pt").to(
self.device)
with torch.no_grad():
output_tokens = self.model.generate(
input_ids=tokenized_inputs.input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
String value = (String) entry.getValue();
if (!"env".equals(key)) {
pyEnv.addParameter(key, value);
properties.put(key, value);
}
logger.debug("{}={}", key, value);
switch (key) {
Expand Down
21 changes: 18 additions & 3 deletions engines/python/src/main/java/ai/djl/python/engine/PyPredictor.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class PyPredictor<I, O> extends Predictor<I, O> {

private PyProcess process;
private int timeout;
private boolean isRollingBatch;
private RollingBatch rollingBatch;

public PyPredictor(
Model model,
Expand All @@ -48,6 +50,12 @@ public PyPredictor(
super(model, translator, device, false);
this.process = process;
this.timeout = timeout;
isRollingBatch = Boolean.parseBoolean(model.getProperty("rolling_batch", "false"));
if (isRollingBatch) {
int maxRollingBatchSize =
Integer.parseInt(model.getProperty("max_rolling_batch_size", "3"));
rollingBatch = new RollingBatch(process, maxRollingBatchSize, timeout);
}
}

/** {@inheritDoc} */
Expand All @@ -62,7 +70,12 @@ public List<O> batchPredict(List<I> inputs) throws TranslateException {
if (first instanceof Input) {
int size = inputs.size();
if (size == 1) {
Output output = process.predict((Input) first, timeout, false);
Output output;
if (isRollingBatch) {
output = rollingBatch.addInput((Input) first, timeout);
} else {
output = process.predict((Input) first, timeout, false);
}
return Collections.singletonList((O) output);
}

Expand Down Expand Up @@ -120,8 +133,7 @@ public List<O> batchPredict(List<I> inputs) throws TranslateException {

/** {@inheritDoc} */
@Override
protected NDList predictInternal(TranslatorContext ctx, NDList ndList)
throws TranslateException {
protected NDList predictInternal(TranslatorContext ctx, NDList ndList) {
Input inputs = new Input();
inputs.addProperty("Content-Type", "tensor/ndlist");
inputs.add(ndList.encode());
Expand All @@ -135,5 +147,8 @@ protected NDList predictInternal(TranslatorContext ctx, NDList ndList)
public void close() {
super.close();
process.stopPythonProcess();
if (rollingBatch != null) {
rollingBatch.shutdown();
}
}
}
168 changes: 168 additions & 0 deletions engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.python.engine;

import ai.djl.inference.streaming.ChunkedBytesSupplier;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;
import ai.djl.util.PairList;

import com.google.gson.JsonObject;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

class RollingBatch implements Runnable {

private static final Logger logger = LoggerFactory.getLogger(RollingBatch.class);

private static ExecutorService threadPool = Executors.newCachedThreadPool();

private PyProcess process;
private int maxRollingBatchSize;
private int timeout;
private boolean stop;
private List<Request> list;
private Thread currentThread;
private ReentrantLock lock;
private Condition canAdd;
private Condition canRead;

RollingBatch(PyProcess process, int maxRollingBatchSize, int timeout) {
this.process = process;
this.maxRollingBatchSize = maxRollingBatchSize;
this.timeout = timeout;
list = new ArrayList<>(3);
lock = new ReentrantLock(true);
canAdd = lock.newCondition();
canRead = lock.newCondition();
threadPool.submit(this);
}

/** {@inheritDoc} */
@Override
public void run() {
currentThread = Thread.currentThread();
while (!stop) {
try {
lock.lock();
if (list.isEmpty()) {
canRead.await();
}

Input batch = new Input();
int size = list.size();
batch.addProperty("batch_size", String.valueOf(size));
for (int i = 0; i < size; ++i) {
Request req = list.get(i);
String prefix = "batch_" + i + ".data";
batch.add(prefix, req.getRequest());
}

// TODO: Handler error case

Output output = process.predict(batch, timeout, false);
PairList<String, BytesSupplier> content = output.getContent();
if (content.size() != size) {
throw new TranslateException(
"Batch output size mismatch, expected: "
+ size
+ ", actual: "
+ content.size());
}
for (int i = 0; i < size; ++i) {
Request status = list.get(i);
String json = content.get(i).getValue().getAsString();
status.addResponse(json);
}
list.removeIf(status -> status.last);
if (list.size() < maxRollingBatchSize) {
canAdd.signal();
}
} catch (InterruptedException e) {
break;
} catch (TranslateException e) {
logger.error("RollingBatch thread died, killing python process.", e);
process.stopPythonProcess();
} finally {
lock.unlock();
}
}
}

public Output addInput(Input input, int timeout) throws TranslateException {
try {
lock.lock();
if (list.size() >= maxRollingBatchSize) {
if (!canAdd.await(timeout, TimeUnit.SECONDS)) {
throw new TranslateException("Time out in: " + timeout);
}
}
Request req = new Request(input);
list.add(req);
canRead.signal();
return req.output;
} catch (InterruptedException e) {
throw new TranslateException("Interrupted", e);
} finally {
lock.unlock();
}
}

public void shutdown() {
this.stop = true;
threadPool.shutdown();
currentThread.interrupt();
}

private static final class Request {

Input input;
ChunkedBytesSupplier data;
Output output;
String nextToken;
boolean last;

Request(Input input) {
this.input = input;
data = new ChunkedBytesSupplier();
output = new Output();
output.add(data);
}

BytesSupplier getRequest() {
if (nextToken != null) {
return BytesSupplier.wrap("{\"inputs\": [\"" + nextToken + "\"]}");
}
return input.getData();
}

void addResponse(String json) {
JsonObject element = JsonUtils.GSON.fromJson(json, JsonObject.class);
last = element.get("last").getAsBoolean();
nextToken = element.get("data").getAsString();
data.appendContent(BytesSupplier.wrap(nextToken), last);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.inference.Predictor;
import ai.djl.inference.streaming.ChunkedBytesSupplier;
import ai.djl.inference.streaming.PublisherBytesSupplier;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
Expand All @@ -30,6 +31,7 @@
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;
import ai.djl.util.RandomUtils;

import com.google.gson.JsonElement;
import com.google.gson.reflect.TypeToken;
Expand Down Expand Up @@ -365,6 +367,39 @@ public void testHuggingfaceModel() throws TranslateException, IOException, Model
}
}

@Test
public void testRollingBatch() throws TranslateException, IOException, ModelException {
Criteria<Input, Output> criteria =
Criteria.builder()
.setTypes(Input.class, Output.class)
.optEngine("Python")
.optModelPath(Paths.get("src/test/resources/rolling_batch"))
.build();
try (ZooModel<Input, Output> model = criteria.loadModel();
Predictor<Input, Output> predictor = model.newPredictor()) {
List<Output> list = new ArrayList<>();
for (int i = 0; i < 5; ++i) {
Input input = new Input();
input.add(
"{\"inputs\": \"request"
+ i
+ "\", \"parameters\": {\"max_length\": "
+ (RandomUtils.nextInt(10) + 5)
+ "}}");
input.addProperty("Content-Type", "application/json");
Output output = predictor.predict(input);
list.add(output);
}

Output output = list.get(4);
ChunkedBytesSupplier cbs = (ChunkedBytesSupplier) output.getData();
Assert.assertNull(cbs.pollChunk());
String ret = cbs.getAsString();
System.out.println(ret);
Assert.assertTrue(ret.startsWith(" token_request4_"));
}
}

@Test
public void testModelException() throws TranslateException, IOException, ModelException {
Criteria<Input, Output> criteria =
Expand Down
Loading

0 comments on commit 40514cf

Please sign in to comment.