-
Notifications
You must be signed in to change notification settings - Fork 442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama 3.2 Vision - 90B #1880
Llama 3.2 Vision - 90B #1880
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1880
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 612dd63 with merge base d3039da (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
from torchtune.modules.tokenizers import parse_hf_tokenizer_json | ||
|
||
|
||
def llama3_2_vision_transform( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no changes here, just put at the top
image_size: int = 560 | ||
) -> DeepFusionModel: | ||
""" Llama 3.2 Vision 11B model | ||
image_size: int = 560, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no changes, just pre commit hook
def llama3_2_vision_transform( | ||
path: str, max_seq_len: int = 8192, image_size: int = 560, special_tokens_path: Optional[str] = None, prompt_template: Optional[_TemplateType] = None | ||
) -> Llama3VisionTransform: | ||
def lora_llama3_2_vision_11b( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no changes, just reordering of functions. It thinks that i am rewriting it. llama3_2_vision_transform is at the top
) | ||
|
||
|
||
def llama3_2_vision_90b( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copied from 11b. updated docstring + a couple of parameters.
11b is this:
decoder = llama3_2_vision_decoder(
vocab_size=128_256,
num_layers=32,
fusion_interval=4,
num_special_tokens=8,
num_heads=32,
num_kv_heads=8,
embed_dim=4096,
max_seq_len=131_072,
encoder_max_seq_len=128_080, # 20*6404
rope_base=500000.0,
intermediate_dim=14336,
)
90b is this:
decoder = llama3_2_vision_decoder(
vocab_size=128_256,
num_layers=100,
fusion_interval=4,
num_special_tokens=8,
num_heads=64,
num_kv_heads=8,
embed_dim=8192,
max_seq_len=131_072,
encoder_max_seq_len=128_080, # 20*6404
rope_base=500000.0,
intermediate_dim=28672,
)
encoder is the same, except for decoder_embed_dim, which is 8192 instead of 4096.
Values taken from here: https://huggingface.co/meta-llama/Llama-3.2-90B-Vision/blob/main/config.json
) | ||
|
||
|
||
def lora_llama3_2_vision_11b( | ||
def lora_llama3_2_vision_90b( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
git is confused :(
lora_llama3_2_vision_11b is still at the top and was not replaced.
this function is a copy of lora_llama3_2_vision_11b
@@ -623,6 +644,9 @@ def save_checkpoint( | |||
epoch=epoch, | |||
intermediate_checkpoint=intermediate_checkpoint, | |||
) | |||
log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") | |||
|
|||
torch.distributed.barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can we add a comment here for why this is necessary? Context will be lost to the ether soon.
self._checkpointer.save_checkpoint( | ||
checkpoint_dict, | ||
epoch=epoch, | ||
intermediate_checkpoint=intermediate_checkpoint, | ||
adapter_only=self._save_adapter_weights_only, | ||
) | ||
log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") | ||
|
||
torch.distributed.barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
@@ -446,6 +448,8 @@ def get_full_optimizer_state_dict( | |||
for group_id, sharded_group in sharded_state.items(): | |||
group_state = {} | |||
for attr, sharded_tensor in sharded_group.items(): | |||
# without this, it may hang forever for +70B models. | |||
torch.distributed.barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have this here, why is it needed in the finetuning script?
@@ -238,7 +238,8 @@ def optim_step(param) -> None: | |||
optim_dict[param].zero_grad() | |||
|
|||
for p in model.parameters(): | |||
p.register_post_accumulate_grad_hook(optim_step) | |||
if p.requires_grad: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🫡
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two main questions from me:
(1) Why do we have 90B full finetune for 4 devices? That should OOM, no?
(2) Bumping @joecummings's comments about use of torch.distributed.barrier()
. It's not clear to me which of those we actually need (and why).
Stamping to unblock
fixed (1) here on github. Will do (2) later. My /fwdproxy_client stopped working :S |
Context
What is the purpose of this PR? Is it to
This PR adds llama 90B to torchtune. A few issues were found:
Changelog
What are the changes made in this PR?
*
logs added
Test plan