Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[perf] improve next token latency when (#threads >= 2 * #heads) by sharding the head into multiple splits #70

Merged
merged 2 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/common/aligned_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <cstddef>
#include <type_traits>

template <typename T, std::size_t Alignment>
struct AlignedType {
alignas(Alignment) T data;

// Default constructor
AlignedType() = default;

// Constructor to initialize with a value of type T
explicit AlignedType(const T &value) : data(value) {}

// Conversion operator to convert AlignedType to T
operator T() const { return data; }

// Overload the assignment operator to assign a value of type T
AlignedType &operator=(const T &value) {
data = value;
return *this;
}
};
229 changes: 197 additions & 32 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
#pragma once
#include <numeric>

#include "aligned_type.h"
#include "bfloat16.h"
#include "debugger.h"
#include "decoder_util.h"
#include "float16.h"
#include "gemm_kernel_ext.h"
#include "kvcache_tensor.h"
#include "matmul_helper.h"
#include "simple_mem_pool.h"
#include "transformer_ctx.h"
#include "transformer_util.h"

Expand Down Expand Up @@ -271,7 +273,8 @@ class Attention {
inputBuffer.Assign(presult, rows, cols, stride);
}
if (ctx->inputSeqLen > 256 && pastSeqLen == 0)
flashAttention(ctx, qkvGroupMatMul, resultBuffer2, resultBuffer1, presentKey, presentValue, attnMask, pastSeqLen);
flashAttention(
ctx, qkvGroupMatMul, resultBuffer2, resultBuffer1, presentKey, presentValue, attnMask, pastSeqLen);
else
fusedAttention(ctx, query, key, value, resultBuffer1, presentKey, presentValue, attnMask, pastSeqLen);
t4.release();
Expand Down Expand Up @@ -364,10 +367,15 @@ class Attention {
int scoreBufSize = batchSize * responsibleHeads * ctx->inputSeqLen * ctx->inputSeqLen;
bool scoreBufByThread = (ctx->numThreads * mBlockSize * (pastSeqLen + ctx->inputSeqLen) <= scoreBufSize);

// For group attention, as #kvHeads != #qHeads, need to copy current key/values to cache seperately
// When M dimension is split, also multiple tasks per copy, so do copy seperately
// If total tasks are too small (compared to total thread number), need to shard the head
bool shardHead = (ctx->inputSeqLen == 1) && (ctx->numThreads >= batchSize * responsibleHeads * 2);

// Need to copy current key/values to cache seperately if:
// (1) For group attention (#kvHeads != #qHeads)
// (2) When M dimension is split, multiple tasks per copy, so do copy seperately
// (3) When head is sharded, also multiple tasks per copy
bool kvCopied = false;
if (ctx->kvHeadNum < ctx->attHeadNum || mBlockSize != ctx->inputSeqLen) {
if (ctx->kvHeadNum < ctx->attHeadNum || mBlockSize != ctx->inputSeqLen || shardHead) {
#pragma omp parallel for collapse(3)
for (int b = 0; b < batchSize; ++b) {
for (int i = 0; i < (this->endKVHead - this->startKVHead); ++i) {
Expand All @@ -394,6 +402,11 @@ class Attention {
kvCopied = true;
}

// Seperate impl. when head is sharded
if (shardHead) {
return crossAttnShardHead(ctx, query, key, value, result, presentKey, presentValue, attnMask, pastSeqLen);
}

#pragma omp parallel for collapse(3)
for (int b = 0; b < batchSize; ++b) {
for (int i = 0; i < responsibleHeads; ++i) {
Expand Down Expand Up @@ -437,8 +450,8 @@ class Attention {
const int queryLen = ctx->inputSeqLen;
const int keyLen = pastSeqLen + ctx->inputSeqLen;

small_gemm_transb(
getMask(attnMask, b, i, queryLen, keyLen) + startSeq * keyLen, A, B, C, m, n, k, lda, ldb, ldc);
small_gemm_transb(getMask(attnMask, b, i, queryLen, keyLen) + startSeq * keyLen, A, B, C, m, n, k,
lda, ldb, ldc);

#ifdef DEBUG
if (b == 0 && i == 0) {
Expand Down Expand Up @@ -498,7 +511,8 @@ class Attention {
if constexpr (std::is_same_v<KVCacheT, float>) {
xdnn_sgemm_single_thread(false, false, m, n, k, 1.0f, A, lda, B, ldb, 0.0f, C, ldc);
} else if constexpr (std::is_same_v<KVCacheT, float16_t>) {
xdnn_sgemm_f32f16f32_single_thread(false, false, m, n, k, 1.0f, A, lda, (const XDNN_FP16 *)B, ldb, 0.0f, C, ldc);
xdnn_sgemm_f32f16f32_single_thread(
false, false, m, n, k, 1.0f, A, lda, (const XDNN_FP16 *)B, ldb, 0.0f, C, ldc);
}

#ifdef DEBUG
Expand All @@ -514,6 +528,160 @@ class Attention {
} // end for b
}

// When #heads is very few, need to shard each head to use more resources
template <typename KVCacheT>
void crossAttnShardHead(DecoderContext *ctx, hpj::Matrix<float> &query, hpj::Matrix<float> &key,
hpj::Matrix<float> &value, hpj::Matrix<float> &result, KVCacheTensor<KVCacheT> &presentKey,
KVCacheTensor<KVCacheT> &presentValue, const float *attnMask, int pastSeqLen) {
const int responsibleHeads = this->endQHead - this->startQHead;
const int batchSize = ctx->batchSize;
const int groupNum = ctx->attHeadNum / ctx->kvHeadNum;

int N = pastSeqLen + ctx->inputSeqLen;
int splits = ctx->numThreads / (batchSize * responsibleHeads);
int nb = (N + splits - 1) / splits;

REQUIRES(splits > 1, "Do not call me when splits=%d", splits);

// max(xi), sum(exp(xi)), finish_tag for each split
int totalTasks = batchSize * responsibleHeads * splits;
AlignedType<std::tuple<float, float, float>, 32> splitInfo[totalTasks];
for (int i = 0; i < totalTasks; ++i) {
std::get<1>(splitInfo[i].data) = 0;
std::get<2>(splitInfo[i].data) = 0;
}

float *shardedOut = (float *)SimpleMemPool::instance().getBuffer(
"shardedOutput", totalTasks * ctx->attHeadSize * sizeof(float));

#pragma omp parallel for collapse(3)
for (int b = 0; b < batchSize; ++b) {
for (int i = 0; i < responsibleHeads; ++i) {
for (int s = 0; s < splits; ++s) {
int headStartIdx = b * responsibleHeads * splits + i * splits;
int threadIdx = b * responsibleHeads * splits + i * splits + s;

// Q * K
int nOff = s * nb;
auto keyMatInfo = presentKey.getHead(b, i / groupNum);
int m = 1;
int k = ctx->attHeadSize;
int n = (s < splits - 1 ? nb : N - nOff);
int lda = query.Stride();
int ldb = keyMatInfo.second;
int strideC = pastSeqLen > 0 ? (N + 15) / 16 * 16 : ctx->inputSeqLen;
int ldc = strideC;
auto A = query.Row(b * ctx->inputSeqLen) + i * ctx->attHeadSize;
auto B = keyMatInfo.first + nOff * ldb;
auto C = ctx->qkScores + (b * responsibleHeads + i) * ctx->inputSeqLen * strideC + nOff;

const int queryLen = ctx->inputSeqLen;
const int keyLen = N;

small_gemm_transb(getMask(attnMask, b, i, queryLen, keyLen), A, B, C, m, n, k, lda, ldb, ldc);

#ifdef DEBUG
if (b == 0 && i == 0 && s == splits - 1) {
dbg.debugPrint("Q * K, first head (some value may not be ready):\n");
auto p = ctx->qkScores;
dbg.debugPrint("%f, %f, %f ... %f %f %f\n", p[0] * ctx->attFactor, p[1] * ctx->attFactor,
p[2] * ctx->attFactor, p[keyLen - 3] * ctx->attFactor, p[keyLen - 2] * ctx->attFactor,
p[keyLen - 1] * ctx->attFactor);
}
#endif

// Softmax and the stats info
auto info = DecoderUtil::softmaxWithStats(
ctx, C, getMask(attnMask, b, i, queryLen, keyLen) + nOff, n);
std::get<0>(splitInfo[threadIdx].data) = info.first;
std::get<1>(splitInfo[threadIdx].data) = info.second;

#ifdef DEBUG
if (b == 0 && i == 0 && s == splits - 1) {
dbg.debugPrint("Softmax(Q * K), first head (some value may not be ready):\n");
auto p = ctx->qkScores;
dbg.debugPrint("%f, %f, %f ... %f %f %f\n", p[0], p[1], p[2], p[keyLen - 3], p[keyLen - 2],
p[keyLen - 1]);
}
#endif

// Softmax * V
auto valueMatInfo = presentValue.getHead(b, i / groupNum);
std::swap(k, n);
lda = strideC;
ldb = valueMatInfo.second;
ldc = result.Stride();
A = C;
B = valueMatInfo.first + nOff * ldb;
C = (s == 0 ? result.Row(b * ctx->inputSeqLen) + i * ctx->attHeadSize
: &shardedOut[threadIdx * ctx->attHeadSize]);

if constexpr (std::is_same_v<KVCacheT, float>) {
xdnn_sgemm_single_thread(false, false, m, n, k, 1.0f, A, lda, B, ldb, 0.0f, C, ldc);
} else if constexpr (std::is_same_v<KVCacheT, float16_t>) {
xdnn_sgemm_f32f16f32_single_thread(
false, false, m, n, k, 1.0f, A, lda, (const XDNN_FP16 *)B, ldb, 0.0f, C, ldc);
}

std::get<2>(splitInfo[threadIdx].data) = 1; // set finished flag

// Wait for all threads to finish and reduce the result
// Firstly get the max value, and then revise the value by considering the factor on numerator and denominator
if (s == 0) {
float realMax = std::get<0>(splitInfo[threadIdx].data);
for (int idx = headStartIdx + 1; idx < headStartIdx + splits; ++idx) {
while (std::get<2>(splitInfo[idx].data) == 0) {
_mm_pause();
}
if (std::get<0>(splitInfo[idx].data) > realMax) {
realMax = std::get<0>(splitInfo[idx].data);
}
}

float realSum = 0;
for (int idx = headStartIdx; idx < headStartIdx + splits; ++idx) {
float splitMax = std::get<0>(splitInfo[idx].data);
float splitSum = std::get<1>(splitInfo[idx].data);
float revFactor = std::exp(splitMax - realMax); // revise factor
std::get<2>(splitInfo[idx].data) = revFactor; // borrow finish flag for revise factor
realSum += splitSum * revFactor;
}

float *p = C;
#pragma simd
for (int t = 0; t < ctx->attHeadSize; ++t) {
int idx = threadIdx;
float splitMax = std::get<0>(splitInfo[idx].data);
float splitSum = std::get<1>(splitInfo[idx].data);
float revFactor = std::get<2>(splitInfo[idx].data);
C[t] = p[t] * revFactor * (splitSum / realSum);
}

for (int idx = headStartIdx + 1; idx < headStartIdx + splits; ++idx) {
float *p = &shardedOut[idx * ctx->attHeadSize];
#pragma simd
for (int t = 0; t < ctx->attHeadSize; ++t) {
float splitMax = std::get<0>(splitInfo[idx].data);
float splitSum = std::get<1>(splitInfo[idx].data);
float revFactor = std::get<2>(splitInfo[idx].data);
C[t] += p[t] * revFactor * (splitSum / realSum);
}
}
}

#ifdef DEBUG
if (b == 0 && i == 0 && s == 0) {
dbg.debugPrint("Softmax(Q * K) * V, first head:\n");
auto p = C;
dbg.debugPrint("%f, %f, %f ... %f %f %f\n", p[0], p[1], p[2], p[ctx->attHeadSize - 3],
p[ctx->attHeadSize - 2], p[ctx->attHeadSize - 1]);
}
#endif
} // end for s
} // end for i
} // end for b
}

template <typename KVCacheT>
void flashAttention(DecoderContext *ctx, hpj::Matrix<float> &qkvMatMul, hpj::Matrix<float> &tmpRes,
hpj::Matrix<float> &result, KVCacheTensor<KVCacheT> &presentKey, KVCacheTensor<KVCacheT> &presentValue,
Expand All @@ -529,19 +697,19 @@ class Attention {
int srcLen = ctx->inputSeqLen;
int tgtLen = pastSeqLen + srcLen;

float *transQKV = (float*)malloc(sizeof(float) * batchSize * qkvCols * srcLen * headSize);
float *transQKV = (float *)malloc(sizeof(float) * batchSize * qkvCols * srcLen * headSize);

DecoderUtil::transposeQKV(qkvMatMul.Data(), transQKV, batchSize, srcLen, respQHeads, respKVHeads, headSize);

float *query = transQKV;
float *key = transQKV + batchSize * respQHeads * srcLen * headSize;
float *value = transQKV + batchSize * (respQHeads + respKVHeads) * srcLen * headSize;

scaledDpAttention(query, key, value, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads,
respKVHeads, headSize, tmpRes.Data());
DecoderUtil::transposeAttnResult(tmpRes.Data(), result.Data(), batchSize, srcLen, respQHeads, headSize,
result.Stride());
scaledDpAttention(query, key, value, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads, respKVHeads,
headSize, tmpRes.Data());
DecoderUtil::transposeAttnResult(
tmpRes.Data(), result.Data(), batchSize, srcLen, respQHeads, headSize, result.Stride());

// For group attention, as #kvHeads != #qHeads, need to copy current key/values to cache seperately
// When M dimension is split, also multiple tasks per copy, so do copy seperately
#pragma omp parallel for collapse(3)
Expand Down Expand Up @@ -571,9 +739,8 @@ class Attention {
}

// scaled dot-product attention: bmm1 + softmax + bmm2
void scaledDpAttention(const float *query, const float *key, const float *value, const float *attnMask,
float scale, int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize,
float* output) {
void scaledDpAttention(const float *query, const float *key, const float *value, const float *attnMask, float scale,
int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, float *output) {
// output = trans(softmax(query * trans(key)) * value)
int nth = omp_get_max_threads();
int minBlk = (nth >= batchSize * numQHead ? 256 : 512);
Expand Down Expand Up @@ -607,23 +774,21 @@ class Attention {
for (int j = 0; j < numQHead; ++j) {
for (int m = 0; m < srcLen; m += srcBlk) {
int tid = omp_get_thread_num();
int tgtOff =
i * numKVHead * tgtLen * headSize + (j / numGroup) * tgtLen * headSize;
const float* k = key + tgtOff;
const float* v = value + tgtOff;
const float* attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen;
int tgtOff = i * numKVHead * tgtLen * headSize + (j / numGroup) * tgtLen * headSize;
const float *k = key + tgtOff;
const float *v = value + tgtOff;
const float *attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen;

int qRealBlk = std::min(srcBlk, srcLen - m);
int srcOff =
i * numQHead * tgtLen * headSize + j * tgtLen * headSize;
const float* q = query + srcOff + m * headSize;
float* out = output + srcOff + m * headSize;
int srcOff = i * numQHead * tgtLen * headSize + j * tgtLen * headSize;
const float *q = query + srcOff + m * headSize;
float *out = output + srcOff + m * headSize;

// reset out
for (int ii = 0; ii < qRealBlk; ++ii) {
#pragma omp simd
for (int jj = 0; jj < headSize; ++jj) {
out[ii * headSize + jj] = 0; // reset output
out[ii * headSize + jj] = 0; // reset output
}
}
// reset sum
Expand All @@ -638,20 +803,20 @@ class Attention {
for (int b = 0; b < tgtLen; b += tgtBlk) {
int kvRealBlk = std::min(tgtBlk, tgtLen - b);
// TODO: mask out
const float* kBlk = k + b * headSize;
const float* vBlk = v + b * headSize;
const float *kBlk = k + b * headSize;
const float *vBlk = v + b * headSize;

DecoderUtil::incrementalTileAttention(q, kBlk, vBlk, attnMsk + b, qRealBlk, headSize, kvRealBlk,
tgtLen, preSum[tid], sum[tid], preMax[tid], max[tid], refac, qkArr[tid],
expQkvArr[tid], out);
tgtLen, preSum[tid], sum[tid], preMax[tid], max[tid], refac, qkArr[tid], expQkvArr[tid],
out);
}
}
}
}
free(thrPtrBuf);
free(thrBuf);
return;
}
}

private:
std::pair<int, int> getTaskRange(int N, int splits, int splitIdx) {
Expand Down Expand Up @@ -692,7 +857,7 @@ class Attention {
return 0; // 0 means using the default value
}

virtual const float* getMask(const float* attnMask, int bId, int hId, int srcLen, int tgtLen) {
virtual const float *getMask(const float *attnMask, int bId, int hId, int srcLen, int tgtLen) {
return attnMask + bId * srcLen * tgtLen;
}

Expand Down
Loading