diff --git a/pylops/basicoperators/restriction.py b/pylops/basicoperators/restriction.py index 2595c971..c2fceab7 100644 --- a/pylops/basicoperators/restriction.py +++ b/pylops/basicoperators/restriction.py @@ -12,7 +12,7 @@ get_array_module, get_normalize_axis_index, inplace_set, - to_cupy_conditional, + to_numpy, ) from pylops.utils.typing import DTypeLike, InputDimsLike, IntNDArray, NDArray @@ -146,7 +146,7 @@ def __init__( # explicitly create a list of indices in the n-dimensional # model space which will be used in _rmatvec to place the input if ncp != np: - self.iavamask = _compute_iavamask(self.dims, axis, iava, ncp) + self.iavamask = _compute_iavamask(self.dims, axis, to_numpy(iava), ncp) self.inplace = inplace self.axis = axis self.iavareshape = iavareshape @@ -173,7 +173,6 @@ def _rmatvec(self, x: NDArray) -> NDArray: ) else: if not hasattr(self, "iavamask"): - self.iava = to_cupy_conditional(x, self.iava) self.iavamask = _compute_iavamask(self.dims, self.axis, self.iava, ncp) y = ncp.zeros(int(self.shape[-1]), dtype=self.dtype) y = inplace_set(x.ravel(), y, self.iavamask)