diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 6ca527b805..b40a620681 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -45,8 +45,15 @@ c10::optional EvaluateNode(ConversionCtx* ctx, const torch:: if (result) { // WARN: If the converter returns None then should pass through // but if repeated dep this section will get called each time - ctx->evaluated_value_map[eval_in] = std::move(result.value()); - eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]); + auto val = result.value(); + if (val.isCustomClass()){ + auto cont = val.toCustomClass(); + ctx->AssociateValueAndTensor(eval_in, cont->tensor()); + eval_args[eval_in] = ctx->value_tensor_map[eval_in]; + } else { + ctx->AssociateValueAndIValue(eval_in, val); + eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]); + } } } else { TRTORCH_THROW_ERROR( @@ -374,6 +381,11 @@ void ConvertBlockToNetDef( } else { TRTORCH_THROW_ERROR("Unsupported return type for evaluated node"); } + } else if (eval.value().isCustomClass()) { + auto container = eval.value().toCustomClass(); + auto tensor = container->tensor(); + LOG_DEBUG(ctx->logger, "Found the value to be an ITensor of shape: " << tensor->getDimensions()); + ctx->AssociateValueAndTensor(n->output(0), tensor); } else if (!eval.value().isTensor()) { LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value()); ctx->AssociateValueAndIValue(n->output(0), eval.value()); diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 50d99459bd..2f2baece4e 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -216,9 +216,17 @@ auto aten_registrations TRTORCH_UNUSED = .evaluator({c10::Symbol::fromQualString("aten::append"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { auto list = args.at(n->input(0)).IValue()->to>(); - auto el = args.at(n->input(1)).IValue(); - list.push_back(std::move(*el)); + if (args.at(n->input(1)).isITensor()) { + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(args.at(n->input(1)).ITensor()); + auto el = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + list.push_back(std::move(el)); + } else { + auto el = args.at(n->input(1)).IValue(); + list.push_back(std::move(*el)); + } + return list; }, EvalOptions().validSchemas({ diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index cc6b0912ce..10fa76f9c8 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -235,4 +235,85 @@ TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) { auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, ATenAppendWithITensorEvaluatesCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : Tensor[] = prim::ListConstruct(%0) + %4 : Tensor[] = aten::append(%3, %1) + %5 : Tensor = aten::cat(%4, %2) + return (%5))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in0 = at::randint(1, 10, {3, 3}, {at::kCUDA}); + auto in1 = at::randint(1, 10, {3, 3}, {at::kCUDA}); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in0, in1}); + + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0, in1}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Evaluators, ATenAppendWithTensorEvaluatesCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[3,3]]() + %2 : None = prim::Constant() # :0:0 + %20 : Device = prim::Constant[value="cuda"]() + %3 : Tensor = aten::zeros(%1, %2, %2, %20, %2) + %4 : Tensor = aten::zeros(%1, %2, %2, %20, %2) + %5 : int = prim::Constant[value=0]() + %15 : int = prim::Constant[value=1]() + %6 : Tensor[] = prim::ListConstruct(%3) + %7 : Tensor[] = aten::append(%6, %4) + %8 : Tensor = aten::cat(%7, %5) + %9 : Tensor = aten::add(%8, %0, %15) + return (%9))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in0 = at::randint(1, 10, {6, 3}, {at::kCUDA}); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in0}); + + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Evaluators, ATenAppendWithITensorAndTensorEvaluatesCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = aten::size(%0) + %2 : None = prim::Constant() # :0:0 + %20 : Device = prim::Constant[value="cuda"]() + %3 : Tensor = aten::zeros(%1, %2, %2, %20, %2) + %4 : int = prim::Constant[value=0]() + %5 : Tensor[] = prim::ListConstruct(%0) + %6 : Tensor[] = aten::append(%5, %3) + %7 : Tensor = aten::cat(%6, %4) + return (%7))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in0 = at::randint(1, 10, {3, 3}, {at::kCUDA}); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in0}); + + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } \ No newline at end of file