forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add some transforms and imagefoler dataset, and ignore warnings (Padd…
- Loading branch information
1 parent
b9c7c2a
commit 19c36c9
Showing
8 changed files
with
411 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union | ||
import numpy as np | ||
import os | ||
|
||
import paddle | ||
|
||
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", | ||
".tiff", ".webp") | ||
|
||
|
||
class ImageFolder(paddle.io.Dataset): | ||
""" Code ref from https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py | ||
A generic data loader where the images are arranged in this way by default: :: | ||
root/dog/xxx.png | ||
root/dog/xxy.png | ||
root/dog/[...]/xxz.png | ||
root/cat/123.png | ||
root/cat/nsdf3.png | ||
root/cat/[...]/asd932_.png | ||
This class inherits from :class:`~torchvision.datasets.DatasetFolder` so | ||
the same methods can be overridden to customize the dataset. | ||
Args: | ||
root (string): Root directory path. | ||
transform (callable, optional): A function/transform that takes in an numpy image | ||
and returns a transformed version. E.g, ``transforms.RandomCrop`` | ||
target_transform (callable, optional): A function/transform that takes in the | ||
target and transforms it. | ||
loader (callable, optional): A function to load an image given its path. | ||
is_valid_file (callable, optional): A function that takes path of an Image file | ||
and check if the file is a valid file (used to check of corrupt files) | ||
Attributes: | ||
classes (list): List of the class names sorted alphabetically. | ||
class_to_idx (dict): Dict with items (class_name, class_index). | ||
imgs (list): List of (image path, class_index) tuples | ||
""" | ||
|
||
def __init__(self, | ||
root, | ||
transform=None, | ||
target_transform=None, | ||
extensions=IMG_EXTENSIONS): | ||
|
||
self.root = root | ||
classes, class_to_idx = self.find_classes(self.root) | ||
samples = self.make_dataset(self.root, class_to_idx, extensions) | ||
print(f'find total {len(classes)} classes and {len(samples)} images.') | ||
|
||
self.extensions = extensions | ||
|
||
self.classes = classes | ||
self.class_to_idx = class_to_idx | ||
self.imgs = samples | ||
self.targets = [s[1] for s in samples] | ||
|
||
self.transform = transform | ||
self.target_transform = target_transform | ||
|
||
@staticmethod | ||
def make_dataset( | ||
directory, | ||
class_to_idx, | ||
extensions=None, | ||
is_valid_file=None, ): | ||
"""Generates a list of samples of a form (path_to_sample, class). | ||
Args: | ||
directory (str): root dataset directory, corresponding to ``self.root``. | ||
class_to_idx (Dict[str, int]): Dictionary mapping class name to class index. | ||
extensions (optional): A list of allowed extensions. | ||
Either extensions or is_valid_file should be passed. Defaults to None. | ||
is_valid_file (optional): A function that takes path of a file | ||
and checks if the file is a valid file | ||
(used to check of corrupt files) both extensions and | ||
is_valid_file should not be passed. Defaults to None. | ||
Raises: | ||
ValueError: In case ``class_to_idx`` is empty. | ||
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. | ||
FileNotFoundError: In case no valid file was found for any class. | ||
Returns: | ||
List[Tuple[str, int]]: samples of a form (path_to_sample, class) | ||
""" | ||
if class_to_idx is None: | ||
# prevent potential bug since make_dataset() would use the class_to_idx logic of the | ||
# find_classes() function, instead of using that of the find_classes() method, which | ||
# is potentially overridden and thus could have a different logic. | ||
raise ValueError("The class_to_idx parameter cannot be None.") | ||
|
||
directory = os.path.expanduser(directory) | ||
|
||
both_none = extensions is None and is_valid_file is None | ||
both_something = extensions is not None and is_valid_file is not None | ||
if both_none or both_something: | ||
raise ValueError( | ||
"Both extensions and is_valid_file cannot be None or not None at the same time" | ||
) | ||
|
||
if extensions is not None: | ||
|
||
def is_valid_file(filename: str) -> bool: | ||
return filename.lower().endswith( | ||
extensions | ||
if isinstance(extensions, str) else tuple(extensions)) | ||
|
||
is_valid_file = cast(Callable[[str], bool], is_valid_file) | ||
|
||
instances = [] | ||
available_classes = set() | ||
for target_class in sorted(class_to_idx.keys()): | ||
class_index = class_to_idx[target_class] | ||
target_dir = os.path.join(directory, target_class) | ||
if not os.path.isdir(target_dir): | ||
continue | ||
for root, _, fnames in sorted( | ||
os.walk( | ||
target_dir, followlinks=True)): | ||
for fname in sorted(fnames): | ||
path = os.path.join(root, fname) | ||
if is_valid_file(path): | ||
item = path, class_index | ||
instances.append(item) | ||
|
||
if target_class not in available_classes: | ||
available_classes.add(target_class) | ||
|
||
empty_classes = set(class_to_idx.keys()) - available_classes | ||
if empty_classes: | ||
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " | ||
if extensions is not None: | ||
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" | ||
raise FileNotFoundError(msg) | ||
|
||
return instances | ||
|
||
def find_classes(self, directory): | ||
"""Find the class folders in a dataset structured as follows:: | ||
directory/ | ||
├── class_x | ||
│ ├── xxx.ext | ||
│ ├── xxy.ext | ||
│ └── ... | ||
│ └── xxz.ext | ||
└── class_y | ||
├── 123.ext | ||
├── nsdf3.ext | ||
└── ... | ||
└── asd932_.ext | ||
This method can be overridden to only consider | ||
a subset of classes, or to adapt to a different dataset directory structure. | ||
Args: | ||
directory(str): Root directory path, corresponding to ``self.root`` | ||
Raises: | ||
FileNotFoundError: If ``dir`` has no class folders. | ||
Returns: | ||
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index. | ||
""" | ||
|
||
classes = sorted( | ||
entry.name for entry in os.scandir(directory) if entry.is_dir()) | ||
if not classes: | ||
raise FileNotFoundError( | ||
f"Couldn't find any class folder in {directory}.") | ||
|
||
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} | ||
return classes, class_to_idx | ||
|
||
def __getitem__(self, idx): | ||
path, target = self.imgs[idx] | ||
with open(path, 'rb') as f: | ||
sample = f.read() | ||
if self.transform is not None: | ||
sample = self.transform(sample) | ||
if self.target_transform is not None: | ||
target = self.target_transform(target) | ||
return (sample, np.int32(target)) | ||
|
||
def __len__(self) -> int: | ||
return len(self.imgs) | ||
|
||
@property | ||
def class_num(self): | ||
return len(set(self.classes)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.