diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index 60a8404a0..e549120a9 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -52,6 +52,9 @@ void caffe_scal(const int N, const Dtype alpha, Dtype *X); template void caffe_sqr(const int N, const Dtype* a, Dtype* y); +template +void caffe_sqrt(const int N, const Dtype* a, Dtype* y); + template void caffe_add(const int N, const Dtype* a, const Dtype* b, Dtype* y); diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 71c02274a..59625bc05 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -196,6 +196,16 @@ void caffe_sqr(const int n, const double* a, double* y) { vdSqr(n, a, y); } +template <> +void caffe_sqrt(const int n, const float* a, float* y) { + vsSqrt(n, a, y); +} + +template <> +void caffe_sqrt(const int n, const double* a, double* y) { + vdSqrt(n, a, y); +} + template <> void caffe_exp(const int n, const float* a, float* y) { vsExp(n, a, y);