We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9604029 commit 1fff96aCopy full SHA for 1fff96a
train.py
@@ -430,5 +430,8 @@ def main(args):
430
parser.add_argument('--num-workers', type=int, default=0)
431
parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0)
432
args = parser.parse_args()
433
- args.device = f"cuda:{args.gpu}" if args.gpu else "cuda:0"
+ if torch.backends.mps.is_available():
434
+ args.device = "mps"
435
+ else:
436
+ args.device = f"cuda:{args.gpu}" if args.gpu else "cuda:0"
437
main(args)
0 commit comments