You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add "classes" option totorchvision.datasets.ImageFolder.__init__() method, which specify the classes I want to use under "root", not every classes (folders) under "root".
Motivation
I'd like to specify classes to use under "root" path, but current ImageFolder loads every classes (dirs) under "root" path.
In Keras' ImageDataGenerator.flow_from_directory, I can specify classes to use under "root" by passing "classes" optional argument. I'd like pytorch ImageFolder to have the same argument as Keras' ImageDataGenerator.flow_from_directory (I've opened same issue which is already closed, at wrong repo before.).
Pitch
Add a new optional argument "classes" which takes list of classes I need, on torchvision.datasets.ImageFolder.__init__() method, and passed to super class DatasetFolder then do selective class import in DatasetFolder._find_classes() method.
If argument "classes" is not passed, ImageFolder loads all classes under "root" (same as current behavior).
The text was updated successfully, but these errors were encountered:
I can imagine that there would be a lot of different types of exceptions to take care of. I would err on the side of not having exceptions. The folders are fairly easy to restructure so as to fit the ImageFolder format. Thoughts?
if `class_list_arg_is_provided` and `provided_class_list_is_valid`:
classes = [d.name for d in os.scandir(dir) if d.is_dir() and d.name in `class_list_provided`]
else:
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
and add a simple method verifying provided_class_list is valid by checking
the provided_class_list is of type list or tuple.
each element in provided_class_list is of type str.
I agree with @vincentqb . There are multiple possible ways of providing a filter function for the class names, and I'm not sure that specifying it as a list is the most generic way. One could for example pass instead a filter function, which could support the case you mentioned, but others as well (for example, filtering via a regex).
This is a common-enough request that has been discussed in #145 for example, and we addressed it by letting _find_classes be a method of FolderDataset in #527. This gives full flexibility for the user, with just a bit of extra code. For example:
classMyDataset(DatasetFolder):
def_find_classes(self, dir):
# apply your custom rules heredataset=MyDataset(...)
As such, I'm closing the issue but let me know if you disagree.
🚀 Feature
Add "classes" option to
torchvision.datasets.ImageFolder.__init__()
method, which specify the classes I want to use under "root", not every classes (folders) under "root".Motivation
I'd like to specify classes to use under "root" path, but current ImageFolder loads every classes (dirs) under "root" path.
In Keras' ImageDataGenerator.flow_from_directory, I can specify classes to use under "root" by passing "classes" optional argument. I'd like pytorch ImageFolder to have the same argument as Keras' ImageDataGenerator.flow_from_directory (I've opened same issue which is already closed, at wrong repo before.).
Pitch
Add a new optional argument "classes" which takes list of classes I need, on
torchvision.datasets.ImageFolder.__init__()
method, and passed to super classDatasetFolder
then do selective class import inDatasetFolder._find_classes()
method.If argument "classes" is not passed, ImageFolder loads all classes under "root" (same as current behavior).
The text was updated successfully, but these errors were encountered: