diff --git a/plsc/__init__.py b/plsc/__init__.py index 97043fd7ba6885..ba55f626ba68fa 100644 --- a/plsc/__init__.py +++ b/plsc/__init__.py @@ -11,3 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from plsc import core as core +from plsc import data as data +from plsc import loss as loss +from plsc import metric as metric +from plsc import models as models +from plsc import nn as nn +from plsc import optimizer as optimizer +from plsc import scheduler as scheduler +from plsc import utils as utils diff --git a/plsc/data/dataset/__init__.py b/plsc/data/dataset/__init__.py index 023bc4fdcbba44..af9ba626d8f86f 100644 --- a/plsc/data/dataset/__init__.py +++ b/plsc/data/dataset/__init__.py @@ -14,3 +14,4 @@ from .imagenet_dataset import ImageNetDataset from .face_recognition_dataset import FaceIdentificationDataset, FaceVerificationDataset, FaceRandomDataset +from .imagefolder_dataset import ImageFolder diff --git a/plsc/data/dataset/imagefolder_dataset.py b/plsc/data/dataset/imagefolder_dataset.py new file mode 100644 index 00000000000000..2e3b2be2625952 --- /dev/null +++ b/plsc/data/dataset/imagefolder_dataset.py @@ -0,0 +1,195 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +import numpy as np +import os + +import paddle + +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", + ".tiff", ".webp") + + +class ImageFolder(paddle.io.Dataset): + """ Code ref from https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py + + A generic data loader where the images are arranged in this way by default: :: + root/dog/xxx.png + root/dog/xxy.png + root/dog/[...]/xxz.png + root/cat/123.png + root/cat/nsdf3.png + root/cat/[...]/asd932_.png + This class inherits from :class:`~torchvision.datasets.DatasetFolder` so + the same methods can be overridden to customize the dataset. + Args: + root (string): Root directory path. + transform (callable, optional): A function/transform that takes in an numpy image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + is_valid_file (callable, optional): A function that takes path of an Image file + and check if the file is a valid file (used to check of corrupt files) + Attributes: + classes (list): List of the class names sorted alphabetically. + class_to_idx (dict): Dict with items (class_name, class_index). + imgs (list): List of (image path, class_index) tuples + """ + + def __init__(self, + root, + transform=None, + target_transform=None, + extensions=IMG_EXTENSIONS): + + self.root = root + classes, class_to_idx = self.find_classes(self.root) + samples = self.make_dataset(self.root, class_to_idx, extensions) + print(f'find total {len(classes)} classes and {len(samples)} images.') + + self.extensions = extensions + + self.classes = classes + self.class_to_idx = class_to_idx + self.imgs = samples + self.targets = [s[1] for s in samples] + + self.transform = transform + self.target_transform = target_transform + + @staticmethod + def make_dataset( + directory, + class_to_idx, + extensions=None, + is_valid_file=None, ): + """Generates a list of samples of a form (path_to_sample, class). + Args: + directory (str): root dataset directory, corresponding to ``self.root``. + class_to_idx (Dict[str, int]): Dictionary mapping class name to class index. + extensions (optional): A list of allowed extensions. + Either extensions or is_valid_file should be passed. Defaults to None. + is_valid_file (optional): A function that takes path of a file + and checks if the file is a valid file + (used to check of corrupt files) both extensions and + is_valid_file should not be passed. Defaults to None. + Raises: + ValueError: In case ``class_to_idx`` is empty. + ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. + FileNotFoundError: In case no valid file was found for any class. + Returns: + List[Tuple[str, int]]: samples of a form (path_to_sample, class) + """ + if class_to_idx is None: + # prevent potential bug since make_dataset() would use the class_to_idx logic of the + # find_classes() function, instead of using that of the find_classes() method, which + # is potentially overridden and thus could have a different logic. + raise ValueError("The class_to_idx parameter cannot be None.") + + directory = os.path.expanduser(directory) + + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError( + "Both extensions and is_valid_file cannot be None or not None at the same time" + ) + + if extensions is not None: + + def is_valid_file(filename: str) -> bool: + return filename.lower().endswith( + extensions + if isinstance(extensions, str) else tuple(extensions)) + + is_valid_file = cast(Callable[[str], bool], is_valid_file) + + instances = [] + available_classes = set() + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + for root, _, fnames in sorted( + os.walk( + target_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + item = path, class_index + instances.append(item) + + if target_class not in available_classes: + available_classes.add(target_class) + + empty_classes = set(class_to_idx.keys()) - available_classes + if empty_classes: + msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " + if extensions is not None: + msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" + raise FileNotFoundError(msg) + + return instances + + def find_classes(self, directory): + """Find the class folders in a dataset structured as follows:: + directory/ + ├── class_x + │ ├── xxx.ext + │ ├── xxy.ext + │ └── ... + │ └── xxz.ext + └── class_y + ├── 123.ext + ├── nsdf3.ext + └── ... + └── asd932_.ext + This method can be overridden to only consider + a subset of classes, or to adapt to a different dataset directory structure. + Args: + directory(str): Root directory path, corresponding to ``self.root`` + Raises: + FileNotFoundError: If ``dir`` has no class folders. + Returns: + (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index. + """ + + classes = sorted( + entry.name for entry in os.scandir(directory) if entry.is_dir()) + if not classes: + raise FileNotFoundError( + f"Couldn't find any class folder in {directory}.") + + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + + def __getitem__(self, idx): + path, target = self.imgs[idx] + with open(path, 'rb') as f: + sample = f.read() + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return (sample, np.int32(target)) + + def __len__(self) -> int: + return len(self.imgs) + + @property + def class_num(self): + return len(set(self.classes)) diff --git a/plsc/data/preprocess/__init__.py b/plsc/data/preprocess/__init__.py index 7f62e318f9e55e..3ea2e2134d91c8 100644 --- a/plsc/data/preprocess/__init__.py +++ b/plsc/data/preprocess/__init__.py @@ -11,6 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .basic_transforms import Compose, DecodeImage, ResizeImage, CenterCropImage, RandCropImage, RandFlipImage, NormalizeImage, ToCHWImage, ColorJitter, RandomErasing +from .basic_transforms import * from .batch_transforms import Mixup, Cutmix, TransformOpSampler from .timm_autoaugment import TimmAutoAugment diff --git a/plsc/data/preprocess/basic_transforms.py b/plsc/data/preprocess/basic_transforms.py index 4242840a562b6b..51cf7fe63179a4 100644 --- a/plsc/data/preprocess/basic_transforms.py +++ b/plsc/data/preprocess/basic_transforms.py @@ -22,11 +22,31 @@ import random import cv2 import numpy as np -from PIL import Image +from PIL import Image, ImageFilter, ImageOps from paddle.vision.transforms import ColorJitter as PPColorJitter from plsc.utils import logger +__all__ = [ + "Compose", + "DecodeImage", + "RandomApply", + "ResizeImage", + "CenterCropImage", + "RandCropImage", + "RandomResizedCrop", + "RandFlipImage", + "RandomHorizontalFlip", + "NormalizeImage", + "ToCHWImage", + "ColorJitter", + "RandomErasing", + "RandomGrayscale", + "SimCLRGaussianBlur", + "BYOLSolarize", + "MAERandCropImage", +] + class OperatorParamError(ValueError): """ OperatorParamError @@ -88,14 +108,24 @@ def __init__(self, interpolation=None, backend="cv2"): 'bicubic': cv2.INTER_CUBIC, 'lanczos': cv2.INTER_LANCZOS4 } - _pil_interp_from_str = { - 'nearest': Image.NEAREST, - 'bilinear': Image.BILINEAR, - 'bicubic': Image.BICUBIC, - 'box': Image.BOX, - 'lanczos': Image.LANCZOS, - 'hamming': Image.HAMMING - } + if hasattr(Image, "Resampling"): + _pil_interp_from_str = { + 'nearest': Image.Resampling.NEAREST, + 'bilinear': Image.Resampling.BILINEAR, + 'bicubic': Image.Resampling.BICUBIC, + 'box': Image.Resampling.BOX, + 'lanczos': Image.Resampling.LANCZOS, + 'hamming': Image.Resampling.HAMMING + } + else: + _pil_interp_from_str = { + 'nearest': Image.NEAREST, + 'bilinear': Image.BILINEAR, + 'bicubic': Image.BICUBIC, + 'box': Image.BOX, + 'lanczos': Image.LANCZOS, + 'hamming': Image.HAMMING + } def _pil_resize(src, size, resample): pil_img = Image.fromarray(src) @@ -227,6 +257,42 @@ def __call__(self, img): return self._resize_func(img, size) +class RandomResizedCrop(RandCropImage): + """ only rename """ + pass + + +class MAERandCropImage(RandCropImage): + """ + RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. + This may lead to results different with torchvision's version. + Following BYOL's TF code: + https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 + """ + + def __call__(self, img): + size = self.size + + img_h, img_w = img.shape[:2] + + target_area = img_w * img_h * np.random.uniform(*self.scale) + log_ratio = tuple(math.log(x) for x in self.ratio) + aspect_ratio = math.exp(np.random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + w = min(w, img_w) + h = min(h, img_h) + + i = random.randint(0, img_w - w) + j = random.randint(0, img_h - h) + + img = img[j:j + h, i:i + w, :] + + return self._resize_func(img, size) + + class RandFlipImage(object): """ random flip image flip_code: @@ -247,6 +313,17 @@ def __call__(self, img): return img +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img): + if np.random.rand() < self.p: + return cv2.flip(img, 1) + else: + return img + + class NormalizeImage(object): """ normalize image such as substract mean, divide std """ @@ -314,16 +391,18 @@ class ColorJitter(PPColorJitter): """ColorJitter. """ - def __init__(self, *args, **kwargs): + def __init__(self, p=1.0, *args, **kwargs): + self.p = p super().__init__(*args, **kwargs) def __call__(self, img): - if not isinstance(img, Image.Image): - img = np.ascontiguousarray(img) - img = Image.fromarray(img) - img = super()._apply_image(img) - if isinstance(img, Image.Image): - img = np.asarray(img) + if random.random() < self.p: + if not isinstance(img, Image.Image): + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + img = super()._apply_image(img) + if isinstance(img, Image.Image): + img = np.asarray(img) return img @@ -394,3 +473,78 @@ def __call__(self, img): img[x1:x1 + h, y1:y1 + w, 0] = pixels[0] return img return img + + +class RandomApply(object): + def __init__(self, transforms, p=0.5): + self.transforms = transforms + self.p = p + + def __call__(self, img): + if self.p < np.random.rand(): + return img + for t in self.transforms: + img = t(img) + return img + + +class RandomGrayscale(object): + def __init__(self, p=0.1): + self.p = p + + def __call__(self, img): + _, _, num_output_channels = img.shape + + if np.random.rand() < self.p: + + if not isinstance(img, Image.Image): + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + + if num_output_channels == 1: + img = img.convert("L") + img = np.array(img, dtype=np.uint8) + elif num_output_channels == 3: + img = img.convert("L") + img = np.array(img, dtype=np.uint8) + img = np.dstack([img, img, img]) + else: + raise ValueError("num_output_channels should be either 1 or 3") + + return img + + +class SimCLRGaussianBlur(object): + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" + + def __init__(self, sigma=[.1, 2.], p=1.0): + self.p = p + self.sigma = sigma + + def __call__(self, img): + if random.random() < self.p: + if not isinstance(img, Image.Image): + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + sigma = random.uniform(self.sigma[0], self.sigma[1]) + img = img.filter(ImageFilter.GaussianBlur(radius=sigma)) + if isinstance(img, Image.Image): + img = np.asarray(img) + return img + + +class BYOLSolarize(object): + """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733""" + + def __init__(self, p=1.0): + self.p = p + + def __call__(self, img): + if random.random() < self.p: + if not isinstance(img, Image.Image): + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + img = ImageOps.solarize(x) + if isinstance(img, Image.Image): + img = np.asarray(img) + return img diff --git a/plsc/data/preprocess/timm_autoaugment.py b/plsc/data/preprocess/timm_autoaugment.py index 037ac668fb2712..51ee52a3a62cf8 100644 --- a/plsc/data/preprocess/timm_autoaugment.py +++ b/plsc/data/preprocess/timm_autoaugment.py @@ -35,23 +35,36 @@ translate_const=250, img_mean=_FILL, ) -_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) +if hasattr(Image, "Resampling"): + BICUBIC = Image.Resampling.BICUBIC + BILINEAR = Image.Resampling.BILINEAR + HAMMING = Image.Resampling.HAMMING + LANCZOS = Image.Resampling.LANCZOS + AFFINE = Image.Transform.AFFINE +else: + BICUBIC = Image.BICUBIC + BILINEAR = Image.BILINEAR + HAMMING = Image.HAMMING + LANCZOS = Image.LANCZOS + AFFINE = Image.AFFINE + +_RANDOM_INTERPOLATION = (BILINEAR, BICUBIC) def _pil_interp(method): if method == 'bicubic': - return Image.BICUBIC + return BICUBIC elif method == 'lanczos': - return Image.LANCZOS + return LANCZOS elif method == 'hamming': - return Image.HAMMING + return HAMMING else: # default bilinear, do we want to allow nearest? - return Image.BILINEAR + return BILINEAR def _interpolation(kwargs): - interpolation = kwargs.pop('resample', Image.BILINEAR) + interpolation = kwargs.pop('resample', BILINEAR) if isinstance(interpolation, (list, tuple)): return random.choice(interpolation) else: @@ -66,40 +79,34 @@ def _check_args_tf(kwargs): def shear_x(img, factor, **kwargs): _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), - **kwargs) + return img.transform(img.size, AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) def shear_y(img, factor, **kwargs): _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), - **kwargs) + return img.transform(img.size, AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) def translate_x_rel(img, pct, **kwargs): pixels = pct * img.size[0] _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), - **kwargs) + return img.transform(img.size, AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) def translate_y_rel(img, pct, **kwargs): pixels = pct * img.size[1] _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), - **kwargs) + return img.transform(img.size, AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) def translate_x_abs(img, pixels, **kwargs): _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), - **kwargs) + return img.transform(img.size, AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) def translate_y_abs(img, pixels, **kwargs): _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), - **kwargs) + return img.transform(img.size, AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) def rotate(img, degrees, **kwargs): @@ -129,7 +136,7 @@ def transform(x, y, matrix): matrix) matrix[2] += rotn_center[0] matrix[5] += rotn_center[1] - return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + return img.transform(img.size, AFFINE, matrix, **kwargs) else: return img.rotate(degrees, resample=kwargs['resample']) diff --git a/plsc/metric/lfw_utils.py b/plsc/metric/lfw_utils.py index b23daec29715e4..ab99b18bf265ca 100644 --- a/plsc/metric/lfw_utils.py +++ b/plsc/metric/lfw_utils.py @@ -23,8 +23,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import numpy as np +import warnings +warnings.simplefilter(action='ignore', category=DeprecationWarning) import sklearn + +import numpy as np from scipy import interpolate from sklearn.decomposition import PCA from sklearn.model_selection import KFold diff --git a/plsc/metric/metrics.py b/plsc/metric/metrics.py index 99b8bdabcd774f..884badb9672769 100644 --- a/plsc/metric/metrics.py +++ b/plsc/metric/metrics.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +import warnings +warnings.simplefilter(action='ignore', category=DeprecationWarning) + import sklearn +import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F