From 4a41aa0a46085f95437bf9853c1786836c3d2321 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 5 Dec 2024 08:03:37 -0800 Subject: [PATCH] [pallas:mosaic_gpu] Removed unnecessarily strict check in `emit_pipeline` PiperOrigin-RevId: 703117465 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 069b8d9e78d3..ee3f03f1849f 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -181,17 +181,6 @@ def emit_pipeline( delay_release = 0 # No need to delay anything. def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): - for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)): - if any( - spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore - for idx in range(1, len(grid) + 1) - if spec.block_shape is not None - ): - raise NotImplementedError( - f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block" - f" shape {spec.block_shape}." - ) - in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) in_smem_refs, out_smem_refs = util.split_list( [