diff --git a/router/src/batch_types.rs b/router/src/batch_types.rs index e572b713..e77e1944 100644 --- a/router/src/batch_types.rs +++ b/router/src/batch_types.rs @@ -32,7 +32,7 @@ pub(crate) trait BatchType: Send + Sync + Clone + 'static { let generated_count = entry.generated_tokens; Self::update_stats( &stats, - entry.input_length + generated_count as usize, + entry.input_length + entry.prefix_length + generated_count as usize, (entry.request.parameters.max_new_tokens - generated_count) as usize, ) } diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 82ca23f4..48dffb62 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -39,6 +39,7 @@ use crate::pb::fmaas::StopReason::{ Cancelled, EosToken, Error, MaxTokens, NotFinished, StopSequence, TimeLimit, TokenLimit }; use crate::pb::fmaas::token_info::TopToken; +use crate::validation::RequestSize; /// Batcher #[derive(Clone)] @@ -96,6 +97,7 @@ impl Batcher { pub(crate) async fn infer( &self, input_length: usize, + prefix_length: usize, request: GenerateRequest, ) -> Result { // One shot channel to communicate with the background batching task @@ -103,7 +105,7 @@ impl Batcher { // Try to add the request to the queue self.enqueue_request(vec![ - Entry::new(request, input_length, Some(response_tx), None), + Entry::new(request, input_length, prefix_length, Some(response_tx), None), ])?; // Await on the response from the background task @@ -117,14 +119,14 @@ impl Batcher { // Add a batch of new requests to the queue and return an vec of futures that will generate the text pub(crate) async fn infer_batch( &self, - requests: Vec<(usize, GenerateRequest)>, + requests: Vec<(RequestSize, GenerateRequest)>, ) -> Result>, impl FnOnce(Result, RecvError>) -> Result + '_>>, InferError> { let mut response_chans= vec![]; let entries: Vec = requests.into_iter() - .map(|(input_length, request)| { + .map(|(request_size, request)| { // One shot channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); response_chans.push(response_rx @@ -134,7 +136,7 @@ impl Batcher { }) ); - Entry::new(request, input_length, Some(response_tx), None) + Entry::new(request, request_size.input_length, request_size.prefix_length, Some(response_tx), None) }).collect(); // Try to add the request to the queue @@ -147,6 +149,7 @@ impl Batcher { pub(crate) async fn infer_stream( &self, input_length: usize, + prefix_length: usize, request: GenerateRequest, result_map: fn (Result) -> T, on_drop: fn (&C, u32, StopReason, Option, Option, String, Option), @@ -170,7 +173,7 @@ impl Batcher { // Try to add the request to the queue self.enqueue_request(vec![ - Entry::new(request, input_length, None, Some(response_tx)), + Entry::new(request, input_length, prefix_length, None, Some(response_tx)), ])?; Ok(ResponseStream { diff --git a/router/src/grpc_server.rs b/router/src/grpc_server.rs index 0e650b7b..cf1288ea 100644 --- a/router/src/grpc_server.rs +++ b/router/src/grpc_server.rs @@ -25,7 +25,7 @@ use crate::server::ServerState; use unicode_truncate::UnicodeTruncateStr; use crate::pb::fmaas::model_info_response::ModelKind; use crate::tokenizer::AsyncTokenizer; -use crate::validation::ValidationError; +use crate::validation::{RequestSize, ValidationError}; /// Whether to fail if sampling parameters are provided in greedy-mode requests /// or to silently ignore them. @@ -127,18 +127,18 @@ impl GenerationService for GenerationServicer { if batch_size == 1 { // Single request case - let (input_length, request) = valids.into_iter().next().unwrap(); - self.state.batcher.infer(input_length, request) + let (request_size, request) = valids.into_iter().next().unwrap(); + self.state.batcher.infer(request_size.input_length, request_size.prefix_length, request) .map_ok(|response| { log_response( - &response.times, input_length, response.gen_token_count, response.reason, + &response.times, request_size.input_length, response.gen_token_count, response.reason, &response.output_text, start_time, "single", "Request", response.request_id ); vec![response.into()] }).await } else { // Batch size > 1 - let input_tokens = valids.iter().map(|r| r.0).collect::>(); + let input_tokens = valids.iter().map(|r| r.0.input_length).collect::>(); match self.state.batcher.infer_batch(valids).await { Ok(response_chans) => { try_join_all(response_chans.into_iter().zip(input_tokens).enumerate() @@ -198,13 +198,13 @@ impl GenerationService for GenerationServicer { )?; // Validate request - let (input_length, validated_request) = self + let (request_size, validated_request) = self .validate(sr.prefix_id, sr.params, vec![req.text], start_time) .await? .pop().unwrap(); let stream = self.state.batcher - .infer_stream(input_length, validated_request, |r| match r { + .infer_stream(request_size.input_length, request_size.prefix_length, validated_request, |r| match r { Ok(resp) => Ok(resp.into()), Err(err) => Err(Status::from_error(Box::new(err))), }, |ctx, count, reason, request_id, times, out, err| { @@ -222,7 +222,7 @@ impl GenerationService for GenerationServicer { } }, StreamContext { span: Span::current(), - input_token_count: input_length, + input_token_count: request_size.input_length, start_time, _permit: permit, }) @@ -297,7 +297,7 @@ impl GenerationServicer { parameters: Option, inputs: Vec, start_time: Instant, - ) -> Result, Status> { + ) -> Result, Status> { match convert_params(parameters, self.state.default_include_stop_seqs) { Ok(params) => self.state.validation.validate( prefix_id, params, inputs diff --git a/router/src/queue.rs b/router/src/queue.rs index 2c8e0241..cfafc362 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -32,6 +32,8 @@ pub(crate) struct Entry { pub stream_tx: Option>>, /// Number of tokens in the input pub input_length: usize, + /// Number of virtual tokens in the prefix, if one is specified + pub prefix_length: usize, /// Instant when this entry was queued pub queue_time: Instant, /// Instant when this entry was added to a batch (queue end time) @@ -52,6 +54,7 @@ impl Entry { pub(crate) fn new( request: GenerateRequest, input_length: usize, + prefix_length: usize, response_tx: Option>>, stream_tx: Option>>, ) -> Self { @@ -60,6 +63,7 @@ impl Entry { response_tx, stream_tx, input_length, + prefix_length, input_tokens: vec![], queue_time: Instant::now(), batch_time: None, @@ -265,7 +269,9 @@ impl Queue { break } - let input_len = entry.input_length; + // For the purposes of deciding if a request can fit into a batch, + // the input length needs to take the length of the prefix into account as well + let input_len = entry.input_length + entry.prefix_length; let output_len = entry.request.parameters.max_new_tokens as usize; let next_stats = ::update_stats( &batch_stats, input_len, output_len @@ -289,7 +295,7 @@ impl Queue { let generated_count = e.generated_tokens as usize; t.insert(( e.request.parameters.max_new_tokens as usize - generated_count, - e.input_length + generated_count, + e.input_length + e.prefix_length + generated_count, t.len(), )); } diff --git a/router/src/server.rs b/router/src/server.rs index d0c6bf01..522ae732 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -91,7 +91,7 @@ async fn generate( // Validate request //let details = req.0.parameters.details; let GenerateRequest {inputs, prefix_id, parameters} = req.0; - let (input_length, validated_request) = + let (request_size, validated_request) = state.validation.validate( prefix_id, parameters, vec![inputs] ).await.map_err(|err| { @@ -102,7 +102,7 @@ async fn generate( // Inference let response = state .batcher - .infer(input_length, validated_request) + .infer(request_size.input_length, request_size.prefix_length, validated_request) .await .map_err(|err| { tracing::error!("{err}"); diff --git a/router/src/validation.rs b/router/src/validation.rs index afb8522d..934f356c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -26,6 +26,11 @@ pub struct Validation { prefix_cache: Cache, } +pub struct RequestSize { + pub(crate) input_length: usize, + pub(crate) prefix_length: usize +} + impl Validation { pub(crate) fn new( tokenizer: AsyncTokenizer, @@ -55,7 +60,7 @@ impl Validation { prefix_id: Option, params: GenerateParameters, inputs: Vec, - ) -> Result, ValidationError> { + ) -> Result, ValidationError> { let min_new_tokens = params.min_new_tokens as usize; let max_new_tokens = params.max_new_tokens as usize; @@ -165,7 +170,10 @@ impl Validation { } Ok(( - input_length, + RequestSize { + input_length, + prefix_length + }, GenerateRequest { prefix_id: prefix_id.clone(), inputs: input, @@ -173,10 +181,10 @@ impl Validation { } )) } - }).collect::, ValidationError>>().map(|results| { + }).collect::, ValidationError>>().map(|results| { // Only record these for successful validation - for (input_length, _) in &results { - metrics::histogram!("tgi_request_input_length", *input_length as f64); + for (request_size, _) in &results { + metrics::histogram!("tgi_request_input_length", request_size.input_length as f64); metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); } results diff --git a/server/tests/test_prompt_cache.py b/server/tests/test_prompt_cache.py index 858e95a6..5c92ed8a 100644 --- a/server/tests/test_prompt_cache.py +++ b/server/tests/test_prompt_cache.py @@ -23,24 +23,31 @@ INTEGRATION_TESTS_DIR = os.path.join(REPO_ROOT, "integration_tests") +@pytest.fixture(autouse=True) +def temp_prompt_store(tmp_path): + # Unless overriden by another fixture, sets the prefix store path to some temp dir + with patch("text_generation_server.prompt_cache.PREFIX_STORE_PATH", tmp_path): + yield + + @pytest.fixture() -def temp_prompt_store(): +def integration_test_prompts(): with patch("text_generation_server.prompt_cache.PREFIX_STORE_PATH", Path(os.path.join(INTEGRATION_TESTS_DIR, "prompt_prefixes"))): yield @pytest.fixture() -def tiny_starcoder_decoder_prompt(temp_prompt_store): +def tiny_starcoder_decoder_prompt(integration_test_prompts): return "tiny_starcoder" @pytest.fixture() -def tiny_raw_llama_peft_adapter_prompt(temp_prompt_store): +def tiny_raw_llama_peft_adapter_prompt(integration_test_prompts): return "tinyllama_peft_adapter_raw" @pytest.fixture() -def tiny_llama_peft_adapter_prompt(temp_prompt_store): +def tiny_llama_peft_adapter_prompt(integration_test_prompts): return "tinyllama_peft_adapter" diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 1ca0fcc0..d1ac9c24 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -25,7 +25,7 @@ def serve( max_sequence_length: int = 2048, max_new_tokens: int = 1024, max_batch_size: int = 12, - batch_safety_margin: int = 20, + batch_safety_margin: int = typer.Option(20, help="Integer from 0-100, a percentage of free GPU memory to hold back as a safety margin to avoid OOM"), revision: Optional[str] = None, sharded: bool = False, cuda_process_memory_fraction: float = 1.0, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 46990b01..0df15ac8 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -10,6 +10,7 @@ from transformers import PreTrainedModel +import text_generation_server.prompt_cache from text_generation_server.models.types import Batch, GenerateError from text_generation_server.inference_engine.engine import BaseInferenceEngine from text_generation_server.pb import generate_pb2 @@ -44,7 +45,8 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype, max_seq_leng # Check whether model supports position_ids self.use_position_ids = "position_ids" in inspect.signature(self.model.forward).parameters - prompt_prefix_supported = self._setup_prompt_encoder() + # Short-circuit: Don't set up the prompt encoder if the prompt cache is not set + prompt_prefix_supported = self.prompt_cache_set() and self._setup_prompt_encoder() if prompt_prefix_supported: # Set up prefix cache @@ -184,6 +186,10 @@ def get_indices_to_keep( next_batch_keep_indices.append(i) return next_batch_keep_indices + @staticmethod + def prompt_cache_set() -> bool: + return text_generation_server.prompt_cache.PREFIX_STORE_PATH is not None + def _setup_prompt_encoder(self) -> bool: try: self.word_embeddings = self.model.get_input_embeddings() diff --git a/server/text_generation_server/prompt_cache.py b/server/text_generation_server/prompt_cache.py index de1a8718..fb65ffd7 100644 --- a/server/text_generation_server/prompt_cache.py +++ b/server/text_generation_server/prompt_cache.py @@ -9,7 +9,8 @@ import torch -PREFIX_STORE_PATH = Path(os.getenv("PREFIX_STORE_PATH", "prompt_prefixes")) +_PREFIX_STORE_PATH_STR = os.getenv("PREFIX_STORE_PATH", None) +PREFIX_STORE_PATH = Path(_PREFIX_STORE_PATH_STR) if _PREFIX_STORE_PATH_STR else None VALID_PREFIX_ID_PATTERN = re.compile("[/\\w\\-]+") PROMPT_CACHE_SIZE_MB = int(os.getenv("PROMPT_CACHE_SIZE_MB", "512")) diff --git a/server/text_generation_server/utils/memory_characterizer.py b/server/text_generation_server/utils/memory_characterizer.py index 8ebefd16..15806442 100644 --- a/server/text_generation_server/utils/memory_characterizer.py +++ b/server/text_generation_server/utils/memory_characterizer.py @@ -1,4 +1,5 @@ from text_generation_server.pb import generate_pb2 +from text_generation_server.prompt_cache import PROMPT_CACHE_SIZE_MB import numpy as np import torch import torch.cuda @@ -8,7 +9,7 @@ # Set the memory estimation method to auto, manual or off. If PT2C is used, auto will be forced. ESTIMATE_MEMORY = os.getenv("ESTIMATE_MEMORY", "auto") -assert ESTIMATE_MEMORY in ["auto", "manual", "off"] +assert ESTIMATE_MEMORY in ["auto", "manual", "off"], "valid options for ESTIMATE_MEMORY are auto, manual, and off" # Select the batch size that is used to run the tests. The idea is to make the # batch large enough so that the measurement is more accurate, i.e. improve signal to # noise ratio. If set too large it could prevent the estimator from finding a quadratic @@ -332,7 +333,17 @@ def find_baseline(self): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(self.model.device) self.baseline = torch.cuda.max_memory_allocated(self.model.device) - self.free_memory, _ = torch.cuda.mem_get_info(self.model.device) + device_free_memory, _ = torch.cuda.mem_get_info(self.model.device) + + # If the model contains a prefix cache, then reduce available memory by the max size of the cache + if self.model.prefix_cache is not None: + max_cache_size = PROMPT_CACHE_SIZE_MB * 1024 * 1024 + print(f"Prefix cache enabled, reducing available memory by {max_cache_size}") + self.free_memory = device_free_memory - max_cache_size + else: + print(f"Prefix cache disabled, using all available memory") + self.free_memory = device_free_memory + print("Baseline: %d, Free memory: %d" % (self.baseline, self.free_memory)) def find_upper_bound(self):