Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

elevate project to the next level #9

Closed
wants to merge 8 commits into from
127 changes: 48 additions & 79 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import os
import argparse
import numpy as np
import os
import torch
import torch.nn

from torch.optim import Adam
from torch.autograd import Variable
from torch.utils.data import DataLoader

from mrtous.network import Basic
from mrtous.dataset import MnibiteNative

from mrtous import dataset, transform, network
from torch import nn, optim, autograd
from torch.utils import data
from torchvision import transforms
from matplotlib import pyplot as plt
from mpl_toolkits import axes_grid1
from mpl_toolkits.axes_grid1 import ImageGrid

VMIN = 0.0
VMAX = 1.0
Expand Down Expand Up @@ -45,7 +50,7 @@ def image_plot(title, subtitles, rows=1, cols=3):
fig = plt.figure(figsize=(8, 4))
fig.suptitle(title)

grid = axes_grid1.ImageGrid(fig, 111, (rows, cols), axes_pad=.1,
grid = ImageGrid(fig, 111, (rows, cols), axes_pad=.1,
cbar_mode='single', cbar_location='right', label_mode=1)

imgs = []
Expand Down Expand Up @@ -75,99 +80,64 @@ def update(images):

return update

def threshold(image):
value = np.mean(image) - 2*np.var(image)

mask = image > value
mask = torch.from_numpy(mask.astype(int))

return Variable(mask).double()

def main(args):
model = network.Simple()

test_loader = data.DataLoader(dataset.MNIBITEFolder(
map(lambda d: os.path.join(args.datadir, d), args.test)),
shuffle=True, batch_size=128, num_workers=4)
train_loader = data.DataLoader(dataset.MNIBITEFolder(
map(lambda d: os.path.join(args.datadir, d), args.train),
transform=transforms.Compose([
# for some reason not using this gives better performance
#transform.RandomZoom(),
#transform.RandomRotate(),
transform.RandomFlipUpDown(),
transform.RandomFlipLeftRight(),
])),
shuffle=True, batch_size=128, num_workers=4)
model = Basic().double()

if args.show_images:
image_loader = data.DataLoader(dataset.MNIBITENative(args.datadir,
int(args.train[0]), transform.RegionCrop()), shuffle=True)
dataset = MnibiteNative(args.datadir, int(args.train))
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

mr, us = dataset[120]
fixed_inputs = Variable(torch.from_numpy(mr)).unsqueeze(0).unsqueeze(0)
fixed_targets = Variable(torch.from_numpy(us)).unsqueeze(0).unsqueeze(0)

test_losses = []
train_losses = []

criterion = nn.MSELoss(size_average=False)
optimizer = optim.Adam(model.parameters())
optimizer = Adam(model.parameters())

if args.show_loss:
update_loss = loss_plot()
if args.show_images:
update_images = image_plot('training images',
['MRI', 'US', 'RE'])
if args.show_patches:
update_patches = image_plot('training and testing patches',
['MRI', 'US', 'RE'], rows=2)
update_images = image_plot('Training Images', ['MRI', 'US', 'OUT'])

for epoch in range(1, args.epochs+1):
test_loss = 0
train_loss = 0

for step, (mr, us) in enumerate(train_loader):
inputs = autograd.Variable(mr).unsqueeze(1)
targets = autograd.Variable(us).unsqueeze(1)
results = model(inputs)

optimizer.zero_grad()
loss = criterion(results, targets)
loss.backward()
optimizer.step()
for mr, us in dataloader:
if np.any(us.numpy()) and us.sum() > 100:
mask = threshold(us.numpy())

train_loss += loss.data[0]

train_patches = [
inputs.data[0][0].numpy(),
targets.data[0][0].numpy(),
results.data[0][0].numpy(),
]
train_losses.append(train_loss)
inputs = Variable(mr).unsqueeze(1)
targets = Variable(us).unsqueeze(1)
results = model(inputs)

for step, (mr, us) in enumerate(test_loader):
inputs = autograd.Variable(mr).unsqueeze(1)
targets = autograd.Variable(us).unsqueeze(1)
results = model(inputs)
optimizer.zero_grad()
loss = results[0].mul(mask).dist(targets[0].mul(mask), 2)
loss.div_(mask.sum().data[0])
loss.backward()
optimizer.step()

loss = criterion(results, targets)
test_loss += loss.data[0]
train_loss += loss.data[0]

test_patches = [
inputs.data[0][0].numpy(),
targets.data[0][0].numpy(),
results.data[0][0].numpy(),
]
test_losses.append(test_loss)
train_losses.append(train_loss)

if args.show_loss:
update_loss(train_losses, test_losses)
if args.show_images:
for _, (mr, us) in enumerate(image_loader):
if np.any(us.numpy()) and sum(us.numpy().shape[1:3]) > 30:
inputs = autograd.Variable(mr).unsqueeze(1)
targets = autograd.Variable(us).unsqueeze(1)
results = model(inputs)
break

update_images([
inputs.data[0][0].numpy(),
targets.data[0][0].numpy(),
results.data[0][0].numpy(),
])
if args.show_patches:
update_patches([
*train_patches,
*test_patches,
fixed_inputs.data[0][0].numpy(),
fixed_targets.data[0][0].numpy(),
model(fixed_inputs).data[0][0].numpy(),
])

print(f'testing (epoch: {epoch}, loss: {test_loss}')
Expand All @@ -177,12 +147,11 @@ def main(args):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--test', type=str, nargs='+', default=['11'])
parser.add_argument('--train', type=str, nargs='+', default=['13'])
parser.add_argument('--test', type=int, nargs='?', default=11)
parser.add_argument('--train', type=int, nargs='?', default=13)
parser.add_argument('--epochs', type=int, nargs='?', default=20)
parser.add_argument('--datadir', type=str, nargs='?', default='mnibite')
parser.add_argument('--show-loss', dest='show_loss', action='store_true')
parser.add_argument('--show-images', dest='show_images', action='store_true')
parser.add_argument('--show-patches', dest='show_patches', action='store_true')
parser.set_defaults(show_loss=False, show_images=False, show_patches=False)
parser.set_defaults(show_loss=False, show_images=False)
main(parser.parse_args())
119 changes: 68 additions & 51 deletions mrtous/dataset.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,106 @@
import os
import h5py
import numpy as np
import skimage as sk
import os

import skimage
import skimage.io

from skimage import io, util
from torch.utils.data import Dataset
from torchvision.transforms import Compose

from mrtous.transform import Normalize, CenterCrop

class Minc2Z(Dataset):

def normalize(value, vrange):
return (np.array(value, np.float32)-np.min(vrange)) / np.sum(np.abs(vrange))
def __init__(self, filename):
self.hdf = h5py.File(filename, 'r')

class MINC2(Dataset):
@property
def volume(self):
return self.hdf['minc-2.0/image/0/image']

AXES = ['x', 'y', 'z']
@property
def vrange(self):
return self.volume.attrs['valid_range']

def __init__(self, filename, axis='z'):
if not axis in self.AXES:
raise ValueError('axis must be "x", "y" or "z"')
self.axis = axis
def __getitem__(self, index):
return self.volume[index].astype(np.float64)

def __len__(self):
return self.volume.shape[0]

with h5py.File(filename, 'r') as f:
self.volume = f['minc-2.0/image/0/image']
self.vrange = f['minc-2.0/image/0/image'].attrs['valid_range']
self.length = f['minc-2.0/dimensions/'+axis+'space'].attrs['length']
self.volume = normalize(self.volume, self.vrange)
class Minc2Y(Minc2Z):

def __getitem__(self, index):
return self.volume[:, index].astype(np.float64)

def __len__(self):
return self.length
return self.volume.shape[1]

class Minc2X(Minc2Z):

def __getitem__(self, index):
if self.axis == self.AXES[2]:
return self.volume[index]
if self.axis == self.AXES[1]:
return np.flipud(self.volume[:, index])
if self.axis == self.AXES[0]:
return np.flipud(self.volume[:, :, index])
return self.volume[:, :, index].astype(np.float64)

def __len__(self):
return self.volume.shape[2]

class MNIBITENative(Dataset):
class MnibiteNative(Dataset):

def __init__(self, root, id, transform=None, axis='z'):
self.mr = MINC2(os.path.join(root, f'{id:02d}_mr.mnc'), axis)
self.us = MINC2(os.path.join(root, f'{id:02d}_us.mnc'), axis)
self.transform = transform
def __init__(self, root, id):
self.mr = Minc2Z(os.path.join(root, f'{id:02d}_mr.mnc'))
self.us = Minc2Z(os.path.join(root, f'{id:02d}_us.mnc'))
assert len(self.mr) == len(self.us)

self.input_transform = Compose([
Normalize(self.mr.vrange),
CenterCrop(300),
])
self.target_transform = Compose([
Normalize(self.us.vrange),
CenterCrop(300),
])

def __getitem__(self, index):
mr = self.mr[index]
us = self.us[index]
if self.transform is not None:
mr, us = self.transform(mr, us)
mr, us = self.mr[index], self.us[index]

if self.input_transform is not None:
mr = self.input_transform(mr)
if self.target_transform is not None:
us = self.target_transform(us)

return mr, us

def __len__(self):
return len(self.mr)

class MNIBITEFolder(Dataset):

def __init__(self, root, transform=None, target_transform=None):
if type(root) is str:
root = [root]
class MnibiteFolder(Dataset):

def __init__(self, root, input_transform=None, target_transform=None):
self.mr_fnames = []
self.us_fnames = []

for root in root:
for fname in os.listdir(root):
fname = os.path.join(root, fname)
if fname.endswith('_mr.png'):
self.mr_fnames.append(fname)
if fname.endswith('_us.png'):
self.us_fnames.append(fname)

for fname in os.listdir(root):
fname = os.path.join(root, fname)
if fname.endswith('_mr.tif'):
self.mr_fnames.append(fname)
if fname.endswith('_us.tif'):
self.us_fnames.append(fname)
assert(len(self.mr_fnames) == len(self.us_fnames))

self.transform = transform
self.input_transform = input_transform
self.target_transform = target_transform

def __getitem__(self, index):
mr = sk.img_as_float(io.imread(self.mr_fnames[index]))
us = sk.img_as_float(io.imread(self.us_fnames[index]))
mr = skimage.io.imread(self.mr_fnames[index])
us = skimage.io.imread(self.us_fnames[index])

if self.transform is not None:
mr = self.transform(mr)
if self.input_transform is not None:
mr = self.input_transform(mr)
if self.target_transform is not None:
us = self.target_transform(us)

return mr.astype(np.float32), us.astype(np.float32)
return mr.astype(np.float64), us.astype(np.float64)

def __len__(self):
return len(self.mr_fnames)
24 changes: 6 additions & 18 deletions mrtous/network.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
import torch.nn as nn

def init_weight(layer):
layer.weight.data.normal_(0.5, 0.2)
def normal_init(model):
classname = model.__class__.__name__

if classname.find('Conv') != -1:
model.weight.data.normal_(0.5, 0.3)

class Basic(nn.Module):

Expand All @@ -12,24 +15,9 @@ def __init__(self):
self.conv = nn.Conv2d(1, 3, 3, padding=1)
self.conn = nn.Conv2d(3, 1, 1)

init_weight(self.conv)
self.apply(normal_init)

def forward(self, x):
x = self.conv(x)
x = self.conn(x)
return x

class Simple(nn.Module):

def __init__(self):
super().__init__()

self.conv1 = nn.Conv2d(1, 16, 5, padding=2)
self.conv2 = nn.Conv2d(16, 64, 3, padding=1)
self.final = nn.Conv2d(64, 1, 1)

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.final(x)
return x
Loading