Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable export of feature vectors for semantic segmentation task #4055

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/otx/algo/segmentation/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,8 @@

def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
msg = "Explain mode is not supported for this model."
raise NotImplementedError(msg)

Check warning on line 167 in src/otx/algo/segmentation/huggingface_model.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/huggingface_model.py#L165-L167

Added lines #L165 - L167 were not covered by tests

return self.model(image)
2 changes: 1 addition & 1 deletion src/otx/algo/segmentation/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _exporter(self) -> OTXModelExporter:
swap_rgb=False,
via_onnx=False,
onnx_export_configuration={"operator_export_type": OperatorExportTypes.ONNX_ATEN_FALLBACK},
output_names=None,
output_names=["preds", "feature_vector"] if self.explain_mode else None,
)

@property
Expand Down
15 changes: 12 additions & 3 deletions src/otx/algo/segmentation/segmentors/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch.nn.functional as f
from torch import Tensor, nn

from otx.algo.explain.explain_algo import feature_vector_fn

if TYPE_CHECKING:
from otx.core.data.entity.base import ImageInfo

Expand Down Expand Up @@ -58,7 +60,7 @@
- If mode is "predict", returns the predicted outputs.
- Otherwise, returns the model outputs after interpolation.
"""
outputs = self.extract_features(inputs)
enc_feats, outputs = self.extract_features(inputs)
outputs = f.interpolate(outputs, size=inputs.size()[2:], mode="bilinear", align_corners=True)

if mode == "tensor":
Expand All @@ -76,12 +78,19 @@
if mode == "predict":
return outputs.argmax(dim=1)

if mode == "explain":
feature_vector = feature_vector_fn(enc_feats)
return {

Check warning on line 83 in src/otx/algo/segmentation/segmentors/base_model.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/segmentation/segmentors/base_model.py#L81-L83

Added lines #L81 - L83 were not covered by tests
"preds": outputs,
"feature_vector": feature_vector,
}

return outputs

def extract_features(self, inputs: Tensor) -> Tensor:
def extract_features(self, inputs: Tensor) -> tuple[Tensor, Tensor]:
"""Extract features from the backbone and head."""
enc_feats = self.backbone(inputs)
return self.decode_head(enc_feats)
return enc_feats, self.decode_head(enc_feats)

def calculate_loss(
self,
Expand Down
1 change: 0 additions & 1 deletion src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def __init__(
self.input_size = input_size
self.classification_layers: dict[str, dict[str, Any]] = {}
self.model = self._create_model()
self._explain_mode = False
self.optimizer_callable = ensure_callable(optimizer)
self.scheduler_callable = ensure_callable(scheduler)
self.metric_callable = ensure_callable(metric)
Expand Down
81 changes: 63 additions & 18 deletions src/otx/core/model/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@
"""

def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]:
mode = "loss" if self.training else "predict"
if self.training:
mode = "loss"
elif self.explain_mode:
mode = "explain"

Check warning on line 125 in src/otx/core/model/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/segmentation.py#L124-L125

Added lines #L124 - L125 were not covered by tests
else:
mode = "predict"

Check warning on line 127 in src/otx/core/model/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/segmentation.py#L127

Added line #L127 was not covered by tests

if self.train_type == OTXTrainType.SEMI_SUPERVISED and mode == "loss":
if not isinstance(entity, dict):
Expand Down Expand Up @@ -155,6 +160,16 @@
losses[k] = v
return losses

if self.explain_mode:
return SegBatchPredEntity(

Check warning on line 164 in src/otx/core/model/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/segmentation.py#L164

Added line #L164 was not covered by tests
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=[],
masks=outputs["preds"],
feature_vector=outputs["feature_vector"],
)

return SegBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
Expand Down Expand Up @@ -199,14 +214,24 @@
swap_rgb=False,
via_onnx=False,
onnx_export_configuration=None,
output_names=None,
output_names=["preds", "feature_vector"] if self.explain_mode else None,
)

def _convert_pred_entity_to_compute_metric(
self,
preds: SegBatchPredEntity,
inputs: SegBatchDataEntity,
) -> MetricInput:
"""Convert prediction and input entities to a format suitable for metric computation.

Args:
preds (SegBatchPredEntity): The predicted segmentation batch entity containing predicted masks.
inputs (SegBatchDataEntity): The input segmentation batch entity containing ground truth masks.

Returns:
MetricInput: A list of dictionaries where each dictionary contains 'preds' and 'target' keys
corresponding to the predicted and target masks for metric evaluation.
"""
return [
{
"preds": pred_mask,
Expand All @@ -228,8 +253,26 @@

def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
raw_outputs = self.model(inputs=image, mode="tensor")
return torch.softmax(raw_outputs, dim=1)
if self.explain_mode:
outputs = self.model(inputs=image, mode="explain")
outputs["preds"] = torch.softmax(outputs["preds"], dim=1)
return outputs

Check warning on line 259 in src/otx/core/model/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/segmentation.py#L256-L259

Added lines #L256 - L259 were not covered by tests

outputs = self.model(inputs=image, mode="tensor")
return torch.softmax(outputs, dim=1)

Check warning on line 262 in src/otx/core/model/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/segmentation.py#L261-L262

Added lines #L261 - L262 were not covered by tests

def forward_explain(self, inputs: SegBatchDataEntity) -> SegBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(inputs=inputs.images, mode="explain")

Check warning on line 266 in src/otx/core/model/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/segmentation.py#L266

Added line #L266 was not covered by tests

return SegBatchPredEntity(

Check warning on line 268 in src/otx/core/model/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/segmentation.py#L268

Added line #L268 was not covered by tests
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=[],
masks=outputs["preds"],
feature_vector=outputs["feature_vector"],
)

def get_dummy_input(self, batch_size: int = 1) -> SegBatchDataEntity:
"""Returns a dummy input for semantic segmentation model."""
Expand Down Expand Up @@ -308,32 +351,34 @@
outputs: list[ImageResultWithSoftPrediction],
inputs: SegBatchDataEntity,
) -> SegBatchPredEntity | OTXBatchLossEntity:
if outputs and outputs[0].saliency_map.size != 1:
predicted_s_maps = [out.saliency_map for out in outputs]
predicted_f_vectors = [out.feature_vector for out in outputs]
return SegBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=[],
masks=[tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs],
saliency_map=predicted_s_maps,
feature_vector=predicted_f_vectors,
)

masks = [tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs]
predicted_f_vectors = (

Check warning on line 355 in src/otx/core/model/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/segmentation.py#L354-L355

Added lines #L354 - L355 were not covered by tests
[out.feature_vector for out in outputs] if outputs and outputs[0].feature_vector.size != 1 else []
)
return SegBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=[],
masks=[tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs],
masks=masks,
feature_vector=predicted_f_vectors,
)

def _convert_pred_entity_to_compute_metric(
self,
preds: SegBatchPredEntity,
inputs: SegBatchDataEntity,
) -> MetricInput:
"""Convert prediction and input entities to a format suitable for metric computation.

Args:
preds (SegBatchPredEntity): The predicted segmentation batch entity containing predicted masks.
inputs (SegBatchDataEntity): The input segmentation batch entity containing ground truth masks.

Returns:
MetricInput: A list of dictionaries where each dictionary contains 'preds' and 'target' keys
corresponding to the predicted and target masks for metric evaluation.
"""
return [
{
"preds": pred_mask,
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def test_otx_e2e_cli(
# 5) otx export with XAI
if "instance_segmentation/rtmdet_inst_tiny" in recipe:
return
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]):
return # Supported only for classification, detection and instance segmentation task.
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation", "semantic_segmentation"]):
return # Supported only for classification, detection and segmentation tasks.

unsupported_models = ["dino", "rtdetr"]
if any(model in model_name for model in unsupported_models):
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def test_otx_e2e(
# 5) otx export with XAI
if "instance_segmentation/rtmdet_inst_tiny" in recipe:
return
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]):
return # Supported only for classification, detection and instance segmentation task.
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation", "semantic_segmentation"]):
return # Supported only for classification, detection and segmentation tasks.

if "dino" in model_name:
return # DINO is not supported.
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/algo/segmentation/segmentors/test_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ def test_forward_returns_prediction(self, model, inputs):
def test_extract_features(self, model, inputs):
images = inputs[0]
features = model.extract_features(images)
assert isinstance(features, torch.Tensor)
assert features.shape == (1, 2, 256, 256)
assert isinstance(features, tuple)
assert isinstance(features[0], torch.Tensor)
assert isinstance(features[1], torch.Tensor)
assert features[1].shape == (1, 2, 256, 256)

def test_calculate_loss(self, model, inputs):
model.criterion.name = "CrossEntropyLoss"
Expand Down
Loading