Skip to content

Commit

Permalink
Merge pull request dipy#2940 from jhlegarreta/FilterSHBasisWarningsIn…
Browse files Browse the repository at this point in the history
…PTTTests

TEST: Filter legacy SH bases warnings in PTT direction getter test
  • Loading branch information
skoudoro authored Oct 19, 2023
2 parents 5f914f7 + 2ac1a76 commit 268b718
Showing 1 changed file with 47 additions and 30 deletions.
77 changes: 47 additions & 30 deletions dipy/direction/tests/test_ptt_direction_getter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
"""Test file for Parallel Transport Tracking Algorithm."""
import warnings

import numpy as np
import numpy.testing as npt
from dipy.core.sphere import unit_octahedron
from dipy.data import get_fnames, default_sphere
from dipy.direction import PTTDirectionGetter
from dipy.io.image import load_nifti
from dipy.reconst.shm import SphHarmFit, SphHarmModel, sh_to_sf
from dipy.reconst.shm import (
SphHarmFit,
SphHarmModel,
sh_to_sf,
descoteaux07_legacy_msg,
tournier07_legacy_msg,
)
from dipy.tracking.local_tracking import LocalTracking
from dipy.tracking.stopping_criterion import BinaryStoppingCriterion
from dipy.tracking.streamline import Streamlines
Expand Down Expand Up @@ -99,39 +107,48 @@ def fit(self, data, mask=None):
dir = unit_octahedron.vertices[0].copy()

# Make ptt_dg from shm_coeffs
dg = PTTDirectionGetter.from_shcoeff(fit.shm_coeff, 90,
unit_octahedron)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=descoteaux07_legacy_msg,
category=PendingDeprecationWarning)
dg = PTTDirectionGetter.from_shcoeff(fit.shm_coeff, 90,
unit_octahedron)
npt.assert_equal(dg.get_direction(point, dir), 1)

# Make ptt_dg from pmf
pmf = np.zeros((3, 3, 3, unit_octahedron.theta.shape[0]))
dg = PTTDirectionGetter.from_pmf(pmf, 90, unit_octahedron)
npt.assert_equal(dg.get_direction(point, dir), 1)

# Check probe_length ValueError
npt.assert_raises(ValueError,
PTTDirectionGetter.from_shcoeff,
fit.shm_coeff, 90, unit_octahedron,
basis_type="tournier07",
probe_length=0)

# Check probe_radius ValueError
npt.assert_raises(ValueError,
PTTDirectionGetter.from_shcoeff,
fit.shm_coeff, 90, unit_octahedron,
basis_type="tournier07",
probe_radius=-1)

# Check probe_quality ValueError
npt.assert_raises(ValueError,
PTTDirectionGetter.from_shcoeff,
fit.shm_coeff, 90, unit_octahedron,
basis_type="tournier07",
probe_quality=1)

# Check probe_length ValueError
npt.assert_raises(ValueError,
PTTDirectionGetter.from_shcoeff,
fit.shm_coeff, 90, unit_octahedron,
basis_type="tournier07",
probe_count=0)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=tournier07_legacy_msg,
category=PendingDeprecationWarning)

# Check probe_length ValueError
npt.assert_raises(ValueError,
PTTDirectionGetter.from_shcoeff,
fit.shm_coeff, 90, unit_octahedron,
basis_type="tournier07",
probe_length=0)

# Check probe_radius ValueError
npt.assert_raises(ValueError,
PTTDirectionGetter.from_shcoeff,
fit.shm_coeff, 90, unit_octahedron,
basis_type="tournier07",
probe_radius=-1)

# Check probe_quality ValueError
npt.assert_raises(ValueError,
PTTDirectionGetter.from_shcoeff,
fit.shm_coeff, 90, unit_octahedron,
basis_type="tournier07",
probe_quality=1)

# Check probe_length ValueError
npt.assert_raises(ValueError,
PTTDirectionGetter.from_shcoeff,
fit.shm_coeff, 90, unit_octahedron,
basis_type="tournier07",
probe_count=0)

0 comments on commit 268b718

Please sign in to comment.