From 47455d11172fd8095b4a1784d48c703f121ba755 Mon Sep 17 00:00:00 2001 From: inocsin Date: Fri, 8 Apr 2022 18:26:41 +0800 Subject: [PATCH] fix: [collection] remove aten::__getitem__ and prim::ListConstruct Signed-off-by: inocsin --- core/conversion/evaluators/aten.cpp | 15 ------- core/conversion/evaluators/prim.cpp | 62 ----------------------------- 2 files changed, 77 deletions(-) diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 30cdeaa46a..fde9e71e66 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -264,21 +264,6 @@ auto aten_registrations TORCHTRT_UNUSED = }, EvalOptions().validSchemas( {"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})}) - .evaluator({c10::Symbol::fromQualString("aten::__getitem__"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto list = args.at(n->input(0)).IValue()->to>(); - auto idx = args.at(n->input(1)).unwrapToInt(); - - const int64_t list_size = list.size(); - const int64_t normalized_idx = normalizeIndex(idx, list_size); - TORCHTRT_CHECK( - normalized_idx >= 0 || normalized_idx < list_size, - "List index out of range (aten::__getitem__)"); - return list.get(normalized_idx); - }, - EvalOptions().validSchemas({ - "aten::__getitem__.t(t[](a) list, int idx) -> (t(*))", - })}) .evaluator({c10::Symbol::fromQualString("aten::append"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { auto list = args.at(n->input(0)).IValue()->to>(); diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index 5c7209a9f9..7146159a5a 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -40,68 +40,6 @@ auto prim_registrations = auto outputVec = outputs->toList().vec(); return std::move(c10::ivalue::Tuple::create(outputVec)); }}) - .evaluator({torch::jit::prim::ListConstruct, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - const auto num_inputs = n->inputs().size(); - if (constTypesOnly(args)) { - c10::ListTypePtr lt = n->output()->type()->expect(); - if (torch::jit::IntType::get() == lt->getElementType()) { - c10::List list; - list.reserve(num_inputs); - for (auto in : n->inputs()) { - list.emplace_back(std::move(args.at(in).unwrapToInt())); - } - return c10::optional(std::move(torch::jit::IValue(list))); - } else if (torch::jit::FloatType::get() == lt->getElementType()) { - c10::List list; - list.reserve(num_inputs); - for (auto in : n->inputs()) { - list.emplace_back(std::move(args.at(in).unwrapToDouble())); - } - return c10::optional(std::move(torch::jit::IValue(list))); - } else if (lt->getElementType() == torch::jit::BoolType::get()) { - c10::List list; - list.reserve(num_inputs); - for (auto in : n->inputs()) { - list.emplace_back(std::move(args.at(in).unwrapToBool())); - } - return c10::optional(std::move(torch::jit::IValue(list))); - } else if (lt->getElementType()->isSubtypeOf(torch::jit::TensorType::get())) { - c10::List list; - list.reserve(num_inputs); - for (auto in : n->inputs()) { - if (args.at(in).isIValue()) { - list.emplace_back(std::move(args.at(in).unwrapToTensor())); - } - } - return c10::optional(std::move(torch::jit::IValue(list))); - } else { - c10::TypePtr elementType = lt->getElementType(); - auto list = c10::impl::GenericList(elementType); - list.reserve(num_inputs); - for (auto in : n->inputs()) { - list.emplace_back(std::move(*(args.at(in).IValue()))); - } - return c10::optional(std::move(torch::jit::IValue(list))); - } - } else { - c10::ListTypePtr lt = n->output()->type()->expect(); - c10::TypePtr elementType = lt->getElementType(); - auto list = c10::impl::GenericList(elementType); - list.reserve(num_inputs); - for (auto in : n->inputs()) { - if (args.at(in).isITensor()) { - auto tensor_holder = TensorContainer(); - tensor_holder.hold_tensor(args.at(in).ITensor()); - auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); - list.emplace_back(std::move(ival)); - } else { - list.emplace_back(std::move(args.at(in).unwrapToTensor())); - } - } - return c10::optional(std::move(torch::jit::IValue(list))); - } - }}) .evaluator({c10::Symbol::fromQualString("prim::dtype"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { auto input = args.at(n->input(0));