Skip to content

Commit

Permalink
[ Tensor ] Support NHWC for dot, add/multiply_strided and other ops
Browse files Browse the repository at this point in the history
This PR includes changes of Tensor and TensorDim to support NHWC
computation for dot, add_strided, multiply_strided, cat, split,
and transpose. It also includes unittests to evaluate.

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Adwaith Anand <[email protected]>
Signed-off-by: Manohara HK <[email protected]>
Signed-off-by: jijoong.moon <[email protected]>
  • Loading branch information
adwaith-a authored and jijoongmoon committed Jul 17, 2023
1 parent 4356071 commit 4bea951
Show file tree
Hide file tree
Showing 9 changed files with 5,844 additions and 449 deletions.
4 changes: 2 additions & 2 deletions api/ccapi/include/tensor_dim.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class TensorDim {
* @param dims std::initialize_list
* @param fm format NCHW | HNWC
*
* formats of {w}, {h, w}, {c, h, w}, {b, c, h, w} for the NCHW are accepted
* formats of {c}, {w, c}, {h, w, c}, {b, h, w, c} for the NHWC are accepted
* formats of {w}, {h, w}, {c, h, w}, {b, c, h, w} for the NCHW & NHWC are
* accepted
*/
TensorDim(std::initializer_list<size_t> dims,
TensorType t_type_ = TensorType());
Expand Down
879 changes: 676 additions & 203 deletions nntrainer/tensor/tensor.cpp

Large diffs are not rendered by default.

172 changes: 120 additions & 52 deletions nntrainer/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@
namespace nntrainer {

using TensorDim = ml::train::TensorDim;

/**
* @brief NHWC is WIP
*/
using Tformat = ml::train::TensorDim::Format;
using Tdatatype = ml::train::TensorDim::DataType;

Expand Down Expand Up @@ -123,9 +119,9 @@ class Tensor {
/**
* @brief Constructor of Tensor
* @param[in] d0 Batch of Tensor
* @param[in] d1 Channel (NCHW) or Height (NHWC)
* @param[in] d2 Height (NCHW) or Width (NHWC)
* @param[in] d3 Width (NCHW) or Channel (NHWC)
* @param[in] d1 Channel
* @param[in] d2 Height
* @param[in] d3 Width
*/
Tensor(size_t d0, size_t d1, size_t d2, size_t d3, Tformat fm = Tformat::NCHW,
Tdatatype d_type = Tdatatype::FP32) :
Expand Down Expand Up @@ -173,9 +169,9 @@ class Tensor {

/**
* @brief Constructor of Tensor
* @param[in] d1 Channel (NCHW) or Height (NHWC)
* @param[in] d2 Height (NCHW) or Width (NHWC)
* @param[in] d3 Width (NCHW) or Channel (NHWC)
* @param[in] d1 Channel
* @param[in] d2 Height
* @param[in] d3 Width
*/
Tensor(size_t d1, size_t d2, size_t d3, ml::train::TensorDim::TensorType t_type) :
Tensor(1, d1, d2, d3, t_type){};
Expand All @@ -185,35 +181,46 @@ class Tensor {
* @param[in] d2 Height (NCHW) or Width (NHWC)
* @param[in] d3 Width (NCHW) or Channel (NHWC)
*/
Tensor(size_t d2, size_t d3,ml::train::TensorDim::TensorType t_type) :
Tensor(1, 1, d2, d3, t_type){};

Tensor(size_t d2, size_t d3, ml::train::TensorDim::TensorType t_type) :
Tensor(1, (t_type.format == Tformat::NCHW) ? 1 : d3,
(t_type.format == Tformat::NCHW) ? d2 : 1,
(t_type.format == Tformat::NCHW) ? d3 : d2, t_type){};
/**
* @brief Constructor of Tensor with just Width or Channel
* @param[in] d3 Width (NCHW) or Channel (NHWC)
*/
explicit Tensor(size_t d3, ml::train::TensorDim::TensorType t_type) :
Tensor(1, 1, 1, d3, t_type){};
Tensor(1, (t_type.format == Tformat::NCHW) ? 1 : d3, 1,
(t_type.format == Tformat::NCHW) ? d3 : 1, t_type){};


/**
* @brief Constructor of Tensor
* @param[in] d data for the Tensor
* @param[in] d data for the Tensor. It needs to set format properly.
*/
Tensor(std::vector<std::vector<std::vector<std::vector<float>>>> const &d) {

Tensor(std::vector<std::vector<std::vector<std::vector<float>>>> const &d,
Tformat fm) {

This comment has been minimized.

Copy link
@djeong20

djeong20 Jul 18, 2023

default argument is needed (e.g., Tformat fm = Tformat::NCHW)


if (d.empty() || d[0].empty() || d[0][0].empty() || d[0][0][0].empty()) {
throw std::out_of_range(
"[Tensor] trying to initialize Tensor from empty vector");
}
// if fm == Tformat::NCHW, then dim[0] == batch , dim[1] == channel, dim[2]
// == height, dim[3] == width. and if fm == Tformat::NHWC, dim[0] == batch,
// dim[1] == height, dim[2] == width, dim[3] == channel
dim.setTensorDim(0, d.size());
if (fm == Tformat::NCHW) {
dim.setTensorDim(1, d[0].size());
dim.setTensorDim(2, d[0][0].size());
dim.setTensorDim(3, d[0][0][0].size());
} else {
dim.setTensorDim(2, d[0].size());
dim.setTensorDim(3, d[0][0].size());
dim.setTensorDim(1, d[0][0][0].size());
}

dim.batch(d.size());
dim.channel(d[0].size());
dim.height(d[0][0].size());
dim.width(d[0][0][0].size());
strides = dim.computeStrides();

MemoryData* mem_data = new MemoryData((void *)(new float[dim.getDataLen()]()));
MemoryData *mem_data =
new MemoryData((void *)(new float[dim.getDataLen()]()));
data = std::shared_ptr<MemoryData>(mem_data, [](MemoryData *mem_data) {
delete[] mem_data->getAddr<float>();
});
Expand All @@ -223,44 +230,62 @@ class Tensor {

setDataType(Tdatatype::FP32);

for (unsigned int i = 0; i < dim.batch(); ++i)
for (unsigned int j = 0; j < dim.channel(); ++j)
for (unsigned int k = 0; k < dim.height(); ++k)
for (unsigned int l = 0; l < dim.width(); ++l) {
this->setValue(i, j, k, l, d[i][j][k][l]);
}
// if fm == Tformat::NCHW, then dim[0] == batch , dim[1] == channel, dim[2]
// == height, dim[3] == width. and if fm == Tformat::NHWC, dim[0] == batch,
// dim[1] == height, dim[2] == width, dim[3] == channel
if (fm == Tformat::NCHW) {
for (unsigned int i = 0; i < batch(); ++i)
for (unsigned int j = 0; j < channel(); ++j)
for (unsigned int k = 0; k < height(); ++k)
for (unsigned int l = 0; l < width(); ++l)
this->setValue(i, j, k, l, d[i][j][k][l]);
} else {
for (unsigned int i = 0; i < batch(); ++i)
for (unsigned int j = 0; j < height(); ++j)
for (unsigned int k = 0; k < width(); ++k)
for (unsigned int l = 0; l < channel(); ++l)
this->setValue(i, l, j, k, d[i][j][k][l]);
}
};

/**
* @brief Constructor of Tensor
* @note This constructor copies vector again. needs refactoring
* @param[in] d data for the Tensor
* @param[in] d data for the Tensor. It needs to set format properly.
*/
Tensor(std::vector<std::vector<std::vector<float>>> const &d) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}){};
Tensor(std::vector<std::vector<std::vector<float>>> const &d,
Tformat fm = Tformat::NCHW) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}, fm){};

/**
* @brief Constructor of Tensor
* @note This constructor copies vector again. needs refactoring
* @param[in] d data for the Tensor with batch size one
*/
Tensor(std::vector<std::vector<float>> const &d) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}){};
Tensor(std::vector<std::vector<float>> const &d, Tformat fm = Tformat::NCHW) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}, fm){};

Tensor(std::vector<std::vector<std::vector<std::vector<__fp16>>>> const &d) {
Tensor(std::vector<std::vector<std::vector<std::vector<__fp16>>>> const &d,
Tformat fm) {

This comment has been minimized.

Copy link
@djeong20

djeong20 Jul 18, 2023

Tformat fm = Tformat::NCHW


if (d.empty() || d[0].empty() || d[0][0].empty() || d[0][0][0].empty()) {
throw std::out_of_range(
"[Tensor] trying to initialize Tensor from empty vector");
}

dim.batch(d.size());
dim.channel(d[0].size());
dim.height(d[0][0].size());
dim.width(d[0][0][0].size());
strides = dim.computeStrides();
dim.setTensorDim(0, d.size());
if (fm == Tformat::NCHW) {
dim.setTensorDim(1, d[0].size());
dim.setTensorDim(2, d[0][0].size());
dim.setTensorDim(3, d[0][0][0].size());
} else {
dim.setTensorDim(2, d[0].size());
dim.setTensorDim(3, d[0][0].size());
dim.setTensorDim(1, d[0][0][0].size());
}

MemoryData* mem_data = new MemoryData((void *)(new __fp16[dim.getDataLen()]()));
MemoryData *mem_data =
new MemoryData((void *)(new __fp16[dim.getDataLen()]()));
data = std::shared_ptr<MemoryData>(mem_data, [](MemoryData *mem_data) {
delete[] mem_data->getAddr<__fp16>();
});
Expand All @@ -270,28 +295,41 @@ class Tensor {

setDataType(Tdatatype::FP16);

for (unsigned int i = 0; i < dim.batch(); ++i)
for (unsigned int j = 0; j < dim.channel(); ++j)
for (unsigned int k = 0; k < dim.height(); ++k)
for (unsigned int l = 0; l < dim.width(); ++l)
this->setValue(i, j, k, l, d[i][j][k][l]);
// if fm == Tformat::NCHW, then dim[0] == batch , dim[1] == channel, dim[2]
// == height, dim[3] == width. and if fm == Tformat::NHWC, dim[0] == batch,
// dim[1] == height, dim[2] == width, dim[3] == channel
if (fm == Tformat::NCHW) {
for (unsigned int i = 0; i < batch(); ++i)
for (unsigned int j = 0; j < channel(); ++j)
for (unsigned int k = 0; k < height(); ++k)
for (unsigned int l = 0; l < width(); ++l)
this->setValue(i, j, k, l, d[i][j][k][l]);
} else {
for (unsigned int i = 0; i < batch(); ++i)
for (unsigned int j = 0; j < height(); ++j)
for (unsigned int k = 0; k < width(); ++k)
for (unsigned int l = 0; l < channel(); ++l)
this->setValue(i, l, j, k, d[i][j][k][l]);
}
};

/**
* @brief Constructor of Tensor
* @note This constructor copies vector again. needs refactoring
* @param[in] d data for the Tensor
*/
Tensor(std::vector<std::vector<std::vector<__fp16>>> const &d) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}){};
Tensor(std::vector<std::vector<std::vector<__fp16>>> const &d,
Tformat fm = Tformat::NCHW) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}, fm){};

/**
* @brief Constructor of Tensor
* @note This constructor copies vector again. needs refactoring
* @param[in] d data for the Tensor with batch size one
*/
Tensor(std::vector<std::vector<__fp16>> const &d) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}){};
Tensor(std::vector<std::vector<__fp16>> const &d,
Tformat fm = Tformat::NCHW) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}, fm){};

/**
* @brief Copy constructor of Tensor.
Expand Down Expand Up @@ -1214,6 +1252,15 @@ class Tensor {
*/
void print(std::ostream &out) const;

/**
* @brief Print element
* @param[in] out out stream
* @param[in] opt print formatting option. opt=0 would pretty print the data,
* else it would print the raw data.
* @retval Tensor
*/
void print_(std::ostream &out, uint opt = 0) const;

/**
* @brief Get size of current tensor
* @retval unsigned int size of the current tensor
Expand Down Expand Up @@ -1645,7 +1692,28 @@ class Tensor {
*/
inline size_t getIndex(unsigned int b, unsigned int c, unsigned int h,
unsigned int w) const noexcept {
return (b * strides[0] + c * strides[1] + h * strides[2] + w * strides[3]);
if (getFormat() == Tformat::NCHW)
return (b * strides[0] + c * strides[1] + h * strides[2] +
w * strides[3]);
else
return (b * strides[0] + h * strides[1] + w * strides[2] +
c * strides[3]);
}

/**
* @brief Check if two given axes are contiguous
*/
bool checkContinuous(unsigned int n, unsigned int np1) const {
std::vector<unsigned int> continuous_order_nhwc = {0, 3, 1, 2};
bool continuous = false;
if (getFormat() == Tformat::NHWC) {
if (continuous_order_nhwc[np1] == continuous_order_nhwc[n] + 1)
continuous = true;
} else {
if (n + 1 == np1)
continuous = true;
}
return continuous;
}

/**
Expand Down
41 changes: 11 additions & 30 deletions nntrainer/tensor/tensor_dim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,9 @@ TensorDim::TensorDim(std::initializer_list<size_t> dims, TensorType t_type_) :
}
}

// TensorDim::TensorDim(std::initializer_list<size_t> dims, TensorDim::Format fm,
// TensorDim::DataType d_type) :
// TensorDim(dims, TensorType{fm, d_type}) {}

TensorDim::TensorDim(const std::array<size_t, 3> &shapes, TensorType t_type_) :
TensorDim({shapes[0], shapes[1], shapes[2]}, t_type_) {}

// TensorDim::TensorDim(const std::array<size_t, 3> &shapes, TensorDim::Format fm,
// TensorDim::DataType d_type) :
// TensorDim({shapes[0], shapes[1], shapes[2]}, TensorType{fm, d_type}) {}

TensorDim::TensorDim(size_t d0, size_t d1, size_t d2, size_t d3,
TensorType t_type_,
const std::bitset<MAXDIM> &eff_dim_flag_,
Expand Down Expand Up @@ -218,38 +210,23 @@ void swap(TensorDim &lhs, TensorDim &rhs) noexcept {

size_t TensorDim::batch() const { return dim[0]; };

size_t TensorDim::channel() const {
return t_type.format == Format::NCHW ? dim[1] : dim[3];
};
size_t TensorDim::channel() const { return dim[1]; };

size_t TensorDim::height() const {
return t_type.format == Format::NCHW ? dim[2] : dim[1];
};
size_t TensorDim::height() const { return dim[2]; };

size_t TensorDim::width() const {
return t_type.format == Format::NCHW ? dim[3] : dim[2];
};
size_t TensorDim::width() const { return dim[3]; };

size_t TensorDim::getDataLen() const { return len; };

size_t TensorDim::getFeatureLen() const { return feature_len; };

void TensorDim::batch(size_t b) { setTensorDim(0, b); }

void TensorDim::channel(size_t c) {
uint i = (t_type.format == Format::NCHW) ? 1 : 3;
setTensorDim(i, c);
}
void TensorDim::channel(size_t c) { setTensorDim(1, c); }

void TensorDim::height(size_t h) {
uint i = (t_type.format == Format::NCHW) ? 2 : 1;
setTensorDim(i, h);
}
void TensorDim::height(size_t h) { setTensorDim(2, h); }

void TensorDim::width(size_t w) {
uint i = (t_type.format == Format::NCHW) ? 3 : 2;
setTensorDim(i, w);
}
void TensorDim::width(size_t w) { setTensorDim(3, w); }

const size_t *TensorDim::getDim() const { return dim; }

Expand Down Expand Up @@ -324,7 +301,10 @@ const size_t &TensorDim::operator[](const unsigned int index) const {
}

std::array<size_t, TensorDim::MAXDIM> TensorDim::computeStrides() const {
return {dim[1] * dim[2] * dim[3], dim[2] * dim[3], dim[3], 1};
if (getFormat() == TensorDim::Format::NCHW)
return {dim[1] * dim[2] * dim[3], dim[2] * dim[3], dim[3], 1};
else
return {height() * channel() * width(), width() * channel(), channel(), 1};
}

void TensorDim::reverse() { std::reverse(dim, dim + MAXDIM); }
Expand Down Expand Up @@ -355,6 +335,7 @@ std::vector<int> TensorDim::getEffectiveDimension(bool dynamic) const {
bool TensorDim::is_dynamic() const { return dyn_dim_flag.any(); }

std::ostream &operator<<(std::ostream &out, TensorDim const &d) {

std::string type_ =
(d.getDataType() == ml::train::TensorDim::DataType::FP16) ? "FP16" : "FP32";
std::string format_ =
Expand Down
Loading

0 comments on commit 4bea951

Please sign in to comment.