Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge load / resize / cache to optimize data loading efficiency for classification #2438

Merged
merged 15 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import copy
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import numpy as np
from mmcls.datasets import PIPELINES
from mmcls.datasets.pipelines import Compose
from mmcls.datasets.pipelines import Compose, Resize
from mmcv.utils.registry import build_from_cfg
from PIL import Image, ImageFilter
from torchvision import transforms as T

import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base
import otx.algorithms.common.adapters.mmcv.pipelines.load_image_from_otx_dataset as load_image_base

# TODO: refactoring to common modules
# TODO: refactoring to Sphinx style.
Expand All @@ -23,6 +23,44 @@ class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset):
"""Pipeline element that loads an image from a OTX Dataset on the fly."""


@PIPELINES.register_module()
class LoadResizeDataFromOTXDataset(load_image_base.LoadResizeDataFromOTXDataset):
"""Load and resize image & annotation with cache support."""

def _create_load_img_op(self, cfg: Dict) -> Any:
"""Creates image loading operation."""
return build_from_cfg(cfg, PIPELINES)

def _create_resize_op(self, cfg: Optional[Dict]) -> Optional[Any]:
"""Creates resize operation."""
if cfg is None:
return None
return build_from_cfg(cfg, PIPELINES)


@PIPELINES.register_module()
class ResizeTo(Resize):
goodsong81 marked this conversation as resolved.
Show resolved Hide resolved
"""Resize to specific size.

This operation works if the input is not in desired shape.
If it's already in the shape, it just returns input dict for efficiency.

Args:
size (tuple): Images scales for resizing (h, w).
"""

def __call__(self, results: Dict[str, Any]):
"""Callback function of ResizeTo.

Args:
results: Inputs to be transformed.
"""
img_shape = results.get("img_shape", (0, 0))
if img_shape[0] == self.size[0] and img_shape[1] == self.size[1]:
return results
return super().__call__(results)


@PIPELINES.register_module()
class RandomAppliedTrans:
"""Randomly applied transformations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
__resize_target_size = 224

__train_pipeline = [
dict(type="LoadImageFromOTXDataset"),
dict(
type="LoadResizeDataFromOTXDataset",
load_img_cfg=dict(type="LoadImageFromOTXDataset", enable_memcache=False), # To be cached after resize
goodsong81 marked this conversation as resolved.
Show resolved Hide resolved
resize_cfg=dict(type="Resize", size=__resize_target_size, downscale_only=True),
goodsong81 marked this conversation as resolved.
Show resolved Hide resolved
# To be resized in this op only if input is larger than expected size
# for speed & cache memory efficiency.
),
dict(type="RandomResizedCrop", size=__resize_target_size, efficientnet_style=True),
dict(type="RandomFlip", flip_prob=0.5, direction="horizontal"),
dict(type="Normalize", **__img_norm_cfg),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""Pipeline element that loads an image from a OTX Dataset on the fly."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from tempfile import TemporaryDirectory
from typing import Any, Dict, Optional, Tuple

import numpy as np

from otx.algorithms.common.utils.data import get_image
from otx.core.data.caching import MemCacheHandlerError, MemCacheHandlerSingleton

_CACHE_DIR = TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with

# TODO: refactoring to common modules


class LoadImageFromOTXDataset:
"""Pipeline element that loads an image from a OTX Dataset on the fly.

Can do conversion to float 32 if needed.
Expected entries in the 'results' dict that should be passed to this pipeline element are:
results['dataset_item']: dataset_item from which to load the image
results['dataset_id']: id of the dataset to which the item belongs
results['index']: index of the item in the dataset

Args:
to_float32 (bool, optional): True to convert images to fp32. defaults to False.
enable_memcache (bool, optional): True to enable in-memory cache. defaults to True.
"""

def __init__(self, to_float32: bool = False, enable_memcache: bool = True):
self._to_float32 = to_float32
self._enable_memcache = enable_memcache
try:
goodsong81 marked this conversation as resolved.
Show resolved Hide resolved
self._mem_cache_handler = MemCacheHandlerSingleton.get()
except MemCacheHandlerError:
# Create a null handler
MemCacheHandlerSingleton.create(mode="null", mem_size=0)
self._mem_cache_handler = MemCacheHandlerSingleton.get()

@staticmethod
def _get_unique_key(results: Dict[str, Any]) -> Tuple:
"""Returns unique key of data item based on the contents."""
# TODO: We should improve it by assigning an unique id to DatasetItemEntity.
# This is because there is a case which
# d_item.media.path is None, but d_item.media.data is not None
d_item = results["dataset_item"]
return d_item.media.path, d_item.roi.id

def __call__(self, results: Dict[str, Any]):
"""Callback function of LoadImageFromOTXDataset."""
img = None
if self._enable_memcache:
key = self._get_unique_key(results)
img, meta = self._mem_cache_handler.get(key)

if img is None:
# Get image (possibly from file cache)
img = get_image(results, _CACHE_DIR.name, to_float32=False)
if self._enable_memcache:
self._mem_cache_handler.put(key, img)

if self._to_float32:
img = img.astype(np.float32)
shape = img.shape

if img.shape[0] != results["height"]:
results["height"] = img.shape[0]

if img.shape[1] != results["width"]:
results["width"] = img.shape[1]

filename = f"Dataset item index {results['index']}"
results["filename"] = filename
results["ori_filename"] = filename
results["img"] = img
results["img_shape"] = shape
results["ori_shape"] = shape
# Set initial values for default meta_keys
results["pad_shape"] = shape
num_channels = 1 if len(shape) < 3 else shape[2]
results["img_norm_cfg"] = dict(
mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32),
to_rgb=False,
)
results["img_fields"] = ["img"]
results["entity_id"] = results.get("entity_id")
results["label_id"] = results.get("label_id")

return results


class LoadResizeDataFromOTXDataset(LoadImageFromOTXDataset):
"""Load and resize image & annotation with cache support.

This base operation loads image and optionally loads annotations.
Then, resize the image and annotation accordingly if resize_cfg given & it's beneficial,
e.g. the size is smaller than original input size.
Finally, if enabled, cache the result and use pre-computed ones from next iterations.

Args:
load_img_cfg (Dict): Creates image loading operation based on the config
load_ann_cfg (Dict, optional): Optionally creates annotation loading operation based on the config.
Defaults to None.
resize_cfg (Dict, optional): Optionally creates resize operation based on the config. Defaults to None.
enable_memcache (bool, optional): True to enable in-memory cache. Defaults to True.
"""

def __init__(
self,
load_img_cfg: Dict,
load_ann_cfg: Optional[Dict] = None,
resize_cfg: Optional[Dict] = None,
**kwargs,
):
super().__init__(**kwargs)
load_img_cfg = load_img_cfg.copy()
load_img_cfg["enable_memcache"] = False # will use outer cache
self._load_img_op = self._create_load_img_op(load_img_cfg)
self._load_ann_op = self._create_load_ann_op(load_ann_cfg)
self._downscale_only = resize_cfg.pop("downscale_only", False) if resize_cfg else False
self._resize_op = self._create_resize_op(resize_cfg)
if self._resize_op is not None:
self._resize_shape = resize_cfg.get("size", resize_cfg.get("img_scale"))
if isinstance(self._resize_shape, int):
self._resize_shape = (self._resize_shape, self._resize_shape)
assert isinstance(self._resize_shape, tuple), f"Random scale is not supported by {self.__class__.__name__}"
else:
self._resize_shape = None

def _create_load_img_op(self, cfg: Dict) -> Any:
"""Creates image loading operation."""
return LoadImageFromOTXDataset(**cfg) # Should be overrided in task-specific implementation if needed

def _create_load_ann_op(self, cfg: Optional[Dict]) -> Optional[Any]:
"""Creates annotation loading operation."""
return None # Should be overrided in task-specific implementation

def _create_resize_op(self, cfg: Optional[Dict]) -> Optional[Any]:
"""Creates resize operation."""
return None # Should be overrided in task-specific implementation

def _load_img(self, results: Dict[str, Any]) -> Dict[str, Any]:
"""Load image and fill the results dict."""
return self._load_img_op(results)

def _load_ann_if_any(self, results: Dict[str, Any]) -> Dict[str, Any]:
"""Load annotations and fill the results dict."""
if self._load_ann_op is None:
return results
return self._load_ann_op(results)

def _resize_img_ann_if_any(self, results: Dict[str, Any]) -> Dict[str, Any]:
"""Resize image and annotations if needed and fill the results dict."""
if self._resize_op is None:
return results
original_shape = results.get("img_shape", self._resize_shape)
if original_shape is None:
return results
if self._downscale_only:
if original_shape[0] * original_shape[1] <= self._resize_shape[0] * self._resize_shape[1]:
# No benfit of early resizing if resize_shape is larger than original_shape
return results
return self._resize_op(results)

def _load_cache(self, results: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Try to load pre-computed results from cache."""
if not self._enable_memcache:
return None
key = self._get_unique_key(results)
img, meta = self._mem_cache_handler.get(key)
if img is None or meta is None:
return None
results = meta.copy()
results["img"] = img
return results

def _save_cache(self, results: Dict[str, Any]):
"""Try to save pre-computed results to cache."""
if not self._enable_memcache:
return
key = self._get_unique_key(results)
meta = results.copy()
meta.pop("dataset_item") # remove irrlevant info
img = meta.pop("img")
self._mem_cache_handler.put(key, img, meta)

def __call__(self, results: Dict[str, Any]) -> Dict[str, Any]:
"""Callback function."""
cached_results = self._load_cache(results)
if cached_results:
return cached_results
results = self._load_img(results)
results = self._load_ann_if_any(results)
results = self._resize_img_ann_if_any(results)
self._save_cache(results)
return results
11 changes: 8 additions & 3 deletions src/otx/algorithms/common/adapters/mmcv/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ class InputSizeManager:
"MultiScaleFlipAug": ["transforms"],
"AutoAugment": ["policies"],
"TwoCropTransform": ["view0", "view1", "pipeline"],
"LoadResizeDataFromOTXDataset": ["resize_cfg"],
}
SUBSET_TYPES: Tuple[str, str, str, str] = ("train", "val", "test", "unlabeled")

Expand Down Expand Up @@ -785,7 +786,9 @@ def _estimate_post_img_size(
for sub_pipeline_name in sub_pipeline_names:
if sub_pipeline_name in pipeline:
sub_pipeline = pipeline[sub_pipeline_name]
if isinstance(sub_pipeline[0], list):
if isinstance(sub_pipeline, dict):
sub_pipeline = [sub_pipeline]
elif isinstance(sub_pipeline[0], list):
sub_pipeline = sub_pipeline[0]
post_img_size = self._estimate_post_img_size(sub_pipeline, post_img_size)
break
Expand Down Expand Up @@ -829,7 +832,9 @@ def _set_pipeline_size_value(self, pipeline: Dict, scale: Tuple[Union[int, float
if pipeline_name == pipeline["type"]:
for sub_pipeline_name in sub_pipeline_names:
if sub_pipeline_name in pipeline:
if isinstance(pipeline[sub_pipeline_name][0], dict):
if isinstance(pipeline[sub_pipeline_name], dict):
self._set_pipeline_size_value(pipeline[sub_pipeline_name], scale)
elif isinstance(pipeline[sub_pipeline_name][0], dict):
for sub_pipeline in pipeline[sub_pipeline_name]:
self._set_pipeline_size_value(sub_pipeline, scale)
elif isinstance(pipeline[sub_pipeline_name][0], list):
Expand All @@ -839,7 +844,7 @@ def _set_pipeline_size_value(self, pipeline: Dict, scale: Tuple[Union[int, float
else:
raise ValueError(
"Dataset pipeline in pipeline wrapper type should be"
"either list[dict] or list[list[dict]]."
"either dict, list[dict] or list[list[dict]]."
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from mmdet.datasets.builder import PIPELINES

import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base
import otx.algorithms.common.adapters.mmcv.pipelines.load_image_from_otx_dataset as load_image_base
from otx.algorithms.detection.adapters.mmdet.datasets.dataset import (
get_annotation_mmdet_format,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from mmseg.datasets.builder import PIPELINES

import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base
import otx.algorithms.common.adapters.mmcv.pipelines.load_image_from_otx_dataset as load_image_base
from otx.algorithms.segmentation.adapters.mmseg.datasets.dataset import (
get_annotation_mmseg_format,
)
Expand Down
Loading