diff --git a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h index e20f3c66b782..d89474172ac1 100644 --- a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h +++ b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h @@ -385,19 +385,32 @@ class PopulateParamsAccel : public BasePopulateParams { // Xdlops interface // class PopulateParamsXDL : public PopulateParamsAccel { - static constexpr size_t nInitParameters = 40; + static constexpr size_t nInitParametersConv = 20; // Initial tuning parameters for forward convolution and backward // convolution. - static const InitParamsAccel initParameters[nInitParameters]; + static const InitParamsAccel initParametersConv[nInitParametersConv]; - static constexpr size_t nInitParametersFp16 = 40; + static constexpr size_t nInitParametersFp16Conv = 20; // Tuning parameters for fp16/bf16 convolutions. - static const InitParamsAccel initParametersFp16[nInitParametersFp16]; + static const InitParamsAccel initParametersFp16Conv[nInitParametersFp16Conv]; - static constexpr size_t nInitParametersForward8Bit = 40; + static constexpr size_t nInitParametersForward8BitConv = 20; // Tuning parameters for i8 convolutions. static const InitParamsAccel - initParametersForward8Bit[nInitParametersForward8Bit]; + initParametersForward8BitConv[nInitParametersForward8BitConv]; + + static constexpr size_t nInitParametersGemm = 20; + // Initial tuning parameters for gemm. + static const InitParamsAccel initParametersGemm[nInitParametersGemm]; + + static constexpr size_t nInitParametersFp16Gemm = 20; + // Tuning parameters for fp16/bf16 gemm. + static const InitParamsAccel initParametersFp16Gemm[nInitParametersFp16Gemm]; + + static constexpr size_t nInitParametersForward8BitGemm = 20; + // Tuning parameters for i8 gemm. + static const InitParamsAccel + initParametersForward8BitGemm[nInitParametersForward8BitGemm]; public: std::vector diff --git a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp index bd4442049442..bf7d4b2e2705 100644 --- a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp @@ -394,138 +394,153 @@ PopulateParamsAccel::obtainTuningParameters(RockGemmWrapperInterface op, /// Xdlops acceleration // clang-format off const InitParamsAccel -PopulateParamsXDL::initParameters[PopulateParamsXDL::nInitParameters] = { +PopulateParamsXDL::initParametersGemm[PopulateParamsXDL::nInitParametersGemm] = { // M/block N/block K/block M/wave N/wave kPack splitKFactor forceUnroll bCopyMore - {256, 256, 2, 128, 32, 4, 1, true, true}, - {256, 64, 8, 128, 32, 1, 1, true, true}, - {128, 128, 8, 64, 16, 4, 1, true, true}, - {128, 128, 4, 128, 32, 4, 1, true, true}, - {128, 128, 2, 32, 32, 8, 1, true, true}, - {128, 64, 8, 64, 16, 1, 1, true, true}, - {128, 64, 8, 32, 32, 4, 1, true, true}, - {128, 64, 8, 32, 16, 1, 1, true, true}, - {128, 64, 4, 32, 32, 4, 1, true, true}, - {128, 64, 2, 128, 32, 4, 1, true, true}, - {128, 32, 4, 128, 16, 4, 1, true, true}, - {128, 16, 4, 32, 16, 8, 1, true, true}, - {64, 256, 8, 64, 16, 4, 1, true, true}, - {64, 128, 4, 64, 32, 1, 1, true, true}, - {64, 128, 4, 64, 16, 4, 1, true , true}, - {64, 128, 4, 32, 16, 4, 1, true, true}, - {64, 128, 2, 32, 32, 8, 1, true, true}, - {64, 64, 8, 32, 32, 4, 1, true, true}, - {64, 64, 8, 16, 16, 4, 1, true, true}, - {64, 64, 8, 32, 16, 4, 1, true, true}, - {64, 64, 8, 16, 16, 8, 1, true, true}, - {64, 64, 4, 32, 16, 4, 1, true, true}, - {64, 64, 4, 16, 16, 8, 1, true, true}, - {64, 64, 8, 64, 16, 8, 1, true, true}, - {64, 32, 4, 32, 16, 8, 1, true, true}, - {64, 32, 8, 16, 16, 4, 1, true, true}, - {64, 32, 8, 16, 16, 4, 1, true, true}, - {64, 16, 8, 16, 16, 8, 1, true, true}, - {32, 128, 8, 32, 16, 1, 1, true, true}, - {32, 128, 8, 16, 16, 4, 1, true , true}, - {32, 64, 8, 32, 16, 4, 1, true, true}, - {32, 64, 4, 32, 16, 4, 1, true, true}, - {32, 32, 8, 16, 16, 8, 1, true, true}, - {32, 32, 8, 16, 16, 4, 1, true, true}, - {32, 16, 8, 16, 16, 8, 1, true, true}, - {32, 16, 4, 16, 16, 8, 1, true, true}, - {16, 32, 4, 16, 16, 4, 1, true, true}, - {16, 32, 8, 16, 16, 8, 1, true, true}, - {16, 16, 4, 16, 16, 4, 1, true, true}, - {16, 16, 8, 16, 16, 8, 1, true, true} + {64,64,8,32,32,4,1,true,true}, + {16,32,8,16,16,8,1,true,true}, + {64,64,4,32,16,4,1,true,true}, + {64,64,8,32,32,8,1,true,true}, + {64,64,8,32,16,4,1,true,true}, + {64,32,8,16,16,4,1,true,true}, + {128,128,4,32,32,8,1,true,true}, + {32,64,8,16,16,4,1,true,true}, + {128,128,4,64,16,4,1,true,true}, + {64,256,8,64,16,1,1,true,true}, + {64,64,4,32,32,8,1,true,true}, + {64,64,4,16,16,4,1,true,true}, + {64,64,4,16,16,8,1,true,true}, + {64,64,8,16,16,4,1,true,true}, + {64,64,8,64,16,4,1,true,true}, + {64,128,8,64,16,1,1,true,true}, + {64,64,4,32,16,8,1,true,true}, + {64,128,4,32,32,4,1,true,true}, + {64,256,4,32,32,4,1,true,true}, + {128,128,8,64,16,4,1,true,true} }; const InitParamsAccel -PopulateParamsXDL::initParametersFp16[PopulateParamsXDL::nInitParametersFp16] = { +PopulateParamsXDL::initParametersConv[PopulateParamsXDL::nInitParametersConv] = { // M/block N/block K/block M/wave N/wave kPack splitKFactor forceUnroll bCopyMore - {128, 256, 8, 64, 32, 4, 1, true, true}, - {128, 256, 4, 64, 32, 8, 1, true, true}, - {128, 128, 8, 128, 32, 8, 1, true, true}, - {128, 128, 8, 64, 32, 4, 1, true, true}, - {128, 128, 8, 32, 32, 8, 1, true, true}, - {128, 128, 8, 32, 16, 4, 1, true, true}, - {128, 128, 4, 128, 32, 8, 1, true, true}, - {128, 128, 4, 128, 16, 8, 1, true, true}, - {128, 128, 4, 64, 32, 8, 1, true, true}, - {128, 128, 4, 64, 16, 8, 1, true, true}, - {128, 128, 4, 32, 32, 8, 1, true, true}, - {128, 64, 4, 128, 16, 8, 1, true, true}, - {128, 64, 4, 32, 32, 8, 1, true, true}, - {128, 32, 8, 32 ,32 ,8, 1, true, true}, - {64, 128, 4, 64, 16, 8, 1, true, true}, - {64, 128, 8, 32, 32, 4, 1, true, true}, - {64, 128, 8, 32, 16, 8, 1, true, true}, - {64, 128, 8, 32, 16, 4, 1, true, true}, - {64, 128, 8, 64, 32, 4, 1, true, true}, - {64, 128, 4, 32, 16, 8, 1, true, true}, - {64, 128, 4, 32, 32, 8, 1, true, true}, - {64, 64, 8, 32, 32, 8, 1, true, true}, - {64, 64, 8, 32, 32, 8, 1, true, true}, - {64, 64, 8, 32, 16, 8, 1, true, true}, - {64, 64, 8, 16, 16, 8, 1, true, true}, - {64, 64, 4, 32, 32, 8, 1, true, true}, - {64 ,64, 2, 32, 32, 4, 1, true, true}, - {64, 32, 8, 32, 32, 8, 1, true, true}, - {64, 32, 8, 32, 16, 8, 1, true, true}, - {64, 16, 8, 16, 16, 8, 1, true, true}, - {32, 128, 8, 32, 32, 4, 1, true, true}, - {32, 64, 8, 32, 32, 8, 1, true, true}, - {32, 64, 8, 32, 16, 4, 1, true, true}, - {32, 32, 8, 32, 32, 4, 1, true, true}, - {32, 32, 8, 16, 16, 8, 1, true, true}, - {32, 16 ,8, 16, 16, 8, 1, true, true}, - {16, 128, 4, 16, 16, 8, 1, true, true}, - {16, 32, 8, 16, 16, 8, 1, true, true}, - {16, 64, 8, 16, 16, 8, 1, true, true}, - {16, 32, 8, 16 ,16 ,4, 1, true, true} + {64,32,8,16,16,4,1,true,true}, + {32,16,4,16,16,8,1,true,true}, + {64,64,4,32,32,8,1,true,true}, + {64,16,8,16,16,8,1,true,true}, + {64,64,8,32,32,8,1,true,true}, + {64,32,4,16,16,4,1,true,true}, + {64,32,4,16,16,8,1,true,true}, + {32,32,8,16,16,4,1,true,true}, + {64,64,8,32,32,4,1,true,true}, + {32,64,8,32,16,4,1,true,true}, + {64,16,4,16,16,8,1,true,true}, + {64,16,8,16,16,4,1,true,true}, + {64,32,4,32,16,4,1,true,true}, + {32,64,8,16,16,4,1,true,true}, + {64,32,8,16,16,8,1,true,true}, + {64,32,8,32,16,8,1,true,true}, + {64,64,4,64,16,4,1,true,true}, + {64,64,8,32,16,4,1,true,true}, + {64,32,8,32,16,4,1,true,true}, + {32,64,8,32,32,4,1,true,true} + }; + +const InitParamsAccel +PopulateParamsXDL::initParametersFp16Gemm[PopulateParamsXDL::nInitParametersFp16Gemm] = { + // M/block N/block K/block M/wave N/wave kPack splitKFactor forceUnroll bCopyMore + {64,64,8,32,32,8,1,true,true}, + {64,256,4,32,32,8,1,true,true}, + {128,128,8,128,16,8,1,true,true}, + {64,16,8,16,16,8,1,true,true}, + {16,64,8,16,16,4,1,true,true}, + {256,256,4,128,16,8,1,true,true}, + {128,32,8,32,16,8,1,true,true}, + {128,128,4,128,32,4,1,true,true}, + {16,16,8,16,16,8,1,true,true}, + {32,64,8,16,16,8,1,true,true}, + {128,256,4,64,16,8,1,true,true}, + {128,128,4,64,32,8,1,true,true}, + {64,128,4,64,32,8,1,true,true}, + {64,128,4,32,32,8,1,true,true}, + {128,128,4,32,32,8,1,true,true}, + {64,128,8,32,16,4,1,true,true}, + {64,128,8,32,16,8,1,true,true}, + {128,256,2,128,32,8,1,true,true}, + {128,128,4,128,16,8,1,true,true}, + {128,128,4,64,16,8,1,true,true} }; const InitParamsAccel -PopulateParamsXDL::initParametersForward8Bit[ - PopulateParamsXDL::nInitParametersForward8Bit] = { - {128, 256, 8, 128, 16, 4, 1, true, true}, - {128, 128, 16, 64, 32, 8, 1, true, true}, - {128, 128, 8, 128, 16, 8, 1, true, true}, - {128, 128, 8, 64, 16, 8, 1, true, true}, - {128, 128, 8, 32, 16, 16, 1, true, true}, - {128, 64, 32, 64, 32, 4, 1, true, true}, - {128, 64, 8, 32, 32, 16, 1, true, true}, - {128, 64, 8, 32, 16, 16, 1, true, true}, - {128, 64, 4, 32, 16, 16, 1, true, true}, - {64, 128, 32, 64, 32, 4, 1, true, true}, - {64, 128, 16, 32, 16, 4, 1, true, true}, - {64, 128, 8, 64, 16, 8, 1, true, true}, - {64, 128, 4, 32, 16, 16, 1, true , true}, - {64, 128, 8, 32, 16, 8, 1, true, true}, - {64, 64, 16, 32, 32, 4, 1, true, true}, - {64, 64, 8, 32, 32, 16, 1, true, true}, - {64, 64, 8, 32, 16, 16, 1, true, true}, - {64, 64, 4, 32, 16, 16, 1, true, true}, - {64, 64, 4, 32, 16, 8, 1, true, true}, - {64, 64, 16, 32, 16, 4, 1, true, true}, - {64, 64, 16, 16, 16, 16, 1, true, true}, - {64, 32, 16, 32, 16, 4, 1, true, true}, - {64, 32, 8, 16, 16, 16, 1, true, true}, - {64, 32, 8, 32, 16, 16, 1, true, true}, - {64, 32, 8, 32, 16, 8, 1, true, true}, - {64, 16, 8, 16, 16, 16, 1, true, true}, - {32, 256, 16, 32, 32, 4, 1, true, true}, - {32, 256, 4, 32, 16, 8, 1, true, true}, - {32, 128, 32, 32, 16, 4, 1, true, true}, - {32, 64, 32, 16, 16, 4, 1, true, true}, - {32, 64, 16, 32, 16, 4, 1, true, true}, - {32, 64, 8, 16, 16, 16, 1, true, true}, - {32, 64, 4, 32, 16, 8, 1, true, true}, - {32, 32, 32, 16, 16, 4, 1, true, true}, - {32, 32, 16, 16, 16, 8, 1, true, true}, - {32, 16, 16, 16, 16, 8, 1, true, true}, - {16, 64, 16, 16, 16, 4, 1, true, true}, - {16, 32, 16, 16, 16, 16, 1, true, true}, - {16, 16, 32, 16, 16, 4, 1, true, true}, - {16, 16, 16, 16, 16, 4, 1, true, true} +PopulateParamsXDL::initParametersFp16Conv[PopulateParamsXDL::nInitParametersFp16Conv] = { + // M/block N/block K/block M/wave N/wave kPack splitKFactor forceUnroll bCopyMore + {64,64,8,32,32,8,1,true,true}, + {64,64,2,32,32,8,1,true,true}, + {128,128,8,64,16,8,1,true,true}, + {64,32,8,16,16,8,1,true,true}, + {32,32,8,16,16,8,1,true,true}, + {64,256,4,32,32,8,1,true,true}, + {256,64,4,64,32,8,1,true,true}, + {16,32,8,16,16,4,1,true,true}, + {32,128,8,32,32,8,1,true,true}, + {64,64,8,16,16,8,1,true,true}, + {128,16,8,32,16,8,1,true,true}, + {64,128,4,32,32,8,1,true,true}, + {32,64,8,16,16,8,1,true,true}, + {64,32,8,32,32,8,1,true,true}, + {16,64,4,16,16,8,1,true,true}, + {128,128,4,64,16,8,1,true,true}, + {64,64,8,32,16,8,1,true,true}, + {128,128,4,32,32,8,1,true,true}, + {64,128,4,16,16,8,1,true,true}, + {128,128,4,64,32,8,1,true,true} + }; + +const InitParamsAccel +PopulateParamsXDL::initParametersForward8BitGemm[ + PopulateParamsXDL::nInitParametersForward8BitGemm] = { + {64,128,8,64,16,8,1,true,true}, + {32,64,8,16,16,8,1,true,true}, + {32,32,16,16,16,8,1,true,true}, + {64,64,16,32,32,4,1,true,true}, + {64,64,16,16,16,8,1,true,true}, + {64,64,16,64,16,4,1,true,true}, + {16,32,32,16,16,4,1,true,true}, + {16,16,4,16,16,4,1,true,true}, + {64,32,8,32,16,16,1,true,true}, + {128,256,8,128,16,4,1,true,true}, + {64,16,8,16,16,16,1,true,true}, + {128,128,8,64,16,8,1,true,true}, + {128,128,8,128,16,8,1,true,true}, + {64,64,32,32,32,4,1,true,true}, + {32,32,32,16,16,4,1,true,true}, + {64,16,16,16,16,16,1,true,true}, + {32,64,4,32,16,8,1,true,true}, + {32,16,4,16,16,4,1,true,true}, + {64,128,8,64,16,4,1,true,true}, + {256,256,8,128,128,1,1,true,true} +}; + +const InitParamsAccel +PopulateParamsXDL::initParametersForward8BitConv[ + PopulateParamsXDL::nInitParametersForward8BitConv] = { +{64,64,8,32,16,16,1,true,true}, +{64,64,8,32,32,16,1,true,true}, +{64,128,4,32,16,16,1,true,true}, +{128,64,4,32,16,16,1,true,true}, +{64,128,4,64,16,16,1,true,true}, +{64,64,4,32,16,16,1,true,true}, +{128,64,4,64,16,16,1,true,true}, +{128,128,4,64,16,16,1,true,true}, +{64,64,4,32,32,16,1,true,true}, +{64,32,8,32,16,16,1,true,true}, +{64,64,16,32,16,16,1,true,true}, +{64,64,4,16,16,16,1,true,true}, +{128,128,4,64,32,16,1,true,true}, +{64,32,8,16,16,16,1,true,true}, +{128,32,4,64,16,16,1,true,true}, +{32,64,8,16,16,16,1,true,true}, +{64,64,16,32,32,16,1,true,true}, +{64,32,4,32,16,16,1,true,true}, +{128,64,16,32,16,8,1,true,true}, +{128,64,16,64,32,8,1,true,true} }; // clang-format on @@ -660,15 +675,29 @@ std::vector PopulateParamsXDL::getTuningParameters(KernelType opType, Type dataTypeA, Type dataTypeB, StringRef arch) const { ArrayRef params; - switch (dataTypeA.getIntOrFloatBitWidth()) { - case 8: - params = {initParametersForward8Bit, nInitParametersForward8Bit}; - break; - case 16: - params = {initParametersFp16, nInitParametersFp16}; - break; - default: - params = {initParameters, nInitParameters}; + if(opType == KernelType::Gemm){ + switch (dataTypeA.getIntOrFloatBitWidth()) { + case 8: + params = {initParametersForward8BitGemm, nInitParametersForward8BitGemm}; + break; + case 16: + params = {initParametersFp16Gemm, nInitParametersFp16Gemm}; + break; + default: + params = {initParametersGemm, nInitParametersGemm}; + } + } + else{ + switch (dataTypeA.getIntOrFloatBitWidth()) { + case 8: + params = {initParametersForward8BitConv, nInitParametersForward8BitConv}; + break; + case 16: + params = {initParametersFp16Conv, nInitParametersFp16Conv}; + break; + default: + params = {initParametersConv, nInitParametersConv}; + } } std::vector res; // Only return valid XDLOp params @@ -881,15 +910,15 @@ PopulateParamsWmma::getTuningParameters(KernelType opType, Type dataTypeA, Type dataTypeB, StringRef arch) const { ArrayRef params; std::vector res; - switch (dataTypeA.getIntOrFloatBitWidth()) { - case 8: + switch (dataTypeA.getIntOrFloatBitWidth()) { + case 8: params = {initParametersForward8Bit, nInitParametersForward8Bit}; - break; - case 16: + break; + case 16: params = {initParametersFp16, nInitParametersFp16}; - break; - default: - return res; + break; + default: + return res; } // Only return valid Wmma params const int64_t waveSize = mlir::rock::lookupArchInfo(arch).waveSize;