From 2c3e1d922c30797a5b9db5b1721a09079316cf73 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Mon, 11 Apr 2022 12:02:42 -0700 Subject: [PATCH] fix: fix the bug that introduces kLong Tensor in prim::NumToTensor Signed-off-by: Bo Wang --- core/conversion/evaluators/eval_util.cpp | 32 ++++++++++++++++++++++++ core/conversion/evaluators/eval_util.h | 2 ++ core/conversion/evaluators/prim.cpp | 2 +- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index cb294fc1c8..77e9715212 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -119,6 +119,38 @@ void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) { } } +at::Tensor scalar_to_tensor_util(const at::Scalar& s, const at::Device device = at::kCPU) { + // This function is basically same with the one in + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float + // won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion + if (device == at::kCPU) { + if (s.isFloatingPoint()) { + LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kInt in scalar_to_tensor_util "); + return at::detail::scalar_tensor_static(s, at::kFloat, at::kCPU); + } else if (s.isComplex()) { + return at::detail::scalar_tensor_static(s, at::kComplexDouble, at::kCPU); + } else if (s.isBoolean()) { + return at::detail::scalar_tensor_static(s, at::kBool, at::kCPU); + } else { + AT_ASSERT(s.isIntegral(false)); + LOG_WARNING("Unable to process input type of at::kLong, truncate type to at::kInt in scalar_to_tensor_util "); + return at::detail::scalar_tensor_static(s, at::kInt, at::kCPU); + } + } + if (s.isFloatingPoint()) { + LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kInt in scalar_to_tensor_util "); + return at::scalar_tensor(s, at::device(device).dtype(at::kFloat)); + } else if (s.isBoolean()) { + return at::scalar_tensor(s, at::device(device).dtype(at::kBool)); + } else if (s.isComplex()) { + return at::scalar_tensor(s, at::device(device).dtype(at::kComplexDouble)); + } else { + AT_ASSERT(s.isIntegral(false)); + LOG_WARNING("Unable to process input type of at::kLong, truncate type to at::kInt in scalar_to_tensor_util "); + return at::scalar_tensor(s, at::device(device).dtype(at::kInt)); + } +} + template void storeLastDimension( char* data, diff --git a/core/conversion/evaluators/eval_util.h b/core/conversion/evaluators/eval_util.h index 0a8b563e3c..ecec1ab210 100644 --- a/core/conversion/evaluators/eval_util.h +++ b/core/conversion/evaluators/eval_util.h @@ -13,6 +13,8 @@ at::Tensor createTensorFromList( const torch::jit::IValue& dtype, const torch::jit::IValue& device); +at::Tensor scalar_to_tensor_util(const at::Scalar& s, const at::Device device = at::kCPU); + } // namespace evaluators } // namespace conversion } // namespace core diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index 5c7209a9f9..7ef3332d9b 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -31,7 +31,7 @@ auto prim_registrations = }}) .evaluator({torch::jit::prim::NumToTensor, [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - return at::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar()); + return scalar_to_tensor_util(args.at(n->input(0)).IValue()->toScalar()); }}) .evaluator({torch::jit::prim::ListUnpack, [](const torch::jit::Node* n, kwargs& args) -> c10::optional {