Skip to content

Commit 1fff96a

Browse files
committed
added mac support for training
1 parent 9604029 commit 1fff96a

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,5 +430,8 @@ def main(args):
430430
parser.add_argument('--num-workers', type=int, default=0)
431431
parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0)
432432
args = parser.parse_args()
433-
args.device = f"cuda:{args.gpu}" if args.gpu else "cuda:0"
433+
if torch.backends.mps.is_available():
434+
args.device = "mps"
435+
else:
436+
args.device = f"cuda:{args.gpu}" if args.gpu else "cuda:0"
434437
main(args)

0 commit comments

Comments
 (0)