Skip to content

Commit

Permalink
feat: adjust rotary embed and avoid cuda graphs of size 2 and smaller
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Dec 6, 2024
1 parent 77d1045 commit 60b9c18
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The correct answer is: blue",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1733445131,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native",
"usage": {
"completion_tokens": 7,
"prompt_tokens": 27,
"total_tokens": 34
}
}
2 changes: 1 addition & 1 deletion integration-tests/models/test_flash_qwen2_vl_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def flash_qwen2_vl_handle(launcher):
with launcher(
"Qwen/Qwen2-VL-2B-Instruct",
max_input_tokens=40,
max_input_length=40,
max_batch_prefill_tokens=50,
max_total_tokens=51,
) as handle:
Expand Down
6 changes: 6 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
BloomForCausalLM,
)
from text_generation_server.models.globals import ATTENTION
import text_generation_server.models.globals as globals
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
Expand Down Expand Up @@ -1208,6 +1209,11 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == QWEN2_VL:
# TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL
logger.warning(
"Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2."
)
globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS))
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def forward(
dim=-1,
)

self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
self.rotary_emb(
query,
torch.select(kv, dim=1, index=0),
cos[: query.shape[0], ...],
sin[: query.shape[0], ...],
)

if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
Expand Down
16 changes: 9 additions & 7 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@
MEM_POOL,
ATTENTION,
BLOCK_SIZE,
CUDA_GRAPHS,
REQUEST_LOGPROBS,
TGI_WIGGLE_ROOM,
get_adapter_to_index,
)
import text_generation_server.models.globals as globals
from text_generation_server.layers.attention import KVCache, Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
Expand Down Expand Up @@ -1629,8 +1629,8 @@ def warmup(
int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
]
elif CUDA_GRAPHS is not None:
tuning_sequences = CUDA_GRAPHS
elif globals.CUDA_GRAPHS is not None:
tuning_sequences = globals.CUDA_GRAPHS
else:
tuning_sequences = [1, 2, 3, 4, 5, 6, 7]

Expand Down Expand Up @@ -1669,20 +1669,22 @@ def warmup(
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
)

if CUDA_GRAPHS:
if globals.CUDA_GRAPHS:
try:
log_master(
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
logger.info,
f"Cuda Graphs are enabled for sizes {globals.CUDA_GRAPHS}",
)
# Warmup cuda graphs
for bs in CUDA_GRAPHS:
for bs in globals.CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens)
except torch.cuda.OutOfMemoryError:
logger.exception("Decode cuda graph warmup failed")
else:
log_master(
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
logger.info,
f"Cuda Graphs are disabled (CUDA_GRAPHS={globals.CUDA_GRAPHS}).",
)

assert max_input_tokens is not None
Expand Down

0 comments on commit 60b9c18

Please sign in to comment.