Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate quick-tuning lists by conv and gemm #1675

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,19 +385,32 @@ class PopulateParamsAccel : public BasePopulateParams<InitParamsAccel> {
// Xdlops interface
//
class PopulateParamsXDL : public PopulateParamsAccel {
static constexpr size_t nInitParameters = 40;
static constexpr size_t nInitParametersConv = 20;
// Initial tuning parameters for forward convolution and backward
// convolution.
static const InitParamsAccel initParameters[nInitParameters];
static const InitParamsAccel initParametersConv[nInitParametersConv];

static constexpr size_t nInitParametersFp16 = 40;
static constexpr size_t nInitParametersFp16Conv = 20;
// Tuning parameters for fp16/bf16 convolutions.
static const InitParamsAccel initParametersFp16[nInitParametersFp16];
static const InitParamsAccel initParametersFp16Conv[nInitParametersFp16Conv];

static constexpr size_t nInitParametersForward8Bit = 40;
static constexpr size_t nInitParametersForward8BitConv = 20;
// Tuning parameters for i8 convolutions.
static const InitParamsAccel
initParametersForward8Bit[nInitParametersForward8Bit];
initParametersForward8BitConv[nInitParametersForward8BitConv];

static constexpr size_t nInitParametersGemm = 20;
// Initial tuning parameters for gemm.
static const InitParamsAccel initParametersGemm[nInitParametersGemm];

static constexpr size_t nInitParametersFp16Gemm = 20;
// Tuning parameters for fp16/bf16 gemm.
static const InitParamsAccel initParametersFp16Gemm[nInitParametersFp16Gemm];

static constexpr size_t nInitParametersForward8BitGemm = 20;
// Tuning parameters for i8 gemm.
static const InitParamsAccel
initParametersForward8BitGemm[nInitParametersForward8BitGemm];

public:
std::vector<InitParamsAccel>
Expand Down
309 changes: 169 additions & 140 deletions mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,138 +394,153 @@ PopulateParamsAccel::obtainTuningParameters(RockGemmWrapperInterface op,
/// Xdlops acceleration
// clang-format off
const InitParamsAccel
PopulateParamsXDL::initParameters[PopulateParamsXDL::nInitParameters] = {
PopulateParamsXDL::initParametersGemm[PopulateParamsXDL::nInitParametersGemm] = {
// M/block N/block K/block M/wave N/wave kPack splitKFactor forceUnroll bCopyMore
{256, 256, 2, 128, 32, 4, 1, true, true},
{256, 64, 8, 128, 32, 1, 1, true, true},
{128, 128, 8, 64, 16, 4, 1, true, true},
{128, 128, 4, 128, 32, 4, 1, true, true},
{128, 128, 2, 32, 32, 8, 1, true, true},
{128, 64, 8, 64, 16, 1, 1, true, true},
{128, 64, 8, 32, 32, 4, 1, true, true},
{128, 64, 8, 32, 16, 1, 1, true, true},
{128, 64, 4, 32, 32, 4, 1, true, true},
{128, 64, 2, 128, 32, 4, 1, true, true},
{128, 32, 4, 128, 16, 4, 1, true, true},
{128, 16, 4, 32, 16, 8, 1, true, true},
{64, 256, 8, 64, 16, 4, 1, true, true},
{64, 128, 4, 64, 32, 1, 1, true, true},
{64, 128, 4, 64, 16, 4, 1, true , true},
{64, 128, 4, 32, 16, 4, 1, true, true},
{64, 128, 2, 32, 32, 8, 1, true, true},
{64, 64, 8, 32, 32, 4, 1, true, true},
{64, 64, 8, 16, 16, 4, 1, true, true},
{64, 64, 8, 32, 16, 4, 1, true, true},
{64, 64, 8, 16, 16, 8, 1, true, true},
{64, 64, 4, 32, 16, 4, 1, true, true},
{64, 64, 4, 16, 16, 8, 1, true, true},
{64, 64, 8, 64, 16, 8, 1, true, true},
{64, 32, 4, 32, 16, 8, 1, true, true},
{64, 32, 8, 16, 16, 4, 1, true, true},
{64, 32, 8, 16, 16, 4, 1, true, true},
{64, 16, 8, 16, 16, 8, 1, true, true},
{32, 128, 8, 32, 16, 1, 1, true, true},
{32, 128, 8, 16, 16, 4, 1, true , true},
{32, 64, 8, 32, 16, 4, 1, true, true},
{32, 64, 4, 32, 16, 4, 1, true, true},
{32, 32, 8, 16, 16, 8, 1, true, true},
{32, 32, 8, 16, 16, 4, 1, true, true},
{32, 16, 8, 16, 16, 8, 1, true, true},
{32, 16, 4, 16, 16, 8, 1, true, true},
{16, 32, 4, 16, 16, 4, 1, true, true},
{16, 32, 8, 16, 16, 8, 1, true, true},
{16, 16, 4, 16, 16, 4, 1, true, true},
{16, 16, 8, 16, 16, 8, 1, true, true}
{64,64,8,32,32,4,1,true,true},
{16,32,8,16,16,8,1,true,true},
{64,64,4,32,16,4,1,true,true},
{64,64,8,32,32,8,1,true,true},
{64,64,8,32,16,4,1,true,true},
{64,32,8,16,16,4,1,true,true},
{128,128,4,32,32,8,1,true,true},
{32,64,8,16,16,4,1,true,true},
{128,128,4,64,16,4,1,true,true},
{64,256,8,64,16,1,1,true,true},
{64,64,4,32,32,8,1,true,true},
{64,64,4,16,16,4,1,true,true},
{64,64,4,16,16,8,1,true,true},
{64,64,8,16,16,4,1,true,true},
{64,64,8,64,16,4,1,true,true},
{64,128,8,64,16,1,1,true,true},
{64,64,4,32,16,8,1,true,true},
{64,128,4,32,32,4,1,true,true},
{64,256,4,32,32,4,1,true,true},
{128,128,8,64,16,4,1,true,true}
};

const InitParamsAccel
PopulateParamsXDL::initParametersFp16[PopulateParamsXDL::nInitParametersFp16] = {
PopulateParamsXDL::initParametersConv[PopulateParamsXDL::nInitParametersConv] = {
// M/block N/block K/block M/wave N/wave kPack splitKFactor forceUnroll bCopyMore
{128, 256, 8, 64, 32, 4, 1, true, true},
{128, 256, 4, 64, 32, 8, 1, true, true},
{128, 128, 8, 128, 32, 8, 1, true, true},
{128, 128, 8, 64, 32, 4, 1, true, true},
{128, 128, 8, 32, 32, 8, 1, true, true},
{128, 128, 8, 32, 16, 4, 1, true, true},
{128, 128, 4, 128, 32, 8, 1, true, true},
{128, 128, 4, 128, 16, 8, 1, true, true},
{128, 128, 4, 64, 32, 8, 1, true, true},
{128, 128, 4, 64, 16, 8, 1, true, true},
{128, 128, 4, 32, 32, 8, 1, true, true},
{128, 64, 4, 128, 16, 8, 1, true, true},
{128, 64, 4, 32, 32, 8, 1, true, true},
{128, 32, 8, 32 ,32 ,8, 1, true, true},
{64, 128, 4, 64, 16, 8, 1, true, true},
{64, 128, 8, 32, 32, 4, 1, true, true},
{64, 128, 8, 32, 16, 8, 1, true, true},
{64, 128, 8, 32, 16, 4, 1, true, true},
{64, 128, 8, 64, 32, 4, 1, true, true},
{64, 128, 4, 32, 16, 8, 1, true, true},
{64, 128, 4, 32, 32, 8, 1, true, true},
{64, 64, 8, 32, 32, 8, 1, true, true},
{64, 64, 8, 32, 32, 8, 1, true, true},
{64, 64, 8, 32, 16, 8, 1, true, true},
{64, 64, 8, 16, 16, 8, 1, true, true},
{64, 64, 4, 32, 32, 8, 1, true, true},
{64 ,64, 2, 32, 32, 4, 1, true, true},
{64, 32, 8, 32, 32, 8, 1, true, true},
{64, 32, 8, 32, 16, 8, 1, true, true},
{64, 16, 8, 16, 16, 8, 1, true, true},
{32, 128, 8, 32, 32, 4, 1, true, true},
{32, 64, 8, 32, 32, 8, 1, true, true},
{32, 64, 8, 32, 16, 4, 1, true, true},
{32, 32, 8, 32, 32, 4, 1, true, true},
{32, 32, 8, 16, 16, 8, 1, true, true},
{32, 16 ,8, 16, 16, 8, 1, true, true},
{16, 128, 4, 16, 16, 8, 1, true, true},
{16, 32, 8, 16, 16, 8, 1, true, true},
{16, 64, 8, 16, 16, 8, 1, true, true},
{16, 32, 8, 16 ,16 ,4, 1, true, true}
{64,32,8,16,16,4,1,true,true},
{32,16,4,16,16,8,1,true,true},
{64,64,4,32,32,8,1,true,true},
{64,16,8,16,16,8,1,true,true},
{64,64,8,32,32,8,1,true,true},
{64,32,4,16,16,4,1,true,true},
{64,32,4,16,16,8,1,true,true},
{32,32,8,16,16,4,1,true,true},
{64,64,8,32,32,4,1,true,true},
{32,64,8,32,16,4,1,true,true},
{64,16,4,16,16,8,1,true,true},
{64,16,8,16,16,4,1,true,true},
{64,32,4,32,16,4,1,true,true},
{32,64,8,16,16,4,1,true,true},
{64,32,8,16,16,8,1,true,true},
{64,32,8,32,16,8,1,true,true},
{64,64,4,64,16,4,1,true,true},
{64,64,8,32,16,4,1,true,true},
{64,32,8,32,16,4,1,true,true},
{32,64,8,32,32,4,1,true,true}
};

const InitParamsAccel
PopulateParamsXDL::initParametersFp16Gemm[PopulateParamsXDL::nInitParametersFp16Gemm] = {
// M/block N/block K/block M/wave N/wave kPack splitKFactor forceUnroll bCopyMore
{64,64,8,32,32,8,1,true,true},
{64,256,4,32,32,8,1,true,true},
{128,128,8,128,16,8,1,true,true},
{64,16,8,16,16,8,1,true,true},
{16,64,8,16,16,4,1,true,true},
{256,256,4,128,16,8,1,true,true},
{128,32,8,32,16,8,1,true,true},
{128,128,4,128,32,4,1,true,true},
{16,16,8,16,16,8,1,true,true},
{32,64,8,16,16,8,1,true,true},
{128,256,4,64,16,8,1,true,true},
{128,128,4,64,32,8,1,true,true},
{64,128,4,64,32,8,1,true,true},
{64,128,4,32,32,8,1,true,true},
{128,128,4,32,32,8,1,true,true},
{64,128,8,32,16,4,1,true,true},
{64,128,8,32,16,8,1,true,true},
{128,256,2,128,32,8,1,true,true},
{128,128,4,128,16,8,1,true,true},
{128,128,4,64,16,8,1,true,true}
};

const InitParamsAccel
PopulateParamsXDL::initParametersForward8Bit[
PopulateParamsXDL::nInitParametersForward8Bit] = {
{128, 256, 8, 128, 16, 4, 1, true, true},
{128, 128, 16, 64, 32, 8, 1, true, true},
{128, 128, 8, 128, 16, 8, 1, true, true},
{128, 128, 8, 64, 16, 8, 1, true, true},
{128, 128, 8, 32, 16, 16, 1, true, true},
{128, 64, 32, 64, 32, 4, 1, true, true},
{128, 64, 8, 32, 32, 16, 1, true, true},
{128, 64, 8, 32, 16, 16, 1, true, true},
{128, 64, 4, 32, 16, 16, 1, true, true},
{64, 128, 32, 64, 32, 4, 1, true, true},
{64, 128, 16, 32, 16, 4, 1, true, true},
{64, 128, 8, 64, 16, 8, 1, true, true},
{64, 128, 4, 32, 16, 16, 1, true , true},
{64, 128, 8, 32, 16, 8, 1, true, true},
{64, 64, 16, 32, 32, 4, 1, true, true},
{64, 64, 8, 32, 32, 16, 1, true, true},
{64, 64, 8, 32, 16, 16, 1, true, true},
{64, 64, 4, 32, 16, 16, 1, true, true},
{64, 64, 4, 32, 16, 8, 1, true, true},
{64, 64, 16, 32, 16, 4, 1, true, true},
{64, 64, 16, 16, 16, 16, 1, true, true},
{64, 32, 16, 32, 16, 4, 1, true, true},
{64, 32, 8, 16, 16, 16, 1, true, true},
{64, 32, 8, 32, 16, 16, 1, true, true},
{64, 32, 8, 32, 16, 8, 1, true, true},
{64, 16, 8, 16, 16, 16, 1, true, true},
{32, 256, 16, 32, 32, 4, 1, true, true},
{32, 256, 4, 32, 16, 8, 1, true, true},
{32, 128, 32, 32, 16, 4, 1, true, true},
{32, 64, 32, 16, 16, 4, 1, true, true},
{32, 64, 16, 32, 16, 4, 1, true, true},
{32, 64, 8, 16, 16, 16, 1, true, true},
{32, 64, 4, 32, 16, 8, 1, true, true},
{32, 32, 32, 16, 16, 4, 1, true, true},
{32, 32, 16, 16, 16, 8, 1, true, true},
{32, 16, 16, 16, 16, 8, 1, true, true},
{16, 64, 16, 16, 16, 4, 1, true, true},
{16, 32, 16, 16, 16, 16, 1, true, true},
{16, 16, 32, 16, 16, 4, 1, true, true},
{16, 16, 16, 16, 16, 4, 1, true, true}
PopulateParamsXDL::initParametersFp16Conv[PopulateParamsXDL::nInitParametersFp16Conv] = {
// M/block N/block K/block M/wave N/wave kPack splitKFactor forceUnroll bCopyMore
{64,64,8,32,32,8,1,true,true},
{64,64,2,32,32,8,1,true,true},
{128,128,8,64,16,8,1,true,true},
{64,32,8,16,16,8,1,true,true},
{32,32,8,16,16,8,1,true,true},
{64,256,4,32,32,8,1,true,true},
{256,64,4,64,32,8,1,true,true},
{16,32,8,16,16,4,1,true,true},
{32,128,8,32,32,8,1,true,true},
{64,64,8,16,16,8,1,true,true},
{128,16,8,32,16,8,1,true,true},
{64,128,4,32,32,8,1,true,true},
{32,64,8,16,16,8,1,true,true},
{64,32,8,32,32,8,1,true,true},
{16,64,4,16,16,8,1,true,true},
{128,128,4,64,16,8,1,true,true},
{64,64,8,32,16,8,1,true,true},
{128,128,4,32,32,8,1,true,true},
{64,128,4,16,16,8,1,true,true},
{128,128,4,64,32,8,1,true,true}
};

const InitParamsAccel
PopulateParamsXDL::initParametersForward8BitGemm[
PopulateParamsXDL::nInitParametersForward8BitGemm] = {
{64,128,8,64,16,8,1,true,true},
{32,64,8,16,16,8,1,true,true},
{32,32,16,16,16,8,1,true,true},
{64,64,16,32,32,4,1,true,true},
{64,64,16,16,16,8,1,true,true},
{64,64,16,64,16,4,1,true,true},
{16,32,32,16,16,4,1,true,true},
{16,16,4,16,16,4,1,true,true},
{64,32,8,32,16,16,1,true,true},
{128,256,8,128,16,4,1,true,true},
{64,16,8,16,16,16,1,true,true},
{128,128,8,64,16,8,1,true,true},
{128,128,8,128,16,8,1,true,true},
{64,64,32,32,32,4,1,true,true},
{32,32,32,16,16,4,1,true,true},
{64,16,16,16,16,16,1,true,true},
{32,64,4,32,16,8,1,true,true},
{32,16,4,16,16,4,1,true,true},
{64,128,8,64,16,4,1,true,true},
{256,256,8,128,128,1,1,true,true}
};

const InitParamsAccel
PopulateParamsXDL::initParametersForward8BitConv[
PopulateParamsXDL::nInitParametersForward8BitConv] = {
{64,64,8,32,16,16,1,true,true},
{64,64,8,32,32,16,1,true,true},
{64,128,4,32,16,16,1,true,true},
{128,64,4,32,16,16,1,true,true},
{64,128,4,64,16,16,1,true,true},
{64,64,4,32,16,16,1,true,true},
{128,64,4,64,16,16,1,true,true},
{128,128,4,64,16,16,1,true,true},
{64,64,4,32,32,16,1,true,true},
{64,32,8,32,16,16,1,true,true},
{64,64,16,32,16,16,1,true,true},
{64,64,4,16,16,16,1,true,true},
{128,128,4,64,32,16,1,true,true},
{64,32,8,16,16,16,1,true,true},
{128,32,4,64,16,16,1,true,true},
{32,64,8,16,16,16,1,true,true},
{64,64,16,32,32,16,1,true,true},
{64,32,4,32,16,16,1,true,true},
{128,64,16,32,16,8,1,true,true},
{128,64,16,64,32,8,1,true,true}
};
// clang-format on

Expand Down Expand Up @@ -660,15 +675,29 @@ std::vector<InitParamsAccel>
PopulateParamsXDL::getTuningParameters(KernelType opType, Type dataTypeA,
Type dataTypeB, StringRef arch) const {
ArrayRef<InitParamsAccel> params;
switch (dataTypeA.getIntOrFloatBitWidth()) {
case 8:
params = {initParametersForward8Bit, nInitParametersForward8Bit};
break;
case 16:
params = {initParametersFp16, nInitParametersFp16};
break;
default:
params = {initParameters, nInitParameters};
if(opType == KernelType::Gemm){
switch (dataTypeA.getIntOrFloatBitWidth()) {
case 8:
params = {initParametersForward8BitGemm, nInitParametersForward8BitGemm};
break;
case 16:
params = {initParametersFp16Gemm, nInitParametersFp16Gemm};
break;
default:
params = {initParametersGemm, nInitParametersGemm};
}
}
else{
switch (dataTypeA.getIntOrFloatBitWidth()) {
case 8:
params = {initParametersForward8BitConv, nInitParametersForward8BitConv};
break;
case 16:
params = {initParametersFp16Conv, nInitParametersFp16Conv};
break;
default:
params = {initParametersConv, nInitParametersConv};
}
}
std::vector<InitParamsAccel> res;
// Only return valid XDLOp params
Expand Down Expand Up @@ -881,15 +910,15 @@ PopulateParamsWmma::getTuningParameters(KernelType opType, Type dataTypeA,
Type dataTypeB, StringRef arch) const {
ArrayRef<InitParamsAccel> params;
std::vector<InitParamsAccel> res;
switch (dataTypeA.getIntOrFloatBitWidth()) {
case 8:
switch (dataTypeA.getIntOrFloatBitWidth()) {
case 8:
params = {initParametersForward8Bit, nInitParametersForward8Bit};
break;
case 16:
break;
case 16:
params = {initParametersFp16, nInitParametersFp16};
break;
default:
return res;
break;
default:
return res;
}
// Only return valid Wmma params
const int64_t waveSize = mlir::rock::lookupArchInfo(arch).waveSize;
Expand Down
Loading