Skip to content

Commit

Permalink
feat(aten::cat): Implements aten::cat and completes support for SSD
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 26, 2020
1 parent 619e345 commit c2d3a6e
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 3 deletions.
3 changes: 2 additions & 1 deletion core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ cc_library(
deps = [
"@tensorrt//:nvinfer",
"//core/util:prelude",
"//core/conversion/arg",
"//core/conversion/var",
"//core/conversion/tensorcontainer",
"//core/conversion/conversionctx",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
Expand Down
47 changes: 47 additions & 0 deletions core/conversion/converters/impl/concat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "core/util/prelude.h"
#include "core/conversion/converters/converters.h"
#include "core/conversion/tensorcontainer/TensorContainer.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace converters {
namespace impl {
namespace {
auto cat_registrations = RegisterNodeConversionPatterns()
.pattern({
"aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto ts = args[0].IValue()->toListRef();
auto dim = args[1].unwrapToInt();

std::vector<nvinfer1::ITensor*> tensors;
for (auto t : ts) {
std::cout << t << std::endl;
if (t.isTensor()) {
auto torch_tensor = t.toTensor();
auto t_weights = Weights(ctx, torch_tensor);
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
tensors.push_back(const_layer->getOutput(0));
} else {
auto cont = t.toCustomClass<TensorContainer>();
tensors.push_back(cont->tensor());
}
}

auto cat_layer = ctx->net->addConcatenation(tensors.data(), tensors.size());
cat_layer->setAxis(static_cast<int>(dim));
auto cat_out = ctx->AssociateValueAndTensor(n->outputs()[0], cat_layer->getOutput(0));

LOG_DEBUG("Output tensor shape: " << cat_out->getDimensions());

return true;
}
});
} // namespace
} // namespace impl
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch

8 changes: 6 additions & 2 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,12 @@ auto prim_registrations = RegisterNodeEvaluators()
auto list = c10::impl::GenericList(elementType);
list.reserve(num_inputs);
for (auto in : n->inputs()) {
auto x = torch::make_custom_class<TensorContainer>(reinterpret_cast<int64_t>(args.at(in).ITensor()));
list.emplace_back(std::move(x));
if (args.at(in).isITensor()) {
auto x = torch::make_custom_class<TensorContainer>(reinterpret_cast<int64_t>(args.at(in).ITensor()));
list.emplace_back(std::move(x));
} else {
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
}
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
}
Expand Down
4 changes: 4 additions & 0 deletions tests/core/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ converter_test(
name = "test_batch_norm"
)

converter_test(
name = "test_concat"
)

converter_test(
name = "test_conv_deconv"
)
Expand Down
53 changes: 53 additions & 0 deletions tests/core/converters/test_concat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include <string>
#include "gtest/gtest.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "tests/util/util.h"
#include "core/compiler.h"

TEST(Converters, ATenCatPureTensorConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor):
%2 : Tensor[] = prim::ListConstruct(%0, %1)
%3 : int = prim::Constant[value=0]()
%4 : Tensor = aten::cat(%2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in1 = at::randint(1, 10, {5}, {at::kCUDA});
auto in2 = at::randint(1, 10, {5}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenCatDiffTensorConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(5)):
%2 : Tensor[] = prim::ListConstruct(%0, %1)
%3 : int = prim::Constant[value=0]()
%4 : Tensor = aten::cat(%2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in1 = at::randint(1, 10, {5}, {at::kCUDA});
auto in2 = at::randint(1, 10, {5}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});

params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

0 comments on commit c2d3a6e

Please sign in to comment.