Skip to content

Add code for DDP tutorial series [PR 1 / 3] #1067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 22, 2022
Merged

Conversation

subramen
Copy link
Contributor

First PR for the DDP tutorial series. This code accompanies the tutorials staged at pytorch/tutorials#2049

This PR includes code for single-gpu, multigpu and multinode training. Each training script builds on top of the previous one, allowing users to identify what changes when moving from one paradigm to another

@netlify
Copy link

netlify bot commented Sep 21, 2022

Deploy Preview for pytorch-examples-preview canceled.

Name Link
🔨 Latest commit 730fdbb
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-examples-preview/deploys/632b723cdaf9400008fe8126

@subramen subramen requested a review from msaroufim September 21, 2022 13:30
@subramen subramen changed the title Add code for DDP tutorial series [PR 1/N] Add code for DDP tutorial series [PR 1 / 3] Sep 21, 2022
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, some minor feedback

  • Use argparse when applicable to make it clear to users what the knobs they can turn are
  • Remove code duplication where possible

I really enjoyed reading the examples, especially with type annotations felt super easy

@malfet @rohan-varma would it be worth also adding distributed tests to examples? The older examples work on an older version of PyTorch and I'm worried the same will happen to this script.

rank: Unique identifier of each process
world_size: Total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be passed in with argparse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this tutorial. For a singlenode training, localhost is effectively the only value master_addr takes. I talk about this in the video too

def main(rank: int, world_size: int, save_every: int, total_epochs: int):
ddp_setup(rank, world_size)
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size=32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make batch size an argument to script


if __name__ == "__main__":
import sys
total_epochs = int(sys.argv[1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace with argparse

save_every: int,
) -> None:
self.gpu_id = gpu_id
self.model = model.to(gpu_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL to(1) is the same as to(torch.device("cuda:1")

@@ -0,0 +1,101 @@
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be quite a bit of code duplication between this script and the ones before it. Not a dealbreaker for a tutorial per se but I think it would make things clearer for readers what changes from one script to another. So you'll have the base utils and then you build out each example on top of the last

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is by design. The tutorial is structured to show what the exact diff is when moving from one script to another (see https://github.com/pytorch/tutorials/blob/5fb19241ada89db8ace17faea2371447de28146b/beginner_source/ddp_multigpu.rst#diff-for-single_gpupy-vs-multigpupy for example). The pages on pytorch/tutorials + the videos that walkthrough these scripts explain the diff

@@ -0,0 +1,26 @@
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add data to name of script

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, some minor questions.

Also curious if we want to add bash / SLURM scripts and clear instructions on how to multi-node launch on AWS clusters.

def _run_batch(self, source, targets):
self.optimizer.zero_grad()
output = self.model(source)
loss = torch.nn.CrossEntropyLoss()(output, targets)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use F.cross_entropy since this is being used in a functional style


def _save_checkpoint(self, epoch):
ckp = self.model.module.state_dict()
torch.save(ckp, "checkpoint.pt")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PATH = "checkpoint.pt"
torch.save(ckp, PATH)
print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

def train(self, max_epochs: int):
for epoch in range(max_epochs):
self._run_epoch(epoch)
if self.gpu_id == 0 and epoch % self.save_every == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for multi-node on N nodes, this would result in N checkpoints being saved, is that the desired behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC after a restart, if global_rank: 0 happens to be on a different node, it would not be able to find the checkpoint. That's why I'm conditioning on local_rank==0 instead of global_rank==0

b_sz = len(next(iter(self.train_data))[0])
print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
for source, targets in self.train_data:
source = source.to(self.gpu_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is not needed, DDP can move the inputs and the way we do it could potentially achieve some overlap and be more performant

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I guess that's needed for targets as targets not input into DDP model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DDP can move the inputs

can you share more? I can add this to the tutorial notes

def _run_batch(self, source, targets):
self.optimizer.zero_grad()
output = self.model(source)
loss = torch.nn.CrossEntropyLoss()(output, targets)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, F.cross_entropy

snapshot = {}
snapshot["MODEL_STATE"] = self.model.module.state_dict()
snapshot["EPOCHS_RUN"] = epoch
torch.save(snapshot, "snapshot.pt")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

snapshot_path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah I had this instance var but forgot to use it here

@@ -0,0 +1 @@
torch>=1.11.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since 1.12 is out, should we just use that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't want to limit the audience / force an upgrade just for the tutorial... that being said, i'm probably not using anything specific to 1.11 either

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can upgrade to the highest version that's better

import sys
total_epochs = int(sys.argv[1])
save_every = int(sys.argv[2])
main(save_every, total_epochs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an accompanying launcher / SLURM script?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, in PR #1068

return self.data[index]


class MyRandomDataset(Dataset):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this class appears unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, thanks for catching this

@msaroufim msaroufim merged commit f45e418 into main Sep 22, 2022
YinZhengxun pushed a commit to YinZhengxun/mt-exercise-02 that referenced this pull request Mar 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants