Skip to content

Commit

Permalink
fix: fix the bug that introduces kLong Tensor in prim::NumToTensor
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Apr 11, 2022
1 parent 609a697 commit 2c3e1d9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
32 changes: 32 additions & 0 deletions core/conversion/evaluators/eval_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,38 @@ void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
}
}

at::Tensor scalar_to_tensor_util(const at::Scalar& s, const at::Device device = at::kCPU) {
// This function is basically same with the one in
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float
// won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion
if (device == at::kCPU) {
if (s.isFloatingPoint()) {
LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kInt in scalar_to_tensor_util ");
return at::detail::scalar_tensor_static(s, at::kFloat, at::kCPU);
} else if (s.isComplex()) {
return at::detail::scalar_tensor_static(s, at::kComplexDouble, at::kCPU);
} else if (s.isBoolean()) {
return at::detail::scalar_tensor_static(s, at::kBool, at::kCPU);
} else {
AT_ASSERT(s.isIntegral(false));
LOG_WARNING("Unable to process input type of at::kLong, truncate type to at::kInt in scalar_to_tensor_util ");
return at::detail::scalar_tensor_static(s, at::kInt, at::kCPU);
}
}
if (s.isFloatingPoint()) {
LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kInt in scalar_to_tensor_util ");
return at::scalar_tensor(s, at::device(device).dtype(at::kFloat));
} else if (s.isBoolean()) {
return at::scalar_tensor(s, at::device(device).dtype(at::kBool));
} else if (s.isComplex()) {
return at::scalar_tensor(s, at::device(device).dtype(at::kComplexDouble));
} else {
AT_ASSERT(s.isIntegral(false));
LOG_WARNING("Unable to process input type of at::kLong, truncate type to at::kInt in scalar_to_tensor_util ");
return at::scalar_tensor(s, at::device(device).dtype(at::kInt));
}
}

template <typename DTYPE>
void storeLastDimension(
char* data,
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/evaluators/eval_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ at::Tensor createTensorFromList(
const torch::jit::IValue& dtype,
const torch::jit::IValue& device);

at::Tensor scalar_to_tensor_util(const at::Scalar& s, const at::Device device = at::kCPU);

} // namespace evaluators
} // namespace conversion
} // namespace core
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ auto prim_registrations =
}})
.evaluator({torch::jit::prim::NumToTensor,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return at::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
return scalar_to_tensor_util(args.at(n->input(0)).IValue()->toScalar());
}})
.evaluator({torch::jit::prim::ListUnpack,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
Expand Down

0 comments on commit 2c3e1d9

Please sign in to comment.