forked from amiltonwong/erfnet_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
282 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Code to produce segmentation output in Pytorch for all cityscapes subset | ||
# Sept 2017 | ||
# Eduardo Romera | ||
####################### | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
import os | ||
import importlib | ||
import time | ||
|
||
from PIL import Image | ||
from argparse import ArgumentParser | ||
|
||
from torch.autograd import Variable | ||
from torch.utils.data import DataLoader | ||
from torchvision.transforms import Compose, CenterCrop, Normalize, Scale | ||
from torchvision.transforms import ToTensor, ToPILImage | ||
|
||
from dataset import cityscapes | ||
from erfnet import ERFNet | ||
from transform import Relabel, ToLabel, Colorize | ||
from iouEval import iouEval, getColorEntry | ||
|
||
NUM_CHANNELS = 3 | ||
NUM_CLASSES = 20 | ||
|
||
image_transform = ToPILImage() | ||
input_transform_cityscapes = Compose([ | ||
Scale(512, Image.BILINEAR), | ||
ToTensor(), | ||
]) | ||
target_transform_cityscapes = Compose([ | ||
Scale(512, Image.NEAREST), | ||
ToLabel(), | ||
Relabel(255, 19), #ignore label to 19 | ||
]) | ||
|
||
cityscapes_trainIds2labelIds = Compose([ | ||
Relabel(19, 255), | ||
Relabel(18, 33), | ||
Relabel(17, 32), | ||
Relabel(16, 31), | ||
Relabel(15, 28), | ||
Relabel(14, 27), | ||
Relabel(13, 26), | ||
Relabel(12, 25), | ||
Relabel(11, 24), | ||
Relabel(10, 23), | ||
Relabel(9, 22), | ||
Relabel(8, 21), | ||
Relabel(7, 20), | ||
Relabel(6, 19), | ||
Relabel(5, 17), | ||
Relabel(4, 13), | ||
Relabel(3, 12), | ||
Relabel(2, 11), | ||
Relabel(1, 8), | ||
Relabel(0, 7), | ||
Relabel(255, 0), | ||
ToPILImage(), | ||
Scale(1024, Image.NEAREST), | ||
]) | ||
|
||
def main(args): | ||
|
||
modelpath = args.loadDir + args.loadModel | ||
weightspath = args.loadDir + args.loadWeights | ||
|
||
print ("Loading model: " + modelpath) | ||
print ("Loading weights: " + weightspath) | ||
|
||
model = ERFNet(NUM_CLASSES) | ||
|
||
model = torch.nn.DataParallel(model) | ||
if (not args.cpu): | ||
model = model.cuda() | ||
|
||
def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements | ||
own_state = model.state_dict() | ||
for name, param in state_dict.items(): | ||
if name not in own_state: | ||
print(name, " not loaded") | ||
continue | ||
own_state[name].copy_(param) | ||
return model | ||
|
||
model = load_my_state_dict(model, torch.load(weightspath)) | ||
print ("Model and weights LOADED successfully") | ||
|
||
model.eval() | ||
|
||
if(not os.path.exists(args.datadir)): | ||
print ("Error: datadir could not be loaded") | ||
|
||
|
||
loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) | ||
|
||
|
||
iouEvalVal = iouEval(NUM_CLASSES) | ||
|
||
start = time.time() | ||
|
||
for step, (images, labels, filename, filenameGt) in enumerate(loader): | ||
if (not args.cpu): | ||
images = images.cuda() | ||
labels = labels.cuda() | ||
|
||
inputs = Variable(images, volatile=True) | ||
outputs = model(inputs) | ||
|
||
iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, labels) | ||
|
||
filenameSave = filename[0].split("leftImg8bit/")[1] | ||
|
||
print (step, filenameSave) | ||
|
||
|
||
iouVal, iou_classes = iouEvalVal.getIoU() | ||
|
||
iou_classes_str = [] | ||
for i in range(iou_classes.size(0)): | ||
iouStr = getColorEntry(iou_classes[i])+'{:0.2f}'.format(iou_classes[i]*100) + '\033[0m' | ||
iou_classes_str.append(iouStr) | ||
|
||
print("---------------------------------------") | ||
print("Took ", time.time()-start, "seconds") | ||
print("=======================================") | ||
#print("TOTAL IOU: ", iou * 100, "%") | ||
print("Per-Class IoU:") | ||
print(iou_classes_str[0], "Road") | ||
print(iou_classes_str[1], "sidewalk") | ||
print(iou_classes_str[2], "building") | ||
print(iou_classes_str[3], "wall") | ||
print(iou_classes_str[4], "fence") | ||
print(iou_classes_str[5], "pole") | ||
print(iou_classes_str[6], "traffic light") | ||
print(iou_classes_str[7], "traffic sign") | ||
print(iou_classes_str[8], "vegetation") | ||
print(iou_classes_str[9], "terrain") | ||
print(iou_classes_str[10], "sky") | ||
print(iou_classes_str[11], "person") | ||
print(iou_classes_str[12], "rider") | ||
print(iou_classes_str[13], "car") | ||
print(iou_classes_str[14], "truck") | ||
print(iou_classes_str[15], "bus") | ||
print(iou_classes_str[16], "train") | ||
print(iou_classes_str[17], "motorcycle") | ||
print(iou_classes_str[18], "bicycle") | ||
print("=======================================") | ||
iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m' | ||
print ("MEAN IoU: ", iouStr, "%") | ||
|
||
if __name__ == '__main__': | ||
parser = ArgumentParser() | ||
|
||
parser.add_argument('--state') | ||
|
||
parser.add_argument('--loadDir',default="../trained_models/") | ||
parser.add_argument('--loadWeights', default="erfnet_pretrained.pth") | ||
parser.add_argument('--loadModel', default="erfnet.py") | ||
parser.add_argument('--subset', default="val") #can be val or train (must have labels) | ||
parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/") | ||
parser.add_argument('--num-workers', type=int, default=4) | ||
parser.add_argument('--batch-size', type=int, default=1) | ||
parser.add_argument('--cpu', action='store_true') | ||
|
||
main(parser.parse_args()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Code for evaluating IoU | ||
# Nov 2017 | ||
# Eduardo Romera | ||
####################### | ||
|
||
import torch | ||
|
||
class iouEval: | ||
|
||
def __init__(self, nClasses, ignoreIndex=19): | ||
self.nClasses = nClasses | ||
self.ignoreIndex = ignoreIndex | ||
self.reset() | ||
|
||
def reset (self): | ||
self.tp = torch.zeros(self.nClasses-1).double() | ||
self.fp = torch.zeros(self.nClasses-1).double() | ||
self.fn = torch.zeros(self.nClasses-1).double() | ||
|
||
def addBatch(self, x, y): #x=preds, y=targets | ||
#sizes should be "batch_size x nClasses x H x W" | ||
|
||
#if size is "batch_size x 1 x H x W" scatter to onehot | ||
if (x.size(1) == 1): | ||
x_onehot = torch.zeros(x.size(0), self.nClasses, x.size(2), x.size(3)).cuda() | ||
x_onehot.scatter_(1, x, 1).float() | ||
else: | ||
x_onehot = x.float() | ||
|
||
if (y.size(1) == 1): | ||
y_onehot = torch.zeros(y.size(0), self.nClasses, y.size(2), y.size(3)).cuda() | ||
y_onehot.scatter_(1, y, 1).float() | ||
else: | ||
y_onehot = y.float() | ||
#TODO cuda only if x is cuda | ||
|
||
ignores = y_onehot[:,self.ignoreIndex].unsqueeze(1) | ||
x_onehot = x_onehot[:, :self.ignoreIndex] | ||
y_onehot = y_onehot[:, :self.ignoreIndex] | ||
|
||
#print(type(x_onehot)) | ||
#print(type(y_onehot)) | ||
#print(x_onehot.size()) | ||
#print(y_onehot.size()) | ||
|
||
tpmult = x_onehot * y_onehot #times prediction and gt coincide is 1 | ||
tp = torch.sum(torch.sum(torch.sum(tpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() | ||
fpmult = x_onehot * (1-y_onehot-ignores) #times prediction says its that class and gt says its not (subtracting cases when its ignore label!) | ||
fp = torch.sum(torch.sum(torch.sum(fpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() | ||
fnmult = (1-x_onehot) * (y_onehot) #times prediction says its not that class and gt says it is | ||
fn = torch.sum(torch.sum(torch.sum(fnmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() | ||
|
||
self.tp += tp.double().cpu() | ||
self.fp += fp.double().cpu() | ||
self.fn += fn.double().cpu() | ||
|
||
def getIoU(self): | ||
num = self.tp | ||
den = self.tp + self.fp + self.fn + 1e-15 | ||
iou = num / den | ||
return torch.mean(iou), iou #returns "iou mean", "iou per class" | ||
|
||
# Class for colors | ||
class colors: | ||
RED = '\033[31;1m' | ||
GREEN = '\033[32;1m' | ||
YELLOW = '\033[33;1m' | ||
BLUE = '\033[34;1m' | ||
MAGENTA = '\033[35;1m' | ||
CYAN = '\033[36;1m' | ||
BOLD = '\033[1m' | ||
UNDERLINE = '\033[4m' | ||
ENDC = '\033[0m' | ||
|
||
# Colored value output if colorized flag is activated. | ||
def getColorEntry(val): | ||
if not isinstance(val, float): | ||
return colors.ENDC | ||
if (val < .20): | ||
return colors.RED | ||
elif (val < .40): | ||
return colors.YELLOW | ||
elif (val < .60): | ||
return colors.BLUE | ||
elif (val < .80): | ||
return colors.CYAN | ||
else: | ||
return colors.GREEN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters