Skip to content

Commit

Permalink
refactor backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Tau-J committed Mar 21, 2023
1 parent 65c4998 commit 3f614c0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 22 deletions.
23 changes: 4 additions & 19 deletions mmpose/engine/hooks/visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import mmcv
import mmengine
from mmengine.fileio import get_file_backend
import mmengine.fileio as fileio
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmengine.visualization import Visualizer
Expand Down Expand Up @@ -72,14 +72,7 @@ def __init__(
self.enable = enable
self.out_dir = out_dir
self._test_index = 0

if backend_args is None:
# lazy init at loading
self.backend_args = None
self.file_backend = None
else:
self.backend_args = backend_args.copy()
self.file_backend = get_file_backend(backend_args=backend_args)
self.backend_args = backend_args

def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[PoseDataSample]) -> None:
Expand All @@ -94,10 +87,6 @@ def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
if self.enable is False:
return

if self.file_backend is None:
self.file_backend = get_file_backend(
backend_args=self.backend_args)

self._visualizer.set_dataset_meta(runner.val_evaluator.dataset_meta)

# There is no guarantee that the same batch of images
Expand All @@ -106,7 +95,7 @@ def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,

# Visualize only the first data
img_path = data_batch['data_samples'][0].get('img_path')
img_bytes = self.file_client.get(img_path)
img_bytes = fileio.get(img_path, backend_args=self.backend_args)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
data_sample = outputs[0]

Expand Down Expand Up @@ -144,17 +133,13 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
self.out_dir)
mmengine.mkdir_or_exist(self.out_dir)

if self.file_backend is None:
self.file_backend = get_file_backend(
backend_args=self.backend_args)

self._visualizer.set_dataset_meta(runner.test_evaluator.dataset_meta)

for data_sample in outputs:
self._test_index += 1

img_path = data_sample.get('img_path')
img_bytes = self.file_client.get(img_path)
img_bytes = fileio.get(img_path, backend_args=self.backend_args)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
data_sample = merge_data_samples([data_sample])

Expand Down
5 changes: 2 additions & 3 deletions tools/misc/browse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import mmcv
import mmengine
import mmengine.fileio as fileio
import numpy as np
from mmengine import Config, DictAction
from mmengine.fileio import get_file_backend
from mmengine.registry import build_from_cfg, init_default_scope
from mmengine.structures import InstanceData

Expand Down Expand Up @@ -81,7 +81,6 @@ def main():
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
backend_args = cfg.get('backend_args', dict(backend='local'))
file_backend = get_file_backend(backend_args=backend_args)

# register all modules in mmpose into the registries
init_default_scope(cfg.get('default_scope', 'mmpose'))
Expand Down Expand Up @@ -122,7 +121,7 @@ def main():
continue
else:
img_path = item['img_path']
img_bytes = file_backend.get(img_path)
img_bytes = fileio.get(img_path, backend_args=backend_args)
img = mmcv.imfrombytes(img_bytes, channel_order='bgr')

# forge pseudo data_sample
Expand Down

0 comments on commit 3f614c0

Please sign in to comment.