Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix icevision task metric names (#1252)
Browse files Browse the repository at this point in the history
Co-authored-by: Kushashwa Ravi Shrimali <[email protected]>
  • Loading branch information
ethanwharris and krshrimali committed Mar 30, 2022
1 parent 6c12885 commit 3e4f39a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

<<<<<<< HEAD
- Fixed examples (question answering), where NLTK's `punkt` module needs to be downloaded first. ([#1215](https://github.com/PyTorchLightning/lightning-flash/pull/1215/files))
- Fixed normalizing inputs to video classification ([#1213](https://github.com/PyTorchLightning/lightning-flash/pull/1213))
- Fixed a bug where `pretraining_transforms` in the `ImageEmbedder` was never called. ([1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196))
Expand All @@ -21,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed DDP spawn support for `ObjectDetector`, `InstanceSegmentation`, and `KeypointDetector` ([#1222](https://github.com/PyTorchLightning/lightning-flash/pull/1222))
- Fixed a bug where `InstanceSegmentation` would fail if samples had an inconsistent number of bboxes, labels, and masks (these will now be treated as negative samples) ([#1222](https://github.com/PyTorchLightning/lightning-flash/pull/1222))
- Fixed a bug where collate functions were never called in the `ImageEmbedder` class. ([#1217](https://github.com/PyTorchLightning/lightning-flash/pull/1217))
- Fixed a bug where `ObjectDetector`, `InstanceSegmentation`, and `KeypointDetector` would log train and validation metrics with the same name ([#1252](https://github.com/PyTorchLightning/lightning-flash/pull/1252))

### Removed

Expand Down
1 change: 1 addition & 0 deletions flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
RunningStage.VALIDATING: "val",
RunningStage.PREDICTING: "predict",
RunningStage.SERVING: "serve",
RunningStage.SANITY_CHECKING: "val",
}
_STAGES_PREFIX_VALUES = {"train", "test", "val", "predict", "serve"}

Expand Down
4 changes: 2 additions & 2 deletions flash/core/integrations/icevision/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class SimpleCOCOMetric(COCOMetric):
def finalize(self) -> Dict[str, float]:
logs = super().finalize()
return {
"Precision (IoU=0.50:0.95,area=all)": logs["AP (IoU=0.50:0.95) area=all"],
"Recall (IoU=0.50:0.95,area=all,maxDets=100)": logs["AR (IoU=0.50:0.95) area=all maxDets=100"],
"precision (IoU=0.50:0.95,area=all)": logs["AP (IoU=0.50:0.95) area=all"],
"recall (IoU=0.50:0.95,area=all,maxDets=100)": logs["AR (IoU=0.50:0.95) area=all maxDets=100"],
}


Expand Down
9 changes: 6 additions & 3 deletions flash/core/integrations/icevision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,27 @@

from torch import nn

from flash.core.data.utils import _STAGES_PREFIX
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _ICEVISION_AVAILABLE

if _ICEVISION_AVAILABLE:
from icevision.backbones import BackboneConfig


def _log_with_prog_bar_override(self, name, value, **kwargs):
def _log_with_name_and_prog_bar_override(self, name, value, **kwargs):
if "prog_bar" not in kwargs:
kwargs["prog_bar"] = True
return self._original_log(name.split("/")[-1], value, **kwargs)
metric = name.split("/")[-1]
metric = f"{_STAGES_PREFIX[self.trainer.state.stage]}_{metric}"
return self._original_log(metric, value, **kwargs)


def icevision_model_adapter(model_type):
adapter = model_type.lightning.ModelAdapter
if not hasattr(adapter, "_original_log"):
adapter._original_log = adapter.log
adapter.log = _log_with_prog_bar_override
adapter.log = _log_with_name_and_prog_bar_override
return adapter


Expand Down

0 comments on commit 3e4f39a

Please sign in to comment.