diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index 5838451b8a6e..6f1ee9cbe5f1 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -774,6 +774,21 @@ llvm::hash_code hash_value(const VectorLayout& layout) { return llvm::hash_value(layout.as_tuple()); } +std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim) { + switch (dim) { + case VectorLayout::ImplicitDim::kNone: + os << "none"; + break; + case VectorLayout::ImplicitDim::kMinor: + os << "-1"; + break; + case VectorLayout::ImplicitDim::kSecondMinor: + os << "-2"; + break; + } + return os; +} + std::optional parseLayout(mlir::AsmParser& parser) { std::string layout_str; if (failed(parser.parseString(&layout_str))) { diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index db93b5eb9a52..e7b4e1e045f4 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -326,6 +326,7 @@ std::ostream &operator<<(std::ostream &os, const Layout &v); llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Layout &v); llvm::hash_code hash_value(const VectorLayout &layout); mlir::Diagnostic &operator<<(mlir::Diagnostic &diag, const Layout &v); +std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim); std::optional parseLayout(mlir::AsmParser &parser); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 8e4cac2983c8..fba18cca938f 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2209,7 +2209,8 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, } const std::array allow_replicated = {!reduces[0], !reduces[1]}; - if (!src_layout.hasNativeTiling(ctx.target_shape)) { + if ((reduces[0] || reduces[1]) && + !src_layout.hasNativeTiling(ctx.target_shape)) { return multi_reduction_op.emitOpError( "Not implemented: Unsupported input layout: ") << src_layout; @@ -2243,11 +2244,13 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, dst_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor; // Anything works. } else if (reduces[0]) { + CHECK_EQ(src_layout.implicit_dim(), VectorLayout::ImplicitDim::kNone); dst_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor; } else if (reduces[1]) { + CHECK_EQ(src_layout.implicit_dim(), VectorLayout::ImplicitDim::kNone); dst_implicit_dim = VectorLayout::ImplicitDim::kMinor; } else { - dst_implicit_dim = VectorLayout::ImplicitDim::kNone; + dst_implicit_dim = src_layout.implicit_dim(); } if (dst_layout.implicit_dim() != dst_implicit_dim) { return multi_reduction_op.emitOpError( diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index ccf10c09f0c1..73b76b68cdc8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -44,6 +44,7 @@ limitations under the License. #include "absl/log/log.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/include/mlir/IR/Attributes.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "xla/layout.h" @@ -956,34 +957,65 @@ class VectorLayoutInferer { auto src_ty = op.getSourceVectorType(); auto dst_ty = dyn_cast(op.getDestType()); TPU_CHECK_OP(dst_ty, "only reductions with vector results supported"); - TPU_CHECK_OP(src_ty.getRank() == dst_ty.getRank() + 1, - "only 1D reductions supported"); - int64_t dim = cast(op.getReductionDims()[0]).getInt(); + SmallVector dims; + dims.reserve(op.getReductionDims().size()); + for (Attribute dim_attr : op.getReductionDims()) { + dims.push_back(cast(dim_attr).getInt()); + } int64_t src_rank = src_ty.getRank(); - auto acc_pad = getLayout(op.getAcc()); - TPU_CHECK_OP(is_fully_replicated(acc_pad), + auto acc_layout = getLayout(op.getAcc()); + TPU_CHECK_OP(is_fully_replicated(acc_layout), "only constant accumulators supported"); TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == kNativeBitwidth, "only 32-bit reductions supported"); auto some_src_layout = getLayout(op.getSource()); TPU_CHECK_OP(some_src_layout, "missing vector layout"); auto &src_layout = *some_src_layout; - TPU_CHECK_OP(src_layout.implicit_dim() == ImplicitDim::kNone, - "only 2D layouts supported"); - if (dim == src_rank - 1) { - setLayout( - op, {src_layout, acc_pad}, - VectorLayout(kNativeBitwidth, {src_layout.offsets()[0], std::nullopt}, - default_tiling_, ImplicitDim::kMinor)); - } else if (dim == src_rank - 2) { - setLayout( - op, {src_layout, acc_pad}, - VectorLayout(kNativeBitwidth, {std::nullopt, src_layout.offsets()[1]}, - default_tiling_, ImplicitDim::kSecondMinor)); - } else { - // Reduction happens over the unrolled dimension --- we can keep layout. - setLayout(op, {src_layout, acc_pad}, src_layout); + std::array reduces; + switch (src_layout.implicit_dim()) { + case VectorLayout::ImplicitDim::kNone: + reduces = { + std::find(dims.begin(), dims.end(), src_rank - 2) != dims.end(), + std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end()}; + break; + case VectorLayout::ImplicitDim::kSecondMinor: + reduces = {false, std::find(dims.begin(), dims.end(), src_rank - 1) != + dims.end()}; + break; + case VectorLayout::ImplicitDim::kMinor: + reduces = { + std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end(), + false}; + break; + } + if ((reduces[0] || reduces[1]) && + !src_layout.hasNativeTiling(target_shape_)) { + src_layout = VectorLayout(kNativeBitwidth, src_layout.offsets(), + default_tiling_, src_layout.implicit_dim()); + } + LayoutOffsets out_offsets = src_layout.offsets(); + for (int i = 0; i < out_offsets.size(); ++i) { + if (reduces[i]) { + out_offsets[i] = std::nullopt; + } + } + ImplicitDim out_implicit_dim = src_layout.implicit_dim(); + if ((reduces[0] && reduces[1]) || + (src_layout.implicit_dim() != ImplicitDim::kNone && + (reduces[0] || reduces[1]))) { + TPU_CHECK_OP( + dst_ty.getRank() > 0 && *(dst_ty.getShape().end() - 1) == 1, + "Not implemented: reductions over both trailing dimensions are only " + "supported when the resulting value has a trailing axis of size 1"); + out_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor; + } else if (reduces[0]) { + out_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor; + } else if (reduces[1]) { + out_implicit_dim = VectorLayout::ImplicitDim::kMinor; } + setLayout(op, {src_layout, acc_layout}, + VectorLayout(src_layout.bitwidth(), out_offsets, + src_layout.tiling(), out_implicit_dim)); return success(); } diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index 8fcc119ec4be..d7fb8c9abf6f 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -3060,7 +3060,7 @@ def _vector_multi_reduction_rule( # pylint: disable=missing-function-docstring reduces = TargetTuple((src_rank - 1) in dims, False) allow_replicated = TargetTuple(not reduces.sublanes, not reduces.lanes) - if not src_layout.has_native_tiling: + if any(reduces) and not src_layout.has_native_tiling: raise NotImplementedError("unsupported input layout") if src_layout.tiling != dst_layout.tiling: raise NotImplementedError("tiling shouldn't change") @@ -3082,11 +3082,13 @@ def _vector_multi_reduction_rule( # pylint: disable=missing-function-docstring ) dst_implicit_dim = ImplicitDim.SECOND_MINOR # Whatever works. elif reduces.lanes: + assert src_layout.implicit_dim is None dst_implicit_dim = ImplicitDim.MINOR elif reduces.sublanes: + assert src_layout.implicit_dim is None dst_implicit_dim = ImplicitDim.SECOND_MINOR else: - dst_implicit_dim = None + dst_implicit_dim = src_layout.implicit_dim if dst_implicit_dim != dst_layout.implicit_dim: raise NotImplementedError("unsupported output implicit dim")