Skip to content

Commit

Permalink
Add lowering info logic
Browse files Browse the repository at this point in the history
Signed-off-by: Arvind Sridhar <[email protected]>
  • Loading branch information
ArvindSridhar authored and narendasan committed Aug 20, 2021
1 parent 02b23cb commit b4feb49
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct CompileSpec {
partitioning::PartitionInfo partition_info;
};

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);

Expand Down
2 changes: 1 addition & 1 deletion core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct LowerInfo {
bool disable_cse = false;
std::vector<std::string> forced_fallback_modules;
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
}
};

void LowerBlock(torch::jit::Block* b);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
Expand Down
5 changes: 4 additions & 1 deletion cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,9 @@ struct TRTORCH_API CompileSpec {
/// A list of names of operations that will explicitly run in PyTorch
std::vector<std::string> forced_fallback_ops;

/// A list of names of modules that will explicitly run in PyTorch
std::vector<std::string> forced_fallback_modules;

/**
* @brief Construct a default Torch Fallback object, fallback will be off
*/
Expand Down Expand Up @@ -781,7 +784,7 @@ TRTORCH_API void dump_build_info();
*
* @returns bool: Method is supported by TRTorch
*/
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name, CompileSpec info);

/**
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
Expand Down
1 change: 1 addition & 0 deletions cpp/api/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.partition_info.enabled = external.torch_fallback.enabled;
internal.partition_info.min_block_size = external.torch_fallback.min_block_size;
internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_ops;
internal.lower_info.forced_fallback_modules = external.torch_fallback.forced_fallback_modules;

switch (external.device.device_type) {
case CompileSpec::Device::DeviceType::kDLA:
Expand Down
4 changes: 2 additions & 2 deletions cpp/api/src/trtorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace trtorch {
core::CompileSpec to_internal_compile_spec(CompileSpec external);
core::runtime::CudaDevice to_internal_cuda_device(CompileSpec::Device device);

bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name) {
return core::CheckMethodOperatorSupport(module, method_name);
bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name, CompileSpec info) {
return core::CheckMethodOperatorSupport(module, method_name, to_internal_compile_spec(info));
}

std::string ConvertGraphToTRTEngine(
Expand Down

0 comments on commit b4feb49

Please sign in to comment.