Skip to content

Commit

Permalink
feat(aten::sqrt): Adding support for sqrt evaluators
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jul 28, 2021
1 parent 50f012e commit 6aaba3b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
19 changes: 19 additions & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,25 @@ auto aten_registrations TRTORCH_UNUSED =
"aten::floor.int(int a) -> (int)",
"aten::floor.float(float a) -> (int)",
})})
.evaluator({c10::Symbol::fromQualString("aten::sqrt"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (args.at(n->input(0)).IValue()->isInt()) {
auto a = args.at(n->input(0)).unwrapToInt();
return std::sqrt(static_cast<double>(a));
} else if (args.at(n->input(0)).IValue()->isDouble()) {
auto a = args.at(n->input(0)).unwrapToDouble();
return std::sqrt(a);
} else {
TRTORCH_THROW_ERROR(
"Unimplemented data type for aten::sqrt evaluator: "
<< args.at(n->input(0)).IValue()->type()->str());
return {};
}
},
EvalOptions().validSchemas({
"aten::sqrt.int(int a) -> (float)",
"aten::sqrt.float(float a) -> (float)",
})})
.evaluator({c10::Symbol::fromQualString("aten::warn"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto warning = args.at(n->input(0)).IValue();
Expand Down
31 changes: 31 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,37 @@ TEST(Evaluators, ATenAppendWithITensorAndTensorEvaluatesCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Evaluators, SqrtIntEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=9]()
%2 : float = aten::sqrt(%1)
return (%2))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, SqrtFloatEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : float = prim::Constant[value=9.0]()
%2 : float = aten::sqrt(%1)
return (%2))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}
TEST(Evaluators, ATenCloneEvaluatesCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
Expand Down

0 comments on commit 6aaba3b

Please sign in to comment.