From f5cce0c799d0da3e7f6b9b6cf7de751e1b612d3e Mon Sep 17 00:00:00 2001 From: tianwei Date: Wed, 6 Dec 2023 19:11:40 +0800 Subject: [PATCH] feat(sdk): support Image type to accept numpy and pillow image types (#3063) --- client/starwhale/api/_impl/dataset/model.py | 2 +- client/starwhale/base/data_type.py | 27 ++++++++++++++++--- .../integrations/huggingface/dataset.py | 6 +---- client/tests/base/test_data_type.py | 17 +++++++++--- client/tests/core/test_dataset.py | 4 +-- client/tests/sdk/test_loader.py | 2 +- scripts/example/src/util.py | 2 +- 7 files changed, 43 insertions(+), 17 deletions(-) diff --git a/client/starwhale/api/_impl/dataset/model.py b/client/starwhale/api/_impl/dataset/model.py index 26e1c7b754..f0da1814b3 100644 --- a/client/starwhale/api/_impl/dataset/model.py +++ b/client/starwhale/api/_impl/dataset/model.py @@ -1774,7 +1774,7 @@ def _iter_records() -> t.Iterator[t.Tuple[str, t.Dict]]: record["caption"] = caption_path.read_text().strip() record["file"] = file_cls( - fp=p, + p, display_name=p.name, mime_type=MIMEType.create_by_file_suffix(p), ) diff --git a/client/starwhale/base/data_type.py b/client/starwhale/base/data_type.py index d00dac9dd3..f479d78649 100644 --- a/client/starwhale/base/data_type.py +++ b/client/starwhale/base/data_type.py @@ -12,6 +12,7 @@ import numpy +from starwhale.utils import console from starwhale.consts import SHORT_VERSION_CNT from starwhale.utils.fs import DIGEST_SIZE, FilePosition from starwhale.base.mixin import ASDictMixin @@ -260,7 +261,7 @@ def to_tensor(self) -> t.Any: class Image(BaseArtifact, SwObject): def __init__( self, - fp: _TArtifactFP = "", + fp: t.Any = "", display_name: str = "", shape: t.Optional[_TShape] = None, mime_type: MIMEType = MIMEType.UNDEFINED, @@ -271,8 +272,9 @@ def __init__( ) -> None: self.as_mask = as_mask self.mask_uri = mask_uri + super().__init__( - fp, + self._convert_pil_and_numpy(fp), ArtifactType.Image, display_name=display_name, shape=shape or (None, None, 3), @@ -281,6 +283,25 @@ def __init__( link=link, ) + def _convert_pil_and_numpy(self, source: t.Any) -> t.Any: + try: + # pillow is optional for starwhale, so we need to check if it is installed + from PIL import Image as PILImage + except ImportError: # pragma: no cover + console.trace( + "pillow is not installed, skip try to convert PILImage and numpy.ndarray to bytes" + ) + return source + + if isinstance(source, (PILImage.Image, numpy.ndarray)): + image_bytes = io.BytesIO() + if isinstance(source, numpy.ndarray): + source = PILImage.fromarray(source) + source.save(image_bytes, format="PNG") + return image_bytes.getvalue() + else: + return source + def _do_validate(self) -> None: if self.mime_type not in ( MIMEType.PNG, @@ -338,7 +359,7 @@ def to_tensor(self) -> t.Any: class GrayscaleImage(Image): def __init__( self, - fp: _TArtifactFP = "", + fp: t.Any = "", display_name: str = "", shape: t.Optional[_TShape] = None, as_mask: bool = False, diff --git a/client/starwhale/integrations/huggingface/dataset.py b/client/starwhale/integrations/huggingface/dataset.py index b5759e945e..bbd1eeb5c6 100644 --- a/client/starwhale/integrations/huggingface/dataset.py +++ b/client/starwhale/integrations/huggingface/dataset.py @@ -36,17 +36,13 @@ def _transform_to_starwhale(data: t.Any, feature: t.Any) -> t.Any: from PIL import Image as PILImage if isinstance(data, PILImage.Image): - img_io = io.BytesIO() - data.save(img_io, format=data.format or "PNG") - img_fp = img_io.getvalue() - try: data_mimetype = data.get_format_mimetype() mime_type = MIMEType(data_mimetype) except (ValueError, AttributeError): mime_type = MIMEType.PNG return Image( - fp=img_fp, + data, shape=(data.height, data.width, len(data.getbands())), mime_type=mime_type, ) diff --git a/client/tests/base/test_data_type.py b/client/tests/base/test_data_type.py index fa22409b7d..2b30a941c6 100644 --- a/client/tests/base/test_data_type.py +++ b/client/tests/base/test_data_type.py @@ -128,6 +128,8 @@ def test_numpy_binary(self) -> None: assert torch.equal(torch.from_numpy(np_array), b.to_tensor()) def test_image(self) -> None: + self.fs.create_file("path/to/file", contents="") + fp = io.StringIO("test") img = Image(fp, display_name="t", shape=[28, 28, 3], mime_type=MIMEType.PNG) assert img.to_bytes() == b"test" @@ -159,7 +161,6 @@ def test_image(self) -> None: assert _asdict["shape"] == [28, 28, 1] assert _asdict["_raw_base64_data"] == base64.b64encode(b"test").decode() - self.fs.create_file("path/to/file", contents="") img = GrayscaleImage(Path("path/to/file"), shape=[28, 28, 1]).carry_raw_data() typ = data_store._get_type(img) assert isinstance(typ, data_store.SwObjectType) @@ -168,9 +169,9 @@ def test_image(self) -> None: pixels = numpy.random.randint( low=0, high=256, size=(100, 100, 3), dtype=numpy.uint8 ) - image_bytes = io.BytesIO() - PILImage.fromarray(pixels, mode="RGB").save(image_bytes, format="PNG") - img = Image(image_bytes.getvalue()) + pil_obj = PILImage.fromarray(pixels, mode="RGB") + + img = Image(pil_obj) pil_img = img.to_pil() assert isinstance(pil_img, PILImage.Image) assert pil_img.mode == "RGB" @@ -182,6 +183,14 @@ def test_image(self) -> None: l_array = img.to_numpy("L") assert l_array.shape == (100, 100) + img = Image(pixels) + pil_img = img.to_pil() + assert isinstance(pil_img, PILImage.Image) + assert pil_img.mode == "RGB" + array = img.to_numpy() + assert isinstance(array, numpy.ndarray) + assert (array == pixels).all() + def test_swobject_subclass_init(self) -> None: from starwhale.base import data_type diff --git a/client/tests/core/test_dataset.py b/client/tests/core/test_dataset.py index 8e6077461a..2ab418834c 100644 --- a/client/tests/core/test_dataset.py +++ b/client/tests/core/test_dataset.py @@ -477,7 +477,7 @@ def test_head(self, *args: t.Any) -> None: ( "label-0", { - "img": GrayscaleImage(fp=b"123"), + "img": GrayscaleImage(b"123"), "label": 0, }, ) @@ -486,7 +486,7 @@ def test_head(self, *args: t.Any) -> None: ( "label-1", { - "img": GrayscaleImage(fp=b"456"), + "img": GrayscaleImage(b"456"), "label": 1, }, ) diff --git a/client/tests/sdk/test_loader.py b/client/tests/sdk/test_loader.py index 31f108d67b..3c679631cf 100644 --- a/client/tests/sdk/test_loader.py +++ b/client/tests/sdk/test_loader.py @@ -638,7 +638,7 @@ def test_data_row(self) -> None: assert dr < dr_another assert dr != dr_another - dr_third = DataRow(index=1, features={"data": Image(fp=b""), "label": 10}) + dr_third = DataRow(index=1, features={"data": Image(b""), "label": 10}) assert dr >= dr_third def test_data_row_exceptions(self) -> None: diff --git a/scripts/example/src/util.py b/scripts/example/src/util.py index 0c7a8b4820..afe50b5acd 100644 --- a/scripts/example/src/util.py +++ b/scripts/example/src/util.py @@ -5,7 +5,7 @@ from starwhale import Image -def random_image() -> bytes: +def random_image() -> Image: try: return _random_image_from_pillow() except ImportError: