diff --git a/cascadenet/network/layers/__init__.py b/cascadenet/network/layers/__init__.py index eeb2b21b..a226092a 100755 --- a/cascadenet/network/layers/__init__.py +++ b/cascadenet/network/layers/__init__.py @@ -9,7 +9,6 @@ from .data_consistency import * from .helper import * from .kspace_averaging import * -from .conv3d import * try: from .conv3d import * except ImportError as e: diff --git a/main_2d.py b/main_2d.py index 40e376b9..489f970e 100755 --- a/main_2d.py +++ b/main_2d.py @@ -21,7 +21,7 @@ from cascadenet.util.helpers import to_lasagne_format -def prep_input(im, gauss_ivar=1e-3): +def prep_input(im, acc=4): """Undersample the batch, then reformat them into what the network accepts. Parameters @@ -29,14 +29,8 @@ def prep_input(im, gauss_ivar=1e-3): gauss_ivar: float - controls the undersampling rate. higher the value, more undersampling """ - mask = cs.cartesian_mask(im.shape, gauss_ivar, - centred=False, - sample_high_freq=True, - sample_centre=True, - sample_n=8) - + mask = cs.cartesian_mask(im.shape, acc, sample_n=8) im_und, k_und = cs.undersample(im, mask, centred=False, norm='ortho') - im_gnd_l = to_lasagne_format(im) im_und_l = to_lasagne_format(im_und) k_und_l = to_lasagne_format(k_und) @@ -124,10 +118,13 @@ def compile_fn(network, net_config, args): default=['0.001'], help='initial learning rate') parser.add_argument('--l2', metavar='float', nargs=1, default=['1e-6'], help='l2 regularisation') - parser.add_argument('--gauss_ivar', metavar='float', nargs=1, - default=['0.0015'], - help='Sensitivity for Gaussian Distribution which' - 'decides the undersampling rate of the Cartesian mask') + parser.add_argument('--acceleration_factor', metavar='float', nargs=1, + default=['4.0'], + help='Acceleration factor for k-space sampling') + # parser.add_argument('--gauss_ivar', metavar='float', nargs=1, + # default=['0.0015'], + # help='Sensitivity for Gaussian Distribution which' + # 'decides the undersampling rate of the Cartesian mask') parser.add_argument('--debug', action='store_true', help='debug mode') parser.add_argument('--savefig', action='store_true', help='Save output images and masks') @@ -136,7 +133,8 @@ def compile_fn(network, net_config, args): # Project config model_name = 'd2_c2' - gauss_ivar = float(args.gauss_ivar[0]) # undersampling rate + #gauss_ivar = float(args.gauss_ivar[0]) # undersampling rate + acc = float(args.acceleration_factor[0]) # undersampling rate num_epoch = int(args.num_epoch[0]) batch_size = int(args.batch_size[0]) Nx, Ny = 128, 128 @@ -155,25 +153,29 @@ def compile_fn(network, net_config, args): net_config, net, = build_d5_c5(input_shape) # Compute acceleration rate - dummy_mask = cs.cartesian_mask((500, Nx, Ny), gauss_ivar, - sample_high_freq=True, - sample_centre=True, sample_n=8) - acc = cs.undersampling_rate(dummy_mask) - print('Acceleration Rate: {:.2f}'.format(acc)) + dummy_mask = cs.cartesian_mask((10, Nx, Ny), acc, sample_n=8) + sample_und_factor = cs.undersampling_rate(dummy_mask) + print('Undersampling Rate: {:.2f}'.format(sample_und_factor)) # Compile function train_fn, val_fn = compile_fn(net, net_config, args) + # D5-C5 with pre-trained parameters + with np.load('./models/pretrained/d5_c5.npz') as f: + param_values = [f['arr_{0}'.format(i)] for i in range(len(f.files))] + lasagne.layers.set_all_param_values(net, param_values) + # Create dataset train, validate, test = create_dummy_data() + print('Start Training...') for epoch in xrange(num_epoch): t_start = time.time() # Training train_err = 0 train_batches = 0 for im in iterate_minibatch(train, batch_size, shuffle=True): - im_und, k_und, mask, im_gnd = prep_input(im, gauss_ivar=gauss_ivar) + im_und, k_und, mask, im_gnd = prep_input(im, acc=acc) err = train_fn(im_und, mask, k_und, im_gnd)[0] train_err += err train_batches += 1 @@ -184,7 +186,7 @@ def compile_fn(network, net_config, args): validate_err = 0 validate_batches = 0 for im in iterate_minibatch(validate, batch_size, shuffle=False): - im_und, k_und, mask, im_gnd = prep_input(im, gauss_ivar=gauss_ivar) + im_und, k_und, mask, im_gnd = prep_input(im, acc=acc) err, pred = val_fn(im_und, mask, k_und, im_gnd) validate_err += err validate_batches += 1 @@ -198,7 +200,7 @@ def compile_fn(network, net_config, args): test_psnr = 0 test_batches = 0 for im in iterate_minibatch(test, batch_size, shuffle=False): - im_und, k_und, mask, im_gnd = prep_input(im, gauss_ivar=gauss_ivar) + im_und, k_und, mask, im_gnd = prep_input(im, acc=acc) err, pred = val_fn(im_und, mask, k_und, im_gnd) test_err += err