Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/release'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia authored and heyselbi committed Jun 21, 2024
2 parents 1f6a325 + d5340ca commit 503da48
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 51 deletions.
2 changes: 2 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_model(
dtype_str: str,
quantize: Optional[str],
max_sequence_length: Optional[int],
memory_scaling_model: Optional[int] = None,
) -> Model:
dtype = get_torch_dtype(dtype_str)
model_path = get_model_path(model_name, revision)
Expand Down Expand Up @@ -74,6 +75,7 @@ def get_model(
dtype, quantize,
model_config,
max_sequence_length=max_sequence_length,
memory_scaling_model=memory_scaling_model,
)

if FLASH_ATTENTION:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ def __init__(self, config, weights):
weights=weights,
)

def get_kv_cache_block_size(self, block_size: int) -> int:
return block_size * self.model.num_key_value_heads * self.model.head_size * 2

def get_input_embeddings(self) -> nn.Module:
return self.model.embed_tokens

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,9 @@ def __init__(self, config, weights):
config, prefix="transformer.wte", weights=weights
)

def get_kv_cache_block_size(self, block_size: int) -> int:
return block_size * self.transformer.head_size * 2

def get_input_embeddings(self) -> nn.Module:
return self.transformer.wte

Expand Down
79 changes: 49 additions & 30 deletions server/text_generation_server/models/paged_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@
from text_generation_server.utils.token_types import TokenInfo, InputTokens
from text_generation_server.utils.tokens import HeterogeneousNextTokenChooser, get_token_info, get_input_tokens_info
from text_generation_server.utils.paged import (
load_speculator,
prepare_inputs_without_speculation,
prepare_inputs_with_speculation,
process_outputs_with_speculation,
prepare_inputs_for_prefill
)
from text_generation_server.inference_engine import get_inference_engine_class

# HF name or path to speculator model (None means no speculation will be used)
SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None)

# we will only do speculation if the batch size is <= this parameter
SPECULATOR_MAX_BATCH_SIZE = int(os.getenv("SPECULATOR_MAX_BATCH_SIZE", "16"))

Expand Down Expand Up @@ -277,6 +275,7 @@ def __init__(
quantize: Optional[str],
model_config: Union[Any] = None,
max_sequence_length: Optional[int] = None,
memory_scaling_model: Optional["MemoryScalingModel"] = None,
):
model_path = get_model_path(model_name, revision)

Expand All @@ -300,27 +299,41 @@ def __init__(

from fms_extras.utils.cache.paged import PagedKVCacheManager

if SPECULATOR_NAME is not None:
from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel
speculator_revision = os.getenv("SPECULATOR_REVISION", None)
speculator_model_path = get_model_path(SPECULATOR_NAME, speculator_revision)
print_rank_n(f"Loading speculator model from: {speculator_model_path}")
# load speculator
self.speculator = load_speculator(self.device, dtype)

if self.speculator is not None:
print_rank_n(f"Speculation will be enabled up to batch size {SPECULATOR_MAX_BATCH_SIZE}")
kwargs = {
"pretrained_model_name_or_path": speculator_model_path,
"local_files_only": True,
"torch_dtype": dtype,
}
with self.device:
self.speculator = MLPSpeculatorPreTrainedModel.from_pretrained(**kwargs)
self.speculator.to(device=self.device)
else:
self.speculator = None

block_size = 16

if KV_CACHE_MANAGER_NUM_GPU_BLOCKS is not None:
total_num_gpu_blocks = int(KV_CACHE_MANAGER_NUM_GPU_BLOCKS)
else:
total_num_gpu_blocks = None
# Firstly, let's compute the size of a cache block in bytes
kv_cache_block_size = self.model.get_kv_cache_block_size(block_size)
total_size = model_config.num_hidden_layers * kv_cache_block_size
dtype_size = torch.tensor([], dtype=dtype).element_size()
cache_block_size = dtype_size * total_size
# We then use our memory scaling model to determine the fraction of the prefill memory
# usage that is due to cache blocks (as opposed to the other stuff needed for forward):
pf_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.linear_fit_params[0]
# We can then do the same for the next token (decoding) step:
nt_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.next_token_params[1]
# In general we know that the next token phase can use many more cache blocks
# relative to the prefill phase (e.g., nt_cache_block_ratio > pf_cache_block_ratio).
# Thus, we need to allocate enough cache blocks to handle the more extreme case:
total_num_gpu_blocks = int(nt_cache_block_ratio * memory_scaling_model.free_memory // cache_block_size)
# This creates an issue though, because if we then try to perform a large prefill, while we
# will certainly have enough cache blocks available, we may not have enough memory leftover
# to allocate the other data structures needed during a forward pass.
# To overcome this, we can set the batch_safety_margin a bit to ensure that:
# free_memory * (1.0-batch_safety_margin/100-0.05) * (1.0-pf_cache_block_ratio) <
# free_memory * (1.0-nf_cache_block_ratio)
# This should ensure that our prefills batches can never get so big as to cause OOM.
recommend_safety_margin = 5 + int(100*(1.0 - (1.0 - nt_cache_block_ratio)/(1.0 - pf_cache_block_ratio)))
if memory_scaling_model.safety_margin < recommend_safety_margin:
print(f"WARN: We recommend increasing the value of BATCH_SAFETY_MARGIN to: {recommend_safety_margin}")

self.kv_cache_manager = PagedKVCacheManager(
model_config.num_hidden_layers,
Expand All @@ -331,8 +344,11 @@ def __init__(
dtype=dtype,
device=self.device,
total_num_gpu_blocks=total_num_gpu_blocks,
block_size=block_size,
)

self.memory_scaling_model = memory_scaling_model

# log number of free blocks at init
print("[PagedKVCacheManager] number of free blocks: %d" % (len(self.kv_cache_manager.free_blocks)))

Expand Down Expand Up @@ -413,12 +429,18 @@ def _prefill(
)

t0 = time.time_ns()
output = self.model(
input_ids,
position_ids=position_ids,
cache_data=cache_data,
return_embeds=True,
)
try:
output = self.model(
input_ids,
position_ids=position_ids,
cache_data=cache_data,
return_embeds=True,
)
except:
# if something goes wrong during forward, we still need to set the sequence ids
#TODO it would be better to fix the forward method to avoid possibility of partial failures
batch.sequence_ids = cache_data.sequence_ids
raise
t_forward_ns = time.time_ns()-t0
logits, embeds = output

Expand Down Expand Up @@ -603,10 +625,7 @@ def generate_token(
)
else:
bsize = batch.input_ids.shape[0]

tokens_remaining = 0
for i in range(len(batch.total_lengths)):
tokens_remaining += batch.total_lengths[i] - batch.input_lengths[i]
weight = sum(batch.total_lengths) * self.memory_scaling_model.next_token_params[1]

spec_ind = []
for i, sample in enumerate(batch.next_token_chooser.do_sample):
Expand All @@ -618,7 +637,7 @@ def generate_token(
len(spec_ind) > 0 and
bsize <= SPECULATOR_MAX_BATCH_SIZE and
batch.next_token_chooser.repetition_processor is None and
tokens_remaining < 0.25*len(self.kv_cache_manager.free_blocks)*self.kv_cache_manager.block_size
(weight/self.memory_scaling_model.weight_limit) <= 0.75
)

if speculate:
Expand Down
45 changes: 24 additions & 21 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def func_with_log(*args, **kwargs):


class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str], memory_scaling_model: MemoryScalingModelPB):
def __init__(self, model: Model, cache: Cache, server_urls: List[str], memory_scaling_model: MemoryScalingModel):
self.cache = cache
self.model = model
self.server_urls = server_urls
Expand All @@ -81,7 +81,7 @@ async def ModelInfo(self, request: generate_pb2.ModelInfoRequest, context) -> ge
if isinstance(self.model, Seq2SeqLM) else ModelInfoResponse.ModelType.CAUSAL_LM,
eos_token=getattr(self.model.tokenizer, 'model_eos_token_id', self.model.tokenizer.eos_token_id),
batch_padding=not isinstance(self.model, FlashCausalLM),
memory_scaling_model=self.memory_scaling_model,
memory_scaling_model=self.memory_scaling_model.as_pb(),
)

@log_rpc_handler_errors
Expand Down Expand Up @@ -244,8 +244,9 @@ def _free_paged_sequences(self, batch: "Batch", completed_ids: Optional[List[int
]
else:
return
self.model.kv_cache_manager.free_sequences(sequence_ids_to_free, recursive=True)

if sequence_ids_to_free is not None:
self.model.kv_cache_manager.free_sequences(sequence_ids_to_free, recursive=True)

def serve(
model_name: str,
Expand Down Expand Up @@ -273,6 +274,22 @@ async def serve_inner(
batch_safety_margin: int,
sharded: bool = False,
):
if quantize not in [None, "gptq", "bitsandbytes"]:
raise ValueError(f"Unrecognized quantization method specified: {quantize}")

if quantize is None and dtype_str == "int8":
print_rank_n("Inferring quantize = bitsandbytes because dtype == int8")
quantize = "bitsandbytes"

cuda_available = torch.cuda.is_available()

# Default dtype based on device if not provided
if dtype_str is None:
dtype_str = "float16" if cuda_available else "float32"

if quantize is not None and not cuda_available:
raise ValueError("Quantization requires CUDA")

if ESTIMATE_MEMORY == "auto" and PAGED_ATTENTION:
# fit memory model using flash model in separate process (ensures GPU memory is entirely cleaned up)
from text_generation_server.utils.paged import fit_memory_scaling_model
Expand All @@ -286,6 +303,8 @@ async def serve_inner(
proc.start()
memory_scaling_model_ext = q_out.get()
proc.join()
else:
memory_scaling_model_ext = None

unix_socket_template = "unix://{}-{}"
world_size = int(os.getenv("WORLD_SIZE", "1"))
Expand All @@ -296,28 +315,12 @@ async def serve_inner(
]
local_url = server_urls[local_rank]

if quantize not in [None, "gptq", "bitsandbytes"]:
raise ValueError(f"Unrecognized quantization method specified: {quantize}")

# Default dtype based on device if not provided
if dtype_str is None:
dtype_str = "float16" if torch.cuda.is_available() else "float32"

if quantize is None and dtype_str == "int8":
print_rank_n("Inferring quantize = bitsandbytes because dtype == int8")
quantize = "bitsandbytes"

cuda_available = torch.cuda.is_available()

if quantize is not None and not cuda_available:
raise ValueError("Quantization requires CUDA")

# Set the fraction of cuda/gpu mem available to this process, then load the model
if cuda_available and cuda_process_memory_fraction < 1:
torch.cuda.set_per_process_memory_fraction(cuda_process_memory_fraction)

model = get_model(
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length, memory_scaling_model_ext,
)

device = model.engine.get_device()
Expand Down Expand Up @@ -424,7 +427,7 @@ def estimate_memory():

server = aio.server()
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(model, Cache(), server_urls, memory_scaling_model.as_pb()), server
TextGenerationService(model, Cache(), server_urls, memory_scaling_model), server
)
# SERVICE_NAMES = (
# generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Expand Down
27 changes: 27 additions & 0 deletions server/text_generation_server/utils/paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,37 @@

from fms_extras.models.speculator import flatten_batch, apply_index_map

# HF name or path to speculator model (None means no speculation will be used)
SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None)

# speculator revision
SPECULATOR_REVISION = os.getenv("SPECULATOR_REVISION", None)

# number of candidates during speculation
SPECULATOR_N_CANDIDATES = os.getenv("SPECULATOR_N_CANDIDATES", None)

# number of candidates per head
SPECULATOR_TOP_K_TOKENS_PER_HEAD = os.getenv("SPECULATOR_TOP_K_TOKENS_PER_HEAD", None)

def load_speculator(device, dtype):

if SPECULATOR_NAME is not None:
from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel
from text_generation_server.utils.hub import get_model_path
from text_generation_server.utils import print_rank_n
speculator_model_path = get_model_path(SPECULATOR_NAME, SPECULATOR_REVISION)
print_rank_n(f"Loading speculator model from: {speculator_model_path}")
kwargs = {
"pretrained_model_name_or_path": speculator_model_path,
"local_files_only": True,
"torch_dtype": dtype,
}
with device:
speculator = MLPSpeculatorPreTrainedModel.from_pretrained(**kwargs)
speculator.to(device=device)
return speculator
else:
return None

def fit_memory_scaling_model(
model_name: str,
Expand Down Expand Up @@ -38,6 +63,8 @@ def fit_memory_scaling_model(
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length
)

speculator = load_speculator(model.device, model.dtype)

memory_scaling_model = Estimator.build_from_env(
model,
batch_safety_margin,
Expand Down

0 comments on commit 503da48

Please sign in to comment.