From be30aa6354a09d63dd46c83c7a73f27a6d1d6ac7 Mon Sep 17 00:00:00 2001 From: zhiltsov-max Date: Wed, 17 Jun 2020 17:21:10 +0300 Subject: [PATCH] [Datumaro] Refactor explain and models CLI (#1714) * Update explain cli * Update model cli * Update config for models * Remove input size hint for models --- datumaro/datumaro/cli/commands/explain.py | 27 +++++--------- .../datumaro/cli/contexts/model/__init__.py | 37 ++++++++++++------- datumaro/datumaro/components/config.py | 2 +- datumaro/datumaro/components/config_model.py | 1 - datumaro/datumaro/components/launcher.py | 3 -- datumaro/datumaro/components/project.py | 3 +- .../datumaro/plugins/openvino_launcher.py | 3 -- 7 files changed, 36 insertions(+), 40 deletions(-) diff --git a/datumaro/datumaro/cli/commands/explain.py b/datumaro/datumaro/cli/commands/explain.py index 9b5a6432d65..a0a5f1cc64a 100644 --- a/datumaro/datumaro/cli/commands/explain.py +++ b/datumaro/datumaro/cli/commands/explain.py @@ -59,7 +59,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): help="Confidence threshold for detections (default: include all)") rise_parser.add_argument('-b', '--batch-size', default=1, type=int, help="Inference batch size (default: %(default)s)") - rise_parser.add_argument('--progressive', action='store_true', + rise_parser.add_argument('--display', action='store_true', help="Visualize results during computations") parser.add_argument('-p', '--project', dest='project_dir', default='.', @@ -108,16 +108,13 @@ def explain_command(args): if args.target[0] == TargetKinds.image: image_path = args.target[1] image = load_image(image_path) - if model.preferred_input_size() is not None: - h, w = model.preferred_input_size() - image = cv2.resize(image, (w, h)) log.info("Running inference explanation for '%s'" % image_path) - heatmap_iter = rise.apply(image, progressive=args.progressive) + heatmap_iter = rise.apply(image, progressive=args.display) image = image / 255.0 file_name = osp.splitext(osp.basename(image_path))[0] - if args.progressive: + if args.display: for i, heatmaps in enumerate(heatmap_iter): for j, heatmap in enumerate(heatmaps): hm_painted = cm.jet(heatmap)[:, :, 2::-1] @@ -154,35 +151,31 @@ def explain_command(args): log.info("Running inference explanation for '%s'" % project_name) for item in dataset: - image = item.image + image = item.image.data if image is None: log.warn( "Dataset item %s does not have image data. Skipping." % \ (item.id)) continue - if model.preferred_input_size() is not None: - h, w = model.preferred_input_size() - image = cv2.resize(image, (w, h)) heatmap_iter = rise.apply(image) image = image / 255.0 - file_name = osp.splitext(osp.basename(image_path))[0] heatmaps = next(heatmap_iter) if args.save_dir is not None: - log.info("Saving inference heatmaps at '%s'" % args.save_dir) + log.info("Saving inference heatmaps to '%s'" % args.save_dir) os.makedirs(args.save_dir, exist_ok=True) for j, heatmap in enumerate(heatmaps): - save_path = osp.join(args.save_dir, - file_name + '-heatmap-%s.png' % j) - save_image(save_path, heatmap * 255.0) + save_image(osp.join(args.save_dir, + item.id + '-heatmap-%s.png' % j), + heatmap * 255.0, create_dir=True) - if args.progressive: + if not args.save_dir or args.display: for j, heatmap in enumerate(heatmaps): disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2 - cv2.imshow(file_name + '-heatmap-%s' % j, disp) + cv2.imshow(item.id + '-heatmap-%s' % j, disp) cv2.waitKey(0) else: raise NotImplementedError() diff --git a/datumaro/datumaro/cli/contexts/model/__init__.py b/datumaro/datumaro/cli/contexts/model/__init__.py index a2af90f74fd..30cc4da83d3 100644 --- a/datumaro/datumaro/cli/contexts/model/__init__.py +++ b/datumaro/datumaro/cli/contexts/model/__init__.py @@ -11,8 +11,9 @@ from datumaro.components.config import DEFAULT_FORMAT from datumaro.components.project import Environment -from ...util import add_subparser, MultilineFormatter -from ...util.project import load_project + +from ...util import CliException, MultilineFormatter, add_subparser +from ...util.project import load_project, generate_next_dir_name def build_add_parser(parser_ctor=argparse.ArgumentParser): @@ -63,19 +64,20 @@ def add_command(args): except KeyError: raise CliException("Launcher '%s' is not found" % args.launcher) - cli_plugin = launcher.cli_plugin + cli_plugin = getattr(launcher, 'cli_plugin', launcher) model_args = cli_plugin.from_cmdline(args.extra_args) if args.copy: - try: - log.info("Copying model data") + log.info("Copying model data") - model_dir = project.local_model_dir(args.name) - os.makedirs(model_dir, exist_ok=False) + model_dir = project.local_model_dir(args.name) + os.makedirs(model_dir, exist_ok=False) + + try: cli_plugin.copy_model(model_dir, model_args) - except NotImplementedError: + except (AttributeError, NotImplementedError): log.error("Can't copy: copying is not available for '%s' models" % \ - (args.launcher)) + args.launcher) log.info("Adding the model") project.add_model(args.name, { @@ -115,12 +117,14 @@ def remove_command(args): def build_run_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor() - parser.add_argument('-o', '--output-dir', dest='dst_dir', required=True, + parser.add_argument('-o', '--output-dir', dest='dst_dir', help="Directory to save output") parser.add_argument('-m', '--model', dest='model_name', required=True, help="Model to apply to the project") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to operate on (default: current dir)") + parser.add_argument('--overwrite', action='store_true', + help="Overwrite if exists") parser.set_defaults(command=run_command) return parser @@ -128,10 +132,17 @@ def build_run_parser(parser_ctor=argparse.ArgumentParser): def run_command(args): project = load_project(args.project_dir) - dst_dir = osp.abspath(args.dst_dir) - os.makedirs(dst_dir, exist_ok=False) + dst_dir = args.dst_dir + if dst_dir: + if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): + raise CliException("Directory '%s' already exists " + "(pass --overwrite overwrite)" % dst_dir) + else: + dst_dir = generate_next_dir_name('%s-inference' % \ + (project.config.project_name)) + project.make_dataset().apply_model( - save_dir=dst_dir, + save_dir=osp.abspath(dst_dir), model=args.model_name) log.info("Inference results have been saved to '%s'" % dst_dir) diff --git a/datumaro/datumaro/components/config.py b/datumaro/datumaro/components/config.py index 520c6e70bd5..ca66eff8a52 100644 --- a/datumaro/datumaro/components/config.py +++ b/datumaro/datumaro/components/config.py @@ -130,7 +130,7 @@ def __len__(self): return len(self.items()) def __iter__(self): - return iter(zip(self.keys(), self.values())) + return iter(self.keys()) def __getitem__(self, key): default = object() diff --git a/datumaro/datumaro/components/config_model.py b/datumaro/datumaro/components/config_model.py index 9bce725ebd7..f46682d2f33 100644 --- a/datumaro/datumaro/components/config_model.py +++ b/datumaro/datumaro/components/config_model.py @@ -21,7 +21,6 @@ def __init__(self, config=None): MODEL_SCHEMA = _SchemaBuilder() \ .add('launcher', str) \ - .add('model_dir', str, internal=True) \ .add('options', dict) \ .build() diff --git a/datumaro/datumaro/components/launcher.py b/datumaro/datumaro/components/launcher.py index 5bcd9ad43ff..1a60ceadb7c 100644 --- a/datumaro/datumaro/components/launcher.py +++ b/datumaro/datumaro/components/launcher.py @@ -17,9 +17,6 @@ def __init__(self, model_dir=None): def launch(self, inputs): raise NotImplementedError() - def preferred_input_size(self): - return None - def categories(self): return None # pylint: enable=no-self-use diff --git a/datumaro/datumaro/components/project.py b/datumaro/datumaro/components/project.py index 7be463455cc..4fe08b13f4f 100644 --- a/datumaro/datumaro/components/project.py +++ b/datumaro/datumaro/components/project.py @@ -817,9 +817,8 @@ def remove_model(self, name): def make_executable_model(self, name): model = self.get_model(name) - model.model_dir = self.local_model_dir(name) return self.env.make_launcher(model.launcher, - **model.options, model_dir=model.model_dir) + **model.options, model_dir=self.local_model_dir(name)) def make_source_project(self, name): source = self.get_source(name) diff --git a/datumaro/datumaro/plugins/openvino_launcher.py b/datumaro/datumaro/plugins/openvino_launcher.py index c79789de370..4e150b039d0 100644 --- a/datumaro/datumaro/plugins/openvino_launcher.py +++ b/datumaro/datumaro/plugins/openvino_launcher.py @@ -186,6 +186,3 @@ def categories(self): def process_outputs(self, inputs, outputs): return self._interpreter.process_outputs(inputs, outputs) - def preferred_input_size(self): - _, _, h, w = self._input_layout - return (h, w)