From 014e3811e980a37b8360294ec2e7ec3e18021d79 Mon Sep 17 00:00:00 2001 From: inocsin Date: Fri, 19 Mar 2021 14:09:44 +0800 Subject: [PATCH] feat: support aten::arange converter Signed-off-by: inocsin --- core/conversion/evaluators/aten.cpp | 58 +++++++++- .../evaluators/test_aten_evaluators.cpp | 103 ++++++++++++++++++ 2 files changed, 160 insertions(+), 1 deletion(-) diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 587bfe5b6e..5bb2c5cf9c 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -467,7 +467,63 @@ auto aten_registrations TRTORCH_UNUSED = LOG_WARNING("Warning from TorchScript: " << *warning); return {}; }, - EvalOptions()}); + EvalOptions()}) + .evaluator({c10::Symbol::fromQualString("aten::arange"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + // int end_scalar = 0; + // auto end_scalar = ceil(args.at(n->input(0)).unwrapToScalar()); + int input_size = n->inputs().size(); + int scalar_count = 0; + for (int i = 0; i < input_size; i++) { + if (args.at(n->input(i)).IValue()->isScalar()) { + scalar_count += 1; + } + } + if (scalar_count == 1) { + if (args.at(n->input(0)).IValue()->isInt()) { + int end_scalar = args.at(n->input(0)).unwrapToInt(); + return torch::arange(end_scalar); + } else if (args.at(n->input(0)).IValue()->isDouble()) { + float end_scalar = ceil(args.at(n->input(0)).unwrapToScalar().to()); + return torch::arange(end_scalar); + } + } else if (scalar_count == 2) { + if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble()) { + float start_scalar = args.at(n->input(0)).unwrapToScalar().to(); + float end_scalar = args.at(n->input(1)).unwrapToScalar().to(); + return torch::arange(start_scalar, end_scalar); + } else { + int start_scalar = args.at(n->input(0)).unwrapToInt(); + int end_scalar = args.at(n->input(1)).unwrapToInt(); + return torch::arange(start_scalar, end_scalar); + } + } else if (scalar_count == 3) { + if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble() || + args.at(n->input(2)).IValue()->isDouble()) { + float start_scalar = args.at(n->input(0)).unwrapToScalar().to(); + float end_scalar = args.at(n->input(1)).unwrapToScalar().to(); + float step_scalar = args.at(n->input(2)).unwrapToScalar().to(); + return torch::arange(start_scalar, end_scalar, step_scalar); + } else { + int start_scalar = args.at(n->input(0)).unwrapToInt(); + int end_scalar = args.at(n->input(1)).unwrapToInt(); + int step_scalar = args.at(n->input(2)).unwrapToInt(); + return torch::arange(start_scalar, end_scalar, step_scalar); + } + } else { + TRTORCH_THROW_ERROR( + "Invalid input argument size for aten::arange, input argument size: " << input_size); + } + return {}; + }, + EvalOptions().validSchemas({ + R"SIG(aten::arange(Scalar end, *, int? dtype=None, int? layout=None, + Device? device=None, bool? pin_memory=None) -> (Tensor))SIG", + R"SIG(aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, + Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG", + R"SIG(aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, + Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG", + })}); } // namespace } // namespace evaluators } // namespace conversion diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index d7a237184d..a142bd0700 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -75,4 +75,107 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) { auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in}); ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor())); +} + +TEST(Evaluators, ATenArangeIntEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %0 : int = prim::Constant[value=51]() + %1 : None = prim::Constant() + %2 : Tensor = aten::arange(%0, %1, %1, %1, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6)); +} + +TEST(Evaluators, ATenArangeFloatEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %0 : float = prim::Constant[value=51.2]() + %1 : None = prim::Constant() + %2 : Tensor = aten::arange(%0, %1, %1, %1, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6)); +} + +TEST(Evaluators, ATenArangeStartEndIntEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %0 : int = prim::Constant[value=1]() + %1 : int = prim::Constant[value=51]() + %2 : None = prim::Constant() + %3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6)); +} + +TEST(Evaluators, ATenArangeStartEndFloatEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %0 : float = prim::Constant[value=1.5]() + %1 : float = prim::Constant[value=51.2]() + %2 : None = prim::Constant() + %3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6)); +} + +TEST(Evaluators, ATenArangeStartEndStepIntEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %0 : int = prim::Constant[value=1]() + %1 : int = prim::Constant[value=51]() + %2 : int = prim::Constant[value=1]() + %3 : None = prim::Constant() + %4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6)); +} + +TEST(Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %0 : float = prim::Constant[value=1.2]() + %1 : float = prim::Constant[value=51.6]() + %2 : float = prim::Constant[value=1.5]() + %3 : None = prim::Constant() + %4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6)); } \ No newline at end of file