Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate first token gen with BF16-gemm MHA and concat-Silu MLP #106

Merged
merged 1 commit into from
Dec 6, 2023

Conversation

abenmao
Copy link
Contributor

@abenmao abenmao commented Nov 30, 2023

  1. BF16-based flash attention. Enabled when prompt_len >=1024
  2. concat-silu gate+up proj
  3. Added intel-mkl library in cmake
  4. In addition, an environment variable ENABLE_CBLAS_MLP can be set to 1 (change downProj with cblas kernel) when prompt len ​​>= 2048. The default is 0.

@abenmao abenmao force-pushed the feature/layers/mha_bf16 branch 5 times, most recently from 21f180b to 6652c16 Compare December 4, 2023 08:30
}
template <typename T>
static void single_thread_cvt2bf16_inplace(T *buf, int m, int n, int stride) {
if (!std::is_same_v<T, bfloat16_t>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we check if T is float?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, done~

for (int j = 0; j < headNum; j++) {
int srcOffEachLine = j * seqLen * headSize;
int dstOffEachHead = j * headSize;
static inline __m512 dilExpKernel(__m512 vecSrc) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it for vector version of exp? if so, there is already one in some place named vexp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have changed to our vexp func

@@ -878,7 +914,8 @@ class Attention {
}

virtual const float *getMask(const float *attnMask, int bId, int hId, int srcLen, int tgtLen) {
return attnMask + bId * srcLen * tgtLen;
return attnMask;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why different with origin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is just for performance tuning. Forgot to rollback. Already add comments.

int minBlk = (nth >= batchSize * numQHead ? 256 : 512);
int srcBlk = std::min(minBlk, srcLen);
int tgtBlk = std::min(minBlk, tgtLen);
int minBlk = (int)std::pow(2, int(std::log2(srcLen / 2)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any design principle here? if any, would you make some comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current block size is just derived from practical experience. Added comments.

float **preSum = thrPtrBuf;
float **sum = thrPtrBuf + nth;
float **preMax = thrPtrBuf + nth * 2;
float **max = thrPtrBuf + nth * 3;
float **qkArr = thrPtrBuf + nth * 4;
float **expQkvArr = thrPtrBuf + nth * 5;
float **qArr = thrPtrBuf + nth * 6;

thrBuf = (float *)malloc(sizeof(float) * nth * arrStride);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it better to use SimpleMemPool to get the buffer? (SimpleMemPool will maintain the buffer, so next layer directly use)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done~

memcpy(upW + i * colSplit, weightPTR + ctx->splitIdx * colSplit, colSplit * sizeof(float));
weightPTR += intermediateSize;

int enable = (getenv("ENABLE_CAT_MLP") ? atoi(getenv("ENABLE_CAT_MLP")) : 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we provide the option to disable this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave the option temporarily for future performance tuning. Maybe we can delete it after a while.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a better method is to use a static variable, thus do not need to call 'getenv' every time.
(and if it is used multiple times in the file, we may declare a global static variable)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added global vars in mlp_llama.cpp. Please review~

@pujiang2018
Copy link
Contributor

@abenmao Have we checked the output of this PR for long tokens?

@abenmao abenmao force-pushed the feature/layers/mha_bf16 branch from 6652c16 to 3d9f90b Compare December 5, 2023 09:24
@abenmao
Copy link
Contributor Author

abenmao commented Dec 5, 2023

@abenmao Have we checked the output of this PR for long tokens?

Yes, have checked the output of long prompt for those models

dbg.dumpMatrix(gateWeight);
dbg.debugPrint("gate output:\n");
dbg.dumpMatrix(imBuffer);
dbg.debugPrint("gateWeight:\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to format?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like I made a mistake. The code already formatted.

@abenmao abenmao force-pushed the feature/layers/mha_bf16 branch from 3d9f90b to 3321e78 Compare December 6, 2023 04:47
@pujiang2018
Copy link
Contributor

will merge after the build server checking.

@pujiang2018 pujiang2018 merged commit 605e62e into intel:main Dec 6, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants