Skip to content

Commit

Permalink
clean up SVRG code (remove direct python calls)
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Aug 30, 2024
1 parent 9d296fd commit 94d1ee3
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 89 deletions.
44 changes: 23 additions & 21 deletions main_SVRG.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
verbose: bool = False,
complete_gradient_epochs: None | list[int] = None,
precond_update_epochs: None | list[int] = None,
precond_hessian_factor: float = 16.0,
**kwargs,
):
"""
Expand All @@ -79,6 +80,7 @@ def __init__(
self._update = 0
self._step_size = initial_step_size
self._subset_number_list = []
self._precond_hessian_factor = precond_hessian_factor

self._data_sub, self._acq_models, self._obj_funs = partitioner.data_partition(
data.acquired_data,
Expand All @@ -98,15 +100,12 @@ def __init__(
for f in self._obj_funs: # add prior evenly to every objective function
f.set_prior(data.prior)

self._subset_adjoint_ones = []
self._adjoint_ones = self.x.get_uniform_copy(0)

for i in range(num_subsets):
if self._verbose:
print(f"Calculating subset {i} sensitivity")
subset_adjoint_ones = self._obj_funs[i].get_subset_sensitivity(0)
self._subset_adjoint_ones.append(subset_adjoint_ones)

self._adjoint_ones = np.sum(self._subset_adjoint_ones)
self._adjoint_ones += self._obj_funs[i].get_subset_sensitivity(0)

self._fov_mask = self.x.get_uniform_copy(0)
tmp = 1.0 * (self._adjoint_ones.as_array() > 0)
Expand Down Expand Up @@ -143,6 +142,10 @@ def __init__(
self._python_prior.kappa = data.kappa.as_array()
self._python_prior.scale = data.prior.get_penalisation_factor()

self._precond_filter = STIR.SeparableGaussianImageFilter()
self._precond_filter.set_fwhms([5.0, 5.0, 5.0])
self._precond_filter.set_up(data.OSEM_image)

# calculate the initial preconditioner based on the initial image
self._precond = self.calc_precond(self.x)

Expand All @@ -157,22 +160,23 @@ def calc_precond(
self,
x: STIR.ImageData,
delta_rel: float = 1e-6,
prior_diag_factor: float = 16.0,
) -> STIR.ImageData:

# generate a smoothed version of the input image
# to avoid high values, especially in first and last slices
xx = x.get_uniform_copy(0)
sig = 5.0 / (2.35 * np.array(xx.spacing))
xx.fill(gaussian_filter(x.as_array(), sig))

delta = delta_rel * xx.max()
prior_diag_hess = xx.get_uniform_copy(0)

prior_diag_hess.fill(self._python_prior.diag_hessian(xx.as_array()))

precond = (xx + delta) / (
self._adjoint_ones + prior_diag_factor * prior_diag_hess * xx
x_sm = self._precond_filter.process(x)
delta = delta_rel * x_sm.max()

prior_diag_hess = x_sm.get_uniform_copy(0)
prior_diag_hess.fill(self._python_prior.diag_hessian(x_sm.as_array()))

precond = (
self._fov_mask
* (x_sm + delta)
/ (
self._adjoint_ones
+ self._precond_hessian_factor * prior_diag_hess * x_sm
)
)

return precond
Expand Down Expand Up @@ -226,12 +230,10 @@ def update(self):
)

### Objective has to be maximized -> "+" for gradient ascent
self.x = self.x + self._step_size * self._precond * self._fov_mask * grad
self.x = self.x + self._step_size * self._precond * grad

# enforce non-negative constraint
tmp = self.x.as_array()
np.clip(tmp, 0, None, out=tmp)
self.x.fill(tmp)
self.x.maximum(0, out=self.x)

self._update += 1

Expand Down
131 changes: 63 additions & 68 deletions test_petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def test_petric(
metric_period: int,
complete_gradient_epochs: None | list[int] = None,
precond_update_epochs: None | list[int] = None,
precond_hessian_factor: float = 16.0,
):

# get arguments and values such that we can dump them in the outdir
Expand All @@ -346,39 +347,25 @@ def test_petric(
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")

sdir_name = f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}_phf_{precond_hessian_factor}"

if ds == 0:
srcdir = SRCDIR / "Siemens_mMR_NEMA_IQ"
outdir = (
OUTDIR
/ "mMR_NEMA"
/ f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}"
)
outdir = OUTDIR / "mMR_NEMA" / sdir_name
metrics = [
MetricsWithTimeout(outdir=outdir, transverse_slice=72, coronal_slice=109)
]
elif ds == 1:
srcdir = SRCDIR / "NeuroLF_Hoffman_Dataset"
outdir = (
OUTDIR
/ "NeuroLF_Hoffman"
/ f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}"
)
outdir = OUTDIR / "NeuroLF_Hoffman" / sdir_name
metrics = [MetricsWithTimeout(outdir=outdir, transverse_slice=72)]
elif ds == 2:
srcdir = SRCDIR / "Siemens_Vision600_thorax"
outdir = (
OUTDIR
/ "Vision600_thorax"
/ f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}"
)
outdir = OUTDIR / "Vision600_thorax" / sdir_name
metrics = [MetricsWithTimeout(outdir=outdir)]
elif ds == 3:
srcdir = SRCDIR / "Siemens_mMR_ACR"
outdir = (
OUTDIR
/ "Siemens_mMR_ACR"
/ f"{formatted_datetime}_ss_{step_size}_n_{num_iter}_subs_{num_subsets}"
)
outdir = OUTDIR / "Siemens_mMR_ACR" / sdir_name
metrics = [MetricsWithTimeout(outdir=outdir)]
else:
raise ValueError(f"Unknown data set {ds}")
Expand Down Expand Up @@ -446,6 +433,7 @@ def test_petric(
num_subsets=num_subsets,
complete_gradient_epochs=complete_gradient_epochs,
precond_update_epochs=precond_update_epochs,
precond_hessian_factor=precond_hessian_factor,
)
algo.run(num_iter, callbacks=metrics + submission_callbacks)

Expand Down Expand Up @@ -489,51 +477,58 @@ def test_petric(
precond_update_epochs=precond_update_epochs,
)
else:
for step_size in [1.0, 1.5]:
# data set 0 "mMR_NEMA_IQ" - num views 252
for num_subsets in [28]:
test_petric(
step_size=step_size,
ds=0,
num_iter=300,
num_subsets=num_subsets,
metric_period=num_subsets,
complete_gradient_epochs=complete_gradient_epochs,
precond_update_epochs=precond_update_epochs,
)

# data set 1 "neuro LF" - num views 128
for num_subsets in [16]:
test_petric(
step_size=step_size,
ds=1,
num_iter=300,
num_subsets=num_subsets,
metric_period=num_subsets,
complete_gradient_epochs=complete_gradient_epochs,
precond_update_epochs=precond_update_epochs,
)

# data set 2 "vision" - num views 50
for num_subsets in [25]:
test_petric(
step_size=step_size,
ds=2,
num_iter=200,
num_subsets=num_subsets,
metric_period=num_subsets,
complete_gradient_epochs=complete_gradient_epochs,
precond_update_epochs=precond_update_epochs,
)

# data set 4 "mMR_ACR" - num views 252
for num_subsets in [28]:
test_petric(
step_size=1.0,
ds=3,
num_iter=3 * 28 + 1,
num_subsets=num_subsets,
metric_period=num_subsets,
complete_gradient_epochs=complete_gradient_epochs,
precond_update_epochs=precond_update_epochs,
)
# for phf in [8.0, 32.0, 4.0]:
for phf in [16]:
for step_size in [1.0]:
# data set 0 "mMR_NEMA_IQ" - num views 252
for num_subsets in [28]:
test_petric(
step_size=step_size,
ds=0,
num_iter=300,
num_subsets=num_subsets,
metric_period=num_subsets,
complete_gradient_epochs=complete_gradient_epochs,
precond_update_epochs=precond_update_epochs,
precond_hessian_factor=phf,
)

# # data set 1 "neuro LF" - num views 128
# for num_subsets in [32]:
# test_petric(
# step_size=step_size,
# ds=1,
# num_iter=300,
# num_subsets=num_subsets,
# metric_period=num_subsets,
# complete_gradient_epochs=complete_gradient_epochs,
# precond_update_epochs=precond_update_epochs,
# precond_hessian_factor=phf,
# )
#
# # data set 2 "vision" - num views 50
# for num_subsets in [25]:
# test_petric(
# step_size=step_size,
# ds=2,
# num_iter=300,
# num_subsets=num_subsets,
# metric_period=num_subsets,
# complete_gradient_epochs=complete_gradient_epochs,
# precond_update_epochs=precond_update_epochs,
# precond_hessian_factor=phf,
# )
#
# # data set 4 "mMR_ACR" - num views 252
# for num_subsets in [28]:
# test_petric(
# step_size=step_size,
# ds=3,
# num_iter=300,
# num_subsets=num_subsets,
# metric_period=num_subsets,
# complete_gradient_epochs=complete_gradient_epochs,
# precond_update_epochs=precond_update_epochs,
# precond_hessian_factor=phf,
# )
#

0 comments on commit 94d1ee3

Please sign in to comment.