4545# Load additional modules
4646from modules .campplus .DTDNN import CAMPPlus
4747
48+ campplus_ckpt_path = load_custom_model_from_hf ("funasr/campplus" , "campplus_cn_common.bin" , config_filename = None )
4849campplus_model = CAMPPlus (feat_dim = 80 , embedding_size = 192 )
49- campplus_model .load_state_dict (torch .load (config [ 'model_params' ][ 'style_encoder' ][ 'campplus_path' ] , map_location = ' cpu' ))
50+ campplus_model .load_state_dict (torch .load (campplus_ckpt_path , map_location = " cpu" ))
5051campplus_model .eval ()
5152campplus_model .to (device )
5253
@@ -103,6 +104,7 @@ def main(args):
103104 diffusion_steps = args .diffusion_steps
104105 length_adjust = args .length_adjust
105106 inference_cfg_rate = args .inference_cfg_rate
107+ n_quantizers = args .n_quantizers
106108 source_audio = librosa .load (source , sr = sr )[0 ]
107109 ref_audio = librosa .load (target_name , sr = sr )[0 ]
108110 # decoded_wav = encodec_model.decoder(encodec_latent)
@@ -117,43 +119,53 @@ def main(args):
117119 source_waves_16k = torchaudio .functional .resample (source_audio , sr , 16000 )
118120 ref_waves_16k = torchaudio .functional .resample (ref_audio , sr , 16000 )
119121
120- S_alt = [
121- cosyvoice_frontend .extract_speech_token (source_waves_16k , )
122- ]
123- S_alt_lens = torch .LongTensor ([s [1 ] for s in S_alt ]).to (device )
124- S_alt = torch .cat ([torch .nn .functional .pad (s [0 ], (0 , max (S_alt_lens ) - s [0 ].size (1 ))) for s in S_alt ], dim = 0 )
125-
126- S_ori = [
127- cosyvoice_frontend .extract_speech_token (ref_waves_16k , )
128- ]
129- S_ori_lens = torch .LongTensor ([s [1 ] for s in S_ori ]).to (device )
130- S_ori = torch .cat ([torch .nn .functional .pad (s [0 ], (0 , max (S_ori_lens ) - s [0 ].size (1 ))) for s in S_ori ], dim = 0 )
122+ if speech_tokenizer_type == "cosyvoice" :
123+ S_alt = cosyvoice_frontend .extract_speech_token (source_waves_16k )[0 ]
124+ S_ori = cosyvoice_frontend .extract_speech_token (ref_waves_16k )[0 ]
125+ elif speech_tokenizer_type == "facodec" :
126+ converted_waves_24k = torchaudio .functional .resample (source_audio , sr , 24000 )
127+ wave_lengths_24k = torch .LongTensor ([converted_waves_24k .size (1 )]).to (converted_waves_24k .device )
128+ waves_input = converted_waves_24k .unsqueeze (1 )
129+ z = codec_encoder .encoder (waves_input )
130+ (quantized , codes ) = codec_encoder .quantizer (z , waves_input )
131+ S_alt = torch .cat ([codes [1 ], codes [0 ]], dim = 1 )
132+
133+ # S_ori should be extracted in the same way
134+ waves_24k = torchaudio .functional .resample (ref_audio , sr , 24000 )
135+ waves_input = waves_24k .unsqueeze (1 )
136+ z = codec_encoder .encoder (waves_input )
137+ (quantized , codes ) = codec_encoder .quantizer (z , waves_input )
138+ S_ori = torch .cat ([codes [1 ], codes [0 ]], dim = 1 )
131139
132140 mel = to_mel (source_audio .to (device ).float ())
133141 mel2 = to_mel (ref_audio .to (device ).float ())
134142
135- target = mel
136- target2 = mel2
137-
138- target_lengths = torch .LongTensor ([int (target .size (2 ) * length_adjust )]).to (target .device )
139- target2_lengths = torch .LongTensor ([target2 .size (2 )]).to (target2 .device )
143+ target_lengths = torch .LongTensor ([int (mel .size (2 ) * length_adjust )]).to (mel .device )
144+ target2_lengths = torch .LongTensor ([mel2 .size (2 )]).to (mel2 .device )
140145
141- feat2 = kaldi .fbank (ref_waves_16k ,
142- num_mel_bins = 80 ,
143- dither = 0 ,
144- sample_frequency = 16000 )
146+ feat2 = torchaudio . compliance . kaldi .fbank (ref_waves_16k ,
147+ num_mel_bins = 80 ,
148+ dither = 0 ,
149+ sample_frequency = 16000 )
145150 feat2 = feat2 - feat2 .mean (dim = 0 , keepdim = True )
146151 style2 = campplus_model (feat2 .unsqueeze (0 ))
147152
148- cond = model .length_regulator (S_alt , ylens = target_lengths )[0 ]
149- prompt_condition = model .length_regulator (S_ori , ylens = target2_lengths )[0 ]
153+ # Length regulation
154+ cond = model .length_regulator (S_alt , ylens = target_lengths , n_quantizers = int (n_quantizers ))[0 ]
155+ prompt_condition = model .length_regulator (S_ori , ylens = target2_lengths , n_quantizers = int (n_quantizers ))[0 ]
150156 cat_condition = torch .cat ([prompt_condition , cond ], dim = 1 )
151- prompt_target = target2
152157
153158 time_vc_start = time .time ()
154- vc_target = model .cfm .inference (cat_condition , torch .LongTensor ([cat_condition .size (1 )]).to (prompt_target .device ), prompt_target , style2 , None , diffusion_steps , inference_cfg_rate = inference_cfg_rate )
155- vc_target = vc_target [:, :, prompt_target .size (- 1 ):]
159+ vc_target = model .cfm .inference (
160+ cat_condition ,
161+ torch .LongTensor ([cat_condition .size (1 )]).to (mel2 .device ),
162+ mel2 , style2 , None , diffusion_steps ,
163+ inference_cfg_rate = inference_cfg_rate )
164+ vc_target = vc_target [:, :, mel2 .size (- 1 ):]
165+
166+ # Convert to waveform
156167 vc_wave = hift_gen .inference (vc_target )
168+
157169 time_vc_end = time .time ()
158170 print (f"RTF: { (time_vc_end - time_vc_start ) / vc_wave .size (- 1 ) * sr } " )
159171
@@ -163,11 +175,10 @@ def main(args):
163175 torchaudio .save (os .path .join (args .output , f"vc_{ source_name } _{ target_name } _{ length_adjust } _{ diffusion_steps } _{ inference_cfg_rate } .wav" ), vc_wave .cpu (), sr )
164176
165177
166-
167178if __name__ == "__main__" :
168179 parser = argparse .ArgumentParser ()
169180 parser .add_argument ("--source" , type = str , default = "./examples/source/source_s1.wav" )
170- parser .add_argument ("--target" , type = str , default = "./examples/target /s1p1.wav" )
181+ parser .add_argument ("--target" , type = str , default = "./examples/reference /s1p1.wav" )
171182 parser .add_argument ("--output" , type = str , default = "./reconstructed" )
172183 parser .add_argument ("--diffusion-steps" , type = int , default = 10 )
173184 parser .add_argument ("--length-adjust" , type = float , default = 1.0 )
0 commit comments