diff --git a/core/runtime/runtime.cpp b/core/runtime/runtime.cpp index afe559ec4d..24eba16cc3 100644 --- a/core/runtime/runtime.cpp +++ b/core/runtime/runtime.cpp @@ -121,6 +121,14 @@ void multi_gpu_device_check() { } } +bool get_multi_device_safe_mode() { + return MULTI_DEVICE_SAFE_MODE; +} + +void set_multi_device_safe_mode(bool multi_device_safe_mode) { + MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; +} + namespace { static DeviceList cuda_device_list; } diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 5e5676b11e..8c9b33328a 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -38,6 +38,10 @@ std::vector execute_engine(std::vector inputs, c10::intr void multi_gpu_device_check(); +bool get_multi_device_safe_mode(); + +void set_multi_device_safe_mode(bool multi_device_safe_mode); + class DeviceList { using DeviceMap = std::unordered_map; DeviceMap device_list; diff --git a/tests/core/runtime/BUILD b/tests/core/runtime/BUILD index cd9f123b59..deddca4cfb 100644 --- a/tests/core/runtime/BUILD +++ b/tests/core/runtime/BUILD @@ -1,3 +1,5 @@ +load("//tests/core/runtime:runtime_test.bzl", "runtime_test") + package(default_visibility = ["//visibility:public"]) config_setting( @@ -6,3 +8,14 @@ config_setting( "define": "abi=pre_cxx11_abi", }, ) + +runtime_test( + name = "test_multi_device_safe_mode", +) + +test_suite( + name = "runtime_tests", + tests = [ + ":test_multi_device_safe_mode", + ], +) diff --git a/tests/core/runtime/runtime_test.bzl b/tests/core/runtime/runtime_test.bzl new file mode 100644 index 0000000000..2c4753cc9b --- /dev/null +++ b/tests/core/runtime/runtime_test.bzl @@ -0,0 +1,25 @@ +""" +runtime test macros +""" + +load("@rules_cc//cc:defs.bzl", "cc_test") + +def runtime_test(name, visibility = None): + """Macro to define a runtime test + + Args: + name: Name of test file + visibility: Visibility of the test target + """ + cc_test( + name = name, + srcs = [name + ".cpp"], + visibility = visibility, + deps = [ + "//tests/util", + "@googletest//:gtest_main", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), + ) diff --git a/tests/core/runtime/test_multi_device_safe_mode.cpp b/tests/core/runtime/test_multi_device_safe_mode.cpp new file mode 100644 index 0000000000..a8c72cac7b --- /dev/null +++ b/tests/core/runtime/test_multi_device_safe_mode.cpp @@ -0,0 +1,8 @@ +#include "core/runtime/runtime.h" +#include "gtest/gtest.h" + +TEST(Runtime, MultiDeviceSafeMode) { + ASSERT_TRUE(!torch_tensorrt::core::runtime::get_multi_device_safe_mode()); + torch_tensorrt::core::runtime::set_multi_device_safe_mode(true); + ASSERT_TRUE(torch_tensorrt::core::runtime::get_multi_device_safe_mode()); +}