Skip to content

Commit

Permalink
llava-next Fp8 (#209)
Browse files Browse the repository at this point in the history
Signed-off-by: yuanwu <[email protected]>
Co-authored-by: Thanaji Rao Thakkalapelli <[email protected]>
Co-authored-by: regisss <[email protected]>
  • Loading branch information
3 people authored Aug 26, 2024
1 parent 55d60a1 commit 2985503
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 6 deletions.
46 changes: 43 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ limitations under the License.
- [LLama 7b FP8 on 1 Gaudi2 card](#llama-7b-fp8-on-1-gaudi2-card)
- [LLama 70b BF16 on 8 Gaudi2 card](#llama-70b-bf16-on-8-gaudi2-card)
- [LLama 70b FP8 on 8 Gaudi2 card](#llama-70b-fp8-on-8-gaudi2-card)
- [LLava-next 7B BF16 on 1 Gaudi2 card](#llava-next-7b-bf16-on-1-gaudi2-card)
- [Llava-next](#llava-next)
- [llava-v1.6-mistral-7b-hf BF16 on 1 Gaudi2 card](#llava-v16-mistral-7b-hf-bf16-on-1-gaudi2-card)
- [llava-v1.6-mistral-7b-hf FP8 on 1 Gaudi2 card](#llava-v16-mistral-7b-hf-fp8-on-1-gaudi2-card)
- [Environment variables](#environment-variables)
- [Profiler](#profiler)

Expand Down Expand Up @@ -264,8 +266,9 @@ docker run -p 8080:80 \
--sharded true \
--num-shard 8
```
### Llava-next
### LLava-next 7B BF16 on 1 Gaudi2 card
#### llava-v1.6-mistral-7b-hf BF16 on 1 Gaudi2 card
An image usually accounts for 2000 input tokens. For example, an image of size 512x512 is represented by 2800 tokens. Thus, `max-input-tokens` must be larger than the number of tokens associated to the image. Otherwise the image may be truncated. We set `BASE_IMAGE_TOKENS=2048` as the default image token number. This is the minimum value of `max-input-tokens`. You can override the environment variable `BASE_IMAGE_TOKENS` to change this value. The warmup will generate graphs with input length from `BASE_IMAGE_TOKENS` to `max-input-tokens`. For LLava-next 7B, the value of `max-batch-prefill-tokens` is 16384, which is calcualted as follows: `prefill_batch_size` = `max-batch-prefill-tokens` / `max-input-tokens`.
Expand All @@ -281,7 +284,44 @@ docker run -p 8080:80 \
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
-e HF_HUB_ENABLE_HF_TRANSFER=1 \
-e HUGGING_FACE_HUB_TOKEN=$hf_token \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
--cap-add=sys_nice \
--ipc=host \
ghcr.io/huggingface/tgi-gaudi:2.0.1 \
--model-id $model \
--max-input-tokens 4096 \
--max-batch-prefill-tokens 16384 \
--max-total-tokens 8192
```
Send the simple request.
```bash
curl -N 127.0.0.1:8080/generate_stream \
-X POST \
-d '{"inputs":"![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)What is this a picture of?\n\n","parameters":{"max_new_tokens":16, "seed": 42}}' \
-H 'Content-Type: application/json'
```
Multi-card Llava-next inference is currently not supported.
#### llava-v1.6-mistral-7b-hf FP8 on 1 Gaudi2 card
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
hf_token=YOUR_ACCESS_TOKEN # HF access token
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e HABANA_VISIBLE_DEVICES=all \
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
-e HF_HUB_ENABLE_HF_TRANSFER=1 \
-e HUGGING_FACE_HUB_TOKEN=$hf_token \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e USE_FLASH_ATTENTION=true \
-e FLASH_ATTENTION_RECOMPUTE=true \
--cap-add=sys_nice \
--ipc=host \
ghcr.io/huggingface/tgi-gaudi:2.0.1 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
):

if token_idx is not None:
Expand All @@ -107,6 +108,8 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=use_flash_attention,
)

logits = outputs[0]
Expand Down Expand Up @@ -145,7 +148,7 @@ def prepare_inputs_for_generation(
**kwargs,
)
else:

use_flash_attention = kwargs.get("use_flash_attention", False)
position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
Expand All @@ -166,7 +169,7 @@ def prepare_inputs_for_generation(
batch_size, num_patches, num_channels, height, width = pixel_values.shape
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
image_features = self.vision_tower(
reshaped_pixel_values, output_hidden_states=True
reshaped_pixel_values, output_hidden_states=True, use_flash_attention=use_flash_attention
)

selected_image_feature = image_features.hidden_states[vision_feature_layer]
Expand Down Expand Up @@ -279,6 +282,7 @@ def prepare_inputs_for_generation(
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"use_flash_attention": use_flash_attention,
}
)

Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def __init__(
"return_dict": True,
}

if model.config.model_type in ["llama", "mistral"]:
if model.config.model_type in ["llama", "mistral", "llava_next"]:
kwargs["attn_softmax_bf16"] = True
kwargs["trim_logits"] = True

Expand Down

0 comments on commit 2985503

Please sign in to comment.