|
4 | 4 | import torch |
5 | 5 | import torchaudio |
6 | 6 | import librosa |
7 | | -from modules.commons import build_model, load_checkpoint, recursive_munch |
| 7 | +from modules.commons import build_model, load_checkpoint, recursive_munch, str2bool |
8 | 8 | import yaml |
9 | 9 | from hf_utils import load_custom_model_from_hf |
10 | 10 | import numpy as np |
|
13 | 13 | # Load model and configuration |
14 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
15 | 15 |
|
16 | | -# Load additional modules |
17 | | -from modules.campplus.DTDNN import CAMPPlus |
18 | | - |
19 | | -campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None) |
20 | | -campplus_model = CAMPPlus(feat_dim=80, embedding_size=192) |
21 | | -campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu")) |
22 | | -campplus_model.eval() |
23 | | -campplus_model.to(device) |
24 | | - |
25 | | -from modules.audio import mel_spectrogram |
26 | | - |
| 16 | +fp16 = False |
27 | 17 | def load_models(args): |
28 | | - global sr, hop_length |
| 18 | + global sr, hop_length, fp16 |
| 19 | + fp16 = args.fp16 |
| 20 | + print(f"Using device: {device}") |
| 21 | + print(f"Using fp16: {fp16}") |
29 | 22 | # f0 conditioned model |
30 | 23 | if args.checkpoint_path is None or args.checkpoint_path == "": |
31 | 24 | dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC", |
32 | 25 | "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema_v2.pth", |
33 | 26 | "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml") |
34 | 27 | else: |
| 28 | + print(f"Using custom checkpoint: {args.checkpoint_path}") |
35 | 29 | dit_checkpoint_path = args.checkpoint_path |
36 | 30 | dit_config_path = args.config_path |
37 | | - |
38 | 31 | config = yaml.safe_load(open(dit_config_path, "r")) |
39 | 32 | model_params = recursive_munch(config["model_params"]) |
40 | 33 | model_params.dit_type = 'DiT' |
@@ -336,13 +329,14 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c |
336 | 329 | chunk_f0 = interpolated_shifted_f0_alt[:, processed_frames:processed_frames + max_source_window] |
337 | 330 | is_last_chunk = processed_frames + max_source_window >= cond.size(1) |
338 | 331 | cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1) |
339 | | - # Voice Conversion |
340 | | - vc_target = inference_module.cfm.inference(cat_condition, |
341 | | - torch.LongTensor([cat_condition.size(1)]).to(mel2.device), |
342 | | - mel2, style2, None, diffusion_steps, |
343 | | - inference_cfg_rate=inference_cfg_rate) |
344 | | - vc_target = vc_target[:, :, mel2.size(-1):] |
345 | | - vc_wave = vocoder_fn(vc_target).squeeze().cpu() |
| 332 | + with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32): |
| 333 | + # Voice Conversion |
| 334 | + vc_target = inference_module.cfm.inference(cat_condition, |
| 335 | + torch.LongTensor([cat_condition.size(1)]).to(mel2.device), |
| 336 | + mel2, style2, None, diffusion_steps, |
| 337 | + inference_cfg_rate=inference_cfg_rate) |
| 338 | + vc_target = vc_target[:, :, mel2.size(-1):] |
| 339 | + vc_wave = vocoder_fn(vc_target).squeeze().cpu() |
346 | 340 | if vc_wave.ndim == 1: |
347 | 341 | vc_wave = vc_wave.unsqueeze(0) |
348 | 342 | if processed_frames == 0: |
@@ -437,6 +431,7 @@ def main(args): |
437 | 431 | parser = argparse.ArgumentParser() |
438 | 432 | parser.add_argument("--checkpoint-path", type=str, help="Path to the checkpoint file", default=None) |
439 | 433 | parser.add_argument("--config-path", type=str, help="Path to the config file", default=None) |
440 | | - parser.add_argument("--share", type=bool, help="Whether to share url link", default=False) |
| 434 | + parser.add_argument("--share", type=str2bool, nargs="?", const=True, default=False, help="Whether to share the app") |
| 435 | + parser.add_argument("--fp16", type=str2bool, nargs="?", const=True, help="Whether to use fp16", default=True) |
441 | 436 | args = parser.parse_args() |
442 | 437 | main(args) |
0 commit comments