Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance(sdk): format Image shape with pillow.Image and numpy.ndarray shape #3077

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions client/starwhale/base/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,37 +270,58 @@ def __init__(
dtype: t.Type = numpy.uint8,
link: t.Optional[Link] = None,
) -> None:
"""Starwhale Image type.

Arguments:
fp: (str, bytes, Path, io.IOBase, pillow.Image, numpy.ndarray) The image data source.
If the argument is str, bytes, Path, io.IOBase, we will try to read the data from the source.
If the argument is pillow.Image or numpy.ndarray, we will try to convert it to bytes.
display_name: (str, optional) The display name of the image, default is "".
shape: (tuple, optional) The shape of the image numpy.ndarray, default is None.
shape = (height, width, channel)
mime_type: (MIMEType, optional) The mime type of the image, default is MIMEType.UNDEFINED.
dtype: (numpy.dtype, optional) The numpy dtype of the image, default is numpy.uint8.
"""
self.as_mask = as_mask
self.mask_uri = mask_uri

fp, shape = self._convert_pil_and_numpy(fp, shape)
super().__init__(
self._convert_pil_and_numpy(fp),
fp,
ArtifactType.Image,
display_name=display_name,
shape=shape or (None, None, 3),
shape=shape,
mime_type=mime_type,
dtype=dtype,
link=link,
)

def _convert_pil_and_numpy(self, source: t.Any) -> t.Any:
def _convert_pil_and_numpy(
self, source: t.Any, shape: t.Optional[_TShape] = None
) -> t.Any:
shape = shape or (None, None, 3)
try:
# pillow is optional for starwhale, so we need to check if it is installed
from PIL import Image as PILImage
except ImportError: # pragma: no cover
console.trace(
"pillow is not installed, skip try to convert PILImage and numpy.ndarray to bytes"
)
return source
return source, shape

if isinstance(source, (PILImage.Image, numpy.ndarray)):
image_bytes = io.BytesIO()
if isinstance(source, numpy.ndarray):
source = PILImage.fromarray(source)
source.save(image_bytes, format="PNG")
return image_bytes.getvalue()
# shape = (height, width, channel) for numpy
return image_bytes.getvalue(), (
source.height,
source.width,
len(source.getbands()),
)
else:
return source
return source, shape

def _do_validate(self) -> None:
if self.mime_type not in (
Expand Down
6 changes: 3 additions & 3 deletions client/tests/base/test_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_image(self) -> None:
assert typ.attrs["_raw_base64_data"] == data_store.STRING

pixels = numpy.random.randint(
low=0, high=256, size=(100, 100, 3), dtype=numpy.uint8
low=0, high=256, size=(200, 100, 3), dtype=numpy.uint8
)
pil_obj = PILImage.fromarray(pixels, mode="RGB")

Expand All @@ -179,9 +179,9 @@ def test_image(self) -> None:
assert l_pil_img.mode == "L"
array = img.to_numpy()
assert isinstance(array, numpy.ndarray)
assert array.shape == (100, 100, 3)
assert array.shape == (200, 100, 3) == tuple(img.shape)
l_array = img.to_numpy("L")
assert l_array.shape == (100, 100)
assert l_array.shape == (200, 100)

img = Image(pixels)
pil_img = img.to_pil()
Expand Down
Loading