Skip to content

Commit

Permalink
Add ignore_eos support to deep-speed and trt-llm backends (#69)
Browse files Browse the repository at this point in the history
In order to make benchmarks more apples-to-apples it would be useful to
have an ability to keep generating output even if `eos` token was
emitted. I've added support for that to deep speed and trt-llm backends.
  • Loading branch information
Timur Abishev authored Nov 17, 2023
1 parent c91c354 commit cdb2520
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
2 changes: 1 addition & 1 deletion deepspeed-mii/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ model_metadata:
- text-generation
python_version: py311
requirements:
- deepspeed-mii==0.1.0
- deepspeed-mii==0.1.1
resources:
cpu: "3"
memory: 14Gi
Expand Down
3 changes: 2 additions & 1 deletion deepspeed-mii/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def load(self):
def predict(self, request: Dict):
prompt = request.pop("prompt")
generate_args = {
"max_new_tokens": request.pop("max_length", DEFAULT_RESPONSE_MAX_LENGTH)
"max_new_tokens": request.pop("max_length", DEFAULT_RESPONSE_MAX_LENGTH),
"ignore_eos": request.pop("ignore_eos", False),
}

if request.pop("stream", False):
Expand Down
37 changes: 24 additions & 13 deletions llama/llama-2-7b-trt-llm/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from client import TritonClient, UserData
from transformers import AutoTokenizer
from utils import download_engine, prepare_grpc_tensor

TRITON_MODEL_REPOSITORY_PATH = Path("/packages/inflight_batcher_llm/")
Expand All @@ -21,7 +22,10 @@ def load(self):
tensor_parallel_count = self._config["model_metadata"].get(
"tensor_parallelism", 1
)
is_hf_token = "hf_access_token" in self._secrets._base_secrets.keys()
if "hf_access_token" in self._secrets._base_secrets.keys():
hf_access_token = self._secrets["hf_access_token"]
else:
hf_access_token = None
is_external_engine_repo = "engine_repository" in self._config["model_metadata"]

# Instantiate TritonClient
Expand All @@ -36,20 +40,23 @@ def load(self):
download_engine(
engine_repository=self._config["model_metadata"]["engine_repository"],
fp=self._data_dir,
auth_token=self._secrets["hf_access_token"] if is_hf_token else None,
auth_token=hf_access_token,
)

# Load Triton Server and model
env = {
"triton_tokenizer_repository": self._config["model_metadata"][
"tokenizer_repository"
],
}
if is_hf_token:
env["HUGGING_FACE_HUB_TOKEN"] = self._secrets["hf_access_token"]
tokenizer_repository = self._config["model_metadata"]["tokenizer_repository"]
env = {"triton_tokenizer_repository": tokenizer_repository}
if hf_access_token is not None:
env["HUGGING_FACE_HUB_TOKEN"] = hf_access_token

self.triton_client.load_server_and_model(env=env)

# setup eos token
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_repository, token=hf_access_token
)
self.eos_token_id = tokenizer.eos_token_id

def predict(self, model_input):
user_data = UserData()
model_name = "ensemble"
Expand All @@ -61,7 +68,7 @@ def predict(self, model_input):
bad_words_list = model_input.get("bad_words_list", [""])
stop_words_list = model_input.get("stop_words_list", [""])
repetition_penalty = model_input.get("repetition_penalty", 1.0)
end_id = model_input.get("end_id", 2) # EOS token for Llama2
ignore_eos = model_input.get("ignore_eos", False)

input0 = [[prompt]]
input0_data = np.array(input0).astype(object)
Expand All @@ -73,8 +80,6 @@ def predict(self, model_input):
beam_width = [[beam_width]]
beam_width_data = np.array(beam_width, dtype=np.uint32)
repetition_penalty_data = np.array([[repetition_penalty]], dtype=np.float32)
end_id = [[end_id]]
end_id_data = np.array(end_id, dtype=np.uint32)

inputs = [
prepare_grpc_tensor("text_input", input0_data),
Expand All @@ -84,9 +89,15 @@ def predict(self, model_input):
prepare_grpc_tensor("stream", streaming_data),
prepare_grpc_tensor("beam_width", beam_width_data),
prepare_grpc_tensor("repetition_penalty", repetition_penalty_data),
prepare_grpc_tensor("end_id", end_id_data),
]

if not ignore_eos:
end_id_data = np.array([[self.eos_token_id]], dtype=np.uint32)
inputs.append(prepare_grpc_tensor("end_id", end_id_data))
else:
# do nothing, trt-llm by default doesn't stop on `eos`
pass

# Start GRPC stream in a separate thread
stream_thread = Thread(
target=self.triton_client.start_grpc_stream,
Expand Down

0 comments on commit cdb2520

Please sign in to comment.