diff --git a/include/layers_attention.h b/include/layers_attention.h index 42a95f90..7752e18e 100644 --- a/include/layers_attention.h +++ b/include/layers_attention.h @@ -1,3 +1,17 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ #pragma once #include "dtype.h" @@ -44,4 +58,10 @@ void invokeAttention(DataType dt, const int batch_size, const int *token_lens, const void *kcache, const void *vcache, int *kvcache_shape, int *block_tables, int *block_nums, int *context_lens, int layer_id, bool is_prefill, int *slot_mapping); +void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum, + int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, int hiddenSize, void *output, + int outputStride, const void *input, int inputStride, const void *queryWeight, const void *keyWeight, + const void *valueWeight, const void *attnOutWeight, const void *queryBias = nullptr, + const void *keyBias = nullptr, const void *valueBias = nullptr, const void *attnOutBias = nullptr); + } // namespace xft \ No newline at end of file diff --git a/src/layers/attention.cpp b/src/layers/attention.cpp new file mode 100644 index 00000000..9bef638e --- /dev/null +++ b/src/layers/attention.cpp @@ -0,0 +1,180 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#include "attention.h" +#include "kvcache_manager.h" +#include "layers_attention.h" +#include "rms_norm.h" + +#include + +namespace xft { + +void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum, + int maxPositions, int maxPosEmbed, int pastSeqLen, int currentSeqLen, int step, int hiddenSize, void *output, + int outputStride, const void *input, int inputStride, const void *queryWeight, const void *keyWeight, + const void *valueWeight, const void *attnOutWeight, const void *queryBias, const void *keyBias, + const void *valueBias, const void *attnOutBias) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + auto prepareAttnMask = [&](DecoderContext *ctx, int step) { + int seqLen = ctx->inputSeqLen; + int accSeqLen = pastSeqLen + currentSeqLen; + float *mask = nullptr; + + auto getAttnMask = [](int sizeRequired) { + static float *attnMask; + static int maskSize = 0; + if (maskSize < sizeRequired) { + if (attnMask) free(attnMask); + attnMask = (float *)xft::alloc(sizeRequired * sizeof(float)); + maskSize = sizeRequired; + } + return attnMask; + }; + + if (step == 0) { + int sizeRequired = ctx->batchSize * seqLen * seqLen; + mask = getAttnMask(sizeRequired); + for (int b = 0; b < ctx->batchSize; ++b) { + auto pmask = mask + b * seqLen * seqLen; + for (int i = 0; i < seqLen; ++i) { + memset(pmask + i * seqLen, 0, (i + 1) * sizeof(float)); // bottom left are 0 + std::fill_n(pmask + i * seqLen + i + 1, seqLen - i - 1, std::numeric_limits::lowest()); + } + } + } else if (seqLen > 1) { + int sizeRequired = ctx->batchSize * accSeqLen * seqLen; + mask = getAttnMask(sizeRequired); + for (int b = 0; b < ctx->batchSize; ++b) { + auto pmask = mask + b * accSeqLen * seqLen; + int pastLen = accSeqLen - seqLen; + for (int i = 0; i < seqLen; ++i) { + memset(pmask + i * accSeqLen, 0, (pastLen + i + 1) * sizeof(float)); + std::fill_n(pmask + i * accSeqLen + pastLen + i + 1, seqLen - i - 1, + std::numeric_limits::lowest()); + } + } + } else { + int sizeRequired = ctx->batchSize * accSeqLen; + mask = getAttnMask(sizeRequired); + memset(mask, 0, ctx->batchSize * accSeqLen * sizeof(float)); // all elements are 0 + } + + return mask; + }; + + if (dt == DataType::bf16) { + static std::unordered_map *> + llama_attention_hub; + + static DecoderContext *ctx; + static KVCacheManager *kvCacheMgr; + + if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->attHeadSize != attHeadDim))) { + if (ctx != nullptr) delete ctx; + printf(">> create context: %d %d\n", hiddenSize, attHeadDim); + ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, 1, "silu", 1e-6, 0, 0, + maxPositions, maxPosEmbed, -1, 0, 1); + ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + kvCacheMgr = new KVCacheManager(1); + } + + // create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed. + std::stringstream weights_addr; + weights_addr << queryWeight << "_" << keyWeight << "_" << valueWeight << "_" << attnOutWeight << "_" << dt + << "_" << attHeadDim << "_" << attHeadNum << "_" << kvHeadNum; + std::string llama_attention_key = weights_addr.str(); + Attention *llama_attention; + + auto it_created = llama_attention_hub.find(llama_attention_key); + if (it_created == llama_attention_hub.end()) { + llama_attention = new Attention(0, ctx); + llama_attention->setWeights(ctx, (float *)queryWeight, nullptr, nullptr, (float *)queryBias, + (float *)keyWeight, nullptr, nullptr, (float *)keyBias, (float *)valueWeight, nullptr, nullptr, + (float *)valueBias, (float *)attnOutWeight, nullptr, nullptr, (float *)attnOutBias, false, nullptr, + nullptr, false); + llama_attention_hub[llama_attention_key] = llama_attention; + printf(">> create llama_attention_key: %s\n", llama_attention_key.c_str()); + } else { + llama_attention = it_created->second; + } + + ctx->resize(batchSize, inputSeqLen, pastSeqLen); + hpj::Matrix actBuffers; + actBuffers.Resize(batchSize * inputSeqLen * 2, hiddenSize); + float *attnMask = prepareAttnMask(ctx, step); + + int workers = 1; + int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers; + kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim); + KVCacheTensor &presentKey = kvCacheMgr->getKey(0); + KVCacheTensor &presentValue = kvCacheMgr->getValue(0); + + llama_attention->forward(ctx, (float *)input, actBuffers.Data(), (float *)output, attnMask, presentKey, + presentValue, inputSeqLen, pastSeqLen, step == 0, false, false, nullptr); + } else if (dt == DataType::fp16) { + static std::unordered_map *> + llama_attention_hub; + + static DecoderContext *ctx; + static KVCacheManager *kvCacheMgr; + + if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->attHeadSize != attHeadDim))) { + if (ctx != nullptr) delete ctx; + printf(">> create context: %d %d\n", hiddenSize, attHeadDim); + ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, 1, "silu", 1e-6, 0, 0, + maxPositions, maxPosEmbed, -1, 0, 1); + ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + kvCacheMgr = new KVCacheManager(1); + } + + // create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed. + std::stringstream weights_addr; + weights_addr << queryWeight << "_" << keyWeight << "_" << valueWeight << "_" << attnOutWeight << "_" << dt + << "_" << attHeadDim << "_" << attHeadNum << "_" << kvHeadNum; + std::string llama_attention_key = weights_addr.str(); + Attention *llama_attention; + + auto it_created = llama_attention_hub.find(llama_attention_key); + if (it_created == llama_attention_hub.end()) { + llama_attention = new Attention(0, ctx); + llama_attention->setWeights(ctx, (float *)queryWeight, nullptr, nullptr, (float *)queryBias, + (float *)keyWeight, nullptr, nullptr, (float *)keyBias, (float *)valueWeight, nullptr, nullptr, + (float *)valueBias, (float *)attnOutWeight, nullptr, nullptr, (float *)attnOutBias, false, nullptr, + nullptr, false); + llama_attention_hub[llama_attention_key] = llama_attention; + printf(">> create llama_attention_key: %s\n", llama_attention_key.c_str()); + } else { + llama_attention = it_created->second; + } + + ctx->resize(batchSize, inputSeqLen, pastSeqLen); + hpj::Matrix actBuffers; + actBuffers.Resize(batchSize * inputSeqLen * 2, hiddenSize); + float *attnMask = prepareAttnMask(ctx, step); + + int workers = 1; + int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers; + kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim); + KVCacheTensor &presentKey = kvCacheMgr->getKey(0); + KVCacheTensor &presentValue = kvCacheMgr->getValue(0); + + llama_attention->forward(ctx, (float *)input, actBuffers.Data(), (float *)output, attnMask, presentKey, + presentValue, inputSeqLen, pastSeqLen, step == 0, false, false, nullptr); + } +} + +} // namespace xft \ No newline at end of file diff --git a/src/layers/attention.h b/src/layers/attention.h index 568cbcfc..0e3977eb 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -76,7 +76,7 @@ class Attention { const float *queryBias, const OriWeiT *keyWeight, const float *keyScale, const float *keyZero, const float *keyBias, const OriWeiT *valueWeight, const float *valueScale, const float *valueZero, const float *valueBias, const OriWeiT *attnOutWeight, const float *attnOutScale, const float *attnOutZero, - const float *attnOutBias, const float *gamma1, const float *beta1, bool trans = true) { + const float *attnOutBias, bool doLNorm, const float *gamma1, const float *beta1, bool trans = true) { int hiddenSize = ctx->hiddenSize; int headSize = ctx->attHeadSize; @@ -188,7 +188,8 @@ class Attention { } // LayerNorm - this->norm.setWeight(gamma1, beta1, hiddenSize); + if (doLNorm) + this->norm.setWeight(gamma1, beta1, hiddenSize); } #ifdef DEBUG @@ -207,6 +208,7 @@ class Attention { * - useSelfAttn: use self attention or not, self attention is used to gen first token * - doLnBefore: Do layer norm before or not. If true, will do layer norm as the first step * currently only support doLnBefore=true + * - doLnAfter: Do layer norm before or not. If true, will do layer norm as the first step * _________ _________ _________ _________ _________ * |_________|------------->|_________|------------->|_________|------------->|_________|------------->|_________| * layerNorm QKV Linear MHA out Linear @@ -215,7 +217,7 @@ class Attention { template void forward(DecoderContext *ctx, InT *input, ImT *imBuf, OutT *output, const float *attnMask, KVCacheTensor &presentKey, KVCacheTensor &presentValue, int inputSeqLen, int pastSeqLen, - bool useSelfAttn, bool doLnBefore, int *positionIds = nullptr) { + bool useSelfAttn, bool doLnBefore, bool doLnAfter, int *positionIds = nullptr) { auto hiddenSize = ctx->hiddenSize; hpj::Matrix inputBuffer(input, ctx->batchSize * inputSeqLen, hiddenSize, hiddenSize); @@ -382,7 +384,7 @@ class Attention { dbg.dumpMatrix(outBuffer); #endif - if (!doLnBefore) { + if (doLnAfter) { TimeLine t6("result.layer_norm"); norm.forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); #ifdef DEBUG diff --git a/src/layers/decoder_layer.h b/src/layers/decoder_layer.h index 38ae670e..3cbc9d55 100644 --- a/src/layers/decoder_layer.h +++ b/src/layers/decoder_layer.h @@ -86,7 +86,7 @@ class Decoder { bool trans = true) { attn.setWeights(ctx, queryWeight, queryScale, queryZero, queryBias, keyWeight, keyScale, keyZero, keyBias, valueWeight, valueScale, valueZero, valueBias, attnOutWeight, attnOutScale, attnOutZero, attnOutBias, - ln1Gamma, ln1Beta, trans); + true, ln1Gamma, ln1Beta, trans); mlp.setWeights(ctx, fc1Weight, fc1Scales, fc1Zeros, fc1Bias, fc2Weight, fc2Scales, fc2Zeros, fc2Bias, ln2Gamma, ln2Beta, fc3Weight, fc3Scales, fc3Zeros, trans); @@ -102,7 +102,7 @@ class Decoder { static_assert(sizeof(ImT) >= sizeof(Ttarget), "Intermediate buffer is NOT big enough!"); attn.forward(ctx, input, (Ttarget *)imBuf, output, attnMask, presentKey, presentValue, inputSeqLen, pastSeqLen, - useSelfAttn, doLnBefore, positionIds); + useSelfAttn, doLnBefore, false, positionIds); } template diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index 6e495d28..7241ccd3 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -22,13 +22,12 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { const std::string emb_cos_str = "emb_cos"; const std::string emb_sin_str = "emb_sin"; - // dim: equals to head size - ctx->GetAttr("size_per_head", &this->dim); - ctx->GetAttr("max_pos_seq_len", &this->max_position_embeddings, 2048); + this->dim = ctx->attHeadSize; + this->max_position_embeddings = ctx->maxPosEmbed; ctx->GetAttr("rope_theta", &this->base, 10000); ctx->GetAttr("rope_type", &this->rope_type, std::to_string(-1)); - if (this->rope_type == "linear") + if (this->rope_type == "linear") ctx->GetAttr("scaling_factor", &this->scaling_factor, 1.0f); inv_freq_size = (dim + 1) / 2; diff --git a/src/models/model_factory.h b/src/models/model_factory.h index d13a9310..2730347d 100644 --- a/src/models/model_factory.h +++ b/src/models/model_factory.h @@ -109,4 +109,4 @@ class DecoderRegister { MODEL(IMPLEMENT, CLASS, NAME) #define REGISTER_MODEL(CLASS, NAME) \ - MODEL(REGISTER, CLASS, NAME) \ No newline at end of file + MODEL(REGISTER, CLASS, NAME) diff --git a/tests/ut/layers_attention_test.cpp b/tests/ut/layers_attention_test.cpp new file mode 100644 index 00000000..942393c7 --- /dev/null +++ b/tests/ut/layers_attention_test.cpp @@ -0,0 +1,114 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#include +#include +#include + +#include "bfloat16.h" +#include "float16.h" +#include "layers_attention.h" +#include "gtest/gtest.h" + +template +static void compareAttentionLLaMA(int step, int batchSize, int inputSeqLen, int pastSeqLen, int currentSeqLen, + int attHeadDim, int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int hiddenSize, + const void *queryWeight, const void *keyWeight, const void *valueWeight, const void *attnOutWeight) { + // Create input + float *input = (float *)aligned_alloc(64, batchSize * inputSeqLen * hiddenSize * sizeof(float)); + float *ourOutput = (float *)aligned_alloc(64, batchSize * inputSeqLen * hiddenSize * sizeof(float)); + memset(ourOutput, 0, batchSize * inputSeqLen * hiddenSize * sizeof(float)); + + for (int i = 0; i < batchSize * inputSeqLen * hiddenSize; ++i) { + input[i] = static_cast(1.0f * rand() / RAND_MAX); + } + + xft::DataType dt = xft::DataType::unknown; + if constexpr (std::is_same::value) { + dt = xft::DataType::bf16; + } else if constexpr (std::is_same::value) { + dt = xft::DataType::fp16; + } else { + printf("Unsupported data type\n"); + GTEST_FAIL(); + return; + } + + auto start = std::chrono::high_resolution_clock::now(); + invokeAttentionLLaMA(dt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, maxPositions, maxPosEmbed, + pastSeqLen, currentSeqLen, step, hiddenSize, (void *)ourOutput, hiddenSize, (const void *)input, hiddenSize, + (const void *)queryWeight, (const void *)keyWeight, (const void *)valueWeight, (const void *)attnOutWeight); + auto end = std::chrono::high_resolution_clock::now(); + float during_time = std::chrono::duration(end - start).count(); + printf("[ RUNTIME ] XFT::invokeAttentionLLaMA %.6f sec\n", during_time); + + free(input); + free(ourOutput); +} + +template +void test_AttentionLLaMA(void) { + int maxPosEmbed = 4096; + int maxPositions = maxPosEmbed; + int hiddenSize = 4096; + int attHeadNum = 32; + int attHeadDim = hiddenSize / attHeadNum; + int kvHeadNum = 32; + int qSize = attHeadDim * attHeadNum; + int kvSize = attHeadDim * kvHeadNum; + + float *qkvProj = (float *)aligned_alloc(64, hiddenSize * (qSize + 2 * kvSize) * sizeof(float)); + float *oProj = (float *)aligned_alloc(64, hiddenSize * hiddenSize * sizeof(float)); + + for (int i = 0; i < hiddenSize * (qSize + 2 * kvSize); ++i) { + qkvProj[i] = static_cast(0.5f * rand() / RAND_MAX); + } + + for (int i = 0; i < hiddenSize * hiddenSize; ++i) { + oProj[i] = static_cast(0.5f * rand() / RAND_MAX); + } + + int step = 0; + int batchSize = 1; + int inputSeqLen = 18; + int pastSeqLen = 0; + int currentSeqLen = inputSeqLen; + int nextTokenNum = 1; + + compareAttentionLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, + kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj); + pastSeqLen += inputSeqLen; + currentSeqLen = nextTokenNum; + compareAttentionLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, + kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj); + pastSeqLen += nextTokenNum; + compareAttentionLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, + kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj); + + free(qkvProj); + free(oProj); +} + +TEST(AttentionLLaMA, bfloat16_t) { + test_AttentionLLaMA(); +} + +TEST(AttentionLLaMA, float16_t) { + test_AttentionLLaMA(); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file