Skip to content

Commit

Permalink
Add dummy XAI to RTDETR (export mode) & disable strong aug (#4106)
Browse files Browse the repository at this point in the history
* Implement warning for unsupported explain mode in DETR model and update transform probabilities to zero in RTDETR recipes

* update changelog

* Update photometric distortion probability in RTDETR recipes
  • Loading branch information
eugene123tw authored Nov 8, 2024
1 parent 0556ea6 commit 88ab4b8
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4082>)
- Fix RTMDet Inst Explain Mode
(<https://github.com/openvinotoolkit/training_extensions/pull/4083>)
- Fix RTDETR Explain Mode
(<https://github.com/openvinotoolkit/training_extensions/pull/4106>)

## \[v2.1.0\]

Expand Down
17 changes: 12 additions & 5 deletions src/otx/algo/detection/base_models/detection_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import warnings
from typing import Any

import numpy as np
Expand Down Expand Up @@ -95,16 +96,22 @@ def export(
explain_mode: bool = False,
) -> dict[str, Any] | tuple[list[Any], list[Any], list[Any]]:
"""Exports the model."""
if explain_mode:
msg = "Explain mode is not supported for DETR models yet."
raise NotImplementedError(msg)

return self.postprocess(
results = self.postprocess(
self._forward_features(batch_inputs),
[meta["img_shape"] for meta in batch_img_metas],
deploy_mode=True,
)

if explain_mode:
# TODO(Eugene): Implement explain mode for DETR model.
warnings.warn("Explain mode is not supported for DETR model. Return dummy values.", stacklevel=2)
xai_output = {
"feature_vector": torch.zeros(1, 1),
"saliency_map": torch.zeros(1),
}
results.update(xai_output) # type: ignore[union-attr]
return results

def postprocess(
self,
outputs: dict[str, Tensor],
Expand Down
3 changes: 0 additions & 3 deletions src/otx/recipe/detection/rtdetr_101.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ overrides:
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
init_args:
p: 0.5
- class_path: torchvision.transforms.v2.RandomZoomOut
init_args:
fill: 0
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
Expand Down
3 changes: 0 additions & 3 deletions src/otx/recipe/detection/rtdetr_18.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ overrides:
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
init_args:
p: 0.5
- class_path: torchvision.transforms.v2.RandomZoomOut
init_args:
fill: 0
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
Expand Down
3 changes: 0 additions & 3 deletions src/otx/recipe/detection/rtdetr_50.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ overrides:
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
init_args:
p: 0.5
- class_path: torchvision.transforms.v2.RandomZoomOut
init_args:
fill: 0
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
Expand Down

0 comments on commit 88ab4b8

Please sign in to comment.