Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge load / resize / cache to optimize data loading efficiency for detection & instance segmentation #2453

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ def _get_unique_key(results: Dict[str, Any]) -> Tuple:
# TODO: We should improve it by assigning an unique id to DatasetItemEntity.
# This is because there is a case which
# d_item.media.path is None, but d_item.media.data is not None
if "cache_key" in results:
return results["cache_key"]
d_item = results["dataset_item"]
return d_item.media.path, d_item.roi.id
results["cache_key"] = d_item.media.path, d_item.roi.id
return results["cache_key"]

def __call__(self, results: Dict[str, Any]):
"""Callback function of LoadImageFromOTXDataset."""
Expand Down Expand Up @@ -177,7 +180,6 @@ def _save_cache(self, results: Dict[str, Any]):
return
key = self._get_unique_key(results)
meta = results.copy()
meta.pop("dataset_item") # remove irrlevant info
img = meta.pop("img")
self._mem_cache_handler.put(key, img, meta)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
# SPDX-License-Identifier: Apache-2.0
#

from .load_pipelines import LoadAnnotationFromOTXDataset, LoadImageFromOTXDataset
from .load_pipelines import (
LoadAnnotationFromOTXDataset,
LoadImageFromOTXDataset,
LoadResizeDataFromOTXDataset,
ResizeTo,
)
from .torchvision2mmdet import (
BranchImage,
ColorJitter,
Expand All @@ -19,6 +24,8 @@
__all__ = [
"LoadImageFromOTXDataset",
"LoadAnnotationFromOTXDataset",
"LoadResizeDataFromOTXDataset",
"ResizeTo",
"ColorJitter",
"RandomGrayscale",
"RandomErasing",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
"""Collection Pipeline for detection task."""
# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
# Copyright (C) 2021-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import copy
from typing import Any, Dict
from typing import Any, Dict, Optional

from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.builder import PIPELINES, build_from_cfg
from mmdet.datasets.pipelines import Resize

import otx.algorithms.common.adapters.mmcv.pipelines.load_image_from_otx_dataset as load_image_base
from otx.algorithms.detection.adapters.mmdet.datasets.dataset import (
Expand All @@ -30,6 +21,50 @@ class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset):
"""Pipeline element that loads an image from a OTX Dataset on the fly."""


@PIPELINES.register_module()
class LoadResizeDataFromOTXDataset(load_image_base.LoadResizeDataFromOTXDataset):
"""Load and resize image & annotation with cache support."""

def _create_load_ann_op(self, cfg: Optional[Dict]) -> Optional[Any]:
"""Creates resize operation."""
if cfg is None:
return None
return build_from_cfg(cfg, PIPELINES)

def _create_resize_op(self, cfg: Optional[Dict]) -> Optional[Any]:
"""Creates resize operation."""
if cfg is None:
return None
return build_from_cfg(cfg, PIPELINES)


@PIPELINES.register_module()
class ResizeTo(Resize):
"""Resize to specific size.

This operation works if the input is not in desired shape.
If it's already in the shape, it just returns input dict for efficiency.

Args:
img_scale (tuple): Images scales for resizing (w, h).
"""

def __init__(self, **kwargs):
super().__init__(override=True, **kwargs) # Allow multiple calls

def __call__(self, results: Dict[str, Any]):
"""Callback function of ResizeTo.

Args:
results: Inputs to be transformed.
"""
img_shape = results.get("img_shape", (0, 0))
img_scale = self.img_scale[0]
if img_shape[0] == img_scale[0] and img_shape[1] == img_scale[1]:
return results
return super().__call__(results)


@PIPELINES.register_module()
class LoadAnnotationFromOTXDataset:
"""Pipeline element that loads an annotation from a OTX Dataset on the fly.
Expand Down Expand Up @@ -84,7 +119,7 @@ def _load_masks(results, ann_info):

def __call__(self, results: Dict[str, Any]):
"""Callback function of LoadAnnotationFromOTXDataset."""
dataset_item = results.pop("dataset_item")
dataset_item = results.pop("dataset_item") # Prevent unnecessary deepcopy
label_list = results.pop("ann_info")["label_list"]
ann_info = get_annotation_mmdet_format(dataset_item, label_list, self.domain, self.min_size)
if self.with_bbox:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,24 @@
__img_norm_cfg = dict(mean=[0, 0, 0], std=[255, 255, 255], to_rgb=True)

train_pipeline = [
dict(type="LoadImageFromOTXDataset", enable_memcache=True),
dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
dict(
type="LoadResizeDataFromOTXDataset",
load_ann_cfg=dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
resize_cfg=dict(
type="Resize",
img_scale=(1088, 800), # max sizes in random image scales
keep_ratio=True,
downscale_only=True,
), # Resize to intermediate size if org image is bigger
enable_memcache=True, # Cache after resizing image & annotations
),
dict(type="MinIoURandomCrop", min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3),
dict(
type="Resize",
img_scale=[(992, 736), (896, 736), (1088, 736), (992, 672), (992, 800)],
multiscale_mode="value",
keep_ratio=False,
override=True, # Allow multiple resize
),
dict(type="RandomFlip", flip_ratio=0.5),
dict(type="Normalize", **__img_norm_cfg),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,22 @@
__img_norm_cfg = dict(mean=(103.53, 116.28, 123.675), std=(1.0, 1.0, 1.0), to_rgb=True)

train_pipeline = [
dict(type="LoadImageFromOTXDataset", enable_memcache=True),
dict(
type="LoadAnnotationFromOTXDataset",
domain="instance_segmentation",
with_bbox=True,
with_mask=True,
poly2mask=False,
type="LoadResizeDataFromOTXDataset",
load_ann_cfg=dict(
type="LoadAnnotationFromOTXDataset",
domain="instance_segmentation",
with_bbox=True,
with_mask=True,
poly2mask=False,
),
resize_cfg=dict(
type="Resize",
img_scale=__img_size,
keep_ratio=False,
),
enable_memcache=True, # Cache after resizing image & annotations
),
dict(type="Resize", img_scale=__img_size, keep_ratio=False),
dict(type="RandomFlip", flip_ratio=0.5),
dict(type="Normalize", **__img_norm_cfg),
dict(type="Pad", size_divisor=32),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,22 @@
__img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

train_pipeline = [
dict(type="LoadImageFromOTXDataset", enable_memcache=True),
dict(
type="LoadAnnotationFromOTXDataset",
domain="instance_segmentation",
with_bbox=True,
with_mask=True,
poly2mask=False,
type="LoadResizeDataFromOTXDataset",
load_ann_cfg=dict(
type="LoadAnnotationFromOTXDataset",
domain="instance_segmentation",
with_bbox=True,
with_mask=True,
poly2mask=False,
),
resize_cfg=dict(
type="Resize",
img_scale=__img_size,
keep_ratio=False,
),
enable_memcache=True, # Cache after resizing image & annotations
),
dict(type="Resize", img_scale=__img_size, keep_ratio=False),
dict(type="RandomFlip", flip_ratio=0.5),
dict(type="Normalize", **__img_norm_cfg),
dict(type="DefaultFormatBundle"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
hue_delta=18,
),
dict(type="RandomFlip", flip_ratio=0.5),
dict(type="Resize", img_scale=__img_size, keep_ratio=True),
dict(type="Resize", img_scale=__img_size, keep_ratio=True, override=True), # Allow multiple resize
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
dict(type="Pad", pad_to_square=True, pad_val=114.0),
dict(type="Normalize", **__img_norm_cfg),
dict(type="DefaultFormatBundle"),
Expand Down Expand Up @@ -82,8 +82,18 @@
dataset=dict(
type=__dataset_type,
pipeline=[
dict(type="LoadImageFromOTXDataset", to_float32=False, enable_memcache=True),
dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
dict(
type="LoadResizeDataFromOTXDataset",
load_ann_cfg=dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
resize_cfg=dict(
type="Resize",
img_scale=__img_size,
keep_ratio=True,
downscale_only=True,
), # Resize to intermediate size if org image is bigger
to_float32=False,
enable_memcache=True, # Cache after resizing image & annotations
),
],
),
pipeline=train_pipeline,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,18 @@
__img_norm_cfg = dict(mean=[0, 0, 0], std=[255, 255, 255], to_rgb=True)

train_pipeline = [
dict(type="LoadImageFromOTXDataset", to_float32=True, enable_memcache=True),
dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
dict(
type="LoadResizeDataFromOTXDataset",
load_ann_cfg=dict(type="LoadAnnotationFromOTXDataset", with_bbox=True),
resize_cfg=dict(
type="Resize",
img_scale=__img_size,
keep_ratio=True,
downscale_only=True,
), # Resize to intermediate size if org image is bigger
to_float32=True,
enable_memcache=True, # Cache after resizing image & annotations
),
dict(
type="PhotoMetricDistortion",
brightness_delta=32,
Expand All @@ -31,7 +41,7 @@
hue_delta=18,
),
dict(type="MinIoURandomCrop", min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.1),
dict(type="Resize", img_scale=__img_size, keep_ratio=False),
dict(type="Resize", img_scale=__img_size, keep_ratio=False, override=True), # Allow multiple resize
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
dict(type="Normalize", **__img_norm_cfg),
dict(type="RandomFlip", flip_ratio=0.5),
dict(type="DefaultFormatBundle"),
Expand Down
11 changes: 6 additions & 5 deletions src/otx/core/data/caching/mem_cache_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def get(self, key: Any) -> Tuple[Optional[np.ndarray], Optional[Dict]]:

addr = self._cache_addr[key]

offset, count, shape, strides, meta = addr
offset, count, dtype, shape, strides, meta = addr

data = np.frombuffer(self._arr, dtype=np.uint8, count=count, offset=offset)
data = np.frombuffer(self._arr, dtype=dtype, count=count, offset=offset)
return np.lib.stride_tricks.as_strided(data, shape, strides), meta

def put(self, key: Any, data: np.ndarray, meta: Optional[Dict] = None) -> Optional[int]:
Expand All @@ -82,20 +82,21 @@ def put(self, key: Any, data: np.ndarray, meta: Optional[Dict] = None) -> Option
if self._freeze.value:
return None

assert data.dtype == np.uint8
data_bytes = data.size * data.itemsize

with self._lock:
new_page = self._cur_page.value + data.size
new_page = self._cur_page.value + data_bytes

if key in self._cache_addr or new_page > self.mem_size:
return None

offset = ct.byref(self._arr, self._cur_page.value)
ct.memmove(offset, data.ctypes.data, data.size)
ct.memmove(offset, data.ctypes.data, data_bytes)

self._cache_addr[key] = (
self._cur_page.value,
data.size,
data.dtype,
data.shape,
data.strides,
meta,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,19 @@ def test_enable_memcache(self, fxt_caching_dataset_cls, fxt_data_list):

# The second round requires no read.
assert mock.call_count == 0


@pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"])
def test_memcache_image_itemtype(mode):
img = (np.random.rand(10, 10, 3) * 255).astype(np.uint8)
MemCacheHandlerSingleton.create(mode, img.size * img.itemsize)
cache = MemCacheHandlerSingleton.get()
cache.put("img_u8", img)
img_cached, _ = cache.get("img_u8")
assert np.array_equal(img, img_cached)
img = np.random.rand(10, 10, 3).astype(np.float)
MemCacheHandlerSingleton.create(mode, img.size * img.itemsize)
cache = MemCacheHandlerSingleton.get()
cache.put("img_f32", img)
img_cached, _ = cache.get("img_f32")
assert np.array_equal(img, img_cached)
Loading