Skip to content

Commit c7ef271

Browse files
committed
fix chunk transition bug in app.py, add long-form inference support for inference.py
1 parent 761986a commit c7ef271

File tree

2 files changed

+105
-34
lines changed

2 files changed

+105
-34
lines changed

app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,7 @@ def crossfade(chunk1, chunk2, overlap):
125125
return chunk2
126126

127127
# streaming and chunk processing related params
128-
max_context_window = sr // hop_length * 30
129128
overlap_frame_len = 16
130-
overlap_wave_len = overlap_frame_len * hop_length
131129
bitrate = "320k"
132130

133131
@torch.no_grad()
@@ -137,6 +135,9 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
137135
mel_fn = to_mel if not f0_condition else to_mel_f0
138136
bigvgan_fn = bigvgan_model if not f0_condition else bigvgan_44k_model
139137
sr = 22050 if not f0_condition else 44100
138+
hop_length = 256 if not f0_condition else 512
139+
max_context_window = sr // hop_length * 30
140+
overlap_wave_len = overlap_frame_len * hop_length
140141
# Load audio
141142
source_audio = librosa.load(source, sr=sr)[0]
142143
ref_audio = librosa.load(target, sr=sr)[0]

inference.py

Lines changed: 102 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import os
2+
3+
import numpy as np
4+
25
os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
36
import shutil
47
import warnings
@@ -230,9 +233,18 @@ def adjust_f0_semitones(f0_sequence, n_semitones):
230233
factor = 2 ** (n_semitones / 12)
231234
return f0_sequence * factor
232235

236+
def crossfade(chunk1, chunk2, overlap):
237+
fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
238+
fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
239+
if len(chunk2) < overlap:
240+
chunk2[:overlap] = chunk2[:overlap] * fade_in[:len(chunk2)] + (chunk1[-overlap:] * fade_out)[:len(chunk2)]
241+
else:
242+
chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
243+
return chunk2
244+
233245
@torch.no_grad()
234246
def main(args):
235-
model, semantic_fn, f0_fn, vocoder_fn, campplus_model, to_mel, mel_fn_args = load_models(args)
247+
model, semantic_fn, f0_fn, vocoder_fn, campplus_model, mel_fn, mel_fn_args = load_models(args)
236248
sr = mel_fn_args['sampling_rate']
237249
f0_condition = args.f0_condition
238250
auto_f0_adjust = args.auto_f0_adjust
@@ -246,36 +258,62 @@ def main(args):
246258
source_audio = librosa.load(source, sr=sr)[0]
247259
ref_audio = librosa.load(target_name, sr=sr)[0]
248260

249-
source_audio = source_audio[:sr * 30]
250-
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
251-
252-
ref_audio = ref_audio[:(sr * 30 - source_audio.size(-1))]
253-
ref_audio = torch.tensor(ref_audio).unsqueeze(0).float().to(device)
261+
sr = 22050 if not f0_condition else 44100
262+
hop_length = 256 if not f0_condition else 512
263+
max_context_window = sr // hop_length * 30
264+
overlap_frame_len = 16
265+
overlap_wave_len = overlap_frame_len * hop_length
254266

255-
source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
256-
ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
267+
# Process audio
268+
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
269+
ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)
257270

271+
time_vc_start = time.time()
272+
# Resample
258273
converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
274+
# if source audio less than 30 seconds, whisper can handle in one forward
275+
if converted_waves_16k.size(-1) <= 16000 * 30:
276+
S_alt = semantic_fn(converted_waves_16k)
277+
else:
278+
overlapping_time = 5 # 5 seconds
279+
S_alt_list = []
280+
buffer = None
281+
traversed_time = 0
282+
while traversed_time < converted_waves_16k.size(-1):
283+
if buffer is None: # first chunk
284+
chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
285+
else:
286+
chunk = torch.cat(
287+
[buffer, converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]],
288+
dim=-1)
289+
S_alt = semantic_fn(chunk)
290+
if traversed_time == 0:
291+
S_alt_list.append(S_alt)
292+
else:
293+
S_alt_list.append(S_alt[:, 50 * overlapping_time:])
294+
buffer = chunk[:, -16000 * overlapping_time:]
295+
traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
296+
S_alt = torch.cat(S_alt_list, dim=1)
297+
259298
ori_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
260-
S_alt = semantic_fn(converted_waves_16k)
261299
S_ori = semantic_fn(ori_waves_16k)
262300

263-
mel = to_mel(source_audio.to(device).float())
264-
mel2 = to_mel(ref_audio.to(device).float())
301+
mel = mel_fn(source_audio.to(device).float())
302+
mel2 = mel_fn(ref_audio.to(device).float())
265303

266304
target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
267305
target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
268306

269-
feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
307+
feat2 = torchaudio.compliance.kaldi.fbank(ori_waves_16k,
270308
num_mel_bins=80,
271309
dither=0,
272310
sample_frequency=16000)
273311
feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
274312
style2 = campplus_model(feat2.unsqueeze(0))
275313

276314
if f0_condition:
277-
F0_ori = f0_fn(ref_waves_16k[0], thred=0.03)
278-
F0_alt = f0_fn(source_waves_16k[0], thred=0.03)
315+
F0_ori = f0_fn(ori_waves_16k[0], thred=0.03)
316+
F0_alt = f0_fn(converted_waves_16k[0], thred=0.03)
279317

280318
F0_ori = torch.from_numpy(F0_ori).to(device)[None]
281319
F0_alt = torch.from_numpy(F0_alt).to(device)[None]
@@ -288,6 +326,7 @@ def main(args):
288326
voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
289327
median_log_f0_ori = torch.median(voiced_log_f0_ori)
290328
median_log_f0_alt = torch.median(voiced_log_f0_alt)
329+
291330
# shift alt log f0 level to ori log f0 level
292331
shifted_log_f0_alt = log_f0_alt.clone()
293332
if auto_f0_adjust:
@@ -301,22 +340,53 @@ def main(args):
301340
shifted_f0_alt = None
302341

303342
# Length regulation
304-
cond, _, codes, commitment_loss, codebook_loss = model.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
305-
prompt_condition, _, prompt_codes, commitment_loss, codebook_loss = model.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
306-
cat_condition = torch.cat([prompt_condition, cond], dim=1)
307-
308-
time_vc_start = time.time()
309-
with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32):
310-
vc_target = model.cfm.inference(
311-
cat_condition,
312-
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
313-
mel2, style2, None, diffusion_steps,
314-
inference_cfg_rate=inference_cfg_rate)
315-
vc_target = vc_target[:, :, mel2.size(-1):]
316-
317-
# Convert to waveform
318-
vc_wave = vocoder_fn(vc_target).squeeze() # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
319-
vc_wave = vc_wave[None, :].float()
343+
cond, _, codes, commitment_loss, codebook_loss = model.length_regulator(S_alt, ylens=target_lengths,
344+
n_quantizers=3,
345+
f0=shifted_f0_alt)
346+
prompt_condition, _, codes, commitment_loss, codebook_loss = model.length_regulator(S_ori,
347+
ylens=target2_lengths,
348+
n_quantizers=3,
349+
f0=F0_ori)
350+
351+
max_source_window = max_context_window - mel2.size(2)
352+
# split source condition (cond) into chunks
353+
processed_frames = 0
354+
generated_wave_chunks = []
355+
# generate chunk by chunk and stream the output
356+
while processed_frames < cond.size(1):
357+
chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
358+
is_last_chunk = processed_frames + max_source_window >= cond.size(1)
359+
cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
360+
with torch.autocast(device_type=device.type, dtype=torch.float16):
361+
# Voice Conversion
362+
vc_target = model.cfm.inference(cat_condition,
363+
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
364+
mel2, style2, None, diffusion_steps,
365+
inference_cfg_rate=inference_cfg_rate)
366+
vc_target = vc_target[:, :, mel2.size(-1):]
367+
vc_wave = vocoder_fn(vc_target).squeeze()
368+
vc_wave = vc_wave[None, :]
369+
if processed_frames == 0:
370+
if is_last_chunk:
371+
output_wave = vc_wave[0].cpu().numpy()
372+
generated_wave_chunks.append(output_wave)
373+
break
374+
output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
375+
generated_wave_chunks.append(output_wave)
376+
previous_chunk = vc_wave[0, -overlap_wave_len:]
377+
processed_frames += vc_target.size(2) - overlap_frame_len
378+
elif is_last_chunk:
379+
output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
380+
generated_wave_chunks.append(output_wave)
381+
processed_frames += vc_target.size(2) - overlap_frame_len
382+
break
383+
else:
384+
output_wave = crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(),
385+
overlap_wave_len)
386+
generated_wave_chunks.append(output_wave)
387+
previous_chunk = vc_wave[0, -overlap_wave_len:]
388+
processed_frames += vc_target.size(2) - overlap_frame_len
389+
vc_wave = torch.tensor(np.concatenate(generated_wave_chunks))[None, :].float()
320390
time_vc_end = time.time()
321391
print(f"RTF: {(time_vc_end - time_vc_start) / vc_wave.size(-1) * sr}")
322392

@@ -334,8 +404,8 @@ def main(args):
334404
parser.add_argument("--diffusion-steps", type=int, default=30)
335405
parser.add_argument("--length-adjust", type=float, default=1.0)
336406
parser.add_argument("--inference-cfg-rate", type=float, default=0.7)
337-
parser.add_argument("--f0-condition", type=str2bool, default=True)
338-
parser.add_argument("--auto-f0-adjust", type=str2bool, default=True)
407+
parser.add_argument("--f0-condition", type=str2bool, default=False)
408+
parser.add_argument("--auto-f0-adjust", type=str2bool, default=False)
339409
parser.add_argument("--semi-tone-shift", type=int, default=0)
340410
parser.add_argument("--checkpoint-path", type=str, help="Path to the checkpoint file", default=None)
341411
parser.add_argument("--config-path", type=str, help="Path to the config file", default=None)

0 commit comments

Comments
 (0)