From a6c4298d27f740117d92476dc7fa10df8cd70014 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Tue, 14 May 2024 09:24:29 -0700 Subject: [PATCH] Fix browser init (#797) Update prompt chat is waiting for, which was modified by https://github.com/pytorch/torchchat/pull/476 Modify logging defaults to not create a file in a temp folder without prompting user, but rather just print an info messages Replace few `prints` with `logging.info` This way, information about bandwith achieved will be printed to the console, but not to the web-browser chat window Test plan: ``` % python3 torchchat.py browser stories110M & % curl -L http://127.0.0.1:5000 % curl -d "prompt=Once upon a time" -X POST http://127.0.0.1:5000/chat ``` TODOs: - Add CI that repeats above steps -Figure out if spawning generator from the browser can be avoided Fixes https://github.com/pytorch/torchchat/issues/785 --- chat_in_browser.py | 4 +++- cli.py | 5 +---- generate.py | 6 +++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/chat_in_browser.py b/chat_in_browser.py index f012ab105..e835fa009 100644 --- a/chat_in_browser.py +++ b/chat_in_browser.py @@ -35,7 +35,7 @@ def main(): except: continue - if decoded.startswith("System Prompt") and decoded.endswith(": "): + if decoded.endswith("Do you want to enter a system prompt? Enter y for yes and anything else for no. \n"): print(f"| {decoded}") proc.stdin.write("\n".encode("utf-8")) proc.stdin.flush() @@ -93,6 +93,8 @@ def chat(): model_prefix = "Model: " if output.startswith(model_prefix): output = output[len(model_prefix) :] + else: + print("But output is", output) global convo diff --git a/cli.py b/cli.py index 5223cf3ce..a28a98863 100644 --- a/cli.py +++ b/cli.py @@ -15,10 +15,7 @@ from build.utils import allowable_dtype_names, allowable_params_table, get_device_str from download import download_and_convert, is_model_downloaded -FORMAT = ( - "%(levelname)s: %(asctime)-15s: %(filename)s: %(funcName)s: %(module)s: %(message)s" -) -logging.basicConfig(filename="/tmp/torchchat.log", level=logging.INFO, format=FORMAT) +logging.basicConfig(level=logging.INFO,format="%(message)s") logger = logging.getLogger(__name__) default_device = os.getenv("TORCHCHAT_DEVICE", "fast") diff --git a/generate.py b/generate.py index b7e7c105f..81bcffaee 100644 --- a/generate.py +++ b/generate.py @@ -752,12 +752,12 @@ def callback(x): # Don't continue here.... because we need to report and reset # continue - print( + logging.info( f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" ) - print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + logging.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") if i == 0: - print( + logging.info( f"*** This first iteration will include cold start effects for dynamic import, hardware caches{', JIT compilation' if jit_compile else ''}. ***" ) if start_pos >= max_seq_length: