Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[bug] Fix from_filepaths (#198)
Browse files Browse the repository at this point in the history
* update

* remove viz

* Update flash/vision/classification/data.py

Co-authored-by: Kaushik B <[email protected]>
  • Loading branch information
tchaton and kaushikb11 authored Mar 30, 2021
1 parent 0b967bb commit 62472d7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 141 deletions.
94 changes: 0 additions & 94 deletions flash/core/data/utils.py

This file was deleted.

62 changes: 20 additions & 42 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,14 @@ def from_filepaths(
test_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None,
test_labels: Optional[Sequence] = None,
predict_filepaths: Optional[Union[str, pathlib.Path, Sequence[Union[str, pathlib.Path]]]] = None,
train_transform: Optional[Callable] = 'default',
val_transform: Optional[Callable] = 'default',
train_transform: Union[str, Dict] = 'default',
val_transform: Union[str, Dict] = 'default',
test_transform: Union[str, Dict] = 'default',
predict_transform: Union[str, Dict] = 'default',
batch_size: int = 64,
num_workers: Optional[int] = None,
seed: Optional[int] = 42,
preprocess_cls: Optional[Type[Preprocess]] = None,
**kwargs,
) -> 'ImageClassificationData':
"""
Expand Down Expand Up @@ -532,59 +535,34 @@ def from_filepaths(
train_filepaths = [os.path.join(train_filepaths, x) for x in os.listdir(train_filepaths)]
else:
train_filepaths = [train_filepaths]

if isinstance(val_filepaths, str):
if os.path.isdir(val_filepaths):
val_filepaths = [os.path.join(val_filepaths, x) for x in os.listdir(val_filepaths)]
else:
val_filepaths = [val_filepaths]

if isinstance(test_filepaths, str):
if os.path.isdir(test_filepaths):
test_filepaths = [os.path.join(test_filepaths, x) for x in os.listdir(test_filepaths)]
else:
test_filepaths = [test_filepaths]
if isinstance(predict_filepaths, str):
if os.path.isdir(predict_filepaths):
predict_filepaths = [os.path.join(predict_filepaths, x) for x in os.listdir(predict_filepaths)]
else:
predict_filepaths = [predict_filepaths]

if train_filepaths is not None and train_labels is not None:
train_dataset = cls._generate_dataset_if_possible(
list(zip(train_filepaths, train_labels)), running_stage=RunningStage.TRAINING
)
else:
train_dataset = None

if val_filepaths is not None and val_labels is not None:
val_dataset = cls._generate_dataset_if_possible(
list(zip(val_filepaths, val_labels)), running_stage=RunningStage.VALIDATING
)
else:
val_dataset = None

if test_filepaths is not None and test_labels is not None:
test_dataset = cls._generate_dataset_if_possible(
list(zip(test_filepaths, test_labels)), running_stage=RunningStage.TESTING
)
else:
test_dataset = None

if predict_filepaths is not None:
predict_dataset = cls._generate_dataset_if_possible(
predict_filepaths, running_stage=RunningStage.PREDICTING
)
else:
predict_dataset = None
preprocess = cls.instantiate_preprocess(
train_transform,
val_transform,
test_transform,
predict_transform,
preprocess_cls=preprocess_cls,
)

return cls(
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
predict_dataset=predict_dataset,
train_transform=train_transform,
val_transform=val_transform,
return cls.from_load_data_inputs(
train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None,
val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None,
test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None,
predict_load_data_input=predict_filepaths,
batch_size=batch_size,
num_workers=num_workers,
seed=seed,
preprocess=preprocess,
**kwargs
)
10 changes: 5 additions & 5 deletions tests/vision/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def _dummy_image_loader(_):


def _rand_image():
return Image.fromarray(np.random.randint(0, 255, (196, 196, 3), dtype="uint8"))
_size = np.random.choice([196, 244])
return Image.fromarray(np.random.randint(0, 255, (_size, _size, 3), dtype="uint8"))


def test_from_filepaths(tmpdir):
Expand All @@ -45,14 +46,13 @@ def test_from_filepaths(tmpdir):
train_filepaths=[tmpdir / "a", tmpdir / "b"],
train_transform=None,
train_labels=[0, 1],
batch_size=1,
batch_size=2,
num_workers=0,
)

data = next(iter(img_data.train_dataloader()))
imgs, labels = data
assert imgs.shape == (1, 3, 196, 196)
assert labels.shape == (1, )
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, )

assert img_data.val_dataloader() is None
assert img_data.test_dataloader() is None
Expand Down

0 comments on commit 62472d7

Please sign in to comment.