From d7ee6e7ccb02e661ef924a5841717c3a0e4c6ce2 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Tue, 6 Aug 2024 18:46:13 -0600 Subject: [PATCH] fix: Fix the CUDAGraphs C++ runtime implementation Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .github/workflows/build-test-linux.yml | 63 +++- .github/workflows/build-test-windows.yml | 8 +- core/conversion/var/Var.cpp | 4 +- core/runtime/TRTEngine.cpp | 11 +- core/runtime/TRTEngine.h | 4 +- core/runtime/execute_engine.cpp | 320 ++++++++++-------- .../dynamo/conversion/impl/elementwise/ops.py | 4 +- .../runtime/_PythonTorchTensorRTModule.py | 318 +++++++++-------- py/torch_tensorrt/dynamo/types.py | 2 +- py/torch_tensorrt/runtime/cudagraphs.py | 4 +- .../runtime/multi_device_safe_mode.py | 8 +- ...gs.py => test_000_compilation_settings.py} | 0 ...er_utils.py => test_000_compiler_utils.py} | 0 ... test_000_convert_module_to_trt_engine.py} | 0 ..._runtime.py => test_000_python_runtime.py} | 0 tests/py/dynamo/runtime/test_001_streams.py | 66 ++++ .../dynamo/runtime/test_002_cudagraphs_cpp.py | 154 +++++++++ .../dynamo/runtime/test_002_cudagraphs_py.py | 164 +++++++++ ...e_init.py => test_002_lazy_engine_init.py} | 0 ...est_safe_mode.py => test_003_safe_mode.py} | 9 +- tests/py/dynamo/runtime/test_cudagraphs.py | 202 ----------- 21 files changed, 804 insertions(+), 537 deletions(-) rename tests/py/dynamo/runtime/{test_compilation_settings.py => test_000_compilation_settings.py} (100%) rename tests/py/dynamo/runtime/{test_compiler_utils.py => test_000_compiler_utils.py} (100%) rename tests/py/dynamo/runtime/{test_convert_module_to_trt_engine.py => test_000_convert_module_to_trt_engine.py} (100%) rename tests/py/dynamo/runtime/{test_python_runtime.py => test_000_python_runtime.py} (100%) create mode 100644 tests/py/dynamo/runtime/test_001_streams.py create mode 100644 tests/py/dynamo/runtime/test_002_cudagraphs_cpp.py create mode 100644 tests/py/dynamo/runtime/test_002_cudagraphs_py.py rename tests/py/dynamo/runtime/{test_lazy_engine_init.py => test_002_lazy_engine_init.py} (100%) rename tests/py/dynamo/runtime/{test_safe_mode.py => test_003_safe_mode.py} (93%) delete mode 100644 tests/py/dynamo/runtime/test_cudagraphs.py diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 7476309266..2f04670583 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -8,9 +8,9 @@ on: - nightly - release/* tags: - # NOTE: Binary build pipelines should only get triggered on release candidate builds - # Release candidate tags look like: v1.11.0-rc1 - - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ workflow_dispatch: jobs: @@ -84,9 +84,9 @@ jobs: popd pushd . cd tests/py/ts - python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_api_test_results.xml api/ - python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_models_test_results.xml models/ - python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_integrations_test_results.xml integrations/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_api_test_results.xml api/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_models_test_results.xml models/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_integrations_test_results.xml integrations/ popd tests-py-dynamo-converters: @@ -114,7 +114,7 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/dynamo - python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/ popd tests-py-dynamo-fe: @@ -142,8 +142,8 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/dynamo - python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py - python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py popd tests-py-dynamo-serde: @@ -171,7 +171,7 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/dynamo - python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py popd tests-py-torch-compile-be: @@ -199,9 +199,9 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/dynamo - python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/ - python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_complete_be_e2e_test_results.xml --ir torch_compile models/test_models.py - python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py + python -m pytest -ra -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/ + python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_complete_be_e2e_test_results.xml --ir torch_compile models/test_models.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py popd tests-py-dynamo-core: @@ -229,9 +229,38 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/dynamo - python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/ - python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/ - python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/ + python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml --ignore runtime/test_002_cudagraphs_py.py --ignore runtime/test_002_cudagraphs_cpp.py runtime/ + python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/ + python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/ + popd + + tests-py-dynamo-cudagraphs: + name: Test dynamo cudagraphs [Python] + needs: [generate-matrix, build] + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/tensorrt + package-name: torch_tensorrt + pre-script: packaging/pre_build_script.sh + post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh + uses: ./.github/workflows/linux-test.yml + with: + job-name: tests-py-dynamo-cudagraphs + repository: "pytorch/tensorrt" + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + script: | + export USE_HOST_DEPS=1 + pushd . + cd tests/py/dynamo + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_cudagraphs_cpp_test_results.xml runtime/test_002_cudagraphs_cpp.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_cudagraphs_py_test_results.xml runtime/test_002_cudagraphs_py.py popd tests-py-core: @@ -259,7 +288,7 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/core - python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml . + python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml . popd concurrency: diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index b3f0352042..ee28c0314f 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -8,9 +8,9 @@ on: - nightly - release/* tags: - # NOTE: Binary build pipelines should only get triggered on release candidate builds - # Release candidate tags look like: v1.11.0-rc1 - - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ workflow_dispatch: jobs: @@ -219,7 +219,7 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/dynamo - python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/ + python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/ python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/ python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/ popd diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index 0444663830..bb96812ac3 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -153,7 +153,7 @@ bool Var::isITensorList() { // Unpack the Var as a List and check if each entry is a custom class since // ITensors are stored in CustomClassHolder auto ival_list = ptr_.ivalue->toList(); - for (int i = 0; i < ival_list.size(); i++) { + for (size_t i = 0; i < ival_list.size(); i++) { if (!ival_list.get(i).isCustomClass()) { return false; } @@ -167,7 +167,7 @@ std::vector Var::unwrapToITensorList() { TORCHTRT_CHECK(isITensorList(), "Expected IValue to be an ITensorList"); auto ivalue_list = ptr_.ivalue->toList(); std::vector outputs; - for (int i = 0; i < ivalue_list.size(); i++) { + for (size_t i = 0; i < ivalue_list.size(); i++) { auto element = ivalue_list.get(i).toCustomClass()->tensor(); outputs.push_back(std::move(element)); } diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 27880ed302..b40e7c8413 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -71,15 +71,6 @@ TRTEngine::TRTEngine( multi_gpu_device_check(); set_rt_device(device_info); - // Set active stream to non-default stream - auto current_stream = c10::cuda::getCurrentCUDAStream(device_info.id); - if (current_stream == c10::cuda::getDefaultCUDAStream(device_info.id)) { - active_stream = c10::cuda::getStreamFromPool(false, device_info.id); - c10::cuda::setCurrentCUDAStream(active_stream); - } else { - active_stream = current_stream; - } - rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); name = slugify(mod_name); @@ -205,6 +196,7 @@ TRTEngine::TRTEngine( } TRTEngine::~TRTEngine() { + cudagraph.reset(); trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); @@ -253,6 +245,7 @@ void TRTEngine::set_profiling_paths() { enqueue_profile_path = std::filesystem::path{profile_path_prefix + "/" + name + "_enqueue_profile.trace"}.string(); trt_engine_profile_path = std::filesystem::path{profile_path_prefix + "/" + name + "_engine_exectuion_profile.trace"}.string(); + cuda_graph_debug_path = std::filesystem::path{profile_path_prefix + "/" + name + "_cudagraph.dot"}.string(); } std::string TRTEngine::to_str() const { diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 1c900e3f34..cffe3bf122 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -70,7 +70,8 @@ struct TRTEngine : torch::CustomClassHolder { // CUDAGraph-Related Functionality at::cuda::CUDAGraph cudagraph = {}; - at::cuda::CUDAStream active_stream = c10::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream(); std::vector input_buffers = {}; std::vector output_buffers = {}; std::string shape_key; @@ -89,6 +90,7 @@ struct TRTEngine : torch::CustomClassHolder { std::string output_profile_path; std::string enqueue_profile_path; std::string trt_engine_profile_path; + std::string cuda_graph_debug_path; std::mutex mu; std::unique_ptr trt_engine_profiler; }; diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 82b868d131..ef5585e723 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -1,3 +1,4 @@ +#include "ATen/cuda/CUDAEvent.h" #include "c10/cuda/CUDAGuard.h" #include "c10/cuda/CUDAStream.h" @@ -70,7 +71,7 @@ bool _cudagraphs_validate_shapes(std::vector inputs, c10::intrusive_ new_shape_key_ss << "("; auto sizes = input.sizes(); auto rank = input.sizes().size(); - for (auto i = 0; i < rank; i++) { + for (size_t i = 0; i < rank; i++) { new_shape_key_ss << sizes[i]; // For all but the final dimension in the shape key, add comma separator if (i < rank - 1) { @@ -107,139 +108,156 @@ std::vector execute_engine(std::vector inputs, c10::intr ss << " Output packing profile: " << compiled_engine->output_profile_path << std::endl; ss << " TRT enqueue profile: " << compiled_engine->enqueue_profile_path << std::endl; ss << " Engine execution profile: " << compiled_engine->trt_engine_profile_path << std::endl; + ss << " CUDA Graph trace: " << compiled_engine->cuda_graph_debug_path << std::endl; auto log_info = ss.str(); LOG_INFO("" << log_info); + compiled_engine->cudagraph.enable_debug_mode(); } // Whether cudagraphs needs to record the graph on this pass - bool need_cudagraphs_record = (CUDAGRAPHS_MODE && !_cudagraphs_validate_shapes(inputs, compiled_engine)); + bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine))); + + if (!CUDAGRAPHS_MODE) { + compiled_engine->cudagraph.reset(); + } // this is a buffer to store shape tensor input addresses throughout the runtime scope std::list> inputShapeTensorValues; - // Intialize outputs to be available throughout the succeeding scopes + // Intialize inputs and outputs to be available throughout the succeeding scopes + std::list formatted_inputs(compiled_engine->num_io.first); std::vector outputs(compiled_engine->num_io.second); - // If not in cudagraphs mode or a new cudagraphs recording is needed - // proceed with input validation and assignment of new I/O pointers for TRT - if (!CUDAGRAPHS_MODE || need_cudagraphs_record) { - if (MULTI_DEVICE_SAFE_MODE) { - std::unique_ptr device_profiler_guard; - if (compiled_engine->profile_execution) { - device_profiler_guard = - std::make_unique(compiled_engine->device_profile_path); + if (MULTI_DEVICE_SAFE_MODE) { + std::unique_ptr device_profiler_guard; + if (compiled_engine->profile_execution) { + device_profiler_guard = + std::make_unique(compiled_engine->device_profile_path); + } + + RTDevice curr_device = get_current_device(); + LOG_DEBUG("Current Device: " << curr_device); + + // Generic Target Device Prefix + std::string target_device = "cuda:"; + + if (is_switch_required(curr_device, compiled_engine->device_info)) { + // Scan through available CUDA devices and set the CUDA device context correctly + RTDevice device = + select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible); + set_rt_device(device); + + // Target device is new device + target_device += std::to_string(device.id); + + for (auto& in : inputs) { + in = in.to(torch::Device(target_device)); } + } else { + // Target device is current device + target_device += std::to_string(curr_device.id); + } - RTDevice curr_device = get_current_device(); - LOG_DEBUG("Current Device: " << curr_device); + // For each input, ensure its current device is the desired target device + for (size_t i = 0; i < inputs.size(); i++) { + at::Tensor* in = &inputs[i]; + std::string current_tensor_device = in->device().str(); + + // If current device string does not match target device, display warning and move tensor accordingly + if (current_tensor_device != target_device) { + LOG_WARNING( + "Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device + << " but should be on " << target_device << ". This tensor is being moved by the runtime but " + << "for performance considerations, ensure your inputs are all on GPU " + << "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this " + << "warning persists."); + *in = in->to(torch::Device(target_device)); + } + } + } - // Generic Target Device Prefix - std::string target_device = "cuda:"; + { // Input Setup + std::unique_ptr input_profiler_guard; + if (compiled_engine->profile_execution) { + input_profiler_guard = + std::make_unique(compiled_engine->input_profile_path); + } - if (is_switch_required(curr_device, compiled_engine->device_info)) { - // Scan through available CUDA devices and set the CUDA device context correctly - RTDevice device = - select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible); - set_rt_device(device); + for (size_t i = 0; i < inputs.size(); i++) { + std::string name = compiled_engine->in_binding_names[i]; - // Update active stream based on new device - auto current_stream = c10::cuda::getCurrentCUDAStream(device.id); - if (current_stream == c10::cuda::getDefaultCUDAStream(device.id)) { - compiled_engine->active_stream = c10::cuda::getStreamFromPool(false, device.id); - c10::cuda::setCurrentCUDAStream(compiled_engine->active_stream); - } else { - compiled_engine->active_stream = current_stream; - } + TORCHTRT_CHECK( + inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device()); - // Target device is new device - target_device += std::to_string(device.id); + auto expected_type = + util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); + TORCHTRT_CHECK( + inputs[i].dtype() == expected_type, + "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype()); + + auto dims = core::util::toDims(inputs[i].sizes()); + auto shape = core::util::toVec(dims); + LOG_DEBUG("Input Name: " << name << " Shape: " << dims); + + if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) { + // Shape tensor inputs are casted to int64 explicitly. + // Refer to + // https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 + auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt64); + std::vector inputs_cpu_vec( + input_cpu.data_ptr(), input_cpu.data_ptr() + input_cpu.numel()); + inputShapeTensorValues.emplace_back(inputs_cpu_vec); + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()), + "Error while setting the tensor address for shape inputs"); - for (auto& in : inputs) { - in = in.to(torch::Device(target_device)); + if (CUDAGRAPHS_MODE) { + // @peri044 I dont know if this makes sense since they are supposed to be GPU buffers + compiled_engine->input_buffers[i] = input_cpu; } + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()), + "Error while setting the tensor address for shape inputs"); + } else { - // Target device is current device - target_device += std::to_string(curr_device.id); - } + at::Tensor contig_input = inputs[i].view(shape).contiguous(); + formatted_inputs.emplace_back(std::move(contig_input)); - // For each input, ensure its current device is the desired target device - for (size_t i = 0; i < inputs.size(); i++) { - at::Tensor* in = &inputs[i]; - std::string current_tensor_device = in->device().str(); - - // If current device string does not match target device, display warning and move tensor accordingly - if (current_tensor_device != target_device) { - LOG_WARNING( - "Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device - << " but should be on " << target_device << ". This tensor is being moved by the runtime but " - << "for performance considerations, ensure your inputs are all on GPU " - << "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this " - << "warning persists."); - *in = in->to(torch::Device(target_device)); + if (need_cudagraphs_record) { + // Create a new persistent input buffer + compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone()); } - } - } - { - std::unique_ptr input_profiler_guard; - if (compiled_engine->profile_execution) { - input_profiler_guard = - std::make_unique(compiled_engine->input_profile_path); - } - for (size_t i = 0; i < inputs.size(); i++) { - std::string name = compiled_engine->in_binding_names[i]; TORCHTRT_CHECK( - inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device()); - auto expected_type = - util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); - TORCHTRT_CHECK( - inputs[i].dtype() == expected_type, - "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype()); - auto dims = core::util::toDims(inputs[i].sizes()); - auto shape = core::util::toVec(dims); - LOG_DEBUG("Input Name: " << name << " Shape: " << dims); - at::Tensor contig_input; - - if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) { - // Shape tensor inputs are casted to int64 explicitly. - // Refer to - // https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 - auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt64); - std::vector inputs_cpu_vec( - input_cpu.data_ptr(), input_cpu.data_ptr() + input_cpu.numel()); - inputShapeTensorValues.emplace_back(inputs_cpu_vec); + compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); + + if (CUDAGRAPHS_MODE) { + // If using CUDAGraphs copy formatted input to the corresponding persistent input buffer + compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true); TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()), - "Error while setting the tensor address for shape inputs"); - compiled_engine->input_buffers[i] = input_cpu; + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), compiled_engine->input_buffers[i].data_ptr()), + "Error while setting the input tensor address for inputs"); } else { - // If in cudagraphs mode, the inputs must be cloned since the memory will be reused - // in subsequent replays of the graph - if (CUDAGRAPHS_MODE) { - contig_input = inputs[i].view(shape).contiguous().clone(); - compiled_engine->input_buffers[i] = contig_input; - } else { - contig_input = inputs[i].view(shape).contiguous(); - } + // Otherwise use the formatted buffer directly TORCHTRT_CHECK( - compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), contig_input.data_ptr()), + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), formatted_inputs.back().data_ptr()), "Error while setting the input tensor address for inputs"); } } - - // Check if input shapes can be inferred. - int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; - std::vector names(io_size); - int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data()); - TORCHTRT_CHECK( - nbNames == 0, - "The shapes of the inputs: " - << names - << " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly"); } + // Check if input shapes can be inferred. + int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; + std::vector names(io_size); + int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data()); + TORCHTRT_CHECK( + nbNames == 0, + "The shapes of the inputs: " + << names + << " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly"); + } + + { // Output Setup std::unique_ptr output_profiler_guard; if (compiled_engine->profile_execution) { output_profiler_guard = @@ -253,63 +271,87 @@ std::vector execute_engine(std::vector inputs, c10::intr std::string name = compiled_engine->out_binding_names[pyt_idx]; auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str()); LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape); + auto dims = core::util::toVec(out_shape); auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous()); - // In cudagraphs mode, the allocated output buffers are stored for reuse + if (need_cudagraphs_record) { + // If we are recording the cuda graph then we need to update the persistent output buffer + compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); + } + if (CUDAGRAPHS_MODE) { - compiled_engine->output_buffers[pyt_idx] = outputs[pyt_idx]; + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress( + name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), + "Error while setting the output tensor address"); + } else { + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), + "Error while setting the output tensor address"); } - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), - "Error while setting the output tensor address"); } } - std::unique_ptr enqueue_profiler_guard; + auto current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart + + compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); + if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { + // Create a new stream if the engine stream is the default stream + compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); + } else { + compiled_engine->engine_stream = compiled_engine->caller_stream; + } // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it. std::unique_lock lock(compiled_engine->mu); - if (!CUDAGRAPHS_MODE) { - // If not in cudagraphs mode, proceed with enqueueV3 as normal - compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream); - } else if (need_cudagraphs_record) { - // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph - - // Cudagraphs cannot record on the current stream, so use an alternate - c10::cuda::CUDAStream recording_stream = c10::cuda::getStreamFromPool(false, inputs[0].device().index()); - c10::cuda::CUDAStreamGuard guard(recording_stream); - - compiled_engine->exec_ctx->enqueueV3(recording_stream); - recording_stream.synchronize(); + { // Engine Execution (execute on engine stream) + c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); - compiled_engine->cudagraph.capture_begin(); - compiled_engine->exec_ctx->enqueueV3(recording_stream); - compiled_engine->cudagraph.capture_end(); + std::unique_ptr enqueue_profiler_guard; + if (compiled_engine->profile_execution) { + enqueue_profiler_guard = + std::make_unique(compiled_engine->enqueue_profile_path); + } - // Reset the stream to its original setting - guard.reset_stream(guard.original_stream()); + // Block engine stream until results are available on caller stream + at::cuda::CUDAEvent caller_exec_complete; + caller_exec_complete.record(compiled_engine->caller_stream); + caller_exec_complete.block(compiled_engine->engine_stream); + + if (!CUDAGRAPHS_MODE) { + // Direct execution uses the caller buffers directly + compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); + } else { + if (need_cudagraphs_record) { + // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph + c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream; + compiled_engine->cudagraph.capture_begin(); + compiled_engine->exec_ctx->enqueueV3(recording_stream); + compiled_engine->cudagraph.capture_end(); + + if (compiled_engine->profile_execution) { + compiled_engine->cudagraph.debug_dump(compiled_engine->cuda_graph_debug_path); + } + } - } else { - // If the cudagraph has already been recorded, copy the input buffers and replay it - for (auto i = 0; i < inputs.size(); i++) { - compiled_engine->input_buffers[i].copy_(inputs[i], true); + // Replay the CUDAGraph + compiled_engine->cudagraph.replay(); // Has a cudaDeviceSynchronize internally } - compiled_engine->cudagraph.replay(); - } + } // End engine exeuction (resets to caller stream) - std::vector model_outputs(compiled_engine->num_io.second); + // Block caller stream until engine execution is complete + at::cuda::CUDAEvent trt_exec_complete; + trt_exec_complete.record(compiled_engine->engine_stream); + trt_exec_complete.block(compiled_engine->caller_stream); - // In cudagraphs mode, the output buffers can be reused, so they must - // be cloned before providing them to the user to avoid data corruption if (CUDAGRAPHS_MODE) { - for (auto i = 0; i < compiled_engine->output_buffers.size(); i++) { - model_outputs[i] = compiled_engine->output_buffers[i].clone(); + // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) + for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { + outputs[o].copy_(compiled_engine->output_buffers[o], false); } - } else { - model_outputs = outputs; } if (compiled_engine->profile_execution) { @@ -318,7 +360,7 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->dump_engine_layer_info(); } - return model_outputs; + return outputs; } } // namespace runtime diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 40c4dfcb3e..c7502cf97e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -1,8 +1,6 @@ from typing import Optional, Union import numpy as np -from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape -import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target @@ -24,6 +22,8 @@ from torch_tensorrt.fx.converters.converter_utils import broadcast from torch_tensorrt.fx.types import TRTTensor +import tensorrt as trt + def trunc_div( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 659f18af52..65495f29c2 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -2,6 +2,7 @@ import logging from contextlib import nullcontext +from tempfile import tempdir from typing import Any, Dict, List, Optional, Sequence, Tuple import torch @@ -71,10 +72,11 @@ def __init__( multi_gpu_device_check() self.name = name - self.input_buffers: List[torch.Tensor] = [] - self.output_buffers: List[torch.Tensor] = [] + self._input_buffers: List[torch.Tensor] = [] + self._output_buffers: List[torch.Tensor] = [] self.cudagraph: Optional[torch.cuda.CUDAGraph] = None - self.active_stream: Optional[torch.cuda.Stream] = None + self._caller_stream: Optional[torch.cuda.Stream] = None + self._engine_stream: Optional[torch.cuda.Stream] = None # TODO: Make the below a Dictionary {shape: cudagraph} self.shape_key: Optional[str] = None @@ -134,15 +136,6 @@ def setup_engine(self) -> None: if torch_tensorrt.runtime.get_cudagraphs_mode(): self.cudagraph = torch.cuda.CUDAGraph() - self.graph_capturer = torch.cuda.graphs.graph(self.cudagraph) - - # Set the active stream using the current device - current_stream = torch.cuda.current_stream() - if current_stream == torch.cuda.default_stream(): - self.active_stream = torch.cuda.Stream() - torch.cuda.set_stream(self.active_stream) - else: - self.active_stream = current_stream def _check_initialized(self) -> None: if not self.initialized: @@ -192,12 +185,17 @@ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: result.__setstate__(self.__getstate__()) return result + def __del__(self) -> None: + if self.cudagraph: + self.cudagraph.reset() + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: # Ensure inputs are available in all scopes and cast symbolic integers to Tensors contiguous_inputs: List[torch.Tensor] = [ (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) for i in inputs ] + with ( torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") if self.profiling_enabled @@ -210,149 +208,152 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs) ) - # If cudagraphs is not enabled or the recorded graph shapes are either uninitialized or invalid - if not cudagraphs_enabled or need_cudagraphs_record: - # If in safe mode, check at each iteration for for whether a switch is required - if ( - torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE - ): - curr_device_id = torch.cuda.current_device() - curr_device_properties = torch.cuda.get_device_properties( - curr_device_id - ) - logger.debug(f"Current Device: cuda:{curr_device_id}") + if need_cudagraphs_record: + self._input_buffers = [None] * len(self.input_names) + self._output_buffers = [None] * len(self.output_names) - # If a switch is required, move all inputs to new device and set as active device - if _is_switch_required( + if not cudagraphs_enabled and self.cudagraph: + self.cudagraph.reset() + self.cudagraph = None + + # If in safe mode, check at each iteration for for whether a switch is required + if ( + torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE + ): + curr_device_id = torch.cuda.current_device() + curr_device_properties = torch.cuda.get_device_properties( + curr_device_id + ) + logger.debug(f"Current Device: cuda:{curr_device_id}") + + # If a switch is required, move all inputs to new device and set as active device + if _is_switch_required( + curr_device_id, + self.target_device_id, + curr_device_properties, + self.target_device_properties, + ): + device_id, _ = _select_rt_device( curr_device_id, self.target_device_id, - curr_device_properties, self.target_device_properties, - ): - device_id, _ = _select_rt_device( - curr_device_id, - self.target_device_id, - self.target_device_properties, - ) + ) - # Update current device - device = torch.device(device_id) - torch.cuda.set_device(device_id) + # Update current device + device = torch.device(device_id) + torch.cuda.set_device(device_id) - # Update current stream - current_stream = torch.cuda.current_stream(device) - if current_stream == torch.cuda.default_stream(device): - self.active_stream = torch.cuda.Stream(device) - torch.cuda.set_stream(self.active_stream) - else: - self.active_stream = current_stream + contiguous_inputs = [ + tensor.to(device) for tensor in contiguous_inputs + ] + logger.warning(f"Moved all input Tensors to cuda:{device_id}") + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessInputs" + ) + if self.profiling_enabled + else nullcontext() + ): + assert len(contiguous_inputs) == len( + self.input_names + ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." - contiguous_inputs = [ - tensor.to(device) for tensor in contiguous_inputs - ] - logger.warning(f"Moved all input Tensors to cuda:{device_id}") + for i, input_name in enumerate(self.input_names): + if not contiguous_inputs[i].is_cuda: + logger.warning( + f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. " + "This tensor is being moved by the runtime but for performance considerations, " + "ensure your inputs are all on GPU and open an issue here " + "(https://github.com/pytorch/TensorRT/issues) if this warning persists." + ) + contiguous_inputs = ( + contiguous_inputs[:i] + + [contiguous_inputs[i].cuda()] + + contiguous_inputs[i + 1 :] + ) - with ( - torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessInputs" - ) - if self.profiling_enabled - else nullcontext() - ): - assert len(contiguous_inputs) == len( - self.input_names - ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." + assert ( + contiguous_inputs[i].dtype == self.input_dtypes[i] + ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." - if cudagraphs_enabled: + if need_cudagraphs_record: # If cudagraphs is enabled, this memory is reserved for future cudagraph runs # Clone is required to avoid re-using user-provided GPU memory - contiguous_inputs = [i.clone() for i in contiguous_inputs] - - bindings = [] - for i, input_name in enumerate(self.input_names): - if not contiguous_inputs[i].is_cuda: - logger.warning( - f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. " - "This tensor is being moved by the runtime but for performance considerations, " - "ensure your inputs are all on GPU and open an issue here " - "(https://github.com/pytorch/TensorRT/issues) if this warning persists." - ) - contiguous_inputs = ( - contiguous_inputs[:i] - + [contiguous_inputs[i].cuda()] - + contiguous_inputs[i + 1 :] - ) - - assert ( - contiguous_inputs[i].dtype == self.input_dtypes[i] - ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." - - # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers - # as per TensorRT requirements - if self.engine.is_shape_inference_io(input_name): - # Shape tensor inputs are casted to int64 explicitly - # Currently Torch CPU pointers are not working; numpy pointers are used instead - # to refer to underlying memory - inputs_cpu = ( - contiguous_inputs[i] - .cpu() - .to(torch.int64) - .numpy() - .copy() - ) + self._input_buffers[i] = contiguous_inputs[i].clone() + + # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers + # as per TensorRT requirements + if self.engine.is_shape_inference_io(input_name): + # Shape tensor inputs are casted to int64 explicitly + # Currently Torch CPU pointers are not working; numpy pointers are used instead + # to refer to underlying memory + inputs_cpu = ( + contiguous_inputs[i].cpu().to(torch.int64).numpy().copy() + ) + self.context.set_tensor_address( + input_name, inputs_cpu.ctypes.data + ) + else: + self.context.set_input_shape( + input_name, tuple(contiguous_inputs[i].shape) + ) + if cudagraphs_enabled: + self._input_buffers[i].copy_(contiguous_inputs[i]) self.context.set_tensor_address( - input_name, inputs_cpu.ctypes.data + input_name, self._input_buffers[i].data_ptr() ) - bindings.append(inputs_cpu.ctypes.data) else: - self.context.set_input_shape( - input_name, tuple(contiguous_inputs[i].shape) - ) self.context.set_tensor_address( input_name, contiguous_inputs[i].data_ptr() ) - bindings.append(contiguous_inputs[i].data_ptr()) - # Check if input shapes can be inferred. - uninferred_input_names = self.context.infer_shapes() - if uninferred_input_names: - logger.warning( - f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \ - This could happen if the input tensor addresses/shapes haven't been configured correctly" - ) - - with ( - torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessOutputs" + # Check if input shapes can be inferred. + uninferred_input_names = self.context.infer_shapes() + if uninferred_input_names: + logger.warning( + f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \ + This could happen if the input tensor addresses/shapes haven't been configured correctly" ) - if self.profiling_enabled - else nullcontext() - ): - # create output tensors - outputs: List[torch.Tensor] = [] - for i, output_name in enumerate(self.output_names): - shape = tuple(self.context.get_tensor_shape(output_name)) + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessOutputs" + ) + if self.profiling_enabled + else nullcontext() + ): + # create output tensors + outputs: List[torch.Tensor] = [] - if DYNAMIC_DIM in shape: - raise ValueError( - "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." - ) + for o, output_name in enumerate(self.output_names): + shape = tuple(self.context.get_tensor_shape(output_name)) - output = torch.empty( - size=shape, - dtype=self.output_dtypes[i].to(torch.dtype), - device=torch.cuda.current_device(), + if DYNAMIC_DIM in shape: + raise ValueError( + "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." ) - bindings.append(output.data_ptr()) - outputs.append(output) - # Assign tensor address appropriately - for idx in range(self.engine.num_io_tensors): - self.context.set_tensor_address( - self.engine.get_tensor_name(idx), bindings[idx] + output = torch.empty( + size=shape, + dtype=self.output_dtypes[o].to(torch.dtype), + device=torch.cuda.current_device(), ) + outputs.append(output) + + if need_cudagraphs_record: + self._output_buffers[o] = outputs[o].clone() + + if cudagraphs_enabled: + self.context.set_tensor_address( + output_name, self._output_buffers[o].data_ptr() + ) + else: + self.context.set_tensor_address( + output_name, outputs[o].data_ptr() + ) + with ( torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:TensorRTRuntime" @@ -360,37 +361,56 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if self.profiling_enabled else nullcontext() ): + self._caller_stream = torch.cuda.current_stream() + if ( + self._engine_stream == torch.cuda.default_stream() + or self._engine_stream is None + ): + self._engine_stream = torch.cuda.Stream() + else: + self._engine_stream = self._caller_stream - if not cudagraphs_enabled: - self.context.execute_async_v3(self.active_stream.cuda_stream) # type: ignore + self._engine_stream.wait_stream(self._caller_stream) - elif need_cudagraphs_record: - self.input_buffers = list(contiguous_inputs) - self.output_buffers = list(outputs) + with torch.cuda.stream(self._engine_stream): - graph_capturer_stream = self.graph_capturer.capture_stream + if cudagraphs_enabled: + if need_cudagraphs_record: + self.cudagraph = torch.cuda.CUDAGraph() - self.context.execute_async_v3(graph_capturer_stream.cuda_stream) - graph_capturer_stream.synchronize() + if self.profiling_enabled: + self.cudagraph.enable_debug_mode() - with self.graph_capturer: - self.context.execute_async_v3(graph_capturer_stream.cuda_stream) + with torch.cuda.graph( + self.cudagraph, stream=self._engine_stream + ): + self.context.execute_async_v3( + self._engine_stream.cuda_stream + ) - else: - for idx, input_tensor in enumerate(inputs): - self.input_buffers[idx].copy_(input_tensor, non_blocking=True) + if self.profiling_enabled: + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + self.cudagraph.debug_dump( + f"{tempdir}/{self.name}_cudagraph.dot" + ) + + self.cudagraph.replay() # type: ignore + + else: + self.context.execute_async_v3(self._engine_stream.cuda_stream) - self.cudagraph.replay() # type: ignore + self._caller_stream.wait_stream(self._engine_stream) if cudagraphs_enabled: - model_outputs = tuple(output.clone() for output in self.output_buffers) - else: - model_outputs = tuple(outputs) + for idx, o in enumerate(outputs): + o.copy_(self._output_buffers[idx]) - if len(model_outputs) == 1: - return model_outputs[0] + if len(outputs) == 1: + return outputs[0] - return model_outputs + return outputs def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None: """ diff --git a/py/torch_tensorrt/dynamo/types.py b/py/torch_tensorrt/dynamo/types.py index e91addcb86..a6f514e727 100644 --- a/py/torch_tensorrt/dynamo/types.py +++ b/py/torch_tensorrt/dynamo/types.py @@ -1,4 +1,4 @@ -from typing import Sequence, Tuple +from typing import Sequence # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt diff --git a/py/torch_tensorrt/runtime/cudagraphs.py b/py/torch_tensorrt/runtime/cudagraphs.py index 56f8b82a73..9d1523ef2e 100644 --- a/py/torch_tensorrt/runtime/cudagraphs.py +++ b/py/torch_tensorrt/runtime/cudagraphs.py @@ -1,9 +1,7 @@ import logging -from importlib.util import find_spec from typing import Any import torch - import torch_tensorrt if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: @@ -21,7 +19,7 @@ def set_cudagraphs_mode(mode: bool) -> None: _PY_RT_CUDAGRAPHS = mode # Set new mode for C++ - if find_spec("torch_tensorrt._C") is not None: + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: torch.ops.tensorrt.set_cudagraphs_mode(mode) logger.info(f"Set Cudagraphs usage to {mode}") diff --git a/py/torch_tensorrt/runtime/multi_device_safe_mode.py b/py/torch_tensorrt/runtime/multi_device_safe_mode.py index 547868edf6..15b0967810 100644 --- a/py/torch_tensorrt/runtime/multi_device_safe_mode.py +++ b/py/torch_tensorrt/runtime/multi_device_safe_mode.py @@ -1,10 +1,10 @@ import logging -from importlib.util import find_spec from typing import Any import torch +import torch_tensorrt -if find_spec("torch_tensorrt._C") is not None: +if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: _PY_RT_MULTI_DEVICE_SAFE_MODE = torch.ops.tensorrt.get_multi_device_safe_mode() else: _PY_RT_MULTI_DEVICE_SAFE_MODE = False @@ -31,7 +31,7 @@ def __exit__(self, *args: Any) -> None: _PY_RT_MULTI_DEVICE_SAFE_MODE = self.old_mode # Set multi-device safe mode back to old mode in C++ - if find_spec("torch_tensorrt._C") is not None: + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: torch.ops.tensorrt.set_multi_device_safe_mode(self.old_mode) @@ -60,7 +60,7 @@ def set_multi_device_safe_mode(mode: bool) -> _MultiDeviceSafeModeContextManager _PY_RT_MULTI_DEVICE_SAFE_MODE = mode # Set new mode for C++ - if find_spec("torch_tensorrt._C") is not None: + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: torch.ops.tensorrt.set_multi_device_safe_mode(mode) logger.info(f"Set multi-device safe mode to {mode}") diff --git a/tests/py/dynamo/runtime/test_compilation_settings.py b/tests/py/dynamo/runtime/test_000_compilation_settings.py similarity index 100% rename from tests/py/dynamo/runtime/test_compilation_settings.py rename to tests/py/dynamo/runtime/test_000_compilation_settings.py diff --git a/tests/py/dynamo/runtime/test_compiler_utils.py b/tests/py/dynamo/runtime/test_000_compiler_utils.py similarity index 100% rename from tests/py/dynamo/runtime/test_compiler_utils.py rename to tests/py/dynamo/runtime/test_000_compiler_utils.py diff --git a/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py b/tests/py/dynamo/runtime/test_000_convert_module_to_trt_engine.py similarity index 100% rename from tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py rename to tests/py/dynamo/runtime/test_000_convert_module_to_trt_engine.py diff --git a/tests/py/dynamo/runtime/test_python_runtime.py b/tests/py/dynamo/runtime/test_000_python_runtime.py similarity index 100% rename from tests/py/dynamo/runtime/test_python_runtime.py rename to tests/py/dynamo/runtime/test_000_python_runtime.py diff --git a/tests/py/dynamo/runtime/test_001_streams.py b/tests/py/dynamo/runtime/test_001_streams.py new file mode 100644 index 0000000000..574db6611e --- /dev/null +++ b/tests/py/dynamo/runtime/test_001_streams.py @@ -0,0 +1,66 @@ +import unittest + +import torch +import torch_tensorrt +from torch import nn +from torch.testing._internal.common_utils import TestCase, run_tests + +from ..testing_utilities import DECIMALS_OF_AGREEMENT + +INPUT_SIZE = (10, 10, 10) +TRIALS = 10 + + +class TestStreams(TestCase): + + def test_non_default_stream_exec(self): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + with torch.inference_mode(): + dtype = torch.half + device = torch.device("cuda", 0) + model = SampleModel().eval().to(device) + inputs = [torch_tensorrt.Input(shape=(1, 3, 5), dtype=dtype)] + + optimized_model = torch_tensorrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={dtype}, + min_block_size=1, + device=device, + ) + + for i in range(100): + new_input = torch.randn((1, 3, 5), dtype=dtype, device=device) + + eager_output = model(new_input) + + stream = torch.cuda.Stream(device=device) + stream.wait_stream(torch.cuda.current_stream(device=device)) + with torch.cuda.stream(stream): + trt_output_with_stream = optimized_model(new_input) + torch.cuda.current_stream(device=device).wait_stream(stream) + + trt_output_without_stream = optimized_model(new_input) + + max_diff_w_stream = float( + torch.max(torch.abs(eager_output - trt_output_with_stream)) + ) + max_diff_wo_stream = float( + torch.max(torch.abs(eager_output - trt_output_without_stream)) + ) + self.assertAlmostEqual( + max_diff_w_stream, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Output using a non default calling stream does not match original model (trial: {i})", + ) + self.assertAlmostEqual( + max_diff_wo_stream, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Output using default stream as calling stream does not match original model (trial: {i})", + ) diff --git a/tests/py/dynamo/runtime/test_002_cudagraphs_cpp.py b/tests/py/dynamo/runtime/test_002_cudagraphs_cpp.py new file mode 100644 index 0000000000..e4a95c73dc --- /dev/null +++ b/tests/py/dynamo/runtime/test_002_cudagraphs_cpp.py @@ -0,0 +1,154 @@ +import itertools +import os +import unittest + +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests + +from ..testing_utilities import DECIMALS_OF_AGREEMENT + +INPUT_SIZE = (3, 16, 16) +TRIALS = 5 + + +class TestCudagraphs(TestCase): + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", + ) + def test_cudagraphs_enabled_inference_cpp(self): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=False, + ) + + result_samples = [] + torch_results_samples = [] + with torch_tensorrt.runtime.enable_cudagraphs(): + for i in inputs: + result_samples.append(optimized_model(i).detach().cpu()) + torch_results_samples.append(fx_graph(i).detach().cpu()) + + for i, (optimized_model_results, torch_model_results) in enumerate( + zip(result_samples, torch_results_samples) + ): + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"CUDA Graph C++ TRT outputs don't match with the original model. (trial: {i})", + ) + + torch._dynamo.reset() + + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", + ) + def test_cudagraphs_enabled_fallback_inference_cpp(self): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.relu((x + 2) * 0.5) + + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + torch_executed_ops={"torch.ops.aten.mul.Tensor"}, + use_python_runtime=False, + ) + + result_samples = [] + torch_results_samples = [] + with torch_tensorrt.runtime.enable_cudagraphs(): + for i in inputs: + result_samples.append(optimized_model(i).detach().cpu()) + torch_results_samples.append(fx_graph(i).detach().cpu()) + + for i, (optimized_model_results, torch_model_results) in enumerate( + zip(result_samples, torch_results_samples) + ): + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"CUDA Graph Python TRT outputs don't match with the original model. (trial: {i})", + ) + + torch._dynamo.reset() + + @unittest.skipIf( + os.environ.get("CI_BUILD") == "1", + "Skipping test due to CI resource constraints", + ) + def test_cudagraphs_recapture_cpp(self): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.relu((x + 2) * 0.5) + + inputs = [ + TRIALS * [torch.randn(*(2 * (i + 1), 2 * (i + 1))).cuda()] + for i in range(TRIALS) + ] + inputs = list(itertools.chain.from_iterable(inputs)) + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + torch_executed_ops={"torch.ops.aten.mul.Tensor"}, + use_python_runtime=False, + ) + + result_samples = [] + torch_results_samples = [] + with torch_tensorrt.runtime.enable_cudagraphs(): + for i in inputs: + result_samples.append(optimized_model(i).detach().cpu()) + torch_results_samples.append(fx_graph(i).detach().cpu()) + + for i, (optimized_model_results, torch_model_results) in enumerate( + zip(result_samples, torch_results_samples) + ): + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"CUDA Graph Python TRT outputs don't match with the original model. (trial: {i})", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_002_cudagraphs_py.py b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py new file mode 100644 index 0000000000..e5e340c500 --- /dev/null +++ b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py @@ -0,0 +1,164 @@ +import itertools +import os +import unittest + +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests + +from ..testing_utilities import DECIMALS_OF_AGREEMENT + +INPUT_SIZE = (3, 16, 16) +TRIALS = 5 + + +class TestCudagraphs(TestCase): + def test_cudagraphs_on(self): + torch_tensorrt.runtime.set_cudagraphs_mode(True) + self.assertTrue(torch.ops.tensorrt.get_cudagraphs_mode()) + + def test_cudagraphs_off(self): + torch_tensorrt.runtime.set_cudagraphs_mode(False) + self.assertFalse(torch.ops.tensorrt.get_cudagraphs_mode()) + + def test_cudagraphs_context(self): + with torch_tensorrt.runtime.enable_cudagraphs(): + self.assertTrue(torch.ops.tensorrt.get_cudagraphs_mode()) + self.assertFalse(torch.ops.tensorrt.get_cudagraphs_mode()) + + def test_cudagraphs_enabled_inference_python(self): + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + debug=True, + ) + + result_samples = [] + torch_results_samples = [] + with torch_tensorrt.runtime.enable_cudagraphs(): + for i in inputs: + result_samples.append(optimized_model(i).detach().cpu()) + torch_results_samples.append(fx_graph(i).detach().cpu()) + + for i, (optimized_model_results, torch_model_results) in enumerate( + zip(result_samples, torch_results_samples) + ): + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"CUDA Graph Python TRT outputs don't match with the original model. (trial: {i})", + ) + + torch._dynamo.reset() + + def test_cudagraphs_enabled_fallback_inference_python(self): + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + torch_executed_ops={"torch.ops.aten.mul.Tensor"}, + use_python_runtime=True, + debug=True, + ) + + result_samples = [] + torch_results_samples = [] + with torch_tensorrt.runtime.enable_cudagraphs(): + for i in inputs: + result_samples.append(optimized_model(i).detach().cpu()) + torch_results_samples.append(fx_graph(i).detach().cpu()) + + for i, (optimized_model_results, torch_model_results) in enumerate( + zip(result_samples, torch_results_samples) + ): + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"CUDA Graph Python TRT outputs don't match with the original model. (trial: {i})", + ) + torch._dynamo.reset() + + @unittest.skipIf( + os.environ.get("CI_BUILD") == "1", + "Skipping test due to CI resource constraints", + ) + def test_cudagraphs_recapture_py(self): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.relu((x + 2) * 0.5) + + inputs = [ + TRIALS * [torch.randn(*(2 * (i + 1), 2 * (i + 1))).cuda()] + for i in range(TRIALS) + ] + inputs = list(itertools.chain.from_iterable(inputs)) + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs[0], + min_block_size=1, + pass_through_build_failures=True, + torch_executed_ops={"torch.ops.aten.mul.Tensor"}, + use_python_runtime=True, + ) + + result_samples = [] + torch_results_samples = [] + with torch_tensorrt.runtime.enable_cudagraphs(): + for i in inputs: + result_samples.append(optimized_model(i).detach().cpu()) + torch_results_samples.append(fx_graph(i).detach().cpu()) + + for i, (optimized_model_results, torch_model_results) in enumerate( + zip(result_samples, torch_results_samples) + ): + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"CUDA Graph Python TRT outputs don't match with the original model. (trial: {i})", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_lazy_engine_init.py b/tests/py/dynamo/runtime/test_002_lazy_engine_init.py similarity index 100% rename from tests/py/dynamo/runtime/test_lazy_engine_init.py rename to tests/py/dynamo/runtime/test_002_lazy_engine_init.py diff --git a/tests/py/dynamo/runtime/test_safe_mode.py b/tests/py/dynamo/runtime/test_003_safe_mode.py similarity index 93% rename from tests/py/dynamo/runtime/test_safe_mode.py rename to tests/py/dynamo/runtime/test_003_safe_mode.py index 5842b3ddc5..0fde0773ed 100644 --- a/tests/py/dynamo/runtime/test_safe_mode.py +++ b/tests/py/dynamo/runtime/test_003_safe_mode.py @@ -7,10 +7,7 @@ from ..testing_utilities import DECIMALS_OF_AGREEMENT -@unittest.skipIf( - not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, - "Torch-TensorRT runtime is not available", -) +@unittest.skipIf(torch.cuda.device_count() == 1, "System does not have multiple GPUs") class TestSafeMode(TestCase): def test_multi_device_safe_mode_on(self): torch_tensorrt.runtime.set_multi_device_safe_mode(True) @@ -65,6 +62,10 @@ def forward(self, x): ) torch._dynamo.reset() + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", + ) def test_multi_device_safe_mode_enabled_inference_cpp(self): torch_tensorrt.runtime.set_multi_device_safe_mode(True) diff --git a/tests/py/dynamo/runtime/test_cudagraphs.py b/tests/py/dynamo/runtime/test_cudagraphs.py deleted file mode 100644 index 4d922629c1..0000000000 --- a/tests/py/dynamo/runtime/test_cudagraphs.py +++ /dev/null @@ -1,202 +0,0 @@ -import unittest - -import torch -from torch.testing._internal.common_utils import TestCase, run_tests - -import torch_tensorrt - -from ..testing_utilities import DECIMALS_OF_AGREEMENT - - -class TestCudagraphs(TestCase): - def test_cudagraphs_on(self): - torch_tensorrt.runtime.set_cudagraphs_mode(True) - self.assertTrue(torch.ops.tensorrt.get_cudagraphs_mode()) - - def test_cudagraphs_off(self): - torch_tensorrt.runtime.set_cudagraphs_mode(False) - self.assertFalse(torch.ops.tensorrt.get_cudagraphs_mode()) - - def test_cudagraphs_context(self): - with torch_tensorrt.runtime.enable_cudagraphs(): - self.assertTrue(torch.ops.tensorrt.get_cudagraphs_mode()) - self.assertFalse(torch.ops.tensorrt.get_cudagraphs_mode()) - - def test_cudagraphs_enabled_inference_python(self): - torch_tensorrt.runtime.set_cudagraphs_mode(True) - - class SampleModel(torch.nn.Module): - def forward(self, x): - return torch.softmax((x + 2) * 7, dim=0) - - inputs = [ - torch.randn( - 3, - 5, - 7, - ).cuda() - ] - - fx_graph = torch.fx.symbolic_trace(SampleModel()) - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - use_python_runtime=True, - ) - optimized_model_results = optimized_model(*inputs).detach().cpu() - torch_model_results = fx_graph(*inputs).detach().cpu() - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - msg=f"Safe Mode Python TRT outputs don't match with the original model.", - ) - torch._dynamo.reset() - - @unittest.skipIf( - not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, - "Torch-TensorRT runtime is not available", - ) - def test_cudagraphs_enabled_inference_cpp(self): - class SampleModel(torch.nn.Module): - def forward(self, x): - return torch.softmax((x + 2) * 7, dim=0) - - inputs = [ - torch.randn( - 3, - 5, - 7, - ).cuda() - ] - - fx_graph = torch.fx.symbolic_trace(SampleModel()) - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - use_python_runtime=False, - ) - - with torch_tensorrt.runtime.enable_cudagraphs(): - optimized_model_results = optimized_model(*inputs).detach().cpu() - - torch_model_results = fx_graph(*inputs).detach().cpu() - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - msg=f"Safe Mode C++ TRT outputs don't match with the original model.", - ) - torch._dynamo.reset() - - def test_cudagraphs_enabled_fallback_inference_python(self): - torch_tensorrt.runtime.set_cudagraphs_mode(True) - - class SampleModel(torch.nn.Module): - def forward(self, x): - return torch.softmax((x + 2) * 7, dim=0) - - inputs = [ - torch.randn( - 3, - 5, - 7, - ).cuda() - ] - - fx_graph = torch.fx.symbolic_trace(SampleModel()) - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - torch_executed_ops={"torch.ops.aten.mul.Tensor"}, - use_python_runtime=True, - ) - - with torch_tensorrt.runtime.enable_cudagraphs(): - optimized_model_results = optimized_model(*inputs).detach().cpu() - - torch_model_results = fx_graph(*inputs).detach().cpu() - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - msg=f"Safe Mode Python TRT outputs don't match with the original model.", - ) - torch._dynamo.reset() - - @unittest.skipIf( - not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, - "Torch-TensorRT runtime is not available", - ) - def test_cudagraphs_enabled_fallback_inference_cpp(self): - class SampleModel(torch.nn.Module): - def forward(self, x): - return torch.softmax((x + 2) * 7, dim=0) - - inputs = [ - torch.randn( - 3, - 5, - 7, - ).cuda() - ] - - fx_graph = torch.fx.symbolic_trace(SampleModel()) - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - torch_executed_ops={"torch.ops.aten.mul.Tensor"}, - use_python_runtime=False, - ) - - with torch_tensorrt.runtime.enable_cudagraphs(): - optimized_model_results = optimized_model(*inputs).detach().cpu() - - torch_model_results = fx_graph(*inputs).detach().cpu() - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - msg=f"Safe Mode C++ TRT outputs don't match with the original model.", - ) - torch._dynamo.reset() - - -if __name__ == "__main__": - run_tests()