Skip to content

Commit

Permalink
Optimize context perf.
Browse files Browse the repository at this point in the history
  • Loading branch information
changqi1 committed Nov 25, 2023
1 parent ccf8f37 commit 516986c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
29 changes: 15 additions & 14 deletions src/layers/mlp_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,39 @@ namespace xft {
void invokeMLPLLaMA(DataType dt, int numTokens, int hiddenSize, int intermediateSize, void *output, int outputStride,
const void *input, int inputStride, const void *gateWeight, const void *upWeight, const void *downWeight) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);

if (dt == DataType::bf16) {
static std::unordered_map<std::string, std::tuple<DecoderContext *, LlamaMLP<bfloat16_t> *>> llama_mlp_hub;
static std::unordered_map<std::string, LlamaMLP<bfloat16_t> *> llama_mlp_hub;

// create hash key
static DecoderContext *ctx;
if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) {
delete ctx;
printf(">> create context: %d %d\n", hiddenSize, intermediateSize);
ctx = new DecoderContext(1, hiddenSize, 1, 1, intermediateSize, "silu", 1e-6, 0, 0, 0, 0, 0, 1);
}

// create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed.
std::stringstream weights_addr;
weights_addr << gateWeight << "_" << upWeight << "_" << downWeight;
std::string llama_mlp_key
= std::to_string(hiddenSize) + "_" + std::to_string(intermediateSize) + "_" + weights_addr.str();

DecoderContext *ctx;
std::string llama_mlp_key = weights_addr.str();
LlamaMLP<bfloat16_t> *llama_mlp;

auto it_created = llama_mlp_hub.find(llama_mlp_key);
if (it_created == llama_mlp_hub.end()) {
// LlamaMLP<bfloat16_t> &llama_mlp = LlamaMLP<bfloat16_t>::getInstance();
ctx = new DecoderContext(1, hiddenSize, 1, 1, intermediateSize, "silu", 1e-6, 0, 0, 0, 0, 0, 1);
std::vector<float *> params {(float *)gateWeight, (float *)nullptr, (float *)upWeight, (float *)nullptr,
(float *)nullptr, (float *)nullptr, (float *)downWeight};

llama_mlp = new LlamaMLP<bfloat16_t>;
llama_mlp->setWeights(ctx, params, false);

std::tuple<DecoderContext *, LlamaMLP<bfloat16_t> *> value(ctx, llama_mlp);
llama_mlp_hub[llama_mlp_key] = value;
printf("create llama_mlp_key: %s\n", llama_mlp_key.c_str());
llama_mlp_hub[llama_mlp_key] = llama_mlp;
printf(">> create llama_mlp_key: %s\n", llama_mlp_key.c_str());
} else {
ctx = std::get<0>(it_created->second);
llama_mlp = std::get<1>(it_created->second);
llama_mlp = it_created->second;
}

// Unsupport different type model serving simultaneously because of same DecoderContext
std::lock_guard<std::mutex> lock(mutex);
ctx->resize(1, numTokens, 0);
llama_mlp->forward(ctx, (float *)const_cast<void *>(input), (float *)output, inputStride, outputStride, false);
}
Expand Down
21 changes: 15 additions & 6 deletions tests/ut/layers_mlp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ static void matmul(int m, int n, int k, const float *A, const float *B, float *C
#pragma omp parallel for collapse(2)
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
C[i * n + j] = 0.0f;
for (int q = 0; q < k; ++q) {
C[i * n + j] += static_cast<float>(static_cast<T>(A[i * k + q]))
* static_cast<float>(static_cast<T>(B[q * n + j]));
Expand Down Expand Up @@ -85,7 +86,7 @@ static void compareMLPLLaMA(
memset(refOutput, 0, numTokens * hiddenSize * sizeof(float));

for (int i = 0; i < numTokens * hiddenSize; ++i) {
input[i] = static_cast<float>(1.0f * rand() / RAND_MAX - 0.5f);
input[i] = static_cast<float>(1.0f * rand() / RAND_MAX);
}

if constexpr (std::is_same<T, bfloat16_t>::value) {
Expand All @@ -101,8 +102,9 @@ static void compareMLPLLaMA(
}

for (int i = 0; i < numTokens * hiddenSize; ++i) {
EXPECT_EQ(std::abs(refOutput[i] - ourOutput[i]) < 0.3 || std::abs((refOutput[i] - ourOutput[i]) / refOutput[i]) < 0.3,
true);
EXPECT_EQ(std::abs(refOutput[i] - ourOutput[i]) > 0.01
&& std::abs((refOutput[i] - ourOutput[i]) / refOutput[i]) > 0.01,
false);
}

free(input);
Expand All @@ -119,18 +121,25 @@ TEST(MLPLLaMA, bfloat16_t) {
float *downW = (float *)aligned_alloc(64, intermediateSize * hiddenSize * sizeof(float));

for (int i = 0; i < hiddenSize * intermediateSize; ++i) {
gateW[i] = static_cast<float>(1.0f * rand() / RAND_MAX - 0.5f);
upW[i] = static_cast<float>(1.0f * rand() / RAND_MAX - 0.5f);
downW[i] = static_cast<float>(1.0f * rand() / RAND_MAX - 0.5f);
gateW[i] = static_cast<float>(0.5f * rand() / RAND_MAX);
upW[i] = static_cast<float>(0.5f * rand() / RAND_MAX);
downW[i] = static_cast<float>(0.5f * rand() / RAND_MAX);
}

compareMLPLLaMA<bfloat16_t>(18, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(10, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(4, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(2, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(1, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(2, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(4, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(6, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);

free(gateW);
free(upW);
Expand Down

0 comments on commit 516986c

Please sign in to comment.