Skip to content

Commit

Permalink
fix: concurrent downloads, ref pykeio#322
Browse files Browse the repository at this point in the history
(cherry picked from commit eb51646)

Conflicts:
	ort-sys/build.rs
	ort-sys/src/internal/mod.rs
	src/session/builder/impl_commit.rs
  • Loading branch information
decahedron1 authored and qryxip committed Nov 24, 2024
1 parent 69ac04c commit d61f7d6
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 8 deletions.
48 changes: 40 additions & 8 deletions ort-sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
use std::{
<<<<<<< HEAD
env, fs,
path::{Path, PathBuf},
process::Command
||||||| parent of eb51646 (fix: concurrent downloads, ref #322)
env, fs,
path::{Path, PathBuf}
=======
env, fs, io,
path::{Path, PathBuf}
>>>>>>> eb51646 (fix: concurrent downloads, ref #322)
};

#[allow(unused)]
Expand All @@ -12,9 +20,9 @@ const ORT_ENV_SYSTEM_LIB_PROFILE: &str = "ORT_LIB_PROFILE";

const DIST_TABLE: &str = include_str!("dist.txt");

#[path = "src/internal/dirs.rs"]
mod dirs;
use self::dirs::cache_dir;
#[path = "src/internal/mod.rs"]
mod internal;
use self::internal::dirs::cache_dir;

#[cfg(feature = "download-binaries")]
fn fetch_file(source_url: &str) -> Vec<u8> {
Expand Down Expand Up @@ -416,21 +424,45 @@ fn prepare_libort_dir() -> (PathBuf, bool) {

let (prebuilt_url, prebuilt_hash) = dist.unwrap();

let mut cache_dir = cache_dir()
let bin_extract_dir = cache_dir()
.expect("could not determine cache directory")
.join("dfbin")
.join(target)
.join(prebuilt_hash);
if fs::create_dir_all(&cache_dir).is_err() {
cache_dir = env::var("OUT_DIR").unwrap().into();
}

<<<<<<< HEAD
let ort_extract_dir = prebuilt_url.split('/').last().unwrap().strip_suffix(".tgz").unwrap();
let lib_dir = cache_dir.join(ort_extract_dir);
||||||| parent of eb51646 (fix: concurrent downloads, ref #322)
let lib_dir = cache_dir.join(ORT_EXTRACT_DIR);
=======
let lib_dir = bin_extract_dir.join(ORT_EXTRACT_DIR);
>>>>>>> eb51646 (fix: concurrent downloads, ref #322)
if !lib_dir.exists() {
let downloaded_file = fetch_file(prebuilt_url);
assert!(verify_file(&downloaded_file, prebuilt_hash), "hash of downloaded ONNX Runtime binary does not match!");
extract_tgz(&downloaded_file, &cache_dir);

let mut temp_extract_dir = bin_extract_dir
.parent()
.unwrap()
.join(format!("tmp.{}_{prebuilt_hash}", self::internal::random_identifier()));
let mut should_rename = true;
if fs::create_dir_all(&temp_extract_dir).is_err() {
temp_extract_dir = env::var("OUT_DIR").unwrap().into();
should_rename = false;
}
extract_tgz(&downloaded_file, &temp_extract_dir);
if should_rename {
match std::fs::rename(&temp_extract_dir, bin_extract_dir) {
Ok(()) => {}
Err(e) => match e.kind() {
io::ErrorKind::AlreadyExists | io::ErrorKind::DirectoryNotEmpty => {
let _ = fs::remove_dir_all(temp_extract_dir);
}
_ => panic!("failed to extract downloaded binaries: {e}")
}
}
}
}

static_link_prerequisites(true);
Expand Down
19 changes: 19 additions & 0 deletions ort-sys/src/internal/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
use std::hash::{BuildHasher, Hasher, RandomState};

pub mod dirs;
<<<<<<< HEAD

#[cfg(feature = "download-binaries")]
include!(concat!(env!("OUT_DIR"), "/downloaded_version.rs"));
||||||| parent of eb51646 (fix: concurrent downloads, ref #322)
=======

pub fn random_identifier() -> String {
let mut state = RandomState::new().build_hasher().finish();
std::iter::repeat_with(move || {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
state
})
.take(12)
.map(|i| b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"[i as usize % 62] as char)
.collect()
}
>>>>>>> eb51646 (fix: concurrent downloads, ref #322)
225 changes: 225 additions & 0 deletions src/session/builder/impl_commit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
#[cfg(feature = "fetch-models")]
use std::fmt::Write;
use std::{any::Any, marker::PhantomData, path::Path, ptr::NonNull, sync::Arc};

use super::SessionBuilder;
use crate::{
AsPointer,
environment::get_environment,
error::{Error, ErrorCode, Result},
execution_providers::apply_execution_providers,
memory::Allocator,
ortsys,
session::{InMemorySession, Input, Output, Session, SharedSessionInner, dangerous}
};

impl SessionBuilder {
/// Downloads a pre-trained ONNX model from the given URL and builds the session.
#[cfg(feature = "fetch-models")]
#[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))]
pub fn commit_from_url(self, model_url: impl AsRef<str>) -> Result<Session> {
let mut download_dir = ort_sys::internal::dirs::cache_dir()
.expect("could not determine cache directory")
.join("models");
if std::fs::create_dir_all(&download_dir).is_err() {
download_dir = std::env::current_dir().expect("Failed to obtain current working directory");
}

let url = model_url.as_ref();
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| {
let _ = write!(&mut s, "{:02x}", b);
s
});
let model_filepath = download_dir.join(&model_filename);
let downloaded_path = if model_filepath.exists() {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
model_filepath
} else {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");

let resp = ureq::get(url).call().map_err(|e| Error::new(format!("Error downloading to file: {e}")))?;

let len = resp
.header("Content-Length")
.and_then(|s| s.parse::<usize>().ok())
.expect("Missing Content-Length header");
tracing::info!(len, "Downloading {} bytes", len);

let mut reader = resp.into_reader();
let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier()));

let f = std::fs::File::create(&temp_filepath).expect("Failed to create model file");
let mut writer = std::io::BufWriter::new(f);

let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(Error::wrap)?;
if bytes_io_count != len as u64 {
return Err(Error::new(format!("Failed to download entire model; file only has {bytes_io_count} bytes, expected {len}")));
}

drop(writer);

match std::fs::rename(&temp_filepath, &model_filepath) {
Ok(()) => model_filepath,
Err(e) => match e.kind() {
std::io::ErrorKind::AlreadyExists => {
let _ = std::fs::remove_file(temp_filepath);
model_filepath
}
_ => return Err(Error::new(format!("Failed to download model: {e}")))
}
}
};

self.commit_from_file(downloaded_path)
}

/// Loads an ONNX model from a file and builds the session.
pub fn commit_from_file<P>(mut self, model_filepath_ref: P) -> Result<Session>
where
P: AsRef<Path>
{
let model_filepath = model_filepath_ref.as_ref();
if !model_filepath.exists() {
return Err(Error::new_with_code(ErrorCode::NoSuchFile, format!("File at `{}` does not exist", model_filepath.display())));
}

let model_path = crate::util::path_to_os_char(model_filepath);

let env = get_environment()?;
apply_execution_providers(&mut self, env.execution_providers.iter().cloned())?;

if env.has_global_threadpool && !self.no_global_thread_pool {
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];
}

let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
if let Some(prepacked_weights) = self.prepacked_weights.as_ref() {
ortsys![unsafe CreateSessionWithPrepackedWeightsContainer(env.ptr(), model_path.as_ptr(), self.ptr(), prepacked_weights.ptr().cast_mut(), &mut session_ptr)?; nonNull(session_ptr)];
} else {
ortsys![unsafe CreateSession(env.ptr(), model_path.as_ptr(), self.ptr(), &mut session_ptr)?; nonNull(session_ptr)];
}

let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) };

let allocator = match &self.memory_info {
Some(info) => {
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr(), &mut allocator_ptr)?; nonNull(allocator_ptr)];
unsafe { Allocator::from_raw_unchecked(allocator_ptr) }
}
None => Allocator::default()
};

// Extract input and output properties
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, &allocator, i))
.collect::<Result<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, &allocator, i))
.collect::<Result<Vec<Output>>>()?;

let mut extras: Vec<Box<dyn Any>> = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>).collect();
if let Some(prepacked_weights) = self.prepacked_weights.take() {
extras.push(Box::new(prepacked_weights) as Box<dyn Any>);
}
if let Some(thread_manager) = self.thread_manager.take() {
extras.push(Box::new(thread_manager) as Box<dyn Any>);
}

Ok(Session {
inner: Arc::new(SharedSessionInner {
session_ptr,
allocator,
_extras: extras,
_environment: env
}),
inputs,
outputs
})
}

/// Load an ONNX graph from memory and commit the session
/// For `.ort` models, we enable `session.use_ort_model_bytes_directly`.
/// For more information, check [Load ORT format model from an in-memory byte array](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html#load-ort-format-model-from-an-in-memory-byte-array).
///
/// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that
/// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros).
pub fn commit_from_memory_directly(mut self, model_bytes: &[u8]) -> Result<InMemorySession<'_>> {
// Enable zero-copy deserialization for models in `.ort` format.
self.add_config_entry("session.use_ort_model_bytes_directly", "1")?;
self.add_config_entry("session.use_ort_model_bytes_for_initializers", "1")?;

let session = self.commit_from_memory(model_bytes)?;

Ok(InMemorySession { session, phantom: PhantomData })
}

/// Load an ONNX graph from memory and commit the session.
pub fn commit_from_memory(mut self, model_bytes: &[u8]) -> Result<Session> {
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();

let env = get_environment()?;
apply_execution_providers(&mut self, env.execution_providers.iter().cloned())?;

if env.has_global_threadpool && !self.no_global_thread_pool {
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?];
}

let model_data = model_bytes.as_ptr().cast::<std::ffi::c_void>();
let model_data_length = model_bytes.len();
if let Some(prepacked_weights) = self.prepacked_weights.as_ref() {
ortsys![
unsafe CreateSessionFromArrayWithPrepackedWeightsContainer(env.ptr(), model_data, model_data_length, self.ptr(), prepacked_weights.ptr().cast_mut(), &mut session_ptr)?;
nonNull(session_ptr)
];
} else {
ortsys![
unsafe CreateSessionFromArray(env.ptr(), model_data, model_data_length, self.ptr(), &mut session_ptr)?;
nonNull(session_ptr)
];
}

let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) };

let allocator = match &self.memory_info {
Some(info) => {
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr(), &mut allocator_ptr)?; nonNull(allocator_ptr)];
unsafe { Allocator::from_raw_unchecked(allocator_ptr) }
}
None => Allocator::default()
};

// Extract input and output properties
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, &allocator, i))
.collect::<Result<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, &allocator, i))
.collect::<Result<Vec<Output>>>()?;

let mut extras: Vec<Box<dyn Any>> = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>).collect();
if let Some(prepacked_weights) = self.prepacked_weights.take() {
extras.push(Box::new(prepacked_weights) as Box<dyn Any>);
}
if let Some(thread_manager) = self.thread_manager.take() {
extras.push(Box::new(thread_manager) as Box<dyn Any>);
}

let session = Session {
inner: Arc::new(SharedSessionInner {
session_ptr,
allocator,
_extras: extras,
_environment: env
}),
inputs,
outputs
};
Ok(session)
}
}

0 comments on commit d61f7d6

Please sign in to comment.