Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
hack: work around Clippy Arc Send/Sync lint
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Aug 27, 2023
1 parent f2bfe0f commit 1c9efac
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 44 deletions.
13 changes: 11 additions & 2 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ impl PartialEq for ContextInner {
impl Eq for ContextInner {}
impl ContextInner {
pub(crate) fn new(ptr: *mut ggml_sys::ggml_context) -> Arc<Self> {
// This context can only be used from one thread at a time - hence why
// it doesn't implement `Send/Sync` - but higher-level abstractions may
// choose to layer their own abstractions that implement higher-level
// synchronization that can offer thread-safety guarantees. To ensure
// that we don't break those, we still use an `Arc` here.
// TODO: check if this is correct?
#[allow(clippy::arc_with_non_send_sync)]
Arc::new(Self {
ptr: NonNull::new(ptr).expect("Should not be null"),
offloaded_tensors: Default::default(),
Expand Down Expand Up @@ -118,7 +125,9 @@ impl PartialEq for ContextStorage {
impl Eq for ContextStorage {}

impl Context {
/// Creates a new [Context] with the given storage..
// See explanation in [`ContextInner::new`].
#[allow(clippy::arc_with_non_send_sync)]
/// Creates a new [Context] with the given storage.
pub fn new(storage: ContextStorage) -> Self {
let init_params = match &storage {
ContextStorage::Buffer(buffer) => sys::ggml_init_params {
Expand Down Expand Up @@ -296,7 +305,7 @@ impl Context {
self.new_tensor_raw(tensor)
}

/// Repeats the `a` tensor along the first dimension of the `b` tensor.
/// Repeats the `a` tensor along the first dimension of the `b` tensor.
pub fn op_repeat(&self, a: &Tensor, b: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_repeat(self.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
Expand Down
12 changes: 8 additions & 4 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use tracing::{instrument, log};
use ggml::accelerator::metal::MetalContext;

use crate::{
mulf, util, InferenceParameters, Model, ModelParameters, OutputRequest, Prompt, TokenId,
TokenUtf8Buffer, TokenizationError,
mulf, util, InferenceParameters, Model, ModelContext, ModelParameters, OutputRequest, Prompt,
TokenId, TokenUtf8Buffer, TokenizationError,
};

// The size of a scratch buffer used for inference. This is used for temporary
Expand Down Expand Up @@ -148,6 +148,10 @@ impl InferenceSession {
ggml::accelerator::set_scratch_size(config.n_batch * 1024 * 1024);
}

// TODO: revisit this with `Rc`, maybe? We should be able to prove that the session
// context is only accessed from one thread at a time, but I've already spent enough
// time on this as-is.
#[allow(clippy::arc_with_non_send_sync)]
let session_ctx = Arc::new(ggml::Context::new_with_allocate(context_byte_size));

// Initialize key + value memory tensors
Expand Down Expand Up @@ -215,7 +219,7 @@ impl InferenceSession {
/// Compute a model (possibly building a graph in the provided closure when called for the first time and/or when parameters have)
pub fn compute<F>(
&mut self,
#[allow(unused_variables)] model_context: Arc<Context>,
#[allow(unused_variables)] model_context: ModelContext,
input_tokens: &[TokenId],
builder: F,
) -> GraphOutputs
Expand All @@ -242,7 +246,7 @@ impl InferenceSession {
#[cfg(feature = "metal")]
{
if let Some(ref mut metal_context) = self.metal_context {
metal_context.add_context(model_context);
metal_context.add_context(model_context.0);
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub use loader::{
};
pub use lora::{LoraAdapter, LoraParameters};
pub use memmap2::Mmap;
pub use model::{Hyperparameters, KnownModel, Model, ModelParameters, OutputRequest};
pub use model::{Hyperparameters, KnownModel, Model, ModelContext, ModelParameters, OutputRequest};
pub use quantize::{quantize, QuantizeError, QuantizeProgress};
pub use regex::Regex;
pub use tokenizer::{
Expand Down
14 changes: 9 additions & 5 deletions crates/llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use std::{
fs::File,
io::{BufRead, BufReader, Read, Seek, SeekFrom},
path::{Path, PathBuf},
sync::Arc,
};

use crate::{
util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelParameters, TokenId,
Tokenizer, TokenizerLoadError, TokenizerSource,
util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelContext, ModelParameters,
TokenId, Tokenizer, TokenizerLoadError, TokenizerSource,
};
pub use ggml::{format::FormatMagic, ContainerType};
use ggml::{
Expand Down Expand Up @@ -398,7 +399,7 @@ pub trait TensorLoader<E: std::error::Error> {
/// Gets a tensor from the loader.
fn load(&mut self, name: &str) -> Result<ggml::Tensor, E>;
/// Finish loading the model, returning the context.
fn finish(self) -> Context;
fn finish(self) -> ModelContext;
}

/// Load a GGML model from the `path` and configure it per the `params`. The status
Expand Down Expand Up @@ -676,8 +677,11 @@ impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
Ok(tensor)
}

fn finish(self) -> Context {
self.context
fn finish(self) -> ModelContext {
// We can ignore this warning as it's OK to share this particular
// context around, being that it is immutable.
#[allow(clippy::arc_with_non_send_sync)]
ModelContext(Arc::new(self.context))
}
}

Expand Down
11 changes: 11 additions & 0 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
fmt::Debug,
io::{BufRead, Write},
path::{Path, PathBuf},
sync::Arc,
};

use ggml::accelerator::Backend;
Expand Down Expand Up @@ -263,3 +264,13 @@ pub struct OutputRequest {
/// `n_batch * n_embd`.
pub embeddings: Option<Vec<f32>>,
}

/// Contains the GGML context for a [`Model`]. Implements `Send` and `Sync`
/// to allow for the free transfer of models; this is made possible by this
/// context being effectively inert after creation, so that it cannot be
/// modified across threads.
#[derive(Clone)]
#[allow(clippy::arc_with_non_send_sync)]
pub struct ModelContext(pub(crate) Arc<ggml::Context>);
unsafe impl Send for ModelContext {}
unsafe impl Sync for ModelContext {}
8 changes: 3 additions & 5 deletions crates/models/bloom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
//! for the `llm` ecosystem.
#![deny(missing_docs)]

use std::sync::Arc;

use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel,
ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
};

/// The BLOOM model. Ref: [Introducing BLOOM](https://bigscience.huggingface.co/blog/bloom)
Expand Down Expand Up @@ -37,7 +35,7 @@ pub struct Bloom {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for Bloom {}
Expand Down Expand Up @@ -101,7 +99,7 @@ impl KnownModel for Bloom {
output_norm_bias,
output,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 3 additions & 5 deletions crates/models/falcon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
//! supported. It is currently only available as a preview.
#![deny(missing_docs)]

use std::sync::Arc;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
};

/// The Falcon model. Ref: [Technology Innovation Institute](https://huggingface.co/tiiuae)
Expand All @@ -39,7 +37,7 @@ pub struct Falcon {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for Falcon {}
Expand Down Expand Up @@ -138,7 +136,7 @@ impl KnownModel for Falcon {
output_norm_b,
lm_head,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 3 additions & 5 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
//! An implementation of [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2) for the `llm` ecosystem.
#![deny(missing_docs)]

use std::sync::Arc;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
};

/// The GPT-2 model. Ref: [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/)
Expand Down Expand Up @@ -38,7 +36,7 @@ pub struct Gpt2 {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for Gpt2 {}
Expand Down Expand Up @@ -123,7 +121,7 @@ impl KnownModel for Gpt2 {
wte,
wpe,
lm_head,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 4 additions & 4 deletions crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
//! An implementation of [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) for the `llm` ecosystem.
#![deny(missing_docs)]

use std::{error::Error, sync::Arc};
use std::error::Error;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
};

/// The GPT-J model. Ref: [GitHub](https://github.com/kingoflolz/mesh-transformer-jax/#gpt-j-6b)
Expand All @@ -35,7 +35,7 @@ pub struct GptJ {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for GptJ {}
Expand Down Expand Up @@ -117,7 +117,7 @@ impl KnownModel for GptJ {
lmh_g,
lmh_b,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 4 additions & 4 deletions crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
//! This crate also supports the [RedPajama](https://www.together.xyz/blog/redpajama) GPT-NeoX model.
#![deny(missing_docs)]

use std::{error::Error, sync::Arc};
use std::error::Error;

use ggml::Tensor;
use llm_base::{
ggml,
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
};

/// The GPT-NeoX model. Ref: [GitHub](https://github.com/EleutherAI/gpt-neox)
Expand All @@ -35,7 +35,7 @@ pub struct GptNeoX {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for GptNeoX {}
Expand Down Expand Up @@ -137,7 +137,7 @@ impl KnownModel for GptNeoX {
wte,
lmh_g,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 4 additions & 4 deletions crates/models/llama/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//! An implementation of [LLaMA](https://huggingface.co/docs/transformers/model_doc/llama) for the `llm` ecosystem.
#![deny(missing_docs)]

use std::{error::Error, sync::Arc};
use std::error::Error;

use llm_base::{
ggml::{self},
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer,
};

/// The LLaMA model. Ref: [Introducing LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/)
Expand All @@ -31,7 +31,7 @@ pub struct Llama {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for Llama {}
Expand Down Expand Up @@ -125,7 +125,7 @@ impl KnownModel for Llama {
norm,
output,
layers,
context: Arc::new(context),
context,
})
}

Expand Down
8 changes: 3 additions & 5 deletions crates/models/mpt/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
//! An implementation of [MPT](https://huggingface.co/mosaicml) for the `llm` ecosystem.
#![deny(missing_docs)]

use std::sync::Arc;

use ggml::Tensor;
use llm_base::{
ggml::{self},
model::{common, HyperparametersWriteError},
util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError,
ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
ModelContext, ModelParameters, OutputRequest, Regex, TokenId, Tokenizer,
};

/// The MosaicML Pretrained Transformer (MPT) model. Ref: [Mosaic ML](https://www.mosaicml.com/blog/mpt-7b)
Expand All @@ -31,7 +29,7 @@ pub struct Mpt {
layers: Vec<Layer>,

// must be kept alive for the model
context: Arc<ggml::Context>,
context: ModelContext,
}

unsafe impl Send for Mpt {}
Expand Down Expand Up @@ -78,7 +76,7 @@ impl KnownModel for Mpt {
wte,
norm,
layers,
context: Arc::new(context),
context,
})
}

Expand Down

0 comments on commit 1c9efac

Please sign in to comment.