Skip to content

Commit d91085d

Browse files
authored
Add code for DDP tutorial series [PR 3 / 3] (#1069)
* Adds files for minGPT training with DDP * filtered-clone, update script path, update readme * add refs to karpathy's repo * add training data * add AMP training * delete raw data file, update index.rst * Update gpt2_train_cfg.yaml
1 parent 84b7588 commit d91085d

File tree

12 files changed

+40635
-0
lines changed

12 files changed

+40635
-0
lines changed

distributed/minGPT-ddp/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# minGPT-DDP
2+
3+
Code accompanying the tutorial at https://pytorch.org/tutorials/intermediate/ddp_minGPT.html for training a GPT-like model with Distributed Data Parallel (DDP) in PyTorch.
4+
5+
Files marked with an asterisk (*) are adapted from the minGPT repo (https://github.com/karpathy/minGPT).
6+
7+
- [trainer.py](mingpt/trainer.py) includes the Trainer class that runs the distributed training iterations on the model with the provided dataset.
8+
- [model.py *](mingpt/model.py) defines the model architecture.
9+
- [char_dataset.py *](mingpt/char_dataset.py) contains the `Dataset`class for a character-level dataset.
10+
- [gpt2_train_cfg.yaml](mingpt/gpt2_train_cfg.yaml) contains the configurations for data, model, optimizer and training run.
11+
- [main.py](mingpt/main.py) is the entry point to the trainig job. It sets up the DDP process group, reads all the configurations and runs the training job.
12+
- [slurm/](mingpt/slurm) contains files for setting up an AWS cluster and the slurm script to run multinode training.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
from torch.utils.data import Dataset
3+
import fsspec
4+
from dataclasses import dataclass
5+
6+
"""
7+
Adapted from https://github.com/karpathy/minGPT/blob/master/projects/chargpt/chargpt.py
8+
"""
9+
10+
@dataclass
11+
class DataConfig:
12+
path: str = None
13+
block_size: int = None
14+
train_split: float = None
15+
truncate: float = 1.0
16+
17+
class CharDataset(Dataset):
18+
19+
def __init__(self, data_cfg: DataConfig): #data_path: str, block_size):
20+
data = fsspec.open(data_cfg.path).open().read().decode('utf-8')
21+
data = data[ : int(len(data) * data_cfg.truncate)]
22+
23+
chars = sorted(list(set(data)))
24+
data_size, vocab_size = len(data), len(chars)
25+
print('Data has %d characters, %d unique.' % (data_size, vocab_size))
26+
27+
self.stoi = {ch: i for i, ch in enumerate(chars)}
28+
self.itos = {i: ch for i, ch in enumerate(chars)}
29+
self.block_size = data_cfg.block_size
30+
self.vocab_size = vocab_size
31+
self.data = data
32+
33+
def __len__(self):
34+
return len(self.data) - self.block_size
35+
36+
def __getitem__(self, idx):
37+
# grab a chunk of (block_size + 1) characters from the data
38+
chunk = self.data[idx:idx + self.block_size + 1]
39+
# encode every character to an integer
40+
dix = [self.stoi[s] for s in chunk]
41+
x = torch.tensor(dix[:-1], dtype=torch.long)
42+
y = torch.tensor(dix[1:], dtype=torch.long)
43+
return x, y

0 commit comments

Comments
 (0)