Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENHANCE] Parametrize saliency maps dumping in export parametrizable #1708

Merged
merged 14 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ def postprocess(self, outputs: Dict[str, np.ndarray], metadata: Dict[str, Any]):
@check_input_parameters_type()
def postprocess_aux_outputs(self, outputs: Dict[str, np.ndarray], metadata: Dict[str, Any]):
"""Post-process for auxiliary outputs."""
saliency_map = outputs["saliency_map"][0]
repr_vector = outputs["feature_vector"].reshape(-1)
logits = outputs[self.out_layer_name].squeeze()
if self.multilabel:
probs = sigmoid_numpy(logits)
Expand All @@ -113,6 +111,13 @@ def postprocess_aux_outputs(self, outputs: Dict[str, np.ndarray], metadata: Dict
else:
probs = softmax_numpy(logits)
act_score = float(np.max(probs) - np.min(probs))

if "saliency_map" in outputs:
saliency_map = outputs["saliency_map"][0]
repr_vector = outputs["feature_vector"].reshape(-1)
else:
saliency_map, repr_vector = None, None

return probs, saliency_map, repr_vector, act_score


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
_base_ = ["../base/deployments/base_classification_dynamic.py"]

ir_config = dict(
output_names=["logits", "feature_vector", "saliency_map"],
output_names=["logits"],
)

backend_config = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
_base_ = ["../base/deployments/base_classification_dynamic.py"]

ir_config = dict(
output_names=["logits", "feature_vector", "saliency_map"],
output_names=["logits"],
)

backend_config = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
_base_ = ["../base/deployments/base_classification_dynamic.py"]

ir_config = dict(
output_names=["logits", "feature_vector", "saliency_map"],
output_names=["logits"],
)

backend_config = dict(
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/classification/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def unload(self):
self.cleanup()

@check_input_parameters_type()
def export(self, export_type: ExportType, output_model: ModelEntity):
def export(self, export_type: ExportType, output_model: ModelEntity, dump_features: bool = False):
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved
"""Export function of OTX Classification Task."""

logger.info("Exporting the model")
Expand All @@ -205,7 +205,7 @@ def export(self, export_type: ExportType, output_model: ModelEntity):
output_model.optimization_type = ModelOptimizationType.MO

stage_module = "ClsExporter"
results = self._run_task(stage_module, mode="train", export=True)
results = self._run_task(stage_module, mode="train", export=True, dump_features=dump_features)
outputs = results.get("outputs")
logger.debug(f"results of run_task = {outputs}")
if outputs is None:
Expand Down
31 changes: 19 additions & 12 deletions otx/algorithms/classification/tasks/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import os
import tempfile
import warnings
from typing import Any, Dict, Optional, Tuple, Union
from zipfile import ZipFile

Expand Down Expand Up @@ -78,8 +79,6 @@
from openvino.model_zoo.model_api.adapters import OpenvinoAdapter, create_core
from openvino.model_zoo.model_api.models import Model
except ImportError:
import warnings

warnings.warn("ModelAPI was not found.")

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -234,17 +233,25 @@ def infer(
probs_meta = TensorEntity(name="probabilities", numpy=probs.reshape(-1))
dataset_item.append_metadata_item(probs_meta, model=self.model)

feature_vec_media = TensorEntity(name="representation_vector", numpy=repr_vector.reshape(-1))
dataset_item.append_metadata_item(feature_vec_media, model=self.model)
if dump_features:
add_saliency_maps_to_dataset_item(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self.model,
labels=self.task_environment.get_labels(),
task="cls",
predicted_scene=predicted_scene,
)
if saliency_map is not None and repr_vector is not None:
feature_vec_media = TensorEntity(name="representation_vector", numpy=repr_vector.reshape(-1))
dataset_item.append_metadata_item(feature_vec_media, model=self.model)

add_saliency_maps_to_dataset_item(
dataset_item=dataset_item,
saliency_map=saliency_map,
model=self.model,
labels=self.task_environment.get_labels(),
task="cls",
predicted_scene=predicted_scene,
)
else:
warnings.warn(
"Could not find Feature Vector and Saliency Map in OpenVINO output. "
"Please rerun OpenVINO export or retrain the model."
)

update_progress_callback(int(i / dataset_size * 100))
return dataset

Expand Down
6 changes: 6 additions & 0 deletions otx/algorithms/common/tasks/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,12 @@ def _initialize(self, options=None): # noqa: C901
assert len(self._precision) == 1
options["precision"] = str(self._precision[0])

options["deploy_cfg"]["dump_features"] = options["dump_features"]
if options["dump_features"]:
output_names = options["deploy_cfg"]["ir_config"]["output_names"]
if "feature_vector" not in output_names and "saliency_map" not in output_names:
options["deploy_cfg"]["ir_config"]["output_names"] += ["feature_vector", "saliency_map"]

self._initialize_post_hook(options)

logger.info("initialized.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
_base_ = ["../../base/deployments/base_detection_dynamic.py"]

ir_config = dict(
output_names=["boxes", "labels", "feature_vector", "saliency_map"],
output_names=["boxes", "labels"],
)

backend_config = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
_base_ = ["../../base/deployments/base_detection_dynamic.py"]

ir_config = dict(
output_names=["boxes", "labels", "feature_vector", "saliency_map"],
output_names=["boxes", "labels"],
)

backend_config = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
_base_ = ["../../base/deployments/base_detection_dynamic.py"]

ir_config = dict(
output_names=["boxes", "labels", "feature_vector", "saliency_map"],
output_names=["boxes", "labels"],
)

backend_config = dict(
Expand Down
3 changes: 2 additions & 1 deletion otx/algorithms/detection/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def unload(self):
self.cleanup()

@check_input_parameters_type()
def export(self, export_type: ExportType, output_model: ModelEntity):
def export(self, export_type: ExportType, output_model: ModelEntity, dump_features: bool = False):
"""Export function of OTX Detection Task."""
# copied from OTX inference_task.py
logger.info("Exporting the model")
Expand All @@ -241,6 +241,7 @@ def export(self, export_type: ExportType, output_model: ModelEntity):
stage_module,
mode="train",
export=True,
dump_features=dump_features,
)
outputs = results.get("outputs")
logger.debug(f"results of run_task = {outputs}")
Expand Down
8 changes: 7 additions & 1 deletion otx/cli/tools/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def get_args():
"--save-model-to",
help="Location where exported model will be stored.",
)
parser.add_argument(
"--dump_features",
action="store_true",
help="Whether to return feature vector and saliency map for explanation purposes.",
)
goodsong81 marked this conversation as resolved.
Show resolved Hide resolved

return parser.parse_args()

Expand Down Expand Up @@ -83,8 +88,9 @@ def main():
task = task_class(task_environment=environment)

exported_model = ModelEntity(None, environment.get_model_configuration())
# args.dump_features = True
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved

task.export(ExportType.OPENVINO, exported_model)
task.export(ExportType.OPENVINO, exported_model, args.dump_features)

if "save_model_to" not in args or not args.save_model_to:
args.save_model_to = str(config_manager.workspace_root / "model-exported")
Expand Down
10 changes: 7 additions & 3 deletions otx/mpa/modules/models/classifiers/sam_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ def sam_image_classifier__extract_feat(ctx, self, img):
def sam_image_classifier__simple_test(ctx, self, img, img_metas):
feat, backbone_feat = self.extract_feat(img)
logit = self.head.simple_test(feat)
saliency_map = ReciproCAMHook(self).func(backbone_feat)
feature_vector = FeatureVectorHook.func(backbone_feat)
return logit, feature_vector, saliency_map

if ctx.cfg["dump_features"]:
saliency_map = ReciproCAMHook(self).func(backbone_feat)
feature_vector = FeatureVectorHook.func(backbone_feat)
return logit, feature_vector, saliency_map

return logit
12 changes: 8 additions & 4 deletions otx/mpa/modules/models/detectors/custom_atss_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ def custom_atss__simple_test(ctx, self, img, img_metas, **kwargs):
feat = self.extract_feat(img)
outs = self.bbox_head(feat)
bbox_results = self.bbox_head.get_bboxes(*outs, img_metas=img_metas, cfg=self.test_cfg, **kwargs)
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results

@mark("custom_atss_forward", inputs=["input"], outputs=["dets", "labels", "feats", "saliencies"])
def __forward_impl(ctx, self, img, img_metas, **kwargs):
Expand Down
10 changes: 7 additions & 3 deletions otx/mpa/modules/models/detectors/custom_maskrcnn_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,16 @@ def load_state_dict_pre_hook(model, model_classes, chkpt_classes, chkpt_dict, pr
def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **kwargs):
assert self.with_bbox, "Bbox head must be implemented."
x = self.extract_feat(img)
feature_vector = FeatureVectorHook.func(x)
saliency_map = ActivationMapHook.func(x[-1])
if proposals is None:
proposals, _ = self.rpn_head.simple_test_rpn(x, img_metas)
out = self.roi_head.simple_test(x, proposals, img_metas, rescale=False)
return (*out, feature_vector, saliency_map)

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(x)
saliency_map = ActivationMapHook.func(x[-1])
return (*out, feature_vector, saliency_map)

return out

@mark("custom_maskrcnn_forward", inputs=["input"], outputs=["dets", "labels", "masks", "feats", "saliencies"])
def __forward_impl(ctx, self, img, img_metas, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,14 @@ def custom_single_stage_detector__simple_test(ctx, self, img, img_metas, **kwarg
feat = self.extract_feat(img)
outs = self.bbox_head(feat)
bbox_results = self.bbox_head.get_bboxes(*outs, img_metas=img_metas, cfg=self.test_cfg, **kwargs)
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results

@mark("custom_ssd_forward", inputs=["input"], outputs=["dets", "labels", "feats", "saliencies"])
def __forward_impl(ctx, self, img, img_metas, **kwargs):
Expand Down
12 changes: 8 additions & 4 deletions otx/mpa/modules/models/detectors/custom_yolox_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,14 @@ def custom_yolox__simple_test(ctx, self, img, img_metas, **kwargs):
feat = self.extract_feat(img)
outs = self.bbox_head(feat)
bbox_results = self.bbox_head.get_bboxes(*outs, img_metas=img_metas, cfg=self.test_cfg, **kwargs)
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results

@mark("custom_yolox_forward", inputs=["input"], outputs=["dets", "labels", "feats", "saliencies"])
def __forward_impl(ctx, self, img, img_metas, **kwargs):
Expand Down