Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Make SelfAttention prepared for AMX_FP16; More balanced task split in Cross Attention #466

Merged
merged 2 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 116 additions & 102 deletions src/kernels/attention_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,23 @@ void gemmSV(
}
}

// T is bfloat16_t or float16_t
// ldb is the K value during packing
template <typename T>
void small_amx_gemm_16bits_compute(int m, int n, int k, T *A, int lda, T *packedB, int ldb, T *C, int ldc) {
static_assert(std::is_same_v<T, bfloat16_t> || std::is_same_v<T, float16_t>, "AMX gemm only supports BF16/FP16.");

if (std::is_same_v<T, bfloat16_t>) {
xdnn_small_amx_sgemm_bf16bf16bf16_compute(
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc);
} else {
//xdnn_small_amx_sgemm_f16f16f16_compute(m, n, k, (XDNN_FP16 *)A, lda, (XDNN_FP16 *)packedB, ldb, (XDNN_FP16 *)C, ldc);
}
}

// Self attention while KV cache copy is separated
template <bool fusedPack, typename Lambda1, typename Lambda2>
void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum,
template <bool fusedPack, typename T, typename Lambda1, typename Lambda2>
void selfAttention_SeparateCopy(T *output, T *query, T *key, T *value, int qHeadNum,
int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes,
const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache,
const Lambda2 &getVCache) {
Expand All @@ -126,8 +140,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
auto totalPackSize
= fusedPack ? threadNum * (kPackSize + vPackSize) : (batchSize * kvHeadNum) * (kPackSize + vPackSize);

bfloat16_t *packBuf
= (bfloat16_t *)SimpleMemPool::instance().getBuffer("kv_packing", totalPackSize * sizeof(bfloat16_t));
T *packBuf
= (T *)SimpleMemPool::instance().getBuffer("kv_packing", totalPackSize * sizeof(T));

// Copy key/value to cache and pack them
// If packing is not fused into computing, then pack it here
Expand All @@ -137,8 +151,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
for (int i = 0; i < kvHeadNum; ++i) {
const int tokens = tokenSizes[b];

bfloat16_t *packedB = packBuf + (b * kvHeadNum + i) * (kPackSize + vPackSize);
bfloat16_t *packedV = packedB + kPackSize;
T *packedB = packBuf + (b * kvHeadNum + i) * (kPackSize + vPackSize);
T *packedV = packedB + kPackSize;

auto B = key + offsets[b] * kvStride + i * headSize;
for (int s = 0; s < tokens; ++s) {
Expand Down Expand Up @@ -181,8 +195,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_

// Prepare score buffer
auto maxScoreStride = (maxTokenSize + 31) / 32 * 32;
bfloat16_t *scores = (bfloat16_t *)SimpleMemPool::instance().getBuffer(
"qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(bfloat16_t));
T *scores = (T *)SimpleMemPool::instance().getBuffer(
"qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(T));

auto totalBlocks = blkEndIndex[batchSize - 1];
std::pair<int, int> packInfo[threadNum];
Expand All @@ -208,8 +222,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
int tid = omp_get_thread_num();
int kvHeadIdx = i / groupNum;
int locationIdx = (fusedPack ? tid : b * kvHeadNum + kvHeadIdx);
bfloat16_t *packedB = packBuf + locationIdx * (kPackSize + vPackSize);
bfloat16_t *packedV = packedB + kPackSize;
T *packedB = packBuf + locationIdx * (kPackSize + vPackSize);
T *packedV = packedB + kPackSize;

const int tokens = tokenSizes[b];
const int startSeq = mb * mBlockSize;
Expand All @@ -234,8 +248,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
}

// Causal mask (either with or without Alibi), use endSeq as N
xdnn_small_amx_sgemm_bf16bf16bf16_compute(
m, endSeq, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc);
small_amx_gemm_16bits_compute(m, endSeq, k, A, lda, packedB, headSize, C, ldc);

#ifdef XFT_DEBUG
if (b == 0 && i == 0) {
Expand All @@ -257,7 +270,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
} else {
DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements);
}
memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(bfloat16_t));
memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(T));
}

#ifdef XFT_DEBUG
Expand All @@ -274,7 +287,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
lda = ldc;
ldc = oStride;
A = C;
C = (bfloat16_t *)output + (offsets[b] + startSeq) * ldc + i * headSize;
C = (T *)output + (offsets[b] + startSeq) * ldc + i * headSize;

if constexpr (fusedPack) {
if (packInfo[tid].first != b || packInfo[tid].second != kvHeadIdx) {
Expand All @@ -287,8 +300,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
}
}

xdnn_small_amx_sgemm_bf16bf16bf16_compute(
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc);
small_amx_gemm_16bits_compute(m, n, k, A, lda, packedV, tokens, C, ldc);

#ifdef XFT_DEBUG
if (b == 0 && i == 0) {
Expand All @@ -301,8 +313,8 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_
});
}

template <typename Lambda1, typename Lambda2>
void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum,
template <typename T, typename Lambda1, typename Lambda2>
void selfAttention_FusedCopy(T *output, T *query, T *key, T *value, int qHeadNum,
int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes,
const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache,
const Lambda2 &getVCache) {
Expand Down Expand Up @@ -331,11 +343,11 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
// Prepare buffers (packing buffer and score buffer)
const int kPackSize = xdnn_small_amx_sgemm_bf16bf16bf16_packb_size(maxTokenSize, headSize, 32, 32);
const int vPackSize = xdnn_small_amx_sgemm_bf16bf16bf16_packb_size(headSize, maxTokenSize, 32, 32);
bfloat16_t *packBuf = (bfloat16_t *)SimpleMemPool::instance().getBuffer(
"kv_packing", threadNum * (kPackSize + vPackSize) * sizeof(bfloat16_t));
T *packBuf = (T *)SimpleMemPool::instance().getBuffer(
"kv_packing", threadNum * (kPackSize + vPackSize) * sizeof(T));
int maxScoreStride = (maxTokenSize + 31) / 32 * 32;
bfloat16_t *scores = (bfloat16_t *)SimpleMemPool::instance().getBuffer(
"qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(bfloat16_t));
T *scores = (T *)SimpleMemPool::instance().getBuffer(
"qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(T));

#ifdef XFT_DEBUG
printf("maxTokenSize=%d, tokenSizes[0]=%d, offsets[0]=%d, kvStride=%d\n", maxTokenSize, tokenSizes[0], offsets[0],
Expand All @@ -349,8 +361,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
const int tokens = tokenSizes[b];
const int mBlockNum = (tokens + mBlockSize - 1) / mBlockSize;

bfloat16_t *packedB = packBuf + tid * (kPackSize + vPackSize);
bfloat16_t *packedV = packedB + kPackSize;
T *packedB = packBuf + tid * (kPackSize + vPackSize);
T *packedV = packedB + kPackSize;

// Copy key/value to cache and pack them
auto B = key + offsets[b] * kvStride + i * headSize;
Expand Down Expand Up @@ -386,8 +398,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
auto A = query + (offsets[b] + startSeq) * qStride + i * headSize;
auto C = scores + tid * mBlockSize * maxScoreStride;

xdnn_small_amx_sgemm_bf16bf16bf16_compute(
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc);
small_amx_gemm_16bits_compute(
m, n, k, A, lda, packedB, headSize, C, ldc);

#ifdef XFT_DEBUG
if (b == 0 && i == 0) {
Expand All @@ -408,7 +420,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
} else {
DecoderUtil::alibiSoftmax(C + seq * ldc, scale, alibiSlopes[i], elements);
}
memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(bfloat16_t));
memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(T));
}

#ifdef XFT_DEBUG
Expand All @@ -425,10 +437,9 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
lda = ldc;
ldc = oStride;
A = C;
C = (bfloat16_t *)output + (offsets[b] + startSeq) * ldc + i * headSize;
C = (T *)output + (offsets[b] + startSeq) * ldc + i * headSize;

xdnn_small_amx_sgemm_bf16bf16bf16_compute(
m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc);
small_amx_gemm_16bits_compute(m, n, k, A, lda, packedV, tokens, C, ldc);

#ifdef XFT_DEBUG
if (b == 0 && i == 0) {
Expand All @@ -443,8 +454,8 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t *
} // end for b
}

template <typename Lambda1, typename Lambda2>
void selfAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bfloat16_t *value, int qHeadNum,
template <typename T, typename Lambda1, typename Lambda2>
void selfAttention(T *output, T *query, T *key, T *value, int qHeadNum,
int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes,
const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache,
const Lambda2 &getVCache) {
Expand Down Expand Up @@ -700,91 +711,94 @@ void crossAttnByHead(T *output, const T *query, const T *key, const T *value, in
size_t scoreSizePerThr = 0;
for (int i = 0; i < batchSize; ++i) {
scoreSizePerThr = std::max(scoreSizePerThr, (size_t)inputSeqLens[i] * (inputSeqLens[i] + pastSeqLens[i]));
inputOffsets[i] = (i > 0 ? inputOffsets[i - 1] + inputSeqLens[i - 1] : 0);
inputOffsets[i] = (i > 0 ? inputOffsets[i - 1] + inputSeqLens[i] : 0);
}

scoreSizePerThr = ALIGNED_SIZE(scoreSizePerThr, 16);
size_t scoreSize = scoreSizePerThr * threadNum;
float *scoreBuf = (float *)SimpleMemPool::instance().getBuffer("scoreBuf", sizeof(float) * scoreSize);

#pragma omp parallel for collapse(2)
for (int b = 0; b < batchSize; ++b) {
for (int i = 0; i < responsibleHeads; ++i) {
// Copy current key to cached keys (if needed)
int kvHdx = i / groupNum;
auto keyMatInfo = getKHead(b, kvHdx);
auto valueMat = getVHead(b, kvHdx);
bool bCopyCache = (i % groupNum == 0);

// Q * K
auto Q = query + inputOffsets[b] * qStride + i * headSize;
auto S = scoreBuf + omp_get_thread_num() * scoreSizePerThr;

const int queryLen = inputSeqLens[b];
const int keyLen = pastSeqLens[b] + inputSeqLens[b];

if (bCopyCache) {
int m = queryLen;
int n = keyLen;
int lda = qStride;
int ldc = keyLen;
#pragma omp parallel for collapse(3)
for (int kvh = 0; kvh < kvHeadNum; ++kvh) {
for (int b = 0; b < batchSize; ++b) {
for (int groupOff = 0; groupOff < groupNum; ++groupOff) {
int i = kvh * groupNum + groupOff;

// Copy to Key cache and compute Query * Key
auto src = key + inputOffsets[b] * kvStride + kvHdx * headSize;
storeKVCache(keyMatInfo, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride);
// Copy current key to cached keys (if needed)
int kvHdx = kvh;
auto keyMatInfo = getKHead(b, kvHdx);
auto valueMat = getVHead(b, kvHdx);
bool bCopyCache = (i % groupNum == 0);

gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc);
} else {
// Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
int m = queryLen;
int n = pastSeqLens[b];
int lda = qStride;
int ldc = keyLen;
gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc);
// Q * K
auto Q = query + inputOffsets[b] * qStride + i * headSize;
auto S = scoreBuf + omp_get_thread_num() * scoreSizePerThr;

int ldb = kvStride;
auto B = key + inputOffsets[b] * kvStride + kvHdx * headSize;
small_gemm_transb(Q, B, S + n, m, inputSeqLens[b], headSize, lda, ldb, ldc);
}
const int queryLen = inputSeqLens[b];
const int keyLen = pastSeqLens[b] + inputSeqLens[b];

if (bCopyCache) {
int m = queryLen;
int n = keyLen;
int lda = qStride;
int ldc = keyLen;

// Softmax(Q * K)
for (int seq = 0; seq < queryLen; ++seq) {
int elements = pastSeqLens[b] + seq + 1;
if (alibiSlopes == nullptr) {
small_softmax_f32(S + seq * keyLen, scale, elements);
// Copy to Key cache and compute Query * Key
auto src = key + inputOffsets[b] * kvStride + kvHdx * headSize;
storeKVCache(keyMatInfo, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride);

gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc);
} else {
DecoderUtil::alibiSoftmax(S + seq * keyLen, scale, alibiSlopes[i], elements);
// Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
int m = queryLen;
int n = pastSeqLens[b];
int lda = qStride;
int ldc = keyLen;
gemmQK(Q, keyMatInfo, S, m, n, headSize, lda, ldc);

int ldb = kvStride;
auto B = key + inputOffsets[b] * kvStride + kvHdx * headSize;
small_gemm_transb(Q, B, S + n, m, inputSeqLens[b], headSize, lda, ldb, ldc);
}
if (keyLen > elements) { memset(S + seq * keyLen + elements, 0, (keyLen - elements) * sizeof(float)); }
}

// Softmax * V
if (bCopyCache) {
// Copy current value to cached values
auto src = value + inputOffsets[b] * kvStride + kvHdx * headSize;
storeKVCache(valueMat, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride);

int m = queryLen;
auto result = output + inputOffsets[b] * oStride + i * headSize;
gemmSV(S, valueMat, result, m, headSize, keyLen, keyLen, oStride);
} else {
// Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
int m = queryLen;
float f32Out[m * headSize]; // accumulate in FP32
gemmSV(S, valueMat, f32Out, m, headSize, pastSeqLens[b], keyLen, headSize);

auto B = value + inputOffsets[b] * kvStride + kvHdx * headSize;
small_gemm(S + pastSeqLens[b], B, f32Out, m, headSize, m, keyLen, kvStride, headSize, true);

// f32Out -> output
auto result = output + inputOffsets[b] * oStride + i * headSize;
for (int t = 0; t < m; ++t) {
xft::copy(result + t * oStride, f32Out + t * headSize, headSize);
// Softmax(Q * K)
for (int seq = 0; seq < queryLen; ++seq) {
int elements = pastSeqLens[b] + seq + 1;
if (alibiSlopes == nullptr) {
small_softmax_f32(S + seq * keyLen, scale, elements);
} else {
DecoderUtil::alibiSoftmax(S + seq * keyLen, scale, alibiSlopes[i], elements);
}
if (keyLen > elements) { memset(S + seq * keyLen + elements, 0, (keyLen - elements) * sizeof(float)); }
}
}

} // end for i
} // end for b
// Softmax * V
if (bCopyCache) {
// Copy current value to cached values
auto src = value + inputOffsets[b] * kvStride + kvHdx * headSize;
storeKVCache(valueMat, src, pastSeqLens[b], inputSeqLens[b], headSize, kvStride);

int m = queryLen;
auto result = output + inputOffsets[b] * oStride + i * headSize;
gemmSV(S, valueMat, result, m, headSize, keyLen, keyLen, oStride);
} else {
// Note: when KV cache is not copied by me, then 2 times gemm to avoid synchronization
int m = queryLen;
float f32Out[m * headSize]; // accumulate in FP32
gemmSV(S, valueMat, f32Out, m, headSize, pastSeqLens[b], keyLen, headSize);

auto B = value + inputOffsets[b] * kvStride + kvHdx * headSize;
small_gemm(S + pastSeqLens[b], B, f32Out, m, headSize, m, keyLen, kvStride, headSize, true);

// f32Out -> output
auto result = output + inputOffsets[b] * oStride + i * headSize;
for (int t = 0; t < m; ++t) {
xft::copy(result + t * oStride, f32Out + t * headSize, headSize);
}
}
} // end for groupOff
} // end for b
} // end for kvh
}

// scaled dot-product attention: bmm1 + softmax + bmm2
Expand Down
Loading
Loading