Skip to content

Commit

Permalink
feat(aten::prelu): Implement the multi-channel version of prelu and
Browse files Browse the repository at this point in the history
broadcasting checks

Signed-off-byL Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jul 27, 2020
1 parent 8bc4369 commit c066581
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 33 deletions.
5 changes: 4 additions & 1 deletion core/conversion/converters/converters.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ struct Weights {

inline nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) {
auto t_weights = Weights(ctx, t);
return ctx->net->addConstant(t_weights.shape, t_weights.data)->getOutput(0);
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
TRTORCH_CHECK(const_layer, "Unable to freeze tensor");
const_layer->setName("[Freeze Tensor]");
return const_layer->getOutput(0);
}

} // namespace converters
Expand Down
30 changes: 23 additions & 7 deletions core/conversion/converters/impl/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,34 @@ auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns()
auto in = args[0].ITensor();
auto slopes = args[1].unwrapToTensor();

//if (slopes.numel() != 1) {
// auto in_dims = util::toVec(in.getDimensions());
// auto per_channel_shape = std::vector<int64_t>(in_dims.begin() + 2, in_dims.end());
// for ()
//}
bool to_reshape = false;
auto original_shape = in->getDimensions();
if (slopes.numel() != 1 && !util::broadcastable(in->getDimensions(), util::toDims(slopes.sizes()), /*multidirectional=*/false)) {
if (util::volume(in->getDimensions()) == util::volume(util::toDims(slopes.sizes()))) {
to_reshape = true;
LOG_DEBUG("Input shape is not broadcastable inserting shuffle layers to reshape to " << util::toDims(slopes.sizes()));
auto in_shuffle = ctx->net->addShuffle(*in);
TRTORCH_CHECK(in_shuffle, "Unable to create resize layer for aten::prelu input");
in_shuffle->setReshapeDimensions(util::toDims(slopes.sizes()));
in_shuffle->setName(std::string("[Reshape in to " + util::toStr(util::toDims(slopes.sizes())) + " for broadcasting]").c_str());
in = in_shuffle->getOutput(0);
}
}

auto slope_tensor = tensor_to_const(ctx, slopes);

auto new_layer = ctx->net->addParametricReLU(*in, *slope_tensor);
new_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
auto out_tensor = new_layer->getOutput(0);

if (to_reshape) {
auto out_shuffle = ctx->net->addShuffle(*out_tensor);
TRTORCH_CHECK(out_shuffle, "Unable to create resize layer for aten::prelu output");
out_shuffle->setReshapeDimensions(original_shape);
out_shuffle->setName((std::string("[Reshape back to ") + util::toStr(original_shape) + std::string("]")).c_str());
out_tensor = out_shuffle->getOutput(0);
}

out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}
Expand Down
90 changes: 65 additions & 25 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,59 @@ namespace trtorch {
namespace core {
namespace util {

bool broadcastable(nvinfer1::Dims a, nvinfer1::Dims b, bool multidirectional) {
if (a == b) {
return true;
}

if (multidirectional) {
nvinfer1::Dims a_dims_eq;
nvinfer1::Dims b_dims_eq;
if (a.nbDims > b.nbDims) {
a_dims_eq = a;
b_dims_eq = toDimsPad(toVec(b), a.nbDims);
} else if (a.nbDims < b.nbDims) {
a_dims_eq = toDimsPad(toVec(a), b.nbDims);
b_dims_eq = b;
} else {
a_dims_eq = a;
b_dims_eq = b;
}

bool broadcastable = true;
for (int i = 0; i < a_dims_eq.nbDims; i++) {
if (b_dims_eq.d[i] == a_dims_eq.d[i] || (b_dims_eq.d[i] == 1 || a_dims_eq.d[i] == 1)) {
continue;
} else {
broadcastable = false;
break;
}
}
return broadcastable;
} else {
nvinfer1::Dims b_dims_eq;
if (a.nbDims > b.nbDims) {
b_dims_eq = toDimsPad(toVec(b), a.nbDims);
} else if (a.nbDims < b.nbDims) {
return false;
} else {
b_dims_eq = b;
}

bool broadcastable = true;
for (int i = 0; i < a.nbDims; i++) {
if (b_dims_eq.d[i] == a.d[i] || b_dims_eq.d[i] == 1) {
continue;
} else {
broadcastable = false;
break;
}
}
return broadcastable;
}
}


int64_t volume(const nvinfer1::Dims& d) {
return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies<int64_t>());
}
Expand All @@ -16,10 +69,7 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) {
return toDims(l);
}

if (pad_to > nvinfer1::Dims::MAX_DIMS) {
//TODO: Handle this with exceptions or whatever
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
}
TRTORCH_CHECK(pad_to <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");

nvinfer1::Dims dims;
dims.nbDims = pad_to;
Expand All @@ -34,10 +84,8 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) {
}

nvinfer1::Dims toDims(c10::IntArrayRef l) {
if (l.size() > nvinfer1::Dims::MAX_DIMS) {
//TODO: Handle this with exceptions or whatever
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
}
TRTORCH_CHECK(l.size() <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");

nvinfer1::Dims dims;
dims.nbDims = l.size();
for (size_t i = 0; i < l.size(); i++) {
Expand All @@ -47,10 +95,8 @@ nvinfer1::Dims toDims(c10::IntArrayRef l) {
}

nvinfer1::Dims toDims(c10::List<int64_t> l) {
if (l.size() > nvinfer1::Dims::MAX_DIMS) {
//TODO: Handle this with exceptions or whatever
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
}
TRTORCH_CHECK(l.size() <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");

nvinfer1::Dims dims;
dims.nbDims = l.size();
for (size_t i = 0; i < l.size(); i++) {
Expand All @@ -65,10 +111,8 @@ nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to) {
return toDims(l);
}

if (pad_to > nvinfer1::Dims::MAX_DIMS) {
//TODO: Handle this with exceptions or whatever
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
}
TRTORCH_CHECK(pad_to <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");


nvinfer1::Dims dims;
dims.nbDims = pad_to;
Expand Down Expand Up @@ -109,7 +153,7 @@ nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) {
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos) {
// acceptable range for pos is [0, d.nbDims]
TRTORCH_ASSERT(pos >= 0 && pos <= d.nbDims, "ERROR: Index to unsqueeze is out of bounds.");

nvinfer1::Dims dims;

int i = 0;
Expand Down Expand Up @@ -148,10 +192,8 @@ std::string toStr(nvinfer1::Dims d) {


nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l) {
if (l.size() != 2) {
//TODO: Handle this with exceptions or whatever
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::DimsHW is not 2");
}
TRTORCH_CHECK(l.size() == 2, "The list requested to be converted to nvinfer1::DimsHW is not 2");

nvinfer1::DimsHW dims;
dims.nbDims = l.size();
for (size_t i = 0; i < l.size(); i++) {
Expand All @@ -161,10 +203,8 @@ nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l) {
}

nvinfer1::DimsHW toDimsHW(c10::IntArrayRef l) {
if (l.size() != 2) {
//TODO: Handle this with exceptions or whatever
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::DimsHW is not 2");
}
TRTORCH_CHECK(l.size() == 2, "The list requested to be converted to nvinfer1::DimsHW is not 2");

nvinfer1::DimsHW dims;
dims.nbDims = l.size();
for (size_t i = 0; i < l.size(); i++) {
Expand Down
1 change: 1 addition & 0 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ namespace util {

int64_t volume(const nvinfer1::Dims& d);

bool broadcastable(nvinfer1::Dims a, nvinfer1::Dims b, bool multidirectional=true);
nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to);
nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to);
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);
Expand Down

0 comments on commit c066581

Please sign in to comment.