diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 78a7dd26b4..d238e52831 100644 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -33,7 +33,8 @@ cc_library( deps = [ "@tensorrt//:nvinfer", "//core/util:prelude", - "//core/conversion/arg", + "//core/conversion/var", + "//core/conversion/tensorcontainer", "//core/conversion/conversionctx", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], diff --git a/core/conversion/converters/impl/concat.cpp b/core/conversion/converters/impl/concat.cpp new file mode 100644 index 0000000000..c109772831 --- /dev/null +++ b/core/conversion/converters/impl/concat.cpp @@ -0,0 +1,47 @@ +#include "core/util/prelude.h" +#include "core/conversion/converters/converters.h" +#include "core/conversion/tensorcontainer/TensorContainer.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { +auto cat_registrations = RegisterNodeConversionPatterns() + .pattern({ + "aten::cat(Tensor[] tensors, int dim=0) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto ts = args[0].IValue()->toListRef(); + auto dim = args[1].unwrapToInt(); + + std::vector tensors; + for (auto t : ts) { + std::cout << t << std::endl; + if (t.isTensor()) { + auto torch_tensor = t.toTensor(); + auto t_weights = Weights(ctx, torch_tensor); + auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); + tensors.push_back(const_layer->getOutput(0)); + } else { + auto cont = t.toCustomClass(); + tensors.push_back(cont->tensor()); + } + } + + auto cat_layer = ctx->net->addConcatenation(tensors.data(), tensors.size()); + cat_layer->setAxis(static_cast(dim)); + auto cat_out = ctx->AssociateValueAndTensor(n->outputs()[0], cat_layer->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << cat_out->getDimensions()); + + return true; + } + }); +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch + diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index b303b3f8e4..932e1094c1 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -74,8 +74,12 @@ auto prim_registrations = RegisterNodeEvaluators() auto list = c10::impl::GenericList(elementType); list.reserve(num_inputs); for (auto in : n->inputs()) { - auto x = torch::make_custom_class(reinterpret_cast(args.at(in).ITensor())); - list.emplace_back(std::move(x)); + if (args.at(in).isITensor()) { + auto x = torch::make_custom_class(reinterpret_cast(args.at(in).ITensor())); + list.emplace_back(std::move(x)); + } else { + list.emplace_back(std::move(args.at(in).unwrapToTensor())); + } } return c10::optional(std::move(torch::jit::IValue(list))); } diff --git a/tests/core/converters/BUILD b/tests/core/converters/BUILD index 2a93517e7d..49c88e963a 100644 --- a/tests/core/converters/BUILD +++ b/tests/core/converters/BUILD @@ -15,6 +15,10 @@ converter_test( name = "test_batch_norm" ) +converter_test( + name = "test_concat" +) + converter_test( name = "test_conv_deconv" ) diff --git a/tests/core/converters/test_concat.cpp b/tests/core/converters/test_concat.cpp new file mode 100644 index 0000000000..359e15eadf --- /dev/null +++ b/tests/core/converters/test_concat.cpp @@ -0,0 +1,53 @@ +#include +#include "gtest/gtest.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "tests/util/util.h" +#include "core/compiler.h" + +TEST(Converters, ATenCatPureTensorConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Tensor): + %2 : Tensor[] = prim::ListConstruct(%0, %1) + %3 : int = prim::Constant[value=0]() + %4 : Tensor = aten::cat(%2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in1 = at::randint(1, 10, {5}, {at::kCUDA}); + auto in2 = at::randint(1, 10, {5}, {at::kCUDA}); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2}); + + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenCatDiffTensorConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(5)): + %2 : Tensor[] = prim::ListConstruct(%0, %1) + %3 : int = prim::Constant[value=0]() + %4 : Tensor = aten::cat(%2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in1 = at::randint(1, 10, {5}, {at::kCUDA}); + auto in2 = at::randint(1, 10, {5}, {at::kCUDA}); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {in2}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1}); + + params = trtorch::core::conversion::get_named_params(g->inputs(), {in2}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} \ No newline at end of file