Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API] Add LLaMA attention API. #378

Merged
merged 9 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/layers_attention.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#pragma once

#include "dtype.h"
Expand Down Expand Up @@ -44,4 +58,11 @@ void invokeAttention(DataType dt,
const int batch_size, const int *token_lens, const void *kcache, const void *vcache, int *kvcache_shape,
int *block_tables, int *block_nums, int *context_lens, int layer_id, bool is_prefill, int *slot_mapping);

void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum,
pujiang2018 marked this conversation as resolved.
Show resolved Hide resolved
int maxPositions, int maxPosEmbed, int maxSeqLength, int pastSeqLen, int currentSeqLen, int step,
int hiddenSize, void *output, int outputStride, const void *input, int inputStride, const float *ln1Gamma,
const float *ln1Beta, const void *queryWeight, const void *keyWeight, const void *valueWeight,
const void *attnOutWeight, const void *queryBias = nullptr, const void *keyBias = nullptr,
const void *valueBias = nullptr, const void *attnOutBias = nullptr);

} // namespace xft
179 changes: 179 additions & 0 deletions src/layers/attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
// Copyright (c) 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#include "attention.h"
#include "kvcache_manager.h"
#include "layers_attention.h"
#include "rms_norm.h"

#include <unordered_map>

namespace xft {

void invokeAttentionLLaMA(DataType dt, int batchSize, int inputSeqLen, int attHeadDim, int attHeadNum, int kvHeadNum,
int maxPositions, int maxPosEmbed, int maxSeqLength, int pastSeqLen, int currentSeqLen, int step,
int hiddenSize, void *output, int outputStride, const void *input, int inputStride, const float *ln1Gamma,
const float *ln1Beta, const void *queryWeight, const void *keyWeight, const void *valueWeight,
const void *attnOutWeight, const void *queryBias, const void *keyBias, const void *valueBias,
const void *attnOutBias) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);

auto prepareAttnMask = [&](DecoderContext *ctx, int step) {
int seqLen = ctx->inputSeqLen;
int accSeqLen = pastSeqLen + currentSeqLen;
float *mask = nullptr;

auto getAttnMask = [](int sizeRequired) {
static float *attnMask;
static int maskSize = 0;
if (maskSize < sizeRequired) {
if (attnMask) free(attnMask);
attnMask = (float *)xft::alloc(sizeRequired * sizeof(float));
maskSize = sizeRequired;
}
return attnMask;
};

if (step == 0) {
int sizeRequired = ctx->batchSize * seqLen * seqLen;
mask = getAttnMask(sizeRequired);
for (int b = 0; b < ctx->batchSize; ++b) {
auto pmask = mask + b * seqLen * seqLen;
for (int i = 0; i < seqLen; ++i) {
memset(pmask + i * seqLen, 0, (i + 1) * sizeof(float)); // bottom left are 0
std::fill_n(pmask + i * seqLen + i + 1, seqLen - i - 1, std::numeric_limits<float>::lowest());
}
}
} else if (seqLen > 1) {
int sizeRequired = ctx->batchSize * accSeqLen * seqLen;
mask = getAttnMask(sizeRequired);
for (int b = 0; b < ctx->batchSize; ++b) {
auto pmask = mask + b * accSeqLen * seqLen;
int pastLen = accSeqLen - seqLen;
for (int i = 0; i < seqLen; ++i) {
memset(pmask + i * accSeqLen, 0, (pastLen + i + 1) * sizeof(float));
std::fill_n(pmask + i * accSeqLen + pastLen + i + 1, seqLen - i - 1,
std::numeric_limits<float>::lowest());
}
}
} else {
int sizeRequired = ctx->batchSize * accSeqLen;
mask = getAttnMask(sizeRequired);
memset(mask, 0, ctx->batchSize * accSeqLen * sizeof(float)); // all elements are 0
}

return mask;
};

if (dt == DataType::bf16) {
static std::unordered_map<std::string, Attention<bfloat16_t, LlamaRotaryEmbedding, RmsNorm> *>
llama_attention_hub;

static DecoderContext *ctx;
static KVCacheManager<float> *kvCacheMgr;

if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->attHeadSize != attHeadDim))) {
if (ctx != nullptr) delete ctx;
printf(">> create context: %d %d\n", hiddenSize, attHeadDim);
ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, 1, "silu", 1e-6, 0, 0,
maxPositions, maxPosEmbed, maxSeqLength, 0, 1);
ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex());
kvCacheMgr = new KVCacheManager<float>(1);
}

// create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed.
std::stringstream weights_addr;
weights_addr << queryWeight << "_" << keyWeight << "_" << valueWeight;
std::string llama_attention_key = weights_addr.str();
Attention<bfloat16_t, LlamaRotaryEmbedding, RmsNorm> *llama_attention;

auto it_created = llama_attention_hub.find(llama_attention_key);
if (it_created == llama_attention_hub.end()) {
llama_attention = new Attention<bfloat16_t, LlamaRotaryEmbedding, RmsNorm>(0, ctx);
llama_attention->setWeights(ctx, (float *)queryWeight, nullptr, nullptr, (float *)queryBias,
(float *)keyWeight, nullptr, nullptr, (float *)keyBias, (float *)valueWeight, nullptr, nullptr,
(float *)valueBias, (float *)attnOutWeight, nullptr, nullptr, (float *)attnOutBias, ln1Gamma,
ln1Beta, false);
llama_attention_hub[llama_attention_key] = llama_attention;
printf(">> create llama_attention_key: %s\n", llama_attention_key.c_str());
} else {
llama_attention = it_created->second;
}

ctx->resize(batchSize, inputSeqLen, pastSeqLen);
hpj::Matrix<float> actBuffers;
actBuffers.Resize(batchSize * inputSeqLen * 2, hiddenSize);
float *attnMask = prepareAttnMask(ctx, step);

int workers = 1;
int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers;
kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim);
KVCacheTensor<float> &presentKey = kvCacheMgr->getKey(0);
KVCacheTensor<float> &presentValue = kvCacheMgr->getValue(0);

llama_attention->forward(ctx, (float *)input, actBuffers.Data(), (float *)output, attnMask, presentKey,
presentValue, inputSeqLen, pastSeqLen, step == 0, false, nullptr);
} else if (dt == DataType::fp16) {
static std::unordered_map<std::string, Attention<float16_t, LlamaRotaryEmbedding, RmsNorm> *>
llama_attention_hub;

static DecoderContext *ctx;
static KVCacheManager<float> *kvCacheMgr;

if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->attHeadSize != attHeadDim))) {
if (ctx != nullptr) delete ctx;
printf(">> create context: %d %d\n", hiddenSize, attHeadDim);
ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, 1, "silu", 1e-6, 0, 0,
maxPositions, maxPosEmbed, maxSeqLength, 0, 1);
ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex());
kvCacheMgr = new KVCacheManager<float>(1);
}

// create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed.
std::stringstream weights_addr;
weights_addr << queryWeight << "_" << keyWeight << "_" << valueWeight;
std::string llama_attention_key = weights_addr.str();
Attention<float16_t, LlamaRotaryEmbedding, RmsNorm> *llama_attention;

auto it_created = llama_attention_hub.find(llama_attention_key);
if (it_created == llama_attention_hub.end()) {
llama_attention = new Attention<float16_t, LlamaRotaryEmbedding, RmsNorm>(0, ctx);
llama_attention->setWeights(ctx, (float *)queryWeight, nullptr, nullptr, (float *)queryBias,
(float *)keyWeight, nullptr, nullptr, (float *)keyBias, (float *)valueWeight, nullptr, nullptr,
(float *)valueBias, (float *)attnOutWeight, nullptr, nullptr, (float *)attnOutBias, ln1Gamma,
ln1Beta, false);
llama_attention_hub[llama_attention_key] = llama_attention;
printf(">> create llama_attention_key: %s\n", llama_attention_key.c_str());
} else {
llama_attention = it_created->second;
}

ctx->resize(batchSize, inputSeqLen, pastSeqLen);
hpj::Matrix<float> actBuffers;
actBuffers.Resize(batchSize * inputSeqLen * 2, hiddenSize);
float *attnMask = prepareAttnMask(ctx, step);

int workers = 1;
int headsPerSplit = (ctx->kvHeadNum + workers - 1) / workers;
kvCacheMgr->resize(maxPositions, batchSize, headsPerSplit, attHeadDim);
pujiang2018 marked this conversation as resolved.
Show resolved Hide resolved
KVCacheTensor<float> &presentKey = kvCacheMgr->getKey(0);
KVCacheTensor<float> &presentValue = kvCacheMgr->getValue(0);

llama_attention->forward(ctx, (float *)input, actBuffers.Data(), (float *)output, attnMask, presentKey,
presentValue, inputSeqLen, pastSeqLen, step == 0, false, nullptr);
}
}

} // namespace xft
7 changes: 3 additions & 4 deletions src/layers/rotary_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) {
const std::string emb_cos_str = "emb_cos";
const std::string emb_sin_str = "emb_sin";

// dim: equals to head size
ctx->GetAttr("size_per_head", &this->dim);
ctx->GetAttr("max_pos_seq_len", &this->max_position_embeddings, 2048);
this->dim = ctx->attHeadSize;
this->max_position_embeddings = ctx->maxPosEmbed;
ctx->GetAttr("rope_theta", &this->base, 10000);
ctx->GetAttr("rope_type", &this->rope_type, std::to_string(-1));

if (this->rope_type == "linear")
if (this->rope_type == "linear")
ctx->GetAttr("scaling_factor", &this->scaling_factor, 1.0f);

inv_freq_size = (dim + 1) / 2;
Expand Down
2 changes: 1 addition & 1 deletion src/models/model_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,4 @@ class DecoderRegister {
MODEL(IMPLEMENT, CLASS, NAME)

#define REGISTER_MODEL(CLASS, NAME) \
MODEL(REGISTER, CLASS, NAME)
MODEL(REGISTER, CLASS, NAME)
pujiang2018 marked this conversation as resolved.
Show resolved Hide resolved
129 changes: 129 additions & 0 deletions tests/ut/layers_attention_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) 2023 Intel Corporation
pujiang2018 marked this conversation as resolved.
Show resolved Hide resolved
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#include <chrono>
#include <cmath>
#include <type_traits>

#include "bfloat16.h"
#include "float16.h"
#include "layers_attention.h"
#include "gtest/gtest.h"

template <typename T>
static void compareAttentionLLaMA(int step, int batchSize, int inputSeqLen, int pastSeqLen, int currentSeqLen,
int attHeadDim, int attHeadNum, int kvHeadNum, int maxPositions, int maxPosEmbed, int maxSeqLength,
int hiddenSize, const float *ln1Gamma, const float *ln1Beta, const void *queryWeight, const void *keyWeight,
const void *valueWeight, const void *attnOutWeight) {
// Create input
float *input = (float *)aligned_alloc(64, batchSize * inputSeqLen * hiddenSize * sizeof(float));
float *ourOutput = (float *)aligned_alloc(64, batchSize * inputSeqLen * hiddenSize * sizeof(float));
memset(ourOutput, 0, batchSize * inputSeqLen * hiddenSize * sizeof(float));

for (int i = 0; i < batchSize * inputSeqLen * hiddenSize; ++i) {
input[i] = static_cast<float>(1.0f * rand() / RAND_MAX);
}

xft::DataType dt = xft::DataType::unknown;
if constexpr (std::is_same<T, bfloat16_t>::value) {
dt = xft::DataType::bf16;
} else if constexpr (std::is_same<T, float16_t>::value) {
dt = xft::DataType::fp16;
} else {
printf("Unsupported data type\n");
GTEST_FAIL();
return;
}

auto start = std::chrono::high_resolution_clock::now();
invokeAttentionLLaMA(dt, batchSize, inputSeqLen, attHeadDim, attHeadNum, kvHeadNum, maxPositions, maxPosEmbed,
maxSeqLength, pastSeqLen, currentSeqLen, step, hiddenSize, (void *)ourOutput, hiddenSize,
(const void *)input, hiddenSize, ln1Gamma, ln1Beta, (const void *)queryWeight, (const void *)keyWeight,
(const void *)valueWeight, (const void *)attnOutWeight);
auto end = std::chrono::high_resolution_clock::now();
float during_time = std::chrono::duration<float>(end - start).count();
printf("[ RUNTIME ] XFT::invokeAttentionLLaMA %.6f sec\n", during_time);

free(input);
free(ourOutput);
}

template <typename T>
void test_AttentionLLaMA(void) {
int maxPosEmbed = 4096;
int maxPositions = maxPosEmbed;
int maxSeqLength = maxPosEmbed;
int hiddenSize = 4096;
int attHeadNum = 32;
int attHeadDim = hiddenSize / attHeadNum;
int kvHeadNum = 32;
int qSize = attHeadDim * attHeadNum;
int kvSize = attHeadDim * kvHeadNum;

float *ln1Gamma = (float *)aligned_alloc(64, hiddenSize * sizeof(float));
float *ln1Beta = (float *)aligned_alloc(64, hiddenSize * sizeof(float));
float *qkvProj = (float *)aligned_alloc(64, hiddenSize * (qSize + 2 * kvSize) * sizeof(float));
float *oProj = (float *)aligned_alloc(64, hiddenSize * hiddenSize * sizeof(float));

for (int i = 0; i < hiddenSize; ++i) {
ln1Gamma[i] = static_cast<float>(0.5f * rand() / RAND_MAX);
ln1Beta[i] = static_cast<float>(0.5f * rand() / RAND_MAX);
}

for (int i = 0; i < hiddenSize * (qSize + 2 * kvSize); ++i) {
qkvProj[i] = static_cast<float>(0.5f * rand() / RAND_MAX);
}

for (int i = 0; i < hiddenSize * hiddenSize; ++i) {
oProj[i] = static_cast<float>(0.5f * rand() / RAND_MAX);
}

int step = 0;
int batchSize = 1;
int inputSeqLen = 18;
int pastSeqLen = 0;
int currentSeqLen = inputSeqLen;
int nextTokenNum = 1;

compareAttentionLLaMA<T>(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum,
kvHeadNum, maxPositions, maxPosEmbed, maxSeqLength, hiddenSize, ln1Gamma, ln1Beta, qkvProj, qkvProj + qSize,
qkvProj + kvSize, oProj);
pastSeqLen += inputSeqLen;
currentSeqLen = nextTokenNum;
compareAttentionLLaMA<T>(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum,
kvHeadNum, maxPositions, maxPosEmbed, maxSeqLength, hiddenSize, ln1Gamma, ln1Beta, qkvProj, qkvProj + qSize,
qkvProj + kvSize, oProj);
pastSeqLen += nextTokenNum;
compareAttentionLLaMA<T>(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum,
kvHeadNum, maxPositions, maxPosEmbed, maxSeqLength, hiddenSize, ln1Gamma, ln1Beta, qkvProj, qkvProj + qSize,
qkvProj + kvSize, oProj);

free(ln1Gamma);
free(ln1Beta);
free(qkvProj);
free(oProj);
}

TEST(AttentionLLaMA, bfloat16_t) {
test_AttentionLLaMA<bfloat16_t>();
}

TEST(AttentionLLaMA, float16_t) {
test_AttentionLLaMA<float16_t>();
}

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}