diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 30cdeaa46a..018b565421 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -342,6 +342,10 @@ auto aten_registrations TORCHTRT_UNUSED = auto a = args.at(n->input(0)).unwrapToDouble(); auto b = args.at(n->input(1)).unwrapToDouble(); return a + b; + } else if (args.at(n->input(0)).IValue()->isString()) { + auto a = args.at(n->input(0)).unwrapToString(); + auto b = args.at(n->input(1)).unwrapToString(); + return a + b; } else { TORCHTRT_THROW_ERROR( "Unimplemented data type for aten::add evaluator: " @@ -349,8 +353,11 @@ auto aten_registrations TORCHTRT_UNUSED = return {}; } }, - EvalOptions().validSchemas( - {"aten::add.int(int a, int b) -> (int)", "aten::add.float(float a, float b) -> (float)"})}) + EvalOptions().validSchemas({ + "aten::add.int(int a, int b) -> (int)", + "aten::add.float(float a, float b) -> (float)", + "aten::add.str(str a, str b) -> (str)" + })}) .evaluator({c10::Symbol::fromQualString("aten::add_"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isList()) {