Skip to content

Commit

Permalink
Restore cublas argument. (injector)
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Aug 29, 2024
1 parent 9ae34f4 commit e84105e
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions zluda_inject/src/bin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ static NVCUDA_DLL: &'static str = "nvcuda.dll";
static NVML_DLL: &'static str = "nvml.dll";
static NVAPI_DLL: &'static str = "nvapi64.dll";
static NVOPTIX_DLL: &'static str = "optix.6.6.0.dll";
static CUBLAS_DLL: &'static str = "cublas64.dll";

include!("../../zluda_redirect/src/payload_guid.rs");

Expand Down Expand Up @@ -59,6 +60,10 @@ struct ProgramArguments {
#[argh(option)]
nvoptix: Option<PathBuf>,

/// DLL to be injected instead of system cublas64.dll. If not provided, no injection will take place
#[argh(option)]
cublas: Option<PathBuf>,

/// display the version of ZLUDA
#[argh(switch)]
#[allow(dead_code)]
Expand Down Expand Up @@ -108,6 +113,9 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
if let Some(ref nvoptix) = environment.nvoptix_path_zero_terminated {
dlls_to_inject.push(nvoptix.as_ptr() as _);
}
if let Some(ref cublas) = environment.cublas_path_zero_terminated {
dlls_to_inject.push(cublas.as_ptr() as _);
}
os_call!(
detours_sys::DetourCreateProcessWithDllsW(
ptr::null(),
Expand Down Expand Up @@ -185,6 +193,7 @@ struct NormalizedArguments {
nvml_path: PathBuf,
nvapi_path: Option<PathBuf>,
nvoptix_path: Option<PathBuf>,
cublas_path: Option<PathBuf>,
redirect_path: PathBuf,
winapi_command_line_zero_terminated: Vec<u16>,
}
Expand All @@ -199,6 +208,7 @@ impl NormalizedArguments {
let nvml_path = Self::get_absolute_path_or_default(&current_exe, prog_args.nvml, NVML_DLL)?;
let nvapi_path = prog_args.nvapi.map(Self::get_absolute_path).transpose()?;
let nvoptix_path = prog_args.nvoptix.map(Self::get_absolute_path).transpose()?;
let cublas_path = prog_args.cublas.map(Self::get_absolute_path).transpose()?;
let winapi_command_line_zero_terminated =
construct_command_line(std::iter::once(prog_args.exe).chain(prog_args.args));
let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
Expand All @@ -210,6 +220,7 @@ impl NormalizedArguments {
nvml_path,
nvapi_path,
nvoptix_path,
cublas_path,
redirect_path,
winapi_command_line_zero_terminated,
})
Expand Down Expand Up @@ -269,6 +280,7 @@ struct Environment {
nvml_path_zero_terminated: String,
nvapi_path_zero_terminated: Option<String>,
nvoptix_path_zero_terminated: Option<String>,
cublas_path_zero_terminated: Option<String>,
redirect_path_zero_terminated: String,
winapi_command_line_zero_terminated: Vec<u16>,
_temp_dir: TempDir,
Expand Down Expand Up @@ -321,6 +333,14 @@ impl Environment {
)?))
})
.transpose()?;
let cublas_path_zero_terminated = args
.cublas_path
.map(|cublas| {
Ok::<_, io::Error>(Self::zero_terminate(Self::copy_to_correct_name(
cublas, &_temp_dir, CUBLAS_DLL,
)?))
})
.transpose()?;
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
Ok(Self {
nccl_path_zero_terminated,
Expand All @@ -329,6 +349,7 @@ impl Environment {
nvml_path_zero_terminated,
nvapi_path_zero_terminated,
nvoptix_path_zero_terminated,
cublas_path_zero_terminated,
redirect_path_zero_terminated,
winapi_command_line_zero_terminated: args.winapi_command_line_zero_terminated,
_temp_dir,
Expand Down

0 comments on commit e84105e

Please sign in to comment.