forked from pykeio/ort
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: concurrent downloads, ref pykeio#322
(cherry picked from commit eb51646) Conflicts: ort-sys/build.rs ort-sys/src/internal/mod.rs src/session/builder/impl_commit.rs
- Loading branch information
1 parent
69ac04c
commit d61f7d6
Showing
3 changed files
with
284 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,23 @@ | ||
use std::hash::{BuildHasher, Hasher, RandomState}; | ||
|
||
pub mod dirs; | ||
<<<<<<< HEAD | ||
|
||
#[cfg(feature = "download-binaries")] | ||
include!(concat!(env!("OUT_DIR"), "/downloaded_version.rs")); | ||
||||||| parent of eb51646 (fix: concurrent downloads, ref #322) | ||
======= | ||
|
||
pub fn random_identifier() -> String { | ||
let mut state = RandomState::new().build_hasher().finish(); | ||
std::iter::repeat_with(move || { | ||
state ^= state << 13; | ||
state ^= state >> 7; | ||
state ^= state << 17; | ||
state | ||
}) | ||
.take(12) | ||
.map(|i| b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"[i as usize % 62] as char) | ||
.collect() | ||
} | ||
>>>>>>> eb51646 (fix: concurrent downloads, ref #322) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
#[cfg(feature = "fetch-models")] | ||
use std::fmt::Write; | ||
use std::{any::Any, marker::PhantomData, path::Path, ptr::NonNull, sync::Arc}; | ||
|
||
use super::SessionBuilder; | ||
use crate::{ | ||
AsPointer, | ||
environment::get_environment, | ||
error::{Error, ErrorCode, Result}, | ||
execution_providers::apply_execution_providers, | ||
memory::Allocator, | ||
ortsys, | ||
session::{InMemorySession, Input, Output, Session, SharedSessionInner, dangerous} | ||
}; | ||
|
||
impl SessionBuilder { | ||
/// Downloads a pre-trained ONNX model from the given URL and builds the session. | ||
#[cfg(feature = "fetch-models")] | ||
#[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))] | ||
pub fn commit_from_url(self, model_url: impl AsRef<str>) -> Result<Session> { | ||
let mut download_dir = ort_sys::internal::dirs::cache_dir() | ||
.expect("could not determine cache directory") | ||
.join("models"); | ||
if std::fs::create_dir_all(&download_dir).is_err() { | ||
download_dir = std::env::current_dir().expect("Failed to obtain current working directory"); | ||
} | ||
|
||
let url = model_url.as_ref(); | ||
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| { | ||
let _ = write!(&mut s, "{:02x}", b); | ||
s | ||
}); | ||
let model_filepath = download_dir.join(&model_filename); | ||
let downloaded_path = if model_filepath.exists() { | ||
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download"); | ||
model_filepath | ||
} else { | ||
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model"); | ||
|
||
let resp = ureq::get(url).call().map_err(|e| Error::new(format!("Error downloading to file: {e}")))?; | ||
|
||
let len = resp | ||
.header("Content-Length") | ||
.and_then(|s| s.parse::<usize>().ok()) | ||
.expect("Missing Content-Length header"); | ||
tracing::info!(len, "Downloading {} bytes", len); | ||
|
||
let mut reader = resp.into_reader(); | ||
let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier())); | ||
|
||
let f = std::fs::File::create(&temp_filepath).expect("Failed to create model file"); | ||
let mut writer = std::io::BufWriter::new(f); | ||
|
||
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(Error::wrap)?; | ||
if bytes_io_count != len as u64 { | ||
return Err(Error::new(format!("Failed to download entire model; file only has {bytes_io_count} bytes, expected {len}"))); | ||
} | ||
|
||
drop(writer); | ||
|
||
match std::fs::rename(&temp_filepath, &model_filepath) { | ||
Ok(()) => model_filepath, | ||
Err(e) => match e.kind() { | ||
std::io::ErrorKind::AlreadyExists => { | ||
let _ = std::fs::remove_file(temp_filepath); | ||
model_filepath | ||
} | ||
_ => return Err(Error::new(format!("Failed to download model: {e}"))) | ||
} | ||
} | ||
}; | ||
|
||
self.commit_from_file(downloaded_path) | ||
} | ||
|
||
/// Loads an ONNX model from a file and builds the session. | ||
pub fn commit_from_file<P>(mut self, model_filepath_ref: P) -> Result<Session> | ||
where | ||
P: AsRef<Path> | ||
{ | ||
let model_filepath = model_filepath_ref.as_ref(); | ||
if !model_filepath.exists() { | ||
return Err(Error::new_with_code(ErrorCode::NoSuchFile, format!("File at `{}` does not exist", model_filepath.display()))); | ||
} | ||
|
||
let model_path = crate::util::path_to_os_char(model_filepath); | ||
|
||
let env = get_environment()?; | ||
apply_execution_providers(&mut self, env.execution_providers.iter().cloned())?; | ||
|
||
if env.has_global_threadpool && !self.no_global_thread_pool { | ||
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?]; | ||
} | ||
|
||
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); | ||
if let Some(prepacked_weights) = self.prepacked_weights.as_ref() { | ||
ortsys![unsafe CreateSessionWithPrepackedWeightsContainer(env.ptr(), model_path.as_ptr(), self.ptr(), prepacked_weights.ptr().cast_mut(), &mut session_ptr)?; nonNull(session_ptr)]; | ||
} else { | ||
ortsys![unsafe CreateSession(env.ptr(), model_path.as_ptr(), self.ptr(), &mut session_ptr)?; nonNull(session_ptr)]; | ||
} | ||
|
||
let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; | ||
|
||
let allocator = match &self.memory_info { | ||
Some(info) => { | ||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); | ||
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr(), &mut allocator_ptr)?; nonNull(allocator_ptr)]; | ||
unsafe { Allocator::from_raw_unchecked(allocator_ptr) } | ||
} | ||
None => Allocator::default() | ||
}; | ||
|
||
// Extract input and output properties | ||
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; | ||
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; | ||
let inputs = (0..num_input_nodes) | ||
.map(|i| dangerous::extract_input(session_ptr, &allocator, i)) | ||
.collect::<Result<Vec<Input>>>()?; | ||
let outputs = (0..num_output_nodes) | ||
.map(|i| dangerous::extract_output(session_ptr, &allocator, i)) | ||
.collect::<Result<Vec<Output>>>()?; | ||
|
||
let mut extras: Vec<Box<dyn Any>> = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>).collect(); | ||
if let Some(prepacked_weights) = self.prepacked_weights.take() { | ||
extras.push(Box::new(prepacked_weights) as Box<dyn Any>); | ||
} | ||
if let Some(thread_manager) = self.thread_manager.take() { | ||
extras.push(Box::new(thread_manager) as Box<dyn Any>); | ||
} | ||
|
||
Ok(Session { | ||
inner: Arc::new(SharedSessionInner { | ||
session_ptr, | ||
allocator, | ||
_extras: extras, | ||
_environment: env | ||
}), | ||
inputs, | ||
outputs | ||
}) | ||
} | ||
|
||
/// Load an ONNX graph from memory and commit the session | ||
/// For `.ort` models, we enable `session.use_ort_model_bytes_directly`. | ||
/// For more information, check [Load ORT format model from an in-memory byte array](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html#load-ort-format-model-from-an-in-memory-byte-array). | ||
/// | ||
/// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that | ||
/// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros). | ||
pub fn commit_from_memory_directly(mut self, model_bytes: &[u8]) -> Result<InMemorySession<'_>> { | ||
// Enable zero-copy deserialization for models in `.ort` format. | ||
self.add_config_entry("session.use_ort_model_bytes_directly", "1")?; | ||
self.add_config_entry("session.use_ort_model_bytes_for_initializers", "1")?; | ||
|
||
let session = self.commit_from_memory(model_bytes)?; | ||
|
||
Ok(InMemorySession { session, phantom: PhantomData }) | ||
} | ||
|
||
/// Load an ONNX graph from memory and commit the session. | ||
pub fn commit_from_memory(mut self, model_bytes: &[u8]) -> Result<Session> { | ||
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); | ||
|
||
let env = get_environment()?; | ||
apply_execution_providers(&mut self, env.execution_providers.iter().cloned())?; | ||
|
||
if env.has_global_threadpool && !self.no_global_thread_pool { | ||
ortsys![unsafe DisablePerSessionThreads(self.ptr_mut())?]; | ||
} | ||
|
||
let model_data = model_bytes.as_ptr().cast::<std::ffi::c_void>(); | ||
let model_data_length = model_bytes.len(); | ||
if let Some(prepacked_weights) = self.prepacked_weights.as_ref() { | ||
ortsys![ | ||
unsafe CreateSessionFromArrayWithPrepackedWeightsContainer(env.ptr(), model_data, model_data_length, self.ptr(), prepacked_weights.ptr().cast_mut(), &mut session_ptr)?; | ||
nonNull(session_ptr) | ||
]; | ||
} else { | ||
ortsys![ | ||
unsafe CreateSessionFromArray(env.ptr(), model_data, model_data_length, self.ptr(), &mut session_ptr)?; | ||
nonNull(session_ptr) | ||
]; | ||
} | ||
|
||
let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; | ||
|
||
let allocator = match &self.memory_info { | ||
Some(info) => { | ||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); | ||
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr(), &mut allocator_ptr)?; nonNull(allocator_ptr)]; | ||
unsafe { Allocator::from_raw_unchecked(allocator_ptr) } | ||
} | ||
None => Allocator::default() | ||
}; | ||
|
||
// Extract input and output properties | ||
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; | ||
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; | ||
let inputs = (0..num_input_nodes) | ||
.map(|i| dangerous::extract_input(session_ptr, &allocator, i)) | ||
.collect::<Result<Vec<Input>>>()?; | ||
let outputs = (0..num_output_nodes) | ||
.map(|i| dangerous::extract_output(session_ptr, &allocator, i)) | ||
.collect::<Result<Vec<Output>>>()?; | ||
|
||
let mut extras: Vec<Box<dyn Any>> = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>).collect(); | ||
if let Some(prepacked_weights) = self.prepacked_weights.take() { | ||
extras.push(Box::new(prepacked_weights) as Box<dyn Any>); | ||
} | ||
if let Some(thread_manager) = self.thread_manager.take() { | ||
extras.push(Box::new(thread_manager) as Box<dyn Any>); | ||
} | ||
|
||
let session = Session { | ||
inner: Arc::new(SharedSessionInner { | ||
session_ptr, | ||
allocator, | ||
_extras: extras, | ||
_environment: env | ||
}), | ||
inputs, | ||
outputs | ||
}; | ||
Ok(session) | ||
} | ||
} |