Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement prefix caching #95

Merged
merged 10 commits into from
Apr 11, 2024
26 changes: 26 additions & 0 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tracing::warn;
use crate::{
get_mut_arcmutex, handle_pipeline_forward_error, handle_seq_error, handle_seq_error_stateaware,
pipeline::Pipeline,
prefix_cacher::{MatchingCache, PrefixCacheManager},
request::Request,
response::{
ChatCompletionResponse, Choice, ChunkChoice, Delta, Logprobs, Response, ResponseLogprob,
Expand All @@ -35,6 +36,7 @@ pub struct Engine {
id: usize,
truncate_sequence: bool,
no_kv_cache: bool,
prefix_cacher: PrefixCacheManager,
}

impl Engine {
Expand All @@ -45,13 +47,15 @@ impl Engine {
truncate_sequence: bool,
no_kv_cache: bool,
) -> Self {
let device = get_mut_arcmutex!(pipeline).device().clone();
Self {
rx,
pipeline,
scheduler: Scheduler::new(method),
id: 0,
truncate_sequence,
no_kv_cache,
prefix_cacher: PrefixCacheManager::new(device, 1), // TODO(EricLBuehler): not have this hardcoded
}
}

Expand Down Expand Up @@ -109,6 +113,16 @@ impl Engine {
seq.prompt_timestamp = Some(now);
}

for seq in scheduled
.prompt
.iter_mut()
.take(self.prefix_cacher.n_on_device)
{
self.prefix_cacher.add_sequence(seq);
}
// Evict all the other seqs
handle_pipeline_forward_error!("evict", self.prefix_cacher.evict_to_cpu(), &mut scheduled.prompt, pipeline, 'lp);

let before_sample = Instant::now();
Self::sample_seqs(&mut *pipeline, &mut scheduled.prompt, logits);
let sampling_time = before_sample.elapsed().as_millis();
Expand Down Expand Up @@ -401,6 +415,10 @@ impl Engine {
warn!("Prompt for request {} was {} tokens over the model maximum length. The last {} tokens were truncated to make space for generation.", request.id, currently_over, prompt_len - prompt.len());
}
}
let prefill_cache = handle_seq_error!(
self.prefix_cacher.search_for_matching_cache(&prompt),
request.response
);

let topk = request
.sampling_params
Expand Down Expand Up @@ -484,6 +502,14 @@ impl Engine {
now.as_secs(),
recognizer.clone(),
);
let seq = if let Some(prefill_cache) = prefill_cache.clone() {
match prefill_cache {
MatchingCache::Verbatim(cache) => seq.prefill(cache),
MatchingCache::Subset(cache, toks) => seq.prefill_subset(cache, toks),
}
} else {
seq
};
self.id += 1;
self.scheduler.add_seq(seq);
}
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod aici;
mod engine;
mod models;
mod pipeline;
mod prefix_cacher;
mod request;
mod response;
mod sampler;
Expand Down
107 changes: 107 additions & 0 deletions mistralrs-core/src/prefix_cacher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use candle_core::{Device, Result};
use indexmap::IndexMap;

use crate::{models::LayerCaches, sequence::Sequence};

pub struct PrefixCacheManager {
caches: IndexMap<Vec<u32>, LayerCaches>,
cpu_caches: IndexMap<Vec<u32>, LayerCaches>,
device: Device,
pub n_on_device: usize,
}

#[derive(Clone)]
pub enum MatchingCache {
Verbatim(LayerCaches),
Subset(LayerCaches, Vec<u32>),
}

impl PrefixCacheManager {
pub fn new(device: Device, n_on_device: usize) -> Self {
PrefixCacheManager {
caches: IndexMap::new(),
cpu_caches: IndexMap::new(),
device,
n_on_device,
}
}

/// This always keeps the cache on the device. If later on, a new seq cannot be allocated due to memory shortage,
/// some caches will be evicted.
pub fn add_sequence(&mut self, seq: &mut Sequence) {
self.caches
.insert(seq.get_toks().to_vec(), seq.cache().clone());
}

/// Evict the caches to CPU. This will evict the first k seqs such that the number of sequences on device after the copy is
/// the maximum allowed. Returns the number of evicted sequences.
pub fn evict_to_cpu(&mut self) -> Result<usize> {
// Intentionally evict the first ones first, as they are the oldest
for (ids, cache) in self.caches.drain(0..self.caches.len() - self.n_on_device) {
let mut new_cache = Vec::new();
for layer in cache {
if let Some((ref q, ref k)) = layer {
new_cache.push(Some((
q.to_device(&Device::Cpu)?,
k.to_device(&Device::Cpu)?,
)));
} else {
new_cache.push(None);
}
}
self.cpu_caches.insert(ids, new_cache);
}
Ok(self.caches.len() - self.n_on_device)
}

pub fn promote_into_device_cache(
&mut self,
toks: Vec<u32>,
cache: &LayerCaches,
) -> Result<LayerCaches> {
let mut new_cache = Vec::new();
for layer in cache {
if let Some((ref q, ref k)) = layer {
new_cache.push(Some((
q.to_device(&self.device)?,
k.to_device(&self.device)?,
)));
} else {
new_cache.push(None);
}
}
// Load it into the cache
self.caches.insert(toks, new_cache.clone());
Ok(new_cache)
}

/// Search for a matching cache given some toks
pub fn search_for_matching_cache(&mut self, toks: &[u32]) -> Result<Option<MatchingCache>> {
if let Some(cache) = self.caches.get(toks) {
Ok(Some(MatchingCache::Verbatim(cache.clone())))
} else if let Some(cache) = self.cpu_caches.get(toks).cloned() {
Ok(Some(MatchingCache::Verbatim(
self.promote_into_device_cache(toks.to_vec(), &cache)?,
)))
} else {
// Look for token ids such that they begins with `toks`
for (ids, cache) in &self.caches {
if ids.len() >= toks.len() && &ids[0..toks.len()] == toks {
return Ok(Some(MatchingCache::Subset(
cache.clone(),
toks[ids.len()..].to_vec(),
)));
}
}
for (ids, cache) in self.cpu_caches.clone() {
if ids.len() >= toks.len() && &ids[0..toks.len()] == toks {
return Ok(Some(MatchingCache::Subset(
self.promote_into_device_cache(toks.to_vec(), &cache)?,
toks[ids.len()..].to_vec(),
)));
}
}
Ok(None)
}
}
}
7 changes: 6 additions & 1 deletion mistralrs-core/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ impl<Backer: FcfsBacker> Scheduler<Backer> {
}

pub fn add_seq(&mut self, seq: Sequence) {
self.waiting.add(seq);
if seq.is_running() {
// prefill case
self.running.push(seq);
} else {
self.waiting.add(seq);
}
}

/// Schedule all sequences based on their state and the available space.
Expand Down
32 changes: 30 additions & 2 deletions mistralrs-core/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{
use crate::aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx};
use crate::{
get_mut_group,
models::LayerCaches,
response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT},
sampler::{Logprobs, Sampler},
ChatCompletionResponse, ChatCompletionUsage,
Expand Down Expand Up @@ -41,6 +42,7 @@ pub enum SequenceState {
RunningCompletion,
Waiting,
Error,
RunningPrefillPrompt,
}

#[derive(Clone)]
Expand All @@ -62,11 +64,12 @@ pub struct Sequence {
responder: Sender<Response>,
response_index: usize,
creation_time: u64,
prefill_prompt_toks: Option<Vec<u32>>,

// Cache
scaling_cache: Option<Tensor>,
cache: Vec<Option<(Tensor, Tensor)>>,
xlora_cache: Option<Vec<Option<(Tensor, Tensor)>>>,
cache: LayerCaches,
xlora_cache: Option<LayerCaches>,

// Mutables
tokens: Vec<u32>,
Expand Down Expand Up @@ -131,9 +134,29 @@ impl Sequence {
response_index,
creation_time,
recognizer,
prefill_prompt_toks: None,
}
}

pub fn prefill(mut self, cache: LayerCaches) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!")
.as_millis();
self.prompt_timestamp = Some(now);

self.cache = cache;
self.set_state(SequenceState::RunningCompletion);
self
}

pub fn prefill_subset(mut self, cache: LayerCaches, toks: Vec<u32>) -> Self {
self.cache = cache;
self.prefill_prompt_toks = Some(toks);
self.set_state(SequenceState::RunningPrefillPrompt);
self
}

pub fn len(&self) -> usize {
self.tokens.len()
}
Expand All @@ -145,6 +168,7 @@ impl Sequence {
pub fn is_running(&self) -> bool {
self.state.get() == SequenceState::RunningCompletion
|| self.state.get() == SequenceState::RunningPrompt
|| self.state.get() == SequenceState::RunningPrefillPrompt
}

pub fn is_completion(&self) -> bool {
Expand All @@ -153,13 +177,17 @@ impl Sequence {

pub fn is_prompt(&self) -> bool {
self.state.get() == SequenceState::RunningPrompt
|| self.state.get() == SequenceState::RunningPrefillPrompt
}

pub fn is_waiting(&self) -> bool {
self.state.get() == SequenceState::Waiting
}

pub fn get_toks(&self) -> &[u32] {
if let Some(toks) = &self.prefill_prompt_toks {
return toks;
}
&self.tokens
}

Expand Down
Loading