diff --git a/src/otx/algo/segmentation/huggingface_model.py b/src/otx/algo/segmentation/huggingface_model.py index 83629896ed8..2a3a65eab69 100644 --- a/src/otx/algo/segmentation/huggingface_model.py +++ b/src/otx/algo/segmentation/huggingface_model.py @@ -162,4 +162,8 @@ def _exporter(self) -> OTXModelExporter: def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + msg = "Explain mode is not supported for this model." + raise NotImplementedError(msg) + return self.model(image) diff --git a/src/otx/algo/segmentation/litehrnet.py b/src/otx/algo/segmentation/litehrnet.py index 73ade597fd8..b7b1e6e2f7a 100644 --- a/src/otx/algo/segmentation/litehrnet.py +++ b/src/otx/algo/segmentation/litehrnet.py @@ -81,7 +81,7 @@ def _exporter(self) -> OTXModelExporter: swap_rgb=False, via_onnx=False, onnx_export_configuration={"operator_export_type": OperatorExportTypes.ONNX_ATEN_FALLBACK}, - output_names=None, + output_names=["preds", "feature_vector"] if self.explain_mode else None, ) @property diff --git a/src/otx/algo/segmentation/segmentors/base_model.py b/src/otx/algo/segmentation/segmentors/base_model.py index c66c49f84f5..7d7550bbf76 100644 --- a/src/otx/algo/segmentation/segmentors/base_model.py +++ b/src/otx/algo/segmentation/segmentors/base_model.py @@ -10,6 +10,8 @@ import torch.nn.functional as f from torch import Tensor, nn +from otx.algo.explain.explain_algo import feature_vector_fn + if TYPE_CHECKING: from otx.core.data.entity.base import ImageInfo @@ -58,7 +60,7 @@ def forward( - If mode is "predict", returns the predicted outputs. - Otherwise, returns the model outputs after interpolation. """ - outputs = self.extract_features(inputs) + enc_feats, outputs = self.extract_features(inputs) outputs = f.interpolate(outputs, size=inputs.size()[2:], mode="bilinear", align_corners=True) if mode == "tensor": @@ -76,12 +78,19 @@ def forward( if mode == "predict": return outputs.argmax(dim=1) + if mode == "explain": + feature_vector = feature_vector_fn(enc_feats) + return { + "preds": outputs, + "feature_vector": feature_vector, + } + return outputs - def extract_features(self, inputs: Tensor) -> Tensor: + def extract_features(self, inputs: Tensor) -> tuple[Tensor, Tensor]: """Extract features from the backbone and head.""" enc_feats = self.backbone(inputs) - return self.decode_head(enc_feats) + return enc_feats, self.decode_head(enc_feats) def calculate_loss( self, diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index ea6308afc6d..ac2331ff885 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -124,7 +124,6 @@ def __init__( self.input_size = input_size self.classification_layers: dict[str, dict[str, Any]] = {} self.model = self._create_model() - self._explain_mode = False self.optimizer_callable = ensure_callable(optimizer) self.scheduler_callable = ensure_callable(scheduler) self.metric_callable = ensure_callable(metric) diff --git a/src/otx/core/model/segmentation.py b/src/otx/core/model/segmentation.py index a7eecdffe8c..0003307e376 100644 --- a/src/otx/core/model/segmentation.py +++ b/src/otx/core/model/segmentation.py @@ -119,7 +119,12 @@ def _build_model(self) -> nn.Module: """ def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]: - mode = "loss" if self.training else "predict" + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" + else: + mode = "predict" if self.train_type == OTXTrainType.SEMI_SUPERVISED and mode == "loss": if not isinstance(entity, dict): @@ -155,6 +160,16 @@ def _customize_outputs( losses[k] = v return losses + if self.explain_mode: + return SegBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=[], + masks=outputs["preds"], + feature_vector=outputs["feature_vector"], + ) + return SegBatchPredEntity( batch_size=len(outputs), images=inputs.images, @@ -199,7 +214,7 @@ def _exporter(self) -> OTXModelExporter: swap_rgb=False, via_onnx=False, onnx_export_configuration=None, - output_names=None, + output_names=["preds", "feature_vector"] if self.explain_mode else None, ) def _convert_pred_entity_to_compute_metric( @@ -207,6 +222,16 @@ def _convert_pred_entity_to_compute_metric( preds: SegBatchPredEntity, inputs: SegBatchDataEntity, ) -> MetricInput: + """Convert prediction and input entities to a format suitable for metric computation. + + Args: + preds (SegBatchPredEntity): The predicted segmentation batch entity containing predicted masks. + inputs (SegBatchDataEntity): The input segmentation batch entity containing ground truth masks. + + Returns: + MetricInput: A list of dictionaries where each dictionary contains 'preds' and 'target' keys + corresponding to the predicted and target masks for metric evaluation. + """ return [ { "preds": pred_mask, @@ -228,8 +253,26 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo: def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: """Model forward function used for the model tracing during model exportation.""" - raw_outputs = self.model(inputs=image, mode="tensor") - return torch.softmax(raw_outputs, dim=1) + if self.explain_mode: + outputs = self.model(inputs=image, mode="explain") + outputs["preds"] = torch.softmax(outputs["preds"], dim=1) + return outputs + + outputs = self.model(inputs=image, mode="tensor") + return torch.softmax(outputs, dim=1) + + def forward_explain(self, inputs: SegBatchDataEntity) -> SegBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(inputs=inputs.images, mode="explain") + + return SegBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=[], + masks=outputs["preds"], + feature_vector=outputs["feature_vector"], + ) def get_dummy_input(self, batch_size: int = 1) -> SegBatchDataEntity: """Returns a dummy input for semantic segmentation model.""" @@ -308,25 +351,17 @@ def _customize_outputs( outputs: list[ImageResultWithSoftPrediction], inputs: SegBatchDataEntity, ) -> SegBatchPredEntity | OTXBatchLossEntity: - if outputs and outputs[0].saliency_map.size != 1: - predicted_s_maps = [out.saliency_map for out in outputs] - predicted_f_vectors = [out.feature_vector for out in outputs] - return SegBatchPredEntity( - batch_size=len(outputs), - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=[], - masks=[tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs], - saliency_map=predicted_s_maps, - feature_vector=predicted_f_vectors, - ) - + masks = [tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs] + predicted_f_vectors = ( + [out.feature_vector for out in outputs] if outputs and outputs[0].feature_vector.size != 1 else [] + ) return SegBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, scores=[], - masks=[tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs], + masks=masks, + feature_vector=predicted_f_vectors, ) def _convert_pred_entity_to_compute_metric( @@ -334,6 +369,16 @@ def _convert_pred_entity_to_compute_metric( preds: SegBatchPredEntity, inputs: SegBatchDataEntity, ) -> MetricInput: + """Convert prediction and input entities to a format suitable for metric computation. + + Args: + preds (SegBatchPredEntity): The predicted segmentation batch entity containing predicted masks. + inputs (SegBatchDataEntity): The input segmentation batch entity containing ground truth masks. + + Returns: + MetricInput: A list of dictionaries where each dictionary contains 'preds' and 'target' keys + corresponding to the predicted and target masks for metric evaluation. + """ return [ { "preds": pred_mask, diff --git a/tests/e2e/cli/test_cli.py b/tests/e2e/cli/test_cli.py index 07ac0daf103..54154114390 100644 --- a/tests/e2e/cli/test_cli.py +++ b/tests/e2e/cli/test_cli.py @@ -220,8 +220,8 @@ def test_otx_e2e_cli( # 5) otx export with XAI if "instance_segmentation/rtmdet_inst_tiny" in recipe: return - if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]): - return # Supported only for classification, detection and instance segmentation task. + if ("_cls" not in task) and (task not in ["detection", "instance_segmentation", "semantic_segmentation"]): + return # Supported only for classification, detection and segmentation tasks. unsupported_models = ["dino", "rtdetr"] if any(model in model_name for model in unsupported_models): diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index c754173f39d..19d65ee3431 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -242,8 +242,8 @@ def test_otx_e2e( # 5) otx export with XAI if "instance_segmentation/rtmdet_inst_tiny" in recipe: return - if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]): - return # Supported only for classification, detection and instance segmentation task. + if ("_cls" not in task) and (task not in ["detection", "instance_segmentation", "semantic_segmentation"]): + return # Supported only for classification, detection and segmentation tasks. if "dino" in model_name: return # DINO is not supported. diff --git a/tests/unit/algo/segmentation/segmentors/test_base_model.py b/tests/unit/algo/segmentation/segmentors/test_base_model.py index 33a33af4dda..32a20baa2c8 100644 --- a/tests/unit/algo/segmentation/segmentors/test_base_model.py +++ b/tests/unit/algo/segmentation/segmentors/test_base_model.py @@ -43,8 +43,10 @@ def test_forward_returns_prediction(self, model, inputs): def test_extract_features(self, model, inputs): images = inputs[0] features = model.extract_features(images) - assert isinstance(features, torch.Tensor) - assert features.shape == (1, 2, 256, 256) + assert isinstance(features, tuple) + assert isinstance(features[0], torch.Tensor) + assert isinstance(features[1], torch.Tensor) + assert features[1].shape == (1, 2, 256, 256) def test_calculate_loss(self, model, inputs): model.criterion.name = "CrossEntropyLoss"