diff --git a/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py b/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py index 3c84f4b07..2838f9dc0 100644 --- a/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py +++ b/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 +import os # DeepSpeed Team -from datasets import load_dataset +from datasets import load_dataset, load_from_disk from torch.utils.data import Subset import re @@ -15,7 +16,9 @@ def __init__(self, output_path, seed, local_rank, dataset_name): self.output_path = output_path self.seed = seed self.local_rank = local_rank - if not dataset_name == 'local/jsonfile': + if os.path.exists(dataset_name): + self.raw_datasets = load_from_disk(dataset_name) + elif not dataset_name == 'local/jsonfile': self.raw_datasets = load_dataset(dataset_name) def get_train_data(self): diff --git a/applications/DeepSpeed-Chat/training/utils/model/reward_model.py b/applications/DeepSpeed-Chat/training/utils/model/reward_model.py index f11d8787a..60d063b18 100644 --- a/applications/DeepSpeed-Chat/training/utils/model/reward_model.py +++ b/applications/DeepSpeed-Chat/training/utils/model/reward_model.py @@ -29,15 +29,15 @@ def __init__(self, self.config.n_embd = self.config.hidden_size if hasattr( self.config, "hidden_size") else self.config.n_embd self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) - self.rwtranrsformer = base_model + self.rwtransformer = base_model self.PAD_ID = tokenizer.pad_token_id self.compute_fp32_loss = compute_fp32_loss def gradient_checkpointing_enable(self): - self.rwtranrsformer.gradient_checkpointing_enable() + self.rwtransformer.gradient_checkpointing_enable() def gradient_checkpointing_disable(self): - self.rwtranrsformer.gradient_checkpointing_disable() + self.rwtransformer.gradient_checkpointing_disable() def forward(self, input_ids=None, @@ -54,7 +54,7 @@ def forward(self, else: kwargs = dict(head_mask=head_mask) - transformer_outputs = self.rwtranrsformer( + transformer_outputs = self.rwtransformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -140,7 +140,7 @@ def forward_value(self, else: kwargs = dict(head_mask=head_mask) - transformer_outputs = self.rwtranrsformer( + transformer_outputs = self.rwtransformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask,