diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f0f116184..fe0dfed42e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,6 +64,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where the backbone would not be frozen when finetuning the `QuestionAnswering` task with certain model types ([#1275](https://github.com/PyTorchLightning/lightning-flash/pull/1275)) +- Fixed a bug where the default Flash zero configurations for `ObjectDetector`, `InstanceSegmentation`, and `KeypointDetector` would error with the latest version of some requirements ([#1306](https://github.com/PyTorchLightning/lightning-flash/pull/1306)) + ## [0.7.0] - 2022-02-15 ### Added diff --git a/flash/image/detection/cli.py b/flash/image/detection/cli.py index 47298246de..3fa6a8a05b 100644 --- a/flash/image/detection/cli.py +++ b/flash/image/detection/cli.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Any, Dict, Optional from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI @@ -22,7 +22,7 @@ def from_coco_128( val_split: float = 0.1, - image_size: Tuple[int, int] = (128, 128), + transform_kwargs: Optional[Dict[str, Any]] = None, batch_size: int = 1, **data_module_kwargs, ) -> ObjectDetectionData: @@ -32,7 +32,7 @@ def from_coco_128( train_folder="data/coco128/images/train2017/", train_ann_file="data/coco128/annotations/instances_train2017.json", val_split=val_split, - transform_kwargs=dict(image_size=image_size), + transform_kwargs=dict(image_size=(128, 128)) if transform_kwargs is None else transform_kwargs, batch_size=batch_size, **data_module_kwargs, ) diff --git a/flash/image/instance_segmentation/cli.py b/flash/image/instance_segmentation/cli.py index 90a80e46ac..31aa610873 100644 --- a/flash/image/instance_segmentation/cli.py +++ b/flash/image/instance_segmentation/cli.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Optional, Tuple +from typing import Any, Callable, Dict, Optional from flash.core.utilities.flash_cli import FlashCLI from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires @@ -34,8 +34,8 @@ def from_pets( test_ann_file: Optional[str] = None, predict_folder: Optional[str] = None, val_split: float = 0.1, - image_size: Tuple[int, int] = (128, 128), parser: Optional[Callable] = None, + transform_kwargs: Optional[Dict[str, Any]] = None, batch_size: int = 1, **data_module_kwargs, ) -> InstanceSegmentationData: @@ -53,7 +53,7 @@ def from_pets( test_folder=test_folder, test_ann_file=test_ann_file, predict_folder=predict_folder, - transform_kwargs=dict(image_size=image_size), + transform_kwargs=dict(image_size=(128, 128)) if transform_kwargs is None else transform_kwargs, parser=parser, val_split=val_split, batch_size=batch_size, diff --git a/flash/image/keypoint_detection/cli.py b/flash/image/keypoint_detection/cli.py index ecb6e07e4a..67b4154620 100644 --- a/flash/image/keypoint_detection/cli.py +++ b/flash/image/keypoint_detection/cli.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Tuple +from typing import Any, Callable, Dict, Optional from flash.core.utilities.flash_cli import FlashCLI from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires @@ -33,8 +33,8 @@ def from_biwi( test_ann_file: Optional[str] = None, predict_folder: Optional[str] = None, val_split: float = 0.1, - image_size: Tuple[int, int] = (128, 128), parser: Optional[Callable] = None, + transform_kwargs: Optional[Dict[str, Any]] = None, batch_size: int = 1, **data_module_kwargs, ) -> KeypointDetectionData: @@ -53,7 +53,7 @@ def from_biwi( test_ann_file=test_ann_file, predict_folder=predict_folder, val_split=val_split, - transform_kwargs=dict(image_size=image_size), + transform_kwargs=dict(image_size=(128, 128)) if transform_kwargs is None else transform_kwargs, batch_size=batch_size, parser=parser, **data_module_kwargs,