Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 3.2 Vision - 90B #1880

Merged
merged 16 commits into from
Oct 29, 2024
Merged

Llama 3.2 Vision - 90B #1880

merged 16 commits into from
Oct 29, 2024

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Oct 22, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

This PR adds llama 90B to torchtune. A few issues were found:

  1. [MITIGATED - NEEDS FIX] 11B was using meta checkpointer. Switching to HF raised an error regarding the peft adapter. For now, we just save the adapter on meta format.
  2. [MITIGATED - NEED ASYNC CKPT] Saving a checkpoint was timing out because rank 0 would stop working. Adding breakpoints solved it.
  3. [FIXED] Optimizer in backward doesnt work with MM. Removed the options from the config.
  4. [NEEDS FIX] QLoRA errors with nproc=8 but works with nproc=2. Set config to 2.
  5. [NEEDS FIX] Saving recipe takes a long time (200~400s). We should make it optional and add save frequency cc: @joecummings
tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_2_vision/90B_qlora  metric_logger=torchtune.training.metric_logging.WandBLogger compile=True log_peak_memory_stats=True enable_activation_checkpointing=True max_steps_per_epoch=25 gradient_accumulation_steps=1 epochs=1 tokenizer.max_seq_len=2048 batch_size=6 num_warmup_steps=0 save_adapter_weights_only=True
[rank4]:     self._model = self._setup_model(
[rank4]:                   ^^^^^^^^^^^^^^^^^^
[rank4]:   File "/data/users/felipemello/torchtune/recipes/lora_finetune_distributed.py", line 456, in _setup_model
[rank4]:     training.shard_model(
[rank4]:   File "/data/users/felipemello/torchtune/torchtune/training/_distributed.py", line 674, in shard_model
[rank4]:     fully_shard(m, **fsdp_kwargs)
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank4]:     updated = func(inp_module, *args, **kwargs)
[rank4]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/_composable/fsdp/fully_shard.py", line 132, in fully_shard
[rank4]:     state._fsdp_param_group = FSDPParamGroup(
[rank4]:                               ^^^^^^^^^^^^^^^
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 114, in __init__
[rank4]:     self.fsdp_params = [
[rank4]:                        ^
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 115, in <listcomp>
[rank4]:     FSDPParam(
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 231, in __init__
[rank4]:     self._init_sharded_param(param, device)
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank4]:     return func(*args, **kwargs)
[rank4]:            ^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 335, in _init_sharded_param
[rank4]:     chunks = _chunk_with_empty(param_data, shard_world_size, dim=0)
[rank4]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/_composable/fsdp/_fsdp_common.py", line 95, in _chunk_with_empty
[rank4]:     chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank4]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torchao/dtypes/nf4tensor.py", line 850, in __torch_function__
[rank4]:     return func(*args, **kwargs)
[rank4]:            ^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank4]:     return fn(*args, **kwargs)
[rank4]:            ^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torchao/dtypes/nf4tensor.py", line 831, in __torch_dispatch__
[rank4]:     return NF4_OPS_TABLE[func](func, args, kwargs)
[rank4]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torchao/dtypes/nf4tensor.py", line 195, in nf4_split
[rank4]:     inner_tensor.numel() % num_chunks == 0
[rank4]: AssertionError: quantization_factor.numel() not divisible by 8

Changelog

What are the changes made in this PR?
*

logs added

NFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Getting full model state dict took 1.25 secs
INFO:torchtune.utils._logging:Retrieving optimizer state dict...
INFO:torchtune.utils._logging:Getting optimizer state dict took 1.86 secs
INFO:torchtune.utils._logging:Adapter checkpoint of size 0.04 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/adapter_-1.pt
WARNING:torchtune.utils._logging:Saving Llama3.2 Vision adapter weights to PEFT format is not supported, saving to torchtune format instead
WARNING:torchtune.utils._logging:PEFT integration for Llama3.2 Vision is not supported, skipping adapter config save
INFO:torchtune.utils._logging:Recipe checkpoint of size 0.09 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/recipe_state.pt
INFO:torchtune.utils._logging:Saving checkpoint took 0.37 secs
NFO:torchtune.utils._logging:Saving checkpoint. This may take some time, depending on the size of your model. Getting full model state dict...
INFO:torchtune.utils._logging:Getting full model state dict took 142.20 secs
INFO:torchtune.utils._logging:Model checkpoint of size 4.60 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0001_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0002_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 5.00 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0003_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.97 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0004_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0005_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0006_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0007_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 5.00 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0008_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.97 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0009_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0010_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0011_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0012_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 5.00 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0013_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.97 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0014_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0015_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0016_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0017_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 5.00 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0018_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.97 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0019_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0020_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0021_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0022_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 5.00 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0023_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.97 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0024_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0025_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0026_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0027_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 5.00 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0028_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.97 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0029_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0030_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0031_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0032_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 5.00 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0033_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.97 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0034_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0035_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.66 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0036_0.pt
INFO:torchtune.utils._logging:Model checkpoint of size 4.88 GB saved to /tmp/Llama-3.2-90B-Vision-Instruct/hf_model_0037_0.pt
INFO:torchtune.utils._logging:Saving final epoch checkpoint.
INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
INFO:torchtune.utils._logging:Saving checkpoint took 412.63 secs

Test plan

tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_2_vision/90B_lora  metric_logger=torchtune.training.metric_logging.WandBLogger compile=True log_peak_memory_stats=True enable_activation_checkpointing=True max_steps_per_epoch=25 gradient_accumulation_steps=1 epochs=1 tokenizer.max_seq_len=2048 batch_size=2 num_warmup_steps=0 save_adapter_weights_only=True

image

Copy link

pytorch-bot bot commented Oct 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1880

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 612dd63 with merge base d3039da (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 22, 2024
from torchtune.modules.tokenizers import parse_hf_tokenizer_json


def llama3_2_vision_transform(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no changes here, just put at the top

image_size: int = 560
) -> DeepFusionModel:
""" Llama 3.2 Vision 11B model
image_size: int = 560,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no changes, just pre commit hook

def llama3_2_vision_transform(
path: str, max_seq_len: int = 8192, image_size: int = 560, special_tokens_path: Optional[str] = None, prompt_template: Optional[_TemplateType] = None
) -> Llama3VisionTransform:
def lora_llama3_2_vision_11b(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no changes, just reordering of functions. It thinks that i am rewriting it. llama3_2_vision_transform is at the top

)


def llama3_2_vision_90b(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied from 11b. updated docstring + a couple of parameters.

11b is this:

decoder = llama3_2_vision_decoder(
        vocab_size=128_256,
        num_layers=32,
        fusion_interval=4,
        num_special_tokens=8,
        num_heads=32,
        num_kv_heads=8,
        embed_dim=4096,
        max_seq_len=131_072,
        encoder_max_seq_len=128_080,  # 20*6404
        rope_base=500000.0,
        intermediate_dim=14336,
    )

90b is this:

decoder = llama3_2_vision_decoder(
        vocab_size=128_256,
        num_layers=100,
        fusion_interval=4,
        num_special_tokens=8,
        num_heads=64,
        num_kv_heads=8,
        embed_dim=8192,
        max_seq_len=131_072,
        encoder_max_seq_len=128_080,  # 20*6404
        rope_base=500000.0,
        intermediate_dim=28672,
    )

encoder is the same, except for decoder_embed_dim, which is 8192 instead of 4096.

Values taken from here: https://huggingface.co/meta-llama/Llama-3.2-90B-Vision/blob/main/config.json

)


def lora_llama3_2_vision_11b(
def lora_llama3_2_vision_90b(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

git is confused :(

lora_llama3_2_vision_11b is still at the top and was not replaced.

this function is a copy of lora_llama3_2_vision_11b

@felipemello1 felipemello1 marked this pull request as ready for review October 26, 2024 03:27
@felipemello1 felipemello1 changed the title [WIP] Llama 3.2 Vision - 90B Llama 3.2 Vision - 90B Oct 26, 2024
@@ -623,6 +644,9 @@ def save_checkpoint(
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
)
log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs")

torch.distributed.barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we add a comment here for why this is necessary? Context will be lost to the ether soon.

self._checkpointer.save_checkpoint(
checkpoint_dict,
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
adapter_only=self._save_adapter_weights_only,
)
log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs")

torch.distributed.barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@@ -446,6 +448,8 @@ def get_full_optimizer_state_dict(
for group_id, sharded_group in sharded_state.items():
group_state = {}
for attr, sharded_tensor in sharded_group.items():
# without this, it may hang forever for +70B models.
torch.distributed.barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have this here, why is it needed in the finetuning script?

@@ -238,7 +238,8 @@ def optim_step(param) -> None:
optim_dict[param].zero_grad()

for p in model.parameters():
p.register_post_accumulate_grad_hook(optim_step)
if p.requires_grad:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🫡

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two main questions from me:

(1) Why do we have 90B full finetune for 4 devices? That should OOM, no?
(2) Bumping @joecummings's comments about use of torch.distributed.barrier(). It's not clear to me which of those we actually need (and why).

Stamping to unblock

@felipemello1
Copy link
Contributor Author

fixed (1) here on github. Will do (2) later. My /fwdproxy_client stopped working :S

@felipemello1 felipemello1 merged commit 1f5e21d into pytorch:main Oct 29, 2024
14 checks passed
@felipemello1 felipemello1 deleted the 90b_llamav branch October 29, 2024 22:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants