Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Semantic Segmentation target masks broken >0.7.5 #1489

Closed
newzealandpaul opened this issue Nov 25, 2022 · 2 comments · Fixed by #1509
Closed

Semantic Segmentation target masks broken >0.7.5 #1489

newzealandpaul opened this issue Nov 25, 2022 · 2 comments · Fixed by #1509
Labels
bug / fix Something isn't working help wanted Extra attention is needed Priority

Comments

@newzealandpaul
Copy link

newzealandpaul commented Nov 25, 2022

🐛 Bug

The switch to albumentation in newer releases of lightning-flash seem to have broken transformation of segmentation targets.

This is what I expect masks to look like (screenshot showing below code sample running on 0.7.5):

2022-11-25 25-10-10-49--259_chrome

This is what it looks like on the latest release (0.8.1):

2022-11-25 25-13-13-32--468_chrome

To Reproduce

Run the below sample with lightning-flash=0.7.5 and lightning-flash=0.8.1 and compare behavior.

Code sample

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData

import matplotlib.pyplot as plt
import numpy as np

# 1. Create the DataModule
# The data was generated with the  CARLA self-driving simulator as part of the Kaggle Lyft Udacity Challenge.
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
# download_data(
#     "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
#     "./data",
# )

datamodule = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    val_split=0.1,
    transform_kwargs=dict(image_size=(256, 256)),
    num_classes=21,
    batch_size=4,
)

# 2. Build the task
model = SemanticSegmentation(
    backbone="mobilenetv3_large_100",
    head="fpn",
    num_classes=datamodule.num_classes,
)

n = 3
fig, axarr = plt.subplots(ncols=2, nrows=n, figsize=(8, 4*n))

for batch in datamodule.train_dataloader():
    print(batch.keys())
    for i in range(n):
        segm = batch['target'][i]
        print(segm.shape)
        img = np.rollaxis(batch['input'][i].numpy(), 0, 3)
        axarr[i, 0].imshow(img)
        axarr[i, 1].imshow(segm)
    break

Environment

  • OS: Ubuntu WSL2
  • Python version: 3.10.8
  • GPU model: RTX 3080
  • CUDA Version: 11.6
@newzealandpaul newzealandpaul added bug / fix Something isn't working help wanted Extra attention is needed labels Nov 25, 2022
@noname202
Copy link

I can confirm this issue. Just spent a significant amount of time trying to figure out if there is anything wrong with my code. Any estimations when this is going to be fixed?

@Borda Borda added the Priority label Dec 23, 2022
@Borda
Copy link
Member

Borda commented Dec 23, 2022

Hi, @newzealandpaul @noname202 we are sorry for this bug which seems to be very critical for segmentations... :( Would you be interested in trying to debug it and I belive that @ethanwharris could eventually give a hand... 🐰

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working help wanted Extra attention is needed Priority
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants