Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from IBM:main #85

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_model(
dtype_str: str,
quantize: Optional[str],
max_sequence_length: Optional[int],
memory_scaling_model: Optional[int] = None,
) -> Model:
dtype = get_torch_dtype(dtype_str)
model_path = get_model_path(model_name, revision)
Expand Down Expand Up @@ -74,6 +75,7 @@ def get_model(
dtype, quantize,
model_config,
max_sequence_length=max_sequence_length,
memory_scaling_model=memory_scaling_model,
)

if FLASH_ATTENTION:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ def __init__(self, config, weights):
weights=weights,
)

def get_kv_cache_block_size(self, block_size: int) -> int:
return block_size * self.model.num_key_value_heads * self.model.head_size * 2

def get_input_embeddings(self) -> nn.Module:
return self.model.embed_tokens

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,9 @@ def __init__(self, config, weights):
config, prefix="transformer.wte", weights=weights
)

def get_kv_cache_block_size(self, block_size: int) -> int:
return block_size * self.transformer.head_size * 2

def get_input_embeddings(self) -> nn.Module:
return self.transformer.wte

Expand Down
79 changes: 49 additions & 30 deletions server/text_generation_server/models/paged_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@
from text_generation_server.utils.token_types import TokenInfo, InputTokens
from text_generation_server.utils.tokens import HeterogeneousNextTokenChooser, get_token_info, get_input_tokens_info
from text_generation_server.utils.paged import (
load_speculator,
prepare_inputs_without_speculation,
prepare_inputs_with_speculation,
process_outputs_with_speculation,
prepare_inputs_for_prefill
)
from text_generation_server.inference_engine import get_inference_engine_class

# HF name or path to speculator model (None means no speculation will be used)
SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None)

# we will only do speculation if the batch size is <= this parameter
SPECULATOR_MAX_BATCH_SIZE = int(os.getenv("SPECULATOR_MAX_BATCH_SIZE", "16"))

Expand Down Expand Up @@ -277,6 +275,7 @@ def __init__(
quantize: Optional[str],
model_config: Union[Any] = None,
max_sequence_length: Optional[int] = None,
memory_scaling_model: Optional["MemoryScalingModel"] = None,
):
model_path = get_model_path(model_name, revision)

Expand All @@ -300,27 +299,41 @@ def __init__(

from fms_extras.utils.cache.paged import PagedKVCacheManager

if SPECULATOR_NAME is not None:
from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel
speculator_revision = os.getenv("SPECULATOR_REVISION", None)
speculator_model_path = get_model_path(SPECULATOR_NAME, speculator_revision)
print_rank_n(f"Loading speculator model from: {speculator_model_path}")
# load speculator
self.speculator = load_speculator(self.device, dtype)

if self.speculator is not None:
print_rank_n(f"Speculation will be enabled up to batch size {SPECULATOR_MAX_BATCH_SIZE}")
kwargs = {
"pretrained_model_name_or_path": speculator_model_path,
"local_files_only": True,
"torch_dtype": dtype,
}
with self.device:
self.speculator = MLPSpeculatorPreTrainedModel.from_pretrained(**kwargs)
self.speculator.to(device=self.device)
else:
self.speculator = None

block_size = 16

if KV_CACHE_MANAGER_NUM_GPU_BLOCKS is not None:
total_num_gpu_blocks = int(KV_CACHE_MANAGER_NUM_GPU_BLOCKS)
else:
total_num_gpu_blocks = None
# Firstly, let's compute the size of a cache block in bytes
kv_cache_block_size = self.model.get_kv_cache_block_size(block_size)
total_size = model_config.num_hidden_layers * kv_cache_block_size
dtype_size = torch.tensor([], dtype=dtype).element_size()
cache_block_size = dtype_size * total_size
# We then use our memory scaling model to determine the fraction of the prefill memory
# usage that is due to cache blocks (as opposed to the other stuff needed for forward):
pf_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.linear_fit_params[0]
# We can then do the same for the next token (decoding) step:
nt_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.next_token_params[1]
# In general we know that the next token phase can use many more cache blocks
# relative to the prefill phase (e.g., nt_cache_block_ratio > pf_cache_block_ratio).
# Thus, we need to allocate enough cache blocks to handle the more extreme case:
total_num_gpu_blocks = int(nt_cache_block_ratio * memory_scaling_model.free_memory // cache_block_size)
# This creates an issue though, because if we then try to perform a large prefill, while we
# will certainly have enough cache blocks available, we may not have enough memory leftover
# to allocate the other data structures needed during a forward pass.
# To overcome this, we can set the batch_safety_margin a bit to ensure that:
# free_memory * (1.0-batch_safety_margin/100-0.05) * (1.0-pf_cache_block_ratio) <
# free_memory * (1.0-nf_cache_block_ratio)
# This should ensure that our prefills batches can never get so big as to cause OOM.
recommend_safety_margin = 5 + int(100*(1.0 - (1.0 - nt_cache_block_ratio)/(1.0 - pf_cache_block_ratio)))
if memory_scaling_model.safety_margin < recommend_safety_margin:
print(f"WARN: We recommend increasing the value of BATCH_SAFETY_MARGIN to: {recommend_safety_margin}")

self.kv_cache_manager = PagedKVCacheManager(
model_config.num_hidden_layers,
Expand All @@ -331,8 +344,11 @@ def __init__(
dtype=dtype,
device=self.device,
total_num_gpu_blocks=total_num_gpu_blocks,
block_size=block_size,
)

self.memory_scaling_model = memory_scaling_model

# log number of free blocks at init
print("[PagedKVCacheManager] number of free blocks: %d" % (len(self.kv_cache_manager.free_blocks)))

Expand Down Expand Up @@ -413,12 +429,18 @@ def _prefill(
)

t0 = time.time_ns()
output = self.model(
input_ids,
position_ids=position_ids,
cache_data=cache_data,
return_embeds=True,
)
try:
output = self.model(
input_ids,
position_ids=position_ids,
cache_data=cache_data,
return_embeds=True,
)
except:
# if something goes wrong during forward, we still need to set the sequence ids
#TODO it would be better to fix the forward method to avoid possibility of partial failures
batch.sequence_ids = cache_data.sequence_ids
raise
t_forward_ns = time.time_ns()-t0
logits, embeds = output

Expand Down Expand Up @@ -603,10 +625,7 @@ def generate_token(
)
else:
bsize = batch.input_ids.shape[0]

tokens_remaining = 0
for i in range(len(batch.total_lengths)):
tokens_remaining += batch.total_lengths[i] - batch.input_lengths[i]
weight = sum(batch.total_lengths) * self.memory_scaling_model.next_token_params[1]

spec_ind = []
for i, sample in enumerate(batch.next_token_chooser.do_sample):
Expand All @@ -618,7 +637,7 @@ def generate_token(
len(spec_ind) > 0 and
bsize <= SPECULATOR_MAX_BATCH_SIZE and
batch.next_token_chooser.repetition_processor is None and
tokens_remaining < 0.25*len(self.kv_cache_manager.free_blocks)*self.kv_cache_manager.block_size
(weight/self.memory_scaling_model.weight_limit) <= 0.75
)

if speculate:
Expand Down
45 changes: 24 additions & 21 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def func_with_log(*args, **kwargs):


class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str], memory_scaling_model: MemoryScalingModelPB):
def __init__(self, model: Model, cache: Cache, server_urls: List[str], memory_scaling_model: MemoryScalingModel):
self.cache = cache
self.model = model
self.server_urls = server_urls
Expand All @@ -81,7 +81,7 @@ async def ModelInfo(self, request: generate_pb2.ModelInfoRequest, context) -> ge
if isinstance(self.model, Seq2SeqLM) else ModelInfoResponse.ModelType.CAUSAL_LM,
eos_token=getattr(self.model.tokenizer, 'model_eos_token_id', self.model.tokenizer.eos_token_id),
batch_padding=not isinstance(self.model, FlashCausalLM),
memory_scaling_model=self.memory_scaling_model,
memory_scaling_model=self.memory_scaling_model.as_pb(),
)

@log_rpc_handler_errors
Expand Down Expand Up @@ -244,8 +244,9 @@ def _free_paged_sequences(self, batch: "Batch", completed_ids: Optional[List[int
]
else:
return
self.model.kv_cache_manager.free_sequences(sequence_ids_to_free, recursive=True)

if sequence_ids_to_free is not None:
self.model.kv_cache_manager.free_sequences(sequence_ids_to_free, recursive=True)

def serve(
model_name: str,
Expand Down Expand Up @@ -273,6 +274,22 @@ async def serve_inner(
batch_safety_margin: int,
sharded: bool = False,
):
if quantize not in [None, "gptq", "bitsandbytes"]:
raise ValueError(f"Unrecognized quantization method specified: {quantize}")

if quantize is None and dtype_str == "int8":
print_rank_n("Inferring quantize = bitsandbytes because dtype == int8")
quantize = "bitsandbytes"

cuda_available = torch.cuda.is_available()

# Default dtype based on device if not provided
if dtype_str is None:
dtype_str = "float16" if cuda_available else "float32"

if quantize is not None and not cuda_available:
raise ValueError("Quantization requires CUDA")

if ESTIMATE_MEMORY == "auto" and PAGED_ATTENTION:
# fit memory model using flash model in separate process (ensures GPU memory is entirely cleaned up)
from text_generation_server.utils.paged import fit_memory_scaling_model
Expand All @@ -286,6 +303,8 @@ async def serve_inner(
proc.start()
memory_scaling_model_ext = q_out.get()
proc.join()
else:
memory_scaling_model_ext = None

unix_socket_template = "unix://{}-{}"
world_size = int(os.getenv("WORLD_SIZE", "1"))
Expand All @@ -296,28 +315,12 @@ async def serve_inner(
]
local_url = server_urls[local_rank]

if quantize not in [None, "gptq", "bitsandbytes"]:
raise ValueError(f"Unrecognized quantization method specified: {quantize}")

# Default dtype based on device if not provided
if dtype_str is None:
dtype_str = "float16" if torch.cuda.is_available() else "float32"

if quantize is None and dtype_str == "int8":
print_rank_n("Inferring quantize = bitsandbytes because dtype == int8")
quantize = "bitsandbytes"

cuda_available = torch.cuda.is_available()

if quantize is not None and not cuda_available:
raise ValueError("Quantization requires CUDA")

# Set the fraction of cuda/gpu mem available to this process, then load the model
if cuda_available and cuda_process_memory_fraction < 1:
torch.cuda.set_per_process_memory_fraction(cuda_process_memory_fraction)

model = get_model(
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length, memory_scaling_model_ext,
)

device = model.engine.get_device()
Expand Down Expand Up @@ -424,7 +427,7 @@ def estimate_memory():

server = aio.server()
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(model, Cache(), server_urls, memory_scaling_model.as_pb()), server
TextGenerationService(model, Cache(), server_urls, memory_scaling_model), server
)
# SERVICE_NAMES = (
# generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Expand Down
27 changes: 27 additions & 0 deletions server/text_generation_server/utils/paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,37 @@

from fms_extras.models.speculator import flatten_batch, apply_index_map

# HF name or path to speculator model (None means no speculation will be used)
SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None)

# speculator revision
SPECULATOR_REVISION = os.getenv("SPECULATOR_REVISION", None)

# number of candidates during speculation
SPECULATOR_N_CANDIDATES = os.getenv("SPECULATOR_N_CANDIDATES", None)

# number of candidates per head
SPECULATOR_TOP_K_TOKENS_PER_HEAD = os.getenv("SPECULATOR_TOP_K_TOKENS_PER_HEAD", None)

def load_speculator(device, dtype):

if SPECULATOR_NAME is not None:
from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel
from text_generation_server.utils.hub import get_model_path
from text_generation_server.utils import print_rank_n
speculator_model_path = get_model_path(SPECULATOR_NAME, SPECULATOR_REVISION)
print_rank_n(f"Loading speculator model from: {speculator_model_path}")
kwargs = {
"pretrained_model_name_or_path": speculator_model_path,
"local_files_only": True,
"torch_dtype": dtype,
}
with device:
speculator = MLPSpeculatorPreTrainedModel.from_pretrained(**kwargs)
speculator.to(device=device)
return speculator
else:
return None

def fit_memory_scaling_model(
model_name: str,
Expand Down Expand Up @@ -38,6 +63,8 @@ def fit_memory_scaling_model(
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length
)

speculator = load_speculator(model.device, model.dtype)

memory_scaling_model = Estimator.build_from_env(
model,
batch_safety_margin,
Expand Down