diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 43b30857ee..ccc72bb4cb 100644 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -17,6 +17,7 @@ cc_library( "impl/linear.cpp", "impl/pooling.cpp", "impl/reduce.cpp", + "impl/shuffle.cpp", "impl/softmax.cpp", "impl/unary.cpp", ], diff --git a/core/conversion/converters/converters.h b/core/conversion/converters/converters.h index 465c952ceb..6c151b061a 100644 --- a/core/conversion/converters/converters.h +++ b/core/conversion/converters/converters.h @@ -6,6 +6,7 @@ #include "torch/csrc/jit/runtime/custom_operator.h" #include "ATen/core/function_schema.h" +#include "core/util/prelude.h" #include "core/conversion/conversionctx/ConversionCtx.h" namespace trtorch { diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp new file mode 100644 index 0000000000..6c262b89a9 --- /dev/null +++ b/core/conversion/converters/impl/shuffle.cpp @@ -0,0 +1,33 @@ +#include "core/conversion/converters/converters.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +static auto shuffle_registrations = RegisterNodeConversionPatterns() + .pattern({ + "aten::reshape(Tensor self, int[] shape) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto new_shape = util::toDimsPad(args[1].unwrapToIntList(), 2); + + auto shuffle = ctx->net->addShuffle(*in); + TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); + shuffle->setReshapeDimensions(new_shape); + shuffle->setName(util::node_info(n).c_str()); + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + + return true; + } + }); +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch diff --git a/core/conversion/converters/impl/softmax.cpp b/core/conversion/converters/impl/softmax.cpp index c138c759f0..35f6f04ef1 100644 --- a/core/conversion/converters/impl/softmax.cpp +++ b/core/conversion/converters/impl/softmax.cpp @@ -1,4 +1,3 @@ -#include "core/util/prelude.h" #include "core/conversion/converters/converters.h" namespace trtorch { @@ -29,12 +28,7 @@ static auto softmax_registrations = RegisterNodeConversionPatterns() auto softmax = ctx->net->addSoftMax(*in); TRTORCH_CHECK(softmax, "Unable to create softmax layer from node: " << *n); - - if (!softmax) { - LOG_ERROR("Unable to create softmax layer from node: " << *n); - return false; - } - LOG_WARNING("Disregarding dtype argument, please verify"); + LOG_DEBUG("Disregarding dtype argument"); if (shape.size() > 3) { softmax->setAxes(1 << (dim)); @@ -69,4 +63,4 @@ static auto softmax_registrations = RegisterNodeConversionPatterns() } // namespace converters } // namespace conversion } // namespace core -} // trtorch +} // namespace trtorch diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 89214e5efd..2f6706c51a 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -59,6 +59,29 @@ nvinfer1::Dims toDims(c10::List l) { return dims; } +nvinfer1::Dims toDimsPad(c10::List l, uint64_t pad_to) { + if (l.size() > pad_to) { + LOG_DEBUG("Requested padding of dimensions to " << pad_to << " but found " << l.size() << " dimensions, not going to pad"); + return toDims(l); + } + + if (pad_to > nvinfer1::Dims::MAX_DIMS) { + //TODO: Handle this with exceptions or whatever + LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT"); + } + + nvinfer1::Dims dims; + dims.nbDims = pad_to; + for (size_t i = 0; i < pad_to - l.size(); i++) { + dims.d[i] = 1; + } + + for (size_t i = pad_to - l.size(); i < pad_to; i++) { + dims.d[i] = l[i - (pad_to - l.size())]; + } + return dims; +} + std::vector toVec(nvinfer1::Dims d) { std::vector dims; for (int i = 0; i < d.nbDims; i++) { diff --git a/core/util/trt_util.h b/core/util/trt_util.h index bf8ea5b224..09cf5ff418 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -78,6 +78,7 @@ namespace util { int64_t volume(const nvinfer1::Dims& d); nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to); +nvinfer1::Dims toDimsPad(c10::List l, uint64_t pad_to); nvinfer1::Dims toDims(c10::IntArrayRef l); nvinfer1::Dims toDims(c10::List l); nvinfer1::DimsHW toDimsHW(c10::List l); diff --git a/tests/core/converters/BUILD b/tests/core/converters/BUILD index 8cbe9c4e68..39d978afa9 100644 --- a/tests/core/converters/BUILD +++ b/tests/core/converters/BUILD @@ -4,6 +4,10 @@ converter_test( name = "test_softmax" ) +converter_test( + name = "test_shuffle" +) + converter_test( name = "test_activation" ) @@ -36,6 +40,7 @@ test_suite( name = "test_converters", tests = [ ":test_softmax", + ":test_shuffle", ":test_activation", ":test_pooling", ":test_unary", diff --git a/tests/core/converters/test_shuffle.cpp b/tests/core/converters/test_shuffle.cpp new file mode 100644 index 0000000000..459d271d05 --- /dev/null +++ b/tests/core/converters/test_shuffle.cpp @@ -0,0 +1,29 @@ +#include +#include "gtest/gtest.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "tests/util/util.h" +#include "core/compiler.h" + +TEST(Converters, ATenReshapeConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=3]() + %2 : int = prim::Constant[value=2]() + %3 : int[] = prim::ListConstruct(%1, %2) + %4 : Tensor = aten::reshape(%0, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(0, 5, {2, 3}, {at::kCUDA}); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} \ No newline at end of file