Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into generic-reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Apr 4, 2023
2 parents 440a39d + 01a9318 commit 3dd30f7
Show file tree
Hide file tree
Showing 60 changed files with 3,109 additions and 3,046 deletions.
4 changes: 2 additions & 2 deletions bin/triton-translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
}

llvm::LLVMContext llvmContext;
auto llvmir =
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue());
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), false /*isRocm*/);
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}
Expand Down
5 changes: 4 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
"device compute capability">,
Option<"isROCM", "is-rocm",
"bool", /*default*/"false",
"compile for ROCM-compatible LLVM">,
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ template <typename T> class OperationPass;
namespace triton {

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
bool isROCM = false);

} // namespace triton

Expand Down
4 changes: 2 additions & 2 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ class DialectInferLayoutInterface
virtual LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding,
Optional<Location> location) const = 0;
std::optional<Location> location) const = 0;

// Note: this function only verify operand encoding but doesn't infer result
// encoding
virtual LogicalResult
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
Attribute retEncoding,
Optional<Location> location) const = 0;
std::optional<Location> location) const = 0;
};

} // namespace triton
Expand Down
18 changes: 17 additions & 1 deletion include/triton/Dialect/Triton/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,26 @@

namespace mlir {

unsigned getPointeeBitWidth(RankedTensorType tensorTy);
namespace triton {

bool isTensorPointerType(Type type);

unsigned getPointeeBitWidth(Type type);

Type getPointeeType(Type type);

Type getPointerType(Type type);

Type getElementTypeOfTensorPointerType(Type type);

Type getI1SameShape(Type type);

Type getI32SameShape(Type type);

Type getPointerTypeSameShape(Type type);

} // namespace triton

} // namespace mlir

#endif // TRITON_IR_TYPES_H_
6 changes: 4 additions & 2 deletions include/triton/Target/LLVMIR/LLVMIRTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ void addExternalLibs(mlir::ModuleOp &module,
// Translate TritonGPU dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability);
mlir::ModuleOp module, int computeCapability,
bool isROCM);

// Translate mlir LLVM dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
bool isROCM);

} // namespace triton
} // namespace mlir
Expand Down
6 changes: 3 additions & 3 deletions include/triton/Tools/Sys/GetPlatform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
#include <memory>
#include <string>

inline bool _isROCM = false;
inline void setROCM() { _isROCM = true; }
inline bool isROCM() { return _isROCM; }
// inline bool _isROCM = false;
// inline void setROCM() { _isROCM = true; }
// inline bool isROCM() { return _isROCM; }

#endif
13 changes: 11 additions & 2 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,16 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
// rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs)
// lhs + rhs = k * d_lhs + p * d_rhs = (k * d_lhs + p * d_rhs) *
// gcd(d_lhs, d_rhs)
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim));
auto elemSize = 1;
if constexpr (std::is_same_v<OpTy, triton::AddPtrOp>) {
// %ptr = addptr %lhs, %rhs
// is equivalent to
// %0 = mul %lhs, %elemSize
// %ptr = add %0, %rhs
elemSize = std::max<unsigned int>(
1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8);
}
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim) * elemSize);
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
Expand Down Expand Up @@ -910,7 +919,7 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto order = triton::gpu::getOrder(layout);
auto maxMultipleBytes = axisInfo.getDivisibility(order[0]);
auto maxContig = axisInfo.getContiguity(order[0]);
auto elemNumBits = getPointeeBitWidth(tensorTy);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
auto maxMultiple = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);
unsigned alignment = std::min(maxMultiple, maxContig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,25 +106,29 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
Value kMatArr = kOrder == 1 ? s1 : s0;
Value nkMatArr = kOrder == 1 ? s0 : s1;

// matrix coordinate inside a CTA, the matrix layout is [2x2wpt] for A and
// [2wptx2] for B. e.g. Setting wpt=3, The data layout for A(kOrder=1) is
// |0 0 1 1 2 2| -> 0,1,2 are the warpids
// |0 0 1 1 2 2|
//
// for B(kOrder=0) is
// |0 0| -> 0,1,2 are the warpids
// |1 1|
// |2 2|
// Matrix coordinates inside a CTA,
// the matrix layout is [2wpt[0], 2] for A and [2, 2wpt[1]] for B.
// e.g., Setting wpt=4, the data layout for A(kOrder=1) is
// |0 0| -> 0,1,2,3 are the warpids
// |0 0|
// |1 1|
// |1 1|
// |2 2|
// |2 2|
// |3 3|
// |3 3|
//
// for B(kOrder=0) is
// |0 1 2 3 0 1 2 3| -> 0,1,2,3 are the warpids
// |0 1 2 3 0 1 2 3|
// Note, for each warp, it handles a 2x2 matrices, that is the coordinate
// address (s0,s1) annotates.

Value matOff[2];
matOff[kOrder ^ 1] =
add(mul(warpId, i32_val(warpOffStride)), // warp offset
mul(nkMatArr, i32_val(matArrStride))); // matrix offset inside a warp
add(mul(warpId, i32_val(warpOffStride)), // warp offset (kOrder=1)
mul(nkMatArr,
i32_val(matArrStride))); // matrix offset inside a warp (kOrder=1)
matOff[kOrder] = kMatArr;

// Physical offset (before swizzling)
Expand All @@ -138,7 +142,13 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,

SmallVector<Value> offs(numPtrs);
Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase));
Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape)));
// To prevent out-of-bound access of B when wpt * 16 > tile_size.
// In such a case, we need to wrap around the offset of B.
// |0 1 2 3 0 1 2 3| -> | 0(0) 1(1) 2(2) 3(3) |
// |0 1 2 3 0 1 2 3| | 0(0) 1(1) 2(2) 3(3) |
// ~~~~~~~ out-of-bound access
Value sOff = urem(add(sOffInMat, mul(sMatOff, i32_val(sMatShape))),
i32_val(tileShape[order[1]]));
for (int i = 0; i < numPtrs; ++i) {
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
cMatOffI = xor_(cMatOffI, phase);
Expand Down Expand Up @@ -631,12 +641,6 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
tensorTy.getShape().end());

// TODO[Superjomn]: transB cannot be accessed in ConvertLayoutOp.
bool transB = false;
if (transB) {
std::swap(shape[0], shape[1]);
}

int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth;
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct LoadStoreConversionBase {
if (!tensorTy)
return 1;
auto contiguity = getContiguity(ptr);
auto pointeeBitWidth = getPointeeBitWidth(tensorTy);
auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy);
// The maximum vector size is 128 bits on NVIDIA GPUs.
return std::min<unsigned>(128 / pointeeBitWidth, contiguity);
}
Expand Down
25 changes: 13 additions & 12 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ namespace {

class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx)
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, bool isROCM)
: ConversionTarget(ctx) {
addLegalDialect<index::IndexDialect>();
addLegalDialect<LLVM::LLVMDialect>();
if (isROCM()) {
if (isROCM) {
addLegalDialect<ROCDL::ROCDLDialect>();
} else {
addLegalDialect<NVVM::NVVMDialect>();
Expand Down Expand Up @@ -135,10 +135,10 @@ struct FuncOpConversion : public FuncOpConversionBase {

class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
explicit TritonLLVMConversionTarget(MLIRContext &ctx, bool isROCM)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
if (isROCM()) {
if (isROCM) {
addLegalDialect<ROCDL::ROCDLDialect>();
} else {
addLegalDialect<NVVM::NVVMDialect>();
Expand All @@ -154,16 +154,16 @@ class ConvertTritonGPUToLLVM
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {

public:
explicit ConvertTritonGPUToLLVM(int computeCapability)
: computeCapability(computeCapability) {}
explicit ConvertTritonGPUToLLVM(int computeCapability, bool isROCM)
: computeCapability(computeCapability), isROCM(isROCM) {}

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
mlir::LowerToLLVMOptions option(context);
option.overrideIndexBitwidth(32);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMConversionTarget target(*context);
TritonLLVMConversionTarget target(*context, isROCM);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);

/* preprocess */
Expand All @@ -181,7 +181,7 @@ class ConvertTritonGPUToLLVM
{
mlir::LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context);
TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM);
RewritePatternSet funcPatterns(context);
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps,
/*benefit=*/1);
Expand Down Expand Up @@ -225,7 +225,7 @@ class ConvertTritonGPUToLLVM
populatePatterns2(populateViewOpToLLVMPatterns);

// Native lowering patterns
if (isROCM()) {
if (isROCM) {
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns,
mlir::gpu::amd::HIP);
} else {
Expand All @@ -237,7 +237,7 @@ class ConvertTritonGPUToLLVM
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();

if (isROCM()) {
if (isROCM) {
TritonGCNConversionTarget gcnTarget(*context);
RewritePatternSet gcnPatterns(context);
populateElementwiseOpToPTXPatterns(typeConverter, gcnPatterns,
Expand Down Expand Up @@ -272,6 +272,7 @@ class ConvertTritonGPUToLLVM
indexCache;

int computeCapability{};
bool isROCM{};

void initSharedMemory(size_t size,
TritonGPUToLLVMTypeConverter &typeConverter) {
Expand Down Expand Up @@ -464,8 +465,8 @@ namespace mlir {
namespace triton {

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability) {
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability);
createConvertTritonGPUToLLVMPass(int computeCapability, bool isROCM) {
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability, isROCM);
}

} // namespace triton
Expand Down
10 changes: 5 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
MLIRContext *ctx, LowerToLLVMOptions &option,
const DataLayoutAnalysis *analysis)
: LLVMTypeConverter(ctx, option, analysis) {
addConversion([&](triton::PointerType type) -> llvm::Optional<Type> {
addConversion([&](triton::PointerType type) -> std::optional<Type> {
return convertTritonPointerType(type);
});
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
addConversion([&](RankedTensorType type) -> std::optional<Type> {
return convertTritonTensorType(type);
});
// Internally store float8 as int8
addConversion([&](mlir::Float8E4M3FNType type) -> llvm::Optional<Type> {
addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E5M2Type type) -> llvm::Optional<Type> {
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
// Internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
addConversion([&](BFloat16Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 16);
});
}
Expand Down
Loading

0 comments on commit 3dd30f7

Please sign in to comment.