From 5630edddae66dc8adabea7cf3ef3a412d114f642 Mon Sep 17 00:00:00 2001 From: Djordje Ramic Date: Thu, 6 Jun 2024 01:06:30 +0000 Subject: [PATCH] Working on threadwise unification --- .../mlir/Dialect/Rock/IR/AccelEmitter.h | 56 ++++- .../mlir/Dialect/Rock/IR/FmaInsnGroup.h | 39 ++++ mlir/include/mlir/Dialect/Rock/IR/RockOps.td | 28 ++- mlir/lib/Dialect/Rock/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/Rock/IR/FmaInsnGroup.cpp | 36 ++++ mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 8 + .../Transforms/BlockwiseGemmToThreadwise.cpp | 11 +- .../Rock/Transforms/GemmToGridwise.cpp | 4 +- .../Transforms/GridwiseGemmToBlockwise.cpp | 7 +- .../Transforms/ThreadwiseGemmLowering.cpp | 168 ++++++++++++++- .../lib/Dialect/Rock/utility/AccelEmitter.cpp | 194 +++++++++++++----- 11 files changed, 479 insertions(+), 73 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Rock/IR/FmaInsnGroup.h create mode 100644 mlir/lib/Dialect/Rock/IR/FmaInsnGroup.cpp diff --git a/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h b/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h index 658d306fda87..fa2601289b35 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h +++ b/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h @@ -25,6 +25,7 @@ #ifndef MLIR_LIB_DIALECT_ROCK_TRANSFORMS_MLIR_ACCEL_EMITTER_H #define MLIR_LIB_DIALECT_ROCK_TRANSFORMS_MLIR_ACCEL_EMITTER_H +#include "mlir/Dialect/Rock/IR/FmaInsnGroup.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Rock/IR/MfmaInsnGroup.h" #include "mlir/Dialect/Rock/IR/TransformMapBuilder.h" @@ -86,7 +87,7 @@ struct AccelEmitter { /// Select the right accelerator based on the set of features and architecture static std::unique_ptr select(GemmFeatures features, Type dataTypeA, Type dataTypeB, StringRef arch, - RockAccelTuningParamAttrInterface tuningParams); + RockTuningParamAttrInterface tuningParams); /// Emit the actual intrinsic in the threadwise operation virtual void emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, @@ -150,15 +151,15 @@ struct AccelEmitter { virtual ~AccelEmitter() {} - enum AccelEmitterKind { AEK_MFMAEmitter, AEK_WMMAEmitter }; + enum AccelEmitterKind { AEK_MFMAEmitter, AEK_WMMAEmitter, AEK_FMAEmitter}; AccelEmitterKind getKind() const { return kind; } protected: - AccelEmitter(StringRef arch, RockAccelTuningParamAttrInterface tuningParams, + AccelEmitter(StringRef arch, RockTuningParamAttrInterface tuningParams, AccelEmitterParams accelEmitterParams, AccelEmitterKind kind); - RockAccelTuningParamAttrInterface tuningParams; + RockTuningParamAttrInterface tuningParams; AccelEmitterParams accelEmitterParams; int64_t waveSize; @@ -170,7 +171,7 @@ struct AccelEmitter { struct MfmaEmitter : public AccelEmitter { MfmaEmitter(MfmaInsnGroup mfmaGroup, StringRef arch, - RockAccelTuningParamAttrInterface tuningParams); + RockTuningParamAttrInterface tuningParams); void emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value argB, Value bufferC, ValueRange regCOffset) override; @@ -208,7 +209,7 @@ struct MfmaEmitter : public AccelEmitter { /// Initialize the emitter parameters for mfma AccelEmitterParams initAccelEmitterParams(MfmaInsnGroup mfmaGroup, - RockAccelTuningParamAttrInterface tuningParams); + RockTuningParamAttrInterface tuningParams); MfmaInsnGroup mfmaGroup; }; @@ -217,7 +218,7 @@ struct MfmaEmitter : public AccelEmitter { struct WmmaEmitter : public AccelEmitter { WmmaEmitter(WmmaInsn wmmaInsn, StringRef arch, - RockAccelTuningParamAttrInterface tuningParams); + RockTuningParamAttrInterface tuningParams); void emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value argB, Value bufferC, ValueRange regCOffset) override; @@ -249,11 +250,50 @@ struct WmmaEmitter : public AccelEmitter { /// Initialize the emitter parameters for wmma AccelEmitterParams initAccelEmitterParams(WmmaInsn wmmaInsn, - RockAccelTuningParamAttrInterface tuningParams); + RockTuningParamAttrInterface tuningParams); // Specifc wmma parameters WmmaInsn wmmaInsn; }; + +// Accel emitter implementation for fma + +struct FmaEmitter : public AccelEmitter { + + FmaEmitter(FmaInsn fmaInsn, StringRef arch, + RockTuningParamAttrInterface tuningParams); + + void emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value argB, + Value bufferC, ValueRange regCOffset) override; + + virtual Value + wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer, + int64_t blockSize, int64_t dInCopyPerThread, + StringRef dName, bool rotateDWithK, + bool doSplitKAcrossThreadsFirst = false) const override; + + virtual RegsAsMatrixSubTiles createAccelGemmOperandTransforms( + OpBuilder &b, Location loc, int64_t kIters, + ArrayRef bidGridLengths, int64_t blockSize, + int64_t dInCopyPerThread, StringRef dName, bool isKContigousDim, + bool rotateDWithK, + bool doSplitKAcrossThreadsFirst = false) const override; + + RegsAsMatrixSubTiles computeOutputTransforms( + OpBuilder &b, Location loc, int64_t mLen, int64_t nLen, int64_t blockSize, + ArrayRef bidGridLengths, int64_t inMPerThread, + int64_t inNPerThread, bool doSwapThreadIterSubDimsForM = false, + bool doSwapThreadIterSubDimsForN = false) override; + +private: + // Initialize the emitter parameters for fma + AccelEmitterParams + initAccelEmitterParams(FmaInsn fmaInsn, + RockTuningParamAttrInterface tuningParams); + + // Specific fma parameters + FmaInsn fmaInsn; +}; } // namespace accel } // namespace rock } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Rock/IR/FmaInsnGroup.h b/mlir/include/mlir/Dialect/Rock/IR/FmaInsnGroup.h new file mode 100644 index 000000000000..f72ca0fcbea1 --- /dev/null +++ b/mlir/include/mlir/Dialect/Rock/IR/FmaInsnGroup.h @@ -0,0 +1,39 @@ +//===- FmaInsnGroup.h - MLIR to C++ for Rock conversion +//---------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// This file implements code selection logic for Fma instructions. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_FMA_INSN_GROUP_H +#define MLIR_FMA_INSN_GROUP_H + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +namespace rock { + +struct FmaInsn { + Type argTypeA; + Type argTypeB; + Type retType; + + public: + static FailureOr select(Type elementTypeA, Type elementTypeB, StringRef arch); +}; + + + +} // namespace rock +} // namespace mlir + +#endif // MLIR_FMA_INSN_GROUP_H \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 53d36a513629..4b064ffa9689 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -390,9 +390,10 @@ def Rock_GridwiseGemmOp : MemRefRankOf:$b, MemRefRankOf:$c, Rock_GemmFeaturesAttr:$features, + StrAttr:$arch, I32Attr:$numCU, I32Attr:$gridSize, - Rock_GeneralGemmParamsAttr:$params)> { + RockTuningParamAttrInterface:$params)> { let summary = "Gridwise GEMM"; let description = [{ The `rock.gridwise_gemm` op computes gridwise GEMM. @@ -415,7 +416,7 @@ def Rock_GridwiseGemmAccelOp : StoreMethodAttr:$storeMethod, I32Attr:$blockSize, I32Attr:$gridSize, - RockAccelTuningParamAttrInterface:$params)> { + RockTuningParamAttrInterface:$params)> { let summary = "Gridwise GEMM accelerated version"; let description = [{ The `rock.gridwise_gemm` op computes gridwise GEMM with acceleration. @@ -1116,7 +1117,9 @@ def Rock_BlockwiseGemmOp: I32Attr:$inNPerThread, UnitAttr:$rotateMWithK, UnitAttr:$rotateNWithK, - Rock_GeneralGemmParamsAttr:$params + StrAttr:$arch, + Rock_GemmFeaturesAttr:$features, + RockTuningParamAttrInterface:$params )> { let summary = "Blockwise GEMM non accelerated version"; let description = [{ @@ -1168,7 +1171,7 @@ def Rock_BlockwiseGemmAccelOp: StrAttr:$arch, Rock_GemmFeaturesAttr:$features, I32Attr:$blockSize, - RockAccelTuningParamAttrInterface:$params)>{ + RockTuningParamAttrInterface:$params)>{ let summary = "Blockwise GEMM accelerated version"; let description = [{ The `rock.block_gemm_v2` op does GEMM at workgroup (block) level. @@ -1215,7 +1218,7 @@ def Rock_ThreadwiseAccelGemmOp: Arg, "dest register view C", [MemRead, MemWrite]>:$matrixC, Variadic:$computeIndices, StrAttr:$arch, Rock_GemmFeaturesAttr:$features, - RockAccelTuningParamAttrInterface:$params)> { + RockTuningParamAttrInterface:$params)> { let summary = "Accelerated GEMM"; let description = [{ The `rock.accel_gemm` op is an abstraction of doing GEMM based on an accelerator. @@ -1229,6 +1232,21 @@ def Rock_ThreadwiseAccelGemmOp: }]; let hasVerifier = 1; } +// threadwise_gemmv2 +def Rock_ThreadwiseGemmOpv2: + Rock_Op<"threadwise_gemmv2">, + Arguments<(ins Arg, "source register view A", [MemRead]>:$matrixA, + Arg, "source register view B", [MemRead]>:$matrixB, + Arg, "dest register view C", [MemRead, MemWrite]>:$matrixC, Variadic:$computeIndices, + StrAttr:$arch, + Rock_GemmFeaturesAttr:$features, + RockTuningParamAttrInterface:$params)> { + let assemblyFormat = [{ + $matrixC `+` `` `=` $matrixA `*` $matrixB `at` `[` $computeIndices `]` `features` `=` $features attr-dict + `:` type($matrixC) `+` `` `=` type($matrixA) `*` type($matrixB) + }]; + let hasVerifier = 1; +} // blockwise_broadcasting_reduction def Rock_BlockwiseBroadcastReduceOp: diff --git a/mlir/lib/Dialect/Rock/IR/CMakeLists.txt b/mlir/lib/Dialect/Rock/IR/CMakeLists.txt index 64d44c8a8bbe..1be6d347a449 100644 --- a/mlir/lib/Dialect/Rock/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Rock/IR/CMakeLists.txt @@ -10,6 +10,7 @@ add_rocmlir_dialect_library(MLIRRockOps RockWriterOpInterface.cpp MfmaInsnGroup.cpp WmmaInsnGroup.cpp + FmaInsnGroup.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Rock diff --git a/mlir/lib/Dialect/Rock/IR/FmaInsnGroup.cpp b/mlir/lib/Dialect/Rock/IR/FmaInsnGroup.cpp new file mode 100644 index 000000000000..1d1ab7149f78 --- /dev/null +++ b/mlir/lib/Dialect/Rock/IR/FmaInsnGroup.cpp @@ -0,0 +1,36 @@ +#include "mlir/Dialect/Rock/IR/FmaInsnGroup.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/Rock/utility/AmdArchDb.h" +#include "mlir/Dialect/Rock/utility/math.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include + +#define DEBUG_TYPE "rock-fma-insn-group" + +using namespace mlir; +using namespace mlir::rock; + +static Type getRetType(Type inputType) { + Builder b(inputType.getContext()); + if (inputType.isInteger(8)) + return b.getI32Type(); + + return b.getF32Type();; +} + +FailureOr FmaInsn::select(mlir::Type elementTypeA, mlir::Type elementTypeB, StringRef arch ){ + LLVM_DEBUG(llvm::dbgs() << "Invoke FMA group selection:\n" + << "elementTypeA: " << elementTypeA << "\n" + << "elementTypeB: " << elementTypeB << "\n" + << "arch: " << arch << "\n"); + + Type retType = getRetType(elementTypeA); + + return FmaInsn{elementTypeA, elementTypeB, retType}; +} \ No newline at end of file diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 74538f9a8e02..98326e42463d 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -1759,6 +1759,14 @@ LogicalResult ThreadwiseGemmOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ThreadwiseGemmOpv2 +//===----------------------------------------------------------------------===// +LogicalResult ThreadwiseGemmOpv2::verify() { + //TO-DO + return success(); +} + //===----------------------------------------------------------------------===// // ThreadwiseAccelGemmOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index ddb74f79cbf0..9fb81f63f971 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -187,7 +187,7 @@ struct BlockwiseGemmRewritePattern int64_t mC = bufferCType.getShape()[0]; int64_t nC = bufferCType.getShape()[1]; - GeneralGemmParamsAttr params = op.getParams(); + GeneralGemmParamsAttr params = op.getParams().cast(); uint32_t blockSize = params.getBlockSize(); int64_t kPerThread = params.getKPerThread(); int64_t mPerThread = params.getMPerThread(); @@ -382,8 +382,8 @@ struct BlockwiseGemmRewritePattern Value reshapedBRegisters = reshapeBuffer( b, loc, threadBAllocOp, {"k", "n", "kpack"}, {kPerThread, nC, kPack}); // Actually do the gemm - this goes inside the look over kOffset - b.create(loc, reshapedARegisters, reshapedBRegisters, - op.getMatrixC()); + b.create(loc, reshapedARegisters, reshapedBRegisters, op.getMatrixC(), + ValueRange{zeroConstantOp,zeroConstantOp,zeroConstantOp,zeroConstantOp}, op.getArchAttr(), op.getFeaturesAttr(), op.getParamsAttr()); return success(); } @@ -402,7 +402,7 @@ struct BlockwiseGemmAccelRewritePattern Location loc = op.getLoc(); StringAttr arch = op.getArchAttr(); - RockAccelTuningParamAttrInterface tuningParams = op.getParams(); + RockAccelTuningParamAttrInterface tuningParams = op.getParams().cast(); int64_t kpackPerBlock = tuningParams.getKpackPerBlock(); int64_t mPerWave = tuningParams.getMPerWave(); int64_t nPerWave = tuningParams.getNPerWave(); @@ -503,7 +503,8 @@ struct BlockwiseGemmAccelRewritePattern Value viewC = accelEmitterPtr->generateThreadwiseViewBufferC( b, loc, adaptor.getMatrixC()); Value k = kLoop.getInductionVar(); - b.create(loc, viewA, viewB, viewC, + + b.create(loc, viewA, viewB, viewC, ValueRange{i, j, k}, arch, op.getFeaturesAttr(), tuningParams); } diff --git a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp index ca8ea84af932..464c876a6684 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp @@ -230,10 +230,10 @@ GemmRewritePattern::matchAndRewrite(GemmOp op, GemmOpAdaptor adaptor, rw.create( loc, a, b, accumulator, op.getArchAttr(), numCUAttr, op.getFeaturesAttr(), op.getStoreMethodAttr(), blockSize, gridSize, - params.cast()); + params.cast()); } else { rw.create(loc, a, b, accumulator, op.getFeaturesAttr(), - numCUAttr, gridSize, + op.getArchAttr(), numCUAttr, gridSize, params.cast()); } diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index c8b2a773e0ca..b379b42a23dc 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -305,7 +305,7 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern { // Obtain critical tuning parameters. uint32_t gridSize = op.getGridSize(); - GeneralGemmParamsAttr tuningParams = op.getParams(); + GeneralGemmParamsAttr tuningParams = op.getParams().cast(); int64_t kpack = tuningParams.getKpack(); // TODO: kPerBlock, as defined in parameter selection etc, // is in units of kPack, not individual k. This should be changed @@ -639,7 +639,8 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern { b.getI32IntegerAttr(copyMPerThread), b.getI32IntegerAttr(copyNPerThread), rotateMWithK ? b.getUnitAttr() : nullptr, - rotateNWithK ? b.getUnitAttr() : nullptr, op.getParamsAttr()); + rotateNWithK ? b.getUnitAttr() : nullptr, + op.getArchAttr(), op.getFeaturesAttr(), op.getParamsAttr()); // LDS barrier. // This barrier prevents halo part of outputs having weird values. @@ -2427,7 +2428,7 @@ struct GridwiseGemmAccelRewritePattern StringRef arch = op.getArch(); uint32_t blockSize = op.getBlockSize(); uint32_t gridSize = op.getGridSize(); - RockAccelTuningParamAttrInterface tuningParams = op.getParams(); + RockAccelTuningParamAttrInterface tuningParams = op.getParams().cast(); int64_t kpack = tuningParams.getKpack(); // TODO: kPerBlock, as defined in parameter selection etc, // is in units of kPack, not individual k. This should be changed diff --git a/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp b/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp index 933286a31a32..72b2e62e9f76 100644 --- a/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp @@ -89,6 +89,168 @@ struct ThreadwiseReadIntoRewritePattern ConversionPatternRewriter &b) const final; }; +//===----------------------------------------------------------------------===// +// ThreadwiseGemmV2 lowering. +//===----------------------------------------------------------------------===// +struct ThreadwiseGemmv2RewritePattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + TransformMapAttr normalizeView(OpBuilder &b, Location loc, + ArrayRef names, + ArrayRef sizes, bool ignoreStartShape, + DenseSet ignoreNames) const { + unsigned pos = 0; + unsigned newPos = 0; + TopDownTMBuilder td(b, names, sizes, loc); + // Convert the normalizedView to a real view by ignoring + // the names contained in `ignoreNames` and letting the rest pass through + for (pos = 0; pos < names.size(); pos++) { + if (ignoreNames.contains(names[pos])) { + td.ignore(names[pos]); + } else { + td.passThrough({newPos++}, {pos}); + } + } + return td.get(); + } + + LogicalResult matchAndRewrite(ThreadwiseGemmOpv2 op, + ThreadwiseGemmOpv2Adaptor adaptor, + ConversionPatternRewriter &b) const override { + + bool isMfma = rock::bitEnumContainsAll(op.getFeatures(), GemmFeatures::mfma); + bool isWmma = rock::bitEnumContainsAll(op.getFeatures(), GemmFeatures::wmma); + + Location loc = op.getLoc(); + + Value bufferA = adaptor.getMatrixA(); + Value bufferB = adaptor.getMatrixB(); + Value bufferC = adaptor.getMatrixC(); + + auto dataTypeA = adaptor.getMatrixA().getType().cast().getElementType(); + auto dataTypeB = adaptor.getMatrixB().getType().cast().getElementType(); + + auto bufferAShape = op.getMatrixA().getType().getShape(); + auto bufferCShape = op.getMatrixC().getType().getShape(); + + if (dataTypeA.isa()) { + dataTypeA = dataTypeA.cast().getElementType(); + } + if (dataTypeB.isa()) { + dataTypeB = dataTypeB.cast().getElementType(); + } + + auto emitter = rock::accel::AccelEmitter::select( + op.getFeatures(), dataTypeA, dataTypeB, op.getArch(), op.getParams()); + + if (!emitter) + return emitError(loc) + << "Failed to select any accelerator instruction.\n"; + + rock::accel::AccelEmitterParams params = emitter->getParams(); + Type argTypeA = params.argTypeA; + Type argTypeB = params.argTypeB; + + // llvm::outs() << "\n**** ARGTYPE: " << argTypeA << "\n"; + + TransformMapAttr normalizedViewA, normalizedViewB, normalizedViewC; + + // Sizes of the [i,j,k] axis + int64_t i = *(bufferCShape.end() - 2); + int64_t j = bufferCShape.back(); + int64_t k = bufferAShape.back(); + + if(isMfma || isWmma){ + normalizedViewA = + normalizeView(b, loc, {"i", "j", "k"}, {i, j, k}, true, {"j"}); + normalizedViewB = + normalizeView(b, loc, {"i", "j", "k"}, {i, j, k}, true, {"i"}); + normalizedViewC = normalizeView( + b, loc, {"i", "j", "k"}, {i, j, k}, false, {"k"}); + } else { + int64_t kPack = bufferAShape[2]; + + SmallVector dimensions = {k, i, j, kPack}; + SmallVector strides = {1, 1, 1, 1}; + + normalizedViewA = normalizeView(b, loc, {"k", "i", "j", "kpack"}, dimensions, true, {"j"}); + normalizedViewB = normalizeView(b, loc, {"k", "i", "j", "kpack"}, dimensions, true, {"i"}); + normalizedViewC = normalizeView(b, loc, {"k", "i", "j", "kpack"}, dimensions, true, {"k", "kpack"}); + } + + auto [rawBufferA, bufferViewA, sourceANeeds64BitIdx] = + untransform(b, bufferA, normalizedViewA); + auto [rawBufferB, bufferViewB, sourceBNeeds64BitIdx] = + untransform(b, bufferB, normalizedViewB); + auto [rawBufferC, bufferViewC, dstNeeds64BitIdx] = + untransform(b, bufferC, normalizedViewC); + assert(!sourceANeeds64BitIdx && "Registers shouldn't need 64-bit indexing"); + assert(!sourceBNeeds64BitIdx && "Registers shouldn't need 64-bit indexing"); + assert(!dstNeeds64BitIdx && "Registers shouldn't need 64-bit indexing"); + + // Loop properties + auto computeStart = llvm::to_vector(op.getComputeIndices()); + + if(isMfma || isWmma){ + + // Emit the loop + auto accelLoop = b.create( + loc, ArrayRef{computeStart, computeStart, computeStart}, + ArrayRef{bufferViewA, bufferViewB, bufferViewC}, + /*bounds=*/ArrayRef{1, 1, 1}, + /*strides=*/ArrayRef{1, 1, 1}, + /*forceUnroll=*/true, /*useIndexDiffs=*/true); + { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(accelLoop.getBody()); + auto coordsA = accelLoop.getLowerCoords(/*domain=*/0); + auto coordsB = accelLoop.getLowerCoords(/*domain=*/1); + auto coordsC = accelLoop.getLowerCoords(/*domain=*/2); + + Value argA = b.create(loc, argTypeA, rawBufferA, coordsA); + Value argB = b.create(loc, argTypeB, rawBufferB, coordsB); + // llvm::outs() << "\n*** argA: " << argA << "\n" + // << "\n*** argB: " << argB << "\n"; + emitter->emitThreadwiseLoop(b, loc, argA, argB, rawBufferC, coordsC); + } + b.eraseOp(op); + return success(); + } else { + + int64_t k = bufferAShape.back(); + int64_t i = *(bufferCShape.end() - 2); + int64_t kPack = bufferAShape[2]; + int64_t j = bufferCShape.back(); + + SmallVector dimensions = {k, i, j, kPack}; + + auto gemmLoop = b.create( + loc, ArrayRef{computeStart, computeStart, computeStart}, + ArrayRef{bufferViewA, bufferViewB, bufferViewC}, dimensions, + /*strides=*/std::nullopt, /*forceUnroll=*/true, + /*useIndexDiffs=*/false); + + { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(gemmLoop.getBody()); + + auto coordsA = gemmLoop.getLowerCoords(/*domain=*/0); + auto coordsB = gemmLoop.getLowerCoords(/*domain=*/1); + auto coordsC = gemmLoop.getLowerCoords(/*domain=*/2); + + Value argA = b.create(loc, argTypeA, rawBufferA, coordsA); + Value argB = b.create(loc, argTypeB, rawBufferB, coordsB); + + emitter->emitThreadwiseLoop(b, loc, argA, argB, rawBufferC, coordsC); + + b.eraseOp(op); + } + return success(); + } + } +}; + //===----------------------------------------------------------------------===// // ThreadwiseGemm lowering. //===----------------------------------------------------------------------===// @@ -234,7 +396,7 @@ struct ThreadwiseAccelGemmRewritePattern ConversionPatternRewriter &b) const override { Location loc = op.getLoc(); - RockAccelTuningParamAttrInterface tuningParams = op.getParams(); + RockAccelTuningParamAttrInterface tuningParams = op.getParams().cast(); auto dataTypeA = adaptor.getMatrixA().getType().cast().getElementType(); @@ -796,14 +958,14 @@ void RockThreadwiseGemmLoweringPass::runOnOperation() { } ConversionTarget target(*ctx); - target.addIllegalOp(); + target.addIllegalOp(); target.addLegalDialect(); target.addLegalOp(); RewritePatternSet patterns(ctx); - patterns.add( + patterns.add( ctx); if (failed(applyPartialConversion(op, target, std::move(patterns)))) return signalPassFailure(); diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index 4a73941d2c22..cf6e93705c27 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -23,6 +23,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Rock/IR/AccelEmitter.h" +#include "mlir/Dialect/Rock/Tuning/GeneralGemmBlockStructure.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Rock/utility/AmdArchDb.h" #include "mlir/Dialect/Rock/utility/loweringUtils.h" @@ -38,7 +39,7 @@ using namespace mlir::rock::accel; // ************************ AccelEmitter::AccelEmitter(StringRef arch, - RockAccelTuningParamAttrInterface tuningParams, + RockTuningParamAttrInterface tuningParams, AccelEmitterParams accelEmitterParams, AccelEmitterKind kind) : tuningParams(tuningParams), accelEmitterParams(accelEmitterParams), @@ -130,27 +131,110 @@ Value AccelEmitter::generateThreadwiseViewBufferC(PatternRewriter &b, return viewC; } +// ************************** +// Fma accelerator interface +// ************************** + +FmaEmitter::FmaEmitter(FmaInsn fmaInsn, StringRef arch, + RockTuningParamAttrInterface tuningParams) + : AccelEmitter{arch, tuningParams, + initAccelEmitterParams(fmaInsn, tuningParams), AccelEmitterKind::AEK_FMAEmitter}, + fmaInsn(fmaInsn) {} + +AccelEmitterParams FmaEmitter::initAccelEmitterParams( + FmaInsn fmaInsn, RockTuningParamAttrInterface rawTuningParams) { + AccelEmitterParams params; + + auto tuningParams = rawTuningParams.dyn_cast(); + + params.argTypeA = fmaInsn.argTypeA; + params.argTypeB = fmaInsn.argTypeB; + + //TO-DO + + return params; +} + +void FmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value argB, + Value bufferC, ValueRange regCOffset){ + + Type dataType = fmaInsn.argTypeA; + + int64_t loadKpackLen = 1; + auto abType = VectorType::get(loadKpackLen, dataType); + + Value zeroConst = b.createOrFold(loc, 0); + Value cVal = b.create(loc, dataType, bufferC, regCOffset); + Value aVector = b.create(loc, abType, argA); + Value bVector = b.create(loc, abType, argB); + Value cVector = b.create(loc, abType, cVal); + + Value result; + if (dataType.isa()) { + Value mul = b.create(loc, aVector, bVector); + result = b.create(loc, mul, cVector); + result = b.create(loc, result, zeroConst); + } else if (dataType.isa()) { + result = b.create(loc, aVector, bVector, cVector); + result = b.create(loc, result, zeroConst); + } else { + llvm_unreachable("Validation should make this ints or floats only"); + } + + b.create(loc, result, bufferC, regCOffset); +} + +Value FmaEmitter::wrapLDSBufferForLoad(OpBuilder &b, Location loc, + Value buffer, int64_t blockSize, + int64_t dInCopyPerThread, + StringRef dName, bool rotateDWithK, + bool doSplitKAcrossThreadsFirst) const { + + //TO-DO +} + +RegsAsMatrixSubTiles FmaEmitter::createAccelGemmOperandTransforms( + OpBuilder &b, Location loc, int64_t kIters, + ArrayRef bidGridLengths, int64_t blockSize, + int64_t dInCopyPerThread, StringRef dName, bool isKContigousDim, + bool rotateDWithK, bool doSplitKAcrossThreadsFirst) const { + + //TO-DO + } + +RegsAsMatrixSubTiles FmaEmitter::computeOutputTransforms( + OpBuilder &b, Location loc, int64_t mLen, int64_t nLen, int64_t blockSize, + ArrayRef bidGridLengths, int64_t inMPerThread, + int64_t inNPerThread, bool doSwapThreadIterSubDimsForM, + bool doSwapThreadIterSubDimsForN){ + + //TO-DO +} + + // ************************** // Mfma accelerator interface // ************************** MfmaEmitter::MfmaEmitter(MfmaInsnGroup mfmaGroup, StringRef arch, - RockAccelTuningParamAttrInterface tuningParams) + RockTuningParamAttrInterface tuningParams) : AccelEmitter{arch, tuningParams, initAccelEmitterParams(mfmaGroup, tuningParams), AccelEmitterKind::AEK_MFMAEmitter}, mfmaGroup{mfmaGroup} {} AccelEmitterParams MfmaEmitter::initAccelEmitterParams( - MfmaInsnGroup mfmaGroup, RockAccelTuningParamAttrInterface tuningParams) { + MfmaInsnGroup mfmaGroup, RockTuningParamAttrInterface tuningParams) { AccelEmitterParams params; MfmaInsnAttr mfmaAttr = mfmaGroup.getInsnAttr(); // Extract relevant tuning parameters - int64_t kpackPerBlock = tuningParams.getKpackPerBlock(); - int64_t mPerWave = tuningParams.getMPerWave(); - int64_t nPerWave = tuningParams.getNPerWave(); - int64_t kPack = tuningParams.getKpack(); + XdlopsGemmDerivedParamsAttr mfmaParams = + tuningParams.cast(); + int64_t kpackPerBlock = mfmaParams.getKpackPerBlock(); + int64_t mPerWave = mfmaParams.getMPerWave(); + int64_t nPerWave = mfmaParams.getNPerWave(); + int64_t kPack = mfmaParams.getKpack(); int64_t K = kpackPerBlock * kPack; // Accelerator parameters @@ -247,10 +331,12 @@ RegsAsMatrixSubTiles MfmaEmitter::computeOutputTransforms( bool doSwapThreadIterSubDimsForN) { // Extract relevant tuning parameters - int64_t mPerBlock = tuningParams.getMPerBlock(); - int64_t nPerBlock = tuningParams.getNPerBlock(); - int64_t mPerWave = tuningParams.getMPerWave(); - int64_t nPerWave = tuningParams.getNPerWave(); + XdlopsGemmDerivedParamsAttr mfmaParams = + tuningParams.cast(); + int64_t mPerBlock = mfmaParams.getMPerBlock(); + int64_t nPerBlock = mfmaParams.getNPerBlock(); + int64_t mPerWave = mfmaParams.getMPerWave(); + int64_t nPerWave = mfmaParams.getNPerWave(); // Extract relevant emitter parameters int64_t mRepeats = accelEmitterParams.mRepeats; @@ -492,12 +578,14 @@ Value MfmaEmitter::wrapLDSBufferForLoad(OpBuilder &b, Location loc, StringRef otherWaveDim = dName == "m" ? "wave_n" : "wave_m"; // Extract relevant tuning parameters - int64_t mPerWave = tuningParams.getMPerWave(); - int64_t nPerWave = tuningParams.getNPerWave(); - int64_t kPerBlock = tuningParams.getKpackPerBlock(); - int64_t mPerBlock = tuningParams.getMPerBlock(); - int64_t nPerBlock = tuningParams.getNPerBlock(); - int64_t kPack = tuningParams.getKpack(); + XdlopsGemmDerivedParamsAttr mfmaParams = + tuningParams.cast(); + int64_t mPerWave = mfmaParams.getMPerWave(); + int64_t nPerWave = mfmaParams.getNPerWave(); + int64_t kPerBlock = mfmaParams.getKpackPerBlock(); + int64_t mPerBlock = mfmaParams.getMPerBlock(); + int64_t nPerBlock = mfmaParams.getNPerBlock(); + int64_t kPack = mfmaParams.getKpack(); // Extract relevant emitter parameters MfmaInsnAttr mfmaAttr = mfmaGroup.getInsnAttr(); @@ -639,12 +727,14 @@ RegsAsMatrixSubTiles MfmaEmitter::createAccelGemmOperandTransforms( dName == "m" ? bidGridLengths[1] : bidGridLengths[2]; // Extract relevant tuning parameters - int64_t mPerWave = tuningParams.getMPerWave(); - int64_t nPerWave = tuningParams.getNPerWave(); - int64_t kPackPerBlock = tuningParams.getKpackPerBlock(); - int64_t mPerBlock = tuningParams.getMPerBlock(); - int64_t nPerBlock = tuningParams.getNPerBlock(); - int64_t kPack = tuningParams.getKpack(); + XdlopsGemmDerivedParamsAttr mfmaParams = + tuningParams.cast(); + int64_t mPerWave = mfmaParams.getMPerWave(); + int64_t nPerWave = mfmaParams.getNPerWave(); + int64_t kPackPerBlock = mfmaParams.getKpackPerBlock(); + int64_t mPerBlock = mfmaParams.getMPerBlock(); + int64_t nPerBlock = mfmaParams.getNPerBlock(); + int64_t kPack = mfmaParams.getKpack(); // Extract relevant emitter parameters MfmaInsnAttr mfmaAttr = mfmaGroup.getInsnAttr(); @@ -922,19 +1012,20 @@ LogicalResult MfmaEmitter::validateAcceleratorProperties() { // ************************** WmmaEmitter::WmmaEmitter(WmmaInsn wmmaInsn, StringRef arch, - RockAccelTuningParamAttrInterface tuningParams) + RockTuningParamAttrInterface tuningParams) : AccelEmitter{arch, tuningParams, initAccelEmitterParams(wmmaInsn, tuningParams), AccelEmitterKind::AEK_WMMAEmitter}, wmmaInsn(wmmaInsn) {} AccelEmitterParams WmmaEmitter::initAccelEmitterParams( - WmmaInsn wmmaInsn, RockAccelTuningParamAttrInterface tuningParams) { + WmmaInsn wmmaInsn, RockTuningParamAttrInterface tuningParams) { AccelEmitterParams params; // Extract relevant tuning parameters - int64_t kpackPerBlock = tuningParams.getKpackPerBlock(); - int64_t kPack = tuningParams.getKpack(); + WmmaGemmParamsAttr wmmaParams = tuningParams.cast(); + int64_t kpackPerBlock = wmmaParams.getKpackPerBlock(); + int64_t kPack = wmmaParams.getKpack(); params.mRepeats = wmmaInsn.mRepeats; params.nRepeats = wmmaInsn.nRepeats; @@ -959,12 +1050,13 @@ Value WmmaEmitter::wrapLDSBufferForLoad(OpBuilder &b, Location loc, bool doSplitKAcrossThreadsFirst) const { // Extract relevant tuning parameters - int64_t mPerBlock = tuningParams.getMPerBlock(); - int64_t nPerBlock = tuningParams.getNPerBlock(); - int64_t kPerBlock = tuningParams.getKpackPerBlock(); - int64_t mPerWave = tuningParams.getMPerWave(); - int64_t nPerWave = tuningParams.getNPerWave(); - int64_t kPack = tuningParams.getKpack(); + WmmaGemmParamsAttr wmmaParams = tuningParams.cast(); + int64_t mPerBlock = wmmaParams.getMPerBlock(); + int64_t nPerBlock = wmmaParams.getNPerBlock(); + int64_t kPerBlock = wmmaParams.getKpackPerBlock(); + int64_t mPerWave = wmmaParams.getMPerWave(); + int64_t nPerWave = wmmaParams.getNPerWave(); + int64_t kPack = wmmaParams.getKpack(); // Extract relevant emitter parameters int64_t inputLen = wmmaInsn.inputLen; @@ -1051,12 +1143,13 @@ RegsAsMatrixSubTiles WmmaEmitter::createAccelGemmOperandTransforms( dName == "m" ? bidGridLengths[1] : bidGridLengths[2]; // Extract relevant tuning parameters - int64_t mPerBlock = tuningParams.getMPerBlock(); - int64_t nPerBlock = tuningParams.getNPerBlock(); - int64_t kPackPerBlock = tuningParams.getKpackPerBlock(); - int64_t mPerWave = tuningParams.getMPerWave(); - int64_t nPerWave = tuningParams.getNPerWave(); - int64_t kPack = tuningParams.getKpack(); + WmmaGemmParamsAttr wmmaParams = tuningParams.cast(); + int64_t mPerBlock = wmmaParams.getMPerBlock(); + int64_t nPerBlock = wmmaParams.getNPerBlock(); + int64_t kPackPerBlock = wmmaParams.getKpackPerBlock(); + int64_t mPerWave = wmmaParams.getMPerWave(); + int64_t nPerWave = wmmaParams.getNPerWave(); + int64_t kPack = wmmaParams.getKpack(); // Extract relevant emitter parameters int64_t inputLen = wmmaInsn.inputLen; @@ -1305,10 +1398,11 @@ RegsAsMatrixSubTiles WmmaEmitter::computeOutputTransforms( bool doSwapThreadIterSubDimsForN) { // Extract relevant tuning parameters - int64_t mPerBlock = tuningParams.getMPerBlock(); - int64_t nPerBlock = tuningParams.getNPerBlock(); - int64_t mPerWave = tuningParams.getMPerWave(); - int64_t nPerWave = tuningParams.getNPerWave(); + WmmaGemmParamsAttr wmmaParams = tuningParams.cast(); + int64_t mPerBlock = wmmaParams.getMPerBlock(); + int64_t nPerBlock = wmmaParams.getNPerBlock(); + int64_t mPerWave = wmmaParams.getMPerWave(); + int64_t nPerWave = wmmaParams.getNPerWave(); // Extract relevant emitter parameters int64_t mRepeats = accelEmitterParams.mRepeats; @@ -1471,7 +1565,7 @@ RegsAsMatrixSubTiles WmmaEmitter::computeOutputTransforms( std::unique_ptr AccelEmitter::select(GemmFeatures features, Type dataTypeA, Type dataTypeB, StringRef arch, - RockAccelTuningParamAttrInterface tuningParams) { + RockTuningParamAttrInterface tuningParams) { bool isMfma = rock::bitEnumContainsAll(features, GemmFeatures::mfma); bool isWmma = rock::bitEnumContainsAll(features, GemmFeatures::wmma); if (isMfma) { @@ -1486,15 +1580,21 @@ AccelEmitter::select(GemmFeatures features, Type dataTypeA, Type dataTypeB, tuningParams); } else if (isWmma) { int64_t waveSize = rock::lookupArchInfo(arch).waveSize; + WmmaGemmParamsAttr wmmaParams = tuningParams.cast(); auto maybeWmmaInsnGroup = WmmaInsn::select(dataTypeA, dataTypeB, waveSize, - tuningParams.getMPerWave(), - tuningParams.getNPerWave()); + wmmaParams.getMPerWave(), + wmmaParams.getNPerWave()); if (failed(maybeWmmaInsnGroup)) { return nullptr; } return std::make_unique(*maybeWmmaInsnGroup, arch, tuningParams); } else { - return nullptr; + auto fmaTuningParams = tuningParams.cast(); + auto maybeFmaInsnGroup = FmaInsn::select(dataTypeA, dataTypeB, arch); + if (failed(maybeFmaInsnGroup)) { + return nullptr; + } + return std::make_unique(*maybeFmaInsnGroup, arch, tuningParams); } }