Skip to content

Commit

Permalink
[python] Fixes device mismatch issue for streaming token (deepjavalib…
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Jun 6, 2023
1 parent fa01765 commit b72e5d8
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 27 deletions.
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,9 @@ def inference(self, inputs: Input):
else:
stream_generator = StreamingUtils.get_stream_generator(
"DeepSpeed")
device = torch.cuda.current_device()
outputs.add_stream_content(
stream_generator(self.model, self.tokenizer,
stream_generator(self.model, self.tokenizer, device,
input_data, **model_kwargs))
return outputs
if self.task == "text-generation":
Expand Down
13 changes: 9 additions & 4 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ def inference(self, inputs):
else:
stream_generator = StreamingUtils.get_stream_generator(
"Accelerate")
device = "cpu" if self.device_id < 0 else f"cuda:{self.device_id}"
outputs.add_stream_content(
stream_generator(self.model, self.tokenizer, data,
**parameters))
device, **parameters))
return outputs

prediction = self.hf_pipeline(data, **parameters)
Expand Down Expand Up @@ -198,7 +199,10 @@ def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
kwargs.pop("tokenizer", None)
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, **kwargs)
hf_pipeline = pipeline(task=task, model=model, tokenizer=tokenizer)
hf_pipeline = pipeline(task=task,
model=model,
tokenizer=tokenizer,
device=self.device_id)

# wrap specific pipeline to support better ux
if task == "conversational":
Expand All @@ -217,16 +221,17 @@ def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
def _init_model_and_tokenizer(self, model_id_or_path: str, **kwargs):
self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path,
padding_side="left")
device = "cpu" if self.device_id < 0 else f"cuda:{self.device_id}"
model_config = AutoConfig.from_pretrained(model_id_or_path,
kwargs=kwargs)
architectures = model_config.architectures
if architectures and architectures[0].endswith(
"ForConditionalGeneration"):
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_id_or_path, **kwargs)
model_id_or_path, **kwargs).to(device)
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, **kwargs)
model_id_or_path, **kwargs).to(device)

@staticmethod
def wrap_conversation_pipeline(hf_pipeline):
Expand Down
31 changes: 10 additions & 21 deletions engines/python/setup/djl_python/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __next__(self):


class StreamingUtils:

DEFAULT_MAX_NEW_TOKENS = 50
SUPPORTED_MODEL_ARCH_SUFFIXES_CAUSAL_LM = ("CausalLM", "GPT2LMHeadModel")
SUPPORTED_MODEL_ARCH_SUFFIXES_SEQ_2_SEQ_LM = (
Expand Down Expand Up @@ -104,7 +103,7 @@ def get_stream_generator(execution_engine: str):

@staticmethod
@torch.inference_mode()
def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs):
def _hf_model_stream_generator(model, tokenizer, inputs, device, **kwargs):
StreamingUtils._validate_inputs(model, inputs)
generic_model_class = StreamingUtils._get_generic_model_class(model)
if not tokenizer.pad_token:
Expand All @@ -115,16 +114,14 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs):

max_new_tokens = kwargs.get("max_new_tokens",
StreamingUtils.DEFAULT_MAX_NEW_TOKENS)
tokenized_inputs = tokenizer(inputs, return_tensors="pt",
padding=True).to(
StreamingUtils._get_current_device())
input_ids = tokenized_inputs["input_ids"]
tokenized_inputs = tokenizer(inputs, return_tensors="pt", padding=True)
input_ids = tokenized_inputs["input_ids"].to(device)
past_key_values = None
decoding_method = StreamingUtils._get_decoding_method(**kwargs)
new_tokens_count = 0
unfinished_sequences = torch.ones((len(inputs), 1),
dtype=torch.long,
device=input_ids.device)
device=device)
stop_generation = False
engine = None
if "engine" in kwargs.keys():
Expand All @@ -135,7 +132,7 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs):

if generic_model_class == "CausalLM":
input_length = input_ids.shape[1]
all_decoder_input_ids = tokenized_inputs["input_ids"]
all_decoder_input_ids = input_ids
is_pad_token_equal_to_eos_token = tokenizer.pad_token == tokenizer.eos_token
attention_mask = input_ids.new_zeros(len(inputs),
input_length + max_new_tokens)
Expand All @@ -145,13 +142,12 @@ def _hf_model_stream_generator(model, tokenizer, inputs, **kwargs):
curr_length = input_length

if generic_model_class == "Seq2SeqLM":
attention_mask = tokenized_inputs["attention_mask"]
attention_mask = tokenized_inputs["attention_mask"].to(device)
decoder_attention_mask = None
encoder_last_hidden_state = None
decoder_input_ids = torch.tensor(
tokenizer.bos_token_id,
device=StreamingUtils._get_current_device()).repeat(
len(inputs)).view(-1, 1)
decoder_input_ids = torch.tensor(tokenizer.bos_token_id,
device=device).repeat(
len(inputs)).view(-1, 1)
all_decoder_input_ids = decoder_input_ids

while True:
Expand Down Expand Up @@ -299,7 +295,7 @@ def _sampling_decoding(logits, input_ids, **kwargs):
processors.append(TypicalLogitsWarper(mass=kwargs["typical_p"]))

logits[-1:, :] = processors(input_ids, logits[-1:, :])
generator = torch.Generator(StreamingUtils._get_current_device())
generator = torch.Generator(input_ids.device)
probs = torch.nn.functional.softmax(logits[-1])
if "manual_seed" in kwargs:
generator.manual_seed(kwargs["manual_seed"])
Expand All @@ -318,10 +314,3 @@ def _get_decoding_method(**kwargs):
return StreamingUtils._sampling_decoding
else:
return StreamingUtils._greedy_decoding

@staticmethod
def _get_current_device():
if torch.cuda.is_available():
return torch.device(torch.cuda.current_device())
else:
return torch.device("cpu")
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/transformers-neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def infer(self, inputs):
model_kwargs["engine"] = "transformers-neuronx"
outputs.add_stream_content(
stream_generator(self.model, self.tokenizer,
input_text, **model_kwargs))
input_text, "cpu", **model_kwargs))
return outputs

encoded_inputs = self.tokenizer.batch_encode_plus(
Expand Down

0 comments on commit b72e5d8

Please sign in to comment.