Skip to content

Commit

Permalink
feat: support aten::format evaluator
Browse files Browse the repository at this point in the history
Signed-off-by: Ruoqian Guo <[email protected]>
  • Loading branch information
ruoqianguo committed Nov 15, 2021
1 parent da15fa5 commit 3a33d33
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
18 changes: 17 additions & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,23 @@ auto aten_registrations TORCHTRT_UNUSED =
},
EvalOptions().validSchemas({
R"SIG(aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> (Tensor(a!)))SIG",
})});
})})
.evaluator({c10::Symbol::fromQualString("aten::format"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
int64_t input_num = n->inputs().size();
std::vector<torch::jit::IValue> stack;
for (auto v : n->inputs()) {
stack.push_back(*args.at(v).IValue());
}
stack.push_back(input_num);
auto& ops = torch::jit::getAllOperatorsFor(c10::Symbol::fromQualString("aten::format"));
auto& aten_format = ops.front();
aten_format->getOperation()(stack);
std::string output;
torch::jit::pop(stack, output);
return output;
},
EvalOptions().validSchemas({"aten::format(str self, ...) -> (str)"})});
} // namespace
} // namespace evaluators
} // namespace conversion
Expand Down
34 changes: 34 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,4 +579,38 @@ TEST(Evaluators, AndBoolResultIsFalseEvaluatesCorrectly) {
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, AtenFormatEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%x_1 : Tensor, %x_2 : Tensor):
%0 : int = prim::Constant[value=1]()
%1 : str = prim::Constant[value="res{}_{}_"]()
%2 : int = prim::Constant[value=5]()
%2.1 : int = prim::Constant[value=2]()
%3 : str = prim::Constant[value="res5_2_"]()
%4 : str = aten::format(%1, %2, %2.1)
%5 : bool = aten::eq(%3, %4)
%y : Tensor = prim::If(%5)
block0():
%194 : Tensor = aten::add(%x_1, %x_2, %0)
-> (%194)
block1():
%195 : Tensor = aten::sub(%x_1, %x_2, %0)
-> (%195)
return (%y))IR";
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in0 = at::randint(1, 10, {3, 4}, {at::kCUDA});
auto in1 = in0.clone();

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in0, in1});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in0, in1});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

0 comments on commit 3a33d33

Please sign in to comment.