Skip to content

Commit fab344c

Browse files
authored
Fix Inference (Plachtaa#2)
- load campplus from huggingface - load tokenizer dynamically
1 parent 74a535f commit fab344c

File tree

2 files changed

+109
-28
lines changed

2 files changed

+109
-28
lines changed

inference.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@
4545
# Load additional modules
4646
from modules.campplus.DTDNN import CAMPPlus
4747

48+
campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
4849
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
49-
campplus_model.load_state_dict(torch.load(config['model_params']['style_encoder']['campplus_path'], map_location='cpu'))
50+
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
5051
campplus_model.eval()
5152
campplus_model.to(device)
5253

@@ -103,6 +104,7 @@ def main(args):
103104
diffusion_steps = args.diffusion_steps
104105
length_adjust = args.length_adjust
105106
inference_cfg_rate = args.inference_cfg_rate
107+
n_quantizers = args.n_quantizers
106108
source_audio = librosa.load(source, sr=sr)[0]
107109
ref_audio = librosa.load(target_name, sr=sr)[0]
108110
# decoded_wav = encodec_model.decoder(encodec_latent)
@@ -117,43 +119,53 @@ def main(args):
117119
source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
118120
ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
119121

120-
S_alt = [
121-
cosyvoice_frontend.extract_speech_token(source_waves_16k, )
122-
]
123-
S_alt_lens = torch.LongTensor([s[1] for s in S_alt]).to(device)
124-
S_alt = torch.cat([torch.nn.functional.pad(s[0], (0, max(S_alt_lens) - s[0].size(1))) for s in S_alt], dim=0)
125-
126-
S_ori = [
127-
cosyvoice_frontend.extract_speech_token(ref_waves_16k, )
128-
]
129-
S_ori_lens = torch.LongTensor([s[1] for s in S_ori]).to(device)
130-
S_ori = torch.cat([torch.nn.functional.pad(s[0], (0, max(S_ori_lens) - s[0].size(1))) for s in S_ori], dim=0)
122+
if speech_tokenizer_type == "cosyvoice":
123+
S_alt = cosyvoice_frontend.extract_speech_token(source_waves_16k)[0]
124+
S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0]
125+
elif speech_tokenizer_type == "facodec":
126+
converted_waves_24k = torchaudio.functional.resample(source_audio, sr, 24000)
127+
wave_lengths_24k = torch.LongTensor([converted_waves_24k.size(1)]).to(converted_waves_24k.device)
128+
waves_input = converted_waves_24k.unsqueeze(1)
129+
z = codec_encoder.encoder(waves_input)
130+
(quantized, codes) = codec_encoder.quantizer(z, waves_input)
131+
S_alt = torch.cat([codes[1], codes[0]], dim=1)
132+
133+
# S_ori should be extracted in the same way
134+
waves_24k = torchaudio.functional.resample(ref_audio, sr, 24000)
135+
waves_input = waves_24k.unsqueeze(1)
136+
z = codec_encoder.encoder(waves_input)
137+
(quantized, codes) = codec_encoder.quantizer(z, waves_input)
138+
S_ori = torch.cat([codes[1], codes[0]], dim=1)
131139

132140
mel = to_mel(source_audio.to(device).float())
133141
mel2 = to_mel(ref_audio.to(device).float())
134142

135-
target = mel
136-
target2 = mel2
137-
138-
target_lengths = torch.LongTensor([int(target.size(2) * length_adjust)]).to(target.device)
139-
target2_lengths = torch.LongTensor([target2.size(2)]).to(target2.device)
143+
target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
144+
target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
140145

141-
feat2 = kaldi.fbank(ref_waves_16k,
142-
num_mel_bins=80,
143-
dither=0,
144-
sample_frequency=16000)
146+
feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
147+
num_mel_bins=80,
148+
dither=0,
149+
sample_frequency=16000)
145150
feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
146151
style2 = campplus_model(feat2.unsqueeze(0))
147152

148-
cond = model.length_regulator(S_alt, ylens=target_lengths)[0]
149-
prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths)[0]
153+
# Length regulation
154+
cond = model.length_regulator(S_alt, ylens=target_lengths, n_quantizers=int(n_quantizers))[0]
155+
prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=int(n_quantizers))[0]
150156
cat_condition = torch.cat([prompt_condition, cond], dim=1)
151-
prompt_target = target2
152157

153158
time_vc_start = time.time()
154-
vc_target = model.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(prompt_target.device), prompt_target, style2, None, diffusion_steps, inference_cfg_rate=inference_cfg_rate)
155-
vc_target = vc_target[:, :, prompt_target.size(-1):]
159+
vc_target = model.cfm.inference(
160+
cat_condition,
161+
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
162+
mel2, style2, None, diffusion_steps,
163+
inference_cfg_rate=inference_cfg_rate)
164+
vc_target = vc_target[:, :, mel2.size(-1):]
165+
166+
# Convert to waveform
156167
vc_wave = hift_gen.inference(vc_target)
168+
157169
time_vc_end = time.time()
158170
print(f"RTF: {(time_vc_end - time_vc_start) / vc_wave.size(-1) * sr}")
159171

@@ -163,11 +175,10 @@ def main(args):
163175
torchaudio.save(os.path.join(args.output, f"vc_{source_name}_{target_name}_{length_adjust}_{diffusion_steps}_{inference_cfg_rate}.wav"), vc_wave.cpu(), sr)
164176

165177

166-
167178
if __name__ == "__main__":
168179
parser = argparse.ArgumentParser()
169180
parser.add_argument("--source", type=str, default="./examples/source/source_s1.wav")
170-
parser.add_argument("--target", type=str, default="./examples/target/s1p1.wav")
181+
parser.add_argument("--target", type=str, default="./examples/reference/s1p1.wav")
171182
parser.add_argument("--output", type=str, default="./reconstructed")
172183
parser.add_argument("--diffusion-steps", type=int, default=10)
173184
parser.add_argument("--length-adjust", type=float, default=1.0)

ruff.toml

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Exclude a variety of commonly ignored directories.
2+
exclude = [
3+
".bzr",
4+
".direnv",
5+
".eggs",
6+
".git",
7+
".hg",
8+
".mypy_cache",
9+
".nox",
10+
".pants.d",
11+
".ruff_cache",
12+
".svn",
13+
".tox",
14+
".venv",
15+
"__pypackages__",
16+
"_build",
17+
"buck-out",
18+
"build",
19+
"dist",
20+
"node_modules",
21+
"venv",
22+
]
23+
extend-exclude = []
24+
line-length = 88
25+
indent-width = 4
26+
target-version = "py310"
27+
show-fixes = true
28+
src = [".", "modules"]
29+
30+
[lint]
31+
select = [
32+
"E", "F", "B", "Q", "I", "C90", "N", "D", "UP", "YTT", "ANN", "S", "BLE",
33+
"FBT", "A", "COM", "C4", "DTZ", "T10", "EM", "EXE", "ISC", "ICN", "INP",
34+
"PIE", "T20", "PT", "Q", "RET", "SIM", "ARG", "ERA", "PD", "PGH", "PL",
35+
"TRY", "RUF",
36+
]
37+
ignore = [
38+
"D105",
39+
"D107",
40+
"D203",
41+
"D213",
42+
"S101", # assert-used
43+
"INP001", # implicit-namespace-package
44+
"ANN101", # missing-type-self
45+
"ANN102", # missing-type-cls
46+
"ANN204", # missing-return-type-special-method
47+
"ERA001", # commented-out-code
48+
"ANN002", # missing-type-args
49+
"ANN003", # missing-type-kwargs
50+
"RET504", # unnecessary-assign
51+
"COM812", # TBD: some conflict
52+
"ISC001", # TBD: some conflict
53+
]
54+
fixable = ["ALL"]
55+
unfixable = []
56+
57+
[format]
58+
quote-style = "double"
59+
indent-style = "space"
60+
61+
[lint.isort]
62+
# force-sort-within-sections and lines-between-types should be incompatible
63+
force-sort-within-sections = false
64+
lines-between-types = 1
65+
force-single-line = true
66+
no-sections = false
67+
from-first = false
68+
69+
[lint.pydocstyle]
70+
convention = "google"

0 commit comments

Comments
 (0)