Skip to content

Commit

Permalink
[API] Export layernorm and rmsnorm function API.
Browse files Browse the repository at this point in the history
  • Loading branch information
changqi1 committed Oct 24, 2023
1 parent 1990f20 commit 5b86e05
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 34 deletions.
9 changes: 7 additions & 2 deletions include/layers_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

namespace xft {

void invokeLayerNorm(float *output, const float *input, const float *gamma, const float *beta, const int rows,
const int size, int iStride = -1, int oStride = -1, const float epsilon = 1e-5);

void invokeRmsNorm(float *output, const float *input, const float *weight, int rows, int cols, int iStride = -1,
int oStride = -1, float epsilon = 1e-6);

// Layer normalization: only support the norm along last dimension
class LayerNorm {
public:
Expand Down Expand Up @@ -30,8 +36,7 @@ class RmsNorm {
void setWeight(const float *w, const float *, int size);

// input and output are in shape of (rows, normSize)
void forward(
const float *input, float *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6);
void forward(const float *input, float *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6);

private:
int normSize;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#pragma once

#include <immintrin.h>

#include "bfloat16.h"
Expand All @@ -14,19 +12,6 @@ struct LayerNormWeight {
const T *beta = nullptr;
};

template <typename T>
void invokeLayerNorm(T *output, const T *input, const T *gamma, const T *beta, const int rows, const int size,
int iStride = -1, int oStride = -1, const float epsilon = 1e-5) {
if constexpr (std::is_same_v<T, float16_t> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, int8_t>) {
printf("Type %s not supported!\n", typeid(T).name());
exit(-1);
} else {
printf("Type %s not supported!\n", typeid(T).name());
exit(-1);
}
}

template <>
void invokeLayerNorm(float *output, const float *input, const float *gamma, const float *beta, const int rows,
const int size, int iStride, int oStride, const float epsilon) {

Expand Down
15 changes: 0 additions & 15 deletions src/kernels/rmsnorm_kernels.h → src/kernels/rmsnorm_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#pragma once

#include <immintrin.h>

#include "bfloat16.h"
Expand All @@ -8,19 +6,6 @@

namespace xft {

template <typename T>
void invokeRmsNorm(T *output, const T *input, const T *weight, int rows, int cols, int iStride = -1, int oStride = -1,
const float epsilon = 1e-6) {
if constexpr (std::is_same_v<T, float16_t> || std::is_same_v<T, bfloat16_t> || std::is_same_v<T, int8_t>) {
printf("Type %s not supported!\n", typeid(T).name());
exit(-1);
} else {
printf("Type %s not supported!\n", typeid(T).name());
exit(-1);
}
}

template <>
void invokeRmsNorm(float *output, const float *input, const float *weight, int rows, int cols, int iStride, int oStride,
float epsilon) {
int size = cols;
Expand Down
1 change: 0 additions & 1 deletion src/layers/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <cstring>
#include <immintrin.h>

#include "layernorm_kernels.h"
#include "layers_norm.h"
#include "timeline.h"

Expand Down
1 change: 0 additions & 1 deletion src/layers/rms_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include <cstring>

#include "layers_norm.h"
#include "rmsnorm_kernels.h"
#include "timeline.h"

namespace xft {
Expand Down

0 comments on commit 5b86e05

Please sign in to comment.