diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 565a79c7a19..ef3ae7af896 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -124,7 +124,7 @@ def __init__( super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) classes, class_to_idx = self._find_classes(self.root) - samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) + samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) if len(samples) == 0: msg = "Found 0 files in subfolders of: {}\n".format(self.root) if extensions is not None: @@ -139,6 +139,15 @@ def __init__( self.samples = samples self.targets = [s[1] for s in samples] + @staticmethod + def make_dataset( + directory: str, + class_to_idx: Dict[str, int], + extensions: Optional[Tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> List[Tuple[str, int]]: + return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) + def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: """ Finds the class folders in a dataset.