diff --git a/.gitignore b/.gitignore index 37f80372..0345144e 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,4 @@ dist/ /3rdparty/onednn /3rdparty/cmdline /3rdparty/sentencepiece -/3rdparty/ig \ No newline at end of file +/3rdparty/xdnn diff --git a/CMakeLists.txt b/CMakeLists.txt index 1fc30c67..f16a97c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,14 +56,14 @@ endif() include("cmake/mklml.cmake") include("cmake/onednn.cmake") -include("cmake/ig.cmake") +include("cmake/xdnn.cmake") include(GNUInstallDirs) include_directories(${CMAKE_SOURCE_DIR}/3rdparty/) include_directories(${CMAKE_SOURCE_DIR}/3rdparty/mklml/include) include_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/include) include_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/build/include) -include_directories(${CMAKE_SOURCE_DIR}/3rdparty/ig) +include_directories(${CMAKE_SOURCE_DIR}/3rdparty/xdnn) include_directories(${CMAKE_SOURCE_DIR}/include) include_directories(${CMAKE_SOURCE_DIR}/src/kernels) include_directories(${CMAKE_SOURCE_DIR}/src/layers) @@ -75,18 +75,18 @@ include_directories(${CMAKE_SOURCE_DIR}/src/common) link_directories(${CMAKE_SOURCE_DIR}/src/kernels) link_directories(${CMAKE_SOURCE_DIR}/3rdparty/mklml/lib) link_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/build/src) -link_directories(${CMAKE_SOURCE_DIR}/3rdparty/ig) +link_directories(${CMAKE_SOURCE_DIR}/3rdparty/xdnn) set(3RDPART_LIB_LIST "MPI::MPI_CXX" "ccl" "dnnl" "numa") -set(DEPEND_LIST "onednn" "mklml" "ig_lib") +set(DEPEND_LIST "onednn" "mklml" "xdnn_lib") option(BUILD_WITH_SHARED_LIBS "Build with shared libraries" OFF) if(BUILD_WITH_SHARED_LIBS) message("Building with shared libraries.") - list(APPEND 3RDPART_LIB_LIST "ig") + list(APPEND 3RDPART_LIB_LIST "xdnn") else() message("Building with static libraries.") - list(APPEND 3RDPART_LIB_LIST "ig_static") + list(APPEND 3RDPART_LIB_LIST "xdnn_static") endif() # Enable AVX512_FP16 optimization diff --git a/cmake/ig.cmake b/cmake/xdnn.cmake similarity index 85% rename from cmake/ig.cmake rename to cmake/xdnn.cmake index 2f88bd43..a8f28149 100644 --- a/cmake/ig.cmake +++ b/cmake/xdnn.cmake @@ -25,11 +25,11 @@ project(dependency NONE) include(ExternalProject) # cmake-format: off -ExternalProject_Add(ig_lib - URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/ig_v1.1.tar.gz - URL_HASH MD5=47e5a2cd021caad2b1367c0b71dff2e7 +ExternalProject_Add(xdnn_lib + URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/xdnn_v1.1.tar.gz + URL_HASH MD5=b49bf8808d66ea75cfba80a406c9a587 TIMEOUT 60 - SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rdparty/ig + SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rdparty/xdnn CONFIGURE_COMMAND "" BUILD_COMMAND "" INSTALL_COMMAND "" diff --git a/include/layers_norm.h b/include/layers_norm.h index c0a92c45..75d1778e 100644 --- a/include/layers_norm.h +++ b/include/layers_norm.h @@ -14,13 +14,15 @@ // ============================================================================ #pragma once +#include "dtype.h" + namespace xft { -void invokeLayerNorm(float *output, const float *input, const float *gamma, const float *beta, const int rows, +void invokeLayerNorm(DataType dt, void *output, const void *input, const void *gamma, const void *beta, const int rows, const int size, int iStride = -1, int oStride = -1, const float epsilon = 1e-5); -void invokeRmsNorm(float *output, const float *input, const float *weight, int rows, int cols, int iStride = -1, - int oStride = -1, float epsilon = 1e-6); +void invokeRmsNorm(DataType dt, void *output, const void *input, const void *weight, int rows, int cols, + int iStride = -1, int oStride = -1, float epsilon = 1e-6); // Layer normalization: only support the norm along last dimension class LayerNorm { diff --git a/src/kernels/layernorm_kernels.cpp b/src/kernels/layernorm_kernels.cpp index e5600de3..630419ca 100644 --- a/src/kernels/layernorm_kernels.cpp +++ b/src/kernels/layernorm_kernels.cpp @@ -15,17 +15,14 @@ #include #include "bfloat16.h" +#include "dtype.h" #include "float16.h" +#include "intrinsic_ext.h" +#include "layernorm_kernels.h" #include "my_types.h" namespace xft { -template -struct LayerNormWeight { - const T *gamma = nullptr; - const T *beta = nullptr; -}; - void invokeLayerNorm(float *output, const float *input, const float *gamma, const float *beta, const int rows, const int size, int iStride, int oStride, const float epsilon) { @@ -79,4 +76,67 @@ void invokeLayerNorm(float *output, const float *input, const float *gamma, cons } } } + +void invokeLayerNorm(bfloat16_t *output, const bfloat16_t *input, const bfloat16_t *gamma, const bfloat16_t *beta, + const int rows, const int size, int iStride, int oStride, const float epsilon) { + + if (iStride == -1) iStride = size; + if (oStride == -1) oStride = size; + +#pragma omp parallel for + for (int r = 0; r < rows; ++r) { + const bfloat16_t *px = input + r * iStride; + bfloat16_t *py = output + r * oStride; + + float sum = 0; + float squareSum = 0; + + __m512 vsum = _mm512_set1_ps(0); + __m512 vsqare = _mm512_set1_ps(0); + + for (int col = 0; col < size; col += 16) { + int remain = size - col; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + + // SUM(x) + __m512 vx = _mm512_maskz_loadu_pbh(mask, px + col); + vsum = _mm512_add_ps(vsum, vx); + + // SUM(x*x) + __m512 tmp = _mm512_mul_ps(vx, vx); + vsqare = _mm512_add_ps(vsqare, tmp); + } + + sum = _mm512_reduce_add_ps(vsum); + squareSum = _mm512_reduce_add_ps(vsqare); + + // Mean + float mean = sum / size; + __m512 vmean = _mm512_set1_ps(mean); + + // Variance + float var = 1 / sqrt(squareSum / size - mean * mean + epsilon); + __m512 vvar = _mm512_set1_ps(var); + + for (int col = 0; col < size; col += 16) { + int remain = size - col; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + + __m512 vx = _mm512_maskz_loadu_pbh(mask, px + col); + __m512 vgamma = _mm512_maskz_loadu_pbh(mask, gamma + col); + __m512 vbeta = _mm512_maskz_loadu_pbh(mask, beta + col); + __m512 vy = (vx - vmean) * vgamma * vvar + vbeta; + _mm512_mask_storeu_pbh(py + col, mask, vy); + } + } +} + +void invokeLayerNorm(DataType dt, void *output, const void *input, const void *gamma, const void *beta, const int rows, + const int size, int iStride, int oStride, const float epsilon) { + if (dt == DataType::bf16) { + invokeLayerNorm((bfloat16_t *)output, (const bfloat16_t *)input, (const bfloat16_t *)gamma, + (const bfloat16_t *)beta, rows, size, iStride, oStride, epsilon); + } +} + } // namespace xft \ No newline at end of file diff --git a/src/kernels/layernorm_kernels.h b/src/kernels/layernorm_kernels.h new file mode 100644 index 00000000..a6a03aee --- /dev/null +++ b/src/kernels/layernorm_kernels.h @@ -0,0 +1,38 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once + +#include + +#include "bfloat16.h" +#include "dtype.h" +#include "float16.h" +#include "my_types.h" + +namespace xft { + +template +struct LayerNormWeight { + const T *gamma = nullptr; + const T *beta = nullptr; +}; + +void invokeLayerNorm(float *output, const float *input, const float *gamma, const float *beta, const int rows, + const int size, int iStride = -1, int oStride = -1, const float epsilon = 1e-5); + +void invokeLayerNorm(bfloat16_t *output, const bfloat16_t *input, const bfloat16_t *gamma, const bfloat16_t *beta, + const int rows, const int size, int iStride = -1, int oStride = -1, const float epsilon = 1e-5); + +} // namespace xft \ No newline at end of file diff --git a/src/kernels/rmsnorm_kernels.cpp b/src/kernels/rmsnorm_kernels.cpp index aa8bca0a..25cd9913 100644 --- a/src/kernels/rmsnorm_kernels.cpp +++ b/src/kernels/rmsnorm_kernels.cpp @@ -15,8 +15,11 @@ #include #include "bfloat16.h" +#include "dtype.h" #include "float16.h" +#include "intrinsic_ext.h" #include "my_types.h" +#include "rmsnorm_kernels.h" namespace xft { @@ -71,4 +74,65 @@ void invokeRmsNorm(float *output, const float *input, const float *weight, int r } } // end for rows } + +void invokeRmsNorm(bfloat16_t *output, const bfloat16_t *input, const bfloat16_t *weight, int rows, int cols, + int iStride, int oStride, float epsilon) { + int size = cols; + + if (iStride == -1) iStride = cols; + if (oStride == -1) oStride = cols; + +#pragma omp parallel for + for (int r = 0; r < rows; ++r) { + const bfloat16_t *px = input + r * iStride; + bfloat16_t *py = output + r * oStride; + + float squareSum = 0; + + __m512 vsqare = _mm512_set1_ps(0); + + int col = 0; + for (; col + 15 < size; col += 16) { + // SUM(x*x) + __m512 vx = _mm512_loadu_pbh(px + col); + __m512 tmp = _mm512_mul_ps(vx, vx); + vsqare = _mm512_add_ps(vsqare, tmp); + } + if (col < size) { + __mmask16 mask = (1 << (size - col)) - 1; + __m512 vx = _mm512_maskz_loadu_pbh(mask, px + col); + __m512 tmp = _mm512_mul_ps(vx, vx); + vsqare = _mm512_add_ps(vsqare, tmp); + } + + squareSum = _mm512_reduce_add_ps(vsqare); + + // Variance + float var = 1 / sqrt(squareSum / size + epsilon); + __m512 vvar = _mm512_set1_ps(var); + + for (col = 0; col + 15 < size; col += 16) { + __m512 vx = _mm512_loadu_pbh(px + col); + __m512 vw = _mm512_loadu_pbh(weight + col); + __m512 vy = vx * vvar * vw; + _mm512_storeu_pbh(py + col, vy); + } + if (col < size) { + __mmask16 mask = (1 << (size - col)) - 1; + __m512 vx = _mm512_maskz_loadu_pbh(mask, px + col); + __m512 vw = _mm512_maskz_loadu_pbh(mask, weight + col); + __m512 vy = vx * vvar * vw; + _mm512_mask_storeu_pbh(py + col, mask, vy); + } + } // end for rows +} + +void invokeRmsNorm(DataType dt, void *output, const void *input, const void *weight, int rows, int cols, int iStride, + int oStride, float epsilon) { + if (dt == DataType::bf16) { + invokeRmsNorm((bfloat16_t *)output, (const bfloat16_t *)input, (const bfloat16_t *)weight, rows, cols, iStride, + oStride, epsilon); + } +} + } // namespace xft \ No newline at end of file diff --git a/src/kernels/rmsnorm_kernels.h b/src/kernels/rmsnorm_kernels.h new file mode 100644 index 00000000..51869a54 --- /dev/null +++ b/src/kernels/rmsnorm_kernels.h @@ -0,0 +1,31 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#pragma once + +#include + +#include "bfloat16.h" +#include "float16.h" +#include "my_types.h" + +namespace xft { + +void invokeRmsNorm(float *output, const float *input, const float *weight, int rows, int cols, int iStride = -1, + int oStride = -1, float epsilon = 1e-6); + +void invokeRmsNorm(bfloat16_t *output, const bfloat16_t *input, const bfloat16_t *weight, int rows, int cols, + int iStride = -1, int oStride = -1, float epsilon = 1e-6); + +} // namespace xft \ No newline at end of file diff --git a/src/layers/attention.h b/src/layers/attention.h index 1c352a08..7651a692 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -496,9 +496,9 @@ class Attention { C = result.Row(b * ctx->inputSeqLen + startSeq) + i * ctx->attHeadSize; if constexpr (std::is_same_v) { - ig_sgemm_single_thread(false, false, m, n, k, 1.0f, A, lda, B, ldb, 0.0f, C, ldc); + xdnn_sgemm_single_thread(false, false, m, n, k, 1.0f, A, lda, B, ldb, 0.0f, C, ldc); } else if constexpr (std::is_same_v) { - ig_sgemm_f32f16f32_single_thread(false, false, m, n, k, 1.0f, A, lda, B, ldb, 0.0f, C, ldc); + xdnn_sgemm_f32f16f32_single_thread(false, false, m, n, k, 1.0f, A, lda, (const XDNN_FP16 *)B, ldb, 0.0f, C, ldc); } #ifdef DEBUG diff --git a/src/layers/layer_norm.cpp b/src/layers/layer_norm.cpp index 9d3be033..26551c3a 100644 --- a/src/layers/layer_norm.cpp +++ b/src/layers/layer_norm.cpp @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================ -#pragma once +#include + #include #include -#include +#include "layernorm_kernels.h" #include "layers_norm.h" #include "timeline.h" @@ -45,7 +46,7 @@ void LayerNorm::forward(const float *input, float *output, int rows, int iStride TimeLine t("LayerNorm.forward"); const float *pgamma = weights; const float *pbeta = weights + normSize; - xft::invokeLayerNorm(output, input, pgamma, pbeta, rows, normSize, iStride, oStride); + invokeLayerNorm(output, input, pgamma, pbeta, rows, normSize, iStride, oStride); } } // namespace xft \ No newline at end of file diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index f035a60e..9acca4d0 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================ -#pragma once #include #include #include #include "layers_norm.h" +#include "rmsnorm_kernels.h" #include "timeline.h" namespace xft { @@ -41,7 +41,7 @@ void RmsNorm::setWeight(const float *w, const float *, int size) { // input and output are in shape of (rows, normSize) void RmsNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); - xft::invokeRmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + invokeRmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); } } // namespace xft \ No newline at end of file diff --git a/src/utils/decoder_util.h b/src/utils/decoder_util.h index 6cc7901c..be6efcb6 100644 --- a/src/utils/decoder_util.h +++ b/src/utils/decoder_util.h @@ -542,7 +542,7 @@ class DecoderUtil { // C = A * B // bTranspose: B need to be transposed or not - // ig_sgemm_single_thread(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + // xdnn_sgemm_single_thread(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); static void sgemm(const float* A, const float* B, float* C, int m, int n, int k, bool transa, bool transb) { int lda = (transa ? m : k); diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index c51e6b40..7843382c 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -15,13 +15,7 @@ #pragma once #include "bfloat16.h" #include "float16.h" -#include "ig_bgemm_f32bf16f32.h" -#include "ig_hgemm_f32f16f32.h" -#include "ig_hgemm_f16f16f32.h" -#include "ig_hgemm_f32i8f32.h" -#include "ig_sgemm.h" -#include "ig_sgemm_f32f16f32.h" -#include "ig_sgemm_f32i8f32.h" +#include "xdnn.h" #include "my_types.h" #include "oneapi/dnnl/dnnl.hpp" #include "oneapi/dnnl/dnnl_config.h" @@ -216,12 +210,12 @@ class MMHelper { zeroWeight.Resize(colsPerSplit); } #ifdef AVX512_FP32_WEIGHT_ONLY_INT8 - ig_sgemm_f32i8f32_quantize(trans, colsPerSplit, rows, + xdnn_sgemm_f32i8f32_quantize(trans, colsPerSplit, rows, trans ? (src + rows * splitOffset) : (src + splitOffset), trans ? rows : cols, 0.9999f, quantizedWeight.Data(), trans ? rows : colsPerSplit, scaleWeight.Data(), zeroWeight.Data()); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) - ig_hgemm_f32i8f32_quantize(trans, colsPerSplit, rows, + xdnn_hgemm_f32i8f32_quantize(trans, colsPerSplit, rows, trans ? (src + rows * splitOffset) : (src + splitOffset), trans ? rows : cols, 0.9999f, quantizedWeight.Data(), trans ? rows : colsPerSplit, scaleWeight.Data(), zeroWeight.Data()); @@ -241,12 +235,12 @@ class MMHelper { zeroWeight.Resize(cols); } #ifdef AVX512_FP32_WEIGHT_ONLY_INT8 - ig_sgemm_f32i8f32_quantize(trans, cols, rowsPerSplit, + xdnn_sgemm_f32i8f32_quantize(trans, cols, rowsPerSplit, trans ? (src + splitOffset) : (src + splitOffset * cols), trans ? rows : cols, 0.9999f, quantizedWeight.Data(), trans ? rowsPerSplit : cols, scaleWeight.Data(), zeroWeight.Data()); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) - ig_hgemm_f32i8f32_quantize(trans, cols, rowsPerSplit, + xdnn_hgemm_f32i8f32_quantize(trans, cols, rowsPerSplit, trans ? (src + splitOffset) : (src + splitOffset * cols), trans ? rows : cols, 0.9999f, quantizedWeight.Data(), trans ? rowsPerSplit : cols, scaleWeight.Data(), zeroWeight.Data()); @@ -293,16 +287,16 @@ class MMHelper { // FP32 if constexpr (std::is_same_v) { weight.Resize(K, N); - ig_sgemm_packb(trans, N, K, src.Data(), src.Stride(), weight.Data()); + xdnn_sgemm_packb(trans, N, K, src.Data(), src.Stride(), weight.Data()); } // FP16 else if constexpr (std::is_same_v) { weight.Resize(K, N); #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - ig_sgemm_f32f16f32_packb(trans, N, K, src.Data(), src.Stride(), weight.Data()); + xdnn_sgemm_f32f16f32_packb(trans, N, K, (const XDNN_FP16 *)src.Data(), src.Stride(), (XDNN_FP16 *)weight.Data()); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) - ig_hgemm_f32f16f32_packb(trans, N, K, src.Data(), src.Stride(), weight.Data()); + xdnn_hgemm_f32f16f32_packb(trans, N, K, (const XDNN_FP16 *)src.Data(), src.Stride(), (XDNN_FP16 *)weight.Data()); #else printf("%s:%d: Need to define WEIGHT_ONLY_FP16 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -317,9 +311,9 @@ class MMHelper { weight.Resize(amx_rows, amx_cols); memset(weight.Data(), 0, amx_rows * amx_cols * sizeof(bfloat16_t)); #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 - ig_sgemm_f32bf16f32_packb(trans, N, K, src.Data(), src.Stride(), weight.Data(), 16, 64); + xdnn_sgemm_f32bf16f32_packb(trans, N, K, (const XDNN_BF16 *)src.Data(), src.Stride(), (XDNN_BF16 *)weight.Data(), 16, 64); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) - ig_bgemm_f32bf16f32_packb(trans, N, K, src.Data(), src.Stride(), weight.Data(), 16, 64); + xdnn_bgemm_f32bf16f32_packb(trans, N, K, (const XDNN_BF16 *)src.Data(), src.Stride(), (XDNN_BF16 *)weight.Data(), 16, 64); #else printf("%s:%d: Need to define WEIGHT_ONLY_BF16 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -330,9 +324,9 @@ class MMHelper { else if constexpr (std::is_same_v) { weight.Resize(K, N); #ifdef AVX512_FP32_WEIGHT_ONLY_INT8 - ig_sgemm_f32i8f32_packb(trans, N, K, src.Data(), src.Stride(), weight.Data()); + xdnn_sgemm_f32i8f32_packb(trans, N, K, src.Data(), src.Stride(), weight.Data()); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) - ig_hgemm_f32i8f32_packb(trans, N, K, src.Data(), src.Stride(), weight.Data()); + xdnn_hgemm_f32i8f32_packb(trans, N, K, src.Data(), src.Stride(), weight.Data()); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -345,18 +339,18 @@ class MMHelper { const float *scaleB, const float *zeroB, float beta, float *C, int ldc) { // FP32 if constexpr (std::is_same_v) { - TimeLine t("ig_sgemm_compute"); - ig_sgemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("xdnn_sgemm_compute"); + xdnn_sgemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); } // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - TimeLine t("ig_sgemm_f32f16f32_compute"); - ig_sgemm_f32f16f32_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("xdnn_sgemm_f32f16f32_compute"); + xdnn_sgemm_f32f16f32_compute(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) - TimeLine t("ig_hgemm_f32f16f32_compute"); - ig_hgemm_f32f16f32_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("xdnn_hgemm_f32f16f32_compute"); + xdnn_hgemm_f32f16f32_compute(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc); #else printf("%s:%d: Need to define WEIGHT_ONLY_FP16 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -366,15 +360,15 @@ class MMHelper { // BF16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 - TimeLine t("ig_sgemm_f32bf16f32_compute"); - ig_sgemm_f32bf16f32_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("xdnn_sgemm_f32bf16f32_compute"); + xdnn_sgemm_f32bf16f32_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if (M > USE_AMX_M) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute"); - ig_amx_sgemm_f32bf16f32_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute"); + onednn_amx_sgemm_f32bf16f32_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); } else { - TimeLine t("ig_bgemm_f32bf16f32_compute"); - ig_bgemm_f32bf16f32_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("xdnn_bgemm_f32bf16f32_compute"); + xdnn_bgemm_f32bf16f32_compute(transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc); } #else printf("%s:%d: Need to define WEIGHT_ONLY_BF16 kernel data type.\n", __FILE__, __LINE__); @@ -385,11 +379,11 @@ class MMHelper { // INT8 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_INT8 - TimeLine t("ig_sgemm_f32i8f32_compute"); - ig_sgemm_f32i8f32_compute(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc); + TimeLine t("xdnn_sgemm_f32i8f32_compute"); + xdnn_sgemm_f32i8f32_compute(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) - TimeLine t("ig_hgemm_f32i8f32_compute"); - ig_hgemm_f32i8f32_compute(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc); + TimeLine t("xdnn_hgemm_f32i8f32_compute"); + xdnn_hgemm_f32i8f32_compute(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -403,18 +397,18 @@ class MMHelper { const float *bias) { // FP32 if constexpr (std::is_same_v) { - TimeLine t("ig_sgemm_compute_biasadd"); - ig_sgemm_compute_biasadd(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); + TimeLine t("xdnn_sgemm_compute_biasadd"); + xdnn_sgemm_compute_biasadd(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); } // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - TimeLine t("ig_sgemm_f32f16f32_compute_biasadd"); - ig_sgemm_f32f16f32_compute_biasadd(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); + TimeLine t("xdnn_sgemm_f32f16f32_compute_biasadd"); + xdnn_sgemm_f32f16f32_compute_biasadd(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) - TimeLine t("ig_hgemm_f32f16f32_compute_biasadd"); - ig_hgemm_f32f16f32_compute_biasadd(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); + TimeLine t("xdnn_hgemm_f32f16f32_compute_biasadd"); + xdnn_hgemm_f32f16f32_compute_biasadd(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias); #else printf("%s:%d: Need to define WEIGHT_ONLY_FP16 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -424,15 +418,15 @@ class MMHelper { // BF16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 - TimeLine t("ig_sgemm_f32bf16f32_compute_biasadd"); - ig_sgemm_f32bf16f32_compute_biasadd(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); + TimeLine t("xdnn_sgemm_f32bf16f32_compute_biasadd"); + xdnn_sgemm_f32bf16f32_compute_biasadd(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if (M > USE_AMX_M) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute_biasadd"); - ig_amx_sgemm_f32bf16f32_compute_biasadd(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_biasadd"); + onednn_amx_sgemm_f32bf16f32_compute_biasadd(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); } else { - TimeLine t("ig_bgemm_f32bf16f32_compute_biasadd"); - ig_bgemm_f32bf16f32_compute_biasadd(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); + TimeLine t("xdnn_bgemm_f32bf16f32_compute_biasadd"); + xdnn_bgemm_f32bf16f32_compute_biasadd(transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias); } #else printf("%s:%d: Need to define WEIGHT_ONLY_BF16 kernel data type.\n", __FILE__, __LINE__); @@ -443,12 +437,12 @@ class MMHelper { // INT8 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_INT8 - TimeLine t("ig_sgemm_f32i8f32_compute_biasadd"); - ig_sgemm_f32i8f32_compute_biasadd( + TimeLine t("xdnn_sgemm_f32i8f32_compute_biasadd"); + xdnn_sgemm_f32i8f32_compute_biasadd( transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) - TimeLine t("ig_hgemm_f32i8f32_compute_biasadd"); - ig_hgemm_f32i8f32_compute_biasadd( + TimeLine t("xdnn_hgemm_f32i8f32_compute_biasadd"); + xdnn_hgemm_f32i8f32_compute_biasadd( transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); @@ -463,18 +457,18 @@ class MMHelper { const float *bias) { // FP32 if constexpr (std::is_same_v) { - TimeLine t("ig_sgemm_compute_biasadd_relu"); - ig_sgemm_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); + TimeLine t("xdnn_sgemm_compute_biasadd_relu"); + xdnn_sgemm_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); } // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - TimeLine t("ig_sgemm_f32f16f32_compute_biasadd_relu"); - ig_sgemm_f32f16f32_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); + TimeLine t("xdnn_sgemm_f32f16f32_compute_biasadd_relu"); + xdnn_sgemm_f32f16f32_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) - TimeLine t("ig_hgemm_f32f16f32_compute_biasadd_relu"); - ig_hgemm_f32f16f32_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); + TimeLine t("xdnn_hgemm_f32f16f32_compute_biasadd_relu"); + xdnn_hgemm_f32f16f32_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias); #else printf("%s:%d: Need to define WEIGHT_ONLY_FP16 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -484,16 +478,16 @@ class MMHelper { // BF16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 - TimeLine t("ig_sgemm_f32bf16f32_compute_biasadd_relu"); - ig_sgemm_f32bf16f32_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); + TimeLine t("xdnn_sgemm_f32bf16f32_compute_biasadd_relu"); + xdnn_sgemm_f32bf16f32_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if (M > USE_AMX_M) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute_biasadd_relu"); - ig_amx_sgemm_f32bf16f32_compute_biasadd_relu( + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu"); + onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu( transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); } else { - TimeLine t("ig_bgemm_f32bf16f32_compute_biasadd_relu"); - ig_bgemm_f32bf16f32_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias); + TimeLine t("xdnn_bgemm_f32bf16f32_compute_biasadd_relu"); + xdnn_bgemm_f32bf16f32_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias); } #else printf("%s:%d: Need to define WEIGHT_ONLY_BF16 kernel data type.\n", __FILE__, __LINE__); @@ -504,12 +498,12 @@ class MMHelper { // INT8 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_INT8 - TimeLine t("ig_sgemm_f32i8f32_compute_biasadd_relu"); - ig_sgemm_f32i8f32_compute_biasadd_relu( + TimeLine t("xdnn_sgemm_f32i8f32_compute_biasadd_relu"); + xdnn_sgemm_f32i8f32_compute_biasadd_relu( transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) - TimeLine t("ig_hgemm_f32i8f32_compute_biasadd_relu"); - ig_hgemm_f32i8f32_compute_biasadd_relu( + TimeLine t("xdnn_hgemm_f32i8f32_compute_biasadd_relu"); + xdnn_hgemm_f32i8f32_compute_biasadd_relu( transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); @@ -523,18 +517,18 @@ class MMHelper { const WeiT *packedB, const float *scaleB, const float *zeroB, float beta, float *C, int ldc) { // FP32 if constexpr (std::is_same_v) { - TimeLine t("ig_sgemm_compute_silu"); - ig_sgemm_compute_silu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("xdnn_sgemm_compute_silu"); + xdnn_sgemm_compute_silu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); } // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - TimeLine t("ig_sgemm_f32f16f32_compute_silu"); - ig_sgemm_f32f16f32_compute_silu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("xdnn_sgemm_f32f16f32_compute_silu"); + xdnn_sgemm_f32f16f32_compute_silu(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) - TimeLine t("ig_hgemm_f32f16f32_compute_silu"); - ig_hgemm_f32f16f32_compute_silu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("xdnn_hgemm_f32f16f32_compute_silu"); + xdnn_hgemm_f32f16f32_compute_silu(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc); #else printf("%s:%d: Need to define WEIGHT_ONLY_FP16 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -544,15 +538,15 @@ class MMHelper { // BF16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 - TimeLine t("ig_sgemm_f32bf16f32_compute_silu"); - ig_sgemm_f32bf16f32_compute_silu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("xdnn_sgemm_f32bf16f32_compute_silu"); + xdnn_sgemm_f32bf16f32_compute_silu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if (M > USE_AMX_M) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute_silu"); - ig_amx_sgemm_f32bf16f32_compute_silu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_silu"); + onednn_amx_sgemm_f32bf16f32_compute_silu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); } else { - TimeLine t("ig_bgemm_f32bf16f32_compute_silu"); - ig_bgemm_f32bf16f32_compute_silu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc); + TimeLine t("xdnn_bgemm_f32bf16f32_compute_silu"); + xdnn_bgemm_f32bf16f32_compute_silu(transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc); } #else printf("%s:%d: Need to define WEIGHT_ONLY_BF16 kernel data type.\n", __FILE__, __LINE__); @@ -563,11 +557,11 @@ class MMHelper { // INT8 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_INT8 - TimeLine t("ig_sgemm_f32i8f32_compute_silu"); - ig_sgemm_f32i8f32_compute_silu(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc); + TimeLine t("xdnn_sgemm_f32i8f32_compute_silu"); + xdnn_sgemm_f32i8f32_compute_silu(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) - TimeLine t("ig_hgemm_f32i8f32_compute_silu"); - ig_hgemm_f32i8f32_compute_silu(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc); + TimeLine t("xdnn_hgemm_f32i8f32_compute_silu"); + xdnn_hgemm_f32i8f32_compute_silu(transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -581,18 +575,18 @@ class MMHelper { const float *res, int ldres) { // FP32 if constexpr (std::is_same_v) { - TimeLine t("ig_sgemm_compute_resmul"); - ig_sgemm_compute_resmul(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, res, ldres); + TimeLine t("xdnn_sgemm_compute_resmul"); + xdnn_sgemm_compute_resmul(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, res, ldres); } // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - TimeLine t("ig_sgemm_f32f16f32_compute_resmul"); - ig_sgemm_f32f16f32_compute_resmul(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, res, ldres); + TimeLine t("xdnn_sgemm_f32f16f32_compute_resmul"); + xdnn_sgemm_f32f16f32_compute_resmul(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, res, ldres); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) - TimeLine t("ig_hgemm_f32f16f32_compute_resmul"); - ig_hgemm_f32f16f32_compute_resmul(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, res, ldres); + TimeLine t("xdnn_hgemm_f32f16f32_compute_resmul"); + xdnn_hgemm_f32f16f32_compute_resmul(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, res, ldres); #else printf("%s:%d: Need to define WEIGHT_ONLY_FP16 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -602,16 +596,16 @@ class MMHelper { // BF16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 - TimeLine t("ig_sgemm_f32bf16f32_compute_resmul"); - ig_sgemm_f32bf16f32_compute_resmul(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, res, ldres); + TimeLine t("xdnn_sgemm_f32bf16f32_compute_resmul"); + xdnn_sgemm_f32bf16f32_compute_resmul(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, res, ldres); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if (M > USE_AMX_M) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute_resmul"); - ig_amx_sgemm_f32bf16f32_compute_resmul( + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_resmul"); + onednn_amx_sgemm_f32bf16f32_compute_resmul( transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, res, ldres); } else { - TimeLine t("ig_bgemm_f32bf16f32_compute_resmul"); - ig_bgemm_f32bf16f32_compute_resmul(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, res, ldres); + TimeLine t("xdnn_bgemm_f32bf16f32_compute_resmul"); + xdnn_bgemm_f32bf16f32_compute_resmul(transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, res, ldres); } #else printf("%s:%d: Need to define WEIGHT_ONLY_BF16 kernel data type.\n", __FILE__, __LINE__); @@ -622,12 +616,12 @@ class MMHelper { // INT8 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_INT8 - TimeLine t("ig_sgemm_f32i8f32_compute_resmul"); - ig_sgemm_f32i8f32_compute_resmul( + TimeLine t("xdnn_sgemm_f32i8f32_compute_resmul"); + xdnn_sgemm_f32i8f32_compute_resmul( transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, res, ldres); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) - TimeLine t("ig_hgemm_f32i8f32_compute_resmul"); - ig_hgemm_f32i8f32_compute_resmul( + TimeLine t("xdnn_hgemm_f32i8f32_compute_resmul"); + xdnn_hgemm_f32i8f32_compute_resmul( transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, res, ldres); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); @@ -642,20 +636,20 @@ class MMHelper { const float *bias, const float *res, int ldres) { // FP32 if constexpr (std::is_same_v) { - TimeLine t("ig_sgemm_compute_residential"); - ig_sgemm_compute_residential(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres); + TimeLine t("xdnn_sgemm_compute_residential"); + xdnn_sgemm_compute_residential(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres); } // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - TimeLine t("ig_sgemm_f32f16f32_compute_residential"); - ig_sgemm_f32f16f32_compute_residential( - transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres); + TimeLine t("xdnn_sgemm_f32f16f32_compute_residential"); + xdnn_sgemm_f32f16f32_compute_residential( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias, res, ldres); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) - TimeLine t("ig_hgemm_f32f16f32_compute_residential"); - ig_hgemm_f32f16f32_compute_residential( - transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres); + TimeLine t("xdnn_hgemm_f32f16f32_compute_residential"); + xdnn_hgemm_f32f16f32_compute_residential( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias, res, ldres); #else printf("%s:%d: Need to define WEIGHT_ONLY_FP16 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -665,18 +659,18 @@ class MMHelper { // BF16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 - TimeLine t("ig_sgemm_f32bf16f32_compute_residential"); - ig_sgemm_f32bf16f32_compute_residential( + TimeLine t("xdnn_sgemm_f32bf16f32_compute_residential"); + xdnn_sgemm_f32bf16f32_compute_residential( transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if (M > USE_AMX_M) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute_residential"); - ig_amx_sgemm_f32bf16f32_compute_residential( + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_residential"); + onednn_amx_sgemm_f32bf16f32_compute_residential( transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres); } else { - TimeLine t("ig_bgemm_f32bf16f32_compute_residential"); - ig_bgemm_f32bf16f32_compute_residential( - transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres); + TimeLine t("xdnn_bgemm_f32bf16f32_compute_residential"); + xdnn_bgemm_f32bf16f32_compute_residential( + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias, res, ldres); } #else printf("%s:%d: Need to define WEIGHT_ONLY_BF16 kernel data type.\n", __FILE__, __LINE__); @@ -687,12 +681,12 @@ class MMHelper { // INT8 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_INT8 - TimeLine t("ig_sgemm_f32i8f32_compute_residential"); - ig_sgemm_f32i8f32_compute_residential( + TimeLine t("xdnn_sgemm_f32i8f32_compute_residential"); + xdnn_sgemm_f32i8f32_compute_residential( transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias, res, ldres); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) - TimeLine t("ig_hgemm_f32i8f32_compute_residential"); - ig_hgemm_f32i8f32_compute_residential( + TimeLine t("xdnn_hgemm_f32i8f32_compute_residential"); + xdnn_hgemm_f32i8f32_compute_residential( transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias, res, ldres); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); @@ -707,20 +701,20 @@ class MMHelper { const float *bias, float gamma, float *res, int ldres) { // FP32 if constexpr (std::is_same_v) { - TimeLine t("ig_sgemm_compute_resext"); - ig_sgemm_compute_resext(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, gamma, res, ldres); + TimeLine t("xdnn_sgemm_compute_resext"); + xdnn_sgemm_compute_resext(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, gamma, res, ldres); } // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - TimeLine t("ig_sgemm_f32f16f32_compute_resext"); - ig_sgemm_f32f16f32_compute_resext( - transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, gamma, res, ldres); + TimeLine t("xdnn_sgemm_f32f16f32_compute_resext"); + xdnn_sgemm_f32f16f32_compute_resext( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias, gamma, res, ldres); #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) - TimeLine t("ig_hgemm_f32f16f32_compute_resext"); - ig_hgemm_f32f16f32_compute_resext( - transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, gamma, res, ldres); + TimeLine t("xdnn_hgemm_f32f16f32_compute_resext"); + xdnn_hgemm_f32f16f32_compute_resext( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias, gamma, res, ldres); #else printf("%s:%d: Need to define WEIGHT_ONLY_FP16 kernel data type.\n", __FILE__, __LINE__); exit(-1); @@ -730,24 +724,24 @@ class MMHelper { // BF16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 - TimeLine t("ig_sgemm_f32bf16f32_compute_resext"); - ig_sgemm_f32bf16f32_compute_resext( + TimeLine t("xdnn_sgemm_f32bf16f32_compute_resext"); + xdnn_sgemm_f32bf16f32_compute_resext( transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, gamma, res, ldres); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if (M > USE_AMX_M) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute_residential"); + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_residential"); #pragma omp parallel for collapse(2) for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { res[i * ldres + j] = res[i * ldres + j] * gamma; } } - ig_amx_sgemm_f32bf16f32_compute_residential( + onednn_amx_sgemm_f32bf16f32_compute_residential( transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres); } else { - TimeLine t("ig_bgemm_f32bf16f32_compute_resext"); - ig_bgemm_f32bf16f32_compute_resext( - transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, gamma, res, ldres); + TimeLine t("xdnn_bgemm_f32bf16f32_compute_resext"); + xdnn_bgemm_f32bf16f32_compute_resext( + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias, gamma, res, ldres); } #else printf("%s:%d: Need to define WEIGHT_ONLY_BF16 kernel data type.\n", __FILE__, __LINE__); @@ -758,12 +752,12 @@ class MMHelper { // INT8 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_INT8 - TimeLine t("ig_sgemm_f32i8f32_compute_resext"); - ig_sgemm_f32i8f32_compute_resext( + TimeLine t("xdnn_sgemm_f32i8f32_compute_resext"); + xdnn_sgemm_f32i8f32_compute_resext( transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias, gamma, res, ldres); #elif defined(AVX512_FP16_WEIGHT_ONLY_INT8) - TimeLine t("ig_hgemm_f32i8f32_compute_resext"); - ig_hgemm_f32i8f32_compute_resext( + TimeLine t("xdnn_hgemm_f32i8f32_compute_resext"); + xdnn_hgemm_f32i8f32_compute_resext( transA, M, N, K, alpha, A, lda, packedB, scaleB, zeroB, beta, C, ldc, bias, gamma, res, ldres); #else printf("%s:%d: Need to define WEIGHT_ONLY_INT8 kernel data type.\n", __FILE__, __LINE__); @@ -805,10 +799,10 @@ class MMHelper { return key; } - static void ig_amx_sgemm_f32bf16f32_compute(bool transA, int M, int N, int K, float alpha, const float *A, int lda, + static void onednn_amx_sgemm_f32bf16f32_compute(bool transA, int M, int N, int K, float alpha, const float *A, int lda, const bfloat16_t *packedB, float beta, float *C, int ldc) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute"); - TimeLine t1("ig_amx_sgemm_f32bf16f32_compute.create_primitive"); + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute"); + TimeLine t1("onednn_amx_sgemm_f32bf16f32_compute.create_primitive"); using namespace dnnl; using tag = memory::format_tag; using dt = memory::data_type; @@ -854,7 +848,7 @@ class MMHelper { t1.release(); // Executions. - TimeLine t2("ig_amx_sgemm_f32bf16f32_compute.execute_primitive"); + TimeLine t2("onednn_amx_sgemm_f32bf16f32_compute.execute_primitive"); // Reorder #pragma omp parallel for for (int i = 0; i < M; ++i) { @@ -865,10 +859,10 @@ class MMHelper { get_dnnl_stream().wait(); } - static void ig_amx_sgemm_f32bf16f32_compute_biasadd(bool transA, int M, int N, int K, float alpha, const float *A, + static void onednn_amx_sgemm_f32bf16f32_compute_biasadd(bool transA, int M, int N, int K, float alpha, const float *A, int lda, const bfloat16_t *packedB, float beta, float *C, int ldc, const float *bias) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute_biasadd"); - TimeLine t1("ig_amx_sgemm_f32bf16f32_compute_biasadd.create_primitive"); + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_biasadd"); + TimeLine t1("onednn_amx_sgemm_f32bf16f32_compute_biasadd.create_primitive"); using namespace dnnl; using tag = memory::format_tag; using dt = memory::data_type; @@ -918,7 +912,7 @@ class MMHelper { t1.release(); // Executions. - TimeLine t2("ig_amx_sgemm_f32bf16f32_compute_biasadd.execute_primitive"); + TimeLine t2("onednn_amx_sgemm_f32bf16f32_compute_biasadd.execute_primitive"); // Reorder #pragma omp parallel for for (int i = 0; i < M; ++i) { @@ -929,10 +923,10 @@ class MMHelper { get_dnnl_stream().wait(); } - static void ig_amx_sgemm_f32bf16f32_compute_biasadd_relu(bool transA, int M, int N, int K, float alpha, + static void onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu(bool transA, int M, int N, int K, float alpha, const float *A, int lda, const bfloat16_t *packedB, float beta, float *C, int ldc, const float *bias) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute_biasadd_relu"); - TimeLine t1("ig_amx_sgemm_f32bf16f32_compute_biasadd_relu.create_primitive"); + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu"); + TimeLine t1("onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu.create_primitive"); using namespace dnnl; using tag = memory::format_tag; using dt = memory::data_type; @@ -991,7 +985,7 @@ class MMHelper { t1.release(); // Executions. - TimeLine t2("ig_amx_sgemm_f32bf16f32_compute_biasadd_relu.execute_primitive"); + TimeLine t2("onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu.execute_primitive"); // Reorder #pragma omp parallel for for (int i = 0; i < M; ++i) { @@ -1002,10 +996,10 @@ class MMHelper { get_dnnl_stream().wait(); } - static void ig_amx_sgemm_f32bf16f32_compute_silu(bool transA, int M, int N, int K, float alpha, const float *A, + static void onednn_amx_sgemm_f32bf16f32_compute_silu(bool transA, int M, int N, int K, float alpha, const float *A, int lda, const bfloat16_t *packedB, float beta, float *C, int ldc) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute_silu"); - TimeLine t1("ig_amx_sgemm_f32bf16f32_compute_silu.create_primitive"); + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_silu"); + TimeLine t1("onednn_amx_sgemm_f32bf16f32_compute_silu.create_primitive"); using namespace dnnl; using tag = memory::format_tag; using dt = memory::data_type; @@ -1059,7 +1053,7 @@ class MMHelper { t1.release(); // Executions. - TimeLine t2("ig_amx_sgemm_f32bf16f32_compute_silu.execute_primitive"); + TimeLine t2("onednn_amx_sgemm_f32bf16f32_compute_silu.execute_primitive"); // Reorder #pragma omp parallel for for (int i = 0; i < M; ++i) { @@ -1071,10 +1065,10 @@ class MMHelper { } // TODO: may be error - static void ig_amx_sgemm_f32bf16f32_compute_resmul(bool transA, int M, int N, int K, float alpha, const float *A, + static void onednn_amx_sgemm_f32bf16f32_compute_resmul(bool transA, int M, int N, int K, float alpha, const float *A, int lda, const bfloat16_t *packedB, float beta, float *C, int ldc, const float *res, int ldres) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute_resmul"); - TimeLine t1("ig_amx_sgemm_f32bf16f32_compute_resmul.create_primitive"); + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_resmul"); + TimeLine t1("onednn_amx_sgemm_f32bf16f32_compute_resmul.create_primitive"); using namespace dnnl; using tag = memory::format_tag; using dt = memory::data_type; @@ -1143,7 +1137,7 @@ class MMHelper { t1.release(); // Executions. - TimeLine t2("ig_amx_sgemm_f32bf16f32_compute_resmul.execute_primitive"); + TimeLine t2("onednn_amx_sgemm_f32bf16f32_compute_resmul.execute_primitive"); // Reorder #pragma omp parallel for for (int i = 0; i < M; ++i) { @@ -1154,11 +1148,11 @@ class MMHelper { get_dnnl_stream().wait(); } - static void ig_amx_sgemm_f32bf16f32_compute_residential(bool transA, int M, int N, int K, float alpha, + static void onednn_amx_sgemm_f32bf16f32_compute_residential(bool transA, int M, int N, int K, float alpha, const float *A, int lda, const bfloat16_t *packedB, float beta, float *C, int ldc, const float *bias, const float *res, int ldres) { - TimeLine t("ig_amx_sgemm_f32bf16f32_compute_residential"); - TimeLine t1("ig_amx_sgemm_f32bf16f32_compute_residential.create_primitive"); + TimeLine t("onednn_amx_sgemm_f32bf16f32_compute_residential"); + TimeLine t1("onednn_amx_sgemm_f32bf16f32_compute_residential.create_primitive"); using namespace dnnl; using tag = memory::format_tag; using dt = memory::data_type; @@ -1230,7 +1224,7 @@ class MMHelper { t1.release(); // Executions. - TimeLine t2("ig_amx_sgemm_f32bf16f32_compute_bias_residential.execute_primitive"); + TimeLine t2("onednn_amx_sgemm_f32bf16f32_compute_bias_residential.execute_primitive"); // Reorder #pragma omp parallel for for (int i = 0; i < M; ++i) { diff --git a/tests/ut/small_gemm_test.cpp b/tests/ut/small_gemm_test.cpp index 84c97be6..faa3b5b7 100644 --- a/tests/ut/small_gemm_test.cpp +++ b/tests/ut/small_gemm_test.cpp @@ -14,8 +14,7 @@ // ============================================================================ #include -#include "ig_sgemm_f32f16f32.h" -#include "ig_sgemm.h" +#include "xdnn.h" #include "float16.h" #include "gtest/gtest.h"