-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add multinode notebook for llama (#2652)
* Add multinode notebook for llama * fix formatting issues
- Loading branch information
1 parent
516cef4
commit 557069e
Showing
3 changed files
with
916 additions
and
0 deletions.
There are no files selected for viewing
58 changes: 58 additions & 0 deletions
58
...-models/system/finetune/Llama-notebooks/multinode-text-classification/download-dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# import library to parse command line arguments | ||
import argparse, os | ||
|
||
parser = argparse.ArgumentParser() | ||
# add an argument to specify a dataset name to download | ||
parser.add_argument( | ||
"--dataset", type=str, default="dair-ai/emotion", help="dataset name" | ||
) | ||
# add an argument to specify a dataset name to download | ||
parser.add_argument( | ||
"--dataset_subset", type=str, default="split", help="dataset subset name" | ||
) | ||
# add an argument to specify the directory to download the dataset to | ||
parser.add_argument( | ||
"--download_dir", | ||
type=str, | ||
default="data", | ||
help="directory to download the dataset to", | ||
) | ||
args = parser.parse_args() | ||
|
||
# create the download directory if it does not exist | ||
if not os.path.exists(args.download_dir): | ||
os.makedirs(args.download_dir) | ||
|
||
|
||
# import hugging face datasets library | ||
from datasets import load_dataset, get_dataset_split_names | ||
from functools import partial | ||
|
||
for split in get_dataset_split_names(args.dataset): | ||
# load the split of the dataset | ||
dataset = load_dataset(args.dataset, split=split) | ||
# save the split of the dataset to the download directory as json lines file | ||
dataset.to_json(os.path.join(args.download_dir, f"{split}.jsonl")) | ||
# print dataset features | ||
|
||
# get label2id and id2label mapping | ||
|
||
# get any split of data | ||
split = get_dataset_split_names(args.dataset)[0] | ||
dataset = load_dataset(args.dataset, split=split) | ||
|
||
labels = dataset.features["label"].names | ||
|
||
id2label = {} | ||
label2id = {} | ||
|
||
for i, label in enumerate(labels): | ||
id2label[i] = label | ||
label2id[label] = i | ||
|
||
label_mapping = {"id2label": id2label, "label2id": label2id} | ||
|
||
import json | ||
|
||
with open(os.path.join(args.download_dir, "label.json"), "w") as f: | ||
json.dump(label_mapping, f) |
Oops, something went wrong.