diff --git a/core/conversion/converters/impl/unary.cpp b/core/conversion/converters/impl/unary.cpp index 52eb17a854..6357f4922e 100644 --- a/core/conversion/converters/impl/unary.cpp +++ b/core/conversion/converters/impl/unary.cpp @@ -34,6 +34,36 @@ auto reciprocal_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns(). return true; }}); +auto log2_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( + {"aten::log2(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + const static float ln2 = 0.693147180559945309; // same constant onnx uses + auto in = args[0].ITensorOrFreeze(ctx); + auto tensor_type = util::TRTDataTypeToScalarType(in->getType()); + if (in->getType() == nvinfer1::DataType::kINT32) { + // pytorch implicitly casts to float for aten::log2(int) + in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT); + tensor_type = at::kFloat; + } + + auto log_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kLOG); + TORCHTRT_CHECK(log_layer, "Unable to create log layer from node: " << *n); + log_layer->setName((util::node_info(n) + "_log").c_str()); + + std::vector ln2_dims(in->getDimensions().nbDims, 1); + auto ln2_tensor = at::full(ln2_dims, ln2, at::TensorOptions().dtype(tensor_type)); + auto ln2_itensor = converters::tensor_to_const(ctx, ln2_tensor); + + auto div_layer = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kDIV, + log_layer->getOutput(0), + ln2_itensor, + (util::node_info(n) + "_div").c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], div_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; + }}); + auto logical_not_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( {"aten::logical_not(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in = args[0].ITensorOrFreeze(ctx); diff --git a/tests/core/conversion/converters/test_unary.cpp b/tests/core/conversion/converters/test_unary.cpp index 2c10e40b42..858fdc69a2 100644 --- a/tests/core/conversion/converters/test_unary.cpp +++ b/tests/core/conversion/converters/test_unary.cpp @@ -47,6 +47,21 @@ TEST(Converters, ATenReciprocalIntConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0])); } +TEST(Converters, ATenLog2IntConvertsCorrectly) { + const auto graph = gen_test_graph("log2"); + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::tensor({1, 2, 7, 25, 50}, {at::kCUDA}).to(torch::kInt32); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0])); +} + TEST(Converters, ATenSignConvertsCorrectly) { const auto graph = gen_test_graph("sign"); auto g = std::make_shared(); @@ -129,6 +144,7 @@ test_unary(abs, Abs); test_unary(floor, Floor); test_unary(reciprocal, Reciprocal); test_unary(log, Log); +test_unary(log2, Log2); test_unary(ceil, Ceil); test_unary(sqrt, Sqrt); test_unary(exp, Exp);