Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA out of memory when quantizing llama3.1-405b on 80GiBx8 H100 instance #36

Open
sfc-gh-zhwang opened this issue Aug 7, 2024 · 2 comments

Comments

@sfc-gh-zhwang
Copy link

Some parameters are on the meta device device because they were offloaded to the cpu.
Quantizing weights:   0%|          | 0/1771 [00:00<?, ?it/s]
Quantizing weights:   9%|▉         | 160/1771 [00:00<00:01, 1008.71it/s]
Traceback (most recent call last):
  File "/home/corvo/quantize.py", line 25, in <module>
    main()
  File "/home/corvo/quantize.py", line 20, in main
    model.quantize(examples)
  File "/home/corvo/AutoFP8/auto_fp8/modeling.py", line 113, in quantize
    quantize_weights(self.model, self.quantize_config)
  File "/home/corvo/AutoFP8/auto_fp8/quantize.py", line 237, in quantize_weights
    quant_weight, weight_scale = per_tensor_quantize(linear.weight)
  File "/home/corvo/AutoFP8/auto_fp8/quantize.py", line 56, in per_tensor_quantize
    qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 1 has a total capacity of 79.11 GiB of which 4.88 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 78.50 GiB is allocated by PyTorch, and 1.18 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
@sfc-gh-zhwang
Copy link
Author

code is below

import argparse
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig


def main():
    parser = argparse.ArgumentParser(description='Quantize a pre-trained model using AutoFP8')
    parser.add_argument('--model', type=str, required=True, help='Directory of the pre-trained model')
    parser.add_argument('--output', type=str, required=True, help='Directory to save the quantized model')
    args = parser.parse_args()

    # For dynamic activation scales, there is no need for calibration examples
    examples = []

    quantize_config = BaseQuantizeConfig(
        quant_method="fp8",
        activation_scheme="dynamic"
    )

    model = AutoFP8ForCausalLM.from_pretrained(args.model, quantize_config=quantize_config)
    model.quantize(examples)
    model.save_quantized(args.output)


if __name__ == "__main__":
    main()

@mgoin
Copy link
Member

mgoin commented Aug 7, 2024

Hey @sfc-gh-zhwang AutoFP8 doesn't have support for quantizing a model that big if you can't fit the whole model in one node memory. We recommend using llm-compressor for this, here is an example https://huggingface.co/neuralmagic/Meta-Llama-3.1-405B-Instruct-FP8-dynamic#creation

import torch

from transformers import AutoTokenizer

from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from llmcompressor.transformers.compression.helpers import (  # noqa
    calculate_offload_device_map,
    custom_offload_device_map,
)

recipe = """
quant_stage:
    quant_modifiers:
        QuantizationModifier:
            ignore: ["lm_head"]
            config_groups:
                group_0:
                    weights:
                        num_bits: 8
                        type: float
                        strategy: channel
                        dynamic: false
                        symmetric: true
                    input_activations:
                        num_bits: 8
                        type: float
                        strategy: token
                        dynamic: true
                        symmetric: true
                    targets: ["Linear"]
"""

model_stub = "meta-llama/Meta-Llama-3.1-405B-Instruct"
model_name = model_stub.split("/")[-1]

device_map = calculate_offload_device_map(
    model_stub, reserve_for_hessians=False, num_gpus=8, torch_dtype=torch.float16
)

model = SparseAutoModelForCausalLM.from_pretrained(
    model_stub, torch_dtype=torch.float16, device_map=device_map
)

output_dir = f"./{model_name}-FP8-dynamic"

oneshot(
    model=model,
    recipe=recipe,
    output_dir=output_dir,
    save_compressed=True,
    tokenizer=AutoTokenizer.from_pretrained(model_stub),
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants