From f216d3f470b9715568acd933624b54baf0aecaa9 Mon Sep 17 00:00:00 2001 From: Abhiram Iyer Date: Mon, 29 Jun 2020 15:07:46 -0700 Subject: [PATCH] feat(): started to work on add_.t evaluator, doesn't work yet Signed-off-by: Abhiram Iyer Signed-off-by: Abhiram Iyer --- core/conversion/evaluators/aten.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 7202b0dae6..007dd259b0 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -10,6 +10,8 @@ #include "core/conversion/evaluators/evaluators.h" #include "core/conversion/evaluators/eval_macros.h" +// #include + namespace trtorch { namespace core { namespace conversion { @@ -243,6 +245,28 @@ auto aten_registrations TRTORCH_UNUSED = RegisterNodeEvaluators() "aten::add.int(int a, int b) -> (int)", "aten::add.float(float a, float b) -> (float)" }) + }).evaluator({ + c10::Symbol::fromQualString("aten::add_"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + LOG_DEBUG("aten::add_ evaluator is found"); + + // std::raise(SIGINT); + + if (args.at(n->input(0)).IValue()->isList()) { + auto a = args.at(n->input(0)).IValue()->to>(); + auto b = args.at(n->input(1)).IValue()->to>(); + + // incorrect syntax + // for (auto each : b) { + // a.push_back(each); + // } + + return a; + } + }, + EvalOptions().validSchemas({ + "aten::add_.t(t[](a!) self, t[] b) -> (t[])" + }) }).evaluator({ c10::Symbol::fromQualString("aten::mul"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional {