Skip to content

Commit

Permalink
fix(aten::size, other aten evaluators): Removes aten::size converter in
Browse files Browse the repository at this point in the history
favor of an evaluator. Also fixes a bunch of bugs with the evaluators

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jun 3, 2020
1 parent 2cc3226 commit c83447e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 48 deletions.
1 change: 0 additions & 1 deletion core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ cc_library(
"impl/matrix_multiply.cpp",
"impl/pooling.cpp",
"impl/reduce.cpp",
"impl/shape.cpp",
"impl/shuffle.cpp",
"impl/softmax.cpp",
"impl/unary.cpp",
Expand Down
32 changes: 0 additions & 32 deletions core/conversion/converters/impl/shape.cpp

This file was deleted.

54 changes: 41 additions & 13 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,44 +30,44 @@ auto aten_registrations = RegisterNodeEvaluators()
// aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto options = torch::TensorOptions()
.dtype(c10::ScalarType(args.at(&(n->output()[1])).unwrapToInt()))
.dtype(c10::ScalarType(args.at(n->output(1)).unwrapToInt()))
.layout(torch::kStrided)
.device(torch::kCUDA);

auto out_tensor = torch::zeros(args.at(&(n->input()[0])).unwrapToIntList().vec(), options);
auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options);
return out_tensor;
}
}).evaluator({
c10::Symbol::fromQualString("aten::mul"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto a = args.at(&(n->input()[0])).unwrapToInt();
auto b = args.at(&(n->input()[1])).unwrapToInt();
auto a = args.at(n->input(0)).unwrapToInt();
auto b = args.at(n->input(1)).unwrapToInt();
return a * b;
},
EvalOptions().validSchemas({"aten::mul.int(int a, int b) -> (int)"})
}).evaluator({
c10::Symbol::fromQualString("aten::sub"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto a = args.at(&(n->input()[0])).unwrapToInt();
auto b = args.at(&(n->input()[1])).unwrapToInt();
auto a = args.at(n->input(0)).unwrapToInt();
auto b = args.at(n->input(1)).unwrapToInt();
return a - b;
},
EvalOptions().validSchemas({"aten::sub.int(int a, int b) -> (int)"})
}).evaluator({
c10::Symbol::fromQualString("aten::__round_to_zero_floordiv"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto a = args.at(&(n->input()[0])).unwrapToInt();
auto b = args.at(&(n->input()[1])).unwrapToInt();
auto a = args.at(n->input(0)).unwrapToInt();
auto b = args.at(n->input(1)).unwrapToInt();
return a / b;
},
EvalOptions().validSchemas({"aten::__round_to_zero_floordiv(int a, int b) -> (int)"})
}).evaluator({
c10::Symbol::fromQualString("aten::slice"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
c10::List<c10::IValue> list = args.at(&(n->input()[0])).IValue()->to<c10::List<c10::IValue>>();
int64_t start = args.at(&(n->input()[0])).unwrapToInt();
int64_t end = args.at(&(n->input()[0])).unwrapToInt();
int64_t step = args.at(&(n->input()[0])).unwrapToInt();
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
int64_t start = args.at(n->input(1)).unwrapToInt();
int64_t end = args.at(n->input(2)).unwrapToInt();
int64_t step = args.at(n->input(3)).unwrapToInt();

const int64_t list_size = list.size();

Expand Down Expand Up @@ -96,10 +96,38 @@ auto aten_registrations = RegisterNodeEvaluators()
}).evaluator({
c10::Symbol::fromQualString("aten::len"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
c10::List<c10::IValue> list = args.at(&(n->input()[0])).IValue()->to<c10::List<c10::IValue>>();
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
return static_cast<int64_t>(list.size());
},
EvalOptions().validSchemas({"aten::len.t(t[] a) -> (int)"})
}).evaluator({
c10::Symbol::fromQualString("aten::size"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
LOG_WARNING("There may be undefined behavior using dynamic shape and aten::size");
auto tensor_var = args.at(n->input(0));
if (n->inputs().size() == 1) {
if (tensor_var.isITensor()) {
auto tensor = tensor_var.ITensor();
return util::toVec(tensor->getDimensions());
} else {
auto tensor = tensor_var.unwrapToTensor();
return tensor.sizes();
}
} else {
auto dim = args.at(n->input(1)).unwrapToInt();
if (tensor_var.isITensor()) {
auto tensor = tensor_var.ITensor();
return util::toVec(tensor->getDimensions())[dim];
} else {
auto tensor = tensor_var.unwrapToTensor();
return tensor.sizes()[dim];
}
}
},
EvalOptions().validSchemas({
"aten::size(Tensor self) -> (int[])",
"aten::size.int(Tensor self, int dim) -> (int)"
})
});
}
} // namespace evaluators
Expand Down
4 changes: 2 additions & 2 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ auto prim_registrations = RegisterNodeEvaluators()
}).evaluator({
torch::jit::prim::NumToTensor,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return at::scalar_to_tensor(args.at(&(n->output()[0])).IValue()->toScalar());
return at::scalar_to_tensor(args.at(n->output(0)).IValue()->toScalar());
}
}).evaluator({
torch::jit::prim::ListConstruct,
Expand Down Expand Up @@ -105,7 +105,7 @@ auto prim_registrations = RegisterNodeEvaluators()
}).evaluator({
c10::Symbol::fromQualString("prim::min"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto a = args.at(&(n->input()[0])).unwrapToIntList();
auto a = args.at(n->input(0)).unwrapToIntList();
int64_t min = std::numeric_limits<int64_t>::max();

for (size_t i = 0; i < a.size(); i++) {
Expand Down

0 comments on commit c83447e

Please sign in to comment.