-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove pad_max_tiles in CLIP #1836
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1836
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6f555db with merge base 4107cc4 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two comments but otherwise fine.
@@ -32,6 +32,7 @@ tokenizer: | |||
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform | |||
path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model | |||
image_size: 560 | |||
max_seq_len: 8192 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you limiting to 8k?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not for the model but for this dataset. There are a few rows that are very long and cause a huge memory spike, so this limits those rows. Most of the rows are under 8k
@@ -355,6 +359,13 @@ def padded_collate_tiled_images_and_mask( | |||
for sample in batch | |||
for image in sample["encoder_input"]["images"] | |||
) | |||
if pad_max_tiles is not None: | |||
if pad_max_tiles < max_num_tiles: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_num_tiles
is such a misleading name.
Summary: Following changes in torchtune: - pytorch/torchtune#1836 - pytorch/torchtune#1853 Update ET downstream and remove pad-max-tiles from preprocess. Pull Request resolved: #6295 Test Plan: With AOTI tests commented out (not working atm): ``` python -m unittest examples/models/llama3_2_vision/preprocess/test_preprocess.py ... ---------------------------------------------------------------------- Ran 4 tests in 21.129s OK ``` Reviewed By: larryliu0820 Differential Revision: D64481012 Pulled By: lucylq fbshipit-source-id: e822c235c5555e0682d181c4c482dec7c170c96e
Context
What is the purpose of this PR? Is it to
Currently we have pad_max_tiles in the CLIP transform that will pad the image to 4 tiles. The CLIP model doesn't mask out the padding tiles, so by default we always added padding tiles in case the model relied on these extra tokens. Though in testing, I find the model completely ignores the padding tiles and it's not necessary to include them unless needed for the batch. Apart from this, I found that doing the pad_max_tiles in the CLIP transform instead of padded_collate_tiled_images_and_mask leads to a subtle bug where the cross attention mask is not aware, downstream, of which tiles are padding and should be masked and which shouldn't be.
Changelog
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
This shows three runs for
tune run full_finetune_single_device --config llama3_2_vision/11B_full_single_device
. The "original" run is what's currently in main, the "new_pad_to_4" is with the fixed padding but adding extra padding tiles, and the "new_pad_to_batch" only pads to the max tiles in the batch. All three of these have the same loss, suggesting that the padding isn't important for model quality, but pad_to_batch gets almost double the qps on the ocr dataset since the cross attention sequence lengths can be much shorter.