From 76a4202f8b571caf1a95a1d61db20eedb8e2b7cd Mon Sep 17 00:00:00 2001 From: Erin Date: Sun, 2 Feb 2020 22:47:46 +0000 Subject: [PATCH] adding first version of bval parser --- dmriprep/utils/vectors.py | 112 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/dmriprep/utils/vectors.py b/dmriprep/utils/vectors.py index d9284ef3..50360aaa 100644 --- a/dmriprep/utils/vectors.py +++ b/dmriprep/utils/vectors.py @@ -7,6 +7,7 @@ B0_THRESHOLD = 50 BVEC_NORM_EPSILON = 0.1 +SHELL_DIFF_THRES = 150 class DiffusionGradientTable: @@ -177,6 +178,117 @@ def to_filename(self, filename, filetype='rasb'): else: raise ValueError('Unknown filetype "%s"' % filetype) +class BVALScheme: + """Data structure for bval scheme.""" + + def __init__(self, bvals = None, rasb_file = None, shell_diff_thres = SHELL_DIFF_THRES): + """ + Parse the available bvals into shells + + Parameters + ---------- + bvals : str or os.pathlike or numpy.ndarray + File path of the b-values. + rasb_file : str or os.pathlike + File path to a RAS-B gradient table. If rasb_file is provided, + then bvecs and bvals will be dismissed. + """ + self._bvals = None + self._shell_diff_thres = shell_diff_thres + + if rasb_file is not None: + if isinstance(rasb_file, (str, Path)): + gradients = np.loadtxt(gradients, skiprows=1) + self._bvals = np.squeeze(gradients[..., -1]) + elif bvals is not None: + if isinstance(bvals, (str, Path)): + bvals = np.loadtxt(str(bvals)).flatten() + self._bvals = np.array(bvals) + + self._kclust = self._k_cluster_result() + + def _k_cluster_result(self): + ''' determine the shell system by running k clustering return a dict with masks separated by shell''' + for k in range(1,len(np.unique(self._bvals)) + 1): + kmeans_res = KMeans(n_clusters = k).fit(self._bvals.reshape(-1,1)) + if kmeans_res.inertia_/len(self._bvals) < self._shell_diff_thres: + return kmeans_res + print('Sorry, bval parsing failed - no shells are more than {} apart found'.format(shell_diff_thres)) + return None + + @property + def bvals(self): + """Get the N b-values.""" + return self._bvals + + @property + def shells(self): + '''return sorted shells rounded to nearest 100''' + shells = np.round(np.squeeze(self._kclust.cluster_centers_),-2) + if shells.size == 1: + return np.array(shells).reshape(1,-1) #convert back to iterable type + else: + return np.sort(shells) + + @property + def n_shells(self): + ''' returns number of non-zero shells''' + return sum(self.shells != 0) + + def get_shell_centers(self, shell = 'all'): + ''' returns non rounded shell centers''' + all_centers = np.squeeze(self._kclust.cluster_centers_) + if all_centers.size > 1: + all_centers = np.sort(all_centers) + else: + all_centers = np.array(all_centers).reshape(1,-1) + if shell == 'all': + return all_centers + elif shell in self.shells: + return all_centers[self.shells == shell] + else: + print("could not find shell {} in bvals".format(shell)) + return None + + def get_shell_mask(self, shell): + ''' returns the mask for a given shell''' + shell_center = self.get_shell_centers(shell = shell) + clustid = np.where(self._kclust.cluster_centers_ == shell_center)[0] + mask = self._kclust.labels_ == clustid + return mask + + def get_n_directions_in_shell(self, shell): + ''' returns the number of directions in a shell''' + return sum(self.get_shell_mask(shell)) + + @property + def b0_mask(self): + return self.get_shell_mask(shell = 0) + + @property + def n_b0(self): + '''returns number of b0s''' + return np.sum(self.b0_mask) + + @property + def total_directions(self): + '''prints number of non-b0 directions (assuming they are unique)''' + return np.sum(np.invert(self.b0_mask)) + + def __str__(self): + ''' prints pretty string summary of bvals ''' + shell_list = [] + for shell in self.shells: + if shell > 0: + shell_list.append("{}:{}".format(int(shell), + self.get_n_directions_in_shell(shell))) + shell_str = "{} B0s + {} directions in {} shell(s) | B-value:n_directions {}".format( + self.n_b0, + self.total_directions, + self.n_shells, + ", ".join(shell_list)) + return shell_str + def normalize_gradients(bvecs, bvals, b0_threshold=B0_THRESHOLD, bvec_norm_epsilon=BVEC_NORM_EPSILON, b_scale=True):