Skip to content

Commit

Permalink
chat in browser
Browse files Browse the repository at this point in the history
  • Loading branch information
Olivia-liu committed Apr 17, 2024
1 parent 26f5d65 commit 33db115
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
57 changes: 57 additions & 0 deletions chat_in_browser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# -*- coding: UTF-8 -*-
"""
hello_jinja2: Get start with Jinja2 templates
"""
from flask import Flask, render_template, request
from cli import add_arguments_for_generate, arg_init, check_args
from generate import main as generate_main
import subprocess
import sys


convo = ""

def create_app(*args):
app = Flask(__name__)

import subprocess
# create a new process and set up pipes for communication
proc = subprocess.Popen(["python", "generate.py", *args],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE)

@app.route('/')
def main():
output = ""
while True:
line = proc.stdout.readline()
if line.decode('utf-8').startswith("What is your prompt?"):
break
output += line.decode('utf-8').strip() + "\n"
return render_template('chat.html', convo="Hello! What is your prompt?")

@app.route('/chat', methods=['POST'])
def chat():
# Retrieve the HTTP POST request parameter value from 'request.form' dictionary
_prompt = request.form.get('prompt')
proc.stdin.write((_prompt + "\n").encode('utf-8'))
proc.stdin.flush()

output = ""
while True:
line = proc.stdout.readline()
print("\tprinting `line`")
print(line.decode('utf-8') + "\n")
if line.decode('utf-8').startswith("What is your prompt?"):
break
output += line.decode('utf-8').strip() + "\n"

global convo

if _prompt:
convo += "Your prompt:\n" + _prompt + "\n\n"
convo += "My response:\n" + output + "\n\n"

return render_template('chat.html', convo=convo)

return app
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _main(
for i in range(start, num_samples):
device_sync(device=builder_args.device)
if i >= 0 and chat_mode:
prompt = input("What is your prompt? ")
prompt = input("What is your prompt? \n")
if is_chat:
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
encoded = encode_tokens(
Expand Down
15 changes: 15 additions & 0 deletions templates/chat.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>torchchat</title>
</head>
<body>
<pre>{{ convo }}</pre>
<form action="chat" method="post">
<label for="username">Prompt: </label>
<input type="text" id="prompt" name="prompt"><br>
<input type="submit" value="SEND">
</form>
</body>
</html>

0 comments on commit 33db115

Please sign in to comment.