11import os
2+
3+ import numpy as np
4+
25os .environ ['HF_HUB_CACHE' ] = './checkpoints/hf_cache'
36import shutil
47import warnings
@@ -230,9 +233,18 @@ def adjust_f0_semitones(f0_sequence, n_semitones):
230233 factor = 2 ** (n_semitones / 12 )
231234 return f0_sequence * factor
232235
236+ def crossfade (chunk1 , chunk2 , overlap ):
237+ fade_out = np .cos (np .linspace (0 , np .pi / 2 , overlap )) ** 2
238+ fade_in = np .cos (np .linspace (np .pi / 2 , 0 , overlap )) ** 2
239+ if len (chunk2 ) < overlap :
240+ chunk2 [:overlap ] = chunk2 [:overlap ] * fade_in [:len (chunk2 )] + (chunk1 [- overlap :] * fade_out )[:len (chunk2 )]
241+ else :
242+ chunk2 [:overlap ] = chunk2 [:overlap ] * fade_in + chunk1 [- overlap :] * fade_out
243+ return chunk2
244+
233245@torch .no_grad ()
234246def main (args ):
235- model , semantic_fn , f0_fn , vocoder_fn , campplus_model , to_mel , mel_fn_args = load_models (args )
247+ model , semantic_fn , f0_fn , vocoder_fn , campplus_model , mel_fn , mel_fn_args = load_models (args )
236248 sr = mel_fn_args ['sampling_rate' ]
237249 f0_condition = args .f0_condition
238250 auto_f0_adjust = args .auto_f0_adjust
@@ -246,36 +258,62 @@ def main(args):
246258 source_audio = librosa .load (source , sr = sr )[0 ]
247259 ref_audio = librosa .load (target_name , sr = sr )[0 ]
248260
249- source_audio = source_audio [: sr * 30 ]
250- source_audio = torch . tensor ( source_audio ). unsqueeze ( 0 ). float (). to ( device )
251-
252- ref_audio = ref_audio [:( sr * 30 - source_audio . size ( - 1 ))]
253- ref_audio = torch . tensor ( ref_audio ). unsqueeze ( 0 ). float (). to ( device )
261+ sr = 22050 if not f0_condition else 44100
262+ hop_length = 256 if not f0_condition else 512
263+ max_context_window = sr // hop_length * 30
264+ overlap_frame_len = 16
265+ overlap_wave_len = overlap_frame_len * hop_length
254266
255- source_waves_16k = torchaudio .functional .resample (source_audio , sr , 16000 )
256- ref_waves_16k = torchaudio .functional .resample (ref_audio , sr , 16000 )
267+ # Process audio
268+ source_audio = torch .tensor (source_audio ).unsqueeze (0 ).float ().to (device )
269+ ref_audio = torch .tensor (ref_audio [:sr * 25 ]).unsqueeze (0 ).float ().to (device )
257270
271+ time_vc_start = time .time ()
272+ # Resample
258273 converted_waves_16k = torchaudio .functional .resample (source_audio , sr , 16000 )
274+ # if source audio less than 30 seconds, whisper can handle in one forward
275+ if converted_waves_16k .size (- 1 ) <= 16000 * 30 :
276+ S_alt = semantic_fn (converted_waves_16k )
277+ else :
278+ overlapping_time = 5 # 5 seconds
279+ S_alt_list = []
280+ buffer = None
281+ traversed_time = 0
282+ while traversed_time < converted_waves_16k .size (- 1 ):
283+ if buffer is None : # first chunk
284+ chunk = converted_waves_16k [:, traversed_time :traversed_time + 16000 * 30 ]
285+ else :
286+ chunk = torch .cat (
287+ [buffer , converted_waves_16k [:, traversed_time :traversed_time + 16000 * (30 - overlapping_time )]],
288+ dim = - 1 )
289+ S_alt = semantic_fn (chunk )
290+ if traversed_time == 0 :
291+ S_alt_list .append (S_alt )
292+ else :
293+ S_alt_list .append (S_alt [:, 50 * overlapping_time :])
294+ buffer = chunk [:, - 16000 * overlapping_time :]
295+ traversed_time += 30 * 16000 if traversed_time == 0 else chunk .size (- 1 ) - 16000 * overlapping_time
296+ S_alt = torch .cat (S_alt_list , dim = 1 )
297+
259298 ori_waves_16k = torchaudio .functional .resample (ref_audio , sr , 16000 )
260- S_alt = semantic_fn (converted_waves_16k )
261299 S_ori = semantic_fn (ori_waves_16k )
262300
263- mel = to_mel (source_audio .to (device ).float ())
264- mel2 = to_mel (ref_audio .to (device ).float ())
301+ mel = mel_fn (source_audio .to (device ).float ())
302+ mel2 = mel_fn (ref_audio .to (device ).float ())
265303
266304 target_lengths = torch .LongTensor ([int (mel .size (2 ) * length_adjust )]).to (mel .device )
267305 target2_lengths = torch .LongTensor ([mel2 .size (2 )]).to (mel2 .device )
268306
269- feat2 = torchaudio .compliance .kaldi .fbank (ref_waves_16k ,
307+ feat2 = torchaudio .compliance .kaldi .fbank (ori_waves_16k ,
270308 num_mel_bins = 80 ,
271309 dither = 0 ,
272310 sample_frequency = 16000 )
273311 feat2 = feat2 - feat2 .mean (dim = 0 , keepdim = True )
274312 style2 = campplus_model (feat2 .unsqueeze (0 ))
275313
276314 if f0_condition :
277- F0_ori = f0_fn (ref_waves_16k [0 ], thred = 0.03 )
278- F0_alt = f0_fn (source_waves_16k [0 ], thred = 0.03 )
315+ F0_ori = f0_fn (ori_waves_16k [0 ], thred = 0.03 )
316+ F0_alt = f0_fn (converted_waves_16k [0 ], thred = 0.03 )
279317
280318 F0_ori = torch .from_numpy (F0_ori ).to (device )[None ]
281319 F0_alt = torch .from_numpy (F0_alt ).to (device )[None ]
@@ -288,6 +326,7 @@ def main(args):
288326 voiced_log_f0_alt = torch .log (voiced_F0_alt + 1e-5 )
289327 median_log_f0_ori = torch .median (voiced_log_f0_ori )
290328 median_log_f0_alt = torch .median (voiced_log_f0_alt )
329+
291330 # shift alt log f0 level to ori log f0 level
292331 shifted_log_f0_alt = log_f0_alt .clone ()
293332 if auto_f0_adjust :
@@ -301,22 +340,53 @@ def main(args):
301340 shifted_f0_alt = None
302341
303342 # Length regulation
304- cond , _ , codes , commitment_loss , codebook_loss = model .length_regulator (S_alt , ylens = target_lengths , n_quantizers = 3 , f0 = shifted_f0_alt )
305- prompt_condition , _ , prompt_codes , commitment_loss , codebook_loss = model .length_regulator (S_ori , ylens = target2_lengths , n_quantizers = 3 , f0 = F0_ori )
306- cat_condition = torch .cat ([prompt_condition , cond ], dim = 1 )
307-
308- time_vc_start = time .time ()
309- with torch .autocast (device_type = device .type , dtype = torch .float16 if fp16 else torch .float32 ):
310- vc_target = model .cfm .inference (
311- cat_condition ,
312- torch .LongTensor ([cat_condition .size (1 )]).to (mel2 .device ),
313- mel2 , style2 , None , diffusion_steps ,
314- inference_cfg_rate = inference_cfg_rate )
315- vc_target = vc_target [:, :, mel2 .size (- 1 ):]
316-
317- # Convert to waveform
318- vc_wave = vocoder_fn (vc_target ).squeeze () # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
319- vc_wave = vc_wave [None , :].float ()
343+ cond , _ , codes , commitment_loss , codebook_loss = model .length_regulator (S_alt , ylens = target_lengths ,
344+ n_quantizers = 3 ,
345+ f0 = shifted_f0_alt )
346+ prompt_condition , _ , codes , commitment_loss , codebook_loss = model .length_regulator (S_ori ,
347+ ylens = target2_lengths ,
348+ n_quantizers = 3 ,
349+ f0 = F0_ori )
350+
351+ max_source_window = max_context_window - mel2 .size (2 )
352+ # split source condition (cond) into chunks
353+ processed_frames = 0
354+ generated_wave_chunks = []
355+ # generate chunk by chunk and stream the output
356+ while processed_frames < cond .size (1 ):
357+ chunk_cond = cond [:, processed_frames :processed_frames + max_source_window ]
358+ is_last_chunk = processed_frames + max_source_window >= cond .size (1 )
359+ cat_condition = torch .cat ([prompt_condition , chunk_cond ], dim = 1 )
360+ with torch .autocast (device_type = device .type , dtype = torch .float16 ):
361+ # Voice Conversion
362+ vc_target = model .cfm .inference (cat_condition ,
363+ torch .LongTensor ([cat_condition .size (1 )]).to (mel2 .device ),
364+ mel2 , style2 , None , diffusion_steps ,
365+ inference_cfg_rate = inference_cfg_rate )
366+ vc_target = vc_target [:, :, mel2 .size (- 1 ):]
367+ vc_wave = vocoder_fn (vc_target ).squeeze ()
368+ vc_wave = vc_wave [None , :]
369+ if processed_frames == 0 :
370+ if is_last_chunk :
371+ output_wave = vc_wave [0 ].cpu ().numpy ()
372+ generated_wave_chunks .append (output_wave )
373+ break
374+ output_wave = vc_wave [0 , :- overlap_wave_len ].cpu ().numpy ()
375+ generated_wave_chunks .append (output_wave )
376+ previous_chunk = vc_wave [0 , - overlap_wave_len :]
377+ processed_frames += vc_target .size (2 ) - overlap_frame_len
378+ elif is_last_chunk :
379+ output_wave = crossfade (previous_chunk .cpu ().numpy (), vc_wave [0 ].cpu ().numpy (), overlap_wave_len )
380+ generated_wave_chunks .append (output_wave )
381+ processed_frames += vc_target .size (2 ) - overlap_frame_len
382+ break
383+ else :
384+ output_wave = crossfade (previous_chunk .cpu ().numpy (), vc_wave [0 , :- overlap_wave_len ].cpu ().numpy (),
385+ overlap_wave_len )
386+ generated_wave_chunks .append (output_wave )
387+ previous_chunk = vc_wave [0 , - overlap_wave_len :]
388+ processed_frames += vc_target .size (2 ) - overlap_frame_len
389+ vc_wave = torch .tensor (np .concatenate (generated_wave_chunks ))[None , :].float ()
320390 time_vc_end = time .time ()
321391 print (f"RTF: { (time_vc_end - time_vc_start ) / vc_wave .size (- 1 ) * sr } " )
322392
@@ -334,8 +404,8 @@ def main(args):
334404 parser .add_argument ("--diffusion-steps" , type = int , default = 30 )
335405 parser .add_argument ("--length-adjust" , type = float , default = 1.0 )
336406 parser .add_argument ("--inference-cfg-rate" , type = float , default = 0.7 )
337- parser .add_argument ("--f0-condition" , type = str2bool , default = True )
338- parser .add_argument ("--auto-f0-adjust" , type = str2bool , default = True )
407+ parser .add_argument ("--f0-condition" , type = str2bool , default = False )
408+ parser .add_argument ("--auto-f0-adjust" , type = str2bool , default = False )
339409 parser .add_argument ("--semi-tone-shift" , type = int , default = 0 )
340410 parser .add_argument ("--checkpoint-path" , type = str , help = "Path to the checkpoint file" , default = None )
341411 parser .add_argument ("--config-path" , type = str , help = "Path to the config file" , default = None )
0 commit comments