Skip to content

Commit

Permalink
add some transforms and imagefoler dataset, and ignore warnings (Padd…
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang authored Dec 7, 2022
1 parent b9c7c2a commit 19c36c9
Show file tree
Hide file tree
Showing 8 changed files with 411 additions and 38 deletions.
10 changes: 10 additions & 0 deletions plsc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from plsc import core as core
from plsc import data as data
from plsc import loss as loss
from plsc import metric as metric
from plsc import models as models
from plsc import nn as nn
from plsc import optimizer as optimizer
from plsc import scheduler as scheduler
from plsc import utils as utils
1 change: 1 addition & 0 deletions plsc/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .imagenet_dataset import ImageNetDataset
from .face_recognition_dataset import FaceIdentificationDataset, FaceVerificationDataset, FaceRandomDataset
from .imagefolder_dataset import ImageFolder
195 changes: 195 additions & 0 deletions plsc/data/dataset/imagefolder_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import numpy as np
import os

import paddle

IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif",
".tiff", ".webp")


class ImageFolder(paddle.io.Dataset):
""" Code ref from https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py
A generic data loader where the images are arranged in this way by default: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
the same methods can be overridden to customize the dataset.
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an numpy image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""

def __init__(self,
root,
transform=None,
target_transform=None,
extensions=IMG_EXTENSIONS):

self.root = root
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions)
print(f'find total {len(classes)} classes and {len(samples)} images.')

self.extensions = extensions

self.classes = classes
self.class_to_idx = class_to_idx
self.imgs = samples
self.targets = [s[1] for s in samples]

self.transform = transform
self.target_transform = target_transform

@staticmethod
def make_dataset(
directory,
class_to_idx,
extensions=None,
is_valid_file=None, ):
"""Generates a list of samples of a form (path_to_sample, class).
Args:
directory (str): root dataset directory, corresponding to ``self.root``.
class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``class_to_idx`` is empty.
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
FileNotFoundError: In case no valid file was found for any class.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
if class_to_idx is None:
# prevent potential bug since make_dataset() would use the class_to_idx logic of the
# find_classes() function, instead of using that of the find_classes() method, which
# is potentially overridden and thus could have a different logic.
raise ValueError("The class_to_idx parameter cannot be None.")

directory = os.path.expanduser(directory)

both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)

if extensions is not None:

def is_valid_file(filename: str) -> bool:
return filename.lower().endswith(
extensions
if isinstance(extensions, str) else tuple(extensions))

is_valid_file = cast(Callable[[str], bool], is_valid_file)

instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(
os.walk(
target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
instances.append(item)

if target_class not in available_classes:
available_classes.add(target_class)

empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
raise FileNotFoundError(msg)

return instances

def find_classes(self, directory):
"""Find the class folders in a dataset structured as follows::
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
This method can be overridden to only consider
a subset of classes, or to adapt to a different dataset directory structure.
Args:
directory(str): Root directory path, corresponding to ``self.root``
Raises:
FileNotFoundError: If ``dir`` has no class folders.
Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
"""

classes = sorted(
entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(
f"Couldn't find any class folder in {directory}.")

class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx

def __getitem__(self, idx):
path, target = self.imgs[idx]
with open(path, 'rb') as f:
sample = f.read()
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return (sample, np.int32(target))

def __len__(self) -> int:
return len(self.imgs)

@property
def class_num(self):
return len(set(self.classes))
2 changes: 1 addition & 1 deletion plsc/data/preprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .basic_transforms import Compose, DecodeImage, ResizeImage, CenterCropImage, RandCropImage, RandFlipImage, NormalizeImage, ToCHWImage, ColorJitter, RandomErasing
from .basic_transforms import *
from .batch_transforms import Mixup, Cutmix, TransformOpSampler
from .timm_autoaugment import TimmAutoAugment
Loading

0 comments on commit 19c36c9

Please sign in to comment.