diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index 6f1ee9cbe5f1..123ee2d56760 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -28,8 +28,6 @@ limitations under the License. #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" @@ -79,7 +77,7 @@ FailureOr> RectangularVregBounds::getVectorMask( DenseBoolArrayAttr RectangularVregBounds::getSublaneMask( MLIRContext* mlir_ctx, const std::array target_shape) const { - llvm::SmallVector sublane_mask(target_shape[0], false); + SmallVector sublane_mask(target_shape[0], false); for (int64_t i = starts_[0]; i < ends_[0]; ++i) { sublane_mask[i] = true; } @@ -178,7 +176,7 @@ class SingleRowVRegBounds : public VRegDataBounds { const int64_t end_sublane = llvm::divideCeil( llvm::divideCeil(stop_offset_, layout_.packing()), target_shape[1]); - llvm::SmallVector sublane_mask(target_shape[0], false); + SmallVector sublane_mask(target_shape[0], false); for (int64_t i = start_sublane; i < end_sublane; ++i) { sublane_mask[i] = true; } @@ -382,7 +380,7 @@ class TiledRectangularVregBounds : public VRegDataBounds { DenseBoolArrayAttr getSublaneMask( MLIRContext* mlir_ctx, const std::array target_shape) const override { - llvm::SmallVector mask(target_shape[0], false); + SmallVector mask(target_shape[0], false); const int64_t start = start_offsets_[0] / layout_.packing(); const int64_t end = llvm::divideCeil(end_offsets_[0], layout_.packing()); const int64_t sublanes_per_tile = layout_.sublanesPerTile(target_shape); @@ -403,8 +401,7 @@ class TiledRectangularVregBounds : public VRegDataBounds { std::array end_offsets_; }; -mlir::ParseResult parseOffset(llvm::StringRef* data, - std::optional* result) { +mlir::ParseResult parseOffset(StringRef* data, std::optional* result) { int64_t int_result; if (data->consume_front("*")) { *result = std::nullopt; @@ -441,21 +438,21 @@ bool VectorLayout::hasNativeTiling( return tiling_ == nativeTiling(bitwidth_, target_shape); } -llvm::SmallVector VectorLayout::implicitShape( +SmallVector VectorLayout::implicitShape( ArrayRef shape) const { CHECK(!shape.empty()); switch (implicit_dim_) { case ImplicitDim::kNone: - return llvm::SmallVector(shape); + return SmallVector(shape); case ImplicitDim::kMinor: { - llvm::SmallVector implicit_shape; + SmallVector implicit_shape; implicit_shape.reserve(shape.size() + 1); implicit_shape.append(shape.begin(), shape.end()); implicit_shape.push_back(1); return implicit_shape; } case ImplicitDim::kSecondMinor: { - llvm::SmallVector implicit_shape; + SmallVector implicit_shape; implicit_shape.reserve(shape.size() + 1); implicit_shape.append(shape.begin(), std::prev(shape.end())); implicit_shape.push_back(1); @@ -465,11 +462,11 @@ llvm::SmallVector VectorLayout::implicitShape( } } -llvm::SmallVector VectorLayout::tileArrayImplicitShape( +SmallVector VectorLayout::tileArrayImplicitShape( const ArrayRef shape, const std::array target_shape) const { const std::array vreg_slice = vregSlice(target_shape); - llvm::SmallVector tiles_shape = implicitShape(shape); + SmallVector tiles_shape = implicitShape(shape); tiles_shape[tiles_shape.size() - 2] = llvm::divideCeil( offsets_[0].value_or(0) + tiles_shape[tiles_shape.size() - 2], vreg_slice[0]); @@ -479,10 +476,10 @@ llvm::SmallVector VectorLayout::tileArrayImplicitShape( return tiles_shape; } -llvm::SmallVector VectorLayout::tileArrayShape( +SmallVector VectorLayout::tileArrayShape( const ArrayRef shape, const std::array target_shape) const { - llvm::SmallVector tiles_shape = + SmallVector tiles_shape = tileArrayImplicitShape(shape, target_shape); // Remove the implicit dimension --- it's always of size 1. switch (implicit_dim_) { @@ -521,11 +518,11 @@ std::unique_ptr VectorLayout::tileDataBounds( break; } - const llvm::SmallVector tiles_implicit_shape = + const SmallVector tiles_implicit_shape = tileArrayImplicitShape(full_shape, target_shape); const int64_t ns = tiles_implicit_shape[tiles_implicit_shape.size() - 2]; const int64_t nl = tiles_implicit_shape[tiles_implicit_shape.size() - 1]; - const llvm::SmallVector implicit_shape = implicitShape(full_shape); + const SmallVector implicit_shape = implicitShape(full_shape); const int64_t is = implicit_shape[implicit_shape.size() - 2]; const int64_t il = implicit_shape[implicit_shape.size() - 1]; @@ -718,8 +715,8 @@ std::optional VectorLayout::join(const VectorLayout& l, return VectorLayout(l.bitwidth_, offsets, l.tiling_, l.implicit_dim_); } -std::optional VectorLayout::parse(llvm::StringRef* data) { - llvm::StringRef local(*data); +std::optional VectorLayout::parse(StringRef* data) { + StringRef local(*data); int8_t bitwidth; LayoutOffsets offsets; std::array tiling; @@ -797,7 +794,7 @@ std::optional parseLayout(mlir::AsmParser& parser) { if (layout_str == "none") { return kNoLayout; } - llvm::StringRef ref(layout_str); + StringRef ref(layout_str); if (auto layout = VectorLayout::parse(&ref); ref.empty()) { return *layout; } diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 21a4c28a6b23..e23310f968e9 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -26,8 +26,6 @@ limitations under the License. #include "absl/log/check.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" #include "llvm/ADT/bit.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -270,10 +268,10 @@ class VectorLayout { return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]}; } - llvm::SmallVector implicitShape(ArrayRef shape) const; + SmallVector implicitShape(ArrayRef shape) const; private: - llvm::SmallVector tileArrayImplicitShape( + SmallVector tileArrayImplicitShape( ArrayRef shape, std::array target_shape) const; public: @@ -288,7 +286,7 @@ class VectorLayout { // // Args: // shape: The shape of the full vector this layout applies to. - llvm::SmallVector tileArrayShape( + SmallVector tileArrayShape( ArrayRef shape, std::array target_shape) const; // Returns the bounds of the given tile that hold useful data. @@ -383,7 +381,7 @@ class VectorLayout { const VectorLayout &r, ArrayRef shape); - static std::optional parse(llvm::StringRef *data); + static std::optional parse(StringRef *data); // Check conditions that depend on the target shape. Invariants that are // independent of it are checked in the constructor. diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 7191ea0b28f2..abc7cc595cd6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep. +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "absl/hash/hash.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" @@ -103,7 +104,7 @@ Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) { return {}; } - llvm::SmallVector tiles; + SmallVector tiles; int64_t size; while (succeeded(parser.parseOptionalLParen())) { xla::Tile &tile = tiles.emplace_back(); @@ -121,7 +122,7 @@ Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) { tile.add_dimensions(size); } } - llvm::SmallVector tile_strides; + SmallVector tile_strides; int64_t stride; if (failed(parser.parseComma())) { return {}; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 8db6a4bef75c..be22d809c3ee 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -32,7 +33,7 @@ namespace tpu { LogicalResult UnrollVectorsOp::canonicalize(UnrollVectorsOp op, PatternRewriter &rewriter) { RollVectorsOp roll_op = - llvm::dyn_cast_or_null(op.getOperand().getDefiningOp()); + dyn_cast_or_null(op.getOperand().getDefiningOp()); if (!roll_op) { return failure(); } @@ -150,8 +151,8 @@ LogicalResult MemRefSqueezeOp::canonicalize(MemRefSqueezeOp op, int target_index = target_shape.size() - 1; auto old_layout = dyn_cast(layout_ty.getLayout()); auto target_strides = old_layout.getTileStrides(); - llvm::SmallVector tile_strides(target_strides.begin(), - target_strides.end()); + SmallVector tile_strides(target_strides.begin(), + target_strides.end()); // We want to remove all strides that correspond to squeezed dimensions and // update the corresponding output layout. while (source_index >= 0 || target_index >= 0) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 5aa4d07b088d..7ba68c2a6d21 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -12,9 +12,7 @@ #include #include -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/iterator_range.h" @@ -3020,7 +3018,7 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, // replicated result ) { // First, insert the new singleton lane dimension. - llvm::SmallVector s(src_shape); + SmallVector s(src_shape); s.push_back(1); xla::Array dst_vregs_local( layout_out.tileArrayShape(s, ctx.target_shape)); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 32ac7e8ab9a0..18dc1fe9ea0a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -25,9 +25,7 @@ limitations under the License. #include #include -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -144,8 +142,7 @@ class VectorLayoutInferer { has_vector_io |= r.getType().isa(); } if (!has_vector_io && any_op.getRegions().empty()) { - llvm::SmallVector in_layout(any_op.getNumOperands(), - kNoLayout); + SmallVector in_layout(any_op.getNumOperands(), kNoLayout); if (any_op.getNumResults() == 0) { setInLayout(&any_op, in_layout); } else if (any_op.getNumResults() == 1) { @@ -412,7 +409,7 @@ class VectorLayoutInferer { auto then_yield = op.thenBlock()->getTerminator(); TPU_CHECK_OP(then_yield->getOperandTypes() == op->getResultTypes(), "scf if results and then branch yield operands do not match"); - llvm::SmallVector result_layout; + SmallVector result_layout; result_layout.reserve(then_yield->getNumOperands()); for (const auto &operand : then_yield->getOperands()) { if (operand.getType().isSignlessIntOrIndexOrFloat()) { @@ -482,7 +479,7 @@ class VectorLayoutInferer { op->getNumOperands() == 3 + op.getNumResults(), "expected num_operands is equal to 3 + num_results in scf.for"); - llvm::SmallVector in_layouts; + SmallVector in_layouts; in_layouts.reserve(op->getNumOperands()); in_layouts.push_back(kNoLayout); // Lower bound. in_layouts.push_back(kNoLayout); // Upper bound. diff --git a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc index 42e5344eb762..00996d8862aa 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc @@ -73,7 +73,7 @@ struct LinalgVectorizationPass // We do not want to apply the vector patterns above to the ops that are // unrelated to the original linalg op. - llvm::SmallVector linalgOps; + SmallVector linalgOps; func.walk([&](linalg::LinalgOp op) { linalgOps.push_back(op); }); if (failed(applyOpPatternsAndFold(linalgOps, std::move(patterns)))) { return signalPassFailure(); diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index d48280e222a2..9ae5f9a59619 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -6,7 +6,6 @@ #include #include -#include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -63,12 +62,12 @@ FailureOr getTypeBitwidth(Type ty) { } template -llvm::ArrayRef> toArrayRef(absl::Span span) { - return llvm::ArrayRef>(span.data(), span.size()); +ArrayRef> toArrayRef(absl::Span span) { + return ArrayRef>(span.data(), span.size()); } template -llvm::ArrayRef> toArrayRef(std::array array) { - return llvm::ArrayRef>(array.data(), array.size()); +ArrayRef> toArrayRef(std::array array) { + return ArrayRef>(array.data(), array.size()); } inline arith::ConstantOp IdxConst(int64_t idx, OpBuilder &builder,