-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable loading ckpt for t0 finetuning #309
base: main
Are you sure you want to change the base?
Conversation
[0, 0, 0, 1, 0, 0, 0], | ||
[0, 0, 0, 1, 1, 0, 0], | ||
[0, 0, 0, 1, 1, 1, 0], | ||
[0, 0, 0, 0, 0, 0, 0]]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[0, 0, 0, 0, 0, 0, 0]]]] | |
[0, 0, 0, 0, 0, 0, 1]]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is a 1
, because the last row & column is 100% padding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum I'm wondering if this doesn't screw something up. Essentially you're going to compute softmax on a row with only zeros ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last row & last col are the attention scores of the last token with respect to the last token. Since the last token is masked out in our loss_mask it doesn't matter I think.
Also it's a row with only -inf, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No you compute softmax, what should be the result of the softmax of a row full of masked out values .... It feels like that would return lots of Nans.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we fill it with -inf
?
And the softmax of a row where all values are the same is just 1/n
, no? Where would it cause NaNs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can try writing a test but I would be pretty sure that the actual results are 0. (with current kernel)
finetune_t0_non_causal_decoder.py
Outdated
loss_mask *= loss_on_targets_only * loss_on_non_pad_only | ||
|
||
attention_mask = get_packed_attention_mask( | ||
# Run non-causal decoder | ||
is_causal=False, | ||
causal_mask=~(causal_mask.bool()), | ||
is_causal=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's rename this file finetune_t0_causal_decoder
then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about just finetune_t0.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right but do we hardcode this everytime? I'd rather have this one be the script for causal decoder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added an argument prefixlm
[0, 0, 0, 1, 0, 0, 0], | ||
[0, 0, 0, 1, 1, 0, 0], | ||
[0, 0, 0, 1, 1, 1, 0], | ||
[0, 0, 0, 0, 0, 0, 0]]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No you compute softmax, what should be the result of the softmax of a row full of masked out values .... It feels like that would return lots of Nans.
* Tmp lossseq * Efficient loss normalization * Reuse variable * Simplify division * Add norm_target_loss arg * Clarify loss on targets & remove kwarg * Loss mask is already float * Move norm to batch pipe * Reshape loss mask * Move view
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! Some things I think shouldn't be in this PR.
group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true", | ||
help='Some objectives require us to sample loss_mask. This might introduce bias towards ' | ||
'specific positions. This option tries to un-bias the loss by reweighting loss on specific ' | ||
'positions based on how frequently we train on that position.' | ||
'This is mostly used for prefix_lm training') | ||
group.add_argument("--noise-density", type=float, default=None, help="Span corruption noise density") | ||
group.add_argument("--mean-noise-span-length", type=int, default=None, help="Span corruption mean noise span length") | ||
group.add_argument("--prefixlm", action='store_true', help="Whether to train a PrefixLM - To be used with finetune t0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah actually let's remove that option. I don't think we've trained one successfully. We'll probably do as people have shown that it works but in another PR IMO.
@@ -274,8 +274,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True | |||
load_dir = getattr(args, load_arg) | |||
|
|||
if args.deepspeed: | |||
load_optimizer_states = False if args.no_load_optim else True | |||
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states) | |||
load_optimizer_states = not args.no_load_optim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just use no_load_optim directly in the method
@@ -342,7 +342,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True | |||
set_checkpoint_version(state_dict.get('checkpoint_version', 0)) | |||
|
|||
# Set iteration. | |||
if args.finetune or release: | |||
if args.finetune or release or args.reset_progress: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it that we didn't set finetune to True?
@@ -361,7 +361,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True | |||
# Check arguments. | |||
assert args.consumed_train_samples == 0 | |||
assert args.consumed_valid_samples == 0 | |||
if 'args' in state_dict: | |||
if 'args' in state_dict and not args.reset_progress: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment? Typically this is only used because the metadata loading mechanism screws with us.
@@ -399,7 +399,7 @@ def _build_index_mappings( | |||
shuffle_idx_filename = _filename + '_decoder_packed_shuffle_idx.npy' | |||
|
|||
# Build the indexed mapping if not exist. | |||
if torch.distributed.get_rank() == 0: | |||
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Afaik you added this code; I think it was for running tests or sth
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arf probably because I wanted to use the data loader only ... Maybe let's remove for now because we should be assuming that torch distributed is always initialized at least in Meg-DS IMO.
assert counts[0].item() == ( | ||
torch.distributed.get_world_size() // | ||
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) | ||
if torch.distributed.is_initialized(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
loss_mask = loss_mask.view(-1) | ||
loss_mask = fast_normalize(loss_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe reshaping to the orignal structure is better API? It's better to bave the same shapes as label IMO (we still still do flatten everything)
@@ -241,7 +241,7 @@ def test_decoder_packed_mtf_dataloader(self): | |||
last_padding_size = len([None for segment_id in items["decoder_segment_ids"][micro_batch_size - 1] if segment_id == 0]) | |||
|
|||
|
|||
def test_finetune_t0_non_causal_decoder_get_batch_pipe(self): | |||
def test_finetune_t0_get_batch_pipe(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah let's make it so that the script is causal decoder specific. Let's figure out non causal decoder later on.
group.add_argument('--append-bos', action='store_true', | ||
help='Append a bos token to the end of a document.') | ||
group.add_argument('--prepend-space', action='store_true', | ||
help='Prepends a space to the beginning of a document') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a mention in which context it's useful, typically it is when you compute targets.
No description provided.