Skip to content

Commit

Permalink
feat(aten::add): adding string concat evaluator
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 8, 2022
1 parent 828d120 commit 65dbf90
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,22 @@ auto aten_registrations TORCHTRT_UNUSED =
auto a = args.at(n->input(0)).unwrapToDouble();
auto b = args.at(n->input(1)).unwrapToDouble();
return a + b;
} else if (args.at(n->input(0)).IValue()->isString()) {
auto a = args.at(n->input(0)).unwrapToString();
auto b = args.at(n->input(1)).unwrapToString();
return a + b;
} else {
TORCHTRT_THROW_ERROR(
"Unimplemented data type for aten::add evaluator: "
<< args.at(n->input(0)).IValue()->type()->str());
return {};
}
},
EvalOptions().validSchemas(
{"aten::add.int(int a, int b) -> (int)", "aten::add.float(float a, float b) -> (float)"})})
EvalOptions().validSchemas({
"aten::add.int(int a, int b) -> (int)",
"aten::add.float(float a, float b) -> (float)",
"aten::add.str(str a, str b) -> (str)"
})})
.evaluator({c10::Symbol::fromQualString("aten::add_"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (args.at(n->input(0)).IValue()->isList()) {
Expand Down

0 comments on commit 65dbf90

Please sign in to comment.