From aa7ba0fcc6cde2b6e44ac6c9b378ccd0ff825da4 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Mon, 18 Dec 2023 09:40:45 -0700 Subject: [PATCH] Try again to fix the flags.writeable problem --- scico/linop/xray/astra.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py index 8c97ac69f..9f39d994e 100644 --- a/scico/linop/xray/astra.py +++ b/scico/linop/xray/astra.py @@ -185,7 +185,7 @@ def _proj(self, x: jax.Array) -> jax.Array: # apply the forward projector and generate a sinogram def f(x): - x.setflags(write=True) + x = ensure_writeable(x) if self.num_dims == 2: proj_id, result = astra.create_sino(x, self.proj_id) astra.data2d.delete(proj_id) @@ -199,7 +199,7 @@ def f(x): def _bproj(self, y: jax.Array) -> jax.Array: # apply backprojector def f(y): - y.setflags(write=True) + y = ensure_writeable(y) if self.num_dims == 2: proj_id, result = astra.create_backprojection(y, self.proj_id) astra.data2d.delete(proj_id) @@ -230,8 +230,7 @@ def fbp(self, sino: jax.Array, filter_type: str = "Ram-Lak") -> jax.Array: # Just use the CPU FBP alg for now; hitting memory issues with GPU one. def f(sino): - if sino.flags.writeable == False: - sino = sino.copy() + sino = ensure_writeable(sino) sino_id = astra.data2d.create("-sino", self.proj_geom, sino) # create memory for result @@ -258,3 +257,14 @@ def f(sino): return out return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), sino) + + +def ensure_writeable(x): + """Ensure that `x.flags.writeable` is ``True``, copying if needed.""" + + if not x.flags.writeable: + try: + x.setflags(write=True) + except ValueError: + x = x.copy() + return x