Skip to content

Commit

Permalink
Handle in HF AutoTokenizer with pad_token=None (NVIDIA#8068)
Browse files Browse the repository at this point in the history
* check if none before encode special token

Signed-off-by: Huiying Li <[email protected]>

* handle when pad_id does not exist for hf Autotokenizer

Signed-off-by: Huiying Li <[email protected]>

* refactor pad_id assignment to use getattr for cleaner code readability

Signed-off-by: Huiying Li <[email protected]>

---------

Signed-off-by: Huiying Li <[email protected]>
Co-authored-by: yaoyu-33 <[email protected]>
  • Loading branch information
HuiyingLi and yaoyu-33 authored Feb 27, 2024
1 parent 564abb4 commit 60fc43f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
14 changes: 14 additions & 0 deletions nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,20 @@ def vocab(self):

@property
def pad_id(self):
if getattr(self, 'pad_token') is None:
return None
return self.tokens_to_ids([getattr(self, 'pad_token')])[0]

@property
def bos_id(self):
if getattr(self, 'bos_token') is None:
return None
return self.tokens_to_ids([getattr(self, 'bos_token')])[0]

@property
def eos_id(self):
if getattr(self, 'eos_token') is None:
return None
return self.tokens_to_ids([getattr(self, 'eos_token')])[0]

@property
Expand All @@ -235,18 +241,26 @@ def eod(self):

@property
def sep_id(self):
if getattr(self, 'sep_token') is None:
return None
return self.tokens_to_ids([getattr(self, 'sep_token')])[0]

@property
def cls_id(self):
if getattr(self, 'cls_token') is None:
return None
return self.tokens_to_ids([getattr(self, 'cls_token')])[0]

@property
def unk_id(self):
if getattr(self, 'unk_token') is None:
return None
return self.tokens_to_ids([getattr(self, 'unk_token')])[0]

@property
def mask_id(self):
if getattr(self, 'mask_token') is None:
return None
return self.tokens_to_ids([getattr(self, 'mask_token')])[0]

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def main():
output_bin_files[key],
impl=args.dataset_impl,
chunk_size=args.chunk_size,
pad_id=tokenizer.pad_id if hasattr(tokenizer, "pad_id") else 0,
pad_id=tokenizer.pad_id if getattr(tokenizer, "pad_id", None) is not None else 0,
retrieval_db=args.retrieval_db,
vocab_size=tokenizer.vocab_size,
stride=args.chunk_stride_size,
Expand Down

0 comments on commit 60fc43f

Please sign in to comment.