diff --git a/CHANGELOG.md b/CHANGELOG.md index d26cdcd97c..5f7860d7fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added predict_kwargs in `ObjectDetector`, `InstanceSegmentation`, `KeypointDetector` ([#990](https://github.com/PyTorchLightning/lightning-flash/pull/990)) + - Added backbones for `GraphClassifier` ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592)) - Added `GraphEmbedder` task ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592)) diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index e723bc2cd5..7d155dd5b9 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -46,13 +46,14 @@ class IceVisionAdapter(Adapter): required_extras: str = "image" - def __init__(self, model_type, model, icevision_adapter, backbone): + def __init__(self, model_type, model, icevision_adapter, backbone, predict_kwargs): super().__init__() self.model_type = model_type self.model = model self.icevision_adapter = icevision_adapter self.backbone = backbone + self.predict_kwargs = predict_kwargs @classmethod @catch_url_error @@ -62,6 +63,7 @@ def from_task( num_classes: int, backbone: str, head: str, + predict_kwargs: Dict, pretrained: bool = True, metrics: Optional["IceVisionMetric"] = None, image_size: Optional = None, @@ -77,7 +79,7 @@ def from_task( **kwargs, ) icevision_adapter = icevision_adapter(model=model, metrics=metrics) - return cls(model_type, model, icevision_adapter, backbone) + return cls(model_type, model, icevision_adapter, backbone, predict_kwargs) @staticmethod def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None): @@ -198,7 +200,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A return batch def forward(self, batch: Any) -> Any: - return from_icevision_predictions(self.model_type.predict_from_dl(self.model, [batch], show_pbar=False)) + return from_icevision_predictions( + self.model_type.predict_from_dl(self.model, [batch], show_pbar=False, **self.predict_kwargs) + ) def training_epoch_end(self, outputs) -> None: return self.icevision_adapter.training_epoch_end(outputs) diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 0a080af611..94905f81e5 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -33,6 +33,7 @@ class ObjectDetector(AdapterTask): lr_scheduler: The LR scheduler to use during training. learning_rate: The learning rate to use for training. output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. + predict_kwargs: dictionary containing parameters that will be used during the prediction phase. kwargs: additional kwargs nessesary for initializing the backbone task """ @@ -50,10 +51,12 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-3, output: OUTPUT_TYPE = None, + predict_kwargs: Dict = None, **kwargs: Any, ): self.save_hyperparameters() + predict_kwargs = predict_kwargs if predict_kwargs else {} metadata = self.heads.get(head, with_metadata=True) adapter = metadata["metadata"]["adapter"].from_task( self, @@ -61,6 +64,7 @@ def __init__( backbone=backbone, head=head, pretrained=pretrained, + predict_kwargs=predict_kwargs, **kwargs, ) @@ -75,3 +79,12 @@ def __init__( def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """This function is used only for debugging usage with CI.""" # todo + + @property + def predict_kwargs(self) -> Dict[str, Any]: + """The kwargs used for the prediction step.""" + return self.adapter.predict_kwargs + + @predict_kwargs.setter + def predict_kwargs(self, predict_kwargs: Dict[str, Any]): + self.adapter.predict_kwargs = predict_kwargs diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 50c1936b9e..eb0e257653 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -40,6 +40,7 @@ class InstanceSegmentation(AdapterTask): lr_scheduler: The LR scheduler to use during training. learning_rate: The learning rate to use for training. output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. + predict_kwargs: dictionary containing parameters that will be used during the prediction phase. **kwargs: additional kwargs used for initializing the task """ @@ -57,10 +58,12 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-4, output: OUTPUT_TYPE = None, + predict_kwargs: Dict = None, **kwargs: Any, ): self.save_hyperparameters() + predict_kwargs = predict_kwargs if predict_kwargs else {} metadata = self.heads.get(head, with_metadata=True) adapter = metadata["metadata"]["adapter"].from_task( self, @@ -68,6 +71,7 @@ def __init__( backbone=backbone, head=head, pretrained=pretrained, + predict_kwargs=predict_kwargs, **kwargs, ) @@ -96,3 +100,12 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: input_transform=InstanceSegmentationInputTransform(), output_transform=InstanceSegmentationOutputTransform(), ) + + @property + def predict_kwargs(self) -> Dict[str, Any]: + """The kwargs used for the prediction step.""" + return self.adapter.predict_kwargs + + @predict_kwargs.setter + def predict_kwargs(self, predict_kwargs: Dict[str, Any]): + self.adapter.predict_kwargs = predict_kwargs diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 3b404d8235..1993ee1ac9 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -34,6 +34,7 @@ class KeypointDetector(AdapterTask): lr_scheduler: The LR scheduler to use during training. learning_rate: The learning rate to use for training. output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. + predict_kwargs: dictionary containing parameters that will be used during the prediction phase. **kwargs: additional kwargs used for initializing the task """ @@ -52,10 +53,12 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-4, output: OUTPUT_TYPE = None, + predict_kwargs: Dict = None, **kwargs: Any, ): self.save_hyperparameters() + predict_kwargs = predict_kwargs if predict_kwargs else {} metadata = self.heads.get(head, with_metadata=True) adapter = metadata["metadata"]["adapter"].from_task( self, @@ -64,6 +67,7 @@ def __init__( backbone=backbone, head=head, pretrained=pretrained, + predict_kwargs=predict_kwargs, **kwargs, ) @@ -78,3 +82,12 @@ def __init__( def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """This function is used only for debugging usage with CI.""" # todo + + @property + def predict_kwargs(self) -> Dict[str, Any]: + """The kwargs used for the prediction step.""" + return self.adapter.predict_kwargs + + @predict_kwargs.setter + def predict_kwargs(self, predict_kwargs: Dict[str, Any]): + self.adapter.predict_kwargs = predict_kwargs diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index 5c4997b151..903948a6de 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -134,3 +134,20 @@ def test_cli(): main() except SystemExit: pass + + +@pytest.mark.parametrize("head", ["retinanet"]) +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="IceVision is not installed for testing") +def test_predict(tmpdir, head): + model = ObjectDetector(num_classes=2, head=head, pretrained=False) + ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + dl = model.process_train_dataset(ds, trainer, 2, 0, False, None) + trainer.fit(model, dl) + dl = model.process_predict_dataset(ds, batch_size=2) + predictions = trainer.predict(model, dl) + assert len(predictions[0][0]["bboxes"]) > 0 + model.predict_kwargs = {"detection_threshold": 2} + predictions = trainer.predict(model, dl) + assert len(predictions[0][0]["bboxes"]) == 0