From bd1f831f3187cd5b1e588a4ced9a3b4627d7fa1c Mon Sep 17 00:00:00 2001 From: Jialei <3217223+jialeicui@users.noreply.github.com> Date: Sun, 8 Oct 2023 14:04:37 +0800 Subject: [PATCH] fix(client): model build with example failed with the latest gradio version (#2814) --- client/starwhale/api/_impl/service.py | 36 +++++++++---------- .../tests/data/sdk/service/default_class.py | 2 +- client/tests/sdk/test_service.py | 13 +++++-- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/client/starwhale/api/_impl/service.py b/client/starwhale/api/_impl/service.py index 28f79c46bb..3c51941d3d 100644 --- a/client/starwhale/api/_impl/service.py +++ b/client/starwhale/api/_impl/service.py @@ -97,7 +97,7 @@ def get_openapi_spec(self) -> t.Any: def _render_api(self, _api: Api, _inst: t.Any) -> None: import gradio - from gradio.components import File, Image, Video, IOComponent + from gradio.components import IOComponent js_func: t.Optional[str] = None if self.hijack and self.hijack.submit: @@ -116,25 +116,21 @@ def _render_api(self, _api: Api, _inst: t.Any) -> None: examples=_api.examples, inputs=[i for i in _api.input if isinstance(i, IOComponent)], ) - if any( - isinstance(i, (File, Image, Video)) - for i in example.dataset.components - ): - # examples should be a list of file path - # use flatten list - to_copy = [i for j in example.examples for i in j] - self.example_resources.extend(to_copy) - # change example resource path for online evaluation - # e.g. /path/to/example.png -> /workdir/src/.starwhale/examples/example.png - if self.hijack and self.hijack.resource_path: - for i in range(len(example.dataset.samples)): - for j in range(len(example.dataset.samples[i])): - origin = example.dataset.samples[i][j] - if origin in to_copy: - name = os.path.basename(origin) - example.dataset.samples[i][j] = os.path.join( - self.hijack.resource_path, name - ) + # examples should be a list of file path + # use flatten list + to_copy = [i for j in example.examples for i in j] + self.example_resources.extend(to_copy) + # change example resource path for online evaluation + # e.g. /path/to/example.png -> /workdir/src/.starwhale/examples/example.png + if self.hijack and self.hijack.resource_path: + for i in range(len(example.dataset.samples)): + for j in range(len(example.dataset.samples[i])): + origin = example.dataset.samples[i][j] + if origin in to_copy: + name = os.path.basename(origin) + example.dataset.samples[i][j] = os.path.join( + self.hijack.resource_path, name + ) with gradio.Column(): for i in _api.output: gradio.components.get_component_instance(i, render=False).render() diff --git a/client/tests/data/sdk/service/default_class.py b/client/tests/data/sdk/service/default_class.py index fedb598788..77831283ef 100644 --- a/client/tests/data/sdk/service/default_class.py +++ b/client/tests/data/sdk/service/default_class.py @@ -13,6 +13,6 @@ def ppl(self, data: bytes, **kw: t.Any) -> t.Any: def handler_foo(self, data: t.Any) -> t.Any: return - @service.api(gradio.Text(), gradio.Json()) + @service.api(gradio.Text(), gradio.Json(), examples=["foo", "bar", __file__]) def cmp(self, ppl_result: t.Iterator) -> t.Any: pass diff --git a/client/tests/sdk/test_service.py b/client/tests/sdk/test_service.py index 6a3db32ceb..b7d301c285 100644 --- a/client/tests/sdk/test_service.py +++ b/client/tests/sdk/test_service.py @@ -1,12 +1,13 @@ import os import json +import tempfile from pathlib import Path import pytest +from tests import ROOT_DIR, BaseTestCase from starwhale.core.model.model import StandaloneModel - -from .. import ROOT_DIR, BaseTestCase +from starwhale.api._impl.service import Hijack class ServiceTestCase(BaseTestCase): @@ -29,8 +30,14 @@ def test_custom_class(self): assert len(spec["dependencies"]) == 2 def test_default_class(self): - svc = StandaloneModel._get_service(["default_class:MyDefaultClass"], self.root) + svc = StandaloneModel._get_service( + ["default_class:MyDefaultClass"], + self.root, + hijack=Hijack(True, tempfile.gettempdir()), + ) assert list(svc.apis.keys()) == ["cmp"] + spec = svc.get_spec() + assert len(spec["dependencies"]) == 2 def test_class_without_api(self): svc = StandaloneModel._get_service(["no_api:NoApi"], self.root)