Skip to content

Commit

Permalink
reply to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
negvet committed Mar 27, 2023
1 parent 76a93f5 commit 32d851b
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
# pylint: disable=too-many-locals



class DetClassProbabilityMapHook(BaseRecordingForwardHook):
"""Saliency map hook for object detection models."""

Expand Down
4 changes: 2 additions & 2 deletions otx/cli/tools/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def get_args():
"--explain-all-classes",
action="store_true",
help="Provides explanations for all classes. Otherwise, explains only predicted classes."
"This feature supported by algorithms that can generate explanations per class.",
"This feature is supported by algorithms that can generate explanations per each class.",
)
parser.add_argument(
"--overlay-weight",
type=float,
default=0.5,
help="Weight of the saliency map when overlaying the input image with saliency map",
help="Weight of the saliency map when overlaying the input image with saliency map.",
)
add_hyper_parameters_sub_parser(parser, hyper_parameters, modes=("INFERENCE",))
override_param = [f"params.{param[2:].split('=')[0]}" for param in params if param.startswith("--")]
Expand Down
38 changes: 35 additions & 3 deletions tests/e2e/cli/detection/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@
otx_eval_deployment_testing,
otx_eval_openvino_testing,
otx_eval_testing,
otx_explain_all_classes_openvino_testing,
otx_explain_openvino_testing,
otx_explain_process_saliency_maps_openvino_testing,
otx_explain_testing,
otx_explain_testing_all_classes,
otx_explain_testing_process_saliency_maps,
otx_export_testing,
otx_export_testing_w_features,
otx_hpo_testing,
Expand All @@ -43,7 +47,7 @@
"--val-data-roots": "tests/assets/car_tree_bug",
"--test-data-roots": "tests/assets/car_tree_bug",
"--input": "tests/assets/car_tree_bug/images/train",
"train_params": ["params", "--learning_parameters.num_iters", "5", "--learning_parameters.batch_size", "4"],
"train_params": ["params", "--learning_parameters.num_iters", "10", "--learning_parameters.batch_size", "4"],
}

# Class-Incremental learning w/ 'vehicle', 'person', 'non-vehicle' classes
Expand Down Expand Up @@ -152,14 +156,42 @@ def test_otx_eval_openvino(self, template, tmp_dir_path):
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_explain(self, template, tmp_dir_path):
tmp_dir_path = tmp_dir_path / "detection"
otx_explain_testing(template, tmp_dir_path, otx_dir, args)
otx_explain_testing(template, tmp_dir_path, otx_dir, args, trained=True)

@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_explain_all_classes(self, template, tmp_dir_path):
tmp_dir_path = tmp_dir_path / "detection"
otx_explain_testing_all_classes(template, tmp_dir_path, otx_dir, args)

@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_explain_process_saliency_maps(self, template, tmp_dir_path):
tmp_dir_path = tmp_dir_path / "detection"
otx_explain_testing_process_saliency_maps(template, tmp_dir_path, otx_dir, args, trained=True)

@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_explain_openvino(self, template, tmp_dir_path):
tmp_dir_path = tmp_dir_path / "detection"
otx_explain_openvino_testing(template, tmp_dir_path, otx_dir, args)
otx_explain_openvino_testing(template, tmp_dir_path, otx_dir, args, trained=True)

@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_explain_all_classes_openvino(self, template, tmp_dir_path):
tmp_dir_path = tmp_dir_path / "detection"
otx_explain_all_classes_openvino_testing(template, tmp_dir_path, otx_dir, args)

@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_explain_process_saliency_maps_openvino(self, template, tmp_dir_path):
tmp_dir_path = tmp_dir_path / "detection"
otx_explain_process_saliency_maps_openvino_testing(template, tmp_dir_path, otx_dir, args, trained=True)

@e2e_pytest_component
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
Expand Down
10 changes: 1 addition & 9 deletions tests/integration/cli/detection/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@
"train_params": ["params", "--learning_parameters.num_iters", "1", "--learning_parameters.batch_size", "4"],
}

num_iters_per_model = {
"Custom_Object_Detection_YOLOX": "10",
"Custom_Object_Detection_Gen3_ATSS": "30",
"Custom_Object_Detection_Gen3_SSD": "10",
}

args_semisl = {
"--train-data-roots": "tests/assets/car_tree_bug",
"--val-data-roots": "tests/assets/car_tree_bug",
Expand Down Expand Up @@ -89,9 +83,7 @@ class TestDetectionCLI:
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_train(self, template, tmp_dir_path):
tmp_dir_path = tmp_dir_path / "detection"
args1 = copy.deepcopy(args)
args1["train_params"][2] = num_iters_per_model[template.model_template_id]
otx_train_testing(template, tmp_dir_path, otx_dir, args1)
otx_train_testing(template, tmp_dir_path, otx_dir, args)

@e2e_pytest_component
@pytest.mark.parametrize("template", default_templates, ids=default_templates_ids)
Expand Down
13 changes: 6 additions & 7 deletions tests/test_suite/run_test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def otx_export_testing_w_features(template, root):
with open(path_to_xml, encoding="utf-8") as stream:
xml_model = stream.read()
assert "feature_vector" in xml_model
assert "saliency_map" in xml_model


def otx_eval_testing(template, root, otx_dir, args):
Expand Down Expand Up @@ -685,7 +684,7 @@ def otx_explain_testing(template, root, otx_dir, args, trained=False):
assert os.path.exists(output_dir)
if trained:
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "tiff" for fname in os.listdir(output_dir)])
assert all([os.path.splitext(fname)[1] == ".tiff" for fname in os.listdir(output_dir)])


def otx_explain_testing_all_classes(template, root, otx_dir, args):
Expand Down Expand Up @@ -729,7 +728,7 @@ def otx_explain_testing_all_classes(template, root, otx_dir, args):
assert len(os.listdir(output_dir)) == len(os.listdir(output_dir_explain_only_predicted_classes))
else:
assert len(os.listdir(output_dir)) >= len(os.listdir(output_dir_explain_only_predicted_classes))
assert all([fname.split(".")[-1] == "tiff" for fname in os.listdir(output_dir)])
assert all([os.path.splitext(fname)[1] == ".tiff" for fname in os.listdir(output_dir)])


def otx_explain_testing_process_saliency_maps(template, root, otx_dir, args, trained=False):
Expand Down Expand Up @@ -768,7 +767,7 @@ def otx_explain_testing_process_saliency_maps(template, root, otx_dir, args, tra
assert os.path.exists(output_dir)
if trained:
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "png" for fname in os.listdir(output_dir)])
assert all([os.path.splitext(fname)[1] == ".png" for fname in os.listdir(output_dir)])


def otx_explain_openvino_testing(template, root, otx_dir, args, trained=False):
Expand Down Expand Up @@ -807,7 +806,7 @@ def otx_explain_openvino_testing(template, root, otx_dir, args, trained=False):
assert os.path.exists(output_dir)
if trained:
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "tiff" for fname in os.listdir(output_dir)])
assert all([os.path.splitext(fname)[1] == ".tiff" for fname in os.listdir(output_dir)])


def otx_explain_all_classes_openvino_testing(template, root, otx_dir, args):
Expand Down Expand Up @@ -852,7 +851,7 @@ def otx_explain_all_classes_openvino_testing(template, root, otx_dir, args):
assert len(os.listdir(output_dir)) == len(os.listdir(output_dir_explain_only_predicted_classes))
else:
assert len(os.listdir(output_dir)) >= len(os.listdir(output_dir_explain_only_predicted_classes))
assert all([fname.split(".")[-1] == "tiff" for fname in os.listdir(output_dir)])
assert all([os.path.splitext(fname)[1] == ".tiff" for fname in os.listdir(output_dir)])


def otx_explain_process_saliency_maps_openvino_testing(template, root, otx_dir, args, trained=False):
Expand Down Expand Up @@ -892,7 +891,7 @@ def otx_explain_process_saliency_maps_openvino_testing(template, root, otx_dir,
assert os.path.exists(output_dir)
if trained:
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "png" for fname in os.listdir(output_dir)])
assert all([os.path.splitext(fname)[1] == ".png" for fname in os.listdir(output_dir)])


def otx_find_testing():
Expand Down

0 comments on commit 32d851b

Please sign in to comment.