diff --git a/ort-sys/build.rs b/ort-sys/build.rs index dc0672c..23e94a8 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -1,5 +1,5 @@ use std::{ - env, fs, + env, fs, io, path::{Path, PathBuf}, process::Command }; @@ -12,9 +12,9 @@ const ORT_ENV_SYSTEM_LIB_PROFILE: &str = "ORT_LIB_PROFILE"; const DIST_TABLE: &str = include_str!("dist.txt"); -#[path = "src/internal/dirs.rs"] -mod dirs; -use self::dirs::cache_dir; +#[path = "src/internal/mod.rs"] +mod internal; +use self::internal::dirs::cache_dir; #[cfg(feature = "download-binaries")] fn fetch_file(source_url: &str) -> Vec { @@ -416,21 +416,40 @@ fn prepare_libort_dir() -> (PathBuf, bool) { let (prebuilt_url, prebuilt_hash) = dist.unwrap(); - let mut cache_dir = cache_dir() + let bin_extract_dir = cache_dir() .expect("could not determine cache directory") .join("dfbin") .join(target) .join(prebuilt_hash); - if fs::create_dir_all(&cache_dir).is_err() { - cache_dir = env::var("OUT_DIR").unwrap().into(); - } let ort_extract_dir = prebuilt_url.split('/').last().unwrap().strip_suffix(".tgz").unwrap(); - let lib_dir = cache_dir.join(ort_extract_dir); + let lib_dir = bin_extract_dir.join(ort_extract_dir); if !lib_dir.exists() { let downloaded_file = fetch_file(prebuilt_url); assert!(verify_file(&downloaded_file, prebuilt_hash), "hash of downloaded ONNX Runtime binary does not match!"); - extract_tgz(&downloaded_file, &cache_dir); + + let mut temp_extract_dir = bin_extract_dir + .parent() + .unwrap() + .join(format!("tmp.{}_{prebuilt_hash}", self::internal::random_identifier())); + let mut should_rename = true; + if fs::create_dir_all(&temp_extract_dir).is_err() { + temp_extract_dir = env::var("OUT_DIR").unwrap().into(); + should_rename = false; + } + extract_tgz(&downloaded_file, &temp_extract_dir); + if should_rename { + match std::fs::rename(&temp_extract_dir, &bin_extract_dir) { + Ok(()) => {} + Err(e) => { + if bin_extract_dir.exists() { + let _ = fs::remove_dir_all(temp_extract_dir); + } else { + panic!("failed to extract downloaded binaries: {e}"); + } + } + } + } } static_link_prerequisites(true); diff --git a/ort-sys/src/internal/mod.rs b/ort-sys/src/internal/mod.rs index 7516f3c..6cb610a 100644 --- a/ort-sys/src/internal/mod.rs +++ b/ort-sys/src/internal/mod.rs @@ -1,4 +1,16 @@ +use std::hash::{BuildHasher, Hasher, RandomState}; + pub mod dirs; -#[cfg(feature = "download-binaries")] -include!(concat!(env!("OUT_DIR"), "/downloaded_version.rs")); +pub fn random_identifier() -> String { + let mut state = RandomState::new().build_hasher().finish(); + std::iter::repeat_with(move || { + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + state + }) + .take(12) + .map(|i| b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"[i as usize % 62] as char) + .collect() +} diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index ee08e63..660e76f 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -10,6 +10,9 @@ #[doc(hidden)] pub mod internal; +#[cfg(feature = "download-binaries")] +include!(concat!(env!("OUT_DIR"), "/downloaded_version.rs")); + pub const ORT_API_VERSION: u32 = 17; pub use std::ffi::{c_char, c_int, c_ulong, c_ulonglong, c_ushort, c_void};