Skip to content

Commit 2837a28

Browse files
committed
Update training & fine-tuning instructions
1 parent 904a65b commit 2837a28

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

app.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
13
import gradio as gr
24
import torch
35
import torchaudio
@@ -28,6 +30,8 @@
2830
model[key].to(device)
2931
model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
3032

33+
print(f"cfm has {sum(p.numel() for p in model.cfm.parameters() if p.requires_grad)} trainable parameters")
34+
3135
# Load additional modules
3236
from modules.campplus.DTDNN import CAMPPlus
3337

@@ -75,7 +79,7 @@
7579
"num_mels": config['preprocess_params']['spect_params']['n_mels'],
7680
"sampling_rate": sr,
7781
"fmin": 0,
78-
"fmax": None,
82+
"fmax": None if config['preprocess_params']['spect_params'].get('fmax') == "None" else config['preprocess_params']['spect_params']['fmax'],
7983
"center": False
8084
}
8185
from modules.audio import mel_spectrogram
@@ -84,7 +88,7 @@
8488

8589
# f0 conditioned model
8690
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
87-
"DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
91+
"DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema_v2.pth",
8892
"config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
8993

9094
config = yaml.safe_load(open(dit_config_path, 'r'))
@@ -114,7 +118,7 @@
114118
"num_mels": config['preprocess_params']['spect_params']['n_mels'],
115119
"sampling_rate": sr,
116120
"fmin": 0,
117-
"fmax": None,
121+
"fmax": None if config['preprocess_params']['spect_params'].get('fmax') == "None" else config['preprocess_params']['spect_params']['fmax'],
118122
"center": False
119123
}
120124
to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
@@ -124,6 +128,15 @@
124128
bigvgan_44k_model.remove_weight_norm()
125129
bigvgan_44k_model = bigvgan_44k_model.eval().to(device)
126130

131+
from modules.hifigan.generator import HiFTGenerator
132+
from modules.hifigan.f0_predictor import ConvRNNF0Predictor
133+
134+
hift_config = yaml.safe_load(open('configs/hifigan.yml', 'r'))
135+
hift_gen = HiFTGenerator(**hift_config['hift'], f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor']))
136+
hift_gen.load_state_dict(torch.load(hift_config['pretrained_model_path'], map_location='cpu'))
137+
hift_gen.eval()
138+
hift_gen.to(device)
139+
127140
def adjust_f0_semitones(f0_sequence, n_semitones):
128141
factor = 2 ** (n_semitones / 12)
129142
return f0_sequence * factor
@@ -148,7 +161,7 @@ def crossfade(chunk1, chunk2, overlap):
148161
def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, f0_condition, auto_f0_adjust, pitch_shift):
149162
inference_module = model if not f0_condition else model_f0
150163
mel_fn = to_mel if not f0_condition else to_mel_f0
151-
bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model
164+
bigvgan_fn = hift_gen if not f0_condition else bigvgan_44k_model
152165
sr = 22050 if not f0_condition else 44100
153166
# Load audio
154167
source_audio = librosa.load(source, sr=sr)[0]
@@ -289,6 +302,8 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
289302
inference_cfg_rate=inference_cfg_rate)
290303
vc_target = vc_target[:, :, mel2.size(-1):]
291304
vc_wave = bigvgan_fn(vc_target)[0]
305+
if vc_wave.ndim == 1:
306+
vc_wave = vc_wave.unsqueeze(0)
292307
if processed_frames == 0:
293308
if is_last_chunk:
294309
output_wave = vc_wave[0].cpu().numpy()

0 commit comments

Comments
 (0)