Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PointRend #34

Merged
merged 2 commits into from
Mar 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions configs/cityscapes_pointrend_deeplabv3_plus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
DATASET:
NAME: "cityscape"
MEAN: [0.5, 0.5, 0.5]
STD: [0.5, 0.5, 0.5]
TRAIN:
EPOCHS: 400
BATCH_SIZE: 2
CROP_SIZE: 768
TEST:
BATCH_SIZE: 2
CROP_SIZE: (1024, 2048)
# TEST_MODEL_PATH: trained_models/deeplabv3_plus_xception_segmentron.pth

SOLVER:
LR: 0.01

MODEL:
MODEL_NAME: "PointRend"
BACKBONE: "xception65"
BN_EPS_FOR_ENCODER: 1e-3
DEEPLABV3_PLUS:
ENABLE_DECODER: False

6 changes: 5 additions & 1 deletion segmentron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,15 @@ def remove_irrelevant_cfg(self):
from ..models.model_zoo import MODEL_REGISTRY
model_list = MODEL_REGISTRY.get_list()
model_list_lower = [x.lower() for x in model_list]
# print('model_list:', model_list)

assert model_name.lower() in model_list_lower, "Expected model name in {}, but received {}"\
.format(model_list, model_name)
pop_keys = []
for key in self.MODEL.keys():
if key.lower() in model_list_lower:
if model_name.lower() == 'pointrend' and \
key.lower() == self.MODEL.POINTREND.BASEMODEL.lower():
continue
if key.lower() in model_list_lower and key.lower() != model_name.lower():
pop_keys.append(key)
for key in pop_keys:
Expand Down
3 changes: 3 additions & 0 deletions segmentron/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@
cfg.MODEL.CGNET.STAGE2_BLOCK_NUM = 3
cfg.MODEL.CGNET.STAGE3_BLOCK_NUM = 21

########################## PointRend config ##################################
cfg.MODEL.POINTREND.BASEMODEL = 'DeepLabV3_Plus'

########################## hrnet config ######################################
cfg.MODEL.HRNET.PRETRAINED_LAYERS = ['*']
cfg.MODEL.HRNET.STEM_INPLANES = 64
Expand Down
3 changes: 3 additions & 0 deletions segmentron/data/dataloader/pascal_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def __getitem__(self, index):
img, target = self._sync_transform(img, target)
elif self.mode == 'val':
img, target = self._val_sync_transform(img, target)
elif self.mode == 'testval':
logging.warn("Use mode of testval, you should set batch size=1")
img, target = self._img_transform(img), self._mask_transform(target)
else:
raise RuntimeError('unknown mode for dataloader: {}'.format(self.mode))
# general resize, normalize and toTensor
Expand Down
1 change: 1 addition & 0 deletions segmentron/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .espnetv2 import ESPNetV2
from .enet import ENet
from .edanet import EDANet
from .pointrend import PointRend
10 changes: 9 additions & 1 deletion segmentron/models/backbones/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import torch.utils.model_zoo as model_zoo

from ...utils.download import download
from ...utils.registry import Registry
from ...config import cfg

Expand Down Expand Up @@ -42,7 +43,14 @@ def load_backbone_pretrained(model, backbone):
return
else:
logging.info('load backbone pretrained model from url..')
msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False)
try:
msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False)
except Exception as e:
logging.warning(e)
logging.info('Use torch download failed, try custom method!')

msg = model.load_state_dict(torch.load(download(model_urls[backbone],
path=os.path.join(torch.hub._get_torch_home(), 'checkpoints'))), strict=False)
logging.info(msg)


Expand Down
166 changes: 166 additions & 0 deletions segmentron/models/pointrend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models._utils import IntermediateLayerGetter
from .model_zoo import MODEL_REGISTRY
from .segbase import SegBaseModel
from ..config import cfg


@MODEL_REGISTRY.register(name='PointRend')
class PointRend(SegBaseModel):
def __init__(self):
super(PointRend, self).__init__(need_backbone=False)
model_name = cfg.MODEL.POINTREND.BASEMODEL
self.backbone = MODEL_REGISTRY.get(model_name)()

self.head = PointHead(num_classes=self.nclass)

def forward(self, x):
c1, _, _, c4 = self.backbone.encoder(x)

out = self.backbone.head(c4, c1)

result = {'res2': c1, 'coarse': out}
result.update(self.head(x, result["res2"], result["coarse"]))
if not self.training:
return (result['fine'],)
return result


class PointHead(nn.Module):
def __init__(self, in_c=275, num_classes=19, k=3, beta=0.75):
super().__init__()
self.mlp = nn.Conv1d(in_c, num_classes, 1)
self.k = k
self.beta = beta

def forward(self, x, res2, out):
"""
1. Fine-grained features are interpolated from res2 for DeeplabV3
2. During training we sample as many points as there are on a stride 16 feature map of the input
3. To measure prediction uncertainty
we use the same strategy during training and inference: the difference between the most
confident and second most confident class probabilities.
"""
if not self.training:
return self.inference(x, res2, out)

points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)

coarse = point_sample(out, points, align_corners=False)
fine = point_sample(res2, points, align_corners=False)

feature_representation = torch.cat([coarse, fine], dim=1)

rend = self.mlp(feature_representation)

return {"rend": rend, "points": points}

@torch.no_grad()
def inference(self, x, res2, out):
"""
During inference, subdivision uses N=8096
(i.e., the number of points in the stride 16 map of a 1024×2048 image)
"""
num_points = 8096

while out.shape[-1] != x.shape[-1]:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)

points_idx, points = sampling_points(out, num_points, training=self.training)

coarse = point_sample(out, points, align_corners=False)
fine = point_sample(res2, points, align_corners=False)

feature_representation = torch.cat([coarse, fine], dim=1)

rend = self.mlp(feature_representation)

B, C, H, W = out.shape
points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
out = (out.reshape(B, C, -1)
.scatter_(2, points_idx, rend)
.view(B, C, H, W))

return {"fine": out}


def point_sample(input, point_coords, **kwargs):
"""
From Detectron2, point_features.py#19
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
[0, 1] x [0, 1] square.
Args:
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
[0, 1] x [0, 1] normalized point coordinates.
Returns:
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
features for points in `point_coords`. The features are obtained via bilinear
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
"""
add_dim = False
if point_coords.dim() == 3:
add_dim = True
point_coords = point_coords.unsqueeze(2)
output = F.grid_sample(input, 2.0 * point_coords - 1.0)#, **kwargs)
if add_dim:
output = output.squeeze(3)
return output


@torch.no_grad()
def sampling_points(mask, N, k=3, beta=0.75, training=True):
"""
Follows 3.1. Point Selection for Inference and Training
In Train:, `The sampling strategy selects N points on a feature map to train on.`
In Inference, `then selects the N most uncertain points`
Args:
mask(Tensor): [B, C, H, W]
N(int): `During training we sample as many points as there are on a stride 16 feature map of the input`
k(int): Over generation multiplier
beta(float): ratio of importance points
training(bool): flag
Return:
selected_point(Tensor) : flattened indexing points [B, num_points, 2]
"""
assert mask.dim() == 4, "Dim must be N(Batch)CHW"
device = mask.device
B, _, H, W = mask.shape
mask, _ = mask.sort(1, descending=True)

if not training:
H_step, W_step = 1 / H, 1 / W
N = min(H * W, N)
uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])
_, idx = uncertainty_map.view(B, -1).topk(N, dim=1)

points = torch.zeros(B, N, 2, dtype=torch.float, device=device)
points[:, :, 0] = W_step / 2.0 + (idx % W).to(torch.float) * W_step
points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step
return idx, points

# Official Comment : point_features.py#92
# It is crucial to calculate uncertanty based on the sampled prediction value for the points.
# Calculating uncertainties of the coarse predictions first and sampling them for points leads
# to worse results. To illustrate the difference: a sampled point between two coarse predictions
# with -1 and 1 logits has 0 logit prediction and therefore 0 uncertainty value, however, if one
# calculates uncertainties for the coarse predictions first (-1 and -1) and sampe it for the
# center point, they will get -1 unceratinty.

over_generation = torch.rand(B, k * N, 2, device=device)
over_generation_map = point_sample(mask, over_generation, align_corners=False)

uncertainty_map = -1 * (over_generation_map[:, 0] - over_generation_map[:, 1])
_, idx = uncertainty_map.topk(int(beta * N), -1)

shift = (k * N) * torch.arange(B, dtype=torch.long, device=device)

idx += shift[:, None]

importance = over_generation.view(-1, 2)[idx.view(-1), :].view(B, int(beta * N), 2)
coverage = torch.rand(B, N - int(beta * N), 2, device=device)
return torch.cat([importance, coverage], 1).to(device)
31 changes: 30 additions & 1 deletion segmentron/solver/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from torch.autograd import Variable
from .lovasz_losses import lovasz_softmax
from ..models.pointrend import point_sample
from ..data.dataloader import datasets
from ..config import cfg

Expand Down Expand Up @@ -360,6 +361,32 @@ def forward(self, *inputs):
return dict(loss=self._aux_forward(*inputs))


class PointRendLoss(nn.CrossEntropyLoss):
def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):
super(PointRendLoss, self).__init__(ignore_index=ignore_index)
self.aux = aux
self.aux_weight = aux_weight
self.ignore_index = ignore_index

def forward(self, *inputs, **kwargs):
result, gt = tuple(inputs)

pred = F.interpolate(result["coarse"], gt.shape[-2:], mode="bilinear", align_corners=True)
seg_loss = F.cross_entropy(pred, gt, ignore_index=self.ignore_index)

gt_points = point_sample(
gt.float().unsqueeze(1),
result["points"],
mode="nearest",
align_corners=False
).squeeze_(1).long()
points_loss = F.cross_entropy(result["rend"], gt_points, ignore_index=self.ignore_index)

loss = seg_loss + points_loss

return dict(loss=loss)


def get_segmentation_loss(model, use_ohem=False, **kwargs):
if use_ohem:
return MixSoftmaxCrossEntropyOHEMLoss(**kwargs)
Expand All @@ -373,11 +400,13 @@ def get_segmentation_loss(model, use_ohem=False, **kwargs):
logging.info('Use dice loss!')
return DiceLoss(**kwargs)


model = model.lower()
if model == 'icnet':
return ICNetLoss(**kwargs)
elif model == 'encnet':
return EncNetLoss(**kwargs)
elif model == 'pointrend':
logging.info('Use pointrend loss!')
return PointRendLoss(**kwargs)
else:
return MixSoftmaxCrossEntropyLoss(**kwargs)
5 changes: 1 addition & 4 deletions segmentron/utils/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ def update(self, preds, labels):
"""

def reduce_tensor(tensor):
if isinstance(tensor, torch.Tensor):
rt = tensor.clone()
else:
rt = copy.deepcopy(tensor)
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
return rt

Expand Down
7 changes: 4 additions & 3 deletions tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def __init__(self, args):
# create network
self.model = get_segmentation_model().to(self.device)

if hasattr(self.model, 'encoder') and cfg.MODEL.BN_EPS_FOR_ENCODER:
logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER))
self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER)
if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
cfg.MODEL.BN_EPS_FOR_ENCODER:
logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER))
self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER)

if args.distributed:
self.model = nn.parallel.DistributedDataParallel(self.model,
Expand Down