diff --git a/nntrainer/tensor/float_tensor.cpp b/nntrainer/tensor/float_tensor.cpp index 7a644a56a5..5ba65578e2 100644 --- a/nntrainer/tensor/float_tensor.cpp +++ b/nntrainer/tensor/float_tensor.cpp @@ -507,6 +507,21 @@ Tensor &FloatTensor::add_strided(Tensor const &input, Tensor &output, return output; } +int FloatTensor::add_i(Tensor const &m, Tensor &output, float const alpha) { + auto f = [&](const BroadcastInfo &e, const float *buf, const float *m_buf, + float *out_buf) { + saxpy(e.buffer_size, alpha, m_buf, e.strides[3], out_buf, strides[3]); + }; + + try { + apply_broadcast(m, f, output); + } catch (std::exception &err) { + ml_loge("%s %s", typeid(err).name(), err.what()); + return ML_ERROR_INVALID_PARAMETER; + } + return ML_ERROR_NONE; +} + Tensor &FloatTensor::add(float const &value, Tensor &output) const { auto f = std::bind(std::plus(), std::placeholders::_1, value); apply(f, output); diff --git a/nntrainer/tensor/float_tensor.h b/nntrainer/tensor/float_tensor.h index 7b7371e189..b91df6c20c 100644 --- a/nntrainer/tensor/float_tensor.h +++ b/nntrainer/tensor/float_tensor.h @@ -64,6 +64,12 @@ class FloatTensor : public TensorBase { std::vector>>> const &d, Tformat fm); + /** + * @brief Construct a new FloatTensor object + * @param rhs TensorBase object to copy + */ + FloatTensor(TensorBase &rhs) : TensorBase(rhs) {} + /** * @brief Basic Destructor */ @@ -256,6 +262,11 @@ class FloatTensor : public TensorBase { Tensor &add_strided(Tensor const &input, Tensor &output, const float beta) const override; + /** + * @copydoc Tensor::add_i(Tensor const &m, float const alpha) + */ + int add_i(Tensor const &m, Tensor &output, float const alpha) override; + /** * @copydoc Tensor::add(float const &value, Tensor &output) */ diff --git a/nntrainer/tensor/half_tensor.cpp b/nntrainer/tensor/half_tensor.cpp index e29d3fd651..230eaf23f7 100644 --- a/nntrainer/tensor/half_tensor.cpp +++ b/nntrainer/tensor/half_tensor.cpp @@ -479,6 +479,22 @@ Tensor &HalfTensor::add_strided(Tensor const &input, Tensor &output, return output; } +int HalfTensor::add_i(Tensor const &m, Tensor &output, float const alpha) { + auto f = [&](const BroadcastInfo &e, const _FP16 *buf, const _FP16 *m_buf, + _FP16 *out_buf) { + saxpy(e.buffer_size, alpha, m_buf, e.strides[3], out_buf, strides[3]); + /// @todo: saxpy is not valid for _FP16 + }; + + try { + apply_broadcast(m, f, output); + } catch (std::exception &err) { + ml_loge("%s %s", typeid(err).name(), err.what()); + return ML_ERROR_INVALID_PARAMETER; + } + return ML_ERROR_NONE; +} + Tensor &HalfTensor::add(float const &value, Tensor &output) const { auto f = std::bind(std::plus<_FP16>(), std::placeholders::_1, static_cast<_FP16>(value)); diff --git a/nntrainer/tensor/half_tensor.h b/nntrainer/tensor/half_tensor.h index 93333db472..0deccdfcae 100644 --- a/nntrainer/tensor/half_tensor.h +++ b/nntrainer/tensor/half_tensor.h @@ -63,6 +63,13 @@ class HalfTensor : public TensorBase { HalfTensor(std::vector>>> const &d, Tformat fm); + /** + * @brief Construct a new FloatTensor object + * + * @param rhs TensorBase object to copy + */ + HalfTensor(TensorBase &rhs) : TensorBase(rhs) {} + /** * @brief Basic Destructor */ @@ -255,6 +262,11 @@ class HalfTensor : public TensorBase { Tensor &add_strided(Tensor const &input, Tensor &output, const float beta) const override; + /** + * @copydoc Tensor::add_i(Tensor const &m, float const alpha) + */ + int add_i(Tensor const &m, Tensor &output, float const alpha) override; + /** * @copydoc Tensor::add(float const &value, Tensor &output) */ diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index b8dde7c1bc..8c25105ac2 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -10,6 +10,7 @@ */ #include +#include #include #ifdef ENABLE_FP16 @@ -100,6 +101,35 @@ Tensor::Tensor( } #endif +Tensor::Tensor(const Tensor &rhs) { + if (rhs.getDataType() == Tdatatype::FP32) { + itensor = std::shared_ptr(new FloatTensor(*rhs.itensor), + std::default_delete()); + } else if (rhs.getDataType() == Tdatatype::FP16) { +#ifdef ENABLE_FP16 + itensor = std::shared_ptr(new HalfTensor(*rhs.itensor), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } +} + +Tensor &Tensor::operator=(const Tensor &rhs) { + if (rhs.getDataType() == Tdatatype::FP32) { + itensor = std::shared_ptr(new FloatTensor(*rhs.itensor), + std::default_delete()); + } else if (rhs.getDataType() == Tdatatype::FP16) { +#ifdef ENABLE_FP16 + itensor = std::shared_ptr(new HalfTensor(*rhs.itensor), + std::default_delete()); +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } + return *this; +} + bool Tensor::operator==(const Tensor &rhs) const { /// compares tensor information if (*itensor == *rhs.itensor) { @@ -176,7 +206,7 @@ int Tensor::multiply_i_strided(Tensor const &m, const float beta) { } Tensor Tensor::multiply_strided(Tensor const &m, const float beta) const { - Tensor t; + Tensor t("", getFormat(), getDataType()); return this->multiply_strided(m, t, beta); } @@ -194,7 +224,7 @@ int Tensor::multiply_i(float const &value) { } Tensor Tensor::multiply(float const &value) const { - Tensor t; + Tensor t("", getFormat(), getDataType()); return multiply(value, t); } @@ -319,13 +349,7 @@ Tensor &Tensor::add(float const &value, Tensor &output) const { } int Tensor::add_i(Tensor const &m, float const alpha) { - try { - this->add(m, *this, alpha); - } catch (std::exception &err) { - ml_loge("%s %s", typeid(err).name(), err.what()); - return ML_ERROR_INVALID_PARAMETER; - } - return ML_ERROR_NONE; + return itensor->add_i(m, *this, alpha); } Tensor Tensor::add(Tensor const &m, float const alpha) const { @@ -536,6 +560,8 @@ void Tensor::cos(Tensor &out, float alpha) { itensor->cos(out, alpha); } +LazyTensor Tensor::chain() const { return LazyTensor(*this); } + float Tensor::l2norm() const { return itensor->l2norm(); } void Tensor::normalization_i() { diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 463da5aabe..2666217d76 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -13,6 +13,8 @@ #define __TENSOR_H__ #ifdef __cplusplus +#define MAKE_SHARED_TENSOR(...) std::make_shared(__VA_ARGS__) + #define CREATE_IF_EMPTY_DIMS(tensor, ...) \ do { \ if (tensor.empty()) \ @@ -26,6 +28,8 @@ namespace nntrainer { +class LazyTensor; + /** * @class Tensor Class * @brief Tensor Class @@ -213,7 +217,7 @@ class Tensor { * @brief Copy constructor of Tensor. * @param[in] Tensor & */ - Tensor(const Tensor &rhs) = default; + Tensor(const Tensor &rhs); /** * @brief Move constructor of Tensor. @@ -225,7 +229,7 @@ class Tensor { * @brief Copy assignment operator. * @param[in] rhs Tensor to be copied. */ - Tensor &operator=(const Tensor &rhs) = default; + Tensor &operator=(const Tensor &rhs); /** * @brief Move assignment operator. @@ -269,7 +273,7 @@ class Tensor { "Creating shared tensor of size bigger than tensor memory."); } - Tensor output; + Tensor output("", d.getFormat(), d.getDataType()); output.setTensorVar(d, buf, offset); return output; }; @@ -941,6 +945,12 @@ class Tensor { */ void cos(Tensor &out, float alpha = 1.0); + /** + * @brief Anchor a starting point to defer following evaluation + * @retval LazyTensor class that can be used with run(); + */ + LazyTensor chain() const; + /** * @brief l2norm the Tensor elements * @retval Calculated l2norm @@ -1439,6 +1449,8 @@ class Tensor { std::swap(lhs.itensor, rhs.itensor); } + static constexpr float epsilon = 1e-5; + private: std::shared_ptr itensor; diff --git a/nntrainer/tensor/tensor_base.h b/nntrainer/tensor/tensor_base.h index 2eb13c72e6..71949443d9 100644 --- a/nntrainer/tensor/tensor_base.h +++ b/nntrainer/tensor/tensor_base.h @@ -114,6 +114,21 @@ class TensorBase { TensorBase(const TensorDim &d, const void *buf = nullptr) : TensorBase(d, true) {} + /** + * @brief Copy constructor of TensorBase. + * @param[in] Tensor & + */ + TensorBase(const TensorBase &rhs) { + dim = rhs.dim; + strides = rhs.strides; + contiguous = rhs.contiguous; + initializer = rhs.initializer; + name = rhs.name; + data = rhs.data; + offset = rhs.offset; + src_tensor = rhs.src_tensor; + } + /** * @brief Comparison operator overload * @param[in] rhs Tensor to be compared with @@ -263,6 +278,11 @@ class TensorBase { virtual Tensor &add_strided(Tensor const &input, Tensor &output, const float beta) const = 0; + /** + * @copydoc Tensor::add_i(Tensor const &m, float const alpha) + */ + virtual int add_i(Tensor const &m, Tensor &output, float const alpha) = 0; + /** * @copydoc Tensor::add(float const &value, Tensor &output) */