Skip to content

Commit

Permalink
Add code for DDP tutorial series [PR 3 / 3] (pytorch#1069)
Browse files Browse the repository at this point in the history
* Adds files for minGPT training with DDP

* filtered-clone, update script path, update readme

* add refs to karpathy's repo

* add training data

* add AMP training

* delete raw data file, update index.rst

* Update gpt2_train_cfg.yaml
  • Loading branch information
subramen authored Sep 26, 2022
1 parent 84b7588 commit d91085d
Show file tree
Hide file tree
Showing 12 changed files with 40,635 additions and 0 deletions.
12 changes: 12 additions & 0 deletions distributed/minGPT-ddp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# minGPT-DDP

Code accompanying the tutorial at https://pytorch.org/tutorials/intermediate/ddp_minGPT.html for training a GPT-like model with Distributed Data Parallel (DDP) in PyTorch.

Files marked with an asterisk (*) are adapted from the minGPT repo (https://github.com/karpathy/minGPT).

- [trainer.py](mingpt/trainer.py) includes the Trainer class that runs the distributed training iterations on the model with the provided dataset.
- [model.py *](mingpt/model.py) defines the model architecture.
- [char_dataset.py *](mingpt/char_dataset.py) contains the `Dataset`class for a character-level dataset.
- [gpt2_train_cfg.yaml](mingpt/gpt2_train_cfg.yaml) contains the configurations for data, model, optimizer and training run.
- [main.py](mingpt/main.py) is the entry point to the trainig job. It sets up the DDP process group, reads all the configurations and runs the training job.
- [slurm/](mingpt/slurm) contains files for setting up an AWS cluster and the slurm script to run multinode training.
43 changes: 43 additions & 0 deletions distributed/minGPT-ddp/mingpt/char_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from torch.utils.data import Dataset
import fsspec
from dataclasses import dataclass

"""
Adapted from https://github.com/karpathy/minGPT/blob/master/projects/chargpt/chargpt.py
"""

@dataclass
class DataConfig:
path: str = None
block_size: int = None
train_split: float = None
truncate: float = 1.0

class CharDataset(Dataset):

def __init__(self, data_cfg: DataConfig): #data_path: str, block_size):
data = fsspec.open(data_cfg.path).open().read().decode('utf-8')
data = data[ : int(len(data) * data_cfg.truncate)]

chars = sorted(list(set(data)))
data_size, vocab_size = len(data), len(chars)
print('Data has %d characters, %d unique.' % (data_size, vocab_size))

self.stoi = {ch: i for i, ch in enumerate(chars)}
self.itos = {i: ch for i, ch in enumerate(chars)}
self.block_size = data_cfg.block_size
self.vocab_size = vocab_size
self.data = data

def __len__(self):
return len(self.data) - self.block_size

def __getitem__(self, idx):
# grab a chunk of (block_size + 1) characters from the data
chunk = self.data[idx:idx + self.block_size + 1]
# encode every character to an integer
dix = [self.stoi[s] for s in chunk]
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
return x, y
Loading

0 comments on commit d91085d

Please sign in to comment.