Skip to content

Commit

Permalink
feat: support aten::Int
Browse files Browse the repository at this point in the history
Signed-off-by: inocsin <[email protected]>
  • Loading branch information
inocsin committed Jul 28, 2021
1 parent bd72677 commit 5bc977d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
24 changes: 24 additions & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,30 @@ auto aten_registrations TRTORCH_UNUSED =
"aten::Float.int(int a) -> float",
"aten::Float.bool(bool a) -> float",
})})
.evaluator({c10::Symbol::fromQualString("aten::Int"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (args.at(n->input(0)).IValue()->isInt()) {
auto a = args.at(n->input(0)).unwrapToInt();
return (int)a;
} else if (args.at(n->input(0)).IValue()->isDouble()) {
auto a = args.at(n->input(0)).unwrapToDouble();
return (int)a;
} else if (args.at(n->input(0)).IValue()->isBool()) {
auto a = args.at(n->input(0)).unwrapToBool();
return (int)a;
} else {
TRTORCH_THROW_ERROR(
"Unimplemented data type for aten::Int evaluator: "
<< args.at(n->input(0)).IValue()->type()->str());
return {};
}
},
EvalOptions().validSchemas({
"aten::Int.Scalar(Scalar a) -> int",
"aten::Int.int(int a) -> int",
"aten::Int.bool(bool a) -> int",
"aten::Int.float(float a) -> int",
})})
.evaluator({c10::Symbol::fromQualString("aten::__not__"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto el = args.at(n->input(0)).unwrapToBool();
Expand Down
16 changes: 16 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,20 @@ TEST(Evaluators, ATenCopyEvaluatesCorrectly) {
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});

ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
}

TEST(Evaluators, IntFloatEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : float = prim::Constant[value=9.3]()
%2 : int = aten::Int(%1)
return (%2))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

0 comments on commit 5bc977d

Please sign in to comment.