diff --git a/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py b/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py index 3c84f4b07..2838f9dc0 100644 --- a/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py +++ b/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 +import os # DeepSpeed Team -from datasets import load_dataset +from datasets import load_dataset, load_from_disk from torch.utils.data import Subset import re @@ -15,7 +16,9 @@ def __init__(self, output_path, seed, local_rank, dataset_name): self.output_path = output_path self.seed = seed self.local_rank = local_rank - if not dataset_name == 'local/jsonfile': + if os.path.exists(dataset_name): + self.raw_datasets = load_from_disk(dataset_name) + elif not dataset_name == 'local/jsonfile': self.raw_datasets = load_dataset(dataset_name) def get_train_data(self):