-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accelerate first token gen with BF16-gemm MHA and concat-Silu MLP #106
Conversation
abenmao
commented
Nov 30, 2023
•
edited
Loading
edited
- BF16-based flash attention. Enabled when prompt_len >=1024
- concat-silu gate+up proj
- Added intel-mkl library in cmake
- In addition, an environment variable ENABLE_CBLAS_MLP can be set to 1 (change downProj with cblas kernel) when prompt len >= 2048. The default is 0.
21f180b
to
6652c16
Compare
src/utils/decoder_util.h
Outdated
} | ||
template <typename T> | ||
static void single_thread_cvt2bf16_inplace(T *buf, int m, int n, int stride) { | ||
if (!std::is_same_v<T, bfloat16_t>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we check if T is float?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, done~
src/utils/decoder_util.h
Outdated
for (int j = 0; j < headNum; j++) { | ||
int srcOffEachLine = j * seqLen * headSize; | ||
int dstOffEachHead = j * headSize; | ||
static inline __m512 dilExpKernel(__m512 vecSrc) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it for vector version of exp? if so, there is already one in some place named vexp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have changed to our vexp func
src/layers/attention.h
Outdated
@@ -878,7 +914,8 @@ class Attention { | |||
} | |||
|
|||
virtual const float *getMask(const float *attnMask, int bId, int hId, int srcLen, int tgtLen) { | |||
return attnMask + bId * srcLen * tgtLen; | |||
return attnMask; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why different with origin?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is just for performance tuning. Forgot to rollback. Already add comments.
int minBlk = (nth >= batchSize * numQHead ? 256 : 512); | ||
int srcBlk = std::min(minBlk, srcLen); | ||
int tgtBlk = std::min(minBlk, tgtLen); | ||
int minBlk = (int)std::pow(2, int(std::log2(srcLen / 2))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any design principle here? if any, would you make some comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current block size is just derived from practical experience. Added comments.
src/layers/attention.h
Outdated
float **preSum = thrPtrBuf; | ||
float **sum = thrPtrBuf + nth; | ||
float **preMax = thrPtrBuf + nth * 2; | ||
float **max = thrPtrBuf + nth * 3; | ||
float **qkArr = thrPtrBuf + nth * 4; | ||
float **expQkvArr = thrPtrBuf + nth * 5; | ||
float **qArr = thrPtrBuf + nth * 6; | ||
|
||
thrBuf = (float *)malloc(sizeof(float) * nth * arrStride); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it better to use SimpleMemPool to get the buffer? (SimpleMemPool will maintain the buffer, so next layer directly use)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done~
src/layers/mlp_chatglm2.h
Outdated
memcpy(upW + i * colSplit, weightPTR + ctx->splitIdx * colSplit, colSplit * sizeof(float)); | ||
weightPTR += intermediateSize; | ||
|
||
int enable = (getenv("ENABLE_CAT_MLP") ? atoi(getenv("ENABLE_CAT_MLP")) : 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we provide the option to disable this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave the option temporarily for future performance tuning. Maybe we can delete it after a while.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a better method is to use a static variable, thus do not need to call 'getenv' every time.
(and if it is used multiple times in the file, we may declare a global static variable)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added global vars in mlp_llama.cpp. Please review~
@abenmao Have we checked the output of this PR for long tokens? |
6652c16
to
3d9f90b
Compare
Yes, have checked the output of long prompt for those models |
dbg.dumpMatrix(gateWeight); | ||
dbg.debugPrint("gate output:\n"); | ||
dbg.dumpMatrix(imBuffer); | ||
dbg.debugPrint("gateWeight:\n"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to format?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like I made a mistake. The code already formatted.
3d9f90b
to
3321e78
Compare
will merge after the build server checking. |