Skip to content

Commit

Permalink
Merge pull request #757 from EmmaRenauld/option_noRandomSeed
Browse files Browse the repository at this point in the history
ENH: Adding option to skip randomization of seed. + other fixes
  • Loading branch information
arnaudbore authored Nov 9, 2023
2 parents 77ea9d1 + ae7d874 commit d531467
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 81 deletions.
16 changes: 12 additions & 4 deletions scilpy/tracking/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def __init__(self, datavolume, step_size, rk_order, space, origin):
# By default, normalizing directions. Adding option for child classes.
self.normalize_directions = True

self.line_rng_generator = None # Will be reset at each new streamline.

def reset_data(self, new_data=None):
"""
Reset data before starting a new process. In current implementation,
Expand All @@ -81,7 +83,7 @@ def reset_data(self, new_data=None):
"""
self.datavolume.data = new_data

def prepare_forward(self, seeding_pos):
def prepare_forward(self, seeding_pos, random_generator):
"""
Prepare information necessary at the first point of the
streamline for forward propagation: v_in and any other information
Expand All @@ -92,6 +94,7 @@ def prepare_forward(self, seeding_pos):
seeding_pos: tuple(x,y,z)
The seeding position. Important, position must be in the same space
and origin as self.space, self.origin!
random_generator: numpy Generator.
Returns
-------
Expand All @@ -100,6 +103,8 @@ def prepare_forward(self, seeding_pos):
Return PropagationStatus.ERROR if no good tracking direction can be
set at current seeding position.
"""
# To be defined by child classes.
# Should set self.line_rng_generator = random_generator
raise NotImplementedError

def prepare_backward(self, line, forward_dir):
Expand Down Expand Up @@ -421,7 +426,7 @@ def _get_sf(self, pos):
sf /= sf_max
return sf

def prepare_forward(self, seeding_pos):
def prepare_forward(self, seeding_pos, random_generator):
"""
Prepare information necessary at the first point of the
streamline for forward propagation: v_in and any other information
Expand All @@ -437,6 +442,7 @@ def prepare_forward(self, seeding_pos):
seeding_pos: tuple(x,y,z)
The seeding position. Important, position must be in the same space
and origin as self.space, self.origin!
random_generator: numpy Generator
Returns
-------
Expand All @@ -452,9 +458,10 @@ def prepare_forward(self, seeding_pos):
# "more probable" peak.
sf = self._get_sf(seeding_pos)
sf[sf < self.sf_threshold_init] = 0
self.line_rng_generator = random_generator

if np.sum(sf) > 0:
ind = sample_distribution(sf)
ind = sample_distribution(sf, self.line_rng_generator)
return TrackingDirection(self.dirs[ind], ind)

# Else: sf at current position is smaller than acceptable threshold in
Expand Down Expand Up @@ -485,7 +492,8 @@ def _sample_next_direction(self, pos, v_in):

# Sampling one.
if np.sum(sf) > 0:
v_out = directions[sample_distribution(sf)]
v_out = directions[sample_distribution(sf,
self.line_rng_generator)]
else:
return None
elif self.algo == 'det':
Expand Down
175 changes: 120 additions & 55 deletions scilpy/tracking/seed.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# -*- coding: utf-8 -*-
import logging

import numpy as np

from dipy.io.stateful_tractogram import Space, Origin


class SeedGenerator(object):
class SeedGenerator:
"""
Class to get seeding positions.
Expand All @@ -19,63 +17,83 @@ class SeedGenerator(object):
in the range x = [0, 3], y = [3, 6], z = [6, 9].
"""
def __init__(self, data, voxres,
space=Space('vox'), origin=Origin('center')):
space=Space('vox'), origin=Origin('center'), n_repeats=1):
"""
Parameters
----------
data: np.array
data: np.ndarray
The data, ex, loaded from nibabel img.get_fdata(). It will be used
to find all voxels with values > 0, but will not be kept in memory.
voxres: np.array(3,)
voxres: np.ndarray(3,)
The pixel resolution, ex, using img.header.get_zooms()[:3].
n_repeats: int
Number of times a same seed position is returned.
If used, we supposed that calls to either get_next_pos or
get_next_n_pos are used sequentially. Not verified.
"""
self.voxres = voxres

self.n_repeats = n_repeats
self.origin = origin
self.space = space
if space == Space.RASMM:
raise NotImplementedError("We do not support rasmm space.")
elif space not in [Space.VOX, Space.VOXMM]:
raise ValueError("Space should be a choice of Dipy Space.")
if origin not in [Origin.NIFTI, Origin.TRACKVIS]:
raise ValueError("Origin should be a choice of Dipy Origin.")

# self.seed are all the voxels where a seed could be placed
# (voxel space, origin=corner, int numbers).
self.seeds_vox = np.array(np.where(np.squeeze(data) > 0),
dtype=float).transpose()
self.seeds_vox_corner = np.array(np.where(np.squeeze(data) > 0),
dtype=float).transpose()

if len(self.seeds_vox) == 0:
logging.warning("There are no positive voxels in the seeding "
"mask!")
if len(self.seeds_vox_corner) == 0:
raise ValueError("There are no positive voxels in the seeding "
"mask!")

def get_next_pos(self, random_generator, indices, which_seed):
# We use this to remember last offset if n_repeats > 1:
self.previous_offset = None

def get_next_pos(self, random_generator, shuffled_indices, which_seed):
"""
Generate the next seed position (Space=voxmm, origin=corner).
See self.init()_generator to get the generator and shuffled_indices.
To be used with self.n_repeats, we suppose that sequential
get_next_pos calls are used with sequentials values of which_seed.
Parameters
----------
random_generator : numpy random generator
random_generator: numpy random generator
Initialized numpy number generator.
indices : List
shuffled_indices: np.array
Indices of current seeding map.
which_seed : int
which_seed: int
Seed number to be processed.
(which_seed // self.n_repeats corresponds to the index of the
chosen seed in the flattened seeding mask).
Return
------
seed_pos: tuple
Position of next seed expressed in mm.
"""
len_seeds = len(self.seeds_vox)
if len_seeds == 0:
return []
nb_seed_voxels = len(self.seeds_vox_corner)

# Voxel selection from the seeding mask
ind = which_seed % len_seeds
x, y, z = self.seeds_vox[indices[ind]]

# Subvoxel initial positioning. Right now x, y, z are in vox space,
# origin=corner, so between 0 and 1.
r_x = random_generator.uniform(0, 1)
r_y = random_generator.uniform(0, 1)
r_z = random_generator.uniform(0, 1)
ind = (which_seed // self.n_repeats) % nb_seed_voxels
x, y, z = self.seeds_vox_corner[shuffled_indices[ind]]

if which_seed % self.n_repeats == 0:
# Subvoxel initial positioning. Right now x, y, z are in vox space,
# origin=corner, so between 0 and 1.
r_x, r_y, r_z = random_generator.uniform(0, 1, size=3)
self.previous_offset = (r_x, r_y, r_z)
else:
# Supposing that calls to get_next_pos are used correctly:
# previous_offset should already exist and correspond to the
# correct offset.
r_x, r_y, r_z = self.previous_offset

# Moving inside the voxel
x += r_x
Expand All @@ -95,45 +113,81 @@ def get_next_pos(self, random_generator, indices, which_seed):
else:
raise NotImplementedError("We do not support rasmm space.")

def get_next_n_pos(self, random_generator, indices, which_seeds):
def get_next_n_pos(self, random_generator, shuffled_indices,
which_seed_start, n):
"""
Generate the next n seed positions. Intended for GPU usage.
Equivalent to:
for s in range(which_seed_start, which_seed_start + nb_seeds):
self.get_next_pos(..., s)
See description of get_next_pos for more information.
To be used with self.n_repeats, we suppose that sequential
get_next_n_pos calls are used with sequential values of
which_seed_start (with steps of nb_seeds).
Parameters
----------
random_generator : numpy random generator
random_generator: numpy random generator
Initialized numpy number generator.
indices : numpy array
shuffled_indices: np.array
Indices of current seeding map.
which_seeds : numpy array
Seed numbers (i.e. IDs) to be processed.
which_seed_start: int
First seed numbers to be processed.
(which_seed_start // self.n_repeats corresponds to the index of the
chosen seed in the flattened seeding mask).
n: int
Number of seeds to get.
Return
------
seed_pos: List[List]
seeds: List[List]
Positions of next seeds expressed seed_generator's space and
origin.
"""

len_seeds = len(self.seeds_vox)

if len_seeds == 0:
return []
nb_seed_voxels = len(self.seeds_vox_corner)
which_seeds = np.arange(which_seed_start, which_seed_start + n)

# Voxel selection from the seeding mask
inds = which_seeds % len_seeds
# Same seed is re-used n_repeats times
inds = (which_seeds // self.n_repeats) % nb_seed_voxels

# Sub-voxel initial positioning
# Prepare sub-voxel random movement now (faster out of loop)
n = len(which_seeds)
r_x = random_generator.uniform(0, 1, size=n)
r_y = random_generator.uniform(0, 1, size=n)
r_z = random_generator.uniform(0, 1, size=n)
r_x = np.zeros((n,))
r_y = np.zeros((n,))
r_z = np.zeros((n,))

# Find where which_seeds % self.n_repeats == 0
# Note. If where_new_offsets[0] is False, supposing that calls to
# get_next_n_pos are used correctly: previous_offset should already
# exist and correspond to the correct offset.
where_new_offsets = ~(which_seeds % self.n_repeats).astype(bool)
ind_first = np.where(where_new_offsets)[0][0]
if not where_new_offsets[0]:
assert self.previous_offset is not None

# First continuing previous_offset.
r_x[0:ind_first] = self.previous_offset[0]
r_y[0:ind_first] = self.previous_offset[1]
r_z[0:ind_first] = self.previous_offset[2]

# Then starting and repeating new offsets.
nb_new_offsets = np.sum(where_new_offsets)
new_offsets_x = random_generator.uniform(0, 1, size=nb_new_offsets)
new_offsets_y = random_generator.uniform(0, 1, size=nb_new_offsets)
new_offsets_z = random_generator.uniform(0, 1, size=nb_new_offsets)
nb_r = n - ind_first
r_x[ind_first:] = np.repeat(new_offsets_x, self.n_repeats)[:nb_r]
r_y[ind_first:] = np.repeat(new_offsets_y, self.n_repeats)[:nb_r]
r_z[ind_first:] = np.repeat(new_offsets_z, self.n_repeats)[:nb_r]

# Save previous offset for next batch
self.previous_offset = (r_x[-1], r_y[-1], r_z[-1])

seeds = []
# Looping. toDo, see if can be done faster.
for i in range(len(which_seeds)):
x, y, z = self.seeds_vox[indices[inds[i]]]
x, y, z = self.seeds_vox_corner[shuffled_indices[inds[i]]]

# Moving inside the voxel
x += r_x[i]
Expand All @@ -158,29 +212,40 @@ def get_next_n_pos(self, random_generator, indices, which_seeds):

return seeds

def init_generator(self, random_initial_value, first_seed_of_chunk):
def init_generator(self, rng_seed, numbers_to_skip):
"""
Initialize numpy number generator according to user's parameter
and indexes from the seeding map.
Initialize a numpy number generator according to user's parameters.
Returns also the suffled index of all voxels.
The values are not stored in this classed, but are returned to the
user, who should add them as arguments in the methods
self.get_next_pos()
self.get_next_n_pos()
The use of this is that with multiprocessing, each process may have its
own generator, with less risk of using the wrong one when they are
managed by the user.
Parameters
----------
random_initial_value : int
rng_seed : int
The "seed" for the random generator.
first_seed_of_chunk : int
Number of seeds to skip (skip parameter + multi-processor skip).
numbers_to_skip : int
Number of seeds (i.e. voxels) to skip. Useful if you want to
continue sampling from the same generator as in a first experiment
(with a fixed rng_seed).
Return
------
random_generator : numpy random generator
Initialized numpy number generator.
indices : ndarray
Indices of current seeding map.
Shuffled indices of current seeding map, shuffled with current
generator.
"""
random_generator = np.random.RandomState(random_initial_value)
random_generator = np.random.RandomState(rng_seed)

# 1. Initializing seeding maps indices (shuffling in-place)
indices = np.arange(len(self.seeds_vox))
indices = np.arange(len(self.seeds_vox_corner))
random_generator.shuffle(indices)

# 2. Initializing the random generator
Expand All @@ -189,7 +254,7 @@ def init_generator(self, random_initial_value, first_seed_of_chunk):
# process (i.e this chunk)'s set of random numbers. Producing only
# 100000 at the time to prevent RAM overuse.
# (Multiplying by 3 for x,y,z)
random_numbers_to_skip = first_seed_of_chunk * 3
random_numbers_to_skip = numbers_to_skip * 3
# toDo: see if 100000 is ok, and if we can create something not
# hard-coded
while random_numbers_to_skip > 100000:
Expand Down
10 changes: 10 additions & 0 deletions scilpy/tracking/tests/test_propagator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-


def test_class_propagator():
"""
We will not test the tracker / propagator : too big to be tested, and only
used in scil_compute_local_tracking_dev, which is intented for developping
and testing new parameters.
"""
pass
Loading

0 comments on commit d531467

Please sign in to comment.