diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index b834d981b6c4..8e4cac2983c8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -271,12 +272,15 @@ FailureOr getZeroIntOrFloatAttr(Type ty) { return emitError(UnknownLoc::get(ty.getContext()), "Not implemented: ") << ty; } -FailureOr getIntConst(Value v) { +FailureOr getIntConst(Value v, bool silent = false) { if (auto constant_op = v.getDefiningOp()) { if (auto integer_attr = dyn_cast(constant_op.getValue())) { return integer_attr.getValue().getSExtValue(); } } + if (silent) { + return failure(); + } return emitError(v.getLoc(), "Expected an integer constant"); } @@ -289,6 +293,31 @@ FailureOr> getIntConstsFromOperandRange( return res; } +SmallVector> getDimIndices(OperandRange indices, + ArrayRef shape, + ImplicitLocOpBuilder& builder) { + CHECK_EQ(indices.size(), shape.size()); + SmallVector> result(indices.size()); + for (int dim = 0; dim < indices.size(); ++dim) { + auto& dim_idx = result[dim]; + dim_idx.reserve(shape[dim]); + if (auto idx_const = getIntConst(indices[dim], /*silent=*/true); + succeeded(idx_const)) { + int64_t cst = idx_const.value(); + for (int64_t off = 0; off < shape[dim]; ++off) { + dim_idx.push_back(IdxConst(cst + off, builder, builder.getLoc())); + } + } else { + for (int64_t off = 0; off < shape[dim]; ++off) { + dim_idx.push_back(builder.create( + indices[dim], IdxConst(off, builder, builder.getLoc()))); + } + } + } + return result; +} + + // Returns the first-level tiling of a (packed and tiled) memref value. FailureOr> getMemRefTiling( TypedValue value, const std::array target_shape) { @@ -1555,11 +1584,16 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( VectorType target_ty, getNativeVregType(vty.getElementType(), ctx.target_shape)); - if (layout_out.implicit_dim() == VectorLayout::ImplicitDim::kMinor) { - return op.emitOpError("Not implemented"); + if (vty.getRank() == 0) { + op.emitOpError("Not implemented: scalar loads from vmem"); + } + const bool is_1d = vty.getRank() == 1; + VectorLayout::ImplicitDim expected_dim = + is_1d ? VectorLayout::ImplicitDim::kSecondMinor + : VectorLayout::ImplicitDim::kNone; + if (layout_out.implicit_dim() != expected_dim) { + return op.emitOpError("Not implemented: unsupported layout"); } - const bool is_1d = - layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone; using Tiling = std::array; // To avoid comma in macro FAILUREOR_ASSIGN_OR_RETURN( Tiling memref_tiling, @@ -1574,16 +1608,8 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, } // TODO(apaszke): Check that loads are from vmem! FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector indices, - getIntConstsFromOperandRange(load_op.getIndices())); - if (llvm::any_of( - llvm::zip_equal(indices, vty.getShape(), memref_ty.getShape()), - [](auto tup) { - auto [idx, n, extent] = tup; - return idx + n > extent; - })) { - return op.emitOpError("Reading out of bounds"); - } + const SmallVector tile_indices, + getIntConstsFromOperandRange(load_op.getIndices().take_back(2 - is_1d))); const SmallVector implicit_shape = layout_out.implicitShape(vty.getShape()); const int64_t ss = implicit_shape[implicit_shape.size() - 2]; @@ -1624,38 +1650,42 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, layout_out.tileArrayShape(vty.getShape(), ctx.target_shape)); const std::array vreg_slice = layout_out.vregSlice(ctx.target_shape); - const int64_t num_dims = indices.size(); + const int64_t num_dims = vty.getRank(); const int64_t num_batch_dims = num_dims - (is_1d ? 1 : 2); + SmallVector> base_batch = + getDimIndices(load_op.getIndices().take_front(num_batch_dims), + vty.getShape().take_front(num_batch_dims), + builder); const absl::Status status = tiles.EachStatus([&](absl::Span tile_idxs, Value * /*v*/) { CHECK_EQ(num_dims, tile_idxs.size()); - SmallVector idxs(tile_idxs.size()); + SmallVector idxs(tile_idxs.size()); for (int64_t i = 0; i < num_batch_dims; ++i) { - idxs[i] = tile_idxs[i] + indices[i]; + idxs[i] = base_batch[i][tile_idxs[i]]; } - const int64_t base_l = indices[num_dims - 1]; + const int64_t base_l = tile_indices.back(); const int64_t lidx = tile_idxs[num_dims - 1]; - idxs[num_dims - 1] = base_l + lidx * vreg_slice[1] - *offsets[1]; + idxs[num_dims - 1] = + IdxConst(base_l + lidx * vreg_slice[1] - *offsets[1], builder, + load_op->getLoc()); if (!is_1d) { - const int64_t base_s = indices[num_dims - 2]; + CHECK_EQ(tile_indices.size(), 2); + const int64_t base_s = tile_indices.front(); const int64_t sidx = tile_idxs[num_dims - 2]; idxs[num_dims - 2] = - base_s + sidx * vreg_slice[0] - offsets[0].value_or(0); + IdxConst(base_s + sidx * vreg_slice[0] - offsets[0].value_or(0), + builder, load_op->getLoc()); } CHECK(tile_idxs[num_dims - 1] + ctx.target_shape[1] <= memref_ty.getShape()[num_dims - 1]); std::unique_ptr bounds = layout_out.tileDataBounds( mlir_ctx, vty.getShape(), toArrayRef(tile_idxs), ctx.target_shape, /*allow_replicated =*/{true, false}); - SmallVector idxs_vs(idxs.size()); - for (int64_t i = 0; i < idxs.size(); ++i) { - idxs_vs[i] = IdxConst(idxs[i], builder, load_op->getLoc()); - } Operation *tile; if (bounds->maskVariesAlong(Direction::kSublanes, ctx.target_shape)) { CHECK(offsets[0].has_value()); tile = builder.create( - target_ty, load_op.getBase(), idxs_vs, + target_ty, load_op.getBase(), idxs, bounds->getSublaneMask(mlir_ctx, ctx.target_shape), builder.getI32IntegerAttr(sublane_stride)); } else { @@ -1666,14 +1696,14 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, return absl::UnimplementedError(""); } tile = builder.create( - target_ty, load_op.getBase(), idxs_vs, load_map, padding, + target_ty, load_op.getBase(), idxs, load_map, padding, nullptr, nullptr); } else { const SmallVector sublane_mask(ctx.target_shape[0], true); const auto sublane_mask_attr = DenseBoolArrayAttr::get(mlir_ctx, sublane_mask); tile = builder.create( - target_ty, load_op.getBase(), idxs_vs, sublane_mask_attr, + target_ty, load_op.getBase(), idxs, sublane_mask_attr, builder.getI32IntegerAttr(sublane_stride)); } } @@ -2507,11 +2537,16 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, vector::StoreOp store_op = cast(op); const VectorType ty = store_op.getValueToStore().getType(); const VectorLayout &to_store_layout = *layouts_in.front(); - if (to_store_layout.implicit_dim() == VectorLayout::ImplicitDim::kMinor) { - return op.emitOpError("Not implemented"); + if (!ty.getRank()) { + return op.emitOpError("Not implemented: scalar stores to vmem"); + } + const bool is_1d = ty.getRank() == 1; + VectorLayout::ImplicitDim expected_dim = + is_1d ? VectorLayout::ImplicitDim::kSecondMinor + : VectorLayout::ImplicitDim::kNone; + if (to_store_layout.implicit_dim() != expected_dim) { + return op.emitOpError("Not implemented: unsupported layout"); } - const bool is_1d = - to_store_layout.implicit_dim() != VectorLayout::ImplicitDim::kNone; using Tiling = std::array; FAILUREOR_ASSIGN_OR_RETURN( const Tiling memref_tiling, @@ -2525,19 +2560,23 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, } } FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector base_indices, - getIntConstsFromOperandRange(store_op.getIndices())); + const SmallVector tile_indices, + getIntConstsFromOperandRange(store_op.getIndices().take_back(2 - is_1d))); FAILUREOR_ASSIGN_OR_RETURN( xla::Array tiles, disassemble(ctx, builder, to_store_layout, store_op.getValueToStore())); - const int64_t ndims = base_indices.size(); + const int64_t ndims = ty.getRank(); const int64_t nbatchdims = is_1d ? ndims - 1 : ndims - 2; - const int64_t base_s = is_1d ? 0 : base_indices[ndims - 2]; - const int64_t base_l = base_indices[ndims - 1]; + const int64_t base_s = is_1d ? 0 : tile_indices.front(); + const int64_t base_l = tile_indices.back(); if (is_1d) { tiles.Reshape( to_store_layout.implicitShape(toArrayRef(tiles.dimensions()))); } + SmallVector> base_batch = + getDimIndices(store_op.getIndices().take_front(nbatchdims), + ty.getShape().take_front(nbatchdims), + builder); const LayoutOffset sublane_offset = to_store_layout.offsets()[0]; const LayoutOffset lane_offset = to_store_layout.offsets()[1]; if (!sublane_offset.has_value() || !lane_offset.has_value()) { @@ -2567,7 +2606,7 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder, store_op->getLoc()); for (int64_t i = 0; i < nbatchdims; ++i) { - indices[i] = boundIdxConst(idx[i] + base_indices[i]); + indices[i] = base_batch[i][idx[i]]; } if (!is_1d) { *(indices.end() - 2) = diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index cb39a81d820d..ccf10c09f0c1 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -849,17 +849,24 @@ class VectorLayoutInferer { SmallVector in_layout(op->getNumOperands(), kNoLayout); CHECK_EQ(op->getNumOperands(), op.getIndices().size() + 1); - SmallVector indices; - indices.reserve(rank); - for (Value v : op.getIndices()) { - auto cst_op = v.getDefiningOp(); - TPU_CHECK_OP(cst_op, "only constant indices are supported"); - indices.push_back(cast(cst_op.getValue()).getInt()); - } - for (int64_t i = 0; i < rank; ++i) { - TPU_CHECK_OP(indices[i] + res_ty.getDimSize(i) <= src_ty.getDimSize(i), - "Loading elements out of bounds"); + SmallVector tile_indices; + for (int i = rank - 1; i >= 0; --i) { + auto cst_op = op.getIndices()[i].getDefiningOp(); + if (cst_op) { + int64_t idx = cast(cst_op.getValue()).getInt(); + TPU_CHECK_OP(idx + res_ty.getDimSize(i) <= src_ty.getDimSize(i), + "Loading elements out of bounds"); + if (tile_indices.size() < 2) { + tile_indices.push_back(idx); + } + } else { + TPU_CHECK_OP( + tile_indices.size() == 2, + "Dynamic indices are not supported in the last two dimensions"); + } } + // We pushed the indices in reverse. + std::reverse(tile_indices.begin(), tile_indices.end()); if (rank == 0) { op.emitOpError("rank 0 vectors unsupported"); @@ -870,7 +877,8 @@ class VectorLayoutInferer { auto tile = tiling.front(); TPU_CHECK_OP(tile % target_shape_[1] == 0, "Unsupported tiling for 1D load"); - int64_t idx = indices.front(); + CHECK_EQ(tile_indices.size(), 1); + int64_t idx = tile_indices.front(); int64_t offset = idx % kVmemAlignment32; // TODO(apaszke): We could generate replicated loads for short values. setLayout(op, in_layout, @@ -878,8 +886,8 @@ class VectorLayoutInferer { ImplicitDim::kSecondMinor)); } else { // rank >= 2 TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ loads"); + CHECK_EQ(tile_indices.size(), 2); std::array, 2> offsets; - const auto tile_indices = ArrayRef(indices).take_back(2); const auto tile_src_shape = src_ty.getShape().take_back(2); const auto tile_res_shape = res_ty.getShape().take_back(2); const int64_t num_sublanes = tile_res_shape[0]; @@ -1140,17 +1148,24 @@ class VectorLayoutInferer { } auto tiling = *maybe_tiling; - SmallVector indices; - indices.reserve(rank); - for (Value v : op.getIndices()) { - auto cst_op = v.getDefiningOp(); - TPU_CHECK_OP(cst_op, "only constant indices are supported"); - indices.push_back(cast(cst_op.getValue()).getInt()); - } - for (int64_t i = 0; i < rank; ++i) { - TPU_CHECK_OP(indices[i] + store_ty.getDimSize(i) <= ref_ty.getDimSize(i), - "storing elements out of bounds"); + SmallVector tile_indices; + for (int i = rank - 1; i >= 0; --i) { + auto cst_op = op.getIndices()[i].getDefiningOp(); + if (cst_op) { + int64_t idx = cast(cst_op.getValue()).getInt(); + TPU_CHECK_OP(idx + store_ty.getDimSize(i) <= ref_ty.getDimSize(i), + "Loading elements out of bounds"); + if (tile_indices.size() < 2) { + tile_indices.push_back(idx); + } + } else { + TPU_CHECK_OP( + tile_indices.size() == 2, + "Dynamic indices are not supported in the last two dimensions"); + } } + // We pushed the indices in reverse. + std::reverse(tile_indices.begin(), tile_indices.end()); Layout store_layout; if (rank == 0) { @@ -1162,14 +1177,15 @@ class VectorLayoutInferer { auto tile = tiling.front(); TPU_CHECK_OP(tile % target_shape_[1] == 0, "Unsupported 1D tiling for 1D store"); - int64_t idx = indices.front(); + CHECK_EQ(tile_indices.size(), 1); + int64_t idx = tile_indices.front(); int64_t offset = idx % kVmemAlignment32; store_layout = VectorLayout(bitwidth, {0, offset}, {1, tile}, ImplicitDim::kSecondMinor); } else { // rank >= 2 // NOLINT(readability-else-after-return) TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ store"); + CHECK_EQ(tile_indices.size(), 2); std::array, 2> offsets; - const auto tile_indices = ArrayRef(indices).take_back(2); const auto tile_ref_shape = ref_ty.getShape().take_back(2); const auto tile_store_shape = store_ty.getShape().take_back(2); const int64_t num_sublanes = tile_store_shape[0]; diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index 89c14cb87993..8fcc119ec4be 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -2509,34 +2509,30 @@ def _vector_load_rule( # pylint: disable=missing-function-docstring memref_ty = ir.MemRefType(op.base.type) ty = ir.VectorType(op.result.type) target_ty = native_vreg_ty(ty.element_type) - if layout_out.implicit_dim == ImplicitDim.MINOR: + if len(ty.shape) == 0: + raise NotImplementedError + is_1d = len(ty.shape) == 1 + expected_dim = ImplicitDim.SECOND_MINOR if is_1d else None + if layout_out.implicit_dim != expected_dim: raise NotImplementedError - is_1d = layout_out.implicit_dim is not None memref_tiling = get_memref_tiling(op.base) if layout_out.tiling != memref_tiling: # Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes). # TODO(b/295393167): need to support strided load for bitwidth < 32. - if layout_out.bitwidth != 32 or layout_out.tiling != ( - 1, - TARGET_SHAPE.lanes, - ): + lanes = TARGET_SHAPE.lanes + if layout_out.bitwidth != 32 or layout_out.tiling != (1, lanes): raise NotImplementedError # TODO(apaszke): Check that loads are from vmem! - indices = [get_int_const(v, "vector.load index") for v in op.indices] - for i, n, extent in zip(indices, ty.shape, memref_ty.shape): - if i + n > extent: - raise ValueError("reading out of bounds") + tile_indices = [ + get_int_const(v, "vector.load index") for v in op.indices[-2:] + ] *_, ss, _ = layout_out.implicit_shape(ty.shape) sublane_stride = 1 # The stride of load should be the number of sublanes in memref tile when # loaing a single sublane. if ( layout_out.bitwidth == 32 - and layout_out.tiling - == ( - 1, - TARGET_SHAPE.lanes, - ) + and layout_out.tiling == (1, TARGET_SHAPE.lanes) and ss == 1 ): sublane_stride = memref_tiling[0] @@ -2558,10 +2554,13 @@ def _vector_load_rule( # pylint: disable=missing-function-docstring if any((o or 0) > t for o, t in zip(offsets, tiling)): raise NotImplementedError if is_1d: - *base_batch, base_l = indices + (base_l,) = tile_indices base_s = 0 else: - *base_batch, base_s, base_l = indices + base_s, base_l = tile_indices + base_batch = get_dim_indices( + op.indices[: -len(tile_indices)], ty.shape[: -len(tile_indices)] + ) tiles = np.ndarray(layout_out.tile_array_shape(ty.shape), dtype=object) vreg_slice = layout_out.vreg_slice for tile_ixs in np.ndindex(tiles.shape): @@ -2574,17 +2573,17 @@ def _vector_load_rule( # pylint: disable=missing-function-docstring base_s + six * vreg_slice.sublanes - (s or 0), base_l + lix * vreg_slice.lanes - l, ) - indices = (*(b + i for b, i in zip(base_batch, batch_ixs)), *tile) - assert indices[-1] + TARGET_SHAPE.lanes <= memref_ty.shape[-1] + indices = ( + *(b[i] for b, i in zip(base_batch, batch_ixs)), *map(ix_cst, tile), + ) bounds = layout_out.tile_data_bounds( ty.shape, tile_ixs, allow_replicated=TargetTuple(True, False)) - indices_vs = list(map(ix_cst, indices)) if bounds.mask_varies_along(SUBLANES): assert s is not REPLICATED # Replicated loads should never go OOB tile = tpu.LoadOp( target_ty, op.base, - indices_vs, + indices, bounds.get_sublane_mask(), sublane_stride=sublane_stride, ) @@ -2593,7 +2592,7 @@ def _vector_load_rule( # pylint: disable=missing-function-docstring if layout_out.bitwidth != 32: raise NotImplementedError tile = vector.TransferReadOp( - target_ty, op.base, indices_vs, load_map, padding) + target_ty, op.base, indices, load_map, padding) else: assert s is not REPLICATED sublane_mask = ir.DenseBoolArrayAttr.get( @@ -2601,7 +2600,7 @@ def _vector_load_rule( # pylint: disable=missing-function-docstring tile = tpu.LoadOp( target_ty, op.base, - indices_vs, + indices, sublane_mask, sublane_stride=sublane_stride, ) @@ -2617,26 +2616,30 @@ def _vector_store_rule( # pylint: disable=missing-function-docstring assert all(ip is None for ip in other_layouts) assert layout_out is None ty = ir.VectorType(op.valueToStore.type) - if to_store_layout.implicit_dim == ImplicitDim.MINOR: + if len(ty.shape) == 0: + raise NotImplementedError + is_1d = len(ty.shape) == 1 + expected_dim = ImplicitDim.SECOND_MINOR if is_1d else None + if to_store_layout.implicit_dim != expected_dim: raise NotImplementedError - is_1d = to_store_layout.implicit_dim is not None memref_tiling = get_memref_tiling(op.base) if to_store_layout.tiling != memref_tiling: # Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes). # TODO(b/295393167): need to support strided store for bitwidth < 32. - if to_store_layout.bitwidth != 32 or to_store_layout.tiling != ( - 1, - TARGET_SHAPE.lanes, - ): + lanes = TARGET_SHAPE.lanes + if to_store_layout.bitwidth != 32 or to_store_layout.tiling != (1, lanes): raise NotImplementedError - base_indices = [get_int_const(v, "vector.store index") for v in op.indices] + tile_indices = [get_int_const(v, "vector.store index") for v in op.indices[-2:]] tiles = disassemble(to_store_layout, op.valueToStore) if is_1d: - *base_batch, base_l = base_indices + (base_l,) = tile_indices base_s = 0 tiles = tiles.reshape(to_store_layout.implicit_shape(tiles.shape)) else: - *base_batch, base_s, base_l = base_indices + base_s, base_l = tile_indices + base_batch = get_dim_indices( + op.indices[: -len(tile_indices)], ty.shape[: -len(tile_indices)] + ) sublane_offset, lane_offset = to_store_layout.offsets check(lane_offset is not REPLICATED and sublane_offset is not REPLICATED, "replicated layout disallowed in vector store") @@ -2654,13 +2657,12 @@ def _vector_store_rule( # pylint: disable=missing-function-docstring bounds = to_store_layout.tile_data_bounds(stored_shape, ixs) *batch_ixs, six, lix = ixs indices = ( - *(b + i for b, i in zip(base_batch, batch_ixs)), - base_s + six * vreg_slice.sublanes - sublane_offset, - base_l + lix * vreg_slice.lanes - lane_offset, + *(b[i] for b, i in zip(base_batch, batch_ixs)), + ix_cst(base_s + six * vreg_slice.sublanes - sublane_offset), + ix_cst(base_l + lix * vreg_slice.lanes - lane_offset), ) if is_1d: indices = (*indices[:-2], indices[-1]) - indices = list(map(ix_cst, indices)) sublane_mask = bounds.get_sublane_mask() masks_subelements = bounds.mask_varies_along(SUBELEMENTS) if bounds.mask_varies_along(LANES) or masks_subelements: @@ -3470,3 +3472,30 @@ def get_memref_tiling(value: ir.Value) -> tuple[int, int]: def _round_up(x: int, to: int): assert x >= 0 return ((x + to - 1) // to) * to + + +def get_dim_indices(indices, shape) -> list[list[ValueLike]]: + dim_indices = [] + index = ir.IndexType.get() + assert len(indices) == len(shape) + for dim_size, idx_val in zip(shape, indices): + idx_const = None + try: + idx_const = get_int_const(idx_val, "") + except ValueError: + pass + if idx_const is not None: + dim_indices.append( + [ + arith.ConstantOp(index, ir.IntegerAttr.get(index, idx_const + i)) + for i in range(dim_size) + ] + ) + else: + dim_indices.append([ + arith.AddIOp( + idx_val, arith.ConstantOp(index, ir.IntegerAttr.get(index, i)) + ) + for i in range(dim_size) + ]) + return dim_indices