Skip to content

Commit

Permalink
fix typo infer_seq_lenght -> infer_seq_length (NVIDIA#9370)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
Co-authored-by: Marc Romeyn <[email protected]>
Signed-off-by: Boxiang Wang <[email protected]>
  • Loading branch information
2 people authored and BoxiangW committed Jun 5, 2024
1 parent 6cca873 commit 07b8a7b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def forward(
_forward_step = forward_step or self.forward_step
_loss_reduction = loss_reduction or self.loss_reduction
_micro_batch_size: int = micro_batch_size or self.infer_micro_batch_size(data)
_seq_length: int = seq_length or self.infer_seq_lenght(data)
_seq_length: int = seq_length or self.infer_seq_length(data)
_num_microbatches: int = num_microbatches or self.infer_num_microbatches(data)

pipeline = self.pipeline
Expand Down Expand Up @@ -396,7 +396,7 @@ def infer_micro_batch_size(self, data: Union[DataT, Iterator[DataT], List[Iterat

raise ValueError("Cannot infer `micro_batch_size` from data, please specify it manually")

def infer_seq_lenght(self, data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]]) -> int:
def infer_seq_length(self, data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]]) -> int:
if hasattr(data, "seq_length"):
return data.seq_length
if hasattr(data, "data_config"):
Expand All @@ -406,10 +406,10 @@ def infer_seq_lenght(self, data: Union[DataT, Iterator[DataT], List[Iterator[Dat
# TODO: Check if at least 2 dims
return data.size(1)
elif isinstance(data, dict):
return self.infer_seq_lenght(next(iter(data.values())))
return self.infer_seq_length(next(iter(data.values())))
elif isinstance(data, (list, tuple)) and len(data) > 0:
_tensor: Tensor = data[0]
return self.infer_seq_lenght(_tensor)
return self.infer_seq_length(_tensor)

raise ValueError("Cannot infer `seq_length` from data, please specify it manually")

Expand Down

0 comments on commit 07b8a7b

Please sign in to comment.