Skip to content

Commit

Permalink
Allow MHA plugin to run on SM_86 as well
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <[email protected]>
  • Loading branch information
rajeevsrao committed Apr 12, 2021
1 parent 77384c1 commit 3102875
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ QKVToContextInterleavedPlugin::QKVToContextInterleavedPlugin(
mSM = getSMVersion();
// variable sequence length is only supported with the fused MHA kernels
// we should not override mS!
assert((mSM == kSM_AMPERE || mSM == kSM_TURING || mSM == kSM_XAVIER)
assert((mSM == kSM_AMPERE_100 || mSM == kSM_AMPERE_10X || mSM == kSM_TURING || mSM == kSM_XAVIER)
&& "requesting maxSeqlen not compatible with GPU arch");
// the layout changes: SxB will be a combined \sum_i s_i and hdim will be the 2nd dimension instead of the third
mXmmaKernel = getXMMAKernelsV2(DATA_TYPE_INT8, mSM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ namespace bert
{
static constexpr int32_t kSM_XAVIER = 72;
static constexpr int32_t kSM_TURING = 75;
static constexpr int32_t kSM_AMPERE = 80;
static constexpr int32_t kSM_AMPERE_100 = 80;
static constexpr int32_t kSM_AMPERE_10X = 86;

class QKVToContextInterleavedPlugin : public nvinfer1::IPluginV2DynamicExt
{
Expand Down

0 comments on commit 3102875

Please sign in to comment.