Skip to content

Commit

Permalink
fix(client): model build with example failed with the latest gradio v…
Browse files Browse the repository at this point in the history
…ersion (#2814)
  • Loading branch information
jialeicui authored Oct 8, 2023
1 parent 9a79136 commit bd1f831
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 24 deletions.
36 changes: 16 additions & 20 deletions client/starwhale/api/_impl/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_openapi_spec(self) -> t.Any:

def _render_api(self, _api: Api, _inst: t.Any) -> None:
import gradio
from gradio.components import File, Image, Video, IOComponent
from gradio.components import IOComponent

js_func: t.Optional[str] = None
if self.hijack and self.hijack.submit:
Expand All @@ -116,25 +116,21 @@ def _render_api(self, _api: Api, _inst: t.Any) -> None:
examples=_api.examples,
inputs=[i for i in _api.input if isinstance(i, IOComponent)],
)
if any(
isinstance(i, (File, Image, Video))
for i in example.dataset.components
):
# examples should be a list of file path
# use flatten list
to_copy = [i for j in example.examples for i in j]
self.example_resources.extend(to_copy)
# change example resource path for online evaluation
# e.g. /path/to/example.png -> /workdir/src/.starwhale/examples/example.png
if self.hijack and self.hijack.resource_path:
for i in range(len(example.dataset.samples)):
for j in range(len(example.dataset.samples[i])):
origin = example.dataset.samples[i][j]
if origin in to_copy:
name = os.path.basename(origin)
example.dataset.samples[i][j] = os.path.join(
self.hijack.resource_path, name
)
# examples should be a list of file path
# use flatten list
to_copy = [i for j in example.examples for i in j]
self.example_resources.extend(to_copy)
# change example resource path for online evaluation
# e.g. /path/to/example.png -> /workdir/src/.starwhale/examples/example.png
if self.hijack and self.hijack.resource_path:
for i in range(len(example.dataset.samples)):
for j in range(len(example.dataset.samples[i])):
origin = example.dataset.samples[i][j]
if origin in to_copy:
name = os.path.basename(origin)
example.dataset.samples[i][j] = os.path.join(
self.hijack.resource_path, name
)
with gradio.Column():
for i in _api.output:
gradio.components.get_component_instance(i, render=False).render()
Expand Down
2 changes: 1 addition & 1 deletion client/tests/data/sdk/service/default_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def ppl(self, data: bytes, **kw: t.Any) -> t.Any:
def handler_foo(self, data: t.Any) -> t.Any:
return

@service.api(gradio.Text(), gradio.Json())
@service.api(gradio.Text(), gradio.Json(), examples=["foo", "bar", __file__])
def cmp(self, ppl_result: t.Iterator) -> t.Any:
pass
13 changes: 10 additions & 3 deletions client/tests/sdk/test_service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import json
import tempfile
from pathlib import Path

import pytest

from tests import ROOT_DIR, BaseTestCase
from starwhale.core.model.model import StandaloneModel

from .. import ROOT_DIR, BaseTestCase
from starwhale.api._impl.service import Hijack


class ServiceTestCase(BaseTestCase):
Expand All @@ -29,8 +30,14 @@ def test_custom_class(self):
assert len(spec["dependencies"]) == 2

def test_default_class(self):
svc = StandaloneModel._get_service(["default_class:MyDefaultClass"], self.root)
svc = StandaloneModel._get_service(
["default_class:MyDefaultClass"],
self.root,
hijack=Hijack(True, tempfile.gettempdir()),
)
assert list(svc.apis.keys()) == ["cmp"]
spec = svc.get_spec()
assert len(spec["dependencies"]) == 2

def test_class_without_api(self):
svc = StandaloneModel._get_service(["no_api:NoApi"], self.root)
Expand Down

0 comments on commit bd1f831

Please sign in to comment.