Skip to content

Commit

Permalink
[Datumaro] Refactor explain and models CLI (#1714)
Browse files Browse the repository at this point in the history
* Update explain cli

* Update model cli

* Update config for models

* Remove input size hint for models
  • Loading branch information
zhiltsov-max authored Jun 17, 2020
1 parent 0e00315 commit be30aa6
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 40 deletions.
27 changes: 10 additions & 17 deletions datumaro/datumaro/cli/commands/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
help="Confidence threshold for detections (default: include all)")
rise_parser.add_argument('-b', '--batch-size', default=1, type=int,
help="Inference batch size (default: %(default)s)")
rise_parser.add_argument('--progressive', action='store_true',
rise_parser.add_argument('--display', action='store_true',
help="Visualize results during computations")

parser.add_argument('-p', '--project', dest='project_dir', default='.',
Expand Down Expand Up @@ -108,16 +108,13 @@ def explain_command(args):
if args.target[0] == TargetKinds.image:
image_path = args.target[1]
image = load_image(image_path)
if model.preferred_input_size() is not None:
h, w = model.preferred_input_size()
image = cv2.resize(image, (w, h))

log.info("Running inference explanation for '%s'" % image_path)
heatmap_iter = rise.apply(image, progressive=args.progressive)
heatmap_iter = rise.apply(image, progressive=args.display)

image = image / 255.0
file_name = osp.splitext(osp.basename(image_path))[0]
if args.progressive:
if args.display:
for i, heatmaps in enumerate(heatmap_iter):
for j, heatmap in enumerate(heatmaps):
hm_painted = cm.jet(heatmap)[:, :, 2::-1]
Expand Down Expand Up @@ -154,35 +151,31 @@ def explain_command(args):
log.info("Running inference explanation for '%s'" % project_name)

for item in dataset:
image = item.image
image = item.image.data
if image is None:
log.warn(
"Dataset item %s does not have image data. Skipping." % \
(item.id))
continue

if model.preferred_input_size() is not None:
h, w = model.preferred_input_size()
image = cv2.resize(image, (w, h))
heatmap_iter = rise.apply(image)

image = image / 255.0
file_name = osp.splitext(osp.basename(image_path))[0]
heatmaps = next(heatmap_iter)

if args.save_dir is not None:
log.info("Saving inference heatmaps at '%s'" % args.save_dir)
log.info("Saving inference heatmaps to '%s'" % args.save_dir)
os.makedirs(args.save_dir, exist_ok=True)

for j, heatmap in enumerate(heatmaps):
save_path = osp.join(args.save_dir,
file_name + '-heatmap-%s.png' % j)
save_image(save_path, heatmap * 255.0)
save_image(osp.join(args.save_dir,
item.id + '-heatmap-%s.png' % j),
heatmap * 255.0, create_dir=True)

if args.progressive:
if not args.save_dir or args.display:
for j, heatmap in enumerate(heatmaps):
disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2
cv2.imshow(file_name + '-heatmap-%s' % j, disp)
cv2.imshow(item.id + '-heatmap-%s' % j, disp)
cv2.waitKey(0)
else:
raise NotImplementedError()
Expand Down
37 changes: 24 additions & 13 deletions datumaro/datumaro/cli/contexts/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

from datumaro.components.config import DEFAULT_FORMAT
from datumaro.components.project import Environment
from ...util import add_subparser, MultilineFormatter
from ...util.project import load_project

from ...util import CliException, MultilineFormatter, add_subparser
from ...util.project import load_project, generate_next_dir_name


def build_add_parser(parser_ctor=argparse.ArgumentParser):
Expand Down Expand Up @@ -63,19 +64,20 @@ def add_command(args):
except KeyError:
raise CliException("Launcher '%s' is not found" % args.launcher)

cli_plugin = launcher.cli_plugin
cli_plugin = getattr(launcher, 'cli_plugin', launcher)
model_args = cli_plugin.from_cmdline(args.extra_args)

if args.copy:
try:
log.info("Copying model data")
log.info("Copying model data")

model_dir = project.local_model_dir(args.name)
os.makedirs(model_dir, exist_ok=False)
model_dir = project.local_model_dir(args.name)
os.makedirs(model_dir, exist_ok=False)

try:
cli_plugin.copy_model(model_dir, model_args)
except NotImplementedError:
except (AttributeError, NotImplementedError):
log.error("Can't copy: copying is not available for '%s' models" % \
(args.launcher))
args.launcher)

log.info("Adding the model")
project.add_model(args.name, {
Expand Down Expand Up @@ -115,23 +117,32 @@ def remove_command(args):
def build_run_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor()

parser.add_argument('-o', '--output-dir', dest='dst_dir', required=True,
parser.add_argument('-o', '--output-dir', dest='dst_dir',
help="Directory to save output")
parser.add_argument('-m', '--model', dest='model_name', required=True,
help="Model to apply to the project")
parser.add_argument('-p', '--project', dest='project_dir', default='.',
help="Directory of the project to operate on (default: current dir)")
parser.add_argument('--overwrite', action='store_true',
help="Overwrite if exists")
parser.set_defaults(command=run_command)

return parser

def run_command(args):
project = load_project(args.project_dir)

dst_dir = osp.abspath(args.dst_dir)
os.makedirs(dst_dir, exist_ok=False)
dst_dir = args.dst_dir
if dst_dir:
if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir):
raise CliException("Directory '%s' already exists "
"(pass --overwrite overwrite)" % dst_dir)
else:
dst_dir = generate_next_dir_name('%s-inference' % \
(project.config.project_name))

project.make_dataset().apply_model(
save_dir=dst_dir,
save_dir=osp.abspath(dst_dir),
model=args.model_name)

log.info("Inference results have been saved to '%s'" % dst_dir)
Expand Down
2 changes: 1 addition & 1 deletion datumaro/datumaro/components/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __len__(self):
return len(self.items())

def __iter__(self):
return iter(zip(self.keys(), self.values()))
return iter(self.keys())

def __getitem__(self, key):
default = object()
Expand Down
1 change: 0 additions & 1 deletion datumaro/datumaro/components/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def __init__(self, config=None):

MODEL_SCHEMA = _SchemaBuilder() \
.add('launcher', str) \
.add('model_dir', str, internal=True) \
.add('options', dict) \
.build()

Expand Down
3 changes: 0 additions & 3 deletions datumaro/datumaro/components/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ def __init__(self, model_dir=None):
def launch(self, inputs):
raise NotImplementedError()

def preferred_input_size(self):
return None

def categories(self):
return None
# pylint: enable=no-self-use
Expand Down
3 changes: 1 addition & 2 deletions datumaro/datumaro/components/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,9 +817,8 @@ def remove_model(self, name):

def make_executable_model(self, name):
model = self.get_model(name)
model.model_dir = self.local_model_dir(name)
return self.env.make_launcher(model.launcher,
**model.options, model_dir=model.model_dir)
**model.options, model_dir=self.local_model_dir(name))

def make_source_project(self, name):
source = self.get_source(name)
Expand Down
3 changes: 0 additions & 3 deletions datumaro/datumaro/plugins/openvino_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,3 @@ def categories(self):
def process_outputs(self, inputs, outputs):
return self._interpreter.process_outputs(inputs, outputs)

def preferred_input_size(self):
_, _, h, w = self._input_layout
return (h, w)

0 comments on commit be30aa6

Please sign in to comment.