Skip to content

Commit

Permalink
Add parse and remove large files
Browse files Browse the repository at this point in the history
  • Loading branch information
dscarmo committed Mar 28, 2023
1 parent e8c63e4 commit 9e6533b
Show file tree
Hide file tree
Showing 16 changed files with 574 additions and 91 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Custom
*.ckpt
*.nii.gz
*.csv

Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
v2.5
* Added Parse vessel segmentation output
* Deprecated some old code
* Move large files to separate release

v2.4

* Better input/output flow, option to use less memory by not showing activations
Expand Down
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Modified EfficientDet Segmentation (MEDSeg)
Official repository for reproducing lung, COVID-19 and airway automated segmentation using our MEDSeg model.
Official repository for reproducing lung, COVID-19, airway and pulmonary artery automated segmentation using our MEDSeg model.

The publication original publication for this method, **Multitasking segmentation of lung and COVID-19 findings in CT scans using modified EfficientDet, UNet and MobileNetV3 models**, has been published at the 17th International Symposium on Medical Information Processing and Analysis (SIPAIM 2021), and won the "SIPAIM Society Award".
http://dx.doi.org/10.1117/12.2606118
Expand All @@ -10,8 +10,10 @@ https://www.youtube.com/watch?v=PlhNUD0Y4hg
We have also applied this model in the ATM22 Challenge (https://atm22.grand-challenge.org/). Airway segmentation is included, with a CLI argument (--atm_mode) to only segment the airway, using less memory. A short paper about this is published in arXiv **Open-source tool for Airway Segmentation in
Computed Tomography using 2.5D Modified EfficientDet: Contribution to the ATM22 Challenge**: https://arxiv.org/pdf/2209.15094.pdf

We have also trained this model to the PARSE Challenge (https://parse2022.grand-challenge.org/), (Pulmonary Artery segmentation). Pulmonary artery labels will be included in the outputs. The model achieved around 0.7 Dice in testing. An paper detailing this application will be published in the future.

## Citation
* **COVID-19 segmentation**: Carmo, Diedre, et al. "Multitasking segmentation of lung and COVID-19 findings in CT scans using modified EfficientDet, UNet and MobileNetV3 models." 17th International Symposium on Medical Information Processing and Analysis. Vol. 12088. SPIE, 2021.
* **COVID-19 segmentation and method in general**: Carmo, Diedre, et al. "Multitasking segmentation of lung and COVID-19 findings in CT scans using modified EfficientDet, UNet and MobileNetV3 models." 17th International Symposium on Medical Information Processing and Analysis. Vol. 12088. SPIE, 2021.

* @inproceedings{carmo2021multitasking,\
title={Multitasking segmentation of lung and COVID-19 findings in CT scans using modified EfficientDet, UNet and MobileNetV3 models},\
Expand Down Expand Up @@ -68,6 +70,8 @@ All additional required libraries and the tool itself will be installed with the

If you use virtual environments, it is safer to install in a new virtual environment to avoid conflicts.

Finally, due to the large size of network weights, you need to go into the Release in this repository, download the data.zip file and extract the .ckpt files inside the medseg folder. The .ckpt files should be in the same directory level as the run.py file.

## Running

To run, just call it in a terminal.
Expand All @@ -91,3 +95,7 @@ If you have any problems, make sure your pip is the same from your miniconda ins
by checking if pip --version points to the miniconda directory.

If you have any issues, feel free to create an issue on this repository.

### Known Issue

"Long prediction" mode is not working due to recent changes in the architecutre. However not using it should be enough for most cases, Long Prediction uses more models in the final ensemble.
2 changes: 1 addition & 1 deletion medseg/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.4.1"
__version__ = "2.5.0"
Binary file removed medseg/airway.ckpt
Binary file not shown.
83 changes: 78 additions & 5 deletions medseg/architecture.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,62 @@
'''
If you use please cite:
CARMO, Diedre et al. Multitasking segmentation of lung and COVID-19 findings in CT scans using modified EfficientDet, UNet and MobileNetV3 models. In: 17th International Symposium on Medical Information Processing and Analysis. SPIE, 2021. p. 65-74.
'''
import torch
from torch import nn

from efficientnet_pytorch.utils import round_filters
from medseg.edet.modeling_efficientdet import EfficientDetForSemanticSegmentation


class MEDSeg(nn.Module):
def __init__(self, nin=3, nout=3, apply_sigmoid=False, dropout=None, backbone="effnet", pretrained=True, expand_bifpn="conv"):
def __init__(self, nin=3, nout=3, apply_sigmoid=False, dropout=None, backbone="effnet", pretrained=True, expand_bifpn="upsample", imnet_norm=False,
num_classes_atm=None,
num_classes_rec=None,
num_classes_vessel=None,
stem_replacement=False,
new_latent_space=False,
compound_coef=4): # compound always has been 4 by default before
super().__init__()
print("WARNING: default expand_bifpn changed to upsample!")
self.model = EfficientDetForSemanticSegmentation(num_classes=nout,
load_weights=pretrained,
apply_sigmoid=apply_sigmoid,
expand_bifpn=expand_bifpn,
dropout=dropout,
backbone=backbone)
backbone=backbone,
compound_coef=compound_coef,
num_classes_atm=num_classes_atm,
num_classes_rec=num_classes_rec,
num_classes_vessel=num_classes_vessel,
new_latent_space=new_latent_space)

self.feature_adapters = self.model.feature_adapters

if imnet_norm:
print("Performing imnet normalization internally, assuming inputs between 1 and 0")
self.imnet_norm = ImNetNorm()
else:
self.imnet_norm = nn.Identity()

self.nin = nin
if self.nin not in [1, 3]:
self.in_conv = nn.Conv2d(in_channels=self.nin, out_channels=3, kernel_size=1, stride=1, padding=0, bias=False)

print(f"MEDSeg initialized. nin: {nin}, nout: {nout}, apply_sigmoid: {apply_sigmoid}, dropout: {dropout}, backbone: {backbone}, pretrained: {pretrained}, expand_bifpn: {expand_bifpn}, align DISABLED")
if stem_replacement:
assert backbone == "effnet", "Stem replacement only valid for efficientnet"
print("Performing stem replacement on EfficientNet backbone (this runs after initialization)")
self.model.backbone_net.model._conv_stem = EffNet3DStemReplacement(self.model.backbone_net.model)

print(f"MEDSeg initialized. nin: {nin}, nout: {nout}, apply_sigmoid: {apply_sigmoid}, dropout: {dropout},"
f"backbone: {backbone}, pretrained: {pretrained}, expand_bifpn: {expand_bifpn}, pad align DISABLED, stem_replacement {stem_replacement}"
f"new latent space extraction {new_latent_space}")

def extract_backbone_features(self, inputs):
return self.model.extract_backbone_features(inputs)

def extract_bifpn_features(self, features):
return self.model.extract_bifpn_features(features)

def forward(self, x):
if self.nin == 1:
x_in = torch.zeros(size=(x.shape[0], 3) + x.shape[2:], device=x.device, dtype=x.dtype)
Expand All @@ -32,5 +69,41 @@ def forward(self, x):
else:
x = self.in_conv(x)

x = self.imnet_norm(x)

return self.model(x)



class EffNet3DStemReplacement(nn.Module):
def __init__(self, effnet_pytorch_instance):
super().__init__()
out_channels = round_filters(32, effnet_pytorch_instance._global_params)
self.conv = nn.Conv3d(1, out_channels, kernel_size=3, stride=1, padding="valid", bias=False)
self.pad = nn.ZeroPad2d(1)
self.conv_pool = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2, padding=0, bias=False)

def forward(self, x):
'''
x is 4D batch but will be treated as 5D
'''
x = self.conv(x.unsqueeze(1)).squeeze(2) # [B, 3, X, Y] -> [B, 1, 3, X, Y]
# -> [B, OUT_CH, 1, X, Y] -> [B, OUT_CH, X, Y]
x = self.pad(x)
x = self.conv_pool(x)
return x


class ImNetNorm():
'''
Assumes input between 1 and 0
'''
def __init__(self):
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]

def __call__(self, xim):
with torch.no_grad():
for i in range(3):
xim[:, i] = (xim[:, i] - self.mean[i])/self.std[i]

return xim
Binary file removed medseg/best_coedet.ckpt
Binary file not shown.
211 changes: 211 additions & 0 deletions medseg/convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
from timm.models.registry import register_model


class Block(nn.Module):
r""" ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

x = input + self.drop_path(x)
return x

class ConvNeXt(nn.Module):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` -
https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(self, in_chans=3, num_classes=1000,
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
layer_scale_init_value=1e-6, head_init_scale=1.,
):
super().__init__()

self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)

self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]

self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
self.head = nn.Linear(dims[-1], num_classes)

self.apply(self._init_weights)
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)

def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)

def forward_features(self, x):
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)

def forward_seg_features(self, x, convnext_expansion_scale, range_limit=3):
outs = []
for i in range(range_limit):
x = self.downsample_layers[i](x)
if convnext_expansion_scale <= 0:
outs.append(self.stages[i](x))
else:
outs.append(F.upsample_bilinear(self.stages[i](x), scale_factor=convnext_expansion_scale))
return outs

def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x

class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )

def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x


model_urls = {
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
"convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
"convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
"convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
"convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
"convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
"convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
"convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
"convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
}

@register_model
def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
if pretrained:
url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model

@register_model
def convnext_small(pretrained=False,in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
if pretrained:
url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model

@register_model
def convnext_base(pretrained=False, in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
if pretrained:
url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model

@register_model
def convnext_large(pretrained=False, in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
if pretrained:
url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model

@register_model
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
if pretrained:
assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
url = model_urls['convnext_xlarge_22k']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
Loading

0 comments on commit 9e6533b

Please sign in to comment.