Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: zhehuaichen <[email protected]>
  • Loading branch information
zhehuaichen committed Dec 22, 2023
1 parent 63131d0 commit 318f784
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions nemo/collections/multimodal/speechllm/models/speechllm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,12 +498,10 @@ def loss_func(output_tensor):
loss_for_ub = self.loss_func(scaled_loss_mask, output_tensor)
else:
loss_for_ub = self.loss_func(loss_mask, output_tensor)
# TODO(zhehuai): undo the following because of this error
# /lustre/fsw/swdl/swdl-langspeech/zhehuaic/results/audio-text-llm-debug/debug0cff_kw_crossbi_inf_main2_sel_FC-GPT_llama_ast_en_de_ja_ls_lr5e-4wd1e-5_CosineAnnealing_warmup2000_minlr1e-4_gbs256_mbs8_ep200/error-4526374-0.out
# self.log('raw_lm_loss', loss_for_ub, prog_bar=True, rank_zero_only=True, batch_size=1)
# for k, v in aux_loss.items():
# self.log(k, v, prog_bar=True, rank_zero_only=True, batch_size=1)
# loss_for_ub += v
self.log('raw_lm_loss', loss_for_ub, prog_bar=True, rank_zero_only=True, batch_size=1)
for k, v in aux_loss.items():
self.log(k, v, prog_bar=True, rank_zero_only=True, batch_size=1)
loss_for_ub += v
if validation_step and not self.cfg.data.get('validation_drop_last', True):
num_valid_tokens_in_ub = batch['loss_mask'].sum()
if loss_for_ub.isnan():
Expand Down Expand Up @@ -1222,7 +1220,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
if metric_name == 'bleu':
metric_result = torch.Tensor(
[sacrebleu.corpus_bleu(deduplicated_outputs['preds'], [labels]).score]
)
).cuda()
else:
for pred, label in zip(deduplicated_outputs['preds'], labels):
_ = metric_fn(pred, label)
Expand Down

0 comments on commit 318f784

Please sign in to comment.