Skip to content

Commit

Permalink
feat: support aten::extend evaluator
Browse files Browse the repository at this point in the history
Signed-off-by: Ruoqian Guo <[email protected]>
  • Loading branch information
ruoqianguo authored and narendasan committed Feb 24, 2022
1 parent b798c7f commit 33c523d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
20 changes: 20 additions & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,26 @@ auto aten_registrations TORCHTRT_UNUSED =
EvalOptions().validSchemas({
"aten::append.t(t[](a!) self, t(c -> *) el) -> (t[](a!))",
})})
.evaluator({c10::Symbol::fromQualString("aten::extend"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (args.at(n->input(0)).IValue()->isList() && args.at(n->input(1)).IValue()->isList()) {
auto self = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
auto other = args.at(n->input(1)).IValue()->to<c10::List<c10::IValue>>();
const int64_t other_size = other.size();

for (int64_t i = 0; i < other_size; i++) {
self.push_back(other.get(i));
}
} else {
TORCHTRT_THROW_ERROR(
"Unimplemented data type for aten::extend.t evaluator: "
<< args.at(n->input(0)).IValue()->type()->str() << ", "
<< args.at(n->input(1)).IValue()->type()->str());
}
},
EvalOptions().validSchemas({
"aten::extend.t(t[](a!) self, t[] other) -> ()",
})})
.evaluator({c10::Symbol::fromQualString("aten::neg"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto el = args.at(n->input(0)).unwrapToInt();
Expand Down
26 changes: 26 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,32 @@ TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) {
ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, ATenExtendEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : int = prim::Constant[value=0]()
%3 : Tensor[] = prim::ListConstruct(%0)
%4 : Tensor[] = prim::ListConstruct(%1)
aten::extend(%3, %4)
%5 : Tensor = aten::cat(%3, %2)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in0 = at::randint(1, 10, {3, 4}, {at::kCUDA});
auto in1 = at::randint(1, 10, {5, 4}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in0, in1});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in0, in1});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Evaluators, ATenAppendWithITensorEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
Expand Down

0 comments on commit 33c523d

Please sign in to comment.