diff --git a/src/enums.py b/src/enums.py index 204df1175..0f388e749 100644 --- a/src/enums.py +++ b/src/enums.py @@ -943,3 +943,6 @@ def gr_to_lg(image_audio_loaders, '.wmf': 'WMF', '.emf': 'WMF', '.xbm': 'XBM', '.xpm': 'XPM'} VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'} + + +max_stream_string_for_json = 1000 diff --git a/src/gen.py b/src/gen.py index 16f462de5..9f5f25620 100644 --- a/src/gen.py +++ b/src/gen.py @@ -79,7 +79,8 @@ is_json_model, is_vision_model, \ model_state_none0, other_model_state_defaults0, image_batch_image_prompt0, image_batch_final_prompt0, \ tokens_per_image, openai_supports_functiontools, openai_supports_parallel_functiontools, does_support_functiontools, \ - json_object_post_prompt_reminder0, json_code_post_prompt_reminder0, json_code2_post_prompt_reminder0 + json_object_post_prompt_reminder0, json_code_post_prompt_reminder0, json_code2_post_prompt_reminder0, \ + max_stream_string_for_json from utils import set_seed, clear_torch_cache, NullContext, wrapped_partial, EThread, get_githash, \ import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, \ @@ -3443,7 +3444,9 @@ def evaluate( response = r['response'] if response_format in ['json_object', 'json_code']: response_raw = response - response = get_json(response, json_schema_type=json_schema_type) + # this can get expensive if long, so only do if small, else do only at end + if len(str(response)) < max_stream_string_for_json: + response = get_json(response, json_schema_type=json_schema_type) sources = r['sources'] num_prompt_tokens = r['num_prompt_tokens'] ntokens = r.get('ntokens') @@ -3462,6 +3465,9 @@ def evaluate( sources=sources, ntokens=ntokens, )) + if response_format in ['json_object', 'json_code']: + # always do at end, in case didn't before due to length + response = get_json(response, json_schema_type=json_schema_type) save_dict.update(dict(prompt=prompt, output=response, where_from="run_qa_db", extra_dict=extra_dict)) yield dict(response=response, sources=sources, save_dict=save_dict, llm_answers=llm_answers, response_no_refs=response_no_refs, sources_str=sources_str, prompt_raw=prompt_raw) @@ -3557,7 +3563,8 @@ def evaluate( if base_model in ['o1-mini', 'o1-preview']: gen_server_kwargs['max_completion_tokens'] = gen_server_kwargs.pop('max_tokens') max_reasoning_tokens = int(os.getenv("MAX_REASONING_TOKENS", 25000)) - gen_server_kwargs['max_completion_tokens'] = max_reasoning_tokens + max(100, gen_server_kwargs['max_completion_tokens']) + gen_server_kwargs['max_completion_tokens'] = max_reasoning_tokens + max(100, gen_server_kwargs[ + 'max_completion_tokens']) gen_server_kwargs['temperature'] = 1.0 gen_server_kwargs.pop('presence_penalty', None) gen_server_kwargs.pop('n', None) @@ -3565,7 +3572,8 @@ def evaluate( gen_server_kwargs.pop('top_p', None) try: if inf_type in ['vllm', 'vllm_chat'] and chosen_model_state['json_vllm']: - response_format_real = response_format if not (guided_json or guided_regex or guided_choice or guided_grammar) else 'text' + response_format_real = response_format if not ( + guided_json or guided_regex or guided_choice or guided_grammar) else 'text' vllm_extra_dict = get_vllm_extra_dict(tokenizer, stop_sequences=stop_sequences, response_format=response_format_real, guided_json=guided_json, @@ -3613,7 +3621,8 @@ def evaluate( sanitize_bot_response=sanitize_bot_response) if response_format in ['json_object', 'json_code']: response_raw = response - response = get_json(response, json_schema_type=json_schema_type) + if len(str(response)) < max_stream_string_for_json: + response = get_json(response, json_schema_type=json_schema_type) yield dict(response=response, sources=sources, save_dict={}, llm_answers=dict(response_raw=response_raw), response_no_refs=response, sources_str='', prompt_raw='') @@ -3622,6 +3631,12 @@ def evaluate( print("Took too long for OpenAI or VLLM: %s" % (time.time() - tgen0), flush=True) break time.sleep(0.005) + if response_format in ['json_object', 'json_code']: + # always do at end, in case didn't before due to length + response = get_json(response, json_schema_type=json_schema_type) + yield dict(response=response, sources=sources, save_dict={}, + llm_answers=dict(response_raw=response_raw), + response_no_refs=response, sources_str='', prompt_raw='') elif inf_type in ['vllm_chat', 'openai_chat']: other_dict = dict(timeout=max_time) if system_prompt in [None, 'None', 'auto']: @@ -3681,7 +3696,8 @@ def evaluate( response = text if response_format in ['json_object', 'json_code']: response_raw = response - response = get_json(response, json_schema_type=json_schema_type) + if len(str(response)) < max_stream_string_for_json: + response = get_json(response, json_schema_type=json_schema_type) else: # NOTE: If some stream failure like wrong model, don't get back response and no failure tgen0 = time.time() @@ -3701,6 +3717,12 @@ def evaluate( print("Took too long for OpenAI or VLLM Chat: %s" % (time.time() - tgen0), flush=True) break + if response_format in ['json_object', 'json_code']: + # always do at end, in case didn't before due to length + response = get_json(response, json_schema_type=json_schema_type) + yield dict(response=response, sources=sources, save_dict={}, + llm_answers=dict(response_raw=response_raw), + response_no_refs=response, sources_str='', prompt_raw='') else: raise RuntimeError("No such OpenAI mode: %s" % inference_server) finally: @@ -3783,7 +3805,8 @@ def evaluate( for response1 in get_llava_stream(**llava_kwargs): if response_format in ['json_object', 'json_code']: response_raw = response1 - response = get_json(response1, json_schema_type=json_schema_type) + if len(str(response)) < max_stream_string_for_json: + response = get_json(response1, json_schema_type=json_schema_type) else: response = response1 yield dict(response=response, sources=[], save_dict={}, error='', @@ -3794,7 +3817,12 @@ def evaluate( if verbose: print("Took too long for TGI: %s" % (time.time() - tgen0), flush=True) break - + if response_format in ['json_object', 'json_code']: + # always do at end, in case didn't before due to length + response = get_json(response, json_schema_type=json_schema_type) + yield dict(response=response, sources=sources, save_dict={}, + llm_answers=dict(response_raw=response_raw), + response_no_refs=response, sources_str='', prompt_raw='') else: if gr_client is not None: # Note: h2oGPT gradio server could handle input token size issues for prompt, @@ -3992,16 +4020,27 @@ def evaluate( gener = gr_client.simple_stream(**gr_stream_kwargs) response = '' response_raw = '' - for res_dict in gener: - if 'response' in res_dict: - response = res_dict['response'] + res_dict = {} + for res_dict1 in gener: + if 'response' in res_dict1: + response = res_dict1['response'] if response_format in ['json_object', 'json_code']: response_raw = response - response = get_json(response, json_schema_type=json_schema_type) - res_dict['response'] = response - res_dict['llm_answers'] = res_dict.get('llm_answers', {}) - res_dict['llm_answers']['response_raw'] = response_raw + if len(str(response)) < max_stream_string_for_json: + response = get_json(response, json_schema_type=json_schema_type) + res_dict1['response'] = response + res_dict1['llm_answers'] = res_dict1.get('llm_answers', {}) + res_dict1['llm_answers']['response_raw'] = response_raw + res_dict = res_dict1 + yield res_dict1 + if response_format in ['json_object', 'json_code']: + # always do at end, in case didn't before due to length + response = get_json(response, json_schema_type=json_schema_type) + res_dict['response'] = response + res_dict['llm_answers'] = res_dict.get('llm_answers', {}) + res_dict['llm_answers']['response_raw'] = response_raw yield res_dict + # listen to inner gradio num_prompt_tokens += res_dict.get('save_dict', {}).get('extra_dict', {}).get('num_prompt_tokens', num_prompt_tokens) @@ -4059,7 +4098,8 @@ def evaluate( sources = [] if response_format in ['json_object', 'json_code']: response_raw = response - response = get_json(response, json_schema_type=json_schema_type) + if len(str(response)) < max_stream_string_for_json: + response = get_json(response, json_schema_type=json_schema_type) yield dict(response=response, sources=sources, save_dict={}, llm_answers=dict(response_raw=response_raw), response_no_refs=response, sources_str='', prompt_raw='') @@ -4068,6 +4108,12 @@ def evaluate( if verbose: print("Took too long for TGI: %s" % (time.time() - tgen0), flush=True) break + if response_format in ['json_object', 'json_code']: + # always do at end, in case didn't before due to length + response = get_json(response, json_schema_type=json_schema_type) + yield dict(response=response, sources=sources, save_dict={}, + llm_answers=dict(response_raw=response_raw), + response_no_refs=response, sources_str='', prompt_raw='') else: raise RuntimeError("Failed to get client: %s" % inference_server) if isinstance(model, GradioClient) and not regenerate_gradio_clients and gr_client is not None: @@ -4257,7 +4303,8 @@ def evaluate( sanitize_bot_response=sanitize_bot_response) if response_format in ['json_object', 'json_code']: response_raw = response - response = get_json(response, json_schema_type=json_schema_type) + if len(str(response)) < max_stream_string_for_json: + response = get_json(response, json_schema_type=json_schema_type) ret = dict(response=response, sources=sources, save_dict=save_dict, llm_answers=dict(response_raw=response_raw), response_no_refs=response, sources_str='', prompt_raw=prompt) @@ -4267,6 +4314,11 @@ def evaluate( if verbose: print("Took too long for Torch: %s" % (time.time() - tgen0), flush=True) break + if response_format in ['json_object', 'json_code']: + response = get_json(response, json_schema_type=json_schema_type) + ret = dict(response=response, sources=sources, save_dict=save_dict, + llm_answers=dict(response_raw=response_raw), + response_no_refs=response, sources_str='', prompt_raw=prompt) if stream_output: # will yield at end if required # yield if anything left over as can happen (FIXME: Understand better) diff --git a/src/version.py b/src/version.py index 3a03496ce..df1cb3f1d 100644 --- a/src/version.py +++ b/src/version.py @@ -1 +1 @@ -__version__ = "d6fa55bc87b9de6d13278b148946b9e4d54bceb9" +__version__ = "1f0a8f36fd8a811958eef2c671bc343d77a187fb"