Skip to content

Commit

Permalink
repalce util::arrToTensor with tensor_to_const and remove addSliceInput
Browse files Browse the repository at this point in the history
Signed-off-by: Ruoqian Guo <[email protected]>
  • Loading branch information
ruoqianguo committed Feb 9, 2021
1 parent 309b701 commit bf5718d
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 79 deletions.
1 change: 0 additions & 1 deletion core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ cc_library(
deps = [
"@tensorrt//:nvinfer",
"//core/util:prelude",
"//core/util:converter_util",
"//core/conversion/var",
"//core/conversion/tensorcontainer",
"//core/conversion/conversionctx",
Expand Down
43 changes: 13 additions & 30 deletions core/conversion/converters/impl/expand.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "NvInfer.h"
#include "core/conversion/converters/converters.h"
#include "core/conversion/tensorcontainer/TensorContainer.h"
#include "core/util/converter_util.h"
#include "core/util/prelude.h"
#include "core/util/trt_util.h"
#include "torch/torch.h"
Expand All @@ -16,25 +15,12 @@ namespace converters {
namespace impl {
namespace {

void addSliceInput(nvinfer1::Dims& dims, int idx, ConversionCtx* ctx, nvinfer1::ISliceLayer* slice) {
int32_t rank = static_cast<int32_t>(dims.nbDims);
int32_t* tmp = new int32_t[rank];
for (int i = 0; i < rank; i++)
tmp[i] = dims.d[i];
const nvinfer1::Dims d{1, {rank}};
const nvinfer1::Weights w{nvinfer1::DataType::kINT32, tmp, rank};
auto t = ctx->net->addConstant(d, w)->getOutput(0);
slice->setInput(idx, *t);
}

nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfer1::ITensor* tensor) {
if (max_rank - old_rank > 0) {
int32_t* tmp = new int32_t[max_rank - old_rank];
for (int i = 0; i < (max_rank - old_rank); i++)
tmp[i] = 1;
auto max_rank_tensor = util::arrToTensor(tmp, max_rank - old_rank, ctx);
torch::Tensor thOne = torch::tensor(std::vector<int32_t>(max_rank - old_rank, 1), torch::kInt32);
auto one_tensor = tensor_to_const(ctx, thOne);
auto in_shape_tensor = ctx->net->addShape(*tensor)->getOutput(0);
nvinfer1::ITensor* const args[2] = {max_rank_tensor, in_shape_tensor};
nvinfer1::ITensor* const args[2] = {one_tensor, in_shape_tensor};
return ctx->net->addConcatenation(args, 2)->getOutput(0);
} else { // max_rank - old_rank == 0
return ctx->net->addShape(*tensor)->getOutput(0);
Expand Down Expand Up @@ -166,7 +152,6 @@ bool add_expand_dynamic(

// Dimensions are right alignment. Eg: an input of [3, 1] and max_rank = 4, the result of concat is [1, 1, 3, 1]
auto new_input_shape_tensor = concat(max_rank, input_rank, ctx, in);
// LOG_DEBUG("Expand layer output tensor shape: " << new_output_shape_tensor->getDimensions());
auto new_output_shape_tensor = expandedDimsTensor;

// Add a reshape layer to expand dims
Expand All @@ -176,6 +161,8 @@ bool add_expand_dynamic(
// Start the slicing from beginning of tensor since this is an expand layer
std::vector<int64_t> start_vec(max_rank, 0);
nvinfer1::Dims starts_dim = util::toDims(c10::IntArrayRef(start_vec));
at::Tensor thStart = torch::tensor(util::toVec(starts_dim), torch::kInt32);
auto starts = tensor_to_const(ctx, thStart);

// compute sizes = max(x,y).
auto sizes =
Expand All @@ -186,18 +173,17 @@ bool add_expand_dynamic(

// Compute (x > 1 ? 1 : 0) for x in newDims, assuming positive x, using only TensorRT operations.
// min(1, sub(input_shape, 1))
int32_t* one_vector_tmp = new int32_t[1];
one_vector_tmp[0] = 1;
auto one_vector = util::arrToTensor(one_vector_tmp, 1, ctx);
auto x_sub_one = ctx->net->addElementWise(*new_input_shape_tensor, *one_vector, nvinfer1::ElementWiseOperation::kSUB)
torch::Tensor thOne = torch::tensor({1}, torch::kInt32);
auto one_tensor = tensor_to_const(ctx, thOne);
auto x_sub_one = ctx->net->addElementWise(*new_input_shape_tensor, *one_tensor, nvinfer1::ElementWiseOperation::kSUB)
->getOutput(0);
auto strides = ctx->net->addElementWise(*one_vector, *x_sub_one, nvinfer1::ElementWiseOperation::kMIN)->getOutput(0);
auto strides = ctx->net->addElementWise(*one_tensor, *x_sub_one, nvinfer1::ElementWiseOperation::kMIN)->getOutput(0);
nvinfer1::Dims strides_dim{-1, {}};
strides_dim.nbDims = max_rank;

// Slice layer does the expansion in TRT. Desired output size is specified by expandedDimsTensor
// Slice layer does the expansion in TRT. Desired output size is specified by sizes input at index 2.
auto slice = ctx->net->addSlice(*shuffle->getOutput(0), starts_dim, sizes_dim, strides_dim);
addSliceInput(starts_dim, 1, ctx, slice);
slice->setInput(1, *starts);
slice->setInput(2, *sizes);
slice->setInput(3, *strides);

Expand All @@ -219,11 +205,8 @@ auto expand_registrations TRTORCH_UNUSED =
auto expandedDims = util::toDims(expanded_size);
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
if (ctx->input_is_dynamic) {
int expanded_size_rank = static_cast<int>(expanded_size.size());
int32_t* tmp = new int32_t[expanded_size_rank];
for (int i = 0; i < expanded_size_rank; i++)
tmp[i] = expanded_size[i];
auto expandedDimsTensor = util::arrToTensor(tmp, expanded_size_rank, ctx);
at::Tensor thExpanded_size = torch::tensor(expanded_size.vec(), torch::kInt32);
auto expandedDimsTensor = tensor_to_const(ctx, thExpanded_size);
return add_expand_dynamic(ctx, n, in, expandedDimsTensor, expandedDims, true);
} else {
return add_expand(ctx, n, in, expandedDims);
Expand Down
20 changes: 1 addition & 19 deletions core/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,6 @@ cc_library(
})
)

cc_library(
name = "converter_util",
hdrs = [
"converter_util.h",
],
srcs = [
"converter_util.cpp"
],
deps = [
"//core/conversion/conversionctx"
]+ select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
}),
)


load("@rules_pkg//:pkg.bzl", "pkg_tar")

pkg_tar(
Expand All @@ -112,7 +95,6 @@ pkg_tar(
"//core/util:Exception.h",
"//core/util:prelude.h",
"//core/util:jit_util.h",
"//core/util:trt_util.h",
"//core/util:converter_util.h"
"//core/util:trt_util.h"
],
)
15 changes: 0 additions & 15 deletions core/util/converter_util.cpp

This file was deleted.

14 changes: 0 additions & 14 deletions core/util/converter_util.h

This file was deleted.

0 comments on commit bf5718d

Please sign in to comment.