Skip to content

Commit ec0b27f

Browse files
committed
enable fp16 for all inference scripts by default
1 parent c83ae7a commit ec0b27f

File tree

7 files changed

+124
-87
lines changed

7 files changed

+124
-87
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ python inference.py --source <source-wav>
4545
--semi-tone-shift 0 # pitch shift in semitones for singing voice conversion
4646
--checkpoint <path-to-checkpoint>
4747
--config <path-to-config>
48+
--fp16 True
4849
```
4950
where:
5051
- `source` is the path to the speech file to convert to reference voice
@@ -58,11 +59,11 @@ where:
5859
- `semi-tone-shift` is the pitch shift in semitones for singing voice conversion, default is 0
5960
- `checkpoint` is the path to the model checkpoint if you have trained or fine-tuned your own model, leave to blank to auto-download default model from huggingface.(`seed-uvit-whisper-small-wavenet` if `f0-condition` is `False` else `seed-uvit-whisper-base`)
6061
- `config` is the path to the model config if you have trained or fine-tuned your own model, leave to blank to auto-download default config from huggingface
61-
62+
- `fp16` is the flag to use float16 inference, default is True
6263

6364
Voice Conversion Web UI:
6465
```bash
65-
python app_vc.py --checkpoint <path-to-checkpoint> --config <path-to-config>
66+
python app_vc.py --checkpoint <path-to-checkpoint> --config <path-to-config> --fp16 True
6667
```
6768
- `checkpoint` is the path to the model checkpoint if you have trained or fine-tuned your own model, leave to blank to auto-download default model from huggingface. (`seed-uvit-whisper-small-wavenet`)
6869
- `config` is the path to the model config if you have trained or fine-tuned your own model, leave to blank to auto-download default config from huggingface
@@ -71,7 +72,7 @@ Then open the browser and go to `http://localhost:7860/` to use the web interfac
7172

7273
Singing Voice Conversion Web UI:
7374
```bash
74-
python app_svc.py --checkpoint <path-to-checkpoint> --config <path-to-config>
75+
python app_svc.py --checkpoint <path-to-checkpoint> --config <path-to-config> --fp16 True
7576
```
7677
- `checkpoint` is the path to the model checkpoint if you have trained or fine-tuned your own model, leave to blank to auto-download default model from huggingface. (`seed-uvit-whisper-base`)
7778
- `config` is the path to the model config if you have trained or fine-tuned your own model, leave to blank to auto-download default config from huggingface

app.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,14 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
269269
chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
270270
is_last_chunk = processed_frames + max_source_window >= cond.size(1)
271271
cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
272-
# Voice Conversion
273-
vc_target = inference_module.cfm.inference(cat_condition,
274-
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
275-
mel2, style2, None, diffusion_steps,
276-
inference_cfg_rate=inference_cfg_rate)
277-
vc_target = vc_target[:, :, mel2.size(-1):]
278-
vc_wave = bigvgan_fn(vc_target)[0]
272+
with torch.autocast(device_type=device.type, dtype=torch.float16):
273+
# Voice Conversion
274+
vc_target = inference_module.cfm.inference(cat_condition,
275+
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
276+
mel2, style2, None, diffusion_steps,
277+
inference_cfg_rate=inference_cfg_rate)
278+
vc_target = vc_target[:, :, mel2.size(-1):]
279+
vc_wave = bigvgan_fn(vc_target)[0]
279280
if processed_frames == 0:
280281
if is_last_chunk:
281282
output_wave = vc_wave[0].cpu().numpy()

app_svc.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import torchaudio
66
import librosa
7-
from modules.commons import build_model, load_checkpoint, recursive_munch
7+
from modules.commons import build_model, load_checkpoint, recursive_munch, str2bool
88
import yaml
99
from hf_utils import load_custom_model_from_hf
1010
import numpy as np
@@ -13,28 +13,21 @@
1313
# Load model and configuration
1414
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1515

16-
# Load additional modules
17-
from modules.campplus.DTDNN import CAMPPlus
18-
19-
campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
20-
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
21-
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
22-
campplus_model.eval()
23-
campplus_model.to(device)
24-
25-
from modules.audio import mel_spectrogram
26-
16+
fp16 = False
2717
def load_models(args):
28-
global sr, hop_length
18+
global sr, hop_length, fp16
19+
fp16 = args.fp16
20+
print(f"Using device: {device}")
21+
print(f"Using fp16: {fp16}")
2922
# f0 conditioned model
3023
if args.checkpoint_path is None or args.checkpoint_path == "":
3124
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
3225
"DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema_v2.pth",
3326
"config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
3427
else:
28+
print(f"Using custom checkpoint: {args.checkpoint_path}")
3529
dit_checkpoint_path = args.checkpoint_path
3630
dit_config_path = args.config_path
37-
3831
config = yaml.safe_load(open(dit_config_path, "r"))
3932
model_params = recursive_munch(config["model_params"])
4033
model_params.dit_type = 'DiT'
@@ -336,13 +329,14 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
336329
chunk_f0 = interpolated_shifted_f0_alt[:, processed_frames:processed_frames + max_source_window]
337330
is_last_chunk = processed_frames + max_source_window >= cond.size(1)
338331
cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
339-
# Voice Conversion
340-
vc_target = inference_module.cfm.inference(cat_condition,
341-
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
342-
mel2, style2, None, diffusion_steps,
343-
inference_cfg_rate=inference_cfg_rate)
344-
vc_target = vc_target[:, :, mel2.size(-1):]
345-
vc_wave = vocoder_fn(vc_target).squeeze().cpu()
332+
with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32):
333+
# Voice Conversion
334+
vc_target = inference_module.cfm.inference(cat_condition,
335+
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
336+
mel2, style2, None, diffusion_steps,
337+
inference_cfg_rate=inference_cfg_rate)
338+
vc_target = vc_target[:, :, mel2.size(-1):]
339+
vc_wave = vocoder_fn(vc_target).squeeze().cpu()
346340
if vc_wave.ndim == 1:
347341
vc_wave = vc_wave.unsqueeze(0)
348342
if processed_frames == 0:
@@ -437,6 +431,7 @@ def main(args):
437431
parser = argparse.ArgumentParser()
438432
parser.add_argument("--checkpoint-path", type=str, help="Path to the checkpoint file", default=None)
439433
parser.add_argument("--config-path", type=str, help="Path to the config file", default=None)
440-
parser.add_argument("--share", type=bool, help="Whether to share url link", default=False)
434+
parser.add_argument("--share", type=str2bool, nargs="?", const=True, default=False, help="Whether to share the app")
435+
parser.add_argument("--fp16", type=str2bool, nargs="?", const=True, help="Whether to use fp16", default=True)
441436
args = parser.parse_args()
442437
main(args)

app_vc.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import torchaudio
66
import librosa
7-
from modules.commons import build_model, load_checkpoint, recursive_munch
7+
from modules.commons import build_model, load_checkpoint, recursive_munch, str2bool
88
import yaml
99
from hf_utils import load_custom_model_from_hf
1010
import numpy as np
@@ -13,9 +13,12 @@
1313

1414
# Load model and configuration
1515
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16-
16+
fp16 = False
1717
def load_models(args):
18-
global sr, hop_length
18+
global sr, hop_length, fp16
19+
fp16 = args.fp16
20+
print(f"Using device: {device}")
21+
print(f"Using fp16: {fp16}")
1922
if args.checkpoint_path is None or args.checkpoint_path == "":
2023
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
2124
"DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
@@ -285,13 +288,14 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
285288
chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
286289
is_last_chunk = processed_frames + max_source_window >= cond.size(1)
287290
cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
288-
# Voice Conversion
289-
vc_target = inference_module.cfm.inference(cat_condition,
290-
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
291-
mel2, style2, None, diffusion_steps,
292-
inference_cfg_rate=inference_cfg_rate)
293-
vc_target = vc_target[:, :, mel2.size(-1):]
294-
vc_wave = vocoder_fn(vc_target)[0]
291+
with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32):
292+
# Voice Conversion
293+
vc_target = inference_module.cfm.inference(cat_condition,
294+
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
295+
mel2, style2, None, diffusion_steps,
296+
inference_cfg_rate=inference_cfg_rate)
297+
vc_target = vc_target[:, :, mel2.size(-1):]
298+
vc_wave = vocoder_fn(vc_target)[0]
295299
if vc_wave.ndim == 1:
296300
vc_wave = vc_wave.unsqueeze(0)
297301
if processed_frames == 0:
@@ -380,6 +384,7 @@ def main(args):
380384
parser = argparse.ArgumentParser()
381385
parser.add_argument("--checkpoint-path", type=str, help="Path to the checkpoint file", default=None)
382386
parser.add_argument("--config-path", type=str, help="Path to the config file", default=None)
383-
parser.add_argument("--share", type=bool, help="Whether to share url link", default=False)
387+
parser.add_argument("--share", type=str2bool, nargs="?", const=True, default=False, help="Whether to share the app")
388+
parser.add_argument("--fp16", type=str2bool, nargs="?", const=True, help="Whether to use fp16", default=True)
384389
args = parser.parse_args()
385390
main(args)

inference.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616

1717
import torchaudio
1818
import librosa
19-
import torchaudio.compliance.kaldi as kaldi
19+
from modules.commons import str2bool
2020

2121
from hf_utils import load_custom_model_from_hf
2222

2323

2424
# Load model and configuration
2525
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26-
26+
fp16 = False
2727
def load_models(args):
28+
global fp16
29+
fp16 = args.fp16
2830
if not args.f0_condition:
2931
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
3032
"DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
@@ -304,17 +306,17 @@ def main(args):
304306
cat_condition = torch.cat([prompt_condition, cond], dim=1)
305307

306308
time_vc_start = time.time()
307-
vc_target = model.cfm.inference(
308-
cat_condition,
309-
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
310-
mel2, style2, None, diffusion_steps,
311-
inference_cfg_rate=inference_cfg_rate)
312-
vc_target = vc_target[:, :, mel2.size(-1):]
313-
314-
315-
# Convert to waveform
316-
vc_wave = vocoder_fn(vc_target).squeeze() # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
317-
vc_wave = vc_wave[None, :]
309+
with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32):
310+
vc_target = model.cfm.inference(
311+
cat_condition,
312+
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
313+
mel2, style2, None, diffusion_steps,
314+
inference_cfg_rate=inference_cfg_rate)
315+
vc_target = vc_target[:, :, mel2.size(-1):]
316+
317+
# Convert to waveform
318+
vc_wave = vocoder_fn(vc_target).squeeze() # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
319+
vc_wave = vc_wave[None, :].float()
318320
time_vc_end = time.time()
319321
print(f"RTF: {(time_vc_end - time_vc_start) / vc_wave.size(-1) * sr}")
320322

@@ -332,10 +334,11 @@ def main(args):
332334
parser.add_argument("--diffusion-steps", type=int, default=30)
333335
parser.add_argument("--length-adjust", type=float, default=1.0)
334336
parser.add_argument("--inference-cfg-rate", type=float, default=0.7)
335-
parser.add_argument("--f0-condition", type=bool, default=False)
336-
parser.add_argument("--auto-f0-adjust", type=bool, default=True)
337+
parser.add_argument("--f0-condition", type=str2bool, default=True)
338+
parser.add_argument("--auto-f0-adjust", type=str2bool, default=True)
337339
parser.add_argument("--semi-tone-shift", type=int, default=0)
338340
parser.add_argument("--checkpoint-path", type=str, help="Path to the checkpoint file", default=None)
339341
parser.add_argument("--config-path", type=str, help="Path to the config file", default=None)
342+
parser.add_argument("--fp16", type=str2bool, default=True)
340343
args = parser.parse_args()
341344
main(args)

modules/commons.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@
55
from torch.nn import functional as F
66
from munch import Munch
77
import json
8-
8+
import argparse
9+
10+
def str2bool(v):
11+
if isinstance(v, bool):
12+
return v
13+
if v.lower() in ("yes", "true", "t", "y", "1"):
14+
return True
15+
elif v.lower() in ("no", "false", "f", "n", "0"):
16+
return False
17+
else:
18+
raise argparse.ArgumentTypeError("Boolean value expected.")
919

1020
class AttrDict(dict):
1121
def __init__(self, *args, **kwargs):

0 commit comments

Comments
 (0)