diff --git a/meerkat/columns/image_column.py b/meerkat/columns/image_column.py index 353767b93..ee856cead 100644 --- a/meerkat/columns/image_column.py +++ b/meerkat/columns/image_column.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os from typing import Collection, Sequence from meerkat.columns.abstract import AbstractColumn @@ -13,22 +14,28 @@ logger = logging.getLogger(__name__) -class ImageCell(LambdaCell): +class ImageLoaderMixin: + def fn(self, filepath: str): + if self.base_dir is not None: + filepath = os.path.join(self.base_dir, filepath) + image = self.loader(filepath) + if self.transform is not None: + image = self.transform(image) + return image + + +class ImageCell(ImageLoaderMixin, LambdaCell): def __init__( self, transform: callable = None, loader: callable = None, data: str = None, + base_dir: str = None, ): self.loader = self.default_loader if loader is None else loader self.transform = transform self._data = data - - def fn(self, filepath: str): - image = self.loader(filepath) - if self.transform is not None: - image = self.transform(image) - return image + self.base_dir = base_dir def __eq__(self, other): return ( @@ -45,12 +52,13 @@ def __repr__(self): return f"ImageCell({short_path}, transform={transform})" -class ImageColumn(LambdaColumn): +class ImageColumn(ImageLoaderMixin, LambdaColumn): def __init__( self, data: Sequence[str] = None, transform: callable = None, loader: callable = None, + base_dir: str = None, *args, **kwargs, ): @@ -59,15 +67,15 @@ def __init__( super(ImageColumn, self).__init__(data, *args, **kwargs) self.loader = self.default_loader if loader is None else loader self.transform = transform + self.base_dir = base_dir def _create_cell(self, data: object) -> ImageCell: - return ImageCell(data=data, loader=self.loader, transform=self.transform) - - def fn(self, filepath: str): - image = self.loader(filepath) - if self.transform is not None: - image = self.transform(image) - return image + return ImageCell( + data=data, + loader=self.loader, + transform=self.transform, + base_dir=self.base_dir, + ) @classmethod def from_filepaths( @@ -75,6 +83,7 @@ def from_filepaths( filepaths: Sequence[str], loader: callable = None, transform: callable = None, + base_dir: str = None, *args, **kwargs, ): @@ -82,6 +91,7 @@ def from_filepaths( data=filepaths, loader=loader, transform=transform, + base_dir=base_dir, *args, **kwargs, ) @@ -92,7 +102,11 @@ def default_loader(cls, *args, **kwargs): @classmethod def _state_keys(cls) -> Collection: - return (super()._state_keys() | {"transform", "loader"}) - {"fn"} + return (super()._state_keys() | {"transform", "loader", "base_dir"}) - {"fn"} + + def _set_state(self, state: dict): + state["base_dir"] = state.get("base_dir", None) # backwards compatibility + super()._set_state(state) def is_equal(self, other: AbstractColumn) -> bool: return ( diff --git a/tests/meerkat/columns/test_image_column.py b/tests/meerkat/columns/test_image_column.py index 448ed8d8a..5f2acc70f 100644 --- a/tests/meerkat/columns/test_image_column.py +++ b/tests/meerkat/columns/test_image_column.py @@ -26,13 +26,14 @@ class ImageColumnTestBed(AbstractColumnTestBed): - DEFAULT_CONFIG = {"transform": [True, False]} + DEFAULT_CONFIG = {"transform": [True, False], "use_base_dir": [True, False]} def __init__( self, tmpdir: str, length: int = 16, transform: bool = False, + use_base_dir: bool = False, seed: int = 123, ): self.image_paths = [] @@ -42,18 +43,28 @@ def __init__( transform = to_tensor if transform else None + self.base_dir = tmpdir if use_base_dir else None + for i in range(0, length): - self.image_paths.append(os.path.join(tmpdir, "{}.png".format(i))) self.image_arrays.append((i * np.ones((4, 4, 3))).astype(np.uint8)) im = Image.fromarray(self.image_arrays[-1]) self.ims.append(im) self.data.append(transform(im) if transform else im) - im.save(self.image_paths[-1]) + filename = "{}.png".format(i) + im.save(os.path.join(tmpdir, filename)) + if use_base_dir: + self.image_paths.append(filename) + else: + self.image_paths.append(os.path.join(tmpdir, filename)) + if transform is not None: self.data = torch.stack(self.data) self.transform = transform self.col = ImageColumn.from_filepaths( - self.image_paths, transform=transform, loader=folder.default_loader + self.image_paths, + transform=transform, + loader=folder.default_loader, + base_dir=self.base_dir, ) def get_map_spec( @@ -172,6 +183,7 @@ def get_data(self, index, materialize: bool = True): data=self.image_paths[index], loader=self.col.loader, transform=self.col.transform, + base_dir=self.base_dir, ) index = np.arange(len(self.data))[index] return PandasSeriesColumn([self.image_paths[idx] for idx in index]) diff --git a/tests/meerkat/test_datapanel.py b/tests/meerkat/test_datapanel.py index 78eeba304..49a106534 100644 --- a/tests/meerkat/test_datapanel.py +++ b/tests/meerkat/test_datapanel.py @@ -49,7 +49,7 @@ def __init__( self, column_configs: Dict[str, AbstractColumn], consolidated: bool = True, - length: int = 16, + length: int = 4, tmpdir: str = None, ): self.column_testbeds = self._build_column_testbeds(