From 516986c58fe6b1b5b328a37784827cc6a8fd69da Mon Sep 17 00:00:00 2001 From: changqi1 Date: Sat, 25 Nov 2023 12:01:23 +0800 Subject: [PATCH] Optimize context perf. --- src/layers/mlp_llama.cpp | 29 +++++++++++++++-------------- tests/ut/layers_mlp_test.cpp | 21 +++++++++++++++------ 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/src/layers/mlp_llama.cpp b/src/layers/mlp_llama.cpp index 33d0397a..a3b4af4d 100644 --- a/src/layers/mlp_llama.cpp +++ b/src/layers/mlp_llama.cpp @@ -22,38 +22,39 @@ namespace xft { void invokeMLPLLaMA(DataType dt, int numTokens, int hiddenSize, int intermediateSize, void *output, int outputStride, const void *input, int inputStride, const void *gateWeight, const void *upWeight, const void *downWeight) { static std::mutex mutex; - std::lock_guard lock(mutex); if (dt == DataType::bf16) { - static std::unordered_map *>> llama_mlp_hub; + static std::unordered_map *> llama_mlp_hub; - // create hash key + static DecoderContext *ctx; + if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) { + delete ctx; + printf(">> create context: %d %d\n", hiddenSize, intermediateSize); + ctx = new DecoderContext(1, hiddenSize, 1, 1, intermediateSize, "silu", 1e-6, 0, 0, 0, 0, 0, 1); + } + + // create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed. std::stringstream weights_addr; weights_addr << gateWeight << "_" << upWeight << "_" << downWeight; - std::string llama_mlp_key - = std::to_string(hiddenSize) + "_" + std::to_string(intermediateSize) + "_" + weights_addr.str(); - - DecoderContext *ctx; + std::string llama_mlp_key = weights_addr.str(); LlamaMLP *llama_mlp; auto it_created = llama_mlp_hub.find(llama_mlp_key); if (it_created == llama_mlp_hub.end()) { // LlamaMLP &llama_mlp = LlamaMLP::getInstance(); - ctx = new DecoderContext(1, hiddenSize, 1, 1, intermediateSize, "silu", 1e-6, 0, 0, 0, 0, 0, 1); std::vector params {(float *)gateWeight, (float *)nullptr, (float *)upWeight, (float *)nullptr, (float *)nullptr, (float *)nullptr, (float *)downWeight}; llama_mlp = new LlamaMLP; llama_mlp->setWeights(ctx, params, false); - - std::tuple *> value(ctx, llama_mlp); - llama_mlp_hub[llama_mlp_key] = value; - printf("create llama_mlp_key: %s\n", llama_mlp_key.c_str()); + llama_mlp_hub[llama_mlp_key] = llama_mlp; + printf(">> create llama_mlp_key: %s\n", llama_mlp_key.c_str()); } else { - ctx = std::get<0>(it_created->second); - llama_mlp = std::get<1>(it_created->second); + llama_mlp = it_created->second; } + // Unsupport different type model serving simultaneously because of same DecoderContext + std::lock_guard lock(mutex); ctx->resize(1, numTokens, 0); llama_mlp->forward(ctx, (float *)const_cast(input), (float *)output, inputStride, outputStride, false); } diff --git a/tests/ut/layers_mlp_test.cpp b/tests/ut/layers_mlp_test.cpp index 741df17b..66868c13 100644 --- a/tests/ut/layers_mlp_test.cpp +++ b/tests/ut/layers_mlp_test.cpp @@ -26,6 +26,7 @@ static void matmul(int m, int n, int k, const float *A, const float *B, float *C #pragma omp parallel for collapse(2) for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) { + C[i * n + j] = 0.0f; for (int q = 0; q < k; ++q) { C[i * n + j] += static_cast(static_cast(A[i * k + q])) * static_cast(static_cast(B[q * n + j])); @@ -85,7 +86,7 @@ static void compareMLPLLaMA( memset(refOutput, 0, numTokens * hiddenSize * sizeof(float)); for (int i = 0; i < numTokens * hiddenSize; ++i) { - input[i] = static_cast(1.0f * rand() / RAND_MAX - 0.5f); + input[i] = static_cast(1.0f * rand() / RAND_MAX); } if constexpr (std::is_same::value) { @@ -101,8 +102,9 @@ static void compareMLPLLaMA( } for (int i = 0; i < numTokens * hiddenSize; ++i) { - EXPECT_EQ(std::abs(refOutput[i] - ourOutput[i]) < 0.3 || std::abs((refOutput[i] - ourOutput[i]) / refOutput[i]) < 0.3, - true); + EXPECT_EQ(std::abs(refOutput[i] - ourOutput[i]) > 0.01 + && std::abs((refOutput[i] - ourOutput[i]) / refOutput[i]) > 0.01, + false); } free(input); @@ -119,9 +121,9 @@ TEST(MLPLLaMA, bfloat16_t) { float *downW = (float *)aligned_alloc(64, intermediateSize * hiddenSize * sizeof(float)); for (int i = 0; i < hiddenSize * intermediateSize; ++i) { - gateW[i] = static_cast(1.0f * rand() / RAND_MAX - 0.5f); - upW[i] = static_cast(1.0f * rand() / RAND_MAX - 0.5f); - downW[i] = static_cast(1.0f * rand() / RAND_MAX - 0.5f); + gateW[i] = static_cast(0.5f * rand() / RAND_MAX); + upW[i] = static_cast(0.5f * rand() / RAND_MAX); + downW[i] = static_cast(0.5f * rand() / RAND_MAX); } compareMLPLLaMA(18, hiddenSize, intermediateSize, gateW, upW, downW); @@ -129,8 +131,15 @@ TEST(MLPLLaMA, bfloat16_t) { compareMLPLLaMA(16, hiddenSize, intermediateSize, gateW, upW, downW); compareMLPLLaMA(16, hiddenSize, intermediateSize, gateW, upW, downW); compareMLPLLaMA(16, hiddenSize, intermediateSize, gateW, upW, downW); + compareMLPLLaMA(10, hiddenSize, intermediateSize, gateW, upW, downW); + compareMLPLLaMA(4, hiddenSize, intermediateSize, gateW, upW, downW); + compareMLPLLaMA(2, hiddenSize, intermediateSize, gateW, upW, downW); compareMLPLLaMA(1, hiddenSize, intermediateSize, gateW, upW, downW); compareMLPLLaMA(2, hiddenSize, intermediateSize, gateW, upW, downW); + compareMLPLLaMA(4, hiddenSize, intermediateSize, gateW, upW, downW); + compareMLPLLaMA(6, hiddenSize, intermediateSize, gateW, upW, downW); + compareMLPLLaMA(16, hiddenSize, intermediateSize, gateW, upW, downW); + compareMLPLLaMA(16, hiddenSize, intermediateSize, gateW, upW, downW); free(gateW); free(upW);