diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py index 8c97ac69f..0d6cff41c 100644 --- a/scico/linop/xray/astra.py +++ b/scico/linop/xray/astra.py @@ -185,7 +185,11 @@ def _proj(self, x: jax.Array) -> jax.Array: # apply the forward projector and generate a sinogram def f(x): - x.setflags(write=True) + if not x.flags.writeable: + try: + x.setflags(write=True) + except ValueError("cannot set WRITEABLE flag"): + x = x.copy() if self.num_dims == 2: proj_id, result = astra.create_sino(x, self.proj_id) astra.data2d.delete(proj_id) @@ -199,7 +203,11 @@ def f(x): def _bproj(self, y: jax.Array) -> jax.Array: # apply backprojector def f(y): - y.setflags(write=True) + if not y.flags.writeable: + try: + y.setflags(write=True) + except ValueError("cannot set WRITEABLE flag"): + y = y.copy() if self.num_dims == 2: proj_id, result = astra.create_backprojection(y, self.proj_id) astra.data2d.delete(proj_id)