Skip to content

Commit

Permalink
feat(//core/conversion): Handle adding and wrapping ITensors as
Browse files Browse the repository at this point in the history
arguments of append and unwrapping singular ITensors as outputs of
evaluators

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 28, 2021
1 parent a7d2b5e commit a22e99b
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 4 deletions.
16 changes: 14 additions & 2 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,15 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
if (result) {
// WARN: If the converter returns None then should pass through
// but if repeated dep this section will get called each time
ctx->evaluated_value_map[eval_in] = std::move(result.value());
eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]);
auto val = result.value();
if (val.isCustomClass()){
auto cont = val.toCustomClass<TensorContainer>();
ctx->AssociateValueAndTensor(eval_in, cont->tensor());
eval_args[eval_in] = ctx->value_tensor_map[eval_in];
} else {
ctx->AssociateValueAndIValue(eval_in, val);
eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]);
}
}
} else {
TRTORCH_THROW_ERROR(
Expand Down Expand Up @@ -374,6 +381,11 @@ void ConvertBlockToNetDef(
} else {
TRTORCH_THROW_ERROR("Unsupported return type for evaluated node");
}
} else if (eval.value().isCustomClass()) {
auto container = eval.value().toCustomClass<TensorContainer>();
auto tensor = container->tensor();
LOG_DEBUG(ctx->logger, "Found the value to be an ITensor of shape: " << tensor->getDimensions());
ctx->AssociateValueAndTensor(n->output(0), tensor);
} else if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
ctx->AssociateValueAndIValue(n->output(0), eval.value());
Expand Down
12 changes: 10 additions & 2 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,17 @@ auto aten_registrations TRTORCH_UNUSED =
.evaluator({c10::Symbol::fromQualString("aten::append"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
auto el = args.at(n->input(1)).IValue();

list.push_back(std::move(*el));
if (args.at(n->input(1)).isITensor()) {
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(args.at(n->input(1)).ITensor());
auto el = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.push_back(std::move(el));
} else {
auto el = args.at(n->input(1)).IValue();
list.push_back(std::move(*el));
}

return list;
},
EvalOptions().validSchemas({
Expand Down
81 changes: 81 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,85 @@ TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) {
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, ATenAppendWithITensorEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : int = prim::Constant[value=0]()
%3 : Tensor[] = prim::ListConstruct(%0)
%4 : Tensor[] = aten::append(%3, %1)
%5 : Tensor = aten::cat(%4, %2)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in0 = at::randint(1, 10, {3, 3}, {at::kCUDA});
auto in1 = at::randint(1, 10, {3, 3}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in0, in1});

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0, in1});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Evaluators, ATenAppendWithTensorEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int[] = prim::Constant[value=[3,3]]()
%2 : None = prim::Constant() # :0:0
%20 : Device = prim::Constant[value="cuda"]()
%3 : Tensor = aten::zeros(%1, %2, %2, %20, %2)
%4 : Tensor = aten::zeros(%1, %2, %2, %20, %2)
%5 : int = prim::Constant[value=0]()
%15 : int = prim::Constant[value=1]()
%6 : Tensor[] = prim::ListConstruct(%3)
%7 : Tensor[] = aten::append(%6, %4)
%8 : Tensor = aten::cat(%7, %5)
%9 : Tensor = aten::add(%8, %0, %15)
return (%9))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in0 = at::randint(1, 10, {6, 3}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in0});

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Evaluators, ATenAppendWithITensorAndTensorEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int[] = aten::size(%0)
%2 : None = prim::Constant() # :0:0
%20 : Device = prim::Constant[value="cuda"]()
%3 : Tensor = aten::zeros(%1, %2, %2, %20, %2)
%4 : int = prim::Constant[value=0]()
%5 : Tensor[] = prim::ListConstruct(%0)
%6 : Tensor[] = aten::append(%5, %3)
%7 : Tensor = aten::cat(%6, %4)
return (%7))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in0 = at::randint(1, 10, {3, 3}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in0});

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

0 comments on commit a22e99b

Please sign in to comment.