Skip to content

Commit

Permalink
chore: Initial input_type api impl
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 authored and narendasan committed Jul 21, 2021
1 parent a3f4a3c commit 8e67e38
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
14 changes: 13 additions & 1 deletion core/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <vector>
#include "NvInfer.h"
// #include "trtorch.h"

namespace trtorch {
namespace core {
Expand All @@ -18,6 +19,17 @@ struct InputRange {
InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
};

// struct Input{
// Input(std::vector<int64_t> shape);
// Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
// Input(std::vector<int64_t> shape, DataType dtype=DataType::kFloat32);
// Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype=DataType::kFloat32);
// nvinfer1::Dims min;
// nvinfer1::Dims max;
// nvinfer1::Dims opt;
// nvinfer1::DataType dtype;
// }

} // namespace ir
} // namespace core
} // namespace trtorch
} // namespace trtorch
53 changes: 53 additions & 0 deletions cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,59 @@ struct TRTORCH_API CompileSpec {
Value value;
};

/**
* @brief A struct to hold Input of a network.
* This struct has all the info (shape, dtype, name, memory_format) of an input tensor.
* The shape field in this struct can either hold a single vector representing an input shape,
* signifying a static input shape or a set of three input shapes representing
* the min, optiminal and max input shapes allowed for the engine.
* dtype : This can take values among values supported by trtorch::DataType
*/
struct TRTORCH_API Input {
/// Minimum acceptable input size into the engine
std::vector<int64_t> min;
/// Optimal input size into the engine (gets best performace)
std::vector<int64_t> opt;
/// Maximum acceptable input size into the engine
std::vector<int64_t> max;
/// Data type of the input
DataType dtype;

/**
* @brief Construct a new Input Range object for static input size from
* vector
*
* @param opt
*/
Input(std::vector<int64_t> opt, DataType dtype=DataType::kFloat);
/**
* @brief Construct a new Input Range object static input size from
* c10::ArrayRef (the type produced by tensor.sizes())
*
* @param opt
*/
Input(c10::ArrayRef<int64_t> opt, DataType dtype=DataType::kFloat);
/**
* @brief Construct a new Input Range object dynamic input size from vectors
* for min, opt, and max supported sizes
*
* @param min
* @param opt
* @param max
*/
Input(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max, DataType dtype=DataType::kFloat);
/**
* @brief Construct a new Input Range object dynamic input size from
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
* supported sizes
*
* @param min
* @param opt
* @param max
*/
Input(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max, DataType dtype=DataType::kFloat);
};

/**
* Emum for selecting engine capability
*/
Expand Down
27 changes: 27 additions & 0 deletions cpp/api/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,33 @@ CompileSpec::CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes) {
}
}

/* ====== DEFINE INPUTS CLASS MEMBERS ======*/
CompileSpec::Input::Input(std::vector<int64_t> opt) {
this->opt = opt;
this->min = opt;
this->max = opt;
}

CompileSpec::Input::Input(c10::IntArrayRef opt) {
this->opt = core::util::toVec(opt);
this->min = core::util::toVec(opt);
this->max = core::util::toVec(opt);
}

CompileSpec::Input::Input(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max) {
this->opt = opt;
this->min = min;
this->max = max;
}

CompileSpec::Input::Input(c10::IntArrayRef min, c10::IntArrayRef opt, c10::IntArrayRef max) {
this->opt = core::util::toVec(opt);
this->min = core::util::toVec(min);
this->max = core::util::toVec(max);
}

/* ==========================================*/

core::ir::InputRange to_internal_input_range(CompileSpec::InputRange i) {
return core::ir::InputRange(i.min, i.opt, i.max);
}
Expand Down

0 comments on commit 8e67e38

Please sign in to comment.