Skip to content

Commit

Permalink
[Attention] Improve next token latency when (#threads >= 2 * #heads) …
Browse files Browse the repository at this point in the history
…by sharding the head into multiple splits (#70)

* support head split for cross attention

* format the code
  • Loading branch information
pujiang2018 authored Nov 24, 2023
1 parent 1a522a2 commit 7430fe5
Show file tree
Hide file tree
Showing 4 changed files with 380 additions and 95 deletions.
24 changes: 24 additions & 0 deletions src/common/aligned_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <cstddef>
#include <type_traits>

template <typename T, std::size_t Alignment>
struct AlignedType {
alignas(Alignment) T data;

// Default constructor
AlignedType() = default;

// Constructor to initialize with a value of type T
explicit AlignedType(const T &value) : data(value) {}

// Conversion operator to convert AlignedType to T
operator T() const { return data; }

// Overload the assignment operator to assign a value of type T
AlignedType &operator=(const T &value) {
data = value;
return *this;
}
};
229 changes: 197 additions & 32 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
#pragma once
#include <numeric>

#include "aligned_type.h"
#include "bfloat16.h"
#include "debugger.h"
#include "decoder_util.h"
#include "float16.h"
#include "gemm_kernel_ext.h"
#include "kvcache_tensor.h"
#include "matmul_helper.h"
#include "simple_mem_pool.h"
#include "transformer_ctx.h"
#include "transformer_util.h"

Expand Down Expand Up @@ -271,7 +273,8 @@ class Attention {
inputBuffer.Assign(presult, rows, cols, stride);
}
if (ctx->inputSeqLen > 256 && pastSeqLen == 0)
flashAttention(ctx, qkvGroupMatMul, resultBuffer2, resultBuffer1, presentKey, presentValue, attnMask, pastSeqLen);
flashAttention(
ctx, qkvGroupMatMul, resultBuffer2, resultBuffer1, presentKey, presentValue, attnMask, pastSeqLen);
else
fusedAttention(ctx, query, key, value, resultBuffer1, presentKey, presentValue, attnMask, pastSeqLen);
t4.release();
Expand Down Expand Up @@ -364,10 +367,15 @@ class Attention {
int scoreBufSize = batchSize * responsibleHeads * ctx->inputSeqLen * ctx->inputSeqLen;
bool scoreBufByThread = (ctx->numThreads * mBlockSize * (pastSeqLen + ctx->inputSeqLen) <= scoreBufSize);

// For group attention, as #kvHeads != #qHeads, need to copy current key/values to cache seperately
// When M dimension is split, also multiple tasks per copy, so do copy seperately
// If total tasks are too small (compared to total thread number), need to shard the head
bool shardHead = (ctx->inputSeqLen == 1) && (ctx->numThreads >= batchSize * responsibleHeads * 2);

// Need to copy current key/values to cache seperately if:
// (1) For group attention (#kvHeads != #qHeads)
// (2) When M dimension is split, multiple tasks per copy, so do copy seperately
// (3) When head is sharded, also multiple tasks per copy
bool kvCopied = false;
if (ctx->kvHeadNum < ctx->attHeadNum || mBlockSize != ctx->inputSeqLen) {
if (ctx->kvHeadNum < ctx->attHeadNum || mBlockSize != ctx->inputSeqLen || shardHead) {
#pragma omp parallel for collapse(3)
for (int b = 0; b < batchSize; ++b) {
for (int i = 0; i < (this->endKVHead - this->startKVHead); ++i) {
Expand All @@ -394,6 +402,11 @@ class Attention {
kvCopied = true;
}

// Seperate impl. when head is sharded
if (shardHead) {
return crossAttnShardHead(ctx, query, key, value, result, presentKey, presentValue, attnMask, pastSeqLen);
}

#pragma omp parallel for collapse(3)
for (int b = 0; b < batchSize; ++b) {
for (int i = 0; i < responsibleHeads; ++i) {
Expand Down Expand Up @@ -437,8 +450,8 @@ class Attention {
const int queryLen = ctx->inputSeqLen;
const int keyLen = pastSeqLen + ctx->inputSeqLen;

small_gemm_transb(
getMask(attnMask, b, i, queryLen, keyLen) + startSeq * keyLen, A, B, C, m, n, k, lda, ldb, ldc);
small_gemm_transb(getMask(attnMask, b, i, queryLen, keyLen) + startSeq * keyLen, A, B, C, m, n, k,
lda, ldb, ldc);

#ifdef DEBUG
if (b == 0 && i == 0) {
Expand Down Expand Up @@ -498,7 +511,8 @@ class Attention {
if constexpr (std::is_same_v<KVCacheT, float>) {
xdnn_sgemm_single_thread(false, false, m, n, k, 1.0f, A, lda, B, ldb, 0.0f, C, ldc);
} else if constexpr (std::is_same_v<KVCacheT, float16_t>) {
xdnn_sgemm_f32f16f32_single_thread(false, false, m, n, k, 1.0f, A, lda, (const XDNN_FP16 *)B, ldb, 0.0f, C, ldc);
xdnn_sgemm_f32f16f32_single_thread(
false, false, m, n, k, 1.0f, A, lda, (const XDNN_FP16 *)B, ldb, 0.0f, C, ldc);
}

#ifdef DEBUG
Expand All @@ -514,6 +528,160 @@ class Attention {
} // end for b
}

// When #heads is very few, need to shard each head to use more resources
template <typename KVCacheT>
void crossAttnShardHead(DecoderContext *ctx, hpj::Matrix<float> &query, hpj::Matrix<float> &key,
hpj::Matrix<float> &value, hpj::Matrix<float> &result, KVCacheTensor<KVCacheT> &presentKey,
KVCacheTensor<KVCacheT> &presentValue, const float *attnMask, int pastSeqLen) {
const int responsibleHeads = this->endQHead - this->startQHead;
const int batchSize = ctx->batchSize;
const int groupNum = ctx->attHeadNum / ctx->kvHeadNum;

int N = pastSeqLen + ctx->inputSeqLen;
int splits = ctx->numThreads / (batchSize * responsibleHeads);
int nb = (N + splits - 1) / splits;

REQUIRES(splits > 1, "Do not call me when splits=%d", splits);

// max(xi), sum(exp(xi)), finish_tag for each split
int totalTasks = batchSize * responsibleHeads * splits;
AlignedType<std::tuple<float, float, float>, 32> splitInfo[totalTasks];
for (int i = 0; i < totalTasks; ++i) {
std::get<1>(splitInfo[i].data) = 0;
std::get<2>(splitInfo[i].data) = 0;
}

float *shardedOut = (float *)SimpleMemPool::instance().getBuffer(
"shardedOutput", totalTasks * ctx->attHeadSize * sizeof(float));

#pragma omp parallel for collapse(3)
for (int b = 0; b < batchSize; ++b) {
for (int i = 0; i < responsibleHeads; ++i) {
for (int s = 0; s < splits; ++s) {
int headStartIdx = b * responsibleHeads * splits + i * splits;
int threadIdx = b * responsibleHeads * splits + i * splits + s;

// Q * K
int nOff = s * nb;
auto keyMatInfo = presentKey.getHead(b, i / groupNum);
int m = 1;
int k = ctx->attHeadSize;
int n = (s < splits - 1 ? nb : N - nOff);
int lda = query.Stride();
int ldb = keyMatInfo.second;
int strideC = pastSeqLen > 0 ? (N + 15) / 16 * 16 : ctx->inputSeqLen;
int ldc = strideC;
auto A = query.Row(b * ctx->inputSeqLen) + i * ctx->attHeadSize;
auto B = keyMatInfo.first + nOff * ldb;
auto C = ctx->qkScores + (b * responsibleHeads + i) * ctx->inputSeqLen * strideC + nOff;

const int queryLen = ctx->inputSeqLen;
const int keyLen = N;

small_gemm_transb(getMask(attnMask, b, i, queryLen, keyLen), A, B, C, m, n, k, lda, ldb, ldc);

#ifdef DEBUG
if (b == 0 && i == 0 && s == splits - 1) {
dbg.debugPrint("Q * K, first head (some value may not be ready):\n");
auto p = ctx->qkScores;
dbg.debugPrint("%f, %f, %f ... %f %f %f\n", p[0] * ctx->attFactor, p[1] * ctx->attFactor,
p[2] * ctx->attFactor, p[keyLen - 3] * ctx->attFactor, p[keyLen - 2] * ctx->attFactor,
p[keyLen - 1] * ctx->attFactor);
}
#endif

// Softmax and the stats info
auto info = DecoderUtil::softmaxWithStats(
ctx, C, getMask(attnMask, b, i, queryLen, keyLen) + nOff, n);
std::get<0>(splitInfo[threadIdx].data) = info.first;
std::get<1>(splitInfo[threadIdx].data) = info.second;

#ifdef DEBUG
if (b == 0 && i == 0 && s == splits - 1) {
dbg.debugPrint("Softmax(Q * K), first head (some value may not be ready):\n");
auto p = ctx->qkScores;
dbg.debugPrint("%f, %f, %f ... %f %f %f\n", p[0], p[1], p[2], p[keyLen - 3], p[keyLen - 2],
p[keyLen - 1]);
}
#endif

// Softmax * V
auto valueMatInfo = presentValue.getHead(b, i / groupNum);
std::swap(k, n);
lda = strideC;
ldb = valueMatInfo.second;
ldc = result.Stride();
A = C;
B = valueMatInfo.first + nOff * ldb;
C = (s == 0 ? result.Row(b * ctx->inputSeqLen) + i * ctx->attHeadSize
: &shardedOut[threadIdx * ctx->attHeadSize]);

if constexpr (std::is_same_v<KVCacheT, float>) {
xdnn_sgemm_single_thread(false, false, m, n, k, 1.0f, A, lda, B, ldb, 0.0f, C, ldc);
} else if constexpr (std::is_same_v<KVCacheT, float16_t>) {
xdnn_sgemm_f32f16f32_single_thread(
false, false, m, n, k, 1.0f, A, lda, (const XDNN_FP16 *)B, ldb, 0.0f, C, ldc);
}

std::get<2>(splitInfo[threadIdx].data) = 1; // set finished flag

// Wait for all threads to finish and reduce the result
// Firstly get the max value, and then revise the value by considering the factor on numerator and denominator
if (s == 0) {
float realMax = std::get<0>(splitInfo[threadIdx].data);
for (int idx = headStartIdx + 1; idx < headStartIdx + splits; ++idx) {
while (std::get<2>(splitInfo[idx].data) == 0) {
_mm_pause();
}
if (std::get<0>(splitInfo[idx].data) > realMax) {
realMax = std::get<0>(splitInfo[idx].data);
}
}

float realSum = 0;
for (int idx = headStartIdx; idx < headStartIdx + splits; ++idx) {
float splitMax = std::get<0>(splitInfo[idx].data);
float splitSum = std::get<1>(splitInfo[idx].data);
float revFactor = std::exp(splitMax - realMax); // revise factor
std::get<2>(splitInfo[idx].data) = revFactor; // borrow finish flag for revise factor
realSum += splitSum * revFactor;
}

float *p = C;
#pragma simd
for (int t = 0; t < ctx->attHeadSize; ++t) {
int idx = threadIdx;
float splitMax = std::get<0>(splitInfo[idx].data);
float splitSum = std::get<1>(splitInfo[idx].data);
float revFactor = std::get<2>(splitInfo[idx].data);
C[t] = p[t] * revFactor * (splitSum / realSum);
}

for (int idx = headStartIdx + 1; idx < headStartIdx + splits; ++idx) {
float *p = &shardedOut[idx * ctx->attHeadSize];
#pragma simd
for (int t = 0; t < ctx->attHeadSize; ++t) {
float splitMax = std::get<0>(splitInfo[idx].data);
float splitSum = std::get<1>(splitInfo[idx].data);
float revFactor = std::get<2>(splitInfo[idx].data);
C[t] += p[t] * revFactor * (splitSum / realSum);
}
}
}

#ifdef DEBUG
if (b == 0 && i == 0 && s == 0) {
dbg.debugPrint("Softmax(Q * K) * V, first head:\n");
auto p = C;
dbg.debugPrint("%f, %f, %f ... %f %f %f\n", p[0], p[1], p[2], p[ctx->attHeadSize - 3],
p[ctx->attHeadSize - 2], p[ctx->attHeadSize - 1]);
}
#endif
} // end for s
} // end for i
} // end for b
}

template <typename KVCacheT>
void flashAttention(DecoderContext *ctx, hpj::Matrix<float> &qkvMatMul, hpj::Matrix<float> &tmpRes,
hpj::Matrix<float> &result, KVCacheTensor<KVCacheT> &presentKey, KVCacheTensor<KVCacheT> &presentValue,
Expand All @@ -529,19 +697,19 @@ class Attention {
int srcLen = ctx->inputSeqLen;
int tgtLen = pastSeqLen + srcLen;

float *transQKV = (float*)malloc(sizeof(float) * batchSize * qkvCols * srcLen * headSize);
float *transQKV = (float *)malloc(sizeof(float) * batchSize * qkvCols * srcLen * headSize);

DecoderUtil::transposeQKV(qkvMatMul.Data(), transQKV, batchSize, srcLen, respQHeads, respKVHeads, headSize);

float *query = transQKV;
float *key = transQKV + batchSize * respQHeads * srcLen * headSize;
float *value = transQKV + batchSize * (respQHeads + respKVHeads) * srcLen * headSize;

scaledDpAttention(query, key, value, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads,
respKVHeads, headSize, tmpRes.Data());
DecoderUtil::transposeAttnResult(tmpRes.Data(), result.Data(), batchSize, srcLen, respQHeads, headSize,
result.Stride());
scaledDpAttention(query, key, value, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads, respKVHeads,
headSize, tmpRes.Data());
DecoderUtil::transposeAttnResult(
tmpRes.Data(), result.Data(), batchSize, srcLen, respQHeads, headSize, result.Stride());

// For group attention, as #kvHeads != #qHeads, need to copy current key/values to cache seperately
// When M dimension is split, also multiple tasks per copy, so do copy seperately
#pragma omp parallel for collapse(3)
Expand Down Expand Up @@ -571,9 +739,8 @@ class Attention {
}

// scaled dot-product attention: bmm1 + softmax + bmm2
void scaledDpAttention(const float *query, const float *key, const float *value, const float *attnMask,
float scale, int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize,
float* output) {
void scaledDpAttention(const float *query, const float *key, const float *value, const float *attnMask, float scale,
int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, float *output) {
// output = trans(softmax(query * trans(key)) * value)
int nth = omp_get_max_threads();
int minBlk = (nth >= batchSize * numQHead ? 256 : 512);
Expand Down Expand Up @@ -607,23 +774,21 @@ class Attention {
for (int j = 0; j < numQHead; ++j) {
for (int m = 0; m < srcLen; m += srcBlk) {
int tid = omp_get_thread_num();
int tgtOff =
i * numKVHead * tgtLen * headSize + (j / numGroup) * tgtLen * headSize;
const float* k = key + tgtOff;
const float* v = value + tgtOff;
const float* attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen;
int tgtOff = i * numKVHead * tgtLen * headSize + (j / numGroup) * tgtLen * headSize;
const float *k = key + tgtOff;
const float *v = value + tgtOff;
const float *attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen;

int qRealBlk = std::min(srcBlk, srcLen - m);
int srcOff =
i * numQHead * tgtLen * headSize + j * tgtLen * headSize;
const float* q = query + srcOff + m * headSize;
float* out = output + srcOff + m * headSize;
int srcOff = i * numQHead * tgtLen * headSize + j * tgtLen * headSize;
const float *q = query + srcOff + m * headSize;
float *out = output + srcOff + m * headSize;

// reset out
for (int ii = 0; ii < qRealBlk; ++ii) {
#pragma omp simd
for (int jj = 0; jj < headSize; ++jj) {
out[ii * headSize + jj] = 0; // reset output
out[ii * headSize + jj] = 0; // reset output
}
}
// reset sum
Expand All @@ -638,20 +803,20 @@ class Attention {
for (int b = 0; b < tgtLen; b += tgtBlk) {
int kvRealBlk = std::min(tgtBlk, tgtLen - b);
// TODO: mask out
const float* kBlk = k + b * headSize;
const float* vBlk = v + b * headSize;
const float *kBlk = k + b * headSize;
const float *vBlk = v + b * headSize;

DecoderUtil::incrementalTileAttention(q, kBlk, vBlk, attnMsk + b, qRealBlk, headSize, kvRealBlk,
tgtLen, preSum[tid], sum[tid], preMax[tid], max[tid], refac, qkArr[tid],
expQkvArr[tid], out);
tgtLen, preSum[tid], sum[tid], preMax[tid], max[tid], refac, qkArr[tid], expQkvArr[tid],
out);
}
}
}
}
free(thrPtrBuf);
free(thrBuf);
return;
}
}

private:
std::pair<int, int> getTaskRange(int N, int splits, int splitIdx) {
Expand Down Expand Up @@ -692,7 +857,7 @@ class Attention {
return 0; // 0 means using the default value
}

virtual const float* getMask(const float* attnMask, int bId, int hId, int srcLen, int tgtLen) {
virtual const float *getMask(const float *attnMask, int bId, int hId, int srcLen, int tgtLen) {
return attnMask + bId * srcLen * tgtLen;
}

Expand Down
Loading

0 comments on commit 7430fe5

Please sign in to comment.