diff --git a/src/simfmri/handlers/acquisition/workers.py b/src/simfmri/handlers/acquisition/workers.py index f899806..961ca29 100644 --- a/src/simfmri/handlers/acquisition/workers.py +++ b/src/simfmri/handlers/acquisition/workers.py @@ -230,24 +230,25 @@ def acq_noncartesian( if nufft_backend == "stacked": kwargs["z_index"] = "auto" logger.debug("extra kwargs %s", kwargs) - if "gpunufft" in nufft_backend: - from mrinufft.operators.interfaces.gpunufft import make_pinned_smaps - - smaps = make_pinned_smaps(sim.smaps) - else: - smaps = sim.smaps op_kwargs = dict( shape=sim.shape, n_coils=sim.n_coils, density=False, backend_name=nufft_backend, ) + if "gpunufft" in nufft_backend: + logger.debug("Using gpunufft, pinning smaps") + from mrinufft.operators.interfaces.gpunufft import make_pinned_smaps + + op_kwargs["pinned_smaps"] = make_pinned_smaps(sim.smaps) + op_kwargs["smaps"] = None + else: + op_kwargs["smaps"] = sim.smaps scheduler = kspace_bulk_shot(trajectory_gen, sim.n_frames, n_shot_sim_frame) with Parallel(n_jobs=n_jobs, backend="loky", mmap_mode="r") as par: par( delayed(_single_worker)( sim_frame, - smaps, shot_batch, shot_pos, op_kwargs, @@ -287,7 +288,6 @@ def work_generator(sim: SimData, kspace_bulk_gen: Generator) -> Generator[tuple] def _single_worker( sim_frame: np.ndarray, - smaps: np.ndarray, shot_batch: np.ndarray, shot_pos: tuple[int, int], op_kwargs: Mapping[str, Any], @@ -302,17 +302,7 @@ def _single_worker( category=UserWarning, module="mrinufft", ) - if "gpunufft" in [ - op_kwargs.get("backend_name", None), - op_kwargs.get("nufft_backend", None), - ]: - fourier_op = get_operator( - samples=shot_batch, - pinned_smaps=smaps, - **op_kwargs, - ) - else: - fourier_op = get_operator(samples=shot_batch, smaps=smaps, **op_kwargs) + fourier_op = get_operator(samples=shot_batch, **op_kwargs) kspace = fourier_op.op(sim_frame) L = shot_batch.shape[1]