diff --git a/flash/data/data_utils.py b/flash/data/data_utils.py index 50e6b3bf63..2cabd08552 100644 --- a/flash/data/data_utils.py +++ b/flash/data/data_utils.py @@ -1,9 +1,15 @@ -from typing import Dict, List, Union +from typing import Any, Dict, List, Union import pandas as pd -def labels_from_categorical_csv(csv: str, index_col: str, return_dict: bool = True) -> Union[Dict, List]: +def labels_from_categorical_csv( + csv: str, + index_col: str, + feature_cols: List, + return_dict: bool = True, + index_col_collate_fn: Any = None +) -> Union[Dict, List]: """ Returns a dictionary with {index_col: label} for each entry in the csv. @@ -17,10 +23,15 @@ def labels_from_categorical_csv(csv: str, index_col: str, return_dict: bool = Tr df = pd.read_csv(csv) # get names names = df[index_col].to_list() - del df[index_col] + + # apply colate fn to index_col + if index_col_collate_fn: + for i in range(len(names)): + names[i] = index_col_collate_fn(names[i]) # everything else is binary - labels = df.to_numpy().argmax(1).tolist() + feature_df = df[feature_cols] + labels = feature_df.to_numpy().argmax(1).tolist() if return_dict: labels = {name: label for name, label in zip(names, labels)} diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 4af59a8764..6f90f2571d 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -77,7 +77,8 @@ def __getitem__(self, index: int) -> Tuple[Any, Optional[int]]: img = self.transform(img) label = None if self.has_dict_labels: - name = os.path.basename(filename) + name = os.path.splitext(filename)[0] + name = os.path.basename(name) label = self.labels[name] elif self.has_labels: @@ -256,6 +257,7 @@ def from_filepaths( train_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None, train_labels: Optional[Sequence] = None, train_transform: Optional[Callable] = _default_train_transforms, + valid_split: Union[None, float] = None, valid_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None, valid_labels: Optional[Sequence] = None, valid_transform: Optional[Callable] = _default_valid_transforms, @@ -264,6 +266,7 @@ def from_filepaths( loader: Callable = _pil_loader, batch_size: int = 64, num_workers: Optional[int] = None, + seed: int = 1234, **kwargs ): """Creates a ImageClassificationData object from lists of image filepaths and labels @@ -272,6 +275,7 @@ def from_filepaths( train_filepaths: string or sequence of file paths for training dataset. Defaults to ``None``. train_labels: sequence of labels for training dataset. Defaults to ``None``. train_transform: transforms for training dataset. Defaults to ``None``. + valid_split: if not None, generates val split from train dataloader using this value. valid_filepaths: string or sequence of file paths for validation dataset. Defaults to ``None``. valid_labels: sequence of labels for validation dataset. Defaults to ``None``. valid_transform: transforms for validation and testing dataset. Defaults to ``None``. @@ -281,6 +285,7 @@ def from_filepaths( batch_size: the batchsize to use for parallel loading. Defaults to ``64``. num_workers: The number of workers to use for parallelized loading. Defaults to ``None`` which equals the number of available CPU threads. + seed: Used for the train/val splits when valid_split is not None Returns: ImageClassificationData: The constructed data module. @@ -319,14 +324,25 @@ def from_filepaths( loader=loader, transform=train_transform, ) - valid_ds = ( - FilepathDataset( - filepaths=valid_filepaths, - labels=valid_labels, - loader=loader, - transform=valid_transform, - ) if valid_filepaths is not None else None - ) + + if valid_split: + full_length = len(train_ds) + train_split = int((1.0 - valid_split) * full_length) + valid_split = full_length - train_split + train_ds, valid_ds = torch.utils.data.random_split( + train_ds, + [train_split, valid_split], + generator=torch.Generator().manual_seed(seed) + ) + else: + valid_ds = ( + FilepathDataset( + filepaths=valid_filepaths, + labels=valid_labels, + loader=loader, + transform=valid_transform, + ) if valid_filepaths is not None else None + ) test_ds = ( FilepathDataset( diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index f05ecab306..8e442fec1e 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -93,27 +93,33 @@ def test_categorical_csv_labels(tmpdir): train_csv = os.path.join(tmpdir, 'some_dataset', 'train.csv') text_file = open(train_csv, 'w') text_file.write( - 'my_id, label_a, label_b, label_c\n"train_1.png", 0, 1, 0\n"train_2.png", 0, 0, 1\n"train_2.png", 1, 0, 0\n' + 'my_id,label_a,label_b,label_c\n"train_1.png", 0, 1, 0\n"train_2.png", 0, 0, 1\n"train_2.png", 1, 0, 0\n' ) text_file.close() valid_csv = os.path.join(tmpdir, 'some_dataset', 'valid.csv') text_file = open(valid_csv, 'w') text_file.write( - 'my_id, label_a, label_b, label_c\n"valid_1.png", 0, 1, 0\n"valid_2.png", 0, 0, 1\n"valid_3.png", 1, 0, 0\n' + 'my_id,label_a,label_b,label_c\n"valid_1.png", 0, 1, 0\n"valid_2.png", 0, 0, 1\n"valid_3.png", 1, 0, 0\n' ) text_file.close() test_csv = os.path.join(tmpdir, 'some_dataset', 'test.csv') text_file = open(test_csv, 'w') text_file.write( - 'my_id, label_a, label_b, label_c\n"test_1.png", 0, 1, 0\n"test_2.png", 0, 0, 1\n"test_3.png", 1, 0, 0\n' + 'my_id,label_a,label_b,label_c\n"test_1.png", 0, 1, 0\n"test_2.png", 0, 0, 1\n"test_3.png", 1, 0, 0\n' ) text_file.close() - train_labels = labels_from_categorical_csv(train_csv, 'my_id') - valid_labels = labels_from_categorical_csv(valid_csv, 'my_id') - test_labels = labels_from_categorical_csv(test_csv, 'my_id') + def index_col_collate_fn(x): + return os.path.splitext(x)[0] + + train_labels = labels_from_categorical_csv( + train_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn) + valid_labels = labels_from_categorical_csv( + valid_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn) + test_labels = labels_from_categorical_csv( + test_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn) data = ImageClassificationData.from_filepaths( batch_size=2, @@ -134,6 +140,16 @@ def test_categorical_csv_labels(tmpdir): for (x, y) in data.test_dataloader(): assert len(x) == 2 + data = ImageClassificationData.from_filepaths( + batch_size=2, + train_filepaths=os.path.join(tmpdir, 'some_dataset', 'train'), + train_labels=train_labels, + valid_split=0.5 + ) + + for (x, y) in data.val_dataloader(): + assert len(x) == 1 + def test_from_folders(tmpdir): train_dir = Path(tmpdir / "train")