Skip to content

Add --device argument to run examples on a specific device #1288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
add --device for vae
  • Loading branch information
shink committed Sep 23, 2024
commit aeacb0e2654c82e441d7f9b0c053cfc36772b0c4
15 changes: 8 additions & 7 deletions vae/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ The main.py script accepts the following arguments:

```bash
optional arguments:
--batch-size input batch size for training (default: 128)
--epochs number of epochs to train (default: 10)
--no-cuda enables CUDA training
--mps enables GPU on macOS
--seed random seed (default: 1)
--log-interval how many batches to wait before logging training status
```
--batch-size N input batch size for training (default: 128)
--epochs EPOCHS number of epochs to train (default: 10)
--no-cuda disables CUDA training
--no-mps disables macOS GPU training
--device DEVICE backend name
--seed SEED random seed (default: 1)
--log-interval N how many batches to wait before logging training status
```
4 changes: 3 additions & 1 deletion vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
help='disables CUDA training')
parser.add_argument('--no-mps', action='store_true', default=False,
help='disables macOS GPU training')
parser.add_argument('--device', type=str, default='cpu',
help='backend device')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
Expand All @@ -32,7 +34,7 @@
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")
device = torch.device(args.device)

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
Expand Down