Skip to content

Commit

Permalink
add tv to example and tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrhardt committed Oct 18, 2018
1 parent c5838b8 commit 5efd25a
Showing 1 changed file with 60 additions and 42 deletions.
102 changes: 60 additions & 42 deletions examples/Python/PET/interactive/spdhg_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,17 @@ def imshow(image, newfig=True, limits=None, title=''):
# close all plots
plt.close('all')

#%% create OSMAPOSL reconstructor
# This implements the Ordered Subsets Maximum A-Posteriori One Step Late
# Since we are not using a penalty, or prior in this example, it
# defaults to using MLEM, but we will modify it to OSEM
#%% create TBC reconstructor

import pCIL # the code from this module needs to be imported somehow differently
#from pCIL import ZeroFun
from ccpi.optimisation.funcs import ZeroFun, IndicatorBox
from plugins.regularisers import FGP_TV
from ccpi.optimisation.funcs import ZeroFun, IndicatorBox # IndicatorBox does currently not work for SIRF
#from ccpi.filters.regularisers import FGP_TV

# IndicatorBox does currently not work for SIRF
#g_reg = IndicatorBox(lower=0,upper=1)
#x = g_reg.prox(image, 1)


data = acquired_data
background = data.copy()
background.fill(5)
Expand All @@ -147,35 +147,24 @@ def imshow(image, newfig=True, limits=None, title=''):

# the FGP_TV will output a CCPi DataContainer not a SIRF one, so
# we will need to wrap it in something compatible

from plugins.regularisers import FGP_TV # needs to be properly imported
class FGP_TV_SIRF(FGP_TV):
def prox(self, x, sigma):
print("calling FGP")
out = super(FGP_TV, self).prox(x, sigma)
out = super(FGP_TV_SIRF, self).prox(x, sigma)
y = x.copy()
y.fill(out.as_array())
return y

g_reg = FGP_TV_SIRF(lambdaReg=.1,
iterationsTV=1000,
g_reg = FGP_TV_SIRF(lambdaReg=.5,
iterationsTV=200,
tolerance=1e-5,
methodTV=0,
nonnegativity=1,
printing=0,
device='cpu')

g = FGP_TV(lambdaReg=.1,
iterationsTV=1000,
tolerance=1e-5,
methodTV=0,
nonnegativity=1,
printing=0,
device='cpu')

g_reg = IndicatorBox(lower=0,upper=1)


class OperatorInd():
class OperatorSubsetPET():

def __init__(self, op, subset_num, num_subsets):
self.__op__ = op
Expand Down Expand Up @@ -208,12 +197,14 @@ def adjoint(self, x):
num_subsets=self.__num_subsets__)

def allocate_direct(self, x=None):
# y = self.__op__.acq_templ.create_uniform_image()
y = self.__op__.img_templ.copy()
if x is not None:
y.fill(x)
return y

def allocate_adjoint(self, x=None):
# y = self.__op__.acq_templ.get_uniform_copy()
y = pet.AcquisitionData(self.__op__.acq_templ)
if x is not None:
y.fill(x)
Expand All @@ -228,7 +219,6 @@ def sub2sirf(self, x):
return y

def sirf2sub(self, x):
#y = numpy.zeros(len(self.ind))
return x.as_array().flatten()[self.ind]

#class SubsetOperator():
Expand All @@ -250,36 +240,64 @@ def sirf2sub(self, x):
# return OperatorInd(self.__op__, i, len(self))


def SubsetOperator(op, nsubsets):
return [OperatorInd(op, ind, nsubsets) for ind in range(nsubsets)]
def OperatorSubsetsPET(op, num_subsets):
return [OperatorSubsetPET(op, ind, num_subsets)
for ind in range(num_subsets)]

niter = 20
num_subsets = 16
num_epochs = 3
num_iter = num_epochs * num_subsets

A = SubsetOperator(am, 14)
A_norms = [pCIL.PowerMethodNonsquare(Ai, 10, x0=image.copy()) for Ai in A]
A = OperatorSubsetsPET(am, num_subsets)
A_norms = [1.05 * pCIL.PowerMethodNonsquare(Ai, 10, x0=image.copy())
for Ai in A]

# increase the norms to allow for inaccuracies in their computation
Ls = [1.05 * L for L in A_norms]
As = OperatorSubsetsPET(am, num_subsets)
As_norms = [1.05 * pCIL.PowerMethodNonsquare(Ai, 10, x0=image.copy())
for Ai in As]

f = [pCIL.KullbackLeibler(op.sirf2sub(noisy_data), op.sirf2sub(background))
for op in A]

fs = [pCIL.KullbackLeibler(op.sirf2sub(noisy_data), op.sirf2sub(background))
for op in As]


#%%
recon_noreg = pCIL.spdhg(f, g_noreg, A, A_norms=Ls)
recon_subsets_noreg = pCIL.spdhg(fs, g_noreg, As, A_norms=As_norms)

# %%
for i in range(3):
print(recon_noreg.iter)
recon_noreg.update()
for i in range(num_iter):
print(recon_subsets_noreg.iter)
recon_subsets_noreg.update()

#%% does currently not work!
#recon_reg = pCIL.spdhg(f, g_reg, A, A_norms=Ls)
#%%
recon_reg = pCIL.spdhg(f, g_reg, A, A_norms=A_norms)

# %%
#for i in range(niter):
# print(recon_reg)
# recon_reg.update()
for i in range(num_epochs):
print(recon_reg.iter)
recon_reg.update()

#%%
recon_subsets_reg = pCIL.spdhg(fs, g_reg, As, A_norms=As_norms)

# %%
for i in range(num_iter):
print(recon_subsets_reg.iter)
recon_subsets_reg.update()

#%% show result
imshow3(recon_noreg.x, limits=[-0.3,cmax], title='recon noreg')
#imshow3(recon_reg.x, limits=[-0.3,cmax], title='recon reg')
plt.figure()

plt.subplot(1, 3, 1)
imshow3(recon_subsets_noreg.x, limits=[-0.3, cmax],
newfig=False, title='recon subsets noreg')
plt.subplot(1, 3, 2)

imshow3(recon_reg.x, limits=[-0.3, cmax],
newfig=False, title='recon reg')

plt.subplot(1, 3, 3)
imshow3(recon_subsets_reg.x, limits=[-0.3, cmax],
newfig=False, title='recon subsets reg')

0 comments on commit 5efd25a

Please sign in to comment.