diff --git a/README.rst b/README.rst index 2b334696..11ec2955 100755 --- a/README.rst +++ b/README.rst @@ -1,11 +1,12 @@ -Deep Cascade of Convolutional Neural Networks for MR Image Reconstruction +Deep Cascade of Convolutional Neural Networks and Convolutioanl Recurrent Nerual Networks for MR Image Reconstruction + ========================================================================= Reconstruct MR images from its undersampled measurements using Deep Cascade of -Convolutional Neural Networks (DC-CNN). This repository contains the -implementation of DC-CNN using Theano and Lasagne and the simple demo. Note that +Convolutional Neural Networks (DC-CNN) and Convolutional Recurrent Neural Networks (CRNN-MRI). This repository contains the +implementation of DC-CNN using Theano and Lasagne, and CRNN-MRI using PyTorch, along with simple demos. Note that the library requires the dev version of Lasagne and Theano, as well as pygpu -backend for using CUFFT Library. Some of the toy dataset borrowed from +backend for using CUFFT Library. PyTorch version needs to be higher than Torch 0.4. Some of the toy dataset borrowed from . 1. 2D Reconstruction @@ -31,6 +32,20 @@ Usage:: python main_3d.py --acceleration_factor 4 +---- + +3. Dynamic Reconstruction using Convolutional Recurrent Neural Networks +========================================================================= + +Reconstruct dynamic MR images from its undersampled measurements using +Convolutional Recurrent Neural Networks. This is a pytorch implementation requiring +Torch 0.4. + +Usage:: + + python main_crnn.py --acceleration_factor 4 + + ---- @@ -55,3 +70,12 @@ Dynamic Reconstruction:: ---- The paper is also available on arXiv: + + +Dynamic Reconstruction using CRNN:: + + Qin, C., Schlemper, J., Caballero, J., Hajnal, J. V., Price, A., & Rueckert, D. Convolutional Recurrent Neural Networks for Dynamic MR Image Reconstruction. IEEE transactions on medical imaging (2018). + +---- + +The paper is also available on arXiv: diff --git a/cascadenet_pytorch/__init__.py b/cascadenet_pytorch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cascadenet_pytorch/dnn_io.py b/cascadenet_pytorch/dnn_io.py new file mode 100644 index 00000000..300ad225 --- /dev/null +++ b/cascadenet_pytorch/dnn_io.py @@ -0,0 +1,84 @@ +import numpy as np + +def r2c(x, axis=1): + """Convert pseudo-complex data (2 real channels) to complex data + + x: ndarray + input data + axis: int + the axis that is used to represent the real and complex channel. + e.g. if axis == i, then x.shape looks like (n_1, n_2, ..., n_i-1, 2, n_i+1, ..., nm) + """ + shape = x.shape + if axis < 0: axis = x.ndim + axis + ctype = np.complex64 if x.dtype == np.float32 else np.complex128 + + if axis < len(shape): + newshape = tuple([i for i in range(0, axis)]) \ + + tuple([i for i in range(axis+1, x.ndim)]) + (axis,) + + x = x.transpose(newshape) + + x = np.ascontiguousarray(x).view(dtype=ctype) + return x.reshape(x.shape[:-1]) + + +def c2r(x, axis=1): + """Convert complex data to pseudo-complex data (2 real channels) + + x: ndarray + input data + axis: int + the axis that is used to represent the real and complex channel. + e.g. if axis == i, then x.shape looks like (n_1, n_2, ..., n_i-1, 2, n_i+1, ..., nm) + """ + shape = x.shape + dtype = np.float32 if x.dtype == np.complex64 else np.float64 + + x = np.ascontiguousarray(x).view(dtype=dtype).reshape(shape + (2,)) + + n = x.ndim + if axis < 0: axis = n + axis + if axis < n: + newshape = tuple([i for i in range(0, axis)]) + (n-1,) \ + + tuple([i for i in range(axis, n-1)]) + x = x.transpose(newshape) + + return x + + +def mask_r2c(m): + return m[0] if m.ndim == 3 else m[:, 0] + + +def to_tensor_format(x, mask=False): + """ + Assumes data is of shape (n[, nt], nx, ny). + Reshapes to (n, n_channels, nx, ny[, nt]) + Note: Depth must be the last axis, the dimensions will be reordered + """ + if x.ndim == 4: # n 3D inputs. reorder axes + x = np.transpose(x, (0, 2, 3, 1)) + + if mask: # Hacky solution + x = x*(1+1j) + + x = c2r(x) + + return x + + +def from_tensor_format(x, mask=False): + """ + Assumes data is of shape (n, 2, nx, ny[, nt]). + Reshapes to (n, [nt, ]nx, ny) + """ + if x.ndim == 5: # n 3D inputs. reorder axes + x = np.transpose(x, (0, 1, 4, 2, 3)) + + if mask: + x = mask_r2c(x) + else: + x = r2c(x) + + return x diff --git a/cascadenet_pytorch/kspace_pytorch.py b/cascadenet_pytorch/kspace_pytorch.py new file mode 100644 index 00000000..023f8e27 --- /dev/null +++ b/cascadenet_pytorch/kspace_pytorch.py @@ -0,0 +1,236 @@ +import numpy as np +import torch +import torch.nn as nn + + +def data_consistency(k, k0, mask, noise_lvl=None): + """ + k - input in k-space + k0 - initially sampled elements in k-space + mask - corresponding nonzero location + """ + v = noise_lvl + if v: # noisy case + out = (1 - mask) * k + mask * (k + v * k0) / (1 + v) + else: # noiseless case + out = (1 - mask) * k + mask * k0 + return out + + +class DataConsistencyInKspace(nn.Module): + """ Create data consistency operator + + Warning: note that FFT2 (by the default of torch.fft) is applied to the last 2 axes of the input. + This method detects if the input tensor is 4-dim (2D data) or 5-dim (3D data) + and applies FFT2 to the (nx, ny) axis. + + """ + + def __init__(self, noise_lvl=None, norm='ortho'): + super(DataConsistencyInKspace, self).__init__() + self.normalized = norm == 'ortho' + self.noise_lvl = noise_lvl + + def forward(self, *input, **kwargs): + return self.perform(*input) + + def perform(self, x, k0, mask): + """ + x - input in image domain, of shape (n, 2, nx, ny[, nt]) + k0 - initially sampled elements in k-space + mask - corresponding nonzero location + """ + + if x.dim() == 4: # input is 2D + x = x.permute(0, 2, 3, 1) + k0 = k0.permute(0, 2, 3, 1) + mask = mask.permute(0, 2, 3, 1) + elif x.dim() == 5: # input is 3D + x = x.permute(0, 4, 2, 3, 1) + k0 = k0.permute(0, 4, 2, 3, 1) + mask = mask.permute(0, 4, 2, 3, 1) + + k = torch.fft(x, 2, normalized=self.normalized) + out = data_consistency(k, k0, mask, self.noise_lvl) + x_res = torch.ifft(out, 2, normalized=self.normalized) + + if x.dim() == 4: + x_res = x_res.permute(0, 3, 1, 2) + elif x.dim() == 5: + x_res = x_res.permute(0, 4, 2, 3, 1) + + return x_res + + +def get_add_neighbour_op(nc, frame_dist, divide_by_n, clipped): + max_sample = max(frame_dist) *2 + 1 + + # for non-clipping, increase the input circularly + if clipped: + padding = (max_sample//2, 0, 0) + else: + padding = 0 + + # expect data to be in this format: (n, nc, nt, nx, ny) (due to FFT) + conv = nn.Conv3d(in_channels=nc, out_channels=nc*len(frame_dist), + kernel_size=(max_sample, 1, 1), + stride=1, padding=padding, bias=False) + + # Although there is only 1 parameter, need to iterate as parameters return generator + conv.weight.requires_grad = False + + # kernel has size nc=2, nc'=8, kt, kx, ky + for i, n in enumerate(frame_dist): + m = max_sample // 2 + #c = 1 / (n * 2 + 1) if divide_by_n else 1 + c = 1 + wt = np.zeros((2, max_sample, 1, 1), dtype=np.float32) + wt[0, m-n:m+n+1] = c + wt2 = np.zeros((2, max_sample, 1, 1), dtype=np.float32) + wt2[1, m-n:m+n+1] = c + + conv.weight.data[2*i] = torch.from_numpy(wt) + conv.weight.data[2*i+1] = torch.from_numpy(wt2) + + conv.cuda() + return conv + + +class KspaceFillNeighbourLayer(nn.Module): + ''' + k-space fill layer - The input data is assumed to be in k-space grid. + + The input data is assumed to be in k-space grid. + This layer should be invoked from AverageInKspaceLayer + ''' + def __init__(self, frame_dist, divide_by_n=False, clipped=True, **kwargs): + # frame_dist is the extent that data sharing goes. + # e.g. current frame is 3, frame_dist = 2, then 1,2, and 4,5 are added for reconstructing 3 + super(KspaceFillNeighbourLayer, self).__init__() + print("fr_d={}, divide_by_n={}, clippd={}".format(frame_dist, divide_by_n, clipped)) + if 0 not in frame_dist: + raise ValueError("There suppose to be a 0 in fr_d in config file!") + frame_dist = [0] + frame_dist # include ID + + self.frame_dist = frame_dist + self.n_samples = [1 + 2*i for i in self.frame_dist] + self.divide_by_n = divide_by_n + self.clipped = clipped + self.op = get_add_neighbour_op(2, frame_dist, divide_by_n, clipped) + + def forward(self, *input, **kwargs): + return self.perform(*input) + + def perform(self, k, mask): + ''' + + Parameters + ------------------------------ + inputs: two 5d tensors, [kspace_data, mask], each of shape (n, 2, NT, nx, ny) + + Returns + ------------------------------ + output: 5d tensor, missing lines of k-space are filled using neighbouring frames. + shape becomes (n* (len(frame_dist), 2, nt, nx, ny) + ''' + max_d = max(self.frame_dist) + k_orig = k + mask_orig = mask + if not self.clipped: + # pad input along nt direction, which is circular boundary condition. Otherwise, just pad outside + # places with 0 (zero-boundary condition) + k = torch.cat([k[:,:,-max_d:], k, k[:,:,:max_d]], 2) + mask = torch.cat([mask[:,:,-max_d:], mask, mask[:,:,:max_d]], 2) + + # start with x, then copy over accumulatedly... + res = self.op(k) + if not self.divide_by_n: + # divide by n basically means for each kspace location, if n non-zero values from neighboring + # time frames contributes to it, then divide this entry by n (like a normalization) + res_mask = self.op(mask) + res = res / res_mask.clamp(min=1) + else: + res_mask = self.op(torch.ones_like(mask)) + res = res / res_mask.clamp(min=1) + + res = data_consistency(res, + k_orig.repeat(1,len(self.frame_dist),1,1,1), + mask_orig.repeat(1,len(self.frame_dist),1,1,1)) + + nb, nc_ri, nt, nx, ny = res.shape # here ri_nc is complicated with data sharing replica and real-img dimension + res = res.reshape(nb, nc_ri//2, 2, nt, nx, ny) + return res + + +class AveragingInKspace(nn.Module): + ''' + Average-in-k-space layer + + First transforms the representation in Fourier domain, + then performs averaging along temporal axis, then transforms back to image + domain. Works only for 5D tensor (see parameter descriptions). + + + Parameters + ----------------------------- + incomings: two 5d tensors, [kspace_data, mask], each of shape (n, 2, nx, ny, nt) + + data_shape: shape of the incoming tensors: (n, 2, nx, ny, nt) (This is for convenience) + + frame_dist: a list of distances of neighbours to sample for each averaging channel + if frame_dist=[1], samples from [-1, 1] for each temporal frames + if frame_dist=[3, 5], samples from [-3,-2,...,0,1,...,3] for one, + [-5,-4,...,0,1,...,5] for the second one + + divide_by_n: bool - Decides how averaging will be done. + True => divide by number of neighbours (=#2*frame_dist+1) + False => divide by number of nonzero contributions + + clipped: bool - By default the layer assumes periodic boundary condition along temporal axis. + True => Averaging will be clipped at the boundary, no circular references. + False => Averages with circular referencing (i.e. at t=0, gets contribution from t=nt-1, so on). + + Returns + ------------------------------ + output: 5d tensor, missing lines of k-space are filled using neighbouring frames. + shape becomes (n, (len(frame_dist))* 2, nx, ny, nt) + ''' + + def __init__(self, frame_dist, divide_by_n=False, clipped=True, norm='ortho'): + super(AveragingInKspace, self).__init__() + self.normalized = norm == 'ortho' + self.frame_dist = frame_dist + self.divide_by_n = divide_by_n + self.kavg = KspaceFillNeighbourLayer(frame_dist, divide_by_n, clipped) + + def forward(self, *input, **kwargs): + return self.perform(*input) + + def perform(self, x, mask): + """ + x - input in image space, shape (n, 2, nx, ny, nt) + mask - corresponding nonzero location + """ + mask = mask.permute(0, 1, 4, 2, 3) + + x = x.permute(0, 4, 2, 3, 1) # put t to front, in convenience for fft + k = torch.fft(x, 2, normalized=self.normalized) + k = k.permute(0, 4, 1, 2, 3) # then put ri to the front, then t + + # data sharing + # nc is the numpy of copies of kspace, specified by frame_dist + out = self.kavg.perform(k, mask) + # after datasharing, it is nb, nc, 2, nt, nx, ny + + nb, nc, _, nt, nx, ny = out.shape # , jo's version + + # out.shape: [nb, 2*len(frame_dist), nt, nx, ny] + # we then detatch confused real/img channel and replica kspace channel due to datasharing (nc) + out = out.permute(0,1,3,4,5,2) # jo version, split ri and nc, put ri to the back for ifft + x_res = torch.ifft(out, 2, normalized=self.normalized) + + + # now nb, nc, nt, nx, ny, ri, put ri to channel position, and after nc (i.e. within each nc) + x_res = x_res.permute(0,1,5,3,4,2).reshape(nb, nc*2, nx,ny, nt)# jo version + + return x_res \ No newline at end of file diff --git a/cascadenet_pytorch/model_pytorch.py b/cascadenet_pytorch/model_pytorch.py new file mode 100755 index 00000000..a5c81d65 --- /dev/null +++ b/cascadenet_pytorch/model_pytorch.py @@ -0,0 +1,411 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable, grad +import numpy as np +import cascadenet_pytorch.kspace_pytorch as cl + + +def lrelu(): + return nn.LeakyReLU(0.01, inplace=True) + + +def relu(): + return nn.ReLU(inplace=True) + + +def conv_block(n_ch, nd, nf=32, ks=3, dilation=1, bn=False, nl='lrelu', conv_dim=2, n_out=None): + + # convolution dimension (2D or 3D) + if conv_dim == 2: + conv = nn.Conv2d + else: + conv = nn.Conv3d + + # output dim: If None, it is assumed to be the same as n_ch + if not n_out: + n_out = n_ch + + # dilated convolution + pad_conv = 1 + if dilation > 1: + # in = floor(in + 2*pad - dilation * (ks-1) - 1)/stride + 1) + # pad = dilation + pad_dilconv = dilation + else: + pad_dilconv = pad_conv + + def conv_i(): + return conv(nf, nf, ks, stride=1, padding=pad_dilconv, dilation=dilation, bias=True) + + conv_1 = conv(n_ch, nf, ks, stride=1, padding=pad_conv, bias=True) + conv_n = conv(nf, n_out, ks, stride=1, padding=pad_conv, bias=True) + + # relu + nll = relu if nl == 'relu' else lrelu + + layers = [conv_1, nll()] + for i in range(nd-2): + if bn: + layers.append(nn.BatchNorm2d(nf)) + layers += [conv_i(), nll()] + + layers += [conv_n] + + return nn.Sequential(*layers) + + +class DnCn(nn.Module): + def __init__(self, n_channels=2, nc=5, nd=5, **kwargs): + super(DnCn, self).__init__() + self.nc = nc + self.nd = nd + print('Creating D{}C{}'.format(nd, nc)) + conv_blocks = [] + dcs = [] + + conv_layer = conv_block + + for i in range(nc): + conv_blocks.append(conv_layer(n_channels, nd, **kwargs)) + dcs.append(cl.DataConsistencyInKspace(norm='ortho')) + + self.conv_blocks = nn.ModuleList(conv_blocks) + self.dcs = dcs + + def forward(self, x, k, m): + for i in range(self.nc): + x_cnn = self.conv_blocks[i](x) + x = x + x_cnn + x = self.dcs[i].perform(x, k, m) + + return x + + +class StochasticDnCn(DnCn): + def __init__(self, n_channels=2, nc=5, nd=5, p=None, **kwargs): + super(StochasticDnCn, self).__init__(n_channels, nc, nd, **kwargs) + + self.sample = False + self.p = p + if not p: + self.p = np.linspace(0, 0.5, nc) + print(self.p) + + def forward(self, x, k, m): + for i in range(self.nc): + + # stochastically drop connection + if self.training or self.sample: + if np.random.random() <= self.p[i]: + continue + + x_cnn = self.conv_blocks[i](x) + x = x + x_cnn + x = self.dcs[i].perform(x, k, m) + + return x + + def set_sample(self, flag=True): + self.sample = flag + + +class DnCn3D(nn.Module): + def __init__(self, n_channels=2, nc=5, nd=5, **kwargs): + super(DnCn3D, self).__init__() + self.nc = nc + self.nd = nd + print('Creating D{}C{} (3D)'.format(nd, nc)) + conv_blocks = [] + dcs = [] + + conv_layer = conv_block + + for i in range(nc): + conv_blocks.append(conv_layer(n_channels, nd, **kwargs)) + dcs.append(cl.DataConsistencyInKspace(norm='ortho')) + + self.conv_blocks = nn.ModuleList(conv_blocks) + self.dcs = nn.ModuleList(dcs) + + def forward(self, x, k, m): + for i in range(self.nc): + x_cnn = self.conv_blocks[i](x) + x = x + x_cnn + x = self.dcs[i].perform(x, k, m) + + return x + + +class DnCn3DDS(nn.Module): + def __init__(self, n_channels=2, nc=5, nd=5, fr_d=None, clipped=False, mode='pytorch', **kwargs): + """ + + Parameters + ---------- + + fr_d: frame distance for data sharing layer. e.g. [1, 3, 5] + + """ + super(DnCn3DDS, self).__init__() + self.nc = nc + self.nd = nd + self.mode = mode + print('Creating D{}C{}-DS (3D)'.format(nd, nc)) + if self.mode == 'theano': + print('Initialised with theano mode (backward-compatibility)') + conv_blocks = [] + dcs = [] + kavgs = [] + + if not fr_d: + fr_d = list(range(10)) + self.fr_d = fr_d + + conv_layer = conv_block + + # update input-output channels for data sharing + n_channels = 2 * len(fr_d) + n_out = 2 + kwargs.update({'n_out': 2}) + + for i in range(nc): + kavgs.append(cl.AveragingInKspace(fr_d, i>0, clipped, norm='ortho')) + conv_blocks.append(conv_layer(n_channels, nd, **kwargs)) + dcs.append(cl.DataConsistencyInKspace(norm='ortho')) + + self.conv_blocks = nn.ModuleList(conv_blocks) + self.dcs = nn.ModuleList(dcs) + self.kavgs = nn.ModuleList(kavgs) + + def forward(self, x, k, m): + for i in range(self.nc): + x_ds = self.kavgs[i](x, m) + if self.mode == 'theano': + # transpose the layes + x_ds_tmp = torch.zeros_like(x_ds) + nneigh = len(self.fr_d) + for j in range(nneigh): + x_ds_tmp[:,2*j] = x_ds[:,j] + x_ds_tmp[:,2*j+1] = x_ds[:,j+nneigh] + x_ds = x_ds_tmp + + x_cnn = self.conv_blocks[i](x_ds) + x = x + x_cnn + x = self.dcs[i](x, k, m) + + return x + + +class DnCn3DShared(nn.Module): + def __init__(self, n_channels=2, nc=5, nd=5, **kwargs): + super(DnCn3DShared, self).__init__() + self.nc = nc + self.nd = nd + print('Creating D{}C{}-S (3D)'.format(nd, nc)) + + self.conv_block = conv_block(n_channels, nd, **kwargs) + self.dc = cl.DataConsistencyInKspace(norm='ortho') + + def forward(self, x, k, m): + for i in range(self.nc): + x_cnn = self.conv_block(x) + x = x + x_cnn + x = self.dc.perform(x, k, m) + + return x + + +class CRNNcell(nn.Module): + """ + Convolutional RNN cell that evolves over both time and iterations + + Parameters + ----------------- + input: 4d tensor, shape (batch_size, channel, width, height) + hidden: hidden states in temporal dimension, 4d tensor, shape (batch_size, hidden_size, width, height) + hidden_iteration: hidden states in iteration dimension, 4d tensor, shape (batch_size, hidden_size, width, height) + + Returns + ----------------- + output: 4d tensor, shape (batch_size, hidden_size, width, height) + + """ + def __init__(self, input_size, hidden_size, kernel_size): + super(CRNNcell, self).__init__() + self.kernel_size = kernel_size + self.i2h = nn.Conv2d(input_size, hidden_size, kernel_size, padding=self.kernel_size // 2) + self.h2h = nn.Conv2d(hidden_size, hidden_size, kernel_size, padding=self.kernel_size // 2) + # add iteration hidden connection + self.ih2ih = nn.Conv2d(hidden_size, hidden_size, kernel_size, padding=self.kernel_size // 2) + self.relu = nn.ReLU(inplace=True) + + def forward(self, input, hidden_iteration, hidden): + in_to_hid = self.i2h(input) + hid_to_hid = self.h2h(hidden) + ih_to_ih = self.ih2ih(hidden_iteration) + + hidden = self.relu(in_to_hid + hid_to_hid + ih_to_ih) + + return hidden + + +class BCRNNlayer(nn.Module): + """ + Bidirectional Convolutional RNN layer + + Parameters + -------------------- + incomings: input: 5d tensor, [input_image] with shape (num_seqs, batch_size, channel, width, height) + input_iteration: 5d tensor, [hidden states from previous iteration] with shape (n_seq, n_batch, hidden_size, width, height) + test: True if in test mode, False if in train mode + + Returns + -------------------- + output: 5d tensor, shape (n_seq, n_batch, hidden_size, width, height) + + """ + def __init__(self, input_size, hidden_size, kernel_size): + super(BCRNNlayer, self).__init__() + self.hidden_size = hidden_size + self.kernel_size = kernel_size + self.input_size = input_size + self.CRNN_model = CRNNcell(self.input_size, self.hidden_size, self.kernel_size) + + def forward(self, input, input_iteration, test=False): + nt, nb, nc, nx, ny = input.shape + size_h = [nb, self.hidden_size, nx, ny] + if test: + with torch.no_grad(): + hid_init = Variable(torch.zeros(size_h)).cuda() + else: + hid_init = Variable(torch.zeros(size_h)).cuda() + + output_f = [] + output_b = [] + # forward + hidden = hid_init + for i in range(nt): + hidden = self.CRNN_model(input[i], input_iteration[i], hidden) + output_f.append(hidden) + + output_f = torch.cat(output_f) + + # backward + hidden = hid_init + for i in range(nt): + hidden = self.CRNN_model(input[nt - i - 1], input_iteration[nt - i -1], hidden) + + output_b.append(hidden) + output_b = torch.cat(output_b[::-1]) + + output = output_f + output_b + + if nb == 1: + output = output.view(nt, 1, self.hidden_size, nx, ny) + + return output + + +class CRNN_MRI(nn.Module): + """ + Model for Dynamic MRI Reconstruction using Convolutional Neural Networks + + Parameters + ----------------------- + incomings: three 5d tensors, [input_image, kspace_data, mask], each of shape (batch_size, 2, width, height, n_seq) + + Returns + ------------------------------ + output: 5d tensor, [output_image] with shape (batch_size, 2, width, height, n_seq) + """ + def __init__(self, n_ch=2, nf=64, ks=3, nc=5, nd=5): + """ + :param n_ch: number of channels + :param nf: number of filters + :param ks: kernel size + :param nc: number of iterations + :param nd: number of CRNN/BCRNN/CNN layers in each iteration + """ + super(CRNN_MRI, self).__init__() + self.nc = nc + self.nd = nd + self.nf = nf + self.ks = ks + + self.bcrnn = BCRNNlayer(n_ch, nf, ks) + self.conv1_x = nn.Conv2d(nf, nf, ks, padding = ks//2) + self.conv1_h = nn.Conv2d(nf, nf, ks, padding = ks//2) + self.conv2_x = nn.Conv2d(nf, nf, ks, padding = ks//2) + self.conv2_h = nn.Conv2d(nf, nf, ks, padding = ks//2) + self.conv3_x = nn.Conv2d(nf, nf, ks, padding = ks//2) + self.conv3_h = nn.Conv2d(nf, nf, ks, padding = ks//2) + self.conv4_x = nn.Conv2d(nf, n_ch, ks, padding = ks//2) + self.relu = nn.ReLU(inplace=True) + + dcs = [] + for i in range(nc): + dcs.append(cl.DataConsistencyInKspace(norm='ortho')) + self.dcs = dcs + + def forward(self, x, k, m, test=False): + """ + x - input in image domain, of shape (n, 2, nx, ny, n_seq) + k - initially sampled elements in k-space + m - corresponding nonzero location + test - True: the model is in test mode, False: train mode + """ + net = {} + n_batch, n_ch, width, height, n_seq = x.size() + size_h = [n_seq*n_batch, self.nf, width, height] + if test: + with torch.no_grad(): + hid_init = Variable(torch.zeros(size_h)).cuda() + else: + hid_init = Variable(torch.zeros(size_h)).cuda() + + for j in range(self.nd-1): + net['t0_x%d'%j]=hid_init + + for i in range(1,self.nc+1): + + x = x.permute(4,0,1,2,3) + x = x.contiguous() + net['t%d_x0' % (i - 1)] = net['t%d_x0' % (i - 1)].view(n_seq, n_batch,self.nf,width, height) + net['t%d_x0'%i] = self.bcrnn(x, net['t%d_x0'%(i-1)], test) + net['t%d_x0'%i] = net['t%d_x0'%i].view(-1,self.nf,width, height) + + net['t%d_x1'%i] = self.conv1_x(net['t%d_x0'%i]) + net['t%d_h1'%i] = self.conv1_h(net['t%d_x1'%(i-1)]) + net['t%d_x1'%i] = self.relu(net['t%d_h1'%i]+net['t%d_x1'%i]) + + net['t%d_x2'%i] = self.conv2_x(net['t%d_x1'%i]) + net['t%d_h2'%i] = self.conv2_h(net['t%d_x2'%(i-1)]) + net['t%d_x2'%i] = self.relu(net['t%d_h2'%i]+net['t%d_x2'%i]) + + net['t%d_x3'%i] = self.conv3_x(net['t%d_x2'%i]) + net['t%d_h3'%i] = self.conv3_h(net['t%d_x3'%(i-1)]) + net['t%d_x3'%i] = self.relu(net['t%d_h3'%i]+net['t%d_x3'%i]) + + net['t%d_x4'%i] = self.conv4_x(net['t%d_x3'%i]) + + x = x.view(-1,n_ch,width, height) + net['t%d_out'%i] = x + net['t%d_x4'%i] + + net['t%d_out'%i] = net['t%d_out'%i].view(-1,n_batch, n_ch, width, height) + net['t%d_out'%i] = net['t%d_out'%i].permute(1,2,3,4,0) + net['t%d_out'%i].contiguous() + net['t%d_out'%i] = self.dcs[i-1].perform(net['t%d_out'%i], k, m) + x = net['t%d_out'%i] + + # clean up i-1 + if test: + to_delete = [ key for key in net if ('t%d'%(i-1)) in key ] + + for elt in to_delete: + del net[elt] + + torch.cuda.empty_cache() + + return net['t%d_out'%i] + + diff --git a/main_crnn.py b/main_crnn.py new file mode 100644 index 00000000..3e851c96 --- /dev/null +++ b/main_crnn.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python +from __future__ import print_function, division + +import os +import time +import torch +import torch.optim as optim +from torch.autograd import Variable +import argparse +import matplotlib.pyplot as plt + +from os.path import join +from scipy.io import loadmat + +from utils import compressed_sensing as cs +from utils.metric import complex_psnr + +from cascadenet_pytorch.model_pytorch import * +from cascadenet_pytorch.dnn_io import to_tensor_format +from cascadenet_pytorch.dnn_io import from_tensor_format + + +def prep_input(im, acc=4.0): + """Undersample the batch, then reformat them into what the network accepts. + + Parameters + ---------- + gauss_ivar: float - controls the undersampling rate. + higher the value, more undersampling + """ + mask = cs.cartesian_mask(im.shape, acc, sample_n=8) + im_und, k_und = cs.undersample(im, mask, centred=False, norm='ortho') + im_gnd_l = torch.from_numpy(to_tensor_format(im)) + im_und_l = torch.from_numpy(to_tensor_format(im_und)) + k_und_l = torch.from_numpy(to_tensor_format(k_und)) + mask_l = torch.from_numpy(to_tensor_format(mask, mask=True)) + + return im_und_l, k_und_l, mask_l, im_gnd_l + + +def iterate_minibatch(data, batch_size, shuffle=True): + n = len(data) + + if shuffle: + data = np.random.permutation(data) + + for i in range(0, n, batch_size): + yield data[i:i+batch_size] + + +def create_dummy_data(): + """Create small cardiac data based on patches for demo. + + Note that in practice, at test time the method will need to be applied to + the whole volume. In addition, one would need more data to prevent + overfitting. + + """ + data = loadmat(join(project_root, './data/cardiac.mat'))['seq'] + nx, ny, nt = data.shape + ny_red = 8 + sl = ny//ny_red + data_t = np.transpose(data, (2, 0, 1)) + + # Synthesize data by extracting patches + train = np.array([data_t[..., i:i+sl] for i in np.random.randint(0, sl*3, 20)]) + validate = np.array([data_t[..., i:i+sl] for i in (sl*4, sl*5)]) + test = np.array([data_t[..., i:i+sl] for i in (sl*6, sl*7)]) + + return train, validate, test + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--num_epoch', metavar='int', nargs=1, default=['10'], + help='number of epochs') + parser.add_argument('--batch_size', metavar='int', nargs=1, default=['1'], + help='batch size') + parser.add_argument('--lr', metavar='float', nargs=1, + default=['0.001'], help='initial learning rate') + parser.add_argument('--acceleration_factor', metavar='float', nargs=1, + default=['4.0'], + help='Acceleration factor for k-space sampling') + parser.add_argument('--debug', action='store_true', help='debug mode') + parser.add_argument('--savefig', action='store_true', + help='Save output images and masks') + + args = parser.parse_args() + cuda = True if torch.cuda.is_available() else False + Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor + + # Project config + model_name = 'crnn_mri' + acc = float(args.acceleration_factor[0]) # undersampling rate + num_epoch = int(args.num_epoch[0]) + batch_size = int(args.batch_size[0]) + Nx, Ny, Nt = 256, 256, 30 + Ny_red = 8 + save_fig = args.savefig + save_every = 5 + + # Configure directory info + project_root = '.' + save_dir = join(project_root, 'models/%s' % model_name) + if not os.path.isdir(save_dir): + os.makedirs(save_dir) + + # Create dataset + train, validate, test = create_dummy_data() + + # Test creating mask and compute the acceleration rate + dummy_mask = cs.cartesian_mask((10, Nx, Ny//Ny_red), acc, sample_n=8) + sample_und_factor = cs.undersampling_rate(dummy_mask) + print('Undersampling Rate: {:.2f}'.format(sample_und_factor)) + + # Specify network + rec_net = CRNN_MRI() + criterion = torch.nn.MSELoss() + optimizer = optim.Adam(rec_net.parameters(), lr=float(args.lr[0]), betas=(0.5, 0.999)) + + # # build CRNN-MRI with pre-trained parameters + # rec_net.load_state_dict(torch.load('./models/pretrained/crnn_mri_d5_c5.pth')) + + if cuda: + rec_net = rec_net.cuda() + criterion.cuda() + + i = 0 + for epoch in range(num_epoch): + t_start = time.time() + # Training + train_err = 0 + train_batches = 0 + for im in iterate_minibatch(train, batch_size, shuffle=True): + im_und, k_und, mask, im_gnd = prep_input(im, acc) + im_u = Variable(im_und.type(Tensor)) + k_u = Variable(k_und.type(Tensor)) + mask = Variable(mask.type(Tensor)) + gnd = Variable(im_gnd.type(Tensor)) + + optimizer.zero_grad() + rec = rec_net(im_u, k_u, mask, test=False) + loss = criterion(rec, gnd) + loss.backward() + optimizer.step() + + train_err += loss.item() + train_batches += 1 + + if args.debug and train_batches == 20: + break + + validate_err = 0 + validate_batches = 0 + rec_net.eval() + for im in iterate_minibatch(validate, batch_size, shuffle=False): + im_und, k_und, mask, im_gnd = prep_input(im, acc) + with torch.no_grad(): + im_u = Variable(im_und.type(Tensor)) + k_u = Variable(k_und.type(Tensor)) + mask = Variable(mask.type(Tensor)) + gnd = Variable(im_gnd.type(Tensor)) + + pred = rec_net(im_u, k_u, mask, test=True) + err = criterion(pred, gnd) + + validate_err += err + validate_batches += 1 + + if args.debug and validate_batches == 20: + break + + vis = [] + test_err = 0 + base_psnr = 0 + test_psnr = 0 + test_batches = 0 + for im in iterate_minibatch(test, batch_size, shuffle=False): + im_und, k_und, mask, im_gnd = prep_input(im, acc) + with torch.no_grad(): + im_u = Variable(im_und.type(Tensor)) + k_u = Variable(k_und.type(Tensor)) + mask = Variable(mask.type(Tensor)) + gnd = Variable(im_gnd.type(Tensor)) + + pred = rec_net(im_u, k_u, mask, test=True) + err = criterion(pred, gnd) + test_err += err + for im_i, und_i, pred_i in zip(im, + from_tensor_format(im_und.numpy()), + from_tensor_format(pred.data.cpu().numpy())): + base_psnr += complex_psnr(im_i, und_i, peak='max') + test_psnr += complex_psnr(im_i, pred_i, peak='max') + + if save_fig and test_batches % save_every == 0: + vis.append((from_tensor_format(im_gnd.numpy())[0], + from_tensor_format(pred.data.cpu().numpy())[0], + from_tensor_format(im_und.numpy())[0], + from_tensor_format(mask.data.cpu().numpy(), mask=True)[0])) + + test_batches += 1 + if args.debug and test_batches == 20: + break + + t_end = time.time() + + train_err /= train_batches + validate_err /= validate_batches + test_err /= test_batches + base_psnr /= (test_batches*batch_size) + test_psnr /= (test_batches*batch_size) + + # Then we print the results for this epoch: + print("Epoch {}/{}".format(epoch+1, num_epoch)) + print(" time: {}s".format(t_end - t_start)) + print(" training loss:\t\t{:.6f}".format(train_err)) + print(" validation loss:\t{:.6f}".format(validate_err)) + print(" test loss:\t\t{:.6f}".format(test_err)) + print(" base PSNR:\t\t{:.6f}".format(base_psnr)) + print(" test PSNR:\t\t{:.6f}".format(test_psnr)) + + # save the model + if epoch in [1, 2, num_epoch-1]: + if save_fig: + + for im_i, pred_i, und_i, mask_i in vis: + im = abs(np.concatenate([und_i[0], pred_i[0], im_i[0], im_i[0] - pred_i[0]], 1)) + plt.imsave(join(save_dir, 'im{0}_x.png'.format(i)), im, cmap='gray') + + im = abs(np.concatenate([und_i[..., 0], pred_i[..., 0], + im_i[..., 0], im_i[..., 0] - pred_i[..., 0]], 0)) + plt.imsave(join(save_dir, 'im{0}_t.png'.format(i)), im, cmap='gray') + plt.imsave(join(save_dir, 'mask{0}.png'.format(i)), + np.fft.fftshift(mask_i[..., 0]), cmap='gray') + i += 1 + + name = '%s_epoch_%d.npz' % (model_name, epoch) + torch.save(rec_net.state_dict(), join(save_dir, name)) + print('model parameters saved at %s' % join(os.getcwd(), name)) + print('') diff --git a/models/pretrained/crnn_mri_d5_c5.pth b/models/pretrained/crnn_mri_d5_c5.pth new file mode 100644 index 00000000..17a64408 Binary files /dev/null and b/models/pretrained/crnn_mri_d5_c5.pth differ diff --git a/requirements.txt b/requirements.txt index 4105b1db..c6afa6a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ pytz==2017.2 scipy==0.19.0 six==1.10.0 Theano==0.9.0 +torch==0.4.0 diff --git a/utils/compressed_sensing.py b/utils/compressed_sensing.py index f6f0e7dc..2bd1c114 100755 --- a/utils/compressed_sensing.py +++ b/utils/compressed_sensing.py @@ -60,17 +60,17 @@ def cartesian_mask(shape, acc, sample_n=10, centred=False): pdf_x += lmda * 1./Nx if sample_n: - pdf_x[Nx/2-sample_n/2:Nx/2+sample_n/2] = 0 + pdf_x[Nx//2-sample_n//2:Nx//2+sample_n//2] = 0 pdf_x /= np.sum(pdf_x) n_lines -= sample_n mask = np.zeros((N, Nx)) - for i in xrange(N): + for i in range(N): idx = np.random.choice(Nx, n_lines, False, pdf_x) mask[i, idx] = 1 if sample_n: - mask[:, Nx/2-sample_n/2:Nx/2+sample_n/2] = 1 + mask[:, Nx//2-sample_n//2:Nx//2+sample_n//2] = 1 size = mask.itemsize mask = as_strided(mask, (N, Nx, Ny), (size * Nx, size, 0))