diff --git a/zluda_inject/src/bin.rs b/zluda_inject/src/bin.rs index 309d8b88..b6065295 100644 --- a/zluda_inject/src/bin.rs +++ b/zluda_inject/src/bin.rs @@ -29,6 +29,7 @@ static NVCUDA_DLL: &'static str = "nvcuda.dll"; static NVML_DLL: &'static str = "nvml.dll"; static NVAPI_DLL: &'static str = "nvapi64.dll"; static NVOPTIX_DLL: &'static str = "optix.6.6.0.dll"; +static CUBLAS_DLL: &'static str = "cublas64.dll"; include!("../../zluda_redirect/src/payload_guid.rs"); @@ -59,6 +60,10 @@ struct ProgramArguments { #[argh(option)] nvoptix: Option, + /// DLL to be injected instead of system cublas64.dll. If not provided, no injection will take place + #[argh(option)] + cublas: Option, + /// display the version of ZLUDA #[argh(switch)] #[allow(dead_code)] @@ -108,6 +113,9 @@ pub fn main_impl() -> Result<(), Box> { if let Some(ref nvoptix) = environment.nvoptix_path_zero_terminated { dlls_to_inject.push(nvoptix.as_ptr() as _); } + if let Some(ref cublas) = environment.cublas_path_zero_terminated { + dlls_to_inject.push(cublas.as_ptr() as _); + } os_call!( detours_sys::DetourCreateProcessWithDllsW( ptr::null(), @@ -185,6 +193,7 @@ struct NormalizedArguments { nvml_path: PathBuf, nvapi_path: Option, nvoptix_path: Option, + cublas_path: Option, redirect_path: PathBuf, winapi_command_line_zero_terminated: Vec, } @@ -199,6 +208,7 @@ impl NormalizedArguments { let nvml_path = Self::get_absolute_path_or_default(¤t_exe, prog_args.nvml, NVML_DLL)?; let nvapi_path = prog_args.nvapi.map(Self::get_absolute_path).transpose()?; let nvoptix_path = prog_args.nvoptix.map(Self::get_absolute_path).transpose()?; + let cublas_path = prog_args.cublas.map(Self::get_absolute_path).transpose()?; let winapi_command_line_zero_terminated = construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args)); let mut redirect_path = current_exe.parent().unwrap().to_path_buf(); @@ -210,6 +220,7 @@ impl NormalizedArguments { nvml_path, nvapi_path, nvoptix_path, + cublas_path, redirect_path, winapi_command_line_zero_terminated, }) @@ -269,6 +280,7 @@ struct Environment { nvml_path_zero_terminated: String, nvapi_path_zero_terminated: Option, nvoptix_path_zero_terminated: Option, + cublas_path_zero_terminated: Option, redirect_path_zero_terminated: String, winapi_command_line_zero_terminated: Vec, _temp_dir: TempDir, @@ -321,6 +333,14 @@ impl Environment { )?)) }) .transpose()?; + let cublas_path_zero_terminated = args + .cublas_path + .map(|cublas| { + Ok::<_, io::Error>(Self::zero_terminate(Self::copy_to_correct_name( + cublas, &_temp_dir, CUBLAS_DLL, + )?)) + }) + .transpose()?; let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path); Ok(Self { nccl_path_zero_terminated, @@ -329,6 +349,7 @@ impl Environment { nvml_path_zero_terminated, nvapi_path_zero_terminated, nvoptix_path_zero_terminated, + cublas_path_zero_terminated, redirect_path_zero_terminated, winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated, _temp_dir,