Skip to content

Commit

Permalink
refactor: move smaps to op_kwargs.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Dec 5, 2023
1 parent 4dd6a8c commit dea99a4
Showing 1 changed file with 9 additions and 19 deletions.
28 changes: 9 additions & 19 deletions src/simfmri/handlers/acquisition/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,24 +230,25 @@ def acq_noncartesian(
if nufft_backend == "stacked":
kwargs["z_index"] = "auto"
logger.debug("extra kwargs %s", kwargs)
if "gpunufft" in nufft_backend:
from mrinufft.operators.interfaces.gpunufft import make_pinned_smaps

smaps = make_pinned_smaps(sim.smaps)
else:
smaps = sim.smaps
op_kwargs = dict(
shape=sim.shape,
n_coils=sim.n_coils,
density=False,
backend_name=nufft_backend,
)
if "gpunufft" in nufft_backend:
logger.debug("Using gpunufft, pinning smaps")
from mrinufft.operators.interfaces.gpunufft import make_pinned_smaps

op_kwargs["pinned_smaps"] = make_pinned_smaps(sim.smaps)
op_kwargs["smaps"] = None
else:
op_kwargs["smaps"] = sim.smaps
scheduler = kspace_bulk_shot(trajectory_gen, sim.n_frames, n_shot_sim_frame)
with Parallel(n_jobs=n_jobs, backend="loky", mmap_mode="r") as par:
par(
delayed(_single_worker)(
sim_frame,
smaps,
shot_batch,
shot_pos,
op_kwargs,
Expand Down Expand Up @@ -287,7 +288,6 @@ def work_generator(sim: SimData, kspace_bulk_gen: Generator) -> Generator[tuple]

def _single_worker(
sim_frame: np.ndarray,
smaps: np.ndarray,
shot_batch: np.ndarray,
shot_pos: tuple[int, int],
op_kwargs: Mapping[str, Any],
Expand All @@ -302,17 +302,7 @@ def _single_worker(
category=UserWarning,
module="mrinufft",
)
if "gpunufft" in [
op_kwargs.get("backend_name", None),
op_kwargs.get("nufft_backend", None),
]:
fourier_op = get_operator(
samples=shot_batch,
pinned_smaps=smaps,
**op_kwargs,
)
else:
fourier_op = get_operator(samples=shot_batch, smaps=smaps, **op_kwargs)
fourier_op = get_operator(samples=shot_batch, **op_kwargs)
kspace = fourier_op.op(sim_frame)
L = shot_batch.shape[1]

Expand Down

0 comments on commit dea99a4

Please sign in to comment.