diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 2f6706c51a..604514ff11 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -82,6 +82,30 @@ nvinfer1::Dims toDimsPad(c10::List l, uint64_t pad_to) { return dims; } +nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) { + nvinfer1::Dims dims; + + int j = 0; + bool pad_dims_done = false; + + for (int i = 0; i < d.nbDims; i++) { + if (d.d[i] == 1 && !pad_dims_done) { + // skip over unecessary dimension + continue; + } else { + dims.d[j] = d.d[i]; + j++; + + // keep all other dimensions (don't skip over them) + pad_dims_done = true; + } + } + + dims.nbDims = j; + + return dims; +} + std::vector toVec(nvinfer1::Dims d) { std::vector dims; for (int i = 0; i < d.nbDims; i++) { diff --git a/core/util/trt_util.h b/core/util/trt_util.h index 09cf5ff418..34f1164b47 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -79,6 +79,7 @@ int64_t volume(const nvinfer1::Dims& d); nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to); nvinfer1::Dims toDimsPad(c10::List l, uint64_t pad_to); +nvinfer1::Dims unpadDims(const nvinfer1::Dims& d); nvinfer1::Dims toDims(c10::IntArrayRef l); nvinfer1::Dims toDims(c10::List l); nvinfer1::DimsHW toDimsHW(c10::List l);