Skip to content

Commit

Permalink
fix: [collection] remove aten::__getitem__ and prim::ListConstruct
Browse files Browse the repository at this point in the history
Signed-off-by: inocsin <[email protected]>
  • Loading branch information
inocsin committed Apr 8, 2022
1 parent a206336 commit 47455d1
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 77 deletions.
15 changes: 0 additions & 15 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,21 +264,6 @@ auto aten_registrations TORCHTRT_UNUSED =
},
EvalOptions().validSchemas(
{"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})})
.evaluator({c10::Symbol::fromQualString("aten::__getitem__"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
auto idx = args.at(n->input(1)).unwrapToInt();

const int64_t list_size = list.size();
const int64_t normalized_idx = normalizeIndex(idx, list_size);
TORCHTRT_CHECK(
normalized_idx >= 0 || normalized_idx < list_size,
"List index out of range (aten::__getitem__)");
return list.get(normalized_idx);
},
EvalOptions().validSchemas({
"aten::__getitem__.t(t[](a) list, int idx) -> (t(*))",
})})
.evaluator({c10::Symbol::fromQualString("aten::append"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
Expand Down
62 changes: 0 additions & 62 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,68 +40,6 @@ auto prim_registrations =
auto outputVec = outputs->toList().vec();
return std::move(c10::ivalue::Tuple::create(outputVec));
}})
.evaluator({torch::jit::prim::ListConstruct,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
const auto num_inputs = n->inputs().size();
if (constTypesOnly(args)) {
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
if (torch::jit::IntType::get() == lt->getElementType()) {
c10::List<int64_t> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(args.at(in).unwrapToInt()));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else if (torch::jit::FloatType::get() == lt->getElementType()) {
c10::List<double> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(args.at(in).unwrapToDouble()));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else if (lt->getElementType() == torch::jit::BoolType::get()) {
c10::List<bool> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(args.at(in).unwrapToBool()));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else if (lt->getElementType()->isSubtypeOf(torch::jit::TensorType::get())) {
c10::List<at::Tensor> list;
list.reserve(num_inputs);
for (auto in : n->inputs()) {
if (args.at(in).isIValue()) {
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
}
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
} else {
c10::TypePtr elementType = lt->getElementType();
auto list = c10::impl::GenericList(elementType);
list.reserve(num_inputs);
for (auto in : n->inputs()) {
list.emplace_back(std::move(*(args.at(in).IValue())));
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
}
} else {
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
c10::TypePtr elementType = lt->getElementType();
auto list = c10::impl::GenericList(elementType);
list.reserve(num_inputs);
for (auto in : n->inputs()) {
if (args.at(in).isITensor()) {
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(args.at(in).ITensor());
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.emplace_back(std::move(ival));
} else {
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
}
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
}
}})
.evaluator({c10::Symbol::fromQualString("prim::dtype"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto input = args.at(n->input(0));
Expand Down

0 comments on commit 47455d1

Please sign in to comment.