From 23d5c10ff0704f66ad7ec65a8cdcd09bd2420591 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Thu, 5 Dec 2024 11:37:42 -0800 Subject: [PATCH] [Mosaic:TPU] Fix fully replicated relayout It was incorrect since batch dims are not replicated PiperOrigin-RevId: 703189919 --- .../tpu/transforms/apply_vector_layout.cc | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 50a7d57346a6..5cbb5e620c88 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -6588,15 +6588,20 @@ FailureOr> relayout(RewriteContext &ctx, /*use_implicit_shape=*/true); } if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() && - !src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) { + !src.offsets()[1].has_value()) { // A fully replicated value is always easy to relayout - // It would be nice to be able to assert this here, but given replicated - // values our rules can introduce equivalent expressions. - // assert all(t is src_tiles_list[0] for t in src_tiles_list) xla::Array dst_tiles( - /*sizes=*/dst.tileArrayShape(vty.getShape(), target_shape), - /*value=*/src_tiles.data()[0]); - return assemble_with_mask_check(dst_tiles); + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + SmallVector idxs; + dst_tiles.Each([&](const absl::Span src_idx, Value *vreg) { + idxs.assign(src_idx.begin(), src_idx.end()); + dst.eraseImplicit(idxs); + src.insertImplicit(idxs, 0); + *(idxs.end() - 2) = 0; + *(idxs.end() - 1) = 0; + *vreg = src_tiles(idxs); + }); + return assemble_with_mask_check(dst_tiles, /*use_implicit_shape=*/true); } // Consider (1,128),-2 -> (8,128). In this case we can change the implicit