Skip to content

Commit

Permalink
Addresses issue #145 as per @fmessa's suggestion. (#527)
Browse files Browse the repository at this point in the history
* Addresses issue #145 as per @fmessa's suggestion.

* Removed blank line for styling.
  • Loading branch information
Choco31415 authored and fmassa committed Jun 6, 2018
1 parent 5a0d079 commit 3f6c23c
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,10 @@ def is_image_file(filename):
return has_file_allowed_extension(filename, IMG_EXTENSIONS)


def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx


def make_dataset(dir, class_to_idx, extensions):
images = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
Expand Down Expand Up @@ -86,7 +79,7 @@ class DatasetFolder(data.Dataset):
"""

def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = find_classes(root)
classes, class_to_idx = self._find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
Expand All @@ -104,6 +97,24 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No
self.transform = transform
self.target_transform = target_transform

def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx

def __getitem__(self, index):
"""
Args:
Expand Down

0 comments on commit 3f6c23c

Please sign in to comment.