Skip to content

Commit

Permalink
Merge 3db63ef into 4dc809a
Browse files Browse the repository at this point in the history
  • Loading branch information
yamengxi authored Dec 2, 2020
2 parents 4dc809a + 3db63ef commit 8aa049f
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 6 deletions.
8 changes: 4 additions & 4 deletions mmseg/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
Transpose, to_tensor)
from .loading import LoadAnnotations, LoadImageFromFile
from .test_time_aug import MultiScaleFlipAug
from .transforms import (Normalize, Pad, PhotoMetricDistortion, RandomCrop,
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray,
SegRescale)
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomFlip,
RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)

__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
'Rerange', 'RGB2Gray'
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
]
48 changes: 46 additions & 2 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mmcv
import numpy as np
from mmcv.utils import deprecated_api_warning
from mmcv.utils import deprecated_api_warning, is_tuple_of
from numpy import random

from ..builder import PIPELINES
Expand Down Expand Up @@ -415,7 +415,6 @@ def __call__(self, results):
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Reranged results.
"""
Expand All @@ -439,6 +438,51 @@ def __repr__(self):
return repr_str


@PIPELINES.register_module()
class CLAHE(object):
"""Use CLAHE method to process the image.
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
Graphics Gems, 1994:474-485.` for more information.
Args:
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
Input image will be divided into equally sized rectangular tiles.
It defines the number of tiles in row and column. Default: (8, 8).
"""

def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
assert isinstance(clip_limit, (float, int))
self.clip_limit = clip_limit
assert is_tuple_of(tile_grid_size, int)
assert len(tile_grid_size) == 2
self.tile_grid_size = tile_grid_size

def __call__(self, results):
"""Call function to Use CLAHE method process images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Processed results.
"""

for i in range(results['img'].shape[2]):
results['img'][:, :, i] = mmcv.clahe(
np.array(results['img'][:, :, i], dtype=np.uint8),
self.clip_limit, self.tile_grid_size)

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(clip_limit={self.clip_limit}, '\
f'tile_grid_size={self.tile_grid_size})'
return repr_str


@PIPELINES.register_module()
class RandomCrop(object):
"""Random crop the image & seg.
Expand Down
40 changes: 40 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,46 @@ def test_rerange():
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'


def test_CLAHE():
# test assertion if clip_limit is None
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', clip_limit=None)
build_from_cfg(transform, PIPELINES)

# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
build_from_cfg(transform, PIPELINES)

# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
build_from_cfg(transform, PIPELINES)

transform = dict(type='CLAHE', clip_limit=2)
transform = build_from_cfg(transform, PIPELINES)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

results = transform(results)

converted_img = np.empty(original_img.shape)
for i in range(original_img.shape[2]):
converted_img[:, :, i] = mmcv.clahe(
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))

assert np.allclose(results['img'], converted_img)
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'


def test_seg_rescale():
results = dict()
seg = np.array(
Expand Down

0 comments on commit 8aa049f

Please sign in to comment.