Skip to content

Commit

Permalink
Merge pull request dipy#2947 from gabknight/RF_bootDG
Browse files Browse the repository at this point in the history
RF - BootDirectionGetter
  • Loading branch information
skoudoro authored Oct 23, 2023
2 parents d93f5b2 + b278fb4 commit 09e7231
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 187 deletions.
132 changes: 116 additions & 16 deletions dipy/direction/bootstrap_direction_getter.pyx
Original file line number Diff line number Diff line change
@@ -1,27 +1,77 @@
cimport numpy as cnp
import numpy as np

from dipy.core.interpolation cimport trilinear_interpolate4d_c
from dipy.data import default_sphere
from dipy.direction.closest_peak_direction_getter cimport (closest_peak,
BasePmfDirectionGetter)
from dipy.direction.pmf import BootPmfGen
from dipy.direction.closest_peak_direction_getter cimport closest_peak
from dipy.direction.peaks import peak_directions
from dipy.reconst import shm
from dipy.tracking.direction_getter cimport DirectionGetter


cdef class BootDirectionGetter(BasePmfDirectionGetter):
cdef class BootDirectionGetter(DirectionGetter):

cdef:
cnp.ndarray dwi_mask
cnp.ndarray vox_data
dict _pf_kwargs
double cos_similarity
double min_separation_angle
double relative_peak_threshold
double[:] pmf
double[:, :] R
double[:, :, :, :] data
int max_attempts
int sh_order
object H
object model
object sphere


def __init__(self, data, model, max_angle, sphere=default_sphere,
max_attempts=5, sh_order=0, b_tol=20, **kwargs):
cdef:
cnp.ndarray x, y, z, r
double[:] theta, phi
double[:, :] B

def __init__(self, pmfgen, maxangle, sphere=default_sphere,
max_attempts=5, **kwargs):
if max_attempts < 1:
raise ValueError("max_attempts must be greater than 0.")

if b_tol <= 0:
raise ValueError("b_tol must be greater than 0.")

self._pf_kwargs = kwargs
self.data = np.asarray(data, dtype=float)
self.model = model
self.cos_similarity = np.cos(np.deg2rad(max_angle))
self.sphere = sphere
self.sh_order = sh_order
self.max_attempts = max_attempts
BasePmfDirectionGetter.__init__(self, pmfgen, maxangle, sphere, **kwargs)

if self.sh_order == 0:
if hasattr(model, "sh_order"):
self.sh_order = model.sh_order
else:
self.sh_order = 4 # DEFAULT Value

self.dwi_mask = model.gtab.b0s_mask == 0
x, y, z = model.gtab.gradients[self.dwi_mask].T
r, theta, phi = shm.cart2sphere(x, y, z)
if r.max() - r.min() >= b_tol:
raise ValueError("BootDirectionGetter only supports single shell \
data.")
B, _, _ = shm.real_sh_descoteaux(self.sh_order, theta, phi)
self.H = shm.hat(B)
self.R = shm.lcr_matrix(self.H)

self.vox_data = np.empty(self.data.shape[3])
self.pmf = np.empty(sphere.vertices.shape[0])


@classmethod
def from_data(cls, data, model, max_angle, sphere=default_sphere,
sh_order=0, max_attempts=5, **kwargs):
sh_order=0, max_attempts=5, b_tol=20, **kwargs):
"""Create a BootDirectionGetter using HARDI data and an ODF type model
Parameters
Expand All @@ -41,17 +91,67 @@ cdef class BootDirectionGetter(BasePmfDirectionGetter):
max_attempts : int
Max number of bootstrap samples used to find tracking direction
before giving up.
pmf_threshold : float
Threshold for ODF functions.
b_tol : float
Maximum difference between b-values to be considered single shell.
relative_peak_threshold : float in [0., 1.]
Relative threshold for excluding ODF peaks.
min_separation_angle : float in [0, 90]
Angular threshold for excluding ODF peaks.
"""
boot_gen = BootPmfGen(np.asarray(data, dtype=float), model, sphere,
sh_order=sh_order)
return cls(boot_gen, max_angle, sphere, max_attempts, **kwargs)
return cls(data, model, max_angle, sphere, max_attempts, sh_order,
b_tol, **kwargs)


cpdef cnp.ndarray[cnp.float_t, ndim=2] initial_direction(self,
double[::1] point):
"""Returns best directions at seed location to start tracking.
Parameters
----------
point : ndarray, shape (3,)
The point in an image at which to lookup tracking directions.
Returns
-------
directions : ndarray, shape (N, 3)
Possible tracking directions from point. ``N`` may be 0, all
directions should be unique.
"""
cdef:
double[:] pmf = self.get_pmf_no_boot(point)

return peak_directions(pmf, self.sphere, **self._pf_kwargs)[0]


cpdef double[:] get_pmf(self, double[::1] point):
"""Produces an ODF from a SH bootstrap sample"""
if trilinear_interpolate4d_c(self.data, &point[0], self.vox_data) != 0:
self.__clear_pmf()
else:
self.vox_data[self.dwi_mask] = shm.bootstrap_data_voxel(
self.vox_data[self.dwi_mask], self.H, self.R)
self.pmf = self.model.fit(self.vox_data).odf(self.sphere)
return self.pmf


cpdef double[:] get_pmf_no_boot(self, double[::1] point):
if trilinear_interpolate4d_c(self.data, &point[0], self.vox_data) != 0:
self.__clear_pmf()
else:
self.pmf = self.model.fit(self.vox_data).odf(self.sphere)
return self.pmf


cdef void __clear_pmf(self) nogil:
cdef:
cnp.npy_intp len_pmf = self.pmf.shape[0]
cnp.npy_intp i

for i in range(len_pmf):
self.pmf[i] = 0.0


cdef int get_direction_c(self, double* point, double* direction):
"""Attempt direction getting on a few bootstrap samples.
Expand All @@ -63,12 +163,12 @@ cdef class BootDirectionGetter(BasePmfDirectionGetter):
1 otherwise.
"""
cdef:
double[:] pmf,
double[:] pmf
cnp.ndarray[cnp.float_t, ndim=2] peaks

for _ in range(self.max_attempts):
pmf = self._get_pmf(point)
peaks = self._get_peak_directions(pmf)
pmf = self.get_pmf(<double[:3]> point)
peaks = peak_directions(pmf, self.sphere, **self._pf_kwargs)[0]
if len(peaks) > 0:
return closest_peak(peaks, direction, self.cos_similarity)
return 1
13 changes: 0 additions & 13 deletions dipy/direction/pmf.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,3 @@ cdef class SHCoeffPmfGen(PmfGen):
double[:, :] B
double[:] coeff
pass


cdef class BootPmfGen(PmfGen):
cdef:
int sh_order
double[:, :] R
object model
object H
np.ndarray vox_data
np.ndarray dwi_mask

cpdef double[:] get_pmf_no_boot(self, double[::1] point)
pass
54 changes: 0 additions & 54 deletions dipy/direction/pmf.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -130,57 +130,3 @@ cdef class SHCoeffPmfGen(PmfGen):
_sum = _sum + (self.B[i, j] * self.coeff[j])
self.pmf[i] = _sum
return self.pmf


cdef class BootPmfGen(PmfGen):

def __init__(self,
double[:, :, :, :] dwi_array,
object model,
object sphere,
int sh_order=0,
double tol=1e-2):
cdef:
double b_range
cnp.ndarray x, y, z, r
double[:] theta, phi
double[:, :] B

PmfGen.__init__(self, dwi_array, sphere)
self.sh_order = sh_order
if self.sh_order == 0:
if hasattr(model, "sh_order"):
self.sh_order = model.sh_order
else:
self.sh_order = 4 # DEFAULT Value

self.dwi_mask = model.gtab.b0s_mask == 0
x, y, z = model.gtab.gradients[self.dwi_mask].T
r, theta, phi = shm.cart2sphere(x, y, z)
b_range = (r.max() - r.min()) / r.min()
if b_range > tol:
raise ValueError("BootPmfGen only supports single shell data.")
B, _, _ = shm.real_sh_descoteaux(self.sh_order, theta, phi)
self.H = shm.hat(B)
self.R = shm.lcr_matrix(self.H)
self.vox_data = np.empty(dwi_array.shape[3])

self.model = model
self.pmf = np.empty(len(sphere.theta))

cpdef double[:] get_pmf(self, double[::1] point):
"""Produces an ODF from a SH bootstrap sample"""
if trilinear_interpolate4d_c(self.data, &point[0], self.vox_data) != 0:
self.__clear_pmf()
else:
self.vox_data[self.dwi_mask] = shm.bootstrap_data_voxel(
self.vox_data[self.dwi_mask], self.H, self.R)
self.pmf = self.model.fit(self.vox_data).odf(self.sphere)
return self.pmf

cpdef double[:] get_pmf_no_boot(self, double[::1] point):
if trilinear_interpolate4d_c(self.data, &point[0], self.vox_data) != 0:
self.__clear_pmf()
else:
self.pmf = self.model.fit(self.vox_data).odf(self.sphere)
return self.pmf
Loading

0 comments on commit 09e7231

Please sign in to comment.