From eb51646860bc2fc5728d879bdf542c0537acc6ea Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 23 Nov 2024 12:30:45 -0600 Subject: [PATCH] fix: concurrent downloads, ref #322 --- ort-sys/build.rs | 38 ++++++++++++++++++++++-------- ort-sys/src/internal/mod.rs | 15 ++++++++++++ src/session/builder/impl_commit.rs | 22 +++++++++++++---- 3 files changed, 60 insertions(+), 15 deletions(-) diff --git a/ort-sys/build.rs b/ort-sys/build.rs index 8d99cd7f..525590a6 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -1,5 +1,5 @@ use std::{ - env, fs, + env, fs, io, path::{Path, PathBuf} }; @@ -16,9 +16,9 @@ const ORT_EXTRACT_DIR: &str = "onnxruntime"; const DIST_TABLE: &str = include_str!("dist.txt"); -#[path = "src/internal/dirs.rs"] -mod dirs; -use self::dirs::cache_dir; +#[path = "src/internal/mod.rs"] +mod internal; +use self::internal::dirs::cache_dir; #[cfg(feature = "download-binaries")] fn fetch_file(source_url: &str) -> Vec { @@ -450,20 +450,38 @@ fn prepare_libort_dir() -> (PathBuf, bool) { let (prebuilt_url, prebuilt_hash) = dist.unwrap(); - let mut cache_dir = cache_dir() + let bin_extract_dir = cache_dir() .expect("could not determine cache directory") .join("dfbin") .join(target) .join(prebuilt_hash); - if fs::create_dir_all(&cache_dir).is_err() { - cache_dir = env::var("OUT_DIR").unwrap().into(); - } - let lib_dir = cache_dir.join(ORT_EXTRACT_DIR); + let lib_dir = bin_extract_dir.join(ORT_EXTRACT_DIR); if !lib_dir.exists() { let downloaded_file = fetch_file(prebuilt_url); assert!(verify_file(&downloaded_file, prebuilt_hash), "hash of downloaded ONNX Runtime binary does not match!"); - extract_tgz(&downloaded_file, &cache_dir); + + let mut temp_extract_dir = bin_extract_dir + .parent() + .unwrap() + .join(format!("tmp.{}_{prebuilt_hash}", self::internal::random_identifier())); + let mut should_rename = true; + if fs::create_dir_all(&temp_extract_dir).is_err() { + temp_extract_dir = env::var("OUT_DIR").unwrap().into(); + should_rename = false; + } + extract_tgz(&downloaded_file, &temp_extract_dir); + if should_rename { + match std::fs::rename(&temp_extract_dir, bin_extract_dir) { + Ok(()) => {} + Err(e) => match e.kind() { + io::ErrorKind::AlreadyExists | io::ErrorKind::DirectoryNotEmpty => { + let _ = fs::remove_dir_all(temp_extract_dir); + } + _ => panic!("failed to extract downloaded binaries: {e}") + } + } + } } static_link_prerequisites(true); diff --git a/ort-sys/src/internal/mod.rs b/ort-sys/src/internal/mod.rs index 16ec3672..6cb610ab 100644 --- a/ort-sys/src/internal/mod.rs +++ b/ort-sys/src/internal/mod.rs @@ -1 +1,16 @@ +use std::hash::{BuildHasher, Hasher, RandomState}; + pub mod dirs; + +pub fn random_identifier() -> String { + let mut state = RandomState::new().build_hasher().finish(); + std::iter::repeat_with(move || { + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + state + }) + .take(12) + .map(|i| b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"[i as usize % 62] as char) + .collect() +} diff --git a/src/session/builder/impl_commit.rs b/src/session/builder/impl_commit.rs index f586c169..a31d27d0 100644 --- a/src/session/builder/impl_commit.rs +++ b/src/session/builder/impl_commit.rs @@ -30,7 +30,7 @@ impl SessionBuilder { let _ = write!(&mut s, "{:02x}", b); s }); - let model_filepath = download_dir.join(model_filename); + let model_filepath = download_dir.join(&model_filename); let downloaded_path = if model_filepath.exists() { tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download"); model_filepath @@ -46,16 +46,28 @@ impl SessionBuilder { tracing::info!(len, "Downloading {} bytes", len); let mut reader = resp.into_reader(); + let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier())); - let f = std::fs::File::create(&model_filepath).expect("Failed to create model file"); + let f = std::fs::File::create(&temp_filepath).expect("Failed to create model file"); let mut writer = std::io::BufWriter::new(f); let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(Error::wrap)?; - if bytes_io_count == len as u64 { - model_filepath - } else { + if bytes_io_count != len as u64 { return Err(Error::new(format!("Failed to download entire model; file only has {bytes_io_count} bytes, expected {len}"))); } + + drop(writer); + + match std::fs::rename(&temp_filepath, &model_filepath) { + Ok(()) => model_filepath, + Err(e) => match e.kind() { + std::io::ErrorKind::AlreadyExists => { + let _ = std::fs::remove_file(temp_filepath); + model_filepath + } + _ => return Err(Error::new(format!("Failed to download model: {e}"))) + } + } }; self.commit_from_file(downloaded_path)