diff --git a/examples/modified_cuda_samples/simpleDrvRuntimePTX/simpleDrvRuntimePTX.cpp b/examples/modified_cuda_samples/simpleDrvRuntimePTX/simpleDrvRuntimePTX.cpp index da0c3cab..68bc605e 100644 --- a/examples/modified_cuda_samples/simpleDrvRuntimePTX/simpleDrvRuntimePTX.cpp +++ b/examples/modified_cuda_samples/simpleDrvRuntimePTX/simpleDrvRuntimePTX.cpp @@ -124,7 +124,7 @@ int main(int argc, char** argv) // first search for the module path before we load the results auto ptx_filename = create_ptx_file(); - auto module = cuda::module::load_from_file(ptx_filename); + auto module = cuda::module::load_from_file(context, ptx_filename); auto vecAdd_kernel = module.get_kernel("VecAdd_kernel"); auto dummy_kernel = module.get_kernel("dummy"); diff --git a/src/cuda/api/module.hpp b/src/cuda/api/module.hpp index 20b56862..8e986a19 100644 --- a/src/cuda/api/module.hpp +++ b/src/cuda/api/module.hpp @@ -31,6 +31,8 @@ class kernel_t; namespace module { +using handle_t = CUmodule; + namespace detail_ { inline module_t construct( @@ -59,23 +61,6 @@ ::std::string identify(const module_t &module); } // namespace detail_ -/** - * Load a module from an appropriate compiled or semi-compiled file, allocating all - * relevant resources for it. - * - * @param path of a cubin, PTX, or fatbin file constituting the module to be loaded. - * @return the loaded module - * - * @note this covers cuModuleLoadFatBinary() even though that's not directly used - */ -module_t load_from_file(const char *path, link::options_t link_options = {}); - -module_t load_from_file(const ::std::string &path, link::options_t link_options = {}); - -#if __cplusplus >= 201703L -module_t load_from_file(const ::std::filesystem::path& path, link::options_t options = {}); -#endif - /** * Create a CUDA driver module from raw module image data. * @@ -195,37 +180,113 @@ namespace module { using handle_t = CUmodule; -/** -* Loads a populated module from a file on disk -* -* @param path Filesystem path of a fatbin, cubin or PTX file -* -* @todo: Do we really need link options here? - * @todo: Make this take a context_t; and consider adding load_module methods to context_t -*/ -inline module_t load_from_file(const char* path, link::options_t link_options) +namespace detail_ { + +inline module_t load_from_file_in_current_context( + device::id_t current_context_device_id, + context::handle_t current_context_handle, + const char *path, + link::options_t link_options) { handle_t new_module_handle; auto status = cuModuleLoad(&new_module_handle, path); throw_if_error(status, ::std::string("Failed loading a module from file ") + path); - bool do_take_ownership { true }; - auto current_context_handle = context::current::detail_::get_handle(); - auto current_device_id = context::detail_::get_device_id(current_context_handle); - return detail_::construct(current_device_id, current_context_handle, new_module_handle, link_options, + bool do_take_ownership{true}; + return construct( + current_context_device_id, + current_context_handle, + new_module_handle, + link_options, do_take_ownership); } -inline module_t load_from_file(const ::std::string& path, link::options_t link_options) +} // namespace detail_ + + +/** + * Load a module from an appropriate compiled or semi-compiled file, allocating all + * relevant resources for it. + * + * @param path of a cubin, PTX, or fatbin file constituting the module to be loaded. + * @return the loaded module + * + * @note this covers cuModuleLoadFatBinary() even though that's not directly used + * + * @todo: consider adding load_module methods to context_t + * @todo: When switching to the C++17 standard, use string_view's instead of the const char* + * and std::string reference + */ +///@{ +inline module_t load_from_file( + const context_t& context, + const char* path, + link::options_t link_options = {}) +{ + context::current::detail_::scoped_override_t set_context_for_this_scope(context.handle()); + return detail_::load_from_file_in_current_context( + context.device_id(), context.handle(), path, link_options); +} + +inline module_t load_from_file( + const context_t& context, + const ::std::string& path, + link::options_t link_options = {}) +{ + return load_from_file(context, path.c_str(), link_options); +} + +inline module_t load_from_file( + const device_t& device, + const char* path, + link::options_t link_options = {}) +{ + auto pc = device.primary_context(); + device::primary_context::detail_::increase_refcount(device.id()); + return load_from_file(pc, path, link_options); +} + +inline module_t load_from_file( + const device_t& device, + const ::std::string* path, + link::options_t link_options = {}) { - return load_from_file(path.c_str(), link_options); + return load_from_file(device, path->c_str(), link_options); } +inline module_t load_from_file( + const char* path, + link::options_t link_options = {}) +{ + return load_from_file(device::current::get(), path, link_options); +} + +inline module_t load_from_file( + const ::std::string& path, + link::options_t link_options) +{ + return load_from_file(device::current::get(), path.c_str(), link_options); +} + + #if __cplusplus >= 201703L -inline module_t load_from_file(const ::std::filesystem::path& path) + +inline module_t load_from_file( + const device_t& device, + const ::std::filesystem::path& path, + link::options_t link_options = {}) +{ + return load_from_file(device, path.c_str(), link_options); +} + +inline module_t load_from_file( + const ::std::filesystem::path& path, + link::options_t link_options = {}) { - return load_from_file(path.c_str()); + return load_from_file(device::current::get(), path, link_options); } + #endif +///@} namespace detail_ {