Skip to content

Commit

Permalink
feat: support aten::eq.str evaluator
Browse files Browse the repository at this point in the history
Signed-off-by: Ruoqian Guo <[email protected]>
  • Loading branch information
ruoqianguo committed Oct 29, 2021
1 parent f99a6ca commit 5643972
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
"aten::eq.bool(bool a, bool b) -> (bool)",
"aten::eq.int(int a, int b) -> (bool)",
"aten::eq.float(float a, float b) -> (bool)",
"aten::eq.str(str a, str b) -> (bool)",
"aten::eq.int_float(int a, float b) -> (bool)",
"aten::eq.float_int(float a, int b) -> (bool)",
}));
Expand Down
11 changes: 11 additions & 0 deletions core/conversion/evaluators/eval_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
return {}; \
} \
} else if (args.at(n->input(0)).IValue()->isString()) { \
auto a = args.at(n->input(0)).unwrapToString(); \
if (args.at(n->input(1)).IValue()->isString()) { \
auto b = args.at(n->input(1)).unwrapToString(); \
return operation; \
} else { \
TRTORCH_THROW_ERROR( \
"Unimplemented data type for " \
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
return {}; \
} \
} else { \
TRTORCH_THROW_ERROR( \
"Unimplemented data type for " \
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/var/Var.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class Var : torch::CustomClassHolder {
double unwrapToDouble();
bool unwrapToBool(bool default_val);
bool unwrapToBool();
std::string unwrapToString(std::string default_val);
std::string unwrapToString();
c10::Scalar unwrapToScalar(c10::Scalar default_val);
c10::Scalar unwrapToScalar();
c10::List<int64_t> unwrapToIntList(c10::List<int64_t> default_val);
Expand Down
1 change: 1 addition & 0 deletions core/conversion/var/Var_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ DEFINE_UNWRAP_TO(at::Tensor, Tensor)
DEFINE_UNWRAP_TO(int64_t, Int)
DEFINE_UNWRAP_TO(double, Double)
DEFINE_UNWRAP_TO(bool, Bool)
DEFINE_UNWRAP_TO(std::string, String)
DEFINE_UNWRAP_TO(c10::Scalar, Scalar)
DEFINE_UNWRAP_TO(c10::List<int64_t>, IntList)
DEFINE_UNWRAP_TO(c10::List<double>, DoubleList)
Expand Down
34 changes: 34 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,5 +506,39 @@ TEST(Evaluators, ATenIsFloatingPointEvaluatesFalseCorrectly) {
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in_trt});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, EqStrResultIsTrueEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : str = prim::Constant[value="res3"]()
%2 : str = prim::Constant[value="res3"]()
%3 : bool = aten::eq(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, EqStrResultIsFalseEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : str = prim::Constant[value="res3"]()
%2 : str = prim::Constant[value="res4"]()
%3 : bool = aten::eq(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

0 comments on commit 5643972

Please sign in to comment.