Skip to content

Commit aff3097

Browse files
Add multiprocessing guard for Windows (spawn + freeze_support)
1 parent 258fde9 commit aff3097

File tree

1 file changed

+105
-59
lines changed

1 file changed

+105
-59
lines changed

train.py

Lines changed: 105 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
2+
import sys
23
os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
34
import torch
5+
import torch.multiprocessing as mp
46
import random
57
import librosa
68
import yaml
@@ -9,14 +11,12 @@
911
import torchaudio.compliance.kaldi as kaldi
1012
import glob
1113
from tqdm import tqdm
14+
import shutil
1215

1316
from modules.commons import recursive_munch, build_model, load_checkpoint
1417
from optimizers import build_optimizer
1518
from data.ft_dataset import build_ft_dataloader
1619
from hf_utils import load_custom_model_from_hf
17-
import shutil
18-
19-
2020

2121

2222
class Trainer:
@@ -79,23 +79,22 @@ def __init__(self,
7979

8080
# initialize optimizers after preparing models for compatibility with FSDP
8181
self.optimizer = build_optimizer({key: self.model[key] for key in self.model},
82-
lr=float(scheduler_params['base_lr']))
82+
lr=float(scheduler_params['base_lr']))
8383

8484
if pretrained_ckpt_path is None:
85-
# find latest checkpoint with name pattern of 'T2V_epoch_*_step_*.pth'
85+
# find latest checkpoint
8686
available_checkpoints = glob.glob(os.path.join(self.log_dir, "DiT_epoch_*_step_*.pth"))
8787
if len(available_checkpoints) > 0:
88-
# find the checkpoint that has the highest step number
8988
latest_checkpoint = max(
9089
available_checkpoints, key=lambda x: int(x.split("_")[-1].split(".")[0])
9190
)
9291
earliest_checkpoint = min(
9392
available_checkpoints, key=lambda x: int(x.split("_")[-1].split(".")[0])
9493
)
95-
# delete the earliest checkpoint
94+
# delete the earliest checkpoint if we have more than 2
9695
if (
97-
earliest_checkpoint != latest_checkpoint
98-
and len(available_checkpoints) > 2
96+
earliest_checkpoint != latest_checkpoint
97+
and len(available_checkpoints) > 2
9998
):
10099
os.remove(earliest_checkpoint)
101100
print(f"Removed {earliest_checkpoint}")
@@ -108,16 +107,18 @@ def __init__(self,
108107
latest_checkpoint = pretrained_ckpt_path
109108

110109
if os.path.exists(latest_checkpoint):
111-
self.model, self.optimizer, self.epoch, self.iters = load_checkpoint(self.model, self.optimizer, latest_checkpoint,
112-
load_only_params=True,
113-
ignore_modules=[],
114-
is_distributed=False)
110+
self.model, self.optimizer, self.epoch, self.iters = load_checkpoint(
111+
self.model, self.optimizer, latest_checkpoint,
112+
load_only_params=True,
113+
ignore_modules=[],
114+
is_distributed=False
115+
)
115116
print(f"Loaded checkpoint from {latest_checkpoint}")
116117
else:
117118
self.epoch, self.iters = 0, 0
118-
print("Failed to load any checkpoint, this implies you are training from scratch.")
119+
print("Failed to load any checkpoint, training from scratch.")
120+
119121
def build_sv_model(self, device, config):
120-
# speaker verification model
121122
from modules.campplus.DTDNN import CAMPPlus
122123
self.campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
123124
campplus_sd_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
@@ -126,16 +127,17 @@ def build_sv_model(self, device, config):
126127
self.campplus_model.eval()
127128
self.campplus_model.to(device)
128129
self.sv_fn = self.campplus_model
130+
129131
def build_f0_fn(self, device, config):
130132
from modules.rmvpe import RMVPE
131133
model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
132134
self.rmvpe = RMVPE(model_path, is_half=False, device=device)
133135
self.f0_fn = self.rmvpe
136+
134137
def build_converter(self, device, config):
135-
# speaker perturbation model
136138
from modules.openvoice.api import ToneColorConverter
137139
ckpt_converter, config_converter = load_custom_model_from_hf("myshell-ai/OpenVoiceV2", "converter/checkpoint.pth", "converter/config.json")
138-
self.tone_color_converter = ToneColorConverter(config_converter, device=device,)
140+
self.tone_color_converter = ToneColorConverter(config_converter, device=device)
139141
self.tone_color_converter.load_ckpt(ckpt_converter)
140142
self.tone_color_converter.model.eval()
141143
se_db_path = load_custom_model_from_hf("Plachta/Seed-VC", "se_db.pt", None)
@@ -146,9 +148,7 @@ def build_vocoder(self, device, config):
146148
vocoder_name = config['model_params']['vocoder'].get('name', None)
147149
if vocoder_type == 'bigvgan':
148150
from modules.bigvgan import bigvgan
149-
bigvgan_name = vocoder_name
150-
self.bigvgan_model = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=False)
151-
# remove weight norm in the model and set to eval mode
151+
self.bigvgan_model = bigvgan.BigVGAN.from_pretrained(vocoder_name, use_cuda_kernel=False)
152152
self.bigvgan_model.remove_weight_norm()
153153
self.bigvgan_model = self.bigvgan_model.eval().to(device)
154154
vocoder_fn = self.bigvgan_model
@@ -158,7 +158,7 @@ def build_vocoder(self, device, config):
158158
hift_config = yaml.safe_load(open('configs/hifigan.yml', 'r'))
159159
hift_path = load_custom_model_from_hf("FunAudioLLM/CosyVoice-300M", 'hift.pt', None)
160160
self.hift_gen = HiFTGenerator(**hift_config['hift'],
161-
f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor']))
161+
f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor']))
162162
self.hift_gen.load_state_dict(torch.load(hift_path, map_location='cpu'))
163163
self.hift_gen.eval()
164164
self.hift_gen.to(device)
@@ -168,21 +168,25 @@ def build_vocoder(self, device, config):
168168
self.vocoder_fn = vocoder_fn
169169

170170
def build_semantic_fn(self, device, config):
171-
# speech tokenizer
172171
speech_tokenizer_type = config['model_params']['speech_tokenizer'].get('type', 'cosyvoice')
173172
if speech_tokenizer_type == 'whisper':
174173
from transformers import AutoFeatureExtractor, WhisperModel
175174
whisper_model_name = config['model_params']['speech_tokenizer']['name']
176175
self.whisper_model = WhisperModel.from_pretrained(whisper_model_name).to(device)
177176
self.whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_model_name)
177+
# remove decoder to save memory
178178
del self.whisper_model.decoder
179+
179180
def semantic_fn(waves_16k):
180-
ori_inputs = self.whisper_feature_extractor([w16k.cpu().numpy() for w16k in waves_16k],
181-
return_tensors="pt",
182-
return_attention_mask=True,
183-
sampling_rate=16000,)
181+
ori_inputs = self.whisper_feature_extractor(
182+
[w16k.cpu().numpy() for w16k in waves_16k],
183+
return_tensors="pt",
184+
return_attention_mask=True,
185+
sampling_rate=16000,
186+
)
184187
ori_input_features = self.whisper_model._mask_input_features(
185-
ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
188+
ori_inputs.input_features, attention_mask=ori_inputs.attention_mask
189+
).to(device)
186190
with torch.no_grad():
187191
ori_outputs = self.whisper_model.encoder(
188192
ori_input_features.to(self.whisper_model.encoder.dtype),
@@ -194,6 +198,7 @@ def semantic_fn(waves_16k):
194198
S_ori = ori_outputs.last_hidden_state.to(torch.float32)
195199
S_ori = S_ori[:, :waves_16k.size(-1) // 320 + 1]
196200
return S_ori
201+
197202
elif speech_tokenizer_type == 'xlsr':
198203
from transformers import (
199204
Wav2Vec2FeatureExtractor,
@@ -209,15 +214,14 @@ def semantic_fn(waves_16k):
209214
self.wav2vec_model = self.wav2vec_model.half()
210215

211216
def semantic_fn(waves_16k):
212-
ori_waves_16k_input_list = [
213-
waves_16k[bib].cpu().numpy()
214-
for bib in range(len(waves_16k))
215-
]
216-
ori_inputs = self.wav2vec_feature_extractor(ori_waves_16k_input_list,
217-
return_tensors="pt",
218-
return_attention_mask=True,
219-
padding=True,
220-
sampling_rate=16000).to(device)
217+
ori_waves_16k_input_list = [waves_16k[bib].cpu().numpy() for bib in range(len(waves_16k))]
218+
ori_inputs = self.wav2vec_feature_extractor(
219+
ori_waves_16k_input_list,
220+
return_tensors="pt",
221+
return_attention_mask=True,
222+
padding=True,
223+
sampling_rate=16000
224+
).to(device)
221225
with torch.no_grad():
222226
ori_outputs = self.wav2vec_model(
223227
ori_inputs.input_values.half(),
@@ -246,11 +250,12 @@ def train_one_step(self, batch):
246250
se_batch = self.tone_color_converter.extract_se(waves_22k, wave_lengths_22k)
247251

248252
ref_se_idx = torch.randint(0, len(self.se_db), (B,))
249-
ref_se = self.se_db[ref_se_idx]
250-
ref_se = ref_se.to(self.device)
253+
ref_se = self.se_db[ref_se_idx].to(self.device)
251254

252255
# convert
253-
converted_waves_22k = self.tone_color_converter.convert(waves_22k, wave_lengths_22k, se_batch, ref_se).squeeze(1)
256+
converted_waves_22k = self.tone_color_converter.convert(
257+
waves_22k, wave_lengths_22k, se_batch, ref_se
258+
).squeeze(1)
254259

255260
if self.sr != 22050:
256261
converted_waves = torchaudio.functional.resample(converted_waves_22k, 22050, self.sr)
@@ -260,6 +265,7 @@ def train_one_step(self, batch):
260265
waves_16k = torchaudio.functional.resample(waves, self.sr, 16000)
261266
wave_lengths_16k = (wave_lengths.float() * 16000 / self.sr).long()
262267
converted_waves_16k = torchaudio.functional.resample(converted_waves, self.sr, 16000)
268+
263269
# extract S_alt (perturbed speech tokens)
264270
S_ori = self.semantic_fn(waves_16k)
265271
S_alt = self.semantic_fn(converted_waves_16k)
@@ -268,11 +274,14 @@ def train_one_step(self, batch):
268274
F0_ori = self.rmvpe.infer_from_audio_batch(waves_16k)
269275
else:
270276
F0_ori = None
277+
271278
# interpolate speech token to match acoustic feature length
272279
alt_cond, _, alt_codes, alt_commitment_loss, alt_codebook_loss = (
273-
self.model.length_regulator(S_alt, ylens=target_lengths, f0=F0_ori))
280+
self.model.length_regulator(S_alt, ylens=target_lengths, f0=F0_ori)
281+
)
274282
ori_cond, _, ori_codes, ori_commitment_loss, ori_codebook_loss = (
275-
self.model.length_regulator(S_ori, ylens=target_lengths, f0=F0_ori))
283+
self.model.length_regulator(S_ori, ylens=target_lengths, f0=F0_ori)
284+
)
276285
if alt_commitment_loss is None:
277286
alt_commitment_loss = 0
278287
alt_codebook_loss = 0
@@ -281,10 +290,10 @@ def train_one_step(self, batch):
281290

282291
# randomly set a length as prompt
283292
prompt_len_max = target_lengths - 1
284-
prompt_len = (torch.rand([B], device=alt_cond.device) * prompt_len_max).floor().to(dtype=torch.long)
293+
prompt_len = (torch.rand([B], device=alt_cond.device) * prompt_len_max).floor().long()
285294
prompt_len[torch.rand([B], device=alt_cond.device) < 0.1] = 0
286295

287-
# for prompt cond token, it must be from ori_cond instead of alt_cond
296+
# for prompt cond token, use ori_cond instead of alt_cond
288297
cond = alt_cond.clone()
289298
for bib in range(B):
290299
cond[bib, :prompt_len[bib]] = ori_cond[bib, :prompt_len[bib]]
@@ -295,13 +304,16 @@ def train_one_step(self, batch):
295304
cond = cond[:, :common_min_len]
296305
target_lengths = torch.clamp(target_lengths, max=common_min_len)
297306
x = target
298-
# style vectors are extracted from prompt only to avoid inference time OOD
307+
308+
# style vectors are extracted from the prompt only
299309
feat_list = []
300310
for bib in range(B):
301-
feat = kaldi.fbank(waves_16k[bib:bib + 1, :wave_lengths_16k[bib]],
302-
num_mel_bins=80,
303-
dither=0,
304-
sample_frequency=16000)
311+
feat = kaldi.fbank(
312+
waves_16k[bib:bib + 1, :wave_lengths_16k[bib]],
313+
num_mel_bins=80,
314+
dither=0,
315+
sample_frequency=16000
316+
)
305317
feat = feat - feat.mean(dim=0, keepdim=True)
306318
feat_list.append(feat)
307319
y_list = []
@@ -313,31 +325,39 @@ def train_one_step(self, batch):
313325

314326
loss, _ = self.model.cfm(x, target_lengths, prompt_len, cond, y)
315327

316-
loss_total = (loss +
317-
(alt_commitment_loss + ori_commitment_loss) * 0.05 +
318-
(ori_codebook_loss + alt_codebook_loss) * 0.15)
328+
loss_total = (
329+
loss +
330+
(alt_commitment_loss + ori_commitment_loss) * 0.05 +
331+
(ori_codebook_loss + alt_codebook_loss) * 0.15
332+
)
319333

320334
self.optimizer.zero_grad()
321335
loss_total.backward()
322-
grad_norm_g = torch.nn.utils.clip_grad_norm_(self.model.cfm.parameters(), 10.0)
323-
grad_norm_g2 = torch.nn.utils.clip_grad_norm_(self.model.length_regulator.parameters(), 10.0)
336+
torch.nn.utils.clip_grad_norm_(self.model.cfm.parameters(), 10.0)
337+
torch.nn.utils.clip_grad_norm_(self.model.length_regulator.parameters(), 10.0)
324338
self.optimizer.step('cfm')
325339
self.optimizer.step('length_regulator')
326340
self.optimizer.scheduler(key='cfm')
327341
self.optimizer.scheduler(key='length_regulator')
328342

329343
return loss.detach().item()
344+
330345
def train_one_epoch(self):
331346
_ = [self.model[key].train() for key in self.model]
332347
for i, batch in enumerate(tqdm(self.train_dataloader)):
333348
batch = [b.to(self.device) for b in batch]
334349
loss = self.train_one_step(batch)
335-
self.ema_loss = self.ema_loss * self.loss_smoothing_rate + loss * (1 - self.loss_smoothing_rate) if self.iters > 0 else loss
350+
self.ema_loss = (
351+
self.ema_loss * self.loss_smoothing_rate + loss * (1 - self.loss_smoothing_rate)
352+
if self.iters > 0 else loss
353+
)
336354
if self.iters % self.log_interval == 0:
337355
print(f"epoch {self.epoch}, step {self.iters}, loss: {self.ema_loss}")
338356
self.iters += 1
357+
339358
if self.iters >= self.max_steps:
340359
break
360+
341361
if self.iters % self.save_interval == 0:
342362
print('Saving..')
343363
state = {
@@ -347,13 +367,15 @@ def train_one_epoch(self):
347367
'iters': self.iters,
348368
'epoch': self.epoch,
349369
}
350-
save_path = os.path.join(self.log_dir, 'DiT_epoch_%05d_step_%05d.pth' % (self.epoch, self.iters))
370+
save_path = os.path.join(
371+
self.log_dir,
372+
f'DiT_epoch_{self.epoch:05d}_step_{self.iters:05d}.pth'
373+
)
351374
torch.save(state, save_path)
352375

353376
# find all checkpoints and remove old ones
354377
checkpoints = glob.glob(os.path.join(self.log_dir, 'DiT_epoch_*.pth'))
355378
if len(checkpoints) > 2:
356-
# sort by step
357379
checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
358380
for cp in checkpoints[:-2]:
359381
os.remove(cp)
@@ -364,15 +386,34 @@ def train(self):
364386
for epoch in range(self.n_epochs):
365387
self.epoch = epoch
366388
self.train_one_epoch()
389+
# Save after each epoch
390+
print('Epoch completed. Saving..')
391+
state = {
392+
'net': {key: self.model[key].state_dict() for key in self.model},
393+
'optimizer': self.optimizer.state_dict(),
394+
'scheduler': self.optimizer.scheduler_state_dict(),
395+
'iters': self.iters,
396+
'epoch': self.epoch,
397+
}
398+
save_path = os.path.join(
399+
self.log_dir,
400+
f'DiT_epoch_{self.epoch:05d}_step_{self.iters:05d}.pth'
401+
)
402+
torch.save(state, save_path)
403+
print(f"Checkpoint saved at {save_path}")
404+
367405
if self.iters >= self.max_steps:
368406
break
369-
print('Saving..')
407+
408+
print('Saving final model..')
370409
state = {
371410
'net': {key: self.model[key].state_dict() for key in self.model},
372411
}
373412
os.makedirs(self.log_dir, exist_ok=True)
374413
save_path = os.path.join(self.log_dir, 'ft_model.pth')
375414
torch.save(state, save_path)
415+
print(f"Final model saved at {save_path}")
416+
376417

377418
def main(args):
378419
trainer = Trainer(
@@ -387,8 +428,12 @@ def main(args):
387428
num_workers=args.num_workers,
388429
)
389430
trainer.train()
390-
431+
391432
if __name__ == '__main__':
433+
if sys.platform == 'win32':
434+
mp.freeze_support()
435+
mp.set_start_method('spawn', force=True)
436+
392437
parser = argparse.ArgumentParser()
393438
parser.add_argument('--config', type=str, default='./configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml')
394439
parser.add_argument('--pretrained-ckpt', type=str, default=None)
@@ -400,4 +445,5 @@ def main(args):
400445
parser.add_argument('--save-every', type=int, default=500)
401446
parser.add_argument('--num-workers', type=int, default=0)
402447
args = parser.parse_args()
403-
main(args)
448+
449+
main(args)

0 commit comments

Comments
 (0)