diff --git a/ort-sys/build.rs b/ort-sys/build.rs index dc0672c..151b946 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -1,7 +1,15 @@ use std::{ +<<<<<<< HEAD env, fs, path::{Path, PathBuf}, process::Command +||||||| parent of eb51646 (fix: concurrent downloads, ref #322) + env, fs, + path::{Path, PathBuf} +======= + env, fs, io, + path::{Path, PathBuf} +>>>>>>> eb51646 (fix: concurrent downloads, ref #322) }; #[allow(unused)] @@ -12,9 +20,9 @@ const ORT_ENV_SYSTEM_LIB_PROFILE: &str = "ORT_LIB_PROFILE"; const DIST_TABLE: &str = include_str!("dist.txt"); -#[path = "src/internal/dirs.rs"] -mod dirs; -use self::dirs::cache_dir; +#[path = "src/internal/mod.rs"] +mod internal; +use self::internal::dirs::cache_dir; #[cfg(feature = "download-binaries")] fn fetch_file(source_url: &str) -> Vec { @@ -416,21 +424,45 @@ fn prepare_libort_dir() -> (PathBuf, bool) { let (prebuilt_url, prebuilt_hash) = dist.unwrap(); - let mut cache_dir = cache_dir() + let bin_extract_dir = cache_dir() .expect("could not determine cache directory") .join("dfbin") .join(target) .join(prebuilt_hash); - if fs::create_dir_all(&cache_dir).is_err() { - cache_dir = env::var("OUT_DIR").unwrap().into(); - } +<<<<<<< HEAD let ort_extract_dir = prebuilt_url.split('/').last().unwrap().strip_suffix(".tgz").unwrap(); let lib_dir = cache_dir.join(ort_extract_dir); +||||||| parent of eb51646 (fix: concurrent downloads, ref #322) + let lib_dir = cache_dir.join(ORT_EXTRACT_DIR); +======= + let lib_dir = bin_extract_dir.join(ORT_EXTRACT_DIR); +>>>>>>> eb51646 (fix: concurrent downloads, ref #322) if !lib_dir.exists() { let downloaded_file = fetch_file(prebuilt_url); assert!(verify_file(&downloaded_file, prebuilt_hash), "hash of downloaded ONNX Runtime binary does not match!"); - extract_tgz(&downloaded_file, &cache_dir); + + let mut temp_extract_dir = bin_extract_dir + .parent() + .unwrap() + .join(format!("tmp.{}_{prebuilt_hash}", self::internal::random_identifier())); + let mut should_rename = true; + if fs::create_dir_all(&temp_extract_dir).is_err() { + temp_extract_dir = env::var("OUT_DIR").unwrap().into(); + should_rename = false; + } + extract_tgz(&downloaded_file, &temp_extract_dir); + if should_rename { + match std::fs::rename(&temp_extract_dir, bin_extract_dir) { + Ok(()) => {} + Err(e) => match e.kind() { + io::ErrorKind::AlreadyExists | io::ErrorKind::DirectoryNotEmpty => { + let _ = fs::remove_dir_all(temp_extract_dir); + } + _ => panic!("failed to extract downloaded binaries: {e}") + } + } + } } static_link_prerequisites(true); diff --git a/ort-sys/src/internal/mod.rs b/ort-sys/src/internal/mod.rs index 7516f3c..45c67ba 100644 --- a/ort-sys/src/internal/mod.rs +++ b/ort-sys/src/internal/mod.rs @@ -1,4 +1,23 @@ +use std::hash::{BuildHasher, Hasher, RandomState}; + pub mod dirs; +<<<<<<< HEAD #[cfg(feature = "download-binaries")] include!(concat!(env!("OUT_DIR"), "/downloaded_version.rs")); +||||||| parent of eb51646 (fix: concurrent downloads, ref #322) +======= + +pub fn random_identifier() -> String { + let mut state = RandomState::new().build_hasher().finish(); + std::iter::repeat_with(move || { + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + state + }) + .take(12) + .map(|i| b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"[i as usize % 62] as char) + .collect() +} +>>>>>>> eb51646 (fix: concurrent downloads, ref #322) diff --git a/src/session/builder/impl_commit.rs b/src/session/builder/impl_commit.rs new file mode 100644 index 0000000..a31d27d --- /dev/null +++ b/src/session/builder/impl_commit.rs @@ -0,0 +1,225 @@ +#[cfg(feature = "fetch-models")] +use std::fmt::Write; +use std::{any::Any, marker::PhantomData, path::Path, ptr::NonNull, sync::Arc}; + +use super::SessionBuilder; +use crate::{ + AsPointer, + environment::get_environment, + error::{Error, ErrorCode, Result}, + execution_providers::apply_execution_providers, + memory::Allocator, + ortsys, + session::{InMemorySession, Input, Output, Session, SharedSessionInner, dangerous} +}; + +impl SessionBuilder { + /// Downloads a pre-trained ONNX model from the given URL and builds the session. + #[cfg(feature = "fetch-models")] + #[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))] + pub fn commit_from_url(self, model_url: impl AsRef) -> Result { + let mut download_dir = ort_sys::internal::dirs::cache_dir() + .expect("could not determine cache directory") + .join("models"); + if std::fs::create_dir_all(&download_dir).is_err() { + download_dir = std::env::current_dir().expect("Failed to obtain current working directory"); + } + + let url = model_url.as_ref(); + let model_filename = ::digest(url).into_iter().fold(String::new(), |mut s, b| { + let _ = write!(&mut s, "{:02x}", b); + s + }); + let model_filepath = download_dir.join(&model_filename); + let downloaded_path = if model_filepath.exists() { + tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download"); + model_filepath + } else { + tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model"); + + let resp = ureq::get(url).call().map_err(|e| Error::new(format!("Error downloading to file: {e}")))?; + + let len = resp + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + .expect("Missing Content-Length header"); + tracing::info!(len, "Downloading {} bytes", len); + + let mut reader = resp.into_reader(); + let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier())); + + let f = std::fs::File::create(&temp_filepath).expect("Failed to create model file"); + let mut writer = std::io::BufWriter::new(f); + + let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(Error::wrap)?; + if bytes_io_count != len as u64 { + return Err(Error::new(format!("Failed to download entire model; file only has {bytes_io_count} bytes, expected {len}"))); + } + + drop(writer); + + match std::fs::rename(&temp_filepath, &model_filepath) { + Ok(()) => model_filepath, + Err(e) => match e.kind() { + std::io::ErrorKind::AlreadyExists => { + let _ = std::fs::remove_file(temp_filepath); + model_filepath + } + _ => return Err(Error::new(format!("Failed to download model: {e}"))) + } + } + }; + + self.commit_from_file(downloaded_path) + } + + /// Loads an ONNX model from a file and builds the session. + pub fn commit_from_file

(mut self, model_filepath_ref: P) -> Result + where + P: AsRef + { + let model_filepath = model_filepath_ref.as_ref(); + if !model_filepath.exists() { + return Err(Error::new_with_code(ErrorCode::NoSuchFile, format!("File at `{}` does not exist", model_filepath.display()))); + } + + let model_path = crate::util::path_to_os_char(model_filepath); + + let env = get_environment()?; + apply_execution_providers(&mut self, env.execution_providers.iter().cloned())?; + + if env.has_global_threadpool && !self.no_global_thread_pool { + ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?]; + } + + let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); + if let Some(prepacked_weights) = self.prepacked_weights.as_ref() { + ortsys![unsafe CreateSessionWithPrepackedWeightsContainer(env.ptr(), model_path.as_ptr(), self.ptr(), prepacked_weights.ptr().cast_mut(), &mut session_ptr)?; nonNull(session_ptr)]; + } else { + ortsys![unsafe CreateSession(env.ptr(), model_path.as_ptr(), self.ptr(), &mut session_ptr)?; nonNull(session_ptr)]; + } + + let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; + + let allocator = match &self.memory_info { + Some(info) => { + let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); + ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr(), &mut allocator_ptr)?; nonNull(allocator_ptr)]; + unsafe { Allocator::from_raw_unchecked(allocator_ptr) } + } + None => Allocator::default() + }; + + // Extract input and output properties + let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; + let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; + let inputs = (0..num_input_nodes) + .map(|i| dangerous::extract_input(session_ptr, &allocator, i)) + .collect::>>()?; + let outputs = (0..num_output_nodes) + .map(|i| dangerous::extract_output(session_ptr, &allocator, i)) + .collect::>>()?; + + let mut extras: Vec> = self.operator_domains.drain(..).map(|d| Box::new(d) as Box).collect(); + if let Some(prepacked_weights) = self.prepacked_weights.take() { + extras.push(Box::new(prepacked_weights) as Box); + } + if let Some(thread_manager) = self.thread_manager.take() { + extras.push(Box::new(thread_manager) as Box); + } + + Ok(Session { + inner: Arc::new(SharedSessionInner { + session_ptr, + allocator, + _extras: extras, + _environment: env + }), + inputs, + outputs + }) + } + + /// Load an ONNX graph from memory and commit the session + /// For `.ort` models, we enable `session.use_ort_model_bytes_directly`. + /// For more information, check [Load ORT format model from an in-memory byte array](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html#load-ort-format-model-from-an-in-memory-byte-array). + /// + /// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that + /// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros). + pub fn commit_from_memory_directly(mut self, model_bytes: &[u8]) -> Result> { + // Enable zero-copy deserialization for models in `.ort` format. + self.add_config_entry("session.use_ort_model_bytes_directly", "1")?; + self.add_config_entry("session.use_ort_model_bytes_for_initializers", "1")?; + + let session = self.commit_from_memory(model_bytes)?; + + Ok(InMemorySession { session, phantom: PhantomData }) + } + + /// Load an ONNX graph from memory and commit the session. + pub fn commit_from_memory(mut self, model_bytes: &[u8]) -> Result { + let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); + + let env = get_environment()?; + apply_execution_providers(&mut self, env.execution_providers.iter().cloned())?; + + if env.has_global_threadpool && !self.no_global_thread_pool { + ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?]; + } + + let model_data = model_bytes.as_ptr().cast::(); + let model_data_length = model_bytes.len(); + if let Some(prepacked_weights) = self.prepacked_weights.as_ref() { + ortsys![ + unsafe CreateSessionFromArrayWithPrepackedWeightsContainer(env.ptr(), model_data, model_data_length, self.ptr(), prepacked_weights.ptr().cast_mut(), &mut session_ptr)?; + nonNull(session_ptr) + ]; + } else { + ortsys![ + unsafe CreateSessionFromArray(env.ptr(), model_data, model_data_length, self.ptr(), &mut session_ptr)?; + nonNull(session_ptr) + ]; + } + + let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; + + let allocator = match &self.memory_info { + Some(info) => { + let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); + ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr(), &mut allocator_ptr)?; nonNull(allocator_ptr)]; + unsafe { Allocator::from_raw_unchecked(allocator_ptr) } + } + None => Allocator::default() + }; + + // Extract input and output properties + let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; + let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; + let inputs = (0..num_input_nodes) + .map(|i| dangerous::extract_input(session_ptr, &allocator, i)) + .collect::>>()?; + let outputs = (0..num_output_nodes) + .map(|i| dangerous::extract_output(session_ptr, &allocator, i)) + .collect::>>()?; + + let mut extras: Vec> = self.operator_domains.drain(..).map(|d| Box::new(d) as Box).collect(); + if let Some(prepacked_weights) = self.prepacked_weights.take() { + extras.push(Box::new(prepacked_weights) as Box); + } + if let Some(thread_manager) = self.thread_manager.take() { + extras.push(Box::new(thread_manager) as Box); + } + + let session = Session { + inner: Arc::new(SharedSessionInner { + session_ptr, + allocator, + _extras: extras, + _environment: env + }), + inputs, + outputs + }; + Ok(session) + } +}