From f8a1432ac5be63bd1517c4ea912c1612ce3bc51a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 16 Jun 2021 12:48:50 +0100 Subject: [PATCH] Fixes --- flash/image/segmentation/heads.py | 7 ++----- tests/image/segmentation/test_heads.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index e793b53fed..4444b5c3ab 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -56,11 +56,8 @@ def _get_backbone_meta(backbone): backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. - stage_indices = sum([ - [0], - [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)], - [len(backbone) - 1], - ]) + stage_indices = ([0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + + [len(backbone) - 1]) out_pos = stage_indices[-1] # use C5 which has output_stride = 16 out_layer = str(out_pos) out_inplanes = backbone[out_pos].out_channels diff --git a/tests/image/segmentation/test_heads.py b/tests/image/segmentation/test_heads.py index 218dbbf441..ec90b03670 100644 --- a/tests/image/segmentation/test_heads.py +++ b/tests/image/segmentation/test_heads.py @@ -13,8 +13,8 @@ # limitations under the License. import pytest import torch -from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS