From 5000d7f48d30d19da66ea684d3e8ba7c26e04df1 Mon Sep 17 00:00:00 2001 From: Bodo Kaiser Date: Fri, 10 Mar 2017 09:53:03 +0100 Subject: [PATCH 1/8] removed MNIBITEFolder as data source from main.py --- main.py | 71 ++++++++++++--------------------------------------------- 1 file changed, 15 insertions(+), 56 deletions(-) diff --git a/main.py b/main.py index cf5cae8..5769b09 100644 --- a/main.py +++ b/main.py @@ -78,23 +78,10 @@ def update(images): def main(args): model = network.Simple() - test_loader = data.DataLoader(dataset.MNIBITEFolder( - map(lambda d: os.path.join(args.datadir, d), args.test)), - shuffle=True, batch_size=128, num_workers=4) - train_loader = data.DataLoader(dataset.MNIBITEFolder( - map(lambda d: os.path.join(args.datadir, d), args.train), - transform=transforms.Compose([ - # for some reason not using this gives better performance - #transform.RandomZoom(), - #transform.RandomRotate(), - transform.RandomFlipUpDown(), - transform.RandomFlipLeftRight(), - ])), - shuffle=True, batch_size=128, num_workers=4) - - if args.show_images: - image_loader = data.DataLoader(dataset.MNIBITENative(args.datadir, - int(args.train[0]), transform.RegionCrop()), shuffle=True) + test_loader = data.DataLoader(dataset.MNIBITENative(args.datadir, + int(args.test[0]), transform.RegionCrop()), shuffle=True) + train_loader = data.DataLoader(dataset.MNIBITENative(args.datadir, + int(args.train[0]), transform.RegionCrop()), shuffle=True) test_losses = [] train_losses = [] @@ -105,70 +92,43 @@ def main(args): if args.show_loss: update_loss = loss_plot() if args.show_images: - update_images = image_plot('training images', - ['MRI', 'US', 'RE']) - if args.show_patches: - update_patches = image_plot('training and testing patches', - ['MRI', 'US', 'RE'], rows=2) + update_images = image_plot('Training Images', ['MRI', 'US', 'OUT']) for epoch in range(1, args.epochs+1): test_loss = 0 train_loss = 0 - for step, (mr, us) in enumerate(train_loader): + for step, (mr, us) in enumerate(test_loader): inputs = autograd.Variable(mr).unsqueeze(1) targets = autograd.Variable(us).unsqueeze(1) results = model(inputs) - optimizer.zero_grad() loss = criterion(results, targets) - loss.backward() - optimizer.step() - - train_loss += loss.data[0] - - train_patches = [ - inputs.data[0][0].numpy(), - targets.data[0][0].numpy(), - results.data[0][0].numpy(), - ] - train_losses.append(train_loss) + test_loss += loss.data[0] - for step, (mr, us) in enumerate(test_loader): + for step, (mr, us) in enumerate(train_loader): inputs = autograd.Variable(mr).unsqueeze(1) targets = autograd.Variable(us).unsqueeze(1) results = model(inputs) + optimizer.zero_grad() loss = criterion(results, targets) - test_loss += loss.data[0] + loss.backward() + optimizer.step() + + train_loss += loss.data[0] - test_patches = [ - inputs.data[0][0].numpy(), - targets.data[0][0].numpy(), - results.data[0][0].numpy(), - ] test_losses.append(test_loss) + train_losses.append(train_loss) if args.show_loss: update_loss(train_losses, test_losses) if args.show_images: - for _, (mr, us) in enumerate(image_loader): - if np.any(us.numpy()) and sum(us.numpy().shape[1:3]) > 30: - inputs = autograd.Variable(mr).unsqueeze(1) - targets = autograd.Variable(us).unsqueeze(1) - results = model(inputs) - break - update_images([ inputs.data[0][0].numpy(), targets.data[0][0].numpy(), results.data[0][0].numpy(), ]) - if args.show_patches: - update_patches([ - *train_patches, - *test_patches, - ]) print(f'testing (epoch: {epoch}, loss: {test_loss}') print(f'training (epoch: {epoch}, loss: {train_loss})') @@ -183,6 +143,5 @@ def main(args): parser.add_argument('--datadir', type=str, nargs='?', default='mnibite') parser.add_argument('--show-loss', dest='show_loss', action='store_true') parser.add_argument('--show-images', dest='show_images', action='store_true') - parser.add_argument('--show-patches', dest='show_patches', action='store_true') - parser.set_defaults(show_loss=False, show_images=False, show_patches=False) + parser.set_defaults(show_loss=False, show_images=False) main(parser.parse_args()) \ No newline at end of file From 6b1d0b5a55b9306f9f44051624b8c9532f5c28ed Mon Sep 17 00:00:00 2001 From: Bodo Kaiser Date: Fri, 10 Mar 2017 09:59:05 +0100 Subject: [PATCH 2/8] updated scripts/patch.py to respect axis --- mrtous/dataset.py | 1 + scripts/patch.py | 18 ++++++++---------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/mrtous/dataset.py b/mrtous/dataset.py index c4de03e..eaee4a0 100644 --- a/mrtous/dataset.py +++ b/mrtous/dataset.py @@ -40,6 +40,7 @@ class MNIBITENative(Dataset): def __init__(self, root, id, transform=None, axis='z'): self.mr = MINC2(os.path.join(root, f'{id:02d}_mr.mnc'), axis) self.us = MINC2(os.path.join(root, f'{id:02d}_us.mnc'), axis) + self.axis = axis self.transform = transform def __getitem__(self, index): diff --git a/scripts/patch.py b/scripts/patch.py index f954aaf..cf440d1 100644 --- a/scripts/patch.py +++ b/scripts/patch.py @@ -4,8 +4,6 @@ import numpy as np import skimage as sk -sys.path.append('..') - from mrtous import dataset from skimage import io, util, exposure @@ -27,6 +25,8 @@ def main(args): os.makedirs(targetdir, exist_ok=True) for mnibite in mnibites: + axis = mnibite.axis + for _, (mr_image, us_image) in enumerate(mnibite): mr_patches = image_to_patches(mr_image, args.targetsize) us_patches = image_to_patches(us_image, args.targetsize) @@ -41,19 +41,17 @@ def main(args): exposure.rescale_intensity(us_patches[index], out_range='float')) - io.imsave(os.path.join(targetdir, f'{index}_mr.png'), + io.imsave(os.path.join(targetdir, f'{index}_{axis}_mr.png'), mr_patch, plugin='freeimage') - io.imsave(os.path.join(targetdir, f'{index}_us.png'), + io.imsave(os.path.join(targetdir, f'{index}_{axis}_us.png'), us_patch, plugin='freeimage') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=int, required=True) - parser.add_argument('--datadir', type=str, nargs='?') - parser.add_argument('--threshold', type=float, nargs='?') - parser.add_argument('--targetdir', type=str, nargs='?') - parser.add_argument('--targetsize', type=int, nargs='?') - parser.set_defaults(datadir='mnibite', targetdir='mnibite', - targetsize=25, threshold=.1) + parser.add_argument('--datadir', type=str, nargs='?', default='mnibite') + parser.add_argument('--targetdir', type=str, nargs='?', default='.') + parser.add_argument('--targetsize', type=int, nargs='?', default=25) + parser.add_argument('--threshold', type=float, nargs='?', default=.5) main(parser.parse_args()) \ No newline at end of file From cb934f665cddf183696fde8b52e99a40153e861f Mon Sep 17 00:00:00 2001 From: Bodo Kaiser Date: Fri, 10 Mar 2017 10:11:21 +0100 Subject: [PATCH 3/8] removed Simple from network, updated weight init --- main.py | 3 ++- mrtous/network.py | 24 +++++------------------- 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/main.py b/main.py index 5769b09..f4781ff 100644 --- a/main.py +++ b/main.py @@ -76,7 +76,8 @@ def update(images): return update def main(args): - model = network.Simple() + model = network.Basic() + model.apply(network.normal_init) test_loader = data.DataLoader(dataset.MNIBITENative(args.datadir, int(args.test[0]), transform.RegionCrop()), shuffle=True) diff --git a/mrtous/network.py b/mrtous/network.py index fd5535f..baa2354 100644 --- a/mrtous/network.py +++ b/mrtous/network.py @@ -1,8 +1,11 @@ import torch import torch.nn as nn -def init_weight(layer): - layer.weight.data.normal_(0.5, 0.2) +def normal_init(model): + classname = model.__class__.__name__ + + if classname.find('Conv') != -1: + model.weight.data.normal_(0.5, 0.3) class Basic(nn.Module): @@ -12,24 +15,7 @@ def __init__(self): self.conv = nn.Conv2d(1, 3, 3, padding=1) self.conn = nn.Conv2d(3, 1, 1) - init_weight(self.conv) - def forward(self, x): x = self.conv(x) x = self.conn(x) - return x - -class Simple(nn.Module): - - def __init__(self): - super().__init__() - - self.conv1 = nn.Conv2d(1, 16, 5, padding=2) - self.conv2 = nn.Conv2d(16, 64, 3, padding=1) - self.final = nn.Conv2d(64, 1, 1) - - def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - x = self.final(x) return x \ No newline at end of file From 9d56cc7741cb43bdeb7da82458b7746a3e8304ce Mon Sep 17 00:00:00 2001 From: Bodo Kaiser Date: Fri, 10 Mar 2017 10:24:39 +0100 Subject: [PATCH 4/8] merged RandomFlipLeftRight and RandomFlipUpDown --- mrtous/dataset.py | 2 +- mrtous/transform.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/mrtous/dataset.py b/mrtous/dataset.py index eaee4a0..aadc4c1 100644 --- a/mrtous/dataset.py +++ b/mrtous/dataset.py @@ -38,9 +38,9 @@ def __getitem__(self, index): class MNIBITENative(Dataset): def __init__(self, root, id, transform=None, axis='z'): + self.axis = axis self.mr = MINC2(os.path.join(root, f'{id:02d}_mr.mnc'), axis) self.us = MINC2(os.path.join(root, f'{id:02d}_us.mnc'), axis) - self.axis = axis self.transform = transform def __getitem__(self, index): diff --git a/mrtous/transform.py b/mrtous/transform.py index d8cb213..7900b1e 100644 --- a/mrtous/transform.py +++ b/mrtous/transform.py @@ -22,17 +22,12 @@ def __call__(self, mr, us): return mr, us -class RandomFlipUpDown(object): +class RandomFlip(object): def __call__(self, image): if random.random() > .5: image = np.flipud(image) - return image - -class RandomFlipLeftRight(object): - - def __call__(self, image): - if random.random() > .5: + if random.random() > .5 image = np.fliplr(image) return image From b8a959f607ec019b3e5fbf54246af8d30ea71b0e Mon Sep 17 00:00:00 2001 From: Bodo Kaiser Date: Fri, 10 Mar 2017 11:44:49 +0100 Subject: [PATCH 5/8] updated dataset, transform and main.py to use float64 --- main.py | 8 +++--- mrtous/dataset.py | 63 +++++++++++++++++++++++++++------------------ mrtous/transform.py | 31 +++++++++++++++++++--- 3 files changed, 70 insertions(+), 32 deletions(-) diff --git a/main.py b/main.py index f4781ff..8f0b3fe 100644 --- a/main.py +++ b/main.py @@ -2,10 +2,9 @@ import argparse import numpy as np -from mrtous import dataset, transform, network +from mrtous import dataset, network from torch import nn, optim, autograd from torch.utils import data -from torchvision import transforms from matplotlib import pyplot as plt from mpl_toolkits import axes_grid1 @@ -78,11 +77,12 @@ def update(images): def main(args): model = network.Basic() model.apply(network.normal_init) + model.double() test_loader = data.DataLoader(dataset.MNIBITENative(args.datadir, - int(args.test[0]), transform.RegionCrop()), shuffle=True) + int(args.test[0])), shuffle=True) train_loader = data.DataLoader(dataset.MNIBITENative(args.datadir, - int(args.train[0]), transform.RegionCrop()), shuffle=True) + int(args.train[0])), shuffle=True) test_losses = [] train_losses = [] diff --git a/mrtous/dataset.py b/mrtous/dataset.py index aadc4c1..9ecb006 100644 --- a/mrtous/dataset.py +++ b/mrtous/dataset.py @@ -1,15 +1,15 @@ import os import h5py import numpy as np -import skimage as sk -from skimage import io, util -from torch.utils.data import Dataset +import skimage +import skimage.io -def normalize(value, vrange): - return (np.array(value, np.float32)-np.min(vrange)) / np.sum(np.abs(vrange)) +from mrtous import transform +from torch.utils import data +from torchvision import transforms -class MINC2(Dataset): +class MINC2(data.Dataset): AXES = ['x', 'y', 'z'] @@ -22,7 +22,7 @@ def __init__(self, filename, axis='z'): self.volume = f['minc-2.0/image/0/image'] self.vrange = f['minc-2.0/image/0/image'].attrs['valid_range'] self.length = f['minc-2.0/dimensions/'+axis+'space'].attrs['length'] - self.volume = normalize(self.volume, self.vrange) + self.volume = np.array(self.volume, np.float64) def __len__(self): return self.length @@ -31,31 +31,44 @@ def __getitem__(self, index): if self.axis == self.AXES[2]: return self.volume[index] if self.axis == self.AXES[1]: - return np.flipud(self.volume[:, index]) + return self.volume[:, index] if self.axis == self.AXES[0]: - return np.flipud(self.volume[:, :, index]) + return self.volume[:, :, index] -class MNIBITENative(Dataset): +class MNIBITENative(data.Dataset): - def __init__(self, root, id, transform=None, axis='z'): - self.axis = axis + def __init__(self, root, id, axis='z'): self.mr = MINC2(os.path.join(root, f'{id:02d}_mr.mnc'), axis) self.us = MINC2(os.path.join(root, f'{id:02d}_us.mnc'), axis) - self.transform = transform + assert len(self.mr) == len(self.us) + + self.axis = axis + + self.input_transform = transforms.Compose([ + transform.Normalize(self.mr.vrange), + transform.CenterCrop(300), + ]) + self.target_transform = transforms.Compose([ + transform.Normalize(self.us.vrange), + transform.CenterCrop(300), + ]) def __getitem__(self, index): - mr = self.mr[index] - us = self.us[index] - if self.transform is not None: - mr, us = self.transform(mr, us) + mr, us = self.mr[index], self.us[index] + + if self.input_transform is not None: + mr = self.input_transform(mr) + if self.target_transform is not None: + us = self.target_transform(us) + return mr, us def __len__(self): return len(self.mr) -class MNIBITEFolder(Dataset): +class MNIBITEFolder(data.Dataset): - def __init__(self, root, transform=None, target_transform=None): + def __init__(self, root, input_transform=None, target_transform=None): if type(root) is str: root = [root] @@ -72,19 +85,19 @@ def __init__(self, root, transform=None, target_transform=None): assert(len(self.mr_fnames) == len(self.us_fnames)) - self.transform = transform + self.input_transform = input_transform self.target_transform = target_transform def __getitem__(self, index): - mr = sk.img_as_float(io.imread(self.mr_fnames[index])) - us = sk.img_as_float(io.imread(self.us_fnames[index])) + mr = skimage.img_as_float(skimage.io.imread(self.mr_fnames[index])) + us = skimage.img_as_float(skimage.io.imread(self.us_fnames[index])) - if self.transform is not None: - mr = self.transform(mr) + if self.input_transform is not None: + mr = self.input_transform(mr) if self.target_transform is not None: us = self.target_transform(us) - return mr.astype(np.float32), us.astype(np.float32) + return mr.astype(np.float64), us.astype(np.float64) def __len__(self): return len(self.mr_fnames) \ No newline at end of file diff --git a/mrtous/transform.py b/mrtous/transform.py index 7900b1e..e173852 100644 --- a/mrtous/transform.py +++ b/mrtous/transform.py @@ -2,13 +2,38 @@ import numpy as np import scipy as sp -from skimage import filters, transform +import skimage.filters +import skimage.transform + +class Normalize(object): + + def __init__(self, vrange): + self.vrange = vrange + + def __call__(self, image): + image -= np.min(self.vrange) + image /= np.sum(np.abs(self.vrange)) + return image + +class CenterCrop(object): + + def __init__(self, size): + self.width = size + self.height = size + + def __call__(self, image): + xlen, ylen = image.shape + + xoff = xlen // 2 - self.width // 2 + yoff = ylen // 2 - self.height // 2 + + return image[xoff:xoff+self.width, yoff:yoff+self.height] class RegionCrop(object): def __call__(self, mr, us): if np.any(mr) and np.any(us): - mask = us > filters.threshold_otsu(us) + mask = us > skimage.filters.threshold_otsu(us) x = np.where(np.any(mask, 0))[0][[0, -1]] y = np.where(np.any(mask, 1))[0][[0, -1]] @@ -27,7 +52,7 @@ class RandomFlip(object): def __call__(self, image): if random.random() > .5: image = np.flipud(image) - if random.random() > .5 + if random.random() > .5: image = np.fliplr(image) return image From e3a0da70787b5751ede4d526bf27363b89a631d6 Mon Sep 17 00:00:00 2001 From: Bodo Kaiser Date: Fri, 10 Mar 2017 12:02:38 +0100 Subject: [PATCH 6/8] updated dataset and scripts/patch.py to support tiff patches --- mrtous/dataset.py | 9 ++++----- scripts/patch.py | 25 ++++++++++--------------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/mrtous/dataset.py b/mrtous/dataset.py index 9ecb006..e22ddcf 100644 --- a/mrtous/dataset.py +++ b/mrtous/dataset.py @@ -78,19 +78,18 @@ def __init__(self, root, input_transform=None, target_transform=None): for root in root: for fname in os.listdir(root): fname = os.path.join(root, fname) - if fname.endswith('_mr.png'): + if fname.endswith('_mr.tif'): self.mr_fnames.append(fname) - if fname.endswith('_us.png'): + if fname.endswith('_us.tif'): self.us_fnames.append(fname) - assert(len(self.mr_fnames) == len(self.us_fnames)) self.input_transform = input_transform self.target_transform = target_transform def __getitem__(self, index): - mr = skimage.img_as_float(skimage.io.imread(self.mr_fnames[index])) - us = skimage.img_as_float(skimage.io.imread(self.us_fnames[index])) + mr = skimage.io.imread(self.mr_fnames[index]) + us = skimage.io.imread(self.us_fnames[index]) if self.input_transform is not None: mr = self.input_transform(mr) diff --git a/scripts/patch.py b/scripts/patch.py index cf440d1..9a9a154 100644 --- a/scripts/patch.py +++ b/scripts/patch.py @@ -1,15 +1,17 @@ import os import sys +import math import argparse import numpy as np -import skimage as sk + +import skimage +import skimage.io +import skimage.util from mrtous import dataset -from skimage import io, util, exposure def image_to_patches(image, size): - stride = int(np.ceil(.3*size)) - patches = util.view_as_windows(image, size, stride) + patches = skimage.util.view_as_windows(image, size, int(math.ceil(.3*size))) return np.reshape(patches, [-1, size, size]) def main(args): @@ -34,17 +36,10 @@ def main(args): indices, = np.where(us_patches.sum((1, 2)) > targetsum) for index in indices: - mr_patch = sk.img_as_uint( - exposure.rescale_intensity(mr_patches[index], - out_range='float')) - us_patch = sk.img_as_uint( - exposure.rescale_intensity(us_patches[index], - out_range='float')) - - io.imsave(os.path.join(targetdir, f'{index}_{axis}_mr.png'), - mr_patch, plugin='freeimage') - io.imsave(os.path.join(targetdir, f'{index}_{axis}_us.png'), - us_patch, plugin='freeimage') + skimage.io.imsave(os.path.join(targetdir, + f'{index}_{axis}_mr.tif'), mr_patches[index]) + skimage.io.imsave(os.path.join(targetdir, + f'{index}_{axis}_us.tif'), us_patches[index]) if __name__ == '__main__': parser = argparse.ArgumentParser() From 535a6ffbd1137c0afbfd4e33fe91483b703efead Mon Sep 17 00:00:00 2001 From: Bodo Kaiser Date: Fri, 10 Mar 2017 12:44:08 +0100 Subject: [PATCH 7/8] updated main.py to use masked loss --- main.py | 67 +++++++++++++++++++++++++-------------------- mrtous/transform.py | 18 ------------ 2 files changed, 38 insertions(+), 47 deletions(-) diff --git a/main.py b/main.py index 8f0b3fe..4ba994c 100644 --- a/main.py +++ b/main.py @@ -2,11 +2,18 @@ import argparse import numpy as np -from mrtous import dataset, network -from torch import nn, optim, autograd -from torch.utils import data -from matplotlib import pyplot as plt -from mpl_toolkits import axes_grid1 +import torch.nn +import torch.optim +import torch.autograd +import torch.utils.data + +import matplotlib.pyplot as plt +import mpl_toolkits.axes_grid1 as axes_grid + +from torch.autograd import Variable +from torch.utils.data import DataLoader + +from mrtous import dataset, network, transform VMIN = 0.0 VMAX = 1.0 @@ -44,7 +51,7 @@ def image_plot(title, subtitles, rows=1, cols=3): fig = plt.figure(figsize=(8, 4)) fig.suptitle(title) - grid = axes_grid1.ImageGrid(fig, 111, (rows, cols), axes_pad=.1, + grid = axes_grid.ImageGrid(fig, 111, (rows, cols), axes_pad=.1, cbar_mode='single', cbar_location='right', label_mode=1) imgs = [] @@ -74,21 +81,24 @@ def update(images): return update +def threshold(image): + value = np.mean(image) - 2*np.var(image) + + mask = image > value + mask = torch.from_numpy(mask.astype(int)) + + return Variable(mask).double() + def main(args): - model = network.Basic() + model = network.Basic().double() model.apply(network.normal_init) - model.double() - test_loader = data.DataLoader(dataset.MNIBITENative(args.datadir, - int(args.test[0])), shuffle=True) - train_loader = data.DataLoader(dataset.MNIBITENative(args.datadir, - int(args.train[0])), shuffle=True) + train_loader = DataLoader(dataset.MNIBITENative(args.datadir, 1)) test_losses = [] train_losses = [] - criterion = nn.MSELoss(size_average=False) - optimizer = optim.Adam(model.parameters()) + optimizer = torch.optim.Adam(model.parameters()) if args.show_loss: update_loss = loss_plot() @@ -99,25 +109,24 @@ def main(args): test_loss = 0 train_loss = 0 - for step, (mr, us) in enumerate(test_loader): - inputs = autograd.Variable(mr).unsqueeze(1) - targets = autograd.Variable(us).unsqueeze(1) - results = model(inputs) + loader = DataLoader(dataset.MNIBITENative( + args.datadir, 1), shuffle=True) - loss = criterion(results, targets) - test_loss += loss.data[0] + for _, (mr, us) in enumerate(loader): + if np.any(us.numpy()) and us.sum() > 100: + mask = threshold(us.numpy()) - for step, (mr, us) in enumerate(train_loader): - inputs = autograd.Variable(mr).unsqueeze(1) - targets = autograd.Variable(us).unsqueeze(1) - results = model(inputs) + inputs = Variable(mr).unsqueeze(1) + targets = Variable(us).unsqueeze(1) + results = model(inputs) - optimizer.zero_grad() - loss = criterion(results, targets) - loss.backward() - optimizer.step() + optimizer.zero_grad() + loss = results[0].mul(mask).dist(targets[0].mul(mask), 2) + loss.div_(mask.sum().data[0]) + loss.backward() + optimizer.step() - train_loss += loss.data[0] + train_loss += loss.data[0] test_losses.append(test_loss) train_losses.append(train_loss) diff --git a/mrtous/transform.py b/mrtous/transform.py index e173852..4998fe6 100644 --- a/mrtous/transform.py +++ b/mrtous/transform.py @@ -29,24 +29,6 @@ def __call__(self, image): return image[xoff:xoff+self.width, yoff:yoff+self.height] -class RegionCrop(object): - - def __call__(self, mr, us): - if np.any(mr) and np.any(us): - mask = us > skimage.filters.threshold_otsu(us) - - x = np.where(np.any(mask, 0))[0][[0, -1]] - y = np.where(np.any(mask, 1))[0][[0, -1]] - - if np.abs(np.diff(x)[0]) < 10 or np.abs(np.diff(y)[0]) < 10: - # "mark" samples which are too small to be filtered - return np.zeros_like(us), np.zeros_like(mr) - - mr = mr[y[0]:y[1], x[0]:x[1]] - us = us[y[0]:y[1], x[0]:x[1]] - - return mr, us - class RandomFlip(object): def __call__(self, image): From a33e1749931bfcd8c681b54f212512bd2c23a463 Mon Sep 17 00:00:00 2001 From: Bodo Kaiser Date: Tue, 14 Mar 2017 09:31:39 +0100 Subject: [PATCH 8/8] rewrote datasets to fix transform issue --- main.py | 46 +++++++++++----------- mrtous/dataset.py | 98 ++++++++++++++++++++++++----------------------- mrtous/network.py | 2 + 3 files changed, 76 insertions(+), 70 deletions(-) diff --git a/main.py b/main.py index 4ba994c..42070b7 100644 --- a/main.py +++ b/main.py @@ -1,19 +1,18 @@ -import os import argparse import numpy as np - +import os +import torch import torch.nn -import torch.optim -import torch.autograd -import torch.utils.data - -import matplotlib.pyplot as plt -import mpl_toolkits.axes_grid1 as axes_grid +from torch.optim import Adam from torch.autograd import Variable from torch.utils.data import DataLoader -from mrtous import dataset, network, transform +from mrtous.network import Basic +from mrtous.dataset import MnibiteNative + +from matplotlib import pyplot as plt +from mpl_toolkits.axes_grid1 import ImageGrid VMIN = 0.0 VMAX = 1.0 @@ -51,7 +50,7 @@ def image_plot(title, subtitles, rows=1, cols=3): fig = plt.figure(figsize=(8, 4)) fig.suptitle(title) - grid = axes_grid.ImageGrid(fig, 111, (rows, cols), axes_pad=.1, + grid = ImageGrid(fig, 111, (rows, cols), axes_pad=.1, cbar_mode='single', cbar_location='right', label_mode=1) imgs = [] @@ -90,15 +89,19 @@ def threshold(image): return Variable(mask).double() def main(args): - model = network.Basic().double() - model.apply(network.normal_init) + model = Basic().double() + + dataset = MnibiteNative(args.datadir, int(args.train)) + dataloader = DataLoader(dataset, batch_size=1, shuffle=True) - train_loader = DataLoader(dataset.MNIBITENative(args.datadir, 1)) + mr, us = dataset[120] + fixed_inputs = Variable(torch.from_numpy(mr)).unsqueeze(0).unsqueeze(0) + fixed_targets = Variable(torch.from_numpy(us)).unsqueeze(0).unsqueeze(0) test_losses = [] train_losses = [] - optimizer = torch.optim.Adam(model.parameters()) + optimizer = Adam(model.parameters()) if args.show_loss: update_loss = loss_plot() @@ -109,10 +112,7 @@ def main(args): test_loss = 0 train_loss = 0 - loader = DataLoader(dataset.MNIBITENative( - args.datadir, 1), shuffle=True) - - for _, (mr, us) in enumerate(loader): + for mr, us in dataloader: if np.any(us.numpy()) and us.sum() > 100: mask = threshold(us.numpy()) @@ -135,9 +135,9 @@ def main(args): update_loss(train_losses, test_losses) if args.show_images: update_images([ - inputs.data[0][0].numpy(), - targets.data[0][0].numpy(), - results.data[0][0].numpy(), + fixed_inputs.data[0][0].numpy(), + fixed_targets.data[0][0].numpy(), + model(fixed_inputs).data[0][0].numpy(), ]) print(f'testing (epoch: {epoch}, loss: {test_loss}') @@ -147,8 +147,8 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--test', type=str, nargs='+', default=['11']) - parser.add_argument('--train', type=str, nargs='+', default=['13']) + parser.add_argument('--test', type=int, nargs='?', default=11) + parser.add_argument('--train', type=int, nargs='?', default=13) parser.add_argument('--epochs', type=int, nargs='?', default=20) parser.add_argument('--datadir', type=str, nargs='?', default='mnibite') parser.add_argument('--show-loss', dest='show_loss', action='store_true') diff --git a/mrtous/dataset.py b/mrtous/dataset.py index e22ddcf..de04737 100644 --- a/mrtous/dataset.py +++ b/mrtous/dataset.py @@ -1,56 +1,64 @@ -import os import h5py import numpy as np +import os import skimage import skimage.io -from mrtous import transform -from torch.utils import data -from torchvision import transforms +from torch.utils.data import Dataset +from torchvision.transforms import Compose + +from mrtous.transform import Normalize, CenterCrop -class MINC2(data.Dataset): +class Minc2Z(Dataset): - AXES = ['x', 'y', 'z'] + def __init__(self, filename): + self.hdf = h5py.File(filename, 'r') - def __init__(self, filename, axis='z'): - if not axis in self.AXES: - raise ValueError('axis must be "x", "y" or "z"') - self.axis = axis + @property + def volume(self): + return self.hdf['minc-2.0/image/0/image'] - with h5py.File(filename, 'r') as f: - self.volume = f['minc-2.0/image/0/image'] - self.vrange = f['minc-2.0/image/0/image'].attrs['valid_range'] - self.length = f['minc-2.0/dimensions/'+axis+'space'].attrs['length'] - self.volume = np.array(self.volume, np.float64) + @property + def vrange(self): + return self.volume.attrs['valid_range'] + + def __getitem__(self, index): + return self.volume[index].astype(np.float64) def __len__(self): - return self.length + return self.volume.shape[0] + +class Minc2Y(Minc2Z): def __getitem__(self, index): - if self.axis == self.AXES[2]: - return self.volume[index] - if self.axis == self.AXES[1]: - return self.volume[:, index] - if self.axis == self.AXES[0]: - return self.volume[:, :, index] - -class MNIBITENative(data.Dataset): - - def __init__(self, root, id, axis='z'): - self.mr = MINC2(os.path.join(root, f'{id:02d}_mr.mnc'), axis) - self.us = MINC2(os.path.join(root, f'{id:02d}_us.mnc'), axis) - assert len(self.mr) == len(self.us) + return self.volume[:, index].astype(np.float64) + + def __len__(self): + return self.volume.shape[1] - self.axis = axis +class Minc2X(Minc2Z): - self.input_transform = transforms.Compose([ - transform.Normalize(self.mr.vrange), - transform.CenterCrop(300), + def __getitem__(self, index): + return self.volume[:, :, index].astype(np.float64) + + def __len__(self): + return self.volume.shape[2] + +class MnibiteNative(Dataset): + + def __init__(self, root, id): + self.mr = Minc2Z(os.path.join(root, f'{id:02d}_mr.mnc')) + self.us = Minc2Z(os.path.join(root, f'{id:02d}_us.mnc')) + assert len(self.mr) == len(self.us) + + self.input_transform = Compose([ + Normalize(self.mr.vrange), + CenterCrop(300), ]) - self.target_transform = transforms.Compose([ - transform.Normalize(self.us.vrange), - transform.CenterCrop(300), + self.target_transform = Compose([ + Normalize(self.us.vrange), + CenterCrop(300), ]) def __getitem__(self, index): @@ -66,22 +74,18 @@ def __getitem__(self, index): def __len__(self): return len(self.mr) -class MNIBITEFolder(data.Dataset): +class MnibiteFolder(Dataset): def __init__(self, root, input_transform=None, target_transform=None): - if type(root) is str: - root = [root] - self.mr_fnames = [] self.us_fnames = [] - for root in root: - for fname in os.listdir(root): - fname = os.path.join(root, fname) - if fname.endswith('_mr.tif'): - self.mr_fnames.append(fname) - if fname.endswith('_us.tif'): - self.us_fnames.append(fname) + for fname in os.listdir(root): + fname = os.path.join(root, fname) + if fname.endswith('_mr.tif'): + self.mr_fnames.append(fname) + if fname.endswith('_us.tif'): + self.us_fnames.append(fname) assert(len(self.mr_fnames) == len(self.us_fnames)) self.input_transform = input_transform diff --git a/mrtous/network.py b/mrtous/network.py index baa2354..55adb15 100644 --- a/mrtous/network.py +++ b/mrtous/network.py @@ -15,6 +15,8 @@ def __init__(self): self.conv = nn.Conv2d(1, 3, 3, padding=1) self.conn = nn.Conv2d(3, 1, 1) + self.apply(normal_init) + def forward(self, x): x = self.conv(x) x = self.conn(x)