Skip to content

Commit

Permalink
add first version of proxSVRG alg.
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Sep 5, 2024
1 parent 65e5730 commit 84bd4e5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 20 deletions.
71 changes: 53 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,8 @@ def __init__(
)

penalization_factor = data.prior.get_penalisation_factor()

# WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations)
data.prior.set_penalisation_factor(penalization_factor / self._num_subsets)
data.prior.set_up(data.OSEM_image)

self._subset_prior_fct = data.prior
self._prior = data.prior

self._adjoint_ones = self.x.get_uniform_copy(0)

Expand Down Expand Up @@ -182,6 +178,7 @@ def __init__(
self._precond_filter.set_up(data.OSEM_image)

# calculate the initial preconditioner based on the initial image
self._prior_diag_hess = None
self._precond = self.calc_precond(self.x)

if update_objective_interval is None:
Expand All @@ -190,6 +187,10 @@ def __init__(
super().__init__(update_objective_interval=update_objective_interval, **kwargs)
self.configured = True # required by Algorithm

# prox related parameters
self._num_prox_iter = 5
self._prior_prox_step = 1.0

@property
def epoch(self):
return self._update // self._num_subsets
Expand Down Expand Up @@ -218,8 +219,8 @@ def calc_precond(
x_sm = self._precond_filter.process(x)
delta = delta_rel * x_sm.max()

prior_diag_hess = x_sm.get_uniform_copy(0)
prior_diag_hess.fill(
self._prior_diag_hess = x_sm.get_uniform_copy(0)
self._prior_diag_hess.fill(
to_device(
self._python_prior.diag_hessian(
xp.asarray(x_sm.as_array(), device=self._dev)
Expand All @@ -233,7 +234,7 @@ def calc_precond(
* (x_sm + delta)
/ (
self._adjoint_ones
+ (self._precond_hessian_factor * 2) * prior_diag_hess * x_sm
+ (self._precond_hessian_factor * 2) * self._prior_diag_hess * x_sm
)
)

Expand All @@ -242,10 +243,11 @@ def calc_precond(
def update_all_subset_gradients(self) -> None:

self._summed_subset_gradients = self.x.get_uniform_copy(0)
self._subset_gradients = [
f.gradient(self.x) for f in self._subset_likelihood_funcs
]
self._summed_subset_gradients = sum(self._subset_gradients)
self._subset_gradients = []

for f in self._subset_likelihood_funcs:
self._subset_gradients.append(f.gradient(self.x))
self._summed_subset_gradients += self._subset_gradients[-1]

def update(self):

Expand Down Expand Up @@ -299,17 +301,50 @@ def update(self):
# XXX

tau = self._step_size
T = self._precond
prior_prox_pc = 1 / (1 / T + tau * self._prior_diag_hess)

import pdb
T = self._precond + 1e-6 * (-self._fov_mask + 1)

pdb.set_trace()
prior_prox_pc = (T.power(-1) + tau * self._prior_diag_hess).power(-1)

# self.x = self._prior_prox(tmp, tau=tau, T=T, precond=prior_prox_pc)
self.x = self.approximate_prior_prox(
tmp, tau=tau, T=T, prox_precond=prior_prox_pc
)

self._update += 1

def approximate_prior_prox(
self, z: STIR.ImageData, tau, T: STIR.ImageData, prox_precond: STIR.ImageData
):

u = z.maximum(0)

for k in range(self._num_prox_iter):
# compute gradient step
grad = (u - z) / T + tau * self._prior.gradient(u)
u_new = u - self._prior_prox_step * prox_precond * grad
u_new.maximum(0, out=u_new)

## update step size

# if self._adaptive_step_size:
# diff_new = xp.linalg.norm(u_new - u)

# if k == 0:
# u = u_new
# diff = diff_new
# else:
# if diff_new <= diff:
# self._step *= self._up
# u = u_new
# diff = diff_new
# else:
# self._step /= self._down
# else:
# u = u_new

u = u_new

return u

def update_objective(self) -> None:
"""
NB: The objective value is not required by OSEM nor by PETRIC, so this returns `0`.
Expand Down
3 changes: 1 addition & 2 deletions test_petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,8 @@ def test_petric(ds: int, num_iter: int, suffix: str = "", **kwargs):
precond_update_epochs=precond_update_epochs,
)
else:
# for i in range(4):
for i in [0, 1, 3, 2]:
for ns in [25, 50, 10]:
for ns in [25]:
test_petric(
ds=i, num_iter=300, suffix=f"num_sub_{ns}", approx_num_subsets=ns
)

0 comments on commit 84bd4e5

Please sign in to comment.