diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index b24222be26..1f16f0f575 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -172,6 +172,53 @@ auto aten_registrations TORCHTRT_UNUSED = auto out_tensor = torch::full(args.at(n->input(0)).unwrapToIntList().vec(), scalar_value, options); return out_tensor; }}) + .evaluator( + {c10::Symbol::fromQualString("aten::full_like"), + // aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, + // Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> (Tensor) + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + // Override options related to layout and device for TensorRT + auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA); + auto input_tensor_var = args.at(n->input(0)); + + std::vector input_shape; + c10::ScalarType input_dtype; + + // Extract data type and shape of input tensor + if (input_tensor_var.isITensor()) { + auto tensor = input_tensor_var.ITensor(); + input_shape = util::toVec(tensor->getDimensions()); + input_dtype = util::TRTDataTypeToScalarType(tensor->getType()); + } else if (input_tensor_var.IValue()->isTensor()) { + auto tensor = input_tensor_var.unwrapToTensor(); + input_shape = tensor.sizes().vec(); + input_dtype = tensor.scalar_type(); + } else if (input_tensor_var.IValue()->isCustomClass()) { + auto tensor = input_tensor_var.IValue()->toCustomClass()->tensor(); + input_shape = util::toVec(tensor->getDimensions()); + input_dtype = util::TRTDataTypeToScalarType(tensor->getType()); + } else { + TORCHTRT_THROW_ERROR( + "Invalid IValue type. IValue is not some class of torch::Tensor or nvinfer1::ITensor. Found: " + << input_tensor_var.IValue()->type()); + } + + // If specified, use third input arg to determine data type, otherwise default to input tensor data type + if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) { + options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt())); + } else { + options = options.dtype(input_dtype); + } + + // Generate full tensor with specified input options + auto scalar_value = args.at(n->input(1)).unwrapToScalar(); + auto out_tensor = torch::full(input_shape, scalar_value, options); + return out_tensor; + }, + EvalOptions().validSchemas( + {R"SIG(aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, + Layout? layout=None, Device? device=None, bool? pin_memory=None, + MemoryFormat? memory_format=None) -> (Tensor))SIG"})}) .evaluator( {c10::Symbol::fromQualString("aten::slice"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { @@ -821,4 +868,4 @@ auto aten_registrations TORCHTRT_UNUSED = } // namespace evaluators } // namespace conversion } // namespace core -} // namespace torch_tensorrt \ No newline at end of file +} // namespace torch_tensorrt diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index 78c60db859..16936ddd6a 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -83,6 +83,71 @@ TEST(Evaluators, FullEvaluatesCorrectly) { ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor())); } +TEST(Evaluators, FullLikeEvaluatesCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %9 : None = prim::Constant() + %13 : float = prim::Constant[value=1.3]() + %14 : int = prim::Constant[value=4]() + %35 : Device = prim::Constant[value="cuda:0"]() + %19 : Tensor = aten::full_like(%x.1, %13, %14, %9, %35, %9, %9) + return (%19))IR"; + + auto in = at::randint(1, 10, {1, 2, 3, 5}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in}); + + ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor())); + ASSERT_TRUE(jit_results[0].toTensor().dtype() == trt_results[0].toTensor().dtype()); +} + +TEST(Evaluators, FullLikeNewDtypeEvaluatesCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %9 : None = prim::Constant() + %13 : Scalar = prim::Constant[value=1]() + %14 : int = prim::Constant[value=11]() + %35 : Device = prim::Constant[value="cuda:0"]() + %19 : Tensor = aten::full_like(%x.1, %13, %14, %9, %35, %9, %9) + return (%19))IR"; + + auto in = at::randint(1, 10, {1, 2, 3, 5}, {at::kCUDA}); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in}); + + ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor())); + ASSERT_TRUE(jit_results[0].toTensor().dtype() == trt_results[0].toTensor().dtype()); +} + +TEST(Evaluators, FullLikeOldDtypeEvaluatesCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %9 : None = prim::Constant() + %13 : Scalar = prim::Constant[value=1.5]() + %35 : Device = prim::Constant[value="cuda:0"]() + %19 : Tensor = aten::full_like(%x.1, %13, %9, %9, %35, %9, %9) + return (%19))IR"; + + auto in = at::randint(1, 10, {1, 2, 3, 5}, {at::kCUDA}).to(torch::kInt32); + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in}); + + ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor())); + ASSERT_TRUE(jit_results[0].toTensor().dtype() == trt_results[0].toTensor().dtype()); +} + TEST(Evaluators, OnesDataTypeEvaluatesCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor):