forked from pytorch/audio
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add code for DDP tutorial series [PR 3 / 3] (pytorch#1069)
* Adds files for minGPT training with DDP * filtered-clone, update script path, update readme * add refs to karpathy's repo * add training data * add AMP training * delete raw data file, update index.rst * Update gpt2_train_cfg.yaml
- Loading branch information
Showing
12 changed files
with
40,635 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# minGPT-DDP | ||
|
||
Code accompanying the tutorial at https://pytorch.org/tutorials/intermediate/ddp_minGPT.html for training a GPT-like model with Distributed Data Parallel (DDP) in PyTorch. | ||
|
||
Files marked with an asterisk (*) are adapted from the minGPT repo (https://github.com/karpathy/minGPT). | ||
|
||
- [trainer.py](mingpt/trainer.py) includes the Trainer class that runs the distributed training iterations on the model with the provided dataset. | ||
- [model.py *](mingpt/model.py) defines the model architecture. | ||
- [char_dataset.py *](mingpt/char_dataset.py) contains the `Dataset`class for a character-level dataset. | ||
- [gpt2_train_cfg.yaml](mingpt/gpt2_train_cfg.yaml) contains the configurations for data, model, optimizer and training run. | ||
- [main.py](mingpt/main.py) is the entry point to the trainig job. It sets up the DDP process group, reads all the configurations and runs the training job. | ||
- [slurm/](mingpt/slurm) contains files for setting up an AWS cluster and the slurm script to run multinode training. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
import fsspec | ||
from dataclasses import dataclass | ||
|
||
""" | ||
Adapted from https://github.com/karpathy/minGPT/blob/master/projects/chargpt/chargpt.py | ||
""" | ||
|
||
@dataclass | ||
class DataConfig: | ||
path: str = None | ||
block_size: int = None | ||
train_split: float = None | ||
truncate: float = 1.0 | ||
|
||
class CharDataset(Dataset): | ||
|
||
def __init__(self, data_cfg: DataConfig): #data_path: str, block_size): | ||
data = fsspec.open(data_cfg.path).open().read().decode('utf-8') | ||
data = data[ : int(len(data) * data_cfg.truncate)] | ||
|
||
chars = sorted(list(set(data))) | ||
data_size, vocab_size = len(data), len(chars) | ||
print('Data has %d characters, %d unique.' % (data_size, vocab_size)) | ||
|
||
self.stoi = {ch: i for i, ch in enumerate(chars)} | ||
self.itos = {i: ch for i, ch in enumerate(chars)} | ||
self.block_size = data_cfg.block_size | ||
self.vocab_size = vocab_size | ||
self.data = data | ||
|
||
def __len__(self): | ||
return len(self.data) - self.block_size | ||
|
||
def __getitem__(self, idx): | ||
# grab a chunk of (block_size + 1) characters from the data | ||
chunk = self.data[idx:idx + self.block_size + 1] | ||
# encode every character to an integer | ||
dix = [self.stoi[s] for s in chunk] | ||
x = torch.tensor(dix[:-1], dtype=torch.long) | ||
y = torch.tensor(dix[1:], dtype=torch.long) | ||
return x, y |
Oops, something went wrong.