diff --git a/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py index 0b9d6df64..6c2ac9094 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py @@ -132,13 +132,15 @@ def get_raw_dataset_split_index(local_rank, class PromptDataset(Dataset): def __init__(self, prompt_dataset, chosen_dataset, reject_dataset, - pad_token_id, train_phase) -> None: + pad_token_id, train_phase, tokenizer, max_seq_len) -> None: super().__init__() self.prompt_dataset = prompt_dataset self.chosen_dataset = chosen_dataset self.reject_dataset = reject_dataset self.pad_token_id = pad_token_id self.train_phase = train_phase + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len def __len__(self): length = len(self.chosen_dataset) @@ -148,16 +150,41 @@ def __len__(self): def __getitem__(self, idx): if self.train_phase == 1: + sentence = self.chosen_dataset[idx] + tokenized_sentence = self.tokenizer(sentence, + max_length=self.max_seq_len, + padding="max_length", + truncation=True, + return_tensors="pt") + tokenized_sentence["input_ids"] = tokenized_sentence["input_ids"].squeeze(0) + tokenized_sentence["attention_mask"] = tokenized_sentence["attention_mask"].squeeze(0) return { - "input_ids": self.chosen_dataset[idx]["input_ids"], - "attention_mask": self.chosen_dataset[idx]["attention_mask"], - "labels": self.chosen_dataset[idx]["input_ids"] + "input_ids": tokenized_sentence["input_ids"], + "attention_mask": tokenized_sentence["attention_mask"], + "labels": tokenized_sentence["input_ids"] } elif self.train_phase == 2: - return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \ - self.reject_dataset[idx]["input_ids"], self.reject_dataset[idx]["attention_mask"] + chosen_sentence = self.chosen_dataset[idx] + reject_sentence = self.reject_dataset[idx] + chosen_token = self.tokenizer(chosen_sentence, + max_length=self.max_seq_len, + padding="max_length", + truncation=True, + return_tensors="pt") + reject_token = self.tokenizer(reject_sentence, + max_length=self.max_seq_len, + padding="max_length", + truncation=True, + return_tensors="pt") + return chosen_token["input_ids"], chosen_token["attention_mask"], \ + reject_token["input_ids"], reject_token["attention_mask"] elif self.train_phase == 3: - return self.prompt_dataset[idx]["input_ids"],self.prompt_dataset[idx]["attention_mask"], \ + prompt_sentence = self.prompt_dataset[idx] + prompt_token = self.tokenizer(prompt_sentence, return_tensors="pt") + for key_word in ["input_ids", "attention_mask"]: + prompt_token[key_word] = prompt_token[ + key_word].squeeze(0).flip(0) + return prompt_token["input_ids"], prompt_token["attention_mask"], \ self.pad_token_id @@ -173,16 +200,7 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer, tmp_data) # the accept response if chosen_sentence is not None: chosen_sentence += end_of_conversation_token - chosen_token = tokenizer(chosen_sentence, - max_length=max_seq_len, - padding="max_length", - truncation=True, - return_tensors="pt") - chosen_token["input_ids"] = chosen_token["input_ids"].squeeze( - 0) - chosen_token["attention_mask"] = chosen_token[ - "attention_mask"].squeeze(0) - chosen_dataset.append(chosen_token) + chosen_dataset.append(chosen_sentence) print( f'Creating dataset {raw_dataset.dataset_name_clean} for {train_phase=} size={len(chosen_dataset)}' ) @@ -197,23 +215,8 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer, if chosen_sentence is not None and reject_sentence is not None: chosen_sentence += end_of_conversation_token # the accept response reject_sentence += end_of_conversation_token - chosen_token = tokenizer(chosen_sentence, - max_length=max_seq_len, - padding="max_length", - truncation=True, - return_tensors="pt") - reject_token = tokenizer(reject_sentence, - max_length=max_seq_len, - padding="max_length", - truncation=True, - return_tensors="pt") - chosen_token["input_ids"] = chosen_token["input_ids"] - chosen_token["attention_mask"] = chosen_token["attention_mask"] - chosen_dataset.append(chosen_token) - - reject_token["input_ids"] = reject_token["input_ids"] - reject_token["attention_mask"] = reject_token["attention_mask"] - reject_dataset.append(reject_token) + chosen_dataset.append(chosen_sentence) + reject_dataset.append(reject_sentence) print( f'Creating dataset {raw_dataset.dataset_name_clean} for {train_phase=} size={len(chosen_dataset)}' ) @@ -226,17 +229,14 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer, if prompt is not None: prompt_token = tokenizer(prompt, return_tensors="pt") if prompt_token["input_ids"].size()[-1] <= max_seq_len: - for key_word in ["input_ids", "attention_mask"]: - prompt_token[key_word] = prompt_token[ - key_word].squeeze(0).flip(0) - prompt_dataset.append(prompt_token) + prompt_dataset.append(prompt) else: filtered += 1 print(f'Creating dataset {raw_dataset.dataset_name_clean} ' f'for {train_phase=} size={len(prompt_dataset)} {filtered=}') return PromptDataset(prompt_dataset, chosen_dataset, reject_dataset, - tokenizer.pad_token_id, train_phase) + tokenizer.pad_token_id, train_phase, tokenizer, max_seq_len) def create_dataset(local_rank, dataset_name, data_split, output_path,