From 2eca759a9edd5e3b2a2723089ecb4165e56915f4 Mon Sep 17 00:00:00 2001 From: Margaret Duff Date: Thu, 7 Mar 2024 16:50:18 +0000 Subject: [PATCH 1/2] Stupid solution --- .../cil/optimisation/functions/Function.py | 18 +++++++++++++----- .../cil/optimisation/operators/Operator.py | 3 ++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/Wrappers/Python/cil/optimisation/functions/Function.py b/Wrappers/Python/cil/optimisation/functions/Function.py index fc454bb1dc..d34d07225c 100644 --- a/Wrappers/Python/cil/optimisation/functions/Function.py +++ b/Wrappers/Python/cil/optimisation/functions/Function.py @@ -124,9 +124,12 @@ def proximal_conjugate(self, x, tau, out=None): DataContainer, the value of the proximal operator of the convex conjugate at point :math:`x` for scalar :math:`\tau` or None if `out`. """ - if id(x)==id(out): - raise InPlaceError(message= "The proximal_conjugate of a CIL function cannot be used in place") + # if id(x)==id(out): + # raise InPlaceError(message= "The proximal_conjugate of a CIL function cannot be used in place") + if id(x)==id(out): + x=x.copy() + try: tmp = x x.divide(tau, out=tmp) @@ -146,6 +149,7 @@ def proximal_conjugate(self, x, tau, out=None): if out is None: return val + # Algebra for Function Class @@ -534,16 +538,19 @@ def proximal_conjugate(self, x, tau, out=None): DataContainer, the proximal conjugate operator for the function evaluated at :math:`x` and :math:`\tau` or `None` if `out`. """ - if out is not None and id(x)==id(out): - raise InPlaceError + # if out is not None and id(x)==id(out): + # raise InPlaceError + if id(x)==id(out): + x = x.copy() + try: tmp = x x.divide(tau, out=tmp) except TypeError: tmp = x.divide(tau, dtype=np.float32) - if out is None: + if out is None or id(x)==id(out): val = self.function.proximal(tmp, self.scalar/tau) else: self.function.proximal(tmp, self.scalar/tau, out=out) @@ -556,6 +563,7 @@ def proximal_conjugate(self, x, tau, out=None): if out is None: return val + out = val class SumScalarFunction(SumFunction): diff --git a/Wrappers/Python/cil/optimisation/operators/Operator.py b/Wrappers/Python/cil/optimisation/operators/Operator.py index 7a6e4af309..e361e02a83 100644 --- a/Wrappers/Python/cil/optimisation/operators/Operator.py +++ b/Wrappers/Python/cil/optimisation/operators/Operator.py @@ -137,7 +137,8 @@ def domain(self): @property def range(self): return self.range_geometry() - + + def __rmul__(self, scalar): '''Defines the multiplication by a scalar on the left From 6b2bef9e7d6d521faaee4c3d38d1181e0083e6b4 Mon Sep 17 00:00:00 2001 From: Margaret Duff Date: Tue, 1 Oct 2024 10:02:02 +0000 Subject: [PATCH 2/2] A non-satisfactory fix --- .../cil/optimisation/functions/TotalVariation.py | 13 +++++++------ Wrappers/Python/test/test_out_in_place.py | 1 + 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/Wrappers/Python/cil/optimisation/functions/TotalVariation.py b/Wrappers/Python/cil/optimisation/functions/TotalVariation.py index 6c438c76a8..316b05d69c 100644 --- a/Wrappers/Python/cil/optimisation/functions/TotalVariation.py +++ b/Wrappers/Python/cil/optimisation/functions/TotalVariation.py @@ -254,10 +254,6 @@ def __call__(self, x): def proximal(self, x, tau, out=None): r""" Returns the proximal operator of the TotalVariation function at :code:`x` .""" - if id(x)==id(out): - raise InPlaceError(message="TotalVariation.proximal cannot be used in place") - - if self.strong_convexity_constant > 0: strongly_convex_factor = (1 + tau * self.strong_convexity_constant) @@ -266,7 +262,8 @@ def proximal(self, x, tau, out=None): solution = self._fista_on_dual_rof(x, tau, out=out) if self.strong_convexity_constant > 0: - x *= strongly_convex_factor + if id(x) != id(solution): + x *= strongly_convex_factor tau *= strongly_convex_factor return solution @@ -306,12 +303,16 @@ def _fista_on_dual_rof(self, x, tau, out=None): if out is None: out = self.gradient_operator.domain_geometry().allocate(0) + if id(x) == id(out): + x_eval= x.copy() + else: + x_eval = x for k in range(self.iterations): t0 = t self.gradient_operator.adjoint(tmp_q, out=out) - out.sapyb(tau_reg_neg, x, 1.0, out=out) + out.sapyb(tau_reg_neg, x_eval, 1.0, out=out) self.projection_C(out, tau=None, out=out) self.gradient_operator.direct(out, out=p1) diff --git a/Wrappers/Python/test/test_out_in_place.py b/Wrappers/Python/test/test_out_in_place.py index d1395864a1..6dfcc63a4d 100644 --- a/Wrappers/Python/test/test_out_in_place.py +++ b/Wrappers/Python/test/test_out_in_place.py @@ -111,6 +111,7 @@ def setUp(self): (WeightedL2NormSquared(weight=b_ig), ig), (TotalVariation(backend='c', warm_start=False, max_iteration=100), ig), (TotalVariation(backend='numpy', warm_start=False, max_iteration=100), ig), + (TotalVariation(backend='numpy', warm_start=False, max_iteration=100, strong_convexity_constant=0.5), ig), (OperatorCompositionFunction(L2NormSquared(), A), ig), (MixedL21Norm(), bg), (SmoothMixedL21Norm(epsilon=0.3), bg),