From 6aaba3bbd51704314edc7ae259ce1658d52c874e Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 24 Jun 2021 15:00:23 -0700 Subject: [PATCH] feat(aten::sqrt): Adding support for sqrt evaluators Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/evaluators/aten.cpp | 19 ++++++++++++ .../evaluators/test_aten_evaluators.cpp | 31 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 514d247ff9..c3c4f5b02b 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -540,6 +540,25 @@ auto aten_registrations TRTORCH_UNUSED = "aten::floor.int(int a) -> (int)", "aten::floor.float(float a) -> (int)", })}) + .evaluator({c10::Symbol::fromQualString("aten::sqrt"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (args.at(n->input(0)).IValue()->isInt()) { + auto a = args.at(n->input(0)).unwrapToInt(); + return std::sqrt(static_cast(a)); + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto a = args.at(n->input(0)).unwrapToDouble(); + return std::sqrt(a); + } else { + TRTORCH_THROW_ERROR( + "Unimplemented data type for aten::sqrt evaluator: " + << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } + }, + EvalOptions().validSchemas({ + "aten::sqrt.int(int a) -> (float)", + "aten::sqrt.float(float a) -> (float)", + })}) .evaluator({c10::Symbol::fromQualString("aten::warn"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { auto warning = args.at(n->input(0)).IValue(); diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index 4b71fcda4f..c01aa73650 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -357,6 +357,37 @@ TEST(Evaluators, ATenAppendWithITensorAndTensorEvaluatesCorrectly) { ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Evaluators, SqrtIntEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=9]() + %2 : float = aten::sqrt(%1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, SqrtFloatEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : float = prim::Constant[value=9.0]() + %2 : float = aten::sqrt(%1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} TEST(Evaluators, ATenCloneEvaluatesCorrectly) { const auto graph = R"IR( graph(%0 : Tensor):