Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OOM due to large prompt cache #39

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion router/src/batch_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub(crate) trait BatchType: Send + Sync + Clone + 'static {
let generated_count = entry.generated_tokens;
Self::update_stats(
&stats,
entry.input_length + generated_count as usize,
entry.input_length + entry.prefix_length + generated_count as usize,
(entry.request.parameters.max_new_tokens - generated_count) as usize,
)
}
Expand Down
13 changes: 8 additions & 5 deletions router/src/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use crate::pb::fmaas::StopReason::{
Cancelled, EosToken, Error, MaxTokens, NotFinished, StopSequence, TimeLimit, TokenLimit
};
use crate::pb::fmaas::token_info::TopToken;
use crate::validation::RequestSize;

/// Batcher
#[derive(Clone)]
Expand Down Expand Up @@ -96,14 +97,15 @@ impl Batcher {
pub(crate) async fn infer(
&self,
input_length: usize,
prefix_length: usize,
request: GenerateRequest,
) -> Result<InferResponse, InferError> {
// One shot channel to communicate with the background batching task
let (response_tx, response_rx) = oneshot::channel();

// Try to add the request to the queue
self.enqueue_request(vec![
Entry::new(request, input_length, Some(response_tx), None),
Entry::new(request, input_length, prefix_length, Some(response_tx), None),
])?;

// Await on the response from the background task
Expand All @@ -117,14 +119,14 @@ impl Batcher {
// Add a batch of new requests to the queue and return an vec of futures that will generate the text
pub(crate) async fn infer_batch(
&self,
requests: Vec<(usize, GenerateRequest)>,
requests: Vec<(RequestSize, GenerateRequest)>,
) -> Result<Vec<Map<Receiver<Result<InferResponse, ClientError>>,
impl FnOnce(Result<Result<InferResponse, ClientError>, RecvError>) -> Result<InferResponse, InferError> + '_>>, InferError> {

let mut response_chans= vec![];

let entries: Vec<Entry> = requests.into_iter()
.map(|(input_length, request)| {
.map(|(request_size, request)| {
// One shot channel to communicate with the background batching task
let (response_tx, response_rx) = oneshot::channel();
response_chans.push(response_rx
Expand All @@ -134,7 +136,7 @@ impl Batcher {
})
);

Entry::new(request, input_length, Some(response_tx), None)
Entry::new(request, request_size.input_length, request_size.prefix_length, Some(response_tx), None)
}).collect();

// Try to add the request to the queue
Expand All @@ -147,6 +149,7 @@ impl Batcher {
pub(crate) async fn infer_stream<T, C>(
&self,
input_length: usize,
prefix_length: usize,
request: GenerateRequest,
result_map: fn (Result<InferResponse, InferError>) -> T,
on_drop: fn (&C, u32, StopReason, Option<u64>, Option<Times>, String, Option<InferError>),
Expand All @@ -170,7 +173,7 @@ impl Batcher {

// Try to add the request to the queue
self.enqueue_request(vec![
Entry::new(request, input_length, None, Some(response_tx)),
Entry::new(request, input_length, prefix_length, None, Some(response_tx)),
])?;

Ok(ResponseStream {
Expand Down
18 changes: 9 additions & 9 deletions router/src/grpc_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::server::ServerState;
use unicode_truncate::UnicodeTruncateStr;
use crate::pb::fmaas::model_info_response::ModelKind;
use crate::tokenizer::AsyncTokenizer;
use crate::validation::ValidationError;
use crate::validation::{RequestSize, ValidationError};

/// Whether to fail if sampling parameters are provided in greedy-mode requests
/// or to silently ignore them.
Expand Down Expand Up @@ -127,18 +127,18 @@ impl GenerationService for GenerationServicer {

if batch_size == 1 {
// Single request case
let (input_length, request) = valids.into_iter().next().unwrap();
self.state.batcher.infer(input_length, request)
let (request_size, request) = valids.into_iter().next().unwrap();
self.state.batcher.infer(request_size.input_length, request_size.prefix_length, request)
.map_ok(|response| {
log_response(
&response.times, input_length, response.gen_token_count, response.reason,
&response.times, request_size.input_length, response.gen_token_count, response.reason,
&response.output_text, start_time, "single", "Request", response.request_id
);
vec![response.into()]
}).await
} else {
// Batch size > 1
let input_tokens = valids.iter().map(|r| r.0).collect::<Vec<usize>>();
let input_tokens = valids.iter().map(|r| r.0.input_length).collect::<Vec<usize>>();
match self.state.batcher.infer_batch(valids).await {
Ok(response_chans) => {
try_join_all(response_chans.into_iter().zip(input_tokens).enumerate()
Expand Down Expand Up @@ -198,13 +198,13 @@ impl GenerationService for GenerationServicer {
)?;

// Validate request
let (input_length, validated_request) = self
let (request_size, validated_request) = self
.validate(sr.prefix_id, sr.params, vec![req.text], start_time)
.await?
.pop().unwrap();

let stream = self.state.batcher
.infer_stream(input_length, validated_request, |r| match r {
.infer_stream(request_size.input_length, request_size.prefix_length, validated_request, |r| match r {
Ok(resp) => Ok(resp.into()),
Err(err) => Err(Status::from_error(Box::new(err))),
}, |ctx, count, reason, request_id, times, out, err| {
Expand All @@ -222,7 +222,7 @@ impl GenerationService for GenerationServicer {
}
}, StreamContext {
span: Span::current(),
input_token_count: input_length,
input_token_count: request_size.input_length,
start_time,
_permit: permit,
})
Expand Down Expand Up @@ -297,7 +297,7 @@ impl GenerationServicer {
parameters: Option<Parameters>,
inputs: Vec<String>,
start_time: Instant,
) -> Result<Vec<(usize, GenerateRequest)>, Status> {
) -> Result<Vec<(RequestSize, GenerateRequest)>, Status> {
match convert_params(parameters, self.state.default_include_stop_seqs) {
Ok(params) => self.state.validation.validate(
prefix_id, params, inputs
Expand Down
10 changes: 8 additions & 2 deletions router/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pub(crate) struct Entry {
pub stream_tx: Option<UnboundedSender<Result<InferResponse, ClientError>>>,
/// Number of tokens in the input
pub input_length: usize,
/// Number of virtual tokens in the prefix, if one is specified
pub prefix_length: usize,
/// Instant when this entry was queued
pub queue_time: Instant,
/// Instant when this entry was added to a batch (queue end time)
Expand All @@ -52,6 +54,7 @@ impl Entry {
pub(crate) fn new(
request: GenerateRequest,
input_length: usize,
prefix_length: usize,
response_tx: Option<Sender<Result<InferResponse, ClientError>>>,
stream_tx: Option<UnboundedSender<Result<InferResponse, ClientError>>>,
) -> Self {
Expand All @@ -60,6 +63,7 @@ impl Entry {
response_tx,
stream_tx,
input_length,
prefix_length,
input_tokens: vec![],
queue_time: Instant::now(),
batch_time: None,
Expand Down Expand Up @@ -265,7 +269,9 @@ impl<B: BatchType> Queue<B> {
break
}

let input_len = entry.input_length;
// For the purposes of deciding if a request can fit into a batch,
// the input length needs to take the length of the prefix into account as well
let input_len = entry.input_length + entry.prefix_length;
let output_len = entry.request.parameters.max_new_tokens as usize;
let next_stats = <B>::update_stats(
&batch_stats, input_len, output_len
Expand All @@ -289,7 +295,7 @@ impl<B: BatchType> Queue<B> {
let generated_count = e.generated_tokens as usize;
t.insert((
e.request.parameters.max_new_tokens as usize - generated_count,
e.input_length + generated_count,
e.input_length + e.prefix_length + generated_count,
t.len(),
));
}
Expand Down
4 changes: 2 additions & 2 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async fn generate(
// Validate request
//let details = req.0.parameters.details;
let GenerateRequest {inputs, prefix_id, parameters} = req.0;
let (input_length, validated_request) =
let (request_size, validated_request) =
state.validation.validate(
prefix_id, parameters, vec![inputs]
).await.map_err(|err| {
Expand All @@ -102,7 +102,7 @@ async fn generate(
// Inference
let response = state
.batcher
.infer(input_length, validated_request)
.infer(request_size.input_length, request_size.prefix_length, validated_request)
.await
.map_err(|err| {
tracing::error!("{err}");
Expand Down
18 changes: 13 additions & 5 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ pub struct Validation {
prefix_cache: Cache<String, usize, RandomState>,
}

pub struct RequestSize {
pub(crate) input_length: usize,
pub(crate) prefix_length: usize
}

impl Validation {
pub(crate) fn new(
tokenizer: AsyncTokenizer,
Expand Down Expand Up @@ -55,7 +60,7 @@ impl Validation {
prefix_id: Option<String>,
params: GenerateParameters,
inputs: Vec<String>,
) -> Result<Vec<(usize, GenerateRequest)>, ValidationError> {
) -> Result<Vec<(RequestSize, GenerateRequest)>, ValidationError> {
let min_new_tokens = params.min_new_tokens as usize;
let max_new_tokens = params.max_new_tokens as usize;

Expand Down Expand Up @@ -165,18 +170,21 @@ impl Validation {
}

Ok((
input_length,
RequestSize {
input_length,
prefix_length
},
GenerateRequest {
prefix_id: prefix_id.clone(),
inputs: input,
parameters,
}
))
}
}).collect::<Result<Vec<(usize, GenerateRequest)>, ValidationError>>().map(|results| {
}).collect::<Result<Vec<(RequestSize, GenerateRequest)>, ValidationError>>().map(|results| {
// Only record these for successful validation
for (input_length, _) in &results {
metrics::histogram!("tgi_request_input_length", *input_length as f64);
for (request_size, _) in &results {
metrics::histogram!("tgi_request_input_length", request_size.input_length as f64);
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
}
results
Expand Down
15 changes: 11 additions & 4 deletions server/tests/test_prompt_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,31 @@
INTEGRATION_TESTS_DIR = os.path.join(REPO_ROOT, "integration_tests")


@pytest.fixture(autouse=True)
def temp_prompt_store(tmp_path):
# Unless overriden by another fixture, sets the prefix store path to some temp dir
with patch("text_generation_server.prompt_cache.PREFIX_STORE_PATH", tmp_path):
yield


@pytest.fixture()
def temp_prompt_store():
def integration_test_prompts():
with patch("text_generation_server.prompt_cache.PREFIX_STORE_PATH", Path(os.path.join(INTEGRATION_TESTS_DIR, "prompt_prefixes"))):
yield


@pytest.fixture()
def tiny_starcoder_decoder_prompt(temp_prompt_store):
def tiny_starcoder_decoder_prompt(integration_test_prompts):
return "tiny_starcoder"


@pytest.fixture()
def tiny_raw_llama_peft_adapter_prompt(temp_prompt_store):
def tiny_raw_llama_peft_adapter_prompt(integration_test_prompts):
return "tinyllama_peft_adapter_raw"


@pytest.fixture()
def tiny_llama_peft_adapter_prompt(temp_prompt_store):
def tiny_llama_peft_adapter_prompt(integration_test_prompts):
return "tinyllama_peft_adapter"


Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def serve(
max_sequence_length: int = 2048,
max_new_tokens: int = 1024,
max_batch_size: int = 12,
batch_safety_margin: int = 20,
batch_safety_margin: int = typer.Option(20, help="Integer from 0-100, a percentage of free GPU memory to hold back as a safety margin to avoid OOM"),
revision: Optional[str] = None,
sharded: bool = False,
cuda_process_memory_fraction: float = 1.0,
Expand Down
8 changes: 7 additions & 1 deletion server/text_generation_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from transformers import PreTrainedModel

import text_generation_server.prompt_cache
from text_generation_server.models.types import Batch, GenerateError
from text_generation_server.inference_engine.engine import BaseInferenceEngine
from text_generation_server.pb import generate_pb2
Expand Down Expand Up @@ -44,7 +45,8 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype, max_seq_leng
# Check whether model supports position_ids
self.use_position_ids = "position_ids" in inspect.signature(self.model.forward).parameters

prompt_prefix_supported = self._setup_prompt_encoder()
# Short-circuit: Don't set up the prompt encoder if the prompt cache is not set
prompt_prefix_supported = self.prompt_cache_set() and self._setup_prompt_encoder()

if prompt_prefix_supported:
# Set up prefix cache
Expand Down Expand Up @@ -184,6 +186,10 @@ def get_indices_to_keep(
next_batch_keep_indices.append(i)
return next_batch_keep_indices

@staticmethod
def prompt_cache_set() -> bool:
return text_generation_server.prompt_cache.PREFIX_STORE_PATH is not None

def _setup_prompt_encoder(self) -> bool:
try:
self.word_embeddings = self.model.get_input_embeddings()
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/prompt_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import torch

PREFIX_STORE_PATH = Path(os.getenv("PREFIX_STORE_PATH", "prompt_prefixes"))
_PREFIX_STORE_PATH_STR = os.getenv("PREFIX_STORE_PATH", None)
PREFIX_STORE_PATH = Path(_PREFIX_STORE_PATH_STR) if _PREFIX_STORE_PATH_STR else None

VALID_PREFIX_ID_PATTERN = re.compile("[/\\w\\-]+")
PROMPT_CACHE_SIZE_MB = int(os.getenv("PROMPT_CACHE_SIZE_MB", "512"))
Expand Down
15 changes: 13 additions & 2 deletions server/text_generation_server/utils/memory_characterizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from text_generation_server.pb import generate_pb2
from text_generation_server.prompt_cache import PROMPT_CACHE_SIZE_MB
import numpy as np
import torch
import torch.cuda
Expand All @@ -8,7 +9,7 @@

# Set the memory estimation method to auto, manual or off. If PT2C is used, auto will be forced.
ESTIMATE_MEMORY = os.getenv("ESTIMATE_MEMORY", "auto")
assert ESTIMATE_MEMORY in ["auto", "manual", "off"]
assert ESTIMATE_MEMORY in ["auto", "manual", "off"], "valid options for ESTIMATE_MEMORY are auto, manual, and off"
# Select the batch size that is used to run the tests. The idea is to make the
# batch large enough so that the measurement is more accurate, i.e. improve signal to
# noise ratio. If set too large it could prevent the estimator from finding a quadratic
Expand Down Expand Up @@ -332,7 +333,17 @@ def find_baseline(self):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(self.model.device)
self.baseline = torch.cuda.max_memory_allocated(self.model.device)
self.free_memory, _ = torch.cuda.mem_get_info(self.model.device)
device_free_memory, _ = torch.cuda.mem_get_info(self.model.device)

# If the model contains a prefix cache, then reduce available memory by the max size of the cache
if self.model.prefix_cache is not None:
max_cache_size = PROMPT_CACHE_SIZE_MB * 1024 * 1024
print(f"Prefix cache enabled, reducing available memory by {max_cache_size}")
self.free_memory = device_free_memory - max_cache_size
else:
print(f"Prefix cache disabled, using all available memory")
self.free_memory = device_free_memory

print("Baseline: %d, Free memory: %d" % (self.baseline, self.free_memory))

def find_upper_bound(self):
Expand Down
Loading