Skip to content

Commit

Permalink
Extend time for llama cpp test since blocked by generation of prior l…
Browse files Browse the repository at this point in the history
…ong generation and file lock. Add timeout for torch via stopping
  • Loading branch information
pseudotensor committed Mar 27, 2024
1 parent b43b8ea commit 06b1bf2
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_args(prompt, prompt_type=None, chat=False, stream_output=False,
document_source_substrings_op='and',
document_content_substrings=[],
document_content_substrings_op='and',
max_time=20,
max_time=200, # nominally want test to complete, not exercise timeout code (llama.cpp gets stuck behind file lock if prior generation is still going)
repetition_penalty=1.0,
do_sample=True,
seed=0,
Expand Down
1 change: 1 addition & 0 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2671,6 +2671,7 @@ def get_llm(use_openai_model=False,
base_model=model_name,
verbose=verbose,
truncation_generation=truncation_generation,
max_time=max_time,
**gen_kwargs)
# pipe.task = "text-generation"
# below makes it listen only to our prompt removal,
Expand Down
5 changes: 4 additions & 1 deletion src/h2oai_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, *args, debug=False, chat=False, stream_output=False,
base_model=None,
stop=None,
truncation_generation=None,
max_time=None,
verbose=False,
**kwargs):
"""
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(self, *args, debug=False, chat=False, stream_output=False,
self.base_model = base_model
self.verbose = verbose
self.truncation_generation = truncation_generation
self.max_time = max_time

@staticmethod
def get_token_count(x, tokenizer):
Expand Down Expand Up @@ -254,7 +256,8 @@ def _forward(self, model_inputs, **generate_kwargs):
model_max_length=self.tokenizer.model_max_length,
prompter=self.prompter,
stop=stop,
truncation_generation=self.truncation_generation)
truncation_generation=self.truncation_generation,
max_time=self.max_time)
generate_kwargs['stopping_criteria'] = self.stopping_criteria
generate_kwargs.pop('stop', None)
# return super()._forward(model_inputs, **generate_kwargs)
Expand Down
15 changes: 12 additions & 3 deletions src/stopping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time

import torch
from transformers import StoppingCriteria, StoppingCriteriaList

Expand All @@ -7,7 +9,7 @@
class StoppingCriteriaSub(StoppingCriteria):

def __init__(self, stops=[], stop_words=[], encounters=[], device="cuda", model_max_length=None, tokenizer=None,
truncation_generation=False):
truncation_generation=False, max_time=None):
super().__init__()
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
self.encounters = encounters
Expand All @@ -21,8 +23,13 @@ def __init__(self, stops=[], stop_words=[], encounters=[], device="cuda", model_
# not setup for handling existing prompt, only look at new tokens, some models like xwin have funny token handling,
# and despite new tokens present the block looks back into different sized output and matches the stop token
self.look_at_new_tokens_only = max(self.encounters) == 1
self.max_time = max_time
self.t0 = time.time()

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.max_time is not None and (time.time() - self.t0) > self.max_time:
print("Stopping: Took too long: %s" % self.max_time)
return True
# if self.tokenizer:
# print('stop: %s' % self.tokenizer.decode(input_ids[0]), flush=True)
if self.token_start is None:
Expand Down Expand Up @@ -54,7 +61,8 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model,
human='<human>:', bot="<bot>:", model_max_length=None,
prompter=None,
stop=None,
truncation_generation=False):
truncation_generation=False,
max_time=None):
stop_words = []
encounters = []
# FIXME: prompt_dict unused currently
Expand Down Expand Up @@ -159,7 +167,8 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model,
stop_words=stop_words,
encounters=encounters, device=device,
model_max_length=model_max_length, tokenizer=tokenizer,
truncation_generation=truncation_generation)])
truncation_generation=truncation_generation,
max_time=max_time)])
else:
# nothing to stop on
stopping_criteria = StoppingCriteriaList()
Expand Down
4 changes: 3 additions & 1 deletion src/utils_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def on_llm_start(

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
if self.tgen0 is not None and self.max_time is not None and (time.time() - self.tgen0) > self.max_time:
if False and \
self.tgen0 is not None and self.max_time is not None and (time.time() - self.tgen0) > self.max_time:
if self.verbose:
print("Took too long in StreamingGradioCallbackHandler: %s" % (time.time() - self.tgen0), flush=True)
self.text_queue.put(self.stop_signal)
self.do_stop = True
else:
self.text_queue.put(token)

Expand Down

0 comments on commit 06b1bf2

Please sign in to comment.