Skip to content

Commit

Permalink
[DEPENDENCIES] Update LLVM to 17.0.0 (c5dede880d17) and port changes. (
Browse files Browse the repository at this point in the history
…triton-lang#1668)

This depends on a [pending LLVM
release](ptillet/triton-llvm-releases#10).

* Implement setCalleeFromCallable in CallOp.
* Cast type to ShapedType for various getters.
* Improve TritonDialect::materializeConstant due to breaking change in
constructor of arith::ConstantOp.
* Add OpaqueProperties argument in inferReturnTypes.

Co-authored-by: Philippe Tillet <[email protected]>
  • Loading branch information
ingomueller-net and ptillet authored May 16, 2023
1 parent fc4fe08 commit 4bb941e
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 19 deletions.
5 changes: 5 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,11 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}

/// Set the callee for this operation.
void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}
}];

let assemblyFormat = [{
Expand Down
14 changes: 8 additions & 6 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,21 @@ class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
auto retShapedType = retType.cast<ShapedType>();
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
if (dyn_cast<RankedTensorType>(retType)) {
if (dyn_cast<RankedTensorType>(retShapedType)) {
assert(value);
if (value.getElementType().isInteger(1) && value.isSplat())
// Workaround until https://reviews.llvm.org/D133743 is included.
value = DenseElementsAttr::get(retType, value.getSplatValue<bool>());
value =
DenseElementsAttr::get(retShapedType, value.getSplatValue<bool>());
else
// This is a hack. We just want to add encoding
value = value.reshape(retType);
value = value.reshape(retShapedType);
}
addNamedAttrs(
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value),
adaptor.getAttributes());
addNamedAttrs(rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, retShapedType, value),
adaptor.getAttributes());
return success();
}
};
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Triton/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,5 @@ void TritonDialect::initialize() {
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<arith::ConstantOp>(loc, type, value);
return arith::ConstantOp::materialize(builder, value, type, loc);
}
8 changes: 4 additions & 4 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
//-- TransOp --
mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// type is the same as the input
auto argTy = operands[0].getType().cast<RankedTensorType>();
Expand All @@ -376,7 +376,7 @@ mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes(
//-- DotOp --
mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// type is the same as the accumulator
auto accTy = operands[2].getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -444,7 +444,7 @@ void ReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,

mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
for (auto arg : operands) {
auto argTy = arg.getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -551,7 +551,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
//-- ExpandDimsOp --
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// infer shape
auto arg = operands[0];
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ bool isBroadcastConstantCombinable(Attribute value) {
DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
Value bcast_res) {

Type resType = bcast_res.getType();
auto resType = bcast_res.getType().cast<ShapedType>();
DenseElementsAttr res;
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
res =
Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ struct RewritedInfo {
auto otherTensorType = RankedTensorType::get(tensorShape, elementType);

// Set zero padding value
Attribute attr =
TypedAttr attr =
elementType.isIntOrIndex()
? builder.getIntegerAttr(elementType, 0).cast<Attribute>()
: builder.getFloatAttr(elementType, 0).cast<Attribute>();
? builder.getIntegerAttr(elementType, 0).cast<TypedAttr>()
: builder.getFloatAttr(elementType, 0).cast<TypedAttr>();

// Float NaN padding case
if (padding.value() == triton::PaddingOption::PAD_NAN) {
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1228,8 +1228,8 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
// cvt(type, constant) -> constant
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
ret.getSplatValue<Attribute>());
auto ty = op->getResultTypes().front().cast<ShapedType>();
auto newRet = SplatElementsAttr::get(ty, ret.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
return mlir::success();
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
SmallVector<Type, 1> newTypes;
auto success = typeInfer.inferReturnTypes(
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
newOp->getAttrDictionary(), newOp->getRegions(), newTypes);
newOp->getAttrDictionary(), newOp->getPropertiesStorage(),
newOp->getRegions(), newTypes);
if (succeeded(success)) {
for (size_t i = 0; i < newTypes.size(); i++)
newOp->getResult(i).setType(newTypes[i]);
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_llvm_package_info():
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
release_suffix = "assert" if use_assert_enabled_llvm else "release"
name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}'
version = "llvm-17.0.0-f733b4fb9b8b"
version = "llvm-17.0.0-c5dede880d17"
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")

Expand Down

0 comments on commit 4bb941e

Please sign in to comment.