From 164a1a6b75d8ee7d7e524eabe5894df5a8905bff Mon Sep 17 00:00:00 2001 From: Abhiram Iyer Date: Thu, 25 Jun 2020 17:03:17 -0700 Subject: [PATCH] feat(trt_util): from Naren, added unpadDims tool Signed-off-by: Abhiram Iyer Signed-off-by: Abhiram Iyer --- core/util/trt_util.cpp | 24 ++++++++++++++++++++++++ core/util/trt_util.h | 1 + 2 files changed, 25 insertions(+) diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 2f6706c51a..604514ff11 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -82,6 +82,30 @@ nvinfer1::Dims toDimsPad(c10::List l, uint64_t pad_to) { return dims; } +nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) { + nvinfer1::Dims dims; + + int j = 0; + bool pad_dims_done = false; + + for (int i = 0; i < d.nbDims; i++) { + if (d.d[i] == 1 && !pad_dims_done) { + // skip over unecessary dimension + continue; + } else { + dims.d[j] = d.d[i]; + j++; + + // keep all other dimensions (don't skip over them) + pad_dims_done = true; + } + } + + dims.nbDims = j; + + return dims; +} + std::vector toVec(nvinfer1::Dims d) { std::vector dims; for (int i = 0; i < d.nbDims; i++) { diff --git a/core/util/trt_util.h b/core/util/trt_util.h index 09cf5ff418..34f1164b47 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -79,6 +79,7 @@ int64_t volume(const nvinfer1::Dims& d); nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to); nvinfer1::Dims toDimsPad(c10::List l, uint64_t pad_to); +nvinfer1::Dims unpadDims(const nvinfer1::Dims& d); nvinfer1::Dims toDims(c10::IntArrayRef l); nvinfer1::Dims toDims(c10::List l); nvinfer1::DimsHW toDimsHW(c10::List l);