Skip to content

Commit

Permalink
Don't stream with repair_json once long, else gets too slow, do at en…
Browse files Browse the repository at this point in the history
…d always
  • Loading branch information
pseudotensor committed Nov 1, 2024
1 parent 1f0a8f3 commit 89d4051
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 18 deletions.
3 changes: 3 additions & 0 deletions src/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,3 +943,6 @@ def gr_to_lg(image_audio_loaders,
'.wmf': 'WMF', '.emf': 'WMF', '.xbm': 'XBM', '.xpm': 'XPM'}

VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'}


max_stream_string_for_json = 1000
86 changes: 69 additions & 17 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@
is_json_model, is_vision_model, \
model_state_none0, other_model_state_defaults0, image_batch_image_prompt0, image_batch_final_prompt0, \
tokens_per_image, openai_supports_functiontools, openai_supports_parallel_functiontools, does_support_functiontools, \
json_object_post_prompt_reminder0, json_code_post_prompt_reminder0, json_code2_post_prompt_reminder0
json_object_post_prompt_reminder0, json_code_post_prompt_reminder0, json_code2_post_prompt_reminder0, \
max_stream_string_for_json

from utils import set_seed, clear_torch_cache, NullContext, wrapped_partial, EThread, get_githash, \
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, \
Expand Down Expand Up @@ -3443,7 +3444,9 @@ def evaluate(
response = r['response']
if response_format in ['json_object', 'json_code']:
response_raw = response
response = get_json(response, json_schema_type=json_schema_type)
# this can get expensive if long, so only do if small, else do only at end
if len(str(response)) < max_stream_string_for_json:
response = get_json(response, json_schema_type=json_schema_type)
sources = r['sources']
num_prompt_tokens = r['num_prompt_tokens']
ntokens = r.get('ntokens')
Expand All @@ -3462,6 +3465,9 @@ def evaluate(
sources=sources,
ntokens=ntokens,
))
if response_format in ['json_object', 'json_code']:
# always do at end, in case didn't before due to length
response = get_json(response, json_schema_type=json_schema_type)
save_dict.update(dict(prompt=prompt, output=response, where_from="run_qa_db", extra_dict=extra_dict))
yield dict(response=response, sources=sources, save_dict=save_dict, llm_answers=llm_answers,
response_no_refs=response_no_refs, sources_str=sources_str, prompt_raw=prompt_raw)
Expand Down Expand Up @@ -3557,15 +3563,17 @@ def evaluate(
if base_model in ['o1-mini', 'o1-preview']:
gen_server_kwargs['max_completion_tokens'] = gen_server_kwargs.pop('max_tokens')
max_reasoning_tokens = int(os.getenv("MAX_REASONING_TOKENS", 25000))
gen_server_kwargs['max_completion_tokens'] = max_reasoning_tokens + max(100, gen_server_kwargs['max_completion_tokens'])
gen_server_kwargs['max_completion_tokens'] = max_reasoning_tokens + max(100, gen_server_kwargs[
'max_completion_tokens'])
gen_server_kwargs['temperature'] = 1.0
gen_server_kwargs.pop('presence_penalty', None)
gen_server_kwargs.pop('n', None)
gen_server_kwargs.pop('frequency_penalty', None)
gen_server_kwargs.pop('top_p', None)
try:
if inf_type in ['vllm', 'vllm_chat'] and chosen_model_state['json_vllm']:
response_format_real = response_format if not (guided_json or guided_regex or guided_choice or guided_grammar) else 'text'
response_format_real = response_format if not (
guided_json or guided_regex or guided_choice or guided_grammar) else 'text'
vllm_extra_dict = get_vllm_extra_dict(tokenizer, stop_sequences=stop_sequences,
response_format=response_format_real,
guided_json=guided_json,
Expand Down Expand Up @@ -3613,7 +3621,8 @@ def evaluate(
sanitize_bot_response=sanitize_bot_response)
if response_format in ['json_object', 'json_code']:
response_raw = response
response = get_json(response, json_schema_type=json_schema_type)
if len(str(response)) < max_stream_string_for_json:
response = get_json(response, json_schema_type=json_schema_type)
yield dict(response=response, sources=sources, save_dict={},
llm_answers=dict(response_raw=response_raw),
response_no_refs=response, sources_str='', prompt_raw='')
Expand All @@ -3622,6 +3631,12 @@ def evaluate(
print("Took too long for OpenAI or VLLM: %s" % (time.time() - tgen0), flush=True)
break
time.sleep(0.005)
if response_format in ['json_object', 'json_code']:
# always do at end, in case didn't before due to length
response = get_json(response, json_schema_type=json_schema_type)
yield dict(response=response, sources=sources, save_dict={},
llm_answers=dict(response_raw=response_raw),
response_no_refs=response, sources_str='', prompt_raw='')
elif inf_type in ['vllm_chat', 'openai_chat']:
other_dict = dict(timeout=max_time)
if system_prompt in [None, 'None', 'auto']:
Expand Down Expand Up @@ -3681,7 +3696,8 @@ def evaluate(
response = text
if response_format in ['json_object', 'json_code']:
response_raw = response
response = get_json(response, json_schema_type=json_schema_type)
if len(str(response)) < max_stream_string_for_json:
response = get_json(response, json_schema_type=json_schema_type)
else:
# NOTE: If some stream failure like wrong model, don't get back response and no failure
tgen0 = time.time()
Expand All @@ -3701,6 +3717,12 @@ def evaluate(
print("Took too long for OpenAI or VLLM Chat: %s" % (time.time() - tgen0),
flush=True)
break
if response_format in ['json_object', 'json_code']:
# always do at end, in case didn't before due to length
response = get_json(response, json_schema_type=json_schema_type)
yield dict(response=response, sources=sources, save_dict={},
llm_answers=dict(response_raw=response_raw),
response_no_refs=response, sources_str='', prompt_raw='')
else:
raise RuntimeError("No such OpenAI mode: %s" % inference_server)
finally:
Expand Down Expand Up @@ -3783,7 +3805,8 @@ def evaluate(
for response1 in get_llava_stream(**llava_kwargs):
if response_format in ['json_object', 'json_code']:
response_raw = response1
response = get_json(response1, json_schema_type=json_schema_type)
if len(str(response)) < max_stream_string_for_json:
response = get_json(response1, json_schema_type=json_schema_type)
else:
response = response1
yield dict(response=response, sources=[], save_dict={}, error='',
Expand All @@ -3794,7 +3817,12 @@ def evaluate(
if verbose:
print("Took too long for TGI: %s" % (time.time() - tgen0), flush=True)
break

if response_format in ['json_object', 'json_code']:
# always do at end, in case didn't before due to length
response = get_json(response, json_schema_type=json_schema_type)
yield dict(response=response, sources=sources, save_dict={},
llm_answers=dict(response_raw=response_raw),
response_no_refs=response, sources_str='', prompt_raw='')
else:
if gr_client is not None:
# Note: h2oGPT gradio server could handle input token size issues for prompt,
Expand Down Expand Up @@ -3992,16 +4020,27 @@ def evaluate(
gener = gr_client.simple_stream(**gr_stream_kwargs)
response = ''
response_raw = ''
for res_dict in gener:
if 'response' in res_dict:
response = res_dict['response']
res_dict = {}
for res_dict1 in gener:
if 'response' in res_dict1:
response = res_dict1['response']
if response_format in ['json_object', 'json_code']:
response_raw = response
response = get_json(response, json_schema_type=json_schema_type)
res_dict['response'] = response
res_dict['llm_answers'] = res_dict.get('llm_answers', {})
res_dict['llm_answers']['response_raw'] = response_raw
if len(str(response)) < max_stream_string_for_json:
response = get_json(response, json_schema_type=json_schema_type)
res_dict1['response'] = response
res_dict1['llm_answers'] = res_dict1.get('llm_answers', {})
res_dict1['llm_answers']['response_raw'] = response_raw
res_dict = res_dict1
yield res_dict1
if response_format in ['json_object', 'json_code']:
# always do at end, in case didn't before due to length
response = get_json(response, json_schema_type=json_schema_type)
res_dict['response'] = response
res_dict['llm_answers'] = res_dict.get('llm_answers', {})
res_dict['llm_answers']['response_raw'] = response_raw
yield res_dict

# listen to inner gradio
num_prompt_tokens += res_dict.get('save_dict', {}).get('extra_dict', {}).get('num_prompt_tokens',
num_prompt_tokens)
Expand Down Expand Up @@ -4059,7 +4098,8 @@ def evaluate(
sources = []
if response_format in ['json_object', 'json_code']:
response_raw = response
response = get_json(response, json_schema_type=json_schema_type)
if len(str(response)) < max_stream_string_for_json:
response = get_json(response, json_schema_type=json_schema_type)
yield dict(response=response, sources=sources, save_dict={},
llm_answers=dict(response_raw=response_raw),
response_no_refs=response, sources_str='', prompt_raw='')
Expand All @@ -4068,6 +4108,12 @@ def evaluate(
if verbose:
print("Took too long for TGI: %s" % (time.time() - tgen0), flush=True)
break
if response_format in ['json_object', 'json_code']:
# always do at end, in case didn't before due to length
response = get_json(response, json_schema_type=json_schema_type)
yield dict(response=response, sources=sources, save_dict={},
llm_answers=dict(response_raw=response_raw),
response_no_refs=response, sources_str='', prompt_raw='')
else:
raise RuntimeError("Failed to get client: %s" % inference_server)
if isinstance(model, GradioClient) and not regenerate_gradio_clients and gr_client is not None:
Expand Down Expand Up @@ -4257,7 +4303,8 @@ def evaluate(
sanitize_bot_response=sanitize_bot_response)
if response_format in ['json_object', 'json_code']:
response_raw = response
response = get_json(response, json_schema_type=json_schema_type)
if len(str(response)) < max_stream_string_for_json:
response = get_json(response, json_schema_type=json_schema_type)
ret = dict(response=response, sources=sources, save_dict=save_dict,
llm_answers=dict(response_raw=response_raw),
response_no_refs=response, sources_str='', prompt_raw=prompt)
Expand All @@ -4267,6 +4314,11 @@ def evaluate(
if verbose:
print("Took too long for Torch: %s" % (time.time() - tgen0), flush=True)
break
if response_format in ['json_object', 'json_code']:
response = get_json(response, json_schema_type=json_schema_type)
ret = dict(response=response, sources=sources, save_dict=save_dict,
llm_answers=dict(response_raw=response_raw),
response_no_refs=response, sources_str='', prompt_raw=prompt)
if stream_output:
# will yield at end if required
# yield if anything left over as can happen (FIXME: Understand better)
Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "d6fa55bc87b9de6d13278b148946b9e4d54bceb9"
__version__ = "1f0a8f36fd8a811958eef2c671bc343d77a187fb"

0 comments on commit 89d4051

Please sign in to comment.