diff --git a/core/ir/ir.h b/core/ir/ir.h index 87bae52645..8e64aa4e10 100644 --- a/core/ir/ir.h +++ b/core/ir/ir.h @@ -2,6 +2,7 @@ #include #include "NvInfer.h" +// #include "trtorch.h" namespace trtorch { namespace core { @@ -18,6 +19,17 @@ struct InputRange { InputRange(std::vector min_shape, std::vector opt_shape, std::vector max_shape); }; +// struct Input{ +// Input(std::vector shape); +// Input(std::vector min_shape, std::vector opt_shape, std::vector max_shape); +// Input(std::vector shape, DataType dtype=DataType::kFloat32); +// Input(std::vector min_shape, std::vector opt_shape, std::vector max_shape, DataType dtype=DataType::kFloat32); +// nvinfer1::Dims min; +// nvinfer1::Dims max; +// nvinfer1::Dims opt; +// nvinfer1::DataType dtype; +// } + } // namespace ir } // namespace core -} // namespace trtorch \ No newline at end of file +} // namespace trtorch diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index a1156a1338..043ba41b58 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -191,6 +191,59 @@ struct TRTORCH_API CompileSpec { Value value; }; + /** + * @brief A struct to hold Input of a network. + * This struct has all the info (shape, dtype, name, memory_format) of an input tensor. + * The shape field in this struct can either hold a single vector representing an input shape, + * signifying a static input shape or a set of three input shapes representing + * the min, optiminal and max input shapes allowed for the engine. + * dtype : This can take values among values supported by trtorch::DataType + */ + struct TRTORCH_API Input { + /// Minimum acceptable input size into the engine + std::vector min; + /// Optimal input size into the engine (gets best performace) + std::vector opt; + /// Maximum acceptable input size into the engine + std::vector max; + /// Data type of the input + DataType dtype; + + /** + * @brief Construct a new Input Range object for static input size from + * vector + * + * @param opt + */ + Input(std::vector opt, DataType dtype=DataType::kFloat); + /** + * @brief Construct a new Input Range object static input size from + * c10::ArrayRef (the type produced by tensor.sizes()) + * + * @param opt + */ + Input(c10::ArrayRef opt, DataType dtype=DataType::kFloat); + /** + * @brief Construct a new Input Range object dynamic input size from vectors + * for min, opt, and max supported sizes + * + * @param min + * @param opt + * @param max + */ + Input(std::vector min, std::vector opt, std::vector max, DataType dtype=DataType::kFloat); + /** + * @brief Construct a new Input Range object dynamic input size from + * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max + * supported sizes + * + * @param min + * @param opt + * @param max + */ + Input(c10::ArrayRef min, c10::ArrayRef opt, c10::ArrayRef max, DataType dtype=DataType::kFloat); + }; + /** * Emum for selecting engine capability */ diff --git a/cpp/api/src/compile_spec.cpp b/cpp/api/src/compile_spec.cpp index aea4cfd68f..29a6a9adea 100644 --- a/cpp/api/src/compile_spec.cpp +++ b/cpp/api/src/compile_spec.cpp @@ -71,6 +71,33 @@ CompileSpec::CompileSpec(std::vector> fixed_sizes) { } } +/* ====== DEFINE INPUTS CLASS MEMBERS ======*/ +CompileSpec::Input::Input(std::vector opt) { + this->opt = opt; + this->min = opt; + this->max = opt; +} + +CompileSpec::Input::Input(c10::IntArrayRef opt) { + this->opt = core::util::toVec(opt); + this->min = core::util::toVec(opt); + this->max = core::util::toVec(opt); +} + +CompileSpec::Input::Input(std::vector min, std::vector opt, std::vector max) { + this->opt = opt; + this->min = min; + this->max = max; +} + +CompileSpec::Input::Input(c10::IntArrayRef min, c10::IntArrayRef opt, c10::IntArrayRef max) { + this->opt = core::util::toVec(opt); + this->min = core::util::toVec(min); + this->max = core::util::toVec(max); +} + +/* ==========================================*/ + core::ir::InputRange to_internal_input_range(CompileSpec::InputRange i) { return core::ir::InputRange(i.min, i.opt, i.max); }