Skip to content

Commit

Permalink
cartesian mask upated
Browse files Browse the repository at this point in the history
  • Loading branch information
Jo Schlemper committed Nov 10, 2017
1 parent 8695efe commit c550f1e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
1 change: 0 additions & 1 deletion cascadenet/network/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .data_consistency import *
from .helper import *
from .kspace_averaging import *
from .conv3d import *
try:
from .conv3d import *
except ImportError as e:
Expand Down
44 changes: 23 additions & 21 deletions main_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,16 @@
from cascadenet.util.helpers import to_lasagne_format


def prep_input(im, gauss_ivar=1e-3):
def prep_input(im, acc=4):
"""Undersample the batch, then reformat them into what the network accepts.
Parameters
----------
gauss_ivar: float - controls the undersampling rate.
higher the value, more undersampling
"""
mask = cs.cartesian_mask(im.shape, gauss_ivar,
centred=False,
sample_high_freq=True,
sample_centre=True,
sample_n=8)

mask = cs.cartesian_mask(im.shape, acc, sample_n=8)
im_und, k_und = cs.undersample(im, mask, centred=False, norm='ortho')

im_gnd_l = to_lasagne_format(im)
im_und_l = to_lasagne_format(im_und)
k_und_l = to_lasagne_format(k_und)
Expand Down Expand Up @@ -124,10 +118,13 @@ def compile_fn(network, net_config, args):
default=['0.001'], help='initial learning rate')
parser.add_argument('--l2', metavar='float', nargs=1,
default=['1e-6'], help='l2 regularisation')
parser.add_argument('--gauss_ivar', metavar='float', nargs=1,
default=['0.0015'],
help='Sensitivity for Gaussian Distribution which'
'decides the undersampling rate of the Cartesian mask')
parser.add_argument('--acceleration_factor', metavar='float', nargs=1,
default=['4.0'],
help='Acceleration factor for k-space sampling')
# parser.add_argument('--gauss_ivar', metavar='float', nargs=1,
# default=['0.0015'],
# help='Sensitivity for Gaussian Distribution which'
# 'decides the undersampling rate of the Cartesian mask')
parser.add_argument('--debug', action='store_true', help='debug mode')
parser.add_argument('--savefig', action='store_true',
help='Save output images and masks')
Expand All @@ -136,7 +133,8 @@ def compile_fn(network, net_config, args):

# Project config
model_name = 'd2_c2'
gauss_ivar = float(args.gauss_ivar[0]) # undersampling rate
#gauss_ivar = float(args.gauss_ivar[0]) # undersampling rate
acc = float(args.acceleration_factor[0]) # undersampling rate
num_epoch = int(args.num_epoch[0])
batch_size = int(args.batch_size[0])
Nx, Ny = 128, 128
Expand All @@ -155,25 +153,29 @@ def compile_fn(network, net_config, args):
net_config, net, = build_d5_c5(input_shape)

# Compute acceleration rate
dummy_mask = cs.cartesian_mask((500, Nx, Ny), gauss_ivar,
sample_high_freq=True,
sample_centre=True, sample_n=8)
acc = cs.undersampling_rate(dummy_mask)
print('Acceleration Rate: {:.2f}'.format(acc))
dummy_mask = cs.cartesian_mask((10, Nx, Ny), acc, sample_n=8)
sample_und_factor = cs.undersampling_rate(dummy_mask)
print('Undersampling Rate: {:.2f}'.format(sample_und_factor))

# Compile function
train_fn, val_fn = compile_fn(net, net_config, args)

# D5-C5 with pre-trained parameters
with np.load('./models/pretrained/d5_c5.npz') as f:
param_values = [f['arr_{0}'.format(i)] for i in range(len(f.files))]
lasagne.layers.set_all_param_values(net, param_values)

# Create dataset
train, validate, test = create_dummy_data()

print('Start Training...')
for epoch in xrange(num_epoch):
t_start = time.time()
# Training
train_err = 0
train_batches = 0
for im in iterate_minibatch(train, batch_size, shuffle=True):
im_und, k_und, mask, im_gnd = prep_input(im, gauss_ivar=gauss_ivar)
im_und, k_und, mask, im_gnd = prep_input(im, acc=acc)
err = train_fn(im_und, mask, k_und, im_gnd)[0]
train_err += err
train_batches += 1
Expand All @@ -184,7 +186,7 @@ def compile_fn(network, net_config, args):
validate_err = 0
validate_batches = 0
for im in iterate_minibatch(validate, batch_size, shuffle=False):
im_und, k_und, mask, im_gnd = prep_input(im, gauss_ivar=gauss_ivar)
im_und, k_und, mask, im_gnd = prep_input(im, acc=acc)
err, pred = val_fn(im_und, mask, k_und, im_gnd)
validate_err += err
validate_batches += 1
Expand All @@ -198,7 +200,7 @@ def compile_fn(network, net_config, args):
test_psnr = 0
test_batches = 0
for im in iterate_minibatch(test, batch_size, shuffle=False):
im_und, k_und, mask, im_gnd = prep_input(im, gauss_ivar=gauss_ivar)
im_und, k_und, mask, im_gnd = prep_input(im, acc=acc)

err, pred = val_fn(im_und, mask, k_und, im_gnd)
test_err += err
Expand Down

0 comments on commit c550f1e

Please sign in to comment.