Skip to content

Commit

Permalink
Try again to fix the flags.writeable problem
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Dec 18, 2023
1 parent 88f04fe commit aa7ba0f
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions scico/linop/xray/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _proj(self, x: jax.Array) -> jax.Array:
# apply the forward projector and generate a sinogram

def f(x):
x.setflags(write=True)
x = ensure_writeable(x)
if self.num_dims == 2:
proj_id, result = astra.create_sino(x, self.proj_id)
astra.data2d.delete(proj_id)
Expand All @@ -199,7 +199,7 @@ def f(x):
def _bproj(self, y: jax.Array) -> jax.Array:
# apply backprojector
def f(y):
y.setflags(write=True)
y = ensure_writeable(y)
if self.num_dims == 2:
proj_id, result = astra.create_backprojection(y, self.proj_id)
astra.data2d.delete(proj_id)
Expand Down Expand Up @@ -230,8 +230,7 @@ def fbp(self, sino: jax.Array, filter_type: str = "Ram-Lak") -> jax.Array:

# Just use the CPU FBP alg for now; hitting memory issues with GPU one.
def f(sino):
if sino.flags.writeable == False:
sino = sino.copy()
sino = ensure_writeable(sino)
sino_id = astra.data2d.create("-sino", self.proj_geom, sino)

# create memory for result
Expand All @@ -258,3 +257,14 @@ def f(sino):
return out

return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), sino)


def ensure_writeable(x):
"""Ensure that `x.flags.writeable` is ``True``, copying if needed."""

if not x.flags.writeable:
try:
x.setflags(write=True)
except ValueError:
x = x.copy()
return x

0 comments on commit aa7ba0f

Please sign in to comment.