From 95e619f3a4d55fc97b668021d34fd732a5dfdc36 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 19 Feb 2024 21:42:32 -0500 Subject: [PATCH] allow to load safetensors from a byte array --- dfdx-core/src/nn_traits/mod.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dfdx-core/src/nn_traits/mod.rs b/dfdx-core/src/nn_traits/mod.rs index 52203373..869e1047 100644 --- a/dfdx-core/src/nn_traits/mod.rs +++ b/dfdx-core/src/nn_traits/mod.rs @@ -173,6 +173,21 @@ pub trait LoadSafeTensors { ) -> Result<(), safetensors::SafeTensorError> { self.load_safetensors_with(path, false, &mut core::convert::identity) } + fn load_safetensors_from_bytes_with String>( + &mut self, + bytes: &[u8], + skip_missing: bool, + key_map: &mut F, + ) -> Result<(), safetensors::SafeTensorError> { + let tensors = safetensors::SafeTensors::deserialize(&bytes)?; + self.read_safetensors_with("", &tensors, skip_missing, key_map) + } + fn load_safetensors_from_bytes( + &mut self, + bytes: &[u8], + ) -> Result<(), safetensors::SafeTensorError> { + self.load_safetensors_from_bytes_with(bytes, false, &mut core::convert::identity) + } fn read_safetensors_with String>( &mut self,