Skip to content

Commit

Permalink
merge ChatGLM2Attention::forward into Attention::forward (#86)
Browse files Browse the repository at this point in the history
add epsilon param into LayerNorm to align with RmsNorm (LayerNorm doesn't use this param)

extend qk_shape from 4 to 5 in attention.h, to pass key_head_num into rotary_embedding_chatglm2 for multi-query-attention.
  • Loading branch information
a3213105 authored Dec 1, 2023
1 parent d5e576d commit 2830926
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 159 deletions.
2 changes: 1 addition & 1 deletion include/layers_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LayerNorm {

// input and output are in shape of (rows, normSize)
// TODO: column-wise parallel
void forward(const float *input, float *output, int rows, int iStride = -1, int oStride = -1);
void forward(const float *input, float *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-5);

private:
int normSize;
Expand Down
8 changes: 5 additions & 3 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class Attention {
auto &resultBuffer1 = imBuffer;
auto &resultBuffer2 = ctx->tmpBuf;

float epsilon = ctx->epsilon;
// init group_qkvBuffer
int attHeadSize = ctx->attHeadSize;
int qkvRows = ctx->batchSize * inputSeqLen;
Expand All @@ -210,7 +211,7 @@ class Attention {
if (doLnBefore) {
TimeLine t1("input.layer_norm");
norm.forward(inputBuffer.Data(), resultBuffer1.Data(), inputBuffer.Rows(), inputBuffer.Stride(),
resultBuffer1.Stride());
resultBuffer1.Stride(), epsilon);
}
#ifdef DEBUG
dbg.debugPrint("layer norm:\n");
Expand All @@ -237,8 +238,9 @@ class Attention {

// Apply post operattions on query and key
TimeLine t3("QKPO");
int heads = this->endQHead - this->startQHead;
int qk_shape[4] = {ctx->batchSize, ctx->inputSeqLen, heads, ctx->attHeadSize};
int qheads = this->endQHead - this->startQHead;
int kheads = this->endKVHead - this->startKVHead;
int qk_shape[5] = {ctx->batchSize, ctx->inputSeqLen, qheads, ctx->attHeadSize, kheads};
if (positionIds != nullptr) {
qkpo.forward(query.Data(), key.Data(), query.Stride(), key.Stride(), qk_shape, positionIds);
} else if (ctx->maxPosEmbed > 0) {
Expand Down
154 changes: 0 additions & 154 deletions src/layers/attn_chatglm2.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,158 +24,4 @@ class ChatGLM2Attention : public Attention<WeiT, QKPO_CLS, NORM_CLS, INPUT_AS_RE
: Attention<WeiT, QKPO_CLS, NORM_CLS, INPUT_AS_RESID>(layerId, ctx) {}
virtual ~ChatGLM2Attention() {}

public:
template <typename KVCacheT>
void forward(DecoderContext *ctx, float *input, float *output, const float *attnMask,
KVCacheTensor<KVCacheT> &presentKey, KVCacheTensor<KVCacheT> &presentValue, int inputSeqLen, int pastSeqLen,
bool useSelfAttn, bool doLnBefore, bool returnAttn, bool returnKVs, bool forPT = true,
int *positionIds = nullptr) {
if (forPT) {
printf("For better perf, need to manage cached key/vaues by ourself, PyTorch extension is not supported "
"any more.\n");
exit(-1);
}

KVCacheT *presentKeys = presentKey.getData();
KVCacheT *presentValues = presentValue.getData();

hpj::Matrix<float> inputBuffer(input, ctx->batchSize * inputSeqLen, ctx->hiddenSize, ctx->hiddenSize);
hpj::Matrix<float> outBuffer(output, ctx->batchSize * inputSeqLen, ctx->hiddenSize, ctx->hiddenSize);

auto hiddenSize = ctx->hiddenSize;
auto &qkvMatMul = ctx->qkvMatMul;
auto &resultBuffer1 = (ctx->numSplit == 1 ? outBuffer : ctx->normBuf);
auto &resultBuffer2 = ctx->tmpBuf;
float epsilon = ctx->epsilon;

// //init group_qkvBuffer
int attHeadSize = ctx->attHeadSize;
int qkvRows = ctx->batchSize * inputSeqLen;
// multi query attention
int q_cols = (this->endQHead - this->startQHead) * attHeadSize;
int kv_cols = (this->endKVHead - this->startKVHead) * attHeadSize;
int qkCols = q_cols + kv_cols;
int qkvCols = qkCols + kv_cols;

int qkvStride = qkvCols;
hpj::Matrix<float> qkvGroupMatMul(qkvMatMul.Data(), qkvRows, qkvCols, qkvStride);

#ifdef DEBUG
this->dbg.debugPrint("---- GLM2 DecoderLayer.forward (useSelfAttn=%d) ----\n", useSelfAttn);
this->dbg.debugPrint("input [%d, %d, %d]:\n", ctx->batchSize * inputSeqLen, ctx->hiddenSize, ctx->hiddenSize);
this->dbg.dumpMatrix(inputBuffer);
#endif

if (doLnBefore) {
TimeLine t1("input.layer_norm");
this->norm.forward(inputBuffer.Data(), resultBuffer1.Data(), inputBuffer.Rows(), inputBuffer.Stride(),
resultBuffer1.Stride(), epsilon);
}
#ifdef DEBUG
this->dbg.debugPrint(
"layer norm [%d, %d, %d]:\n", ctx->batchSize * inputSeqLen, ctx->hiddenSize, ctx->hiddenSize);
this->dbg.dumpMatrix(resultBuffer1);
#endif
// Query, Key, Value computed together
TimeLine t2("QKV.linear");
DecoderUtil::dense(resultBuffer1, this->qkvWeight, this->qkvWeightScale, this->qkvWeightZero, this->qkvBias,
qkvGroupMatMul);
t2.release();

#ifdef DEBUG
this->dbg.debugPrint("dense [%d, %d, %d]:\n", ctx->batchSize * inputSeqLen, ctx->hiddenSize, qkvCols);
this->dbg.dumpMatrix(qkvGroupMatMul);
#endif

// Apply post operattions on query and key
TimeLine t3("QKPO");
if (positionIds != nullptr) {
this->qkpo.forward(qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), ctx->batchSize, inputSeqLen, qkCols,
attHeadSize, positionIds);
} else {
std::vector<int> position_ids(ctx->inputSeqLen);
if (inputSeqLen == 1) {
position_ids[0] = pastSeqLen;
} else {
std::iota(position_ids.begin(), position_ids.end(), pastSeqLen);
}
this->qkpo.forward(qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), ctx->batchSize, inputSeqLen, qkCols,
attHeadSize, position_ids.data());
}
t3.release();

#ifdef DEBUG
this->dbg.debugPrint("qkpo [%d, %d, %d]:\n", ctx->batchSize * inputSeqLen, ctx->hiddenSize, qkvCols);
this->dbg.dumpMatrix(qkvGroupMatMul);
#endif

// this->expand_to_qkv(qkvMatMul, qkvGroupMatMul, qkvRows, q_cols, kv_cols, kvHeadNum);
// printf("q_cols=%d, kv_cols=%d, qk_cols=%d\n", q_cols, kv_cols, qkCols);
hpj::Matrix<float> query(qkvGroupMatMul, 0, inputBuffer.Rows(), 0, q_cols);
hpj::Matrix<float> key(qkvGroupMatMul, 0, inputBuffer.Rows(), q_cols, kv_cols);
hpj::Matrix<float> value(qkvGroupMatMul, 0, inputBuffer.Rows(), qkCols, kv_cols);

#ifdef DEBUG
this->dbg.debugPrint("Q [%d, %d]:\n", query.Rows(), query.Cols());
this->dbg.dumpMatrix(query);
this->dbg.debugPrint("K [%d, %d]:\n", key.Rows(), key.Cols());
this->dbg.dumpMatrix(key);
this->dbg.debugPrint("V [%d, %d]:\n", value.Rows(), value.Cols());
this->dbg.dumpMatrix(value);
#endif

if (this->getScalingCoeff() != 0) { ctx->attFactor = this->getScalingCoeff(); }
TimeLine t4("MHA");
if constexpr (!INPUT_AS_RESID) {
auto presult = resultBuffer1.Data();
int rows = resultBuffer1.Rows(), cols = resultBuffer1.Cols(), stride = resultBuffer1.Stride();
resultBuffer1.Assign(inputBuffer.Data(), inputBuffer.Rows(), inputBuffer.Cols(), inputBuffer.Stride());
inputBuffer.Assign(presult, rows, cols, stride);
}
if (ctx->inputSeqLen > 1024 && pastSeqLen == 0)
this->flashAttention(
ctx, qkvGroupMatMul, resultBuffer2, resultBuffer1, presentKey, presentValue, attnMask, pastSeqLen);
else
this->fusedAttention(ctx, query, key, value, resultBuffer1, presentKey, presentValue, attnMask, pastSeqLen);
t4.release();
hpj::Matrix<float> attnSplit(resultBuffer1.Data(), resultBuffer1.Rows(), resultBuffer1.Cols() / ctx->numSplit,
resultBuffer1.Stride());

#ifdef DEBUG
this->dbg.debugPrint("attention_%d (softmax * value):\n", ctx->splitIdx);
this->dbg.dumpMatrix(attnSplit);
#endif

TimeLine t5("Output");
// Output/projection in attention, only add the input in the first split
if (ctx->splitIdx == 0) {
float gamma = this->getResidentialScale();

// denseWithScaledSum should be enough, but as the performance of denseWithScaledSum is not verified,
// So here still use denseWithSum
if (gamma == 1) {
DecoderUtil::denseWithSum(attnSplit, this->attnOutputWeight, this->attnOutputWeightScale,
this->attnOutputWeightZero, this->attnOutputBias, inputBuffer, resultBuffer2);
} else {
DecoderUtil::denseWithScaledSum(attnSplit, this->attnOutputWeight, this->attnOutputWeightScale,
this->attnOutputWeightZero, this->attnOutputBias, gamma, inputBuffer, resultBuffer2);
}
} else {
DecoderUtil::dense(attnSplit, this->attnOutputWeight, this->attnOutputWeightScale,
this->attnOutputWeightZero, this->attnOutputBias, resultBuffer2);
}
t5.release();

#ifdef DEBUG
this->dbg.debugPrint("attention output/projection: [%d, %d] (%d)\n", resultBuffer2.Rows(), resultBuffer2.Cols(),
resultBuffer2.Stride());
this->dbg.dumpMatrix(resultBuffer2);
#endif

if (!doLnBefore) {
TimeLine t6("result.layer_norm");
this->norm.forward(resultBuffer2.Data(), resultBuffer1.Data(), resultBuffer2.Rows(), resultBuffer2.Stride(),
resultBuffer1.Stride());
}
}
};
2 changes: 1 addition & 1 deletion src/layers/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void LayerNorm::setWeight(const float *gamma, const float *beta, int cols) {

// input and output are in shape of (rows, normSize)
// TODO: column-wise parallel
void LayerNorm::forward(const float *input, float *output, int rows, int iStride, int oStride) {
void LayerNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) {
TimeLine t("LayerNorm.forward");
const float *pgamma = weights;
const float *pbeta = weights + normSize;
Expand Down
33 changes: 33 additions & 0 deletions src/layers/rotary_embedding_chatglm2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,36 @@ void ChatGLM2RotaryEmbedding::forward(float *buf, int bufStride, int batch_size,
}
}
}

void ChatGLM2RotaryEmbedding::forward(
float *query, float *key, int qStride, int kStride, const int *qk_shape, const int *position_ids) {
int dim = inv_freq_size * 2;
REQUIRES(dim == qk_shape[3], "Incorrect shape, last dimention is not the head size.");
const int batch_size = qk_shape[0];
const int seq_len = qk_shape[1];
const int head_num = qk_shape[2] + qk_shape[4];
const int half = inv_freq_size;

#pragma omp parallel for
for (int head = 0; head < head_num; ++head) {
int off = head * dim;
for (int bs = 0; bs < batch_size; ++bs) {
for (int seq = 0; seq < seq_len; ++seq) {
float *p1 = query + off;

int pos = position_ids[seq];
float *pcos = emb_cos + pos * dim;
float *psin = emb_sin + pos * dim;

#pragma omp simd
for (int i = 0; i < half; i += 2) {
auto t1 = p1[i];
p1[i] = p1[i] * pcos[i] - p1[i + 1] * psin[i];
p1[i + 1] = p1[i + 1] * pcos[i] + t1 * psin[i];
}
off += qStride;
}
}
}
}

1 change: 1 addition & 0 deletions src/layers/rotary_embedding_chatglm2.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ChatGLM2RotaryEmbedding {
void forward(float *buf, int bufStride, int batch_size, int seq_len, int qk_size,
int hidden_size_per_attention_head, const int *position_ids);

void forward(float *query, float *key, int qStride, int kStride, const int *qk_shape, const int *position_ids);
private:
void glm2CalEmb();

Expand Down

0 comments on commit 2830926

Please sign in to comment.