Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support random video compression during training #646

Merged
merged 4 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ jobs:
- name: Install FaceXLib
run: pip install facexlib
if: ${{matrix.torch > '1.7'}}
- name: Install av
run: python -m pip install av
- name: Install unittest dependencies
run: |
pip install -r requirements.txt
Expand Down Expand Up @@ -141,6 +143,8 @@ jobs:
- name: Install FaceXLib
run: python -m pip install facexlib
if: ${{matrix.torch > '1.7'}}
- name: Install av
run: python -m pip install av
- name: Install unittest dependencies
run: |
python -m pip install -r requirements.txt
Expand Down Expand Up @@ -207,6 +211,8 @@ jobs:
- name: Install FaceXLib
run: python -m pip install facexlib
if: ${{matrix.torch > '1.7'}}
- name: Install av
run: python -m pip install av
- name: Install unittest dependencies
run: |
python -m pip install -r requirements.txt
Expand Down
5 changes: 3 additions & 2 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .normalization import Normalize, RescaleToZeroOne
from .random_degradations import (DegradationsWithShuffle, RandomBlur,
RandomJPEGCompression, RandomNoise,
RandomResize)
RandomResize, RandomVideoCompression)
from .random_down_sampling import RandomDownSampling

__all__ = [
Expand All @@ -40,5 +40,6 @@
'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence',
'CropLike', 'GenerateHeatmap', 'MATLABLikeResize', 'CopyValues',
'Quantize', 'RandomBlur', 'RandomJPEGCompression', 'RandomNoise',
'DegradationsWithShuffle', 'RandomResize', 'UnsharpMasking'
'DegradationsWithShuffle', 'RandomResize', 'UnsharpMasking',
'RandomVideoCompression'
]
77 changes: 77 additions & 0 deletions mmedit/datasets/pipelines/random_degradations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import io
import logging
import random

import cv2
Expand All @@ -6,6 +8,12 @@
from mmedit.datasets.pipelines import blur_kernels as blur_kernels
from ..registry import PIPELINES

try:
import av
has_av = True
except ImportError:
has_av = False


@PIPELINES.register_module()
class RandomBlur:
Expand Down Expand Up @@ -322,11 +330,80 @@ def __repr__(self):
return repr_str


@PIPELINES.register_module()
class RandomVideoCompression:
"""Apply random video compression to the input.

Modified keys are the attributed specified in "keys".

Args:
params (dict): A dictionary specifying the degradation settings.
keys (list[str]): A list specifying the keys whose values are
modified.
"""

def __init__(self, params, keys):
assert has_av, 'Please install av to use video compression.'

self.keys = keys
self.params = params
logging.getLogger('libav').setLevel(50)

def _apply_random_compression(self, imgs):
codec = random.choices(self.params['codec'],
self.params['codec_prob'])[0]
bitrate = self.params['bitrate']
bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)

buf = io.BytesIO()
with av.open(buf, 'w', 'mp4') as container:
stream = container.add_stream(codec, rate=1)
stream.height = imgs[0].shape[0]
stream.width = imgs[0].shape[1]
stream.pix_fmt = 'yuv420p'
stream.bit_rate = bitrate

for img in imgs:
img = (255 * img).astype(np.uint8)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame.pict_type = 'NONE'
for packet in stream.encode(frame):
container.mux(packet)

# Flush stream
for packet in stream.encode():
container.mux(packet)

outputs = []
with av.open(buf, 'r', 'mp4') as container:
if container.streams.video:
for frame in container.decode(**{'video': 0}):
outputs.append(
frame.to_rgb().to_ndarray().astype(np.float32) / 255.)

return outputs

def __call__(self, results):
if np.random.uniform() > self.params.get('prob', 1):
return results

for key in self.keys:
results[key] = self._apply_random_compression(results[key])

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(params={self.params}, keys={self.keys})')
return repr_str


allowed_degradations = {
'RandomBlur': RandomBlur,
'RandomResize': RandomResize,
'RandomNoise': RandomNoise,
'RandomJPEGCompression': RandomJPEGCompression,
'RandomVideoCompression': RandomVideoCompression,
}


Expand Down
29 changes: 28 additions & 1 deletion tests/test_data/test_pipelines/test_random_degradations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from mmedit.datasets.pipelines import (DegradationsWithShuffle, RandomBlur,
RandomJPEGCompression, RandomNoise,
RandomResize)
RandomResize, RandomVideoCompression)


def test_random_noise():
Expand Down Expand Up @@ -64,6 +64,33 @@ def test_random_jpeg_compression():
+ "keys=['lq'])"


def test_random_video_compression():
results = {}
results['lq'] = [np.ones((8, 8, 3)).astype(np.float32)] * 5

model = RandomVideoCompression(
params=dict(
codec=['libx264', 'h264', 'mpeg4'],
codec_prob=[1 / 3., 1 / 3., 1 / 3.],
bitrate=[1e4, 1e5]),
keys=['lq'])
results = model(results)
assert results['lq'][0].shape == (8, 8, 3)
assert len(results['lq']) == 5

# skip degradations with prob < 1
params = dict(
codec=['libx264', 'h264', 'mpeg4'],
codec_prob=[1 / 3., 1 / 3., 1 / 3.],
bitrate=[1e4, 1e5],
prob=0)
model = RandomVideoCompression(params=params, keys=['lq'])
assert model(results) == results

assert repr(model) == model.__class__.__name__ + f'(params={params}, ' \
+ "keys=['lq'])"


def test_random_resize():
results = {}
results['lq'] = np.ones((8, 8, 3)).astype(np.float32)
Expand Down