diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index ad403fc500..514d247ff9 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -383,6 +383,30 @@ auto aten_registrations TRTORCH_UNUSED = "aten::Float.int(int a) -> float", "aten::Float.bool(bool a) -> float", })}) + .evaluator({c10::Symbol::fromQualString("aten::Int"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (args.at(n->input(0)).IValue()->isInt()) { + auto a = args.at(n->input(0)).unwrapToInt(); + return (int)a; + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto a = args.at(n->input(0)).unwrapToDouble(); + return (int)a; + } else if (args.at(n->input(0)).IValue()->isBool()) { + auto a = args.at(n->input(0)).unwrapToBool(); + return (int)a; + } else { + TRTORCH_THROW_ERROR( + "Unimplemented data type for aten::Int evaluator: " + << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } + }, + EvalOptions().validSchemas({ + "aten::Int.Scalar(Scalar a) -> int", + "aten::Int.int(int a) -> int", + "aten::Int.bool(bool a) -> int", + "aten::Int.float(float a) -> int", + })}) .evaluator({c10::Symbol::fromQualString("aten::__not__"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { auto el = args.at(n->input(0)).unwrapToBool(); diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index a29e453d1c..4b71fcda4f 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -399,4 +399,20 @@ TEST(Evaluators, ATenCopyEvaluatesCorrectly) { auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in}); ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor())); +} + +TEST(Evaluators, IntFloatEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : float = prim::Constant[value=9.3]() + %2 : int = aten::Int(%1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); } \ No newline at end of file