Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Handle 3D tensors in cuDNN legacy API
Browse files Browse the repository at this point in the history
  • Loading branch information
Vladimir Cherepanov committed Dec 9, 2021
1 parent 8db4e88 commit 0ac4a0b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 29 deletions.
9 changes: 0 additions & 9 deletions src/common/cuda/cudnn_cxx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,6 @@ std::vector<Descriptor> GetSomeAttrs(size_t max_n,
return ret;
}

std::vector<int64_t> PackedStrides(const std::vector<size_t>& order,
const std::vector<int64_t>& dims) {
CHECK_EQ(order.size(), dims.size());
std::vector<int64_t> ret(dims.size(), 1);
for (size_t i = dims.size() - 1; i--;)
ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]];
return ret;
}

std::vector<Descriptor> GetPlans(cudnnBackendHeurMode_t h_mode,
cudnnHandle_t handle,
const Descriptor& op_graph,
Expand Down
10 changes: 8 additions & 2 deletions src/common/cuda/cudnn_cxx.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,14 @@ std::vector<Descriptor> GetSomeAttrs(size_t max_n,
cudnnBackendDescriptorType_t type);

// Order sets layout, as a permutation of dims, with N,C,<spacial dims> being identity.
std::vector<int64_t> PackedStrides(const std::vector<size_t>& order,
const std::vector<int64_t>& dims);
template <typename T>
std::vector<T> PackedStrides(const std::vector<size_t>& order, const std::vector<T>& dims) {
CHECK_EQ(order.size(), dims.size());
std::vector<T> ret(dims.size(), 1);
for (size_t i = dims.size() - 1; i--;)
ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]];
return ret;
}

// Given an engine config's `notes`, return whether that config is compatible, i.e. does
// the config have all of the required notes and none of the notes that are being excluded.
Expand Down
22 changes: 5 additions & 17 deletions src/operator/cudnn_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

#include <dmlc/parameter.h>

#include <algorithm>
#include <cstdlib>
#include <iomanip>
#include <iterator>
Expand Down Expand Up @@ -79,10 +78,6 @@ size_t LayoutInfo::ChannelIdx() const {
return channel_last ? 1 + n_space_dims : 1;
}

std::vector<int64_t> LayoutInfo::Strides(const std::vector<int64_t>& dims) const {
return PackedStrides(Order(), dims);
}

LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout) {
static std::unordered_map<mshadow::LayoutFlag, LayoutInfo> layout_map{
{mshadow::kNCW, {1, false}},
Expand Down Expand Up @@ -165,14 +160,8 @@ Descriptor MakeTensorDesc(int64_t uid,
for (size_t i = 0; i < dims.size(); ++i)
dims[i] = blob.shape_[rev_order[i]];
auto strides = li.Strides(dims);
if (li.n_space_dims == 1 && expand_1d) {
dims.insert(dims.begin() + 2, 1);
std::vector<size_t> order(dims.size());
std::iota(order.begin(), order.end(), 0);
if (li.channel_last)
std::rotate(order.begin() + 1, order.begin() + 2, order.end());
strides = PackedStrides(order, dims);
}
if (expand_1d)
li.ExpandIf1d(&dims, &strides);
return MakeTensorDesc(
uid, CudnnType(static_cast<mshadow::TypeFlag>(blob.type_flag_)), dims, strides, is_virtual);
}
Expand Down Expand Up @@ -803,9 +792,8 @@ void SetLegacyTensor(cudnnTensorDescriptor_t desc, const TBlob& blob, const Layo
auto rev_order = ReverseOrder(li.Order());
for (size_t i = 0; i < dims.size(); ++i)
dims[i] = blob.shape_[rev_order[i]];
auto strides64 = li.Strides(std::vector<int64_t>(dims.begin(), dims.end()));
std::vector<int> strides(strides64.begin(), strides64.end());

auto strides = li.Strides(dims);
li.ExpandIf1d(&dims, &strides);
auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
}
Expand All @@ -817,7 +805,7 @@ void SetLegacyCTensorExpandDims(cudnnTensorDescriptor_t desc,
dims[1] = blob.shape_[0];
std::vector<int> strides(dims.size(), 1);
strides[0] = blob.shape_[0];

li.ExpandIf1d(&dims, &strides);
auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
}
Expand Down
19 changes: 18 additions & 1 deletion src/operator/cudnn_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <mxnet/op_attr_types.h>

#include <algorithm>
#include <mutex>
#include <tuple>
#include <unordered_map>
Expand Down Expand Up @@ -89,7 +90,23 @@ struct LayoutInfo {

std::vector<size_t> Order() const;
size_t ChannelIdx() const;
std::vector<int64_t> Strides(const std::vector<int64_t>& dims) const;

template <typename T>
std::vector<T> Strides(const std::vector<T>& dims) const {
return cudnn_cxx::PackedStrides(Order(), dims);
}

template <typename T>
void ExpandIf1d(std::vector<T>* dims, std::vector<T>* strides) const {
if (n_space_dims != 1)
return;
dims->insert(dims->begin() + 2, 1);
std::vector<size_t> order(dims->size());
std::iota(order.begin(), order.end(), 0);
if (channel_last)
std::rotate(order.begin() + 1, order.begin() + 2, order.end());
*strides = cudnn_cxx::PackedStrides(order, *dims);
}
};

LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout);
Expand Down

0 comments on commit 0ac4a0b

Please sign in to comment.