-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Add code for DDP tutorial series [PR 1 / 3] #1067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ Deploy Preview for pytorch-examples-preview canceled.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, some minor feedback
- Use argparse when applicable to make it clear to users what the knobs they can turn are
- Remove code duplication where possible
I really enjoyed reading the examples, especially with type annotations felt super easy
@malfet @rohan-varma would it be worth also adding distributed tests to examples? The older examples work on an older version of PyTorch and I'm worried the same will happen to this script.
rank: Unique identifier of each process | ||
world_size: Total number of processes | ||
""" | ||
os.environ["MASTER_ADDR"] = "localhost" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these be passed in with argparse
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for this tutorial. For a singlenode training, localhost is effectively the only value master_addr takes. I talk about this in the video too
def main(rank: int, world_size: int, save_every: int, total_epochs: int): | ||
ddp_setup(rank, world_size) | ||
dataset, model, optimizer = load_train_objs() | ||
train_data = prepare_dataloader(dataset, batch_size=32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make batch size an argument to script
|
||
if __name__ == "__main__": | ||
import sys | ||
total_epochs = int(sys.argv[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace with argparse
save_every: int, | ||
) -> None: | ||
self.gpu_id = gpu_id | ||
self.model = model.to(gpu_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL to(1)
is the same as to(torch.device("cuda:1")
@@ -0,0 +1,101 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be quite a bit of code duplication between this script and the ones before it. Not a dealbreaker for a tutorial per se but I think it would make things clearer for readers what changes from one script to another. So you'll have the base utils and then you build out each example on top of the last
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is by design. The tutorial is structured to show what the exact diff is when moving from one script to another (see https://github.com/pytorch/tutorials/blob/5fb19241ada89db8ace17faea2371447de28146b/beginner_source/ddp_multigpu.rst#diff-for-single_gpupy-vs-multigpupy for example). The pages on pytorch/tutorials + the videos that walkthrough these scripts explain the diff
@@ -0,0 +1,26 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add data to name of script
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, some minor questions.
Also curious if we want to add bash / SLURM scripts and clear instructions on how to multi-node launch on AWS clusters.
def _run_batch(self, source, targets): | ||
self.optimizer.zero_grad() | ||
output = self.model(source) | ||
loss = torch.nn.CrossEntropyLoss()(output, targets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use F.cross_entropy since this is being used in a functional style
|
||
def _save_checkpoint(self, epoch): | ||
ckp = self.model.module.state_dict() | ||
torch.save(ckp, "checkpoint.pt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PATH = "checkpoint.pt"
torch.save(ckp, PATH)
print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")
def train(self, max_epochs: int): | ||
for epoch in range(max_epochs): | ||
self._run_epoch(epoch) | ||
if self.gpu_id == 0 and epoch % self.save_every == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for multi-node on N nodes, this would result in N checkpoints being saved, is that the desired behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC after a restart, if global_rank: 0 happens to be on a different node, it would not be able to find the checkpoint. That's why I'm conditioning on local_rank==0 instead of global_rank==0
b_sz = len(next(iter(self.train_data))[0]) | ||
print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}") | ||
for source, targets in self.train_data: | ||
source = source.to(self.gpu_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is not needed, DDP can move the inputs and the way we do it could potentially achieve some overlap and be more performant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, I guess that's needed for targets as targets not input into DDP model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DDP can move the inputs
can you share more? I can add this to the tutorial notes
def _run_batch(self, source, targets): | ||
self.optimizer.zero_grad() | ||
output = self.model(source) | ||
loss = torch.nn.CrossEntropyLoss()(output, targets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, F.cross_entropy
snapshot = {} | ||
snapshot["MODEL_STATE"] = self.model.module.state_dict() | ||
snapshot["EPOCHS_RUN"] = epoch | ||
torch.save(snapshot, "snapshot.pt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
snapshot_path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hah I had this instance var but forgot to use it here
@@ -0,0 +1 @@ | |||
torch>=1.11.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since 1.12 is out, should we just use that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't want to limit the audience / force an upgrade just for the tutorial... that being said, i'm probably not using anything specific to 1.11 either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can upgrade to the highest version that's better
import sys | ||
total_epochs = int(sys.argv[1]) | ||
save_every = int(sys.argv[2]) | ||
main(save_every, total_epochs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there an accompanying launcher / SLURM script?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, in PR #1068
return self.data[index] | ||
|
||
|
||
class MyRandomDataset(Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this class appears unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, thanks for catching this
First PR for the DDP tutorial series. This code accompanies the tutorials staged at pytorch/tutorials#2049
This PR includes code for single-gpu, multigpu and multinode training. Each training script builds on top of the previous one, allowing users to identify what changes when moving from one paradigm to another