-
Notifications
You must be signed in to change notification settings - Fork 357
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Adding profiling support to the runtime
Signed-off-by: Naren Dasan <[email protected]>
- Loading branch information
1 parent
b70c913
commit 8a48100
Showing
24 changed files
with
238 additions
and
142 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#pragma once | ||
#include <string> | ||
#include "NvInfer.h" | ||
|
||
namespace torch_tensorrt { | ||
namespace core { | ||
namespace runtime { | ||
|
||
struct CUDADevice { | ||
int64_t id; // CUDA device id | ||
int64_t major; // CUDA compute major version | ||
int64_t minor; // CUDA compute minor version | ||
nvinfer1::DeviceType device_type; | ||
std::string device_name; | ||
|
||
CUDADevice(); | ||
CUDADevice(int64_t gpu_id, nvinfer1::DeviceType device_type); | ||
CUDADevice(std::string serialized_device_info); | ||
~CUDADevice() = default; | ||
CUDADevice(const CUDADevice& other) = default; | ||
CUDADevice& operator=(const CUDADevice& other); | ||
std::string serialize(); | ||
std::string getSMCapability() const; | ||
friend std::ostream& operator<<(std::ostream& os, const CUDADevice& device); | ||
}; | ||
|
||
void set_cuda_device(CUDADevice& cuda_device); | ||
// Gets the current active GPU (DLA will not show up through this) | ||
CUDADevice get_current_device(); | ||
|
||
} // namespace torch_tensorrt | ||
} // namespace core | ||
} // namespace runtime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#pragma once | ||
#include <map> | ||
#include <memory> | ||
#include <mutex> | ||
#include <utility> | ||
#include "ATen/core/function_schema.h" | ||
#include "NvInfer.h" | ||
#include "core/util/prelude.h" | ||
#include "torch/custom_class.h" | ||
|
||
namespace torch_tensorrt { | ||
namespace core { | ||
namespace runtime { | ||
|
||
struct TRTEngine : torch::CustomClassHolder { | ||
// Each engine needs it's own runtime object | ||
std::shared_ptr<nvinfer1::IRuntime> rt; | ||
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine; | ||
std::shared_ptr<nvinfer1::IExecutionContext> exec_ctx; | ||
std::pair<uint64_t, uint64_t> num_io; | ||
std::string name; | ||
std::mutex mu; | ||
CUDADevice device_info; | ||
|
||
std::string execution_profile_path; | ||
std::string device_profile_path; | ||
std::string input_profile_path; | ||
std::string output_profile_path; | ||
std::string enqueue_profile_path; | ||
std::string profile_path = "/tmp"; | ||
|
||
std::unordered_map<uint64_t, uint64_t> in_binding_map; | ||
std::unordered_map<uint64_t, uint64_t> out_binding_map; | ||
|
||
#ifndef NDEBUG | ||
bool debug = true; | ||
#else | ||
bool debuf = false; | ||
#endif | ||
|
||
~TRTEngine() = default; | ||
TRTEngine(std::string serialized_engine, CUDADevice cuda_device); | ||
TRTEngine(std::vector<std::string> serialized_info); | ||
TRTEngine(std::string mod_name, std::string serialized_engine, CUDADevice cuda_device); | ||
TRTEngine& operator=(const TRTEngine& other); | ||
std::string to_str() const; | ||
void set_paths(); | ||
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine); | ||
// TODO: Implement a call method | ||
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs); | ||
}; | ||
|
||
} // namespace torch_tensorrt | ||
} // namespace core | ||
} // namespace runtime |
Oops, something went wrong.