diff --git a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h index e20f3c66b782..381202c8744d 100644 --- a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h +++ b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h @@ -269,7 +269,8 @@ class BasePopulateParams { // Succced if `params` should be included in a "full" tuning space that // excludes those known to not yeild good performance on the problem described // in `info`. This function uses hardcoded heuristics. - virtual LogicalResult couldBePerformant(const PopulateParamsInfo &info, + virtual LogicalResult couldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, const InitParamType ¶ms) = 0; // Convert the provided InitParamType into an MLIR `Attribute`. @@ -316,7 +317,8 @@ class PopulateParams : public BasePopulateParams { const PopulateParamsInfo &info, const InitParamsNonAccel ¶ms) override; - LogicalResult couldBePerformant(const PopulateParamsInfo &info, + LogicalResult couldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, const InitParamsNonAccel ¶ms) override; int64_t calculatePaddingAmount(const InitParamsNonAccel ¶ms, @@ -357,7 +359,8 @@ class PopulateParamsAccel : public BasePopulateParams { const PopulateParamsInfo &info, const InitParamsAccel ¶ms) override; - LogicalResult couldBePerformant(const PopulateParamsInfo &info, + LogicalResult couldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, const InitParamsAccel ¶ms) override; virtual LogicalResult @@ -376,9 +379,10 @@ class PopulateParamsAccel : public BasePopulateParams { /// The actual implementation of couldBePerformant(), which shouldn't exist /// once we merge gridwise_gemm and gridwise_gemm_accel and thus flatten /// out the class heirachy in this file. - virtual LogicalResult specificCouldBePerformant(const InitParamsAccel ¶ms, - Type dataTypeA, - Type dataTypeB) = 0; + virtual LogicalResult specificCouldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, + const InitParamsAccel ¶ms) = 0; + }; // @@ -413,9 +417,9 @@ class PopulateParamsXDL : public PopulateParamsAccel { bool enableDPerWaveFiltering = true) override; protected: - LogicalResult specificCouldBePerformant(const InitParamsAccel ¶ms, - Type dataTypeA, - Type dataTypeB) override; + LogicalResult specificCouldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, + const InitParamsAccel ¶ms) override; }; // @@ -448,9 +452,9 @@ class PopulateParamsWmma : public PopulateParamsAccel { bool enableDPerWaveFiltering = true) override; protected: - LogicalResult specificCouldBePerformant(const InitParamsAccel ¶ms, - Type dataTypeA, - Type dataTypeB) override; + LogicalResult specificCouldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, + const InitParamsAccel ¶ms) override; }; } // namespace rock diff --git a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp index b19d86aa1a7a..923ece9fe647 100644 --- a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Rock/utility/AmdArchDb.h" #include "mlir/Dialect/Rock/utility/loweringUtils.h" #include "mlir/Dialect/Rock/utility/math.h" +#include "mlir/Dialect/Rock/IR/AccelEmitter.h" #include "mlir/Support/LogicalResult.h" #include "llvm/Support/Debug.h" @@ -212,9 +213,11 @@ PopulateParams::paramsProbablyValid(OpBuilder &b, } LogicalResult -PopulateParams::couldBePerformant(const PopulateParamsInfo &info, +PopulateParams::couldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, const InitParamsNonAccel ¶ms) { // Implement this if needed. + (void)b; (void)info; (void)params; return success(); @@ -336,9 +339,9 @@ PopulateParamsAccel::paramsProbablyValid(OpBuilder &b, } LogicalResult -PopulateParamsAccel::couldBePerformant(const PopulateParamsInfo &info, +PopulateParamsAccel::couldBePerformant(OpBuilder &b, const PopulateParamsInfo &info, const InitParamsAccel ¶ms) { - return specificCouldBePerformant(params, info.gemmAType, info.gemmBType); + return specificCouldBePerformant(b, info, params); } LogicalResult PopulateParamsAccel::obtainTuningParameters( @@ -693,12 +696,33 @@ PopulateParamsXDL::getTuningParameters(KernelType opType, Type dataTypeA, } LogicalResult -PopulateParamsXDL::specificCouldBePerformant(const InitParamsAccel ¶ms, - Type dataTypeA, Type dataTypeB) { - // Implement this if needed. - (void)params; - (void)dataTypeA; - (void)dataTypeB; +PopulateParamsXDL::specificCouldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, + const InitParamsAccel ¶ms) { + Attribute params0 = getGemmParamsAttr(b, params); + RockAccelTuningParamAttrInterface accelParams0; + if (auto xdlopsParams0 = dyn_cast(params0)) { + auto xdlopsDerivedParams0 = XdlopsGemmDerivedParamsAttr::get(xdlopsParams0); + accelParams0 = xdlopsDerivedParams0; + } else { + accelParams0 = cast(params0); + } + auto accelEmitterPtr = accel::AccelEmitter::select( + info.gemmFeatures, info.gemmAType, info.gemmBType, StringRef(info.arch), accelParams0); + + if (!accelEmitterPtr) + return failure(); + + rock::accel::AccelEmitterParams accelParams = accelEmitterPtr->getParams(); + + int64_t numOutputVectorElements = accelParams.numOutputVectorElements(); + + // would be best to have register count be a part of arch, is not necessarily totalVGPRPerEu + if(numOutputVectorElements > 256) { + return failure(); + } + + return success(); } @@ -913,12 +937,13 @@ PopulateParamsWmma::getTuningParameters(KernelType opType, Type dataTypeA, } LogicalResult -PopulateParamsWmma::specificCouldBePerformant(const InitParamsAccel ¶ms, - Type dataTypeA, Type dataTypeB) { +PopulateParamsWmma::specificCouldBePerformant(OpBuilder &b, + const PopulateParamsInfo &info, + const InitParamsAccel ¶ms) { // Implement this if needed. + (void)b; + (void)info; (void)params; - (void)dataTypeA; - (void)dataTypeB; return success(); } diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 113ee996e4ba..c652ce2227c1 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -272,7 +272,7 @@ void createGemmTuningRangeBF(TuningParamSet *newSpace, b, info, gemmParams)) && (kind == TuningParamSetKind::Exhaustive || succeeded( - tuningInfo.couldBePerformant(info, gemmParams)))) + tuningInfo.couldBePerformant(b, info, gemmParams)))) newSpace->tuningRange.push_back( cast( tuningInfo.getGemmParamsAttr(b, gemmParams))); @@ -309,7 +309,7 @@ void createGemmTuningRangeBF(TuningParamSet *newSpace, gemmParams)) && (kind == TuningParamSetKind::Exhaustive || succeeded( - tuningInfo.couldBePerformant(info, gemmParams)))) + tuningInfo.couldBePerformant(b, info, gemmParams)))) newSpace->tuningRange.push_back( cast( tuningInfo.getGemmParamsAttr(b, gemmParams))); @@ -340,7 +340,7 @@ void createGemmTuningRangeBF(TuningParamSet *newSpace, gemmParams)) && (kind == TuningParamSetKind::Exhaustive || succeeded( - tuningInfo.couldBePerformant(info, gemmParams)))) + tuningInfo.couldBePerformant(b, info, gemmParams)))) newSpace->tuningRange.push_back( cast( tuningInfo.getGemmParamsAttr(b, gemmParams)));