diff --git a/include/layers_norm.h b/include/layers_norm.h index e93be7c8..8235cff5 100644 --- a/include/layers_norm.h +++ b/include/layers_norm.h @@ -2,6 +2,12 @@ namespace xft { +void invokeLayerNorm(float *output, const float *input, const float *gamma, const float *beta, const int rows, + const int size, int iStride = -1, int oStride = -1, const float epsilon = 1e-5); + +void invokeRmsNorm(float *output, const float *input, const float *weight, int rows, int cols, int iStride = -1, + int oStride = -1, float epsilon = 1e-6); + // Layer normalization: only support the norm along last dimension class LayerNorm { public: @@ -30,8 +36,7 @@ class RmsNorm { void setWeight(const float *w, const float *, int size); // input and output are in shape of (rows, normSize) - void forward( - const float *input, float *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6); + void forward(const float *input, float *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6); private: int normSize; diff --git a/src/kernels/layernorm_kernels.h b/src/kernels/layernorm_kernels.cpp similarity index 79% rename from src/kernels/layernorm_kernels.h rename to src/kernels/layernorm_kernels.cpp index 077afb54..1c2ac6b3 100644 --- a/src/kernels/layernorm_kernels.h +++ b/src/kernels/layernorm_kernels.cpp @@ -1,5 +1,3 @@ -#pragma once - #include #include "bfloat16.h" @@ -14,19 +12,6 @@ struct LayerNormWeight { const T *beta = nullptr; }; -template -void invokeLayerNorm(T *output, const T *input, const T *gamma, const T *beta, const int rows, const int size, - int iStride = -1, int oStride = -1, const float epsilon = 1e-5) { - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - printf("Type %s not supported!\n", typeid(T).name()); - exit(-1); - } else { - printf("Type %s not supported!\n", typeid(T).name()); - exit(-1); - } -} - -template <> void invokeLayerNorm(float *output, const float *input, const float *gamma, const float *beta, const int rows, const int size, int iStride, int oStride, const float epsilon) { diff --git a/src/kernels/rmsnorm_kernels.h b/src/kernels/rmsnorm_kernels.cpp similarity index 78% rename from src/kernels/rmsnorm_kernels.h rename to src/kernels/rmsnorm_kernels.cpp index 350468c7..a8642d05 100644 --- a/src/kernels/rmsnorm_kernels.h +++ b/src/kernels/rmsnorm_kernels.cpp @@ -1,5 +1,3 @@ -#pragma once - #include #include "bfloat16.h" @@ -8,19 +6,6 @@ namespace xft { -template -void invokeRmsNorm(T *output, const T *input, const T *weight, int rows, int cols, int iStride = -1, int oStride = -1, - const float epsilon = 1e-6) { - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - printf("Type %s not supported!\n", typeid(T).name()); - exit(-1); - } else { - printf("Type %s not supported!\n", typeid(T).name()); - exit(-1); - } -} - -template <> void invokeRmsNorm(float *output, const float *input, const float *weight, int rows, int cols, int iStride, int oStride, float epsilon) { int size = cols; diff --git a/src/layers/layer_norm.cpp b/src/layers/layer_norm.cpp index 9d3f976e..db9feb7f 100644 --- a/src/layers/layer_norm.cpp +++ b/src/layers/layer_norm.cpp @@ -3,7 +3,6 @@ #include #include -#include "layernorm_kernels.h" #include "layers_norm.h" #include "timeline.h" diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index c80f0d53..202e74c2 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -5,7 +5,6 @@ #include #include "layers_norm.h" -#include "rmsnorm_kernels.h" #include "timeline.h" namespace xft {