Skip to content

Commit

Permalink
'make_dataset' as staticmethod of 'DatasetFolder' (#3215)
Browse files Browse the repository at this point in the history
Summary:
* 'make_dataset' as staticmethod of 'DatasetFolder'

* a better fix

Reviewed By: datumbox

Differential Revision: D25954567

fbshipit-source-id: 514fde3bad4e27518a198276228a36c3217c2163

Co-authored-by: Francisco Massa <[email protected]>
  • Loading branch information
2 people authored and facebook-github-bot committed Jan 21, 2021
1 parent 6a02f85 commit d10758d
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
Expand All @@ -139,6 +139,15 @@ def __init__(
self.samples = samples
self.targets = [s[1] for s in samples]

@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
"""
Finds the class folders in a dataset.
Expand Down

0 comments on commit d10758d

Please sign in to comment.