diff --git a/src/client_test.py b/src/client_test.py index e4a28b887..c81ae495b 100644 --- a/src/client_test.py +++ b/src/client_test.py @@ -92,7 +92,7 @@ def get_args(prompt, prompt_type=None, chat=False, stream_output=False, document_source_substrings_op='and', document_content_substrings=[], document_content_substrings_op='and', - max_time=20, + max_time=200, # nominally want test to complete, not exercise timeout code (llama.cpp gets stuck behind file lock if prior generation is still going) repetition_penalty=1.0, do_sample=True, seed=0, diff --git a/src/gpt_langchain.py b/src/gpt_langchain.py index f6abdb83e..a2f3ae586 100644 --- a/src/gpt_langchain.py +++ b/src/gpt_langchain.py @@ -2671,6 +2671,7 @@ def get_llm(use_openai_model=False, base_model=model_name, verbose=verbose, truncation_generation=truncation_generation, + max_time=max_time, **gen_kwargs) # pipe.task = "text-generation" # below makes it listen only to our prompt removal, diff --git a/src/h2oai_pipeline.py b/src/h2oai_pipeline.py index e7dc461a7..1062cff3a 100644 --- a/src/h2oai_pipeline.py +++ b/src/h2oai_pipeline.py @@ -18,6 +18,7 @@ def __init__(self, *args, debug=False, chat=False, stream_output=False, base_model=None, stop=None, truncation_generation=None, + max_time=None, verbose=False, **kwargs): """ @@ -65,6 +66,7 @@ def __init__(self, *args, debug=False, chat=False, stream_output=False, self.base_model = base_model self.verbose = verbose self.truncation_generation = truncation_generation + self.max_time = max_time @staticmethod def get_token_count(x, tokenizer): @@ -254,7 +256,8 @@ def _forward(self, model_inputs, **generate_kwargs): model_max_length=self.tokenizer.model_max_length, prompter=self.prompter, stop=stop, - truncation_generation=self.truncation_generation) + truncation_generation=self.truncation_generation, + max_time=self.max_time) generate_kwargs['stopping_criteria'] = self.stopping_criteria generate_kwargs.pop('stop', None) # return super()._forward(model_inputs, **generate_kwargs) diff --git a/src/stopping.py b/src/stopping.py index 67e1e59dd..aa4b17a22 100644 --- a/src/stopping.py +++ b/src/stopping.py @@ -1,3 +1,5 @@ +import time + import torch from transformers import StoppingCriteria, StoppingCriteriaList @@ -7,7 +9,7 @@ class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], stop_words=[], encounters=[], device="cuda", model_max_length=None, tokenizer=None, - truncation_generation=False): + truncation_generation=False, max_time=None): super().__init__() assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match" self.encounters = encounters @@ -21,8 +23,13 @@ def __init__(self, stops=[], stop_words=[], encounters=[], device="cuda", model_ # not setup for handling existing prompt, only look at new tokens, some models like xwin have funny token handling, # and despite new tokens present the block looks back into different sized output and matches the stop token self.look_at_new_tokens_only = max(self.encounters) == 1 + self.max_time = max_time + self.t0 = time.time() def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + if self.max_time is not None and (time.time() - self.t0) > self.max_time: + print("Stopping: Took too long: %s" % self.max_time) + return True # if self.tokenizer: # print('stop: %s' % self.tokenizer.decode(input_ids[0]), flush=True) if self.token_start is None: @@ -54,7 +61,8 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model, human=':', bot=":", model_max_length=None, prompter=None, stop=None, - truncation_generation=False): + truncation_generation=False, + max_time=None): stop_words = [] encounters = [] # FIXME: prompt_dict unused currently @@ -159,7 +167,8 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model, stop_words=stop_words, encounters=encounters, device=device, model_max_length=model_max_length, tokenizer=tokenizer, - truncation_generation=truncation_generation)]) + truncation_generation=truncation_generation, + max_time=max_time)]) else: # nothing to stop on stopping_criteria = StoppingCriteriaList() diff --git a/src/utils_langchain.py b/src/utils_langchain.py index f9645a168..3ded4bbf3 100644 --- a/src/utils_langchain.py +++ b/src/utils_langchain.py @@ -56,10 +56,12 @@ def on_llm_start( def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """Run on new LLM token. Only available when streaming is enabled.""" - if self.tgen0 is not None and self.max_time is not None and (time.time() - self.tgen0) > self.max_time: + if False and \ + self.tgen0 is not None and self.max_time is not None and (time.time() - self.tgen0) > self.max_time: if self.verbose: print("Took too long in StreamingGradioCallbackHandler: %s" % (time.time() - self.tgen0), flush=True) self.text_queue.put(self.stop_signal) + self.do_stop = True else: self.text_queue.put(token)