Skip to content

Commit

Permalink
[pallas:mosaic_gpu] Removed unnecessarily strict check in `emit_pipel…
Browse files Browse the repository at this point in the history
…ine`

PiperOrigin-RevId: 703117465
  • Loading branch information
superbobry authored and Google-ML-Automation committed Dec 5, 2024
1 parent 5fe5206 commit 4a41aa0
Showing 1 changed file with 0 additions and 11 deletions.
11 changes: 0 additions & 11 deletions jax/_src/pallas/mosaic_gpu/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,6 @@ def emit_pipeline(
delay_release = 0 # No need to delay anything.

def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)):
if any(
spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore
for idx in range(1, len(grid) + 1)
if spec.block_shape is not None
):
raise NotImplementedError(
f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block"
f" shape {spec.block_shape}."
)

in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
in_smem_refs, out_smem_refs = util.split_list(
[
Expand Down

0 comments on commit 4a41aa0

Please sign in to comment.