diff --git a/flash/core/model.py b/flash/core/model.py index eafaea4af6..9914b4cb61 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -181,8 +181,10 @@ def _resolve( new_preprocess: Optional[Preprocess], new_postprocess: Optional[Postprocess], ) -> Tuple[Optional[Preprocess], Optional[Postprocess]]: - """Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use, choosing ``new_*`` if it is not - None or a base class (:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`) and ``old_*`` otherwise. + """Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use, + choosing ``new_*`` if it is not None or a base class + (:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`) + and ``old_*`` otherwise. Args: old_preprocess: :class:`~flash.data.process.Preprocess` to be overridden. @@ -204,7 +206,8 @@ def _resolve( return preprocess, postprocess def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]: - """Build a :class:`.DataPipeline` incorporating available :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` + """Build a :class:`.DataPipeline` incorporating available + :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. These will be overridden in the following resolution order (lowest priority first): - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. @@ -213,7 +216,8 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O - :class:`.DataPipeline` passed to this method. Args: - data_pipeline: Optional highest priority source of :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. + data_pipeline: Optional highest priority source of + :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. Returns: The fully resolved :class:`.DataPipeline`. diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 528ce99063..5b6d9dca30 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -115,7 +115,4 @@ def __init__( def forward(self, x) -> torch.Tensor: x = self.backbone(x) - if self.hparams.multi_label: - return self.head(x) - else: - return torch.softmax(self.head(x), -1) + return self.head(x)