Skip to content

Commit 3c8858f

Browse files
add gpu selector flag to example python scripts
1 parent 258fde9 commit 3c8858f

File tree

4 files changed

+17
-6
lines changed

4 files changed

+17
-6
lines changed

app_svc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from pydub import AudioSegment
1212
import argparse
1313
# Load model and configuration
14-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1514

1615
fp16 = False
16+
device = None
1717
def load_models(args):
1818
global sr, hop_length, fp16
1919
fp16 = args.fp16
@@ -433,5 +433,8 @@ def main(args):
433433
parser.add_argument("--config-path", type=str, help="Path to the config file", default=None)
434434
parser.add_argument("--share", type=str2bool, nargs="?", const=True, default=False, help="Whether to share the app")
435435
parser.add_argument("--fp16", type=str2bool, nargs="?", const=True, help="Whether to use fp16", default=True)
436+
parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0)
436437
args = parser.parse_args()
438+
cuda_target = f"cuda:{args.gpu}" if args.gpu else "cuda"
439+
device = torch.device(cuda_target if torch.cuda.is_available() else "cpu")
437440
main(args)

app_vc.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import argparse
1313

1414
# Load model and configuration
15-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1615
fp16 = False
16+
device = None
1717
def load_models(args):
1818
global sr, hop_length, fp16
1919
fp16 = args.fp16
@@ -386,5 +386,8 @@ def main(args):
386386
parser.add_argument("--config-path", type=str, help="Path to the config file", default=None)
387387
parser.add_argument("--share", type=str2bool, nargs="?", const=True, default=False, help="Whether to share the app")
388388
parser.add_argument("--fp16", type=str2bool, nargs="?", const=True, help="Whether to use fp16", default=True)
389+
parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0)
389390
args = parser.parse_args()
390-
main(args)
391+
cuda_target = f"cuda:{args.gpu}" if args.gpu else "cuda"
392+
device = torch.device(cuda_target if torch.cuda.is_available() else "cpu")
393+
main(args)

real-time-gui.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import torch
3131
from modules.commons import str2bool
3232
# Load model and configuration
33-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33+
device = None
3434

3535
flag_vc = False
3636

@@ -328,7 +328,7 @@ def printt(strr, *args):
328328

329329
class Config:
330330
def __init__(self):
331-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
331+
self.device = device
332332

333333

334334
if __name__ == "__main__":
@@ -1137,5 +1137,8 @@ def get_device_channels(self):
11371137
parser.add_argument("--checkpoint-path", type=str, default=None, help="Path to the model checkpoint")
11381138
parser.add_argument("--config-path", type=str, default=None, help="Path to the vocoder checkpoint")
11391139
parser.add_argument("--fp16", type=str2bool, nargs="?", const=True, help="Whether to use fp16", default=True)
1140+
parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0)
11401141
args = parser.parse_args()
1142+
cuda_target = f"cuda:{args.gpu}" if args.gpu else "cuda"
1143+
device = torch.device(cuda_target if torch.cuda.is_available() else "cpu")
11411144
gui = GUI(args)

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919

2020

21-
2221
class Trainer:
2322
def __init__(self,
2423
config_path,
@@ -385,6 +384,7 @@ def main(args):
385384
max_epochs=args.max_epochs,
386385
save_interval=args.save_every,
387386
num_workers=args.num_workers,
387+
device=args.device
388388
)
389389
trainer.train()
390390

@@ -399,5 +399,7 @@ def main(args):
399399
parser.add_argument('--max-epochs', type=int, default=1000)
400400
parser.add_argument('--save-every', type=int, default=500)
401401
parser.add_argument('--num-workers', type=int, default=0)
402+
parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0)
402403
args = parser.parse_args()
404+
args.device = f"cuda:{args.gpu}" if args.gpu else "cuda:0"
403405
main(args)

0 commit comments

Comments
 (0)