Skip to content

Commit

Permalink
fix: concurrent downloads, ref pykeio#322
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Nov 23, 2024
1 parent 548bfed commit eb51646
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 15 deletions.
38 changes: 28 additions & 10 deletions ort-sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
env, fs,
env, fs, io,
path::{Path, PathBuf}
};

Expand All @@ -16,9 +16,9 @@ const ORT_EXTRACT_DIR: &str = "onnxruntime";

const DIST_TABLE: &str = include_str!("dist.txt");

#[path = "src/internal/dirs.rs"]
mod dirs;
use self::dirs::cache_dir;
#[path = "src/internal/mod.rs"]
mod internal;
use self::internal::dirs::cache_dir;

#[cfg(feature = "download-binaries")]
fn fetch_file(source_url: &str) -> Vec<u8> {
Expand Down Expand Up @@ -450,20 +450,38 @@ fn prepare_libort_dir() -> (PathBuf, bool) {

let (prebuilt_url, prebuilt_hash) = dist.unwrap();

let mut cache_dir = cache_dir()
let bin_extract_dir = cache_dir()
.expect("could not determine cache directory")
.join("dfbin")
.join(target)
.join(prebuilt_hash);
if fs::create_dir_all(&cache_dir).is_err() {
cache_dir = env::var("OUT_DIR").unwrap().into();
}

let lib_dir = cache_dir.join(ORT_EXTRACT_DIR);
let lib_dir = bin_extract_dir.join(ORT_EXTRACT_DIR);
if !lib_dir.exists() {
let downloaded_file = fetch_file(prebuilt_url);
assert!(verify_file(&downloaded_file, prebuilt_hash), "hash of downloaded ONNX Runtime binary does not match!");
extract_tgz(&downloaded_file, &cache_dir);

let mut temp_extract_dir = bin_extract_dir
.parent()
.unwrap()
.join(format!("tmp.{}_{prebuilt_hash}", self::internal::random_identifier()));
let mut should_rename = true;
if fs::create_dir_all(&temp_extract_dir).is_err() {
temp_extract_dir = env::var("OUT_DIR").unwrap().into();
should_rename = false;
}
extract_tgz(&downloaded_file, &temp_extract_dir);
if should_rename {
match std::fs::rename(&temp_extract_dir, bin_extract_dir) {
Ok(()) => {}
Err(e) => match e.kind() {
io::ErrorKind::AlreadyExists | io::ErrorKind::DirectoryNotEmpty => {
let _ = fs::remove_dir_all(temp_extract_dir);
}
_ => panic!("failed to extract downloaded binaries: {e}")
}
}
}
}

static_link_prerequisites(true);
Expand Down
15 changes: 15 additions & 0 deletions ort-sys/src/internal/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,16 @@
use std::hash::{BuildHasher, Hasher, RandomState};

pub mod dirs;

pub fn random_identifier() -> String {
let mut state = RandomState::new().build_hasher().finish();
std::iter::repeat_with(move || {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
state
})
.take(12)
.map(|i| b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"[i as usize % 62] as char)
.collect()
}
22 changes: 17 additions & 5 deletions src/session/builder/impl_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl SessionBuilder {
let _ = write!(&mut s, "{:02x}", b);
s
});
let model_filepath = download_dir.join(model_filename);
let model_filepath = download_dir.join(&model_filename);
let downloaded_path = if model_filepath.exists() {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
model_filepath
Expand All @@ -46,16 +46,28 @@ impl SessionBuilder {
tracing::info!(len, "Downloading {} bytes", len);

let mut reader = resp.into_reader();
let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier()));

let f = std::fs::File::create(&model_filepath).expect("Failed to create model file");
let f = std::fs::File::create(&temp_filepath).expect("Failed to create model file");
let mut writer = std::io::BufWriter::new(f);

let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(Error::wrap)?;
if bytes_io_count == len as u64 {
model_filepath
} else {
if bytes_io_count != len as u64 {
return Err(Error::new(format!("Failed to download entire model; file only has {bytes_io_count} bytes, expected {len}")));
}

drop(writer);

match std::fs::rename(&temp_filepath, &model_filepath) {
Ok(()) => model_filepath,
Err(e) => match e.kind() {
std::io::ErrorKind::AlreadyExists => {
let _ = std::fs::remove_file(temp_filepath);
model_filepath
}
_ => return Err(Error::new(format!("Failed to download model: {e}")))
}
}
};

self.commit_from_file(downloaded_path)
Expand Down

0 comments on commit eb51646

Please sign in to comment.