Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for base_dir in image column #184

Merged
merged 1 commit into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions meerkat/columns/image_column.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import os
from typing import Collection, Sequence

from meerkat.columns.abstract import AbstractColumn
Expand All @@ -13,22 +14,28 @@
logger = logging.getLogger(__name__)


class ImageCell(LambdaCell):
class ImageLoaderMixin:
def fn(self, filepath: str):
if self.base_dir is not None:
filepath = os.path.join(self.base_dir, filepath)
image = self.loader(filepath)
if self.transform is not None:
image = self.transform(image)
return image


class ImageCell(ImageLoaderMixin, LambdaCell):
def __init__(
self,
transform: callable = None,
loader: callable = None,
data: str = None,
base_dir: str = None,
):
self.loader = self.default_loader if loader is None else loader
self.transform = transform
self._data = data

def fn(self, filepath: str):
image = self.loader(filepath)
if self.transform is not None:
image = self.transform(image)
return image
self.base_dir = base_dir

def __eq__(self, other):
return (
Expand All @@ -45,12 +52,13 @@ def __repr__(self):
return f"ImageCell({short_path}, transform={transform})"


class ImageColumn(LambdaColumn):
class ImageColumn(ImageLoaderMixin, LambdaColumn):
def __init__(
self,
data: Sequence[str] = None,
transform: callable = None,
loader: callable = None,
base_dir: str = None,
*args,
**kwargs,
):
Expand All @@ -59,29 +67,31 @@ def __init__(
super(ImageColumn, self).__init__(data, *args, **kwargs)
self.loader = self.default_loader if loader is None else loader
self.transform = transform
self.base_dir = base_dir

def _create_cell(self, data: object) -> ImageCell:
return ImageCell(data=data, loader=self.loader, transform=self.transform)

def fn(self, filepath: str):
image = self.loader(filepath)
if self.transform is not None:
image = self.transform(image)
return image
return ImageCell(
data=data,
loader=self.loader,
transform=self.transform,
base_dir=self.base_dir,
)

@classmethod
def from_filepaths(
cls,
filepaths: Sequence[str],
loader: callable = None,
transform: callable = None,
base_dir: str = None,
*args,
**kwargs,
):
return cls(
data=filepaths,
loader=loader,
transform=transform,
base_dir=base_dir,
*args,
**kwargs,
)
Expand All @@ -92,7 +102,11 @@ def default_loader(cls, *args, **kwargs):

@classmethod
def _state_keys(cls) -> Collection:
return (super()._state_keys() | {"transform", "loader"}) - {"fn"}
return (super()._state_keys() | {"transform", "loader", "base_dir"}) - {"fn"}

def _set_state(self, state: dict):
state["base_dir"] = state.get("base_dir", None) # backwards compatibility
super()._set_state(state)

def is_equal(self, other: AbstractColumn) -> bool:
return (
Expand Down
20 changes: 16 additions & 4 deletions tests/meerkat/columns/test_image_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@

class ImageColumnTestBed(AbstractColumnTestBed):

DEFAULT_CONFIG = {"transform": [True, False]}
DEFAULT_CONFIG = {"transform": [True, False], "use_base_dir": [True, False]}

def __init__(
self,
tmpdir: str,
length: int = 16,
transform: bool = False,
use_base_dir: bool = False,
seed: int = 123,
):
self.image_paths = []
Expand All @@ -42,18 +43,28 @@ def __init__(

transform = to_tensor if transform else None

self.base_dir = tmpdir if use_base_dir else None

for i in range(0, length):
self.image_paths.append(os.path.join(tmpdir, "{}.png".format(i)))
self.image_arrays.append((i * np.ones((4, 4, 3))).astype(np.uint8))
im = Image.fromarray(self.image_arrays[-1])
self.ims.append(im)
self.data.append(transform(im) if transform else im)
im.save(self.image_paths[-1])
filename = "{}.png".format(i)
im.save(os.path.join(tmpdir, filename))
if use_base_dir:
self.image_paths.append(filename)
else:
self.image_paths.append(os.path.join(tmpdir, filename))

if transform is not None:
self.data = torch.stack(self.data)
self.transform = transform
self.col = ImageColumn.from_filepaths(
self.image_paths, transform=transform, loader=folder.default_loader
self.image_paths,
transform=transform,
loader=folder.default_loader,
base_dir=self.base_dir,
)

def get_map_spec(
Expand Down Expand Up @@ -172,6 +183,7 @@ def get_data(self, index, materialize: bool = True):
data=self.image_paths[index],
loader=self.col.loader,
transform=self.col.transform,
base_dir=self.base_dir,
)
index = np.arange(len(self.data))[index]
return PandasSeriesColumn([self.image_paths[idx] for idx in index])
Expand Down
2 changes: 1 addition & 1 deletion tests/meerkat/test_datapanel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
self,
column_configs: Dict[str, AbstractColumn],
consolidated: bool = True,
length: int = 16,
length: int = 4,
tmpdir: str = None,
):
self.column_testbeds = self._build_column_testbeds(
Expand Down