Skip to content

Commit

Permalink
Self-contained safetensor wrappers (huggingface#946)
Browse files Browse the repository at this point in the history
* Self-contained safetensor wrappers.

* Use the new safetensor container in varbuilders.
  • Loading branch information
LaurentMazare authored and EricLBuehler committed Oct 25, 2023
1 parent 9f1e49d commit 8a013ff
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 30 deletions.
22 changes: 13 additions & 9 deletions candle-core/src/safetensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,18 @@ impl MmapedSafetensors {
}

pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.get(name)?.load(dev)
}

pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
let mut tensors = vec![];
for safetensors in self.safetensors.iter() {
tensors.push(safetensors.get().0.tensors())
}
tensors.into_iter().flatten().collect()
}

pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
let index = match &self.routing {
None => 0,
Some(routing) => {
Expand All @@ -333,15 +345,7 @@ impl MmapedSafetensors {
*index
}
};
self.safetensors[index].get().0.tensor(name)?.load(dev)
}

pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
let mut tensors = vec![];
for safetensors in self.safetensors.iter() {
tensors.push(safetensors.get().0.tensors())
}
tensors.into_iter().flatten().collect()
Ok(self.safetensors[index].get().0.tensor(name)?)
}
}

Expand Down
26 changes: 6 additions & 20 deletions candle-examples/examples/t5/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,30 +122,16 @@ impl T5ModelBuilder {
}

pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
let weights = self
.weights_filename
.iter()
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
.collect::<candle::Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|w| w.deserialize())
.collect::<candle::Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
};
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
}

pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
let weights = self
.weights_filename
.iter()
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
.collect::<candle::Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|w| w.deserialize())
.collect::<candle::Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
};
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}
Expand Down
43 changes: 42 additions & 1 deletion candle-nn/src/var_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,32 @@ impl SimpleBackend for candle::npy::NpzTensors {
}
}

impl SimpleBackend for candle::safetensors::MmapedSafetensors {
fn get(
&self,
s: Shape,
name: &str,
_: crate::Init,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {name}"),
expected: s,
got: tensor.shape().clone(),
}
.bt())?
}
Ok(tensor)
}

fn contains_tensor(&self, name: &str) -> bool {
self.get(name).is_ok()
}
}

impl<'a> VarBuilder<'a> {
fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
let data = TensorData {
Expand Down Expand Up @@ -361,7 +387,7 @@ impl<'a> VarBuilder<'a> {
}

/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
/// files.
/// data.
pub fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, dev: &Device) -> Self {
let mut routing = HashMap::new();
for (index, sf) in safetensors.iter().enumerate() {
Expand All @@ -376,6 +402,21 @@ impl<'a> VarBuilder<'a> {
Self::new(Box::new(tensors), dtype, dev.clone())
}

/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
/// files.
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
paths: &[P],
dtype: DType,
dev: &Device,
) -> Result<Self> {
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
}

/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
let npz = candle::npy::NpzTensors::new(p)?;
Expand Down

0 comments on commit 8a013ff

Please sign in to comment.