Skip to content

Commit

Permalink
use online tokenizer to avoid oom
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Jan 1, 2024
1 parent 8e4cdd8 commit 17c8243
Showing 1 changed file with 39 additions and 39 deletions.
78 changes: 39 additions & 39 deletions applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ def get_raw_dataset_split_index(local_rank,
class PromptDataset(Dataset):

def __init__(self, prompt_dataset, chosen_dataset, reject_dataset,
pad_token_id, train_phase) -> None:
pad_token_id, train_phase, tokenizer, max_seq_len) -> None:
super().__init__()
self.prompt_dataset = prompt_dataset
self.chosen_dataset = chosen_dataset
self.reject_dataset = reject_dataset
self.pad_token_id = pad_token_id
self.train_phase = train_phase
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len

def __len__(self):
length = len(self.chosen_dataset)
Expand All @@ -148,16 +150,41 @@ def __len__(self):

def __getitem__(self, idx):
if self.train_phase == 1:
sentence = self.chosen_dataset[idx]
tokenized_sentence = self.tokenizer(sentence,
max_length=self.max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt")
tokenized_sentence["input_ids"] = tokenized_sentence["input_ids"].squeeze(0)
tokenized_sentence["attention_mask"] = tokenized_sentence["attention_mask"].squeeze(0)
return {
"input_ids": self.chosen_dataset[idx]["input_ids"],
"attention_mask": self.chosen_dataset[idx]["attention_mask"],
"labels": self.chosen_dataset[idx]["input_ids"]
"input_ids": tokenized_sentence["input_ids"],
"attention_mask": tokenized_sentence["attention_mask"],
"labels": tokenized_sentence["input_ids"]
}
elif self.train_phase == 2:
return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \
self.reject_dataset[idx]["input_ids"], self.reject_dataset[idx]["attention_mask"]
chosen_sentence = self.chosen_dataset[idx]
reject_sentence = self.reject_dataset[idx]
chosen_token = self.tokenizer(chosen_sentence,
max_length=self.max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt")
reject_token = self.tokenizer(reject_sentence,
max_length=self.max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt")
return chosen_token["input_ids"], chosen_token["attention_mask"], \
reject_token["input_ids"], reject_token["attention_mask"]
elif self.train_phase == 3:
return self.prompt_dataset[idx]["input_ids"],self.prompt_dataset[idx]["attention_mask"], \
prompt_sentence = self.prompt_dataset[idx]
prompt_token = self.tokenizer(prompt_sentence, return_tensors="pt")
for key_word in ["input_ids", "attention_mask"]:
prompt_token[key_word] = prompt_token[
key_word].squeeze(0).flip(0)
return prompt_token["input_ids"], prompt_token["attention_mask"], \
self.pad_token_id


Expand All @@ -173,16 +200,7 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
tmp_data) # the accept response
if chosen_sentence is not None:
chosen_sentence += end_of_conversation_token
chosen_token = tokenizer(chosen_sentence,
max_length=max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt")
chosen_token["input_ids"] = chosen_token["input_ids"].squeeze(
0)
chosen_token["attention_mask"] = chosen_token[
"attention_mask"].squeeze(0)
chosen_dataset.append(chosen_token)
chosen_dataset.append(chosen_sentence)
print(
f'Creating dataset {raw_dataset.dataset_name_clean} for {train_phase=} size={len(chosen_dataset)}'
)
Expand All @@ -197,23 +215,8 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
if chosen_sentence is not None and reject_sentence is not None:
chosen_sentence += end_of_conversation_token # the accept response
reject_sentence += end_of_conversation_token
chosen_token = tokenizer(chosen_sentence,
max_length=max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt")
reject_token = tokenizer(reject_sentence,
max_length=max_seq_len,
padding="max_length",
truncation=True,
return_tensors="pt")
chosen_token["input_ids"] = chosen_token["input_ids"]
chosen_token["attention_mask"] = chosen_token["attention_mask"]
chosen_dataset.append(chosen_token)

reject_token["input_ids"] = reject_token["input_ids"]
reject_token["attention_mask"] = reject_token["attention_mask"]
reject_dataset.append(reject_token)
chosen_dataset.append(chosen_sentence)
reject_dataset.append(reject_sentence)
print(
f'Creating dataset {raw_dataset.dataset_name_clean} for {train_phase=} size={len(chosen_dataset)}'
)
Expand All @@ -226,17 +229,14 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
if prompt is not None:
prompt_token = tokenizer(prompt, return_tensors="pt")
if prompt_token["input_ids"].size()[-1] <= max_seq_len:
for key_word in ["input_ids", "attention_mask"]:
prompt_token[key_word] = prompt_token[
key_word].squeeze(0).flip(0)
prompt_dataset.append(prompt_token)
prompt_dataset.append(prompt)
else:
filtered += 1
print(f'Creating dataset {raw_dataset.dataset_name_clean} '
f'for {train_phase=} size={len(prompt_dataset)} {filtered=}')

return PromptDataset(prompt_dataset, chosen_dataset, reject_dataset,
tokenizer.pad_token_id, train_phase)
tokenizer.pad_token_id, train_phase, tokenizer, max_seq_len)


def create_dataset(local_rank, dataset_name, data_split, output_path,
Expand Down

0 comments on commit 17c8243

Please sign in to comment.