Skip to content

Commit

Permalink
Relax shape constraints on matmul unrolling
Browse files Browse the repository at this point in the history
Previously we required the contraction dimension and RHS tensor dimension to be a multiple
of 128. The new code allows arbitrary contraction dimension shape for RHS-transposed
matmuls and arbitrary RHS tensor dimension shape for non-transposed matmuls.

PiperOrigin-RevId: 568180430
  • Loading branch information
apaszke authored and jax authors committed Sep 25, 2023
1 parent 06e8725 commit c478282
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 27 deletions.
11 changes: 11 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,17 @@ def TPU_YieldOp : TPU_Op<"yield", [Pure, ReturnLike, Terminator]> {
let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }];
}

// Expands the granularity of mask to subelements.
def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> {
let arguments = (ins AnyVector:$input);
let results = (outs AnyVector:$result);
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type($result)
}];
let hasVerifier = 1;
}


def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::func::FuncDialect",
Expand Down
13 changes: 13 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ void MatmulOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
CanonicalizeAddOfMatmul<arith::AddIOp>>(context);
}

LogicalResult MaskCastOp::verify() {
auto input_ty = getInput().getType();
auto output_ty = getResult().getType();
return success(input_ty.getElementType() == output_ty.getElementType() &&
output_ty.getRank() == 3 &&
(input_ty.getRank() == 2 ||
(input_ty.getRank() == 3 &&
input_ty.getDimSize(2) < output_ty.getDimSize(2))) &&
input_ty.getShape().take_front(2) ==
output_ty.getShape().take_front(2));
return success();
}

} // namespace tpu
} // namespace mlir

Expand Down
21 changes: 13 additions & 8 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,10 @@ class VectorLayoutInferer {
}

LogicalResult inferMatmul(Operation* op) {
auto get_unpadded_layout = [&](Value v) -> std::optional<VectorLayout> {
auto get_unpadded_layout =
[&](Value v, std::optional<int64_t> major_multiple = std::nullopt,
std::optional<int64_t> minor_multiple =
std::nullopt) -> std::optional<VectorLayout> {
auto pad = getLayout(v);
if (!pad.has_value() || pad->implicit_dim() != ImplicitDim::kNone) {
return std::nullopt;
Expand All @@ -1257,8 +1260,9 @@ class VectorLayoutInferer {
auto tiling = nativeTiling(vty.getElementTypeBitWidth());
auto shape = vty.getShape().take_back(2);
if (pad->offsets()[0].value_or(0) != 0 ||
pad->offsets()[1].value_or(0) != 0 || shape[0] % tiling[0] != 0 ||
shape[1] % tiling[1] != 0) {
pad->offsets()[1].value_or(0) != 0 ||
shape[0] % major_multiple.value_or(tiling[0]) != 0 ||
shape[1] % minor_multiple.value_or(tiling[1]) != 0) {
return std::nullopt;
}
// Override tiling to match the native one.
Expand All @@ -1271,11 +1275,12 @@ class VectorLayoutInferer {
"only 32-bit matmul results supported");
std::array<Layout, 3> in_layout;
CHECK_EQ(op->getNumOperands(), 3);
for (int i = 0; i < 3; ++i) {
if (auto layout = get_unpadded_layout(op->getOperand(i))) {
in_layout[i] = *layout;
} else {
op->emitOpError("padded operands");
in_layout[0] = get_unpadded_layout(op->getOperand(0), std::nullopt, 1);
in_layout[1] = get_unpadded_layout(op->getOperand(1), 128, 1);
in_layout[2] = get_unpadded_layout(op->getOperand(2), std::nullopt, 1);
for (Layout &l : in_layout) {
if (!l.has_value()) {
op->emitOpError("unsupported operand shapes or layouts");
return failure();
}
}
Expand Down
99 changes: 80 additions & 19 deletions jaxlib/mosaic/python/apply_vector_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2730,43 +2730,99 @@ def _matmul_rule(
acc_type = ir.VectorType(op.acc.type)
if type_bitwidth(acc_type.element_type) != 32:
raise NotImplementedError("non-32-bit matmul result")
# The code below puts no constraints on the second dimension of both lhs and
# rhs. However, leading axis of lhs needs to be a multiple of native tiling,
# while leading axis of rhs needs to be a multiple of 128 (no matter the
# transpose mode).
if lhs_type.shape[0] % layout_lhs.tiling[0] != 0:
raise NotImplementedError("layout matmul lhs")
if rhs_type.shape[0] % 128 != 0 or rhs_type.shape[1] % 128 != 0:
raise NotImplementedError("matmul rhs requires padding")
layout_attr = op.attributes["out_layout"]
lhs_col_ty = ir.VectorType.get(
(lhs_type.shape[0], 128), lhs_type.element_type)
acc_col_ty = ir.VectorType.get(
(lhs_type.shape[0], 128), acc_type.element_type)
raise NotImplementedError("Unsupported LHS shape")
if rhs_type.shape[0] % 128 != 0:
raise NotImplementedError("Unsupported RHS shape")
padded_lhs_rows = _round_up(lhs_type.shape[0], to=layout_lhs.tiling[0])
lhs_col_ty = ir.VectorType.get((padded_lhs_rows, 128), lhs_type.element_type)
if _round_up(lhs_type.shape[0], to=layout_acc.tiling[0]) != padded_lhs_rows:
raise NotImplementedError("matmul acc requires less padding than lhs")
acc_col_ty = ir.VectorType.get((padded_lhs_rows, 128), acc_type.element_type)
lhs_tiles = disassemble(layout_lhs, op.lhs)
acc_tiles = disassemble(layout_acc, op.acc)
assert padded_lhs_rows == lhs_tiles.shape[-2] * layout_lhs.tiling[-2]
assert padded_lhs_rows == acc_tiles.shape[-2] * layout_acc.tiling[-2]
lhs_cols = [tpu.RollVectorsOp(lhs_col_ty, lhs_tiles[:, i])
for i in range(lhs_tiles.shape[1])]
if contraction_rem := lhs_type.shape[1] % 128:
i32_vreg = native_vreg_ty(i32())
contraction_lane_mask = arith.CmpIOp(
arith.CmpIPredicate.slt,
tpu.IotaOp(i32_vreg, dimension=1),
arith.ConstantOp(
i32_vreg,
ir.DenseElementsAttr.get_splat(
i32_vreg, ir.IntegerAttr.get(i32(), contraction_rem)
),
),
).result
def mask_last_lane_contraction_tile(zeros, vreg):
mask = contraction_lane_mask
if vreg.type.shape != mask.type.shape:
mask = tpu.MaskCastOp(
ir.VectorType.get(vreg.type.shape, ir.IntegerType.get_signless(1)),
mask,
)
return arith.SelectOp(mask, vreg, zeros)
lhs_vreg_type = lhs_tiles.flat[0].type
lhs_zeros = arith.ConstantOp(
lhs_vreg_type,
ir.DenseElementsAttr.get_splat(
lhs_vreg_type, get_constant(lhs_vreg_type.element_type, 0)
),
)
lhs_masked_tiles = np.empty_like(lhs_tiles[:, -1])
for idx, vreg in np.ndenumerate(lhs_tiles[:, -1]):
lhs_masked_tiles[idx] = mask_last_lane_contraction_tile(lhs_zeros, vreg)
lhs_cols[-1] = tpu.RollVectorsOp(lhs_col_ty, lhs_masked_tiles)
else:
mask_last_lane_contraction_tile = None
lhs_layout_attr = ir.ArrayAttr.get([ir.StringAttr.get(print_layout(layout_lhs))])
rhs_layout_attr = ir.ArrayAttr.get([ir.StringAttr.get(print_layout(layout_rhs))])
acc_layout_attr = ir.ArrayAttr.get([ir.StringAttr.get(print_layout(layout_acc))])
for col in lhs_cols:
col.attributes["out_layout"] = layout_attr
col.attributes["out_layout"] = lhs_layout_attr
rhs_tile_ty = ir.VectorType.get((128, 128), rhs_type.element_type)
rhs_vregs = disassemble(layout_rhs, op.rhs)
rhs_vregs_per_tile = 16 // layout_rhs.packing
if transpose_rhs:
nj, nk = cdiv(tuple(rhs_type.shape), (128, 128))
rhs_tiles = rhs_vregs.reshape((nj, rhs_vregs_per_tile, nk, 1)).transpose(
2, 0, 1, 3)
rhs_full_tiles = rhs_vregs.reshape(
(nj, rhs_vregs_per_tile, nk, 1)
).transpose(2, 0, 1, 3)
else:
nk, nj = cdiv(tuple(rhs_type.shape), (128, 128))
rhs_tiles = rhs_vregs.reshape((nk, rhs_vregs_per_tile, nj, 1)).transpose(
0, 2, 1, 3)
rhs_full_tiles = rhs_vregs.reshape(
(nk, rhs_vregs_per_tile, nj, 1)
).transpose(0, 2, 1, 3)

precision = None
if "precision" in op.attributes:
precision = op.attributes["precision"]
rhs_vreg_type = rhs_full_tiles.flat[0].type
rhs_zeros = arith.ConstantOp(
rhs_vreg_type,
ir.DenseElementsAttr.get_splat(
rhs_vreg_type, get_constant(rhs_vreg_type.element_type, 0)
),
)
for j, k in np.ndindex((nj, nk)):
rhs_tile = rhs_tiles[k, j]
rhs_tile = rhs_full_tiles[k, j]
assert rhs_tile.shape == (rhs_vregs_per_tile, 1)
if mask_last_lane_contraction_tile is not None and k == nk - 1:
rhs_masked_tile = np.empty_like(rhs_tile)
for idx, vreg in np.ndenumerate(rhs_tile):
rhs_masked_tile[idx] = mask_last_lane_contraction_tile(rhs_zeros, vreg)
rhs_tile = rhs_masked_tile
rhs_rolled_tile = tpu.RollVectorsOp(rhs_tile_ty, list(rhs_tile.flat))
rhs_rolled_tile.attributes["out_layout"] = layout_attr
rhs_rolled_tile.attributes["out_layout"] = rhs_layout_attr
acc_col = tpu.RollVectorsOp(acc_col_ty, acc_tiles[:, j])
acc_col.attributes["out_layout"] = layout_attr
acc_col.attributes["out_layout"] = acc_layout_attr
new_acc_col = tpu.MatmulOp(
acc_col_ty, lhs_cols[k], rhs_rolled_tile, acc_col,
transpose_lhs=transpose_lhs,
Expand All @@ -2775,7 +2831,7 @@ def _matmul_rule(
)
new_acc_tiles = tpu.UnrollVectorsOp([v.type for v in acc_tiles[:, j]],
new_acc_col)
new_acc_tiles.attributes["in_layout"] = layout_attr
new_acc_tiles.attributes["in_layout"] = acc_layout_attr
acc_tiles[:, j] = new_acc_tiles.results
return ctx.replace(op, assemble(op.result.type, layout_out, acc_tiles))

Expand Down Expand Up @@ -3172,9 +3228,9 @@ def get_constant(ty: ir.Type, value: Union[int, float]) -> ir.Attribute:
elif ty == ir.IndexType.get():
return ir.IntegerAttr.get(ty, value)
elif ty == ir.BF16Type.get():
return ir.FloatAttr.get_bf16(value)
return ir.FloatAttr.get(ty, value)
elif ty == ir.F32Type.get():
return ir.FloatAttr.get_f32(value)
return ir.FloatAttr.get(ty, value)
raise NotImplementedError(ty)


Expand Down Expand Up @@ -3225,3 +3281,8 @@ def get_memref_tiling(value: ir.Value) -> tuple[int, int]:
raise NotImplementedError
return first_tiles
raise NotImplementedError((first_tiles, *other_tiles))


def _round_up(x: int, to: int):
assert x >= 0
return ((x + to - 1) // to) * to

0 comments on commit c478282

Please sign in to comment.