Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation-related fixes to the python client #3663

Merged
merged 29 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 109 additions & 2 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def __init__(
hf_token: str | None = None,
max_workers: int = 40,
):
"""
Parameters:
space: The name of the Space to load, e.g. "abidlabs/pictionary". If it is a private Space, you must provide an hf_token. app. Either `space` or `src` must be provided.
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
src: The full URL of the hosted Gradio app to load, e.g. "https://mydomain.com/app" or the shareable link to a Gradio app, e.g. "https://bec81a83-5b5c-471e.gradio.live/". Either `space` or `src` must be provided.
hf_token: The Hugging Face token to use to access private Spaces. If not provided, only public Spaces can be loaded.
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
"""
self.hf_token = hf_token
self.headers = build_hf_headers(
token=hf_token,
Expand Down Expand Up @@ -68,6 +75,15 @@ def predict(
fn_index: int = 0,
result_callbacks: Callable | List[Callable] | None = None,
) -> Future:
"""
Parameters:
*args: The arguments to pass to the remote API. The order of the arguments must match the order of the inputs in the Gradio app.
api_name: The name of the API endpoint to call. If not provided, the first API will be called. Takes precedence over fn_index.
fn_index: The index of the API endpoint to call. If not provided, the first API will be called.
result_callbacks: A callback function, or list of callback functions, to be called when the result is ready. If a list of functions is provided, they will be called in order. The return values from the remote API are provided as separate parameters into the callback. If None, no callback will be called.
Returns:
A Job object that can be used to retrieve the status and result of the remote API call.
"""
if api_name:
fn_index = self._infer_fn_index(api_name)

Expand All @@ -93,6 +109,67 @@ def fn(future):

return job

def usage(self, all_endpoints=True, print_usage=True) -> str | None:
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
"""
Parameters:
all_endpoints: If True, returns information for both named and unnamed endpoints in the Gradio app. If False, will only return info about named endpoints.
print_usage: If True, prints the usage info to the console. If False, returns the usage info as a string.
"""
named_endpoints: Dict[str, Dict[str, Dict[str, Tuple[str, str]]]] = {}
unnamed_endpoints: Dict[int, Dict[str, Dict[str, Tuple[str, str]]]] = {}
for endpoint in self.endpoints:
if endpoint.is_valid:
if endpoint.api_name:
named_endpoints[endpoint.api_name] = endpoint.get_info()
else:
unnamed_endpoints[endpoint.fn_index] = endpoint.get_info()

usage_info = f"Client.predict() Usage Info\n---------------------------\nNamed endpoints: {len(named_endpoints)}\n"
usage_info += self._render_endpoints_info(
named_endpoints, label_format='api_name="{}"'
)
if unnamed_endpoints and all_endpoints:
usage_info += f"\nUnnamed endpoints: {len(unnamed_endpoints)}\n"
usage_info += self._render_endpoints_info(
unnamed_endpoints, label_format="fn_index={}"
)
if print_usage:
print(usage_info)
else:
return usage_info

def _render_endpoints_info(
self,
endpoints_info: Dict[Any, Dict[str, Dict[str, Tuple[str, str]]]],
label_format: str,
) -> str:
usage_info = ""
for label, info in endpoints_info.items():
parameters = ",".join(list(info["parameters"].keys()))
if parameters:
parameters = f"{parameters}, "
returns = ",".join(list(info["returns"].keys()))
if returns:
returns = f" -> {returns}"
usage_info += (
f" - predict({parameters}{label_format.format(label)}){returns}\n"
)
if parameters:
usage_info += " Parameters:\n"
for name, type in info["parameters"].items():
usage_info += f" - [{type[1]}] {name}: {type[0]}\n"
if returns:
usage_info += " Returns:\n"
for name, type in info["returns"].items():
usage_info += f" - [{type[1]}] {name}: {type[0]}\n"
return usage_info

def __repr__(self):
return self.usage()

def __str__(self):
return self.usage()

def _telemetry_thread(self) -> None:
# Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1
data = {
Expand Down Expand Up @@ -132,7 +209,7 @@ def _get_config(self) -> Dict:
raise ValueError(f"Could not get Gradio config from: {self.src}")
if "allow_flagging" in config:
raise ValueError(
"Gradio 2.x is not supported by this client. Please upgrade this app to Gradio 3.x."
"Gradio 2.x is not supported by this client. Please upgrade your Gradio app to Gradio 3.x or higher."
)
return config

Expand All @@ -145,6 +222,7 @@ def __init__(self, client: Client, fn_index: int, dependency: Dict):
self.ws_url = client.ws_url
self.fn_index = fn_index
self.dependency = dependency
self.api_name: str | None = dependency.get("api_name")
self.headers = client.headers
self.config = client.config
self.use_ws = self._use_websocket(self.dependency)
Expand All @@ -153,10 +231,39 @@ def __init__(self, client: Client, fn_index: int, dependency: Dict):
self.serializers, self.deserializers = self._setup_serializers()
self.is_valid = self.dependency[
"backend_fn"
] # Only a real API endpoint if backend_fn is True
] # Only a real API endpoint if backend_fn is True and serializers are valid
except AssertionError:
self.is_valid = False

def get_info(self) -> Dict[str, Dict[str, Tuple[str, str]]]:
parameters = {}
for i, input in enumerate(self.dependency["inputs"]):
for component in self.config["components"]:
if component["id"] == input:
label = component["props"].get("label", f"parameter_{i}").lower()
if "info" in component:
info = component["info"]["input"]
else:
info = (
self.serializers[i].get_input_type(),
component.get("type", "component").capitalize(),
)
parameters[label] = info
returns = {}
for o, output in enumerate(self.dependency["outputs"]):
for component in self.config["components"]:
if component["id"] == output:
label = component["props"].get("label", f"parameter_{o}").lower()
if "info" in component:
info = component["info"]["output"]
info = (
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
self.deserializers[o].get_output_type(),
component.get("type", "component").capitalize(),
)
returns[label] = info

return {"parameters": parameters, "returns": returns}

def end_to_end_fn(self, *data):
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
Expand Down
144 changes: 107 additions & 37 deletions client/python/gradio_client/serializing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,22 @@

class Serializable(ABC):
@abstractmethod
def serialize(self, x: Any, load_dir: str | Path = ""):
def get_input_type() -> str:
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
"""
Convert data from human-readable format to serialized format for a browser.
Get the type of input this component accepts (for documentation generation).
"""
pass

@abstractmethod
def deserialize(
self,
x: Any,
save_dir: str | Path | None = None,
root_url: str | None = None,
hf_token: str | None = None,
):
def get_output_type() -> str:
"""
Convert data from serialized format for a browser to human-readable format.
Get the type of input this component accepts (for documentation generation).
"""
pass


class SimpleSerializable(Serializable):
def serialize(self, x: Any, load_dir: str | Path = "") -> Any:
def serialize(self, x: Any, load_dir: str | Path = ""):
"""
Convert data from human-readable format to serialized format. For SimpleSerializable components, this is a no-op.
Parameters:
x: Input data to serialize
load_dir: Ignored
Convert data from human-readable format to serialized format for a browser.
"""
return x

Expand All @@ -50,17 +39,80 @@ def deserialize(
hf_token: str | None = None,
):
"""
Convert data from serialized format to human-readable format. For SimpleSerializable components, this is a no-op.
Parameters:
x: Input data to deserialize
save_dir: Ignored
root_url: Ignored
hf_token: Ignored
Convert data from serialized format for a browser to human-readable format.
"""
return x


class SimpleSerializable(Serializable):
"""General class that does not perform any serialization or deserialization."""

def get_input_type(self) -> str:
return "Any"

def get_output_type(self) -> str:
return "Any"


class StringSerializable(Serializable):
"""Expects a string as input/output but performs no serialization."""

def get_input_type(self) -> str:
return "str (value)"

def get_output_type(self) -> str:
return "str (value)"


class ListStringSerializable(Serializable):
"""Expects a list of strings as input/output but performs no serialization."""

def get_input_type(self) -> str:
return "List[str] (values)"

def get_output_type(self) -> str:
return "List[str] (values)"


class DropdownSerializable(Serializable):
"""Expects a string or list of strings as input/output but performs no serialization."""

def get_input_type(self) -> str:
return "str | List[str] (value[s])"

def get_output_type(self) -> str:
return "str | List[str] (value[s])"


class BooleanSerializable(Serializable):
"""Expects a boolean as input/output but performs no serialization."""

def get_input_type(self) -> str:
return "bool (value)"

def get_output_type(self) -> str:
return "bool (value)"


class NumberSerializable(Serializable):
"""Expects a number (int/float) as input/output but performs no serialization."""

def get_input_type(self) -> str:
return "int | float (value)"

def get_output_type(self) -> str:
return "int | float (value)"


class ImgSerializable(Serializable):
"""Expects a base64 string as input/output which is ."""

def get_input_type(self) -> str:
return "str (filepath or URL)"

def get_output_type(self) -> str:
return "str (filepath or URL)"

def serialize(
self,
x: str | None,
Expand Down Expand Up @@ -102,6 +154,12 @@ def deserialize(


class FileSerializable(Serializable):
def get_input_type(self) -> str:
return "str (filepath or URL)"

def get_output_type(self) -> str:
return "str (filepath or URL)"

def serialize(
self,
x: str | None,
Expand Down Expand Up @@ -168,6 +226,12 @@ def deserialize(


class JSONSerializable(Serializable):
def get_input_type(self) -> str:
return "str (filepath to json file)"

def get_output_type(self) -> str:
return "str (filepath to json file)"

def serialize(
self,
x: str | None,
Expand Down Expand Up @@ -206,6 +270,12 @@ def deserialize(


class GallerySerializable(Serializable):
def get_input_type(self) -> str:
return "str (directory path)"

def get_output_type(self) -> str:
return "str (directory path)"

def serialize(
self, x: str | None, load_dir: str | Path = ""
) -> List[List[str]] | None:
Expand Down Expand Up @@ -249,33 +319,33 @@ def deserialize(

SERIALIZER_MAPPING = {cls.__name__: cls for cls in Serializable.__subclasses__()}

COMPONENT_MAPPING = {
"textbox": SimpleSerializable,
"number": SimpleSerializable,
"slider": SimpleSerializable,
"checkbox": SimpleSerializable,
"checkboxgroup": SimpleSerializable,
"radio": SimpleSerializable,
"dropdown": SimpleSerializable,
COMPONENT_MAPPING: Dict[str, type] = {
"textbox": StringSerializable,
"number": NumberSerializable,
"slider": NumberSerializable,
"checkbox": BooleanSerializable,
"checkboxgroup": ListStringSerializable,
"radio": StringSerializable,
"dropdown": DropdownSerializable,
"image": ImgSerializable,
"video": FileSerializable,
"audio": FileSerializable,
"file": FileSerializable,
"dataframe": JSONSerializable,
"timeseries": JSONSerializable,
"state": SimpleSerializable,
"button": SimpleSerializable,
"button": StringSerializable,
"uploadbutton": FileSerializable,
"colorpicker": SimpleSerializable,
"colorpicker": StringSerializable,
"label": JSONSerializable,
"highlightedtext": JSONSerializable,
"json": JSONSerializable,
"html": SimpleSerializable,
"html": StringSerializable,
"gallery": GallerySerializable,
"chatbot": JSONSerializable,
"model3d": FileSerializable,
"plot": JSONSerializable,
"markdown": SimpleSerializable,
"dataset": SimpleSerializable,
"code": SimpleSerializable,
"markdown": StringSerializable,
"dataset": StringSerializable,
"code": StringSerializable,
}
Loading