Skip to content

Commit

Permalink
Update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Oct 18, 2023
1 parent f29a7da commit 70fae68
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/otx/algorithms/segmentation/utils/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ def get_seg_model_api_configuration(label_schema: LabelSchemaEntity, hyperparams
("model_info", "blur_strength"): str(hyperparams.postprocessing.blur_strength),
("model_info", "labels"): all_labels.strip(),
("model_info", "label_ids"): all_label_ids.strip(),
("model_info", "task_type"): "segmentation",
}
4 changes: 4 additions & 0 deletions tests/unit/algorithms/classification/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,7 @@ def test_get_cls_model_api_configuration(default_hierarchical_data):
assert len(model_api_cfg) > 0
assert model_api_cfg[("model_info", "confidence_threshold")] == str(config["confidence_threshold"])
assert ("model_info", "hierarchical_config") in model_api_cfg
assert ("model_info", "labels") in model_api_cfg
assert ("model_info", "label_ids") in model_api_cfg
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "labels")].split())
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "label_ids")].split())
4 changes: 4 additions & 0 deletions tests/unit/algorithms/detection/utils/test_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ def test_get_det_model_api_configuration():
tiling_parameters.tile_overlap / tiling_parameters.tile_ir_scale_factor
)
assert model_api_cfg[("model_info", "max_pred_number")] == str(tiling_parameters.tile_max_number)
assert ("model_info", "labels") in model_api_cfg
assert ("model_info", "label_ids") in model_api_cfg
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "labels")].split())
assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "label_ids")].split())

0 comments on commit 70fae68

Please sign in to comment.