diff --git a/generate.py b/generate.py index eb294318f..4c292cf4e 100644 --- a/generate.py +++ b/generate.py @@ -314,6 +314,7 @@ def go_gradio(**kwargs): ) with gr.Row(): clear = gr.Button("Clear") + stop_btn = gr.Button(value="Stop") flag_btn = gr.Button("Flag") else: text_output = gr.Textbox(lines=5, label="Output") @@ -396,7 +397,7 @@ def go_gradio(**kwargs): ) if not kwargs['chat']: submit = gr.Button("Submit") - submit.click(fun, inputs=inputs_list, outputs=text_output) + click_event = submit.click(fun, inputs=inputs_list, outputs=text_output) # examples after submit or any other buttons for chat or no chat if kwargs['examples'] is not None and kwargs['show_examples']: @@ -441,7 +442,7 @@ def bot(*args): yield history return - instruction.submit(user, + click_event = instruction.submit(user, inputs_list + [text_output], # matching user() inputs [instruction, text_output], queue=stream_output).then( bot, inputs_list + [text_output], text_output @@ -451,6 +452,9 @@ def bot(*args): # callback for logging flagged input/output callback.setup(inputs_list + [text_output], "flagged_data_points") flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False) + # don't pass text_output, don't want to clear output, just stop it + # FIXME: have to click once to stop output and second time to stop GPUs going + stop_btn.click(lambda: None, None, None, cancels=[click_event, click_event], queue=False) demo.queue(concurrency_count=1) favicon_path = "h2o-logo.svg"