Skip to content

Commit

Permalink
Add stop button for interface automaticallY
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton committed Oct 13, 2022
1 parent 97d74ca commit af5ace6
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 21 deletions.
39 changes: 20 additions & 19 deletions gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,16 @@ def get_cancel_function(
)
fn_to_comp[fn_index] = [Context.root_block.blocks[o] for o in dep["outputs"]]

async def cancel(session_hash: str):
from gradio.components import _Keywords
async def cancel(session_hash: str) -> None:

output = {}
for comps in fn_to_comp.values():
for comp in comps:
output[comp] = update(value=_Keywords.NO_VALUE)
task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp])

for task in asyncio.all_tasks():
if task.get_name() in task_ids:
matching_id = None
for id_ in task_ids:
if task.get_name() == id_:
matching_id = id_
fn_index_ = int(matching_id.split("_")[1])
task.cancel()
await asyncio.gather(task, return_exceptions=True)
for comp in fn_to_comp[fn_index_]:
output[comp] = update(value=None)
return output
matching_tasks = [
task for task in asyncio.all_tasks() if task.get_name() in task_ids
]
for task in matching_tasks:
task.cancel()
await asyncio.gather(*matching_tasks, return_exceptions=True)

return (
cancel,
Expand All @@ -56,7 +45,7 @@ def set_cancel_events(block: Block, event_name: str, cancels: List[Dict[str, Any
event_name,
cancel_fn,
inputs=None,
outputs=output,
outputs=None,
queue=False,
preprocess=False,
cancels=fn_indices_to_cancel,
Expand Down Expand Up @@ -93,6 +82,7 @@ def change(
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
Expand Down Expand Up @@ -147,6 +137,7 @@ def click(
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
Expand Down Expand Up @@ -202,6 +193,7 @@ def submit(
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
Expand Down Expand Up @@ -256,6 +248,7 @@ def edit(
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
Expand Down Expand Up @@ -310,6 +303,7 @@ def clear(
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
Expand Down Expand Up @@ -364,6 +358,7 @@ def play(
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
Expand Down Expand Up @@ -416,6 +411,7 @@ def pause(
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
Expand Down Expand Up @@ -468,6 +464,7 @@ def stop(
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
Expand Down Expand Up @@ -522,6 +519,7 @@ def stream(
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
self.streaming = True
Expand Down Expand Up @@ -561,6 +559,7 @@ def blur(
queue: Optional[bool] = None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
Expand All @@ -576,6 +575,7 @@ def blur(
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.

Expand All @@ -592,3 +592,4 @@ def blur(
postprocess=postprocess,
queue=queue,
)
set_cancel_events(self, "blur", cancels)
46 changes: 44 additions & 2 deletions gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from mdit_py_plugins.footnote import footnote_plugin

from gradio import Examples, interpretation, utils
from gradio.blocks import Blocks
from gradio.blocks import Blocks, update
from gradio.components import (
Button,
Component,
Expand Down Expand Up @@ -177,6 +177,8 @@ def __init__(
flagging_callback: An instance of a subclass of FlaggingCallback which will be called when a sample is flagged. By default logs to a local CSV file.
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
"""
stop_btn_color = "#stop-btn {background: red}"
css = css + stop_btn_color if css else stop_btn_color
super().__init__(
analytics_enabled=analytics_enabled,
mode="interface",
Expand Down Expand Up @@ -472,6 +474,16 @@ def render_flag_btns(flagging_options):
clear_btn = Button("Clear")
if not self.live:
submit_btn = Button("Submit", variant="primary")
# Stopping jobs only works if the queue is enabled
# We don't know if the queue is enabled when the interface
# is created. We use whether a generator function is provided
# as a proxy of whether the queue will be enabled.
# Using a generator function without the queue will raise an error.
if inspect.isgeneratorfunction(fn):
stop_btn = Button(
"Stop", visible=False, elem_id="stop-btn"
)

elif self.interface_type == self.InterfaceTypes.UNIFIED:
clear_btn = Button("Clear")
submit_btn = Button("Submit", variant="primary")
Expand All @@ -491,6 +503,15 @@ def render_flag_btns(flagging_options):
if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
clear_btn = Button("Clear")
submit_btn = Button("Generate", variant="primary")
if inspect.isgeneratorfunction(fn):
# Stopping jobs only works if the queue is enabled
# We don't know if the queue is enabled when the interface
# is created. We use whether a generator function is provided
# as a proxy of whether the queue will be enabled.
# Using a generator function without the queue will raise an error.
stop_btn = Button(
"Stop", visible=False, elem_id="stop-btn"
)
if self.allow_flagging == "manual":
flag_btns = render_flag_btns(self.flagging_options)
if self.interpretation:
Expand Down Expand Up @@ -535,7 +556,7 @@ def render_flag_btns(flagging_options):
postprocess=not (self.api_mode),
)
else:
submit_btn.click(
pred = submit_btn.click(
self.fn,
self.input_components,
self.output_components,
Expand All @@ -544,6 +565,27 @@ def render_flag_btns(flagging_options):
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
)
if inspect.isgeneratorfunction(fn):
submit_btn.click(
lambda: {
submit_btn: update(visible=False),
stop_btn: update(visible=True),
},
inputs=None,
outputs=[submit_btn, stop_btn],
queue=False,
)
stop_btn.click(
lambda: {
submit_btn: update(visible=True),
stop_btn: update(visible=False),
},
inputs=None,
outputs=[submit_btn, stop_btn],
queue=False,
cancels=[pred],
)

clear_btn.click(
None,
[],
Expand Down

0 comments on commit af5ace6

Please sign in to comment.