11import os
2+ import sys
23os .environ ['HF_HUB_CACHE' ] = './checkpoints/hf_cache'
34import torch
5+ import torch .multiprocessing as mp
46import random
57import librosa
68import yaml
911import torchaudio .compliance .kaldi as kaldi
1012import glob
1113from tqdm import tqdm
14+ import shutil
1215
1316from modules .commons import recursive_munch , build_model , load_checkpoint
1417from optimizers import build_optimizer
1518from data .ft_dataset import build_ft_dataloader
1619from hf_utils import load_custom_model_from_hf
17- import shutil
18-
19-
2020
2121
2222class Trainer :
@@ -79,23 +79,22 @@ def __init__(self,
7979
8080 # initialize optimizers after preparing models for compatibility with FSDP
8181 self .optimizer = build_optimizer ({key : self .model [key ] for key in self .model },
82- lr = float (scheduler_params ['base_lr' ]))
82+ lr = float (scheduler_params ['base_lr' ]))
8383
8484 if pretrained_ckpt_path is None :
85- # find latest checkpoint with name pattern of 'T2V_epoch_*_step_*.pth'
85+ # find latest checkpoint
8686 available_checkpoints = glob .glob (os .path .join (self .log_dir , "DiT_epoch_*_step_*.pth" ))
8787 if len (available_checkpoints ) > 0 :
88- # find the checkpoint that has the highest step number
8988 latest_checkpoint = max (
9089 available_checkpoints , key = lambda x : int (x .split ("_" )[- 1 ].split ("." )[0 ])
9190 )
9291 earliest_checkpoint = min (
9392 available_checkpoints , key = lambda x : int (x .split ("_" )[- 1 ].split ("." )[0 ])
9493 )
95- # delete the earliest checkpoint
94+ # delete the earliest checkpoint if we have more than 2
9695 if (
97- earliest_checkpoint != latest_checkpoint
98- and len (available_checkpoints ) > 2
96+ earliest_checkpoint != latest_checkpoint
97+ and len (available_checkpoints ) > 2
9998 ):
10099 os .remove (earliest_checkpoint )
101100 print (f"Removed { earliest_checkpoint } " )
@@ -108,16 +107,18 @@ def __init__(self,
108107 latest_checkpoint = pretrained_ckpt_path
109108
110109 if os .path .exists (latest_checkpoint ):
111- self .model , self .optimizer , self .epoch , self .iters = load_checkpoint (self .model , self .optimizer , latest_checkpoint ,
112- load_only_params = True ,
113- ignore_modules = [],
114- is_distributed = False )
110+ self .model , self .optimizer , self .epoch , self .iters = load_checkpoint (
111+ self .model , self .optimizer , latest_checkpoint ,
112+ load_only_params = True ,
113+ ignore_modules = [],
114+ is_distributed = False
115+ )
115116 print (f"Loaded checkpoint from { latest_checkpoint } " )
116117 else :
117118 self .epoch , self .iters = 0 , 0
118- print ("Failed to load any checkpoint, this implies you are training from scratch." )
119+ print ("Failed to load any checkpoint, training from scratch." )
120+
119121 def build_sv_model (self , device , config ):
120- # speaker verification model
121122 from modules .campplus .DTDNN import CAMPPlus
122123 self .campplus_model = CAMPPlus (feat_dim = 80 , embedding_size = 192 )
123124 campplus_sd_path = load_custom_model_from_hf ("funasr/campplus" , "campplus_cn_common.bin" , config_filename = None )
@@ -126,16 +127,17 @@ def build_sv_model(self, device, config):
126127 self .campplus_model .eval ()
127128 self .campplus_model .to (device )
128129 self .sv_fn = self .campplus_model
130+
129131 def build_f0_fn (self , device , config ):
130132 from modules .rmvpe import RMVPE
131133 model_path = load_custom_model_from_hf ("lj1995/VoiceConversionWebUI" , "rmvpe.pt" , None )
132134 self .rmvpe = RMVPE (model_path , is_half = False , device = device )
133135 self .f0_fn = self .rmvpe
136+
134137 def build_converter (self , device , config ):
135- # speaker perturbation model
136138 from modules .openvoice .api import ToneColorConverter
137139 ckpt_converter , config_converter = load_custom_model_from_hf ("myshell-ai/OpenVoiceV2" , "converter/checkpoint.pth" , "converter/config.json" )
138- self .tone_color_converter = ToneColorConverter (config_converter , device = device , )
140+ self .tone_color_converter = ToneColorConverter (config_converter , device = device )
139141 self .tone_color_converter .load_ckpt (ckpt_converter )
140142 self .tone_color_converter .model .eval ()
141143 se_db_path = load_custom_model_from_hf ("Plachta/Seed-VC" , "se_db.pt" , None )
@@ -146,9 +148,7 @@ def build_vocoder(self, device, config):
146148 vocoder_name = config ['model_params' ]['vocoder' ].get ('name' , None )
147149 if vocoder_type == 'bigvgan' :
148150 from modules .bigvgan import bigvgan
149- bigvgan_name = vocoder_name
150- self .bigvgan_model = bigvgan .BigVGAN .from_pretrained (bigvgan_name , use_cuda_kernel = False )
151- # remove weight norm in the model and set to eval mode
151+ self .bigvgan_model = bigvgan .BigVGAN .from_pretrained (vocoder_name , use_cuda_kernel = False )
152152 self .bigvgan_model .remove_weight_norm ()
153153 self .bigvgan_model = self .bigvgan_model .eval ().to (device )
154154 vocoder_fn = self .bigvgan_model
@@ -158,7 +158,7 @@ def build_vocoder(self, device, config):
158158 hift_config = yaml .safe_load (open ('configs/hifigan.yml' , 'r' ))
159159 hift_path = load_custom_model_from_hf ("FunAudioLLM/CosyVoice-300M" , 'hift.pt' , None )
160160 self .hift_gen = HiFTGenerator (** hift_config ['hift' ],
161- f0_predictor = ConvRNNF0Predictor (** hift_config ['f0_predictor' ]))
161+ f0_predictor = ConvRNNF0Predictor (** hift_config ['f0_predictor' ]))
162162 self .hift_gen .load_state_dict (torch .load (hift_path , map_location = 'cpu' ))
163163 self .hift_gen .eval ()
164164 self .hift_gen .to (device )
@@ -168,21 +168,25 @@ def build_vocoder(self, device, config):
168168 self .vocoder_fn = vocoder_fn
169169
170170 def build_semantic_fn (self , device , config ):
171- # speech tokenizer
172171 speech_tokenizer_type = config ['model_params' ]['speech_tokenizer' ].get ('type' , 'cosyvoice' )
173172 if speech_tokenizer_type == 'whisper' :
174173 from transformers import AutoFeatureExtractor , WhisperModel
175174 whisper_model_name = config ['model_params' ]['speech_tokenizer' ]['name' ]
176175 self .whisper_model = WhisperModel .from_pretrained (whisper_model_name ).to (device )
177176 self .whisper_feature_extractor = AutoFeatureExtractor .from_pretrained (whisper_model_name )
177+ # remove decoder to save memory
178178 del self .whisper_model .decoder
179+
179180 def semantic_fn (waves_16k ):
180- ori_inputs = self .whisper_feature_extractor ([w16k .cpu ().numpy () for w16k in waves_16k ],
181- return_tensors = "pt" ,
182- return_attention_mask = True ,
183- sampling_rate = 16000 ,)
181+ ori_inputs = self .whisper_feature_extractor (
182+ [w16k .cpu ().numpy () for w16k in waves_16k ],
183+ return_tensors = "pt" ,
184+ return_attention_mask = True ,
185+ sampling_rate = 16000 ,
186+ )
184187 ori_input_features = self .whisper_model ._mask_input_features (
185- ori_inputs .input_features , attention_mask = ori_inputs .attention_mask ).to (device )
188+ ori_inputs .input_features , attention_mask = ori_inputs .attention_mask
189+ ).to (device )
186190 with torch .no_grad ():
187191 ori_outputs = self .whisper_model .encoder (
188192 ori_input_features .to (self .whisper_model .encoder .dtype ),
@@ -194,6 +198,7 @@ def semantic_fn(waves_16k):
194198 S_ori = ori_outputs .last_hidden_state .to (torch .float32 )
195199 S_ori = S_ori [:, :waves_16k .size (- 1 ) // 320 + 1 ]
196200 return S_ori
201+
197202 elif speech_tokenizer_type == 'xlsr' :
198203 from transformers import (
199204 Wav2Vec2FeatureExtractor ,
@@ -209,15 +214,14 @@ def semantic_fn(waves_16k):
209214 self .wav2vec_model = self .wav2vec_model .half ()
210215
211216 def semantic_fn (waves_16k ):
212- ori_waves_16k_input_list = [
213- waves_16k [bib ].cpu ().numpy ()
214- for bib in range (len (waves_16k ))
215- ]
216- ori_inputs = self .wav2vec_feature_extractor (ori_waves_16k_input_list ,
217- return_tensors = "pt" ,
218- return_attention_mask = True ,
219- padding = True ,
220- sampling_rate = 16000 ).to (device )
217+ ori_waves_16k_input_list = [waves_16k [bib ].cpu ().numpy () for bib in range (len (waves_16k ))]
218+ ori_inputs = self .wav2vec_feature_extractor (
219+ ori_waves_16k_input_list ,
220+ return_tensors = "pt" ,
221+ return_attention_mask = True ,
222+ padding = True ,
223+ sampling_rate = 16000
224+ ).to (device )
221225 with torch .no_grad ():
222226 ori_outputs = self .wav2vec_model (
223227 ori_inputs .input_values .half (),
@@ -246,11 +250,12 @@ def train_one_step(self, batch):
246250 se_batch = self .tone_color_converter .extract_se (waves_22k , wave_lengths_22k )
247251
248252 ref_se_idx = torch .randint (0 , len (self .se_db ), (B ,))
249- ref_se = self .se_db [ref_se_idx ]
250- ref_se = ref_se .to (self .device )
253+ ref_se = self .se_db [ref_se_idx ].to (self .device )
251254
252255 # convert
253- converted_waves_22k = self .tone_color_converter .convert (waves_22k , wave_lengths_22k , se_batch , ref_se ).squeeze (1 )
256+ converted_waves_22k = self .tone_color_converter .convert (
257+ waves_22k , wave_lengths_22k , se_batch , ref_se
258+ ).squeeze (1 )
254259
255260 if self .sr != 22050 :
256261 converted_waves = torchaudio .functional .resample (converted_waves_22k , 22050 , self .sr )
@@ -260,6 +265,7 @@ def train_one_step(self, batch):
260265 waves_16k = torchaudio .functional .resample (waves , self .sr , 16000 )
261266 wave_lengths_16k = (wave_lengths .float () * 16000 / self .sr ).long ()
262267 converted_waves_16k = torchaudio .functional .resample (converted_waves , self .sr , 16000 )
268+
263269 # extract S_alt (perturbed speech tokens)
264270 S_ori = self .semantic_fn (waves_16k )
265271 S_alt = self .semantic_fn (converted_waves_16k )
@@ -268,11 +274,14 @@ def train_one_step(self, batch):
268274 F0_ori = self .rmvpe .infer_from_audio_batch (waves_16k )
269275 else :
270276 F0_ori = None
277+
271278 # interpolate speech token to match acoustic feature length
272279 alt_cond , _ , alt_codes , alt_commitment_loss , alt_codebook_loss = (
273- self .model .length_regulator (S_alt , ylens = target_lengths , f0 = F0_ori ))
280+ self .model .length_regulator (S_alt , ylens = target_lengths , f0 = F0_ori )
281+ )
274282 ori_cond , _ , ori_codes , ori_commitment_loss , ori_codebook_loss = (
275- self .model .length_regulator (S_ori , ylens = target_lengths , f0 = F0_ori ))
283+ self .model .length_regulator (S_ori , ylens = target_lengths , f0 = F0_ori )
284+ )
276285 if alt_commitment_loss is None :
277286 alt_commitment_loss = 0
278287 alt_codebook_loss = 0
@@ -281,10 +290,10 @@ def train_one_step(self, batch):
281290
282291 # randomly set a length as prompt
283292 prompt_len_max = target_lengths - 1
284- prompt_len = (torch .rand ([B ], device = alt_cond .device ) * prompt_len_max ).floor ().to ( dtype = torch . long )
293+ prompt_len = (torch .rand ([B ], device = alt_cond .device ) * prompt_len_max ).floor ().long ( )
285294 prompt_len [torch .rand ([B ], device = alt_cond .device ) < 0.1 ] = 0
286295
287- # for prompt cond token, it must be from ori_cond instead of alt_cond
296+ # for prompt cond token, use ori_cond instead of alt_cond
288297 cond = alt_cond .clone ()
289298 for bib in range (B ):
290299 cond [bib , :prompt_len [bib ]] = ori_cond [bib , :prompt_len [bib ]]
@@ -295,13 +304,16 @@ def train_one_step(self, batch):
295304 cond = cond [:, :common_min_len ]
296305 target_lengths = torch .clamp (target_lengths , max = common_min_len )
297306 x = target
298- # style vectors are extracted from prompt only to avoid inference time OOD
307+
308+ # style vectors are extracted from the prompt only
299309 feat_list = []
300310 for bib in range (B ):
301- feat = kaldi .fbank (waves_16k [bib :bib + 1 , :wave_lengths_16k [bib ]],
302- num_mel_bins = 80 ,
303- dither = 0 ,
304- sample_frequency = 16000 )
311+ feat = kaldi .fbank (
312+ waves_16k [bib :bib + 1 , :wave_lengths_16k [bib ]],
313+ num_mel_bins = 80 ,
314+ dither = 0 ,
315+ sample_frequency = 16000
316+ )
305317 feat = feat - feat .mean (dim = 0 , keepdim = True )
306318 feat_list .append (feat )
307319 y_list = []
@@ -313,31 +325,39 @@ def train_one_step(self, batch):
313325
314326 loss , _ = self .model .cfm (x , target_lengths , prompt_len , cond , y )
315327
316- loss_total = (loss +
317- (alt_commitment_loss + ori_commitment_loss ) * 0.05 +
318- (ori_codebook_loss + alt_codebook_loss ) * 0.15 )
328+ loss_total = (
329+ loss +
330+ (alt_commitment_loss + ori_commitment_loss ) * 0.05 +
331+ (ori_codebook_loss + alt_codebook_loss ) * 0.15
332+ )
319333
320334 self .optimizer .zero_grad ()
321335 loss_total .backward ()
322- grad_norm_g = torch .nn .utils .clip_grad_norm_ (self .model .cfm .parameters (), 10.0 )
323- grad_norm_g2 = torch .nn .utils .clip_grad_norm_ (self .model .length_regulator .parameters (), 10.0 )
336+ torch .nn .utils .clip_grad_norm_ (self .model .cfm .parameters (), 10.0 )
337+ torch .nn .utils .clip_grad_norm_ (self .model .length_regulator .parameters (), 10.0 )
324338 self .optimizer .step ('cfm' )
325339 self .optimizer .step ('length_regulator' )
326340 self .optimizer .scheduler (key = 'cfm' )
327341 self .optimizer .scheduler (key = 'length_regulator' )
328342
329343 return loss .detach ().item ()
344+
330345 def train_one_epoch (self ):
331346 _ = [self .model [key ].train () for key in self .model ]
332347 for i , batch in enumerate (tqdm (self .train_dataloader )):
333348 batch = [b .to (self .device ) for b in batch ]
334349 loss = self .train_one_step (batch )
335- self .ema_loss = self .ema_loss * self .loss_smoothing_rate + loss * (1 - self .loss_smoothing_rate ) if self .iters > 0 else loss
350+ self .ema_loss = (
351+ self .ema_loss * self .loss_smoothing_rate + loss * (1 - self .loss_smoothing_rate )
352+ if self .iters > 0 else loss
353+ )
336354 if self .iters % self .log_interval == 0 :
337355 print (f"epoch { self .epoch } , step { self .iters } , loss: { self .ema_loss } " )
338356 self .iters += 1
357+
339358 if self .iters >= self .max_steps :
340359 break
360+
341361 if self .iters % self .save_interval == 0 :
342362 print ('Saving..' )
343363 state = {
@@ -347,13 +367,15 @@ def train_one_epoch(self):
347367 'iters' : self .iters ,
348368 'epoch' : self .epoch ,
349369 }
350- save_path = os .path .join (self .log_dir , 'DiT_epoch_%05d_step_%05d.pth' % (self .epoch , self .iters ))
370+ save_path = os .path .join (
371+ self .log_dir ,
372+ f'DiT_epoch_{ self .epoch :05d} _step_{ self .iters :05d} .pth'
373+ )
351374 torch .save (state , save_path )
352375
353376 # find all checkpoints and remove old ones
354377 checkpoints = glob .glob (os .path .join (self .log_dir , 'DiT_epoch_*.pth' ))
355378 if len (checkpoints ) > 2 :
356- # sort by step
357379 checkpoints .sort (key = lambda x : int (x .split ('_' )[- 1 ].split ('.' )[0 ]))
358380 for cp in checkpoints [:- 2 ]:
359381 os .remove (cp )
@@ -364,15 +386,34 @@ def train(self):
364386 for epoch in range (self .n_epochs ):
365387 self .epoch = epoch
366388 self .train_one_epoch ()
389+ # Save after each epoch
390+ print ('Epoch completed. Saving..' )
391+ state = {
392+ 'net' : {key : self .model [key ].state_dict () for key in self .model },
393+ 'optimizer' : self .optimizer .state_dict (),
394+ 'scheduler' : self .optimizer .scheduler_state_dict (),
395+ 'iters' : self .iters ,
396+ 'epoch' : self .epoch ,
397+ }
398+ save_path = os .path .join (
399+ self .log_dir ,
400+ f'DiT_epoch_{ self .epoch :05d} _step_{ self .iters :05d} .pth'
401+ )
402+ torch .save (state , save_path )
403+ print (f"Checkpoint saved at { save_path } " )
404+
367405 if self .iters >= self .max_steps :
368406 break
369- print ('Saving..' )
407+
408+ print ('Saving final model..' )
370409 state = {
371410 'net' : {key : self .model [key ].state_dict () for key in self .model },
372411 }
373412 os .makedirs (self .log_dir , exist_ok = True )
374413 save_path = os .path .join (self .log_dir , 'ft_model.pth' )
375414 torch .save (state , save_path )
415+ print (f"Final model saved at { save_path } " )
416+
376417
377418def main (args ):
378419 trainer = Trainer (
@@ -387,8 +428,12 @@ def main(args):
387428 num_workers = args .num_workers ,
388429 )
389430 trainer .train ()
390-
431+
391432if __name__ == '__main__' :
433+ if sys .platform == 'win32' :
434+ mp .freeze_support ()
435+ mp .set_start_method ('spawn' , force = True )
436+
392437 parser = argparse .ArgumentParser ()
393438 parser .add_argument ('--config' , type = str , default = './configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml' )
394439 parser .add_argument ('--pretrained-ckpt' , type = str , default = None )
@@ -400,4 +445,5 @@ def main(args):
400445 parser .add_argument ('--save-every' , type = int , default = 500 )
401446 parser .add_argument ('--num-workers' , type = int , default = 0 )
402447 args = parser .parse_args ()
403- main (args )
448+
449+ main (args )
0 commit comments