diff --git a/src/lightning_app/components/serve/gradio.py b/src/lightning_app/components/serve/gradio.py index 7e7801925937f..328e70e743b43 100644 --- a/src/lightning_app/components/serve/gradio.py +++ b/src/lightning_app/components/serve/gradio.py @@ -31,6 +31,8 @@ class ServeGradio(LightningWork, abc.ABC): outputs: Any examples: Optional[List] = None enable_queue: bool = False + title: Optional[str] = None + description: Optional[str] = None def __init__(self, *args, **kwargs): requires("gradio")(super().__init__(*args, **kwargs)) @@ -58,7 +60,14 @@ def run(self, *args, **kwargs): self._model = self.build_model() fn = partial(self.predict, *args, **kwargs) fn.__name__ = self.predict.__name__ - gradio.Interface(fn=fn, inputs=self.inputs, outputs=self.outputs, examples=self.examples).launch( + gradio.Interface( + fn=fn, + inputs=self.inputs, + outputs=self.outputs, + examples=self.examples, + title=self.title, + description=self.description, + ).launch( server_name=self.host, server_port=self.port, enable_queue=self.enable_queue, diff --git a/tests/tests_app/components/serve/test_gradio.py b/tests/tests_app/components/serve/test_gradio.py index 8dcdeec70a341..0b57656e6aa31 100644 --- a/tests/tests_app/components/serve/test_gradio.py +++ b/tests/tests_app/components/serve/test_gradio.py @@ -27,4 +27,6 @@ def predict(self, *args, **kwargs): comp.run() assert comp.model == "model" assert comp.predict() == "prediction" - gradio_mock.Interface.assert_called_once_with(fn=ANY, inputs=ANY, outputs=ANY, examples=ANY) + gradio_mock.Interface.assert_called_once_with( + fn=ANY, inputs=ANY, outputs=ANY, examples=ANY, title=None, description=None + )