Skip to content

Commit

Permalink
feat(//core/conversion/evaluators): aten::pow support
Browse files Browse the repository at this point in the history
Adds support for the following aten::pow variants in the evaluator
library

```
 	"aten::pow.int(int a, int b) -> (float)",
        "aten::pow.float(float a, float b) -> (float)",
        "aten::pow.int_float(int a, float b) -> (float)",
        "aten::pow.float_int(float a, int b) -> (float)",
```

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Mar 8, 2022
1 parent c5c5c47 commit c4fdfcb
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 1 deletion.
15 changes: 14 additions & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <math.h>

#include "ATen/core/List.h"
#include "ATen/core/functional.h"
#include "ATen/core/ivalue.h"
Expand Down Expand Up @@ -98,10 +100,21 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
"aten::ge.float_int(float a, int b) -> (bool)",
}));

DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(
pow,
"aten::pow",
pow(a,b),
std::set<std::string>({
"aten::pow.int(int a, int b) -> (float)",
"aten::pow.float(float a, float b) -> (float)",
"aten::pow.int_float(int a, float b) -> (float)",
"aten::pow.float_int(float a, int b) -> (float)",
}));

DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
and,
"aten::__and__",
a&& b,
a && b,
bool,
std::set<std::string>({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"}));
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(or, "aten::__or__", a || b, bool, {"aten::__or__(int a, int b) -> (bool)"});
Expand Down
47 changes: 47 additions & 0 deletions core/conversion/evaluators/eval_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,53 @@
}, \
EvalOptions().validSchemas(schemas)});

#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \
auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
{c10::Symbol::fromQualString(node_kind), \
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
if (args.at(n->input(0)).IValue()->isInt()) { \
auto a = args.at(n->input(0)).unwrapToInt(); \
if (args.at(n->input(1)).IValue()->isInt()) { \
auto b = args.at(n->input(1)).unwrapToInt(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
auto b = args.at(n->input(1)).unwrapToDouble(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isBool()) { \
auto b = args.at(n->input(1)).unwrapToBool(); \
return operation; \
} else { \
TORCHTRT_THROW_ERROR( \
"Unimplemented data type for " \
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
return {}; \
} \
} else if (args.at(n->input(0)).IValue()->isDouble()) { \
auto a = args.at(n->input(0)).unwrapToDouble(); \
if (args.at(n->input(1)).IValue()->isInt()) { \
auto b = args.at(n->input(1)).unwrapToInt(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
auto b = args.at(n->input(1)).unwrapToDouble(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isBool()) { \
auto b = args.at(n->input(1)).unwrapToBool(); \
return operation; \
} else { \
TORCHTRT_THROW_ERROR( \
"Unimplemented data type for " \
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
return {}; \
} \
} else { \
TORCHTRT_THROW_ERROR( \
"Unimplemented data type for " \
<< node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \
return {}; \
} \
}, \
EvalOptions().validSchemas(schemas)});

#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \
auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
{c10::Symbol::fromQualString(node_name), \
Expand Down
68 changes: 68 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,5 +726,73 @@ TEST(Evaluators, RangeLengthNegEvaluatesCorrectly) {
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, PowIntEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=9]()
%2 : int = prim::Constant[value=4]()
%3 : float = aten::pow(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, PowFloatEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : float = prim::Constant[value=9.5]()
%2 : float = prim::Constant[value=4.5]()
%3 : float = aten::pow(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, PowIntFloatEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=9]()
%2 : float = prim::Constant[value=4.5]()
%3 : float = aten::pow(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, PowFloatIntEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : float = prim::Constant[value=9.5]()
%2 : int = prim::Constant[value=4]()
%3 : float = aten::pow(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

0 comments on commit c4fdfcb

Please sign in to comment.