diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 5cf464797bf5..218e53a92007 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -450,6 +450,17 @@ def TPU_YieldOp : TPU_Op<"yield", [Pure, ReturnLike, Terminator]> { let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }]; } +// Expands the granularity of mask to subelements. +def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> { + let arguments = (ins AnyVector:$input); + let results = (outs AnyVector:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; +} + + def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> { let dependentDialects = [ "::mlir::func::FuncDialect", diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 62cf24045c46..aa2c1d96347c 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -115,6 +115,19 @@ void MatmulOp::getCanonicalizationPatterns(RewritePatternSet &patterns, CanonicalizeAddOfMatmul>(context); } +LogicalResult MaskCastOp::verify() { + auto input_ty = getInput().getType(); + auto output_ty = getResult().getType(); + return success(input_ty.getElementType() == output_ty.getElementType() && + output_ty.getRank() == 3 && + (input_ty.getRank() == 2 || + (input_ty.getRank() == 3 && + input_ty.getDimSize(2) < output_ty.getDimSize(2))) && + input_ty.getShape().take_front(2) == + output_ty.getShape().take_front(2)); + return success(); +} + } // namespace tpu } // namespace mlir diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index ec8c7ebc41f7..c32f278af0c1 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1248,7 +1248,10 @@ class VectorLayoutInferer { } LogicalResult inferMatmul(Operation* op) { - auto get_unpadded_layout = [&](Value v) -> std::optional { + auto get_unpadded_layout = + [&](Value v, std::optional major_multiple = std::nullopt, + std::optional minor_multiple = + std::nullopt) -> std::optional { auto pad = getLayout(v); if (!pad.has_value() || pad->implicit_dim() != ImplicitDim::kNone) { return std::nullopt; @@ -1257,8 +1260,9 @@ class VectorLayoutInferer { auto tiling = nativeTiling(vty.getElementTypeBitWidth()); auto shape = vty.getShape().take_back(2); if (pad->offsets()[0].value_or(0) != 0 || - pad->offsets()[1].value_or(0) != 0 || shape[0] % tiling[0] != 0 || - shape[1] % tiling[1] != 0) { + pad->offsets()[1].value_or(0) != 0 || + shape[0] % major_multiple.value_or(tiling[0]) != 0 || + shape[1] % minor_multiple.value_or(tiling[1]) != 0) { return std::nullopt; } // Override tiling to match the native one. @@ -1271,11 +1275,12 @@ class VectorLayoutInferer { "only 32-bit matmul results supported"); std::array in_layout; CHECK_EQ(op->getNumOperands(), 3); - for (int i = 0; i < 3; ++i) { - if (auto layout = get_unpadded_layout(op->getOperand(i))) { - in_layout[i] = *layout; - } else { - op->emitOpError("padded operands"); + in_layout[0] = get_unpadded_layout(op->getOperand(0), std::nullopt, 1); + in_layout[1] = get_unpadded_layout(op->getOperand(1), 128, 1); + in_layout[2] = get_unpadded_layout(op->getOperand(2), std::nullopt, 1); + for (Layout &l : in_layout) { + if (!l.has_value()) { + op->emitOpError("unsupported operand shapes or layouts"); return failure(); } } diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index 80fb209eb3e9..9f4654bfb8dd 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -2730,43 +2730,99 @@ def _matmul_rule( acc_type = ir.VectorType(op.acc.type) if type_bitwidth(acc_type.element_type) != 32: raise NotImplementedError("non-32-bit matmul result") + # The code below puts no constraints on the second dimension of both lhs and + # rhs. However, leading axis of lhs needs to be a multiple of native tiling, + # while leading axis of rhs needs to be a multiple of 128 (no matter the + # transpose mode). if lhs_type.shape[0] % layout_lhs.tiling[0] != 0: - raise NotImplementedError("layout matmul lhs") - if rhs_type.shape[0] % 128 != 0 or rhs_type.shape[1] % 128 != 0: - raise NotImplementedError("matmul rhs requires padding") - layout_attr = op.attributes["out_layout"] - lhs_col_ty = ir.VectorType.get( - (lhs_type.shape[0], 128), lhs_type.element_type) - acc_col_ty = ir.VectorType.get( - (lhs_type.shape[0], 128), acc_type.element_type) + raise NotImplementedError("Unsupported LHS shape") + if rhs_type.shape[0] % 128 != 0: + raise NotImplementedError("Unsupported RHS shape") + padded_lhs_rows = _round_up(lhs_type.shape[0], to=layout_lhs.tiling[0]) + lhs_col_ty = ir.VectorType.get((padded_lhs_rows, 128), lhs_type.element_type) + if _round_up(lhs_type.shape[0], to=layout_acc.tiling[0]) != padded_lhs_rows: + raise NotImplementedError("matmul acc requires less padding than lhs") + acc_col_ty = ir.VectorType.get((padded_lhs_rows, 128), acc_type.element_type) lhs_tiles = disassemble(layout_lhs, op.lhs) acc_tiles = disassemble(layout_acc, op.acc) + assert padded_lhs_rows == lhs_tiles.shape[-2] * layout_lhs.tiling[-2] + assert padded_lhs_rows == acc_tiles.shape[-2] * layout_acc.tiling[-2] lhs_cols = [tpu.RollVectorsOp(lhs_col_ty, lhs_tiles[:, i]) for i in range(lhs_tiles.shape[1])] + if contraction_rem := lhs_type.shape[1] % 128: + i32_vreg = native_vreg_ty(i32()) + contraction_lane_mask = arith.CmpIOp( + arith.CmpIPredicate.slt, + tpu.IotaOp(i32_vreg, dimension=1), + arith.ConstantOp( + i32_vreg, + ir.DenseElementsAttr.get_splat( + i32_vreg, ir.IntegerAttr.get(i32(), contraction_rem) + ), + ), + ).result + def mask_last_lane_contraction_tile(zeros, vreg): + mask = contraction_lane_mask + if vreg.type.shape != mask.type.shape: + mask = tpu.MaskCastOp( + ir.VectorType.get(vreg.type.shape, ir.IntegerType.get_signless(1)), + mask, + ) + return arith.SelectOp(mask, vreg, zeros) + lhs_vreg_type = lhs_tiles.flat[0].type + lhs_zeros = arith.ConstantOp( + lhs_vreg_type, + ir.DenseElementsAttr.get_splat( + lhs_vreg_type, get_constant(lhs_vreg_type.element_type, 0) + ), + ) + lhs_masked_tiles = np.empty_like(lhs_tiles[:, -1]) + for idx, vreg in np.ndenumerate(lhs_tiles[:, -1]): + lhs_masked_tiles[idx] = mask_last_lane_contraction_tile(lhs_zeros, vreg) + lhs_cols[-1] = tpu.RollVectorsOp(lhs_col_ty, lhs_masked_tiles) + else: + mask_last_lane_contraction_tile = None + lhs_layout_attr = ir.ArrayAttr.get([ir.StringAttr.get(print_layout(layout_lhs))]) + rhs_layout_attr = ir.ArrayAttr.get([ir.StringAttr.get(print_layout(layout_rhs))]) + acc_layout_attr = ir.ArrayAttr.get([ir.StringAttr.get(print_layout(layout_acc))]) for col in lhs_cols: - col.attributes["out_layout"] = layout_attr + col.attributes["out_layout"] = lhs_layout_attr rhs_tile_ty = ir.VectorType.get((128, 128), rhs_type.element_type) rhs_vregs = disassemble(layout_rhs, op.rhs) rhs_vregs_per_tile = 16 // layout_rhs.packing if transpose_rhs: nj, nk = cdiv(tuple(rhs_type.shape), (128, 128)) - rhs_tiles = rhs_vregs.reshape((nj, rhs_vregs_per_tile, nk, 1)).transpose( - 2, 0, 1, 3) + rhs_full_tiles = rhs_vregs.reshape( + (nj, rhs_vregs_per_tile, nk, 1) + ).transpose(2, 0, 1, 3) else: nk, nj = cdiv(tuple(rhs_type.shape), (128, 128)) - rhs_tiles = rhs_vregs.reshape((nk, rhs_vregs_per_tile, nj, 1)).transpose( - 0, 2, 1, 3) + rhs_full_tiles = rhs_vregs.reshape( + (nk, rhs_vregs_per_tile, nj, 1) + ).transpose(0, 2, 1, 3) precision = None if "precision" in op.attributes: precision = op.attributes["precision"] + rhs_vreg_type = rhs_full_tiles.flat[0].type + rhs_zeros = arith.ConstantOp( + rhs_vreg_type, + ir.DenseElementsAttr.get_splat( + rhs_vreg_type, get_constant(rhs_vreg_type.element_type, 0) + ), + ) for j, k in np.ndindex((nj, nk)): - rhs_tile = rhs_tiles[k, j] + rhs_tile = rhs_full_tiles[k, j] assert rhs_tile.shape == (rhs_vregs_per_tile, 1) + if mask_last_lane_contraction_tile is not None and k == nk - 1: + rhs_masked_tile = np.empty_like(rhs_tile) + for idx, vreg in np.ndenumerate(rhs_tile): + rhs_masked_tile[idx] = mask_last_lane_contraction_tile(rhs_zeros, vreg) + rhs_tile = rhs_masked_tile rhs_rolled_tile = tpu.RollVectorsOp(rhs_tile_ty, list(rhs_tile.flat)) - rhs_rolled_tile.attributes["out_layout"] = layout_attr + rhs_rolled_tile.attributes["out_layout"] = rhs_layout_attr acc_col = tpu.RollVectorsOp(acc_col_ty, acc_tiles[:, j]) - acc_col.attributes["out_layout"] = layout_attr + acc_col.attributes["out_layout"] = acc_layout_attr new_acc_col = tpu.MatmulOp( acc_col_ty, lhs_cols[k], rhs_rolled_tile, acc_col, transpose_lhs=transpose_lhs, @@ -2775,7 +2831,7 @@ def _matmul_rule( ) new_acc_tiles = tpu.UnrollVectorsOp([v.type for v in acc_tiles[:, j]], new_acc_col) - new_acc_tiles.attributes["in_layout"] = layout_attr + new_acc_tiles.attributes["in_layout"] = acc_layout_attr acc_tiles[:, j] = new_acc_tiles.results return ctx.replace(op, assemble(op.result.type, layout_out, acc_tiles)) @@ -3172,9 +3228,9 @@ def get_constant(ty: ir.Type, value: Union[int, float]) -> ir.Attribute: elif ty == ir.IndexType.get(): return ir.IntegerAttr.get(ty, value) elif ty == ir.BF16Type.get(): - return ir.FloatAttr.get_bf16(value) + return ir.FloatAttr.get(ty, value) elif ty == ir.F32Type.get(): - return ir.FloatAttr.get_f32(value) + return ir.FloatAttr.get(ty, value) raise NotImplementedError(ty) @@ -3225,3 +3281,8 @@ def get_memref_tiling(value: ir.Value) -> tuple[int, int]: raise NotImplementedError return first_tiles raise NotImplementedError((first_tiles, *other_tiles)) + + +def _round_up(x: int, to: int): + assert x >= 0 + return ((x + to - 1) // to) * to