Skip to content

Commit 9490f44

Browse files
committed
fix fp16 for inference.py
1 parent 41104fc commit 9490f44

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,15 +357,15 @@ def main(args):
357357
chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
358358
is_last_chunk = processed_frames + max_source_window >= cond.size(1)
359359
cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
360-
with torch.autocast(device_type=device.type, dtype=torch.float16):
360+
with torch.autocast(device_type=device.type, dtype=torch.float16 if fp16 else torch.float32):
361361
# Voice Conversion
362362
vc_target = model.cfm.inference(cat_condition,
363363
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
364364
mel2, style2, None, diffusion_steps,
365365
inference_cfg_rate=inference_cfg_rate)
366366
vc_target = vc_target[:, :, mel2.size(-1):]
367-
vc_wave = vocoder_fn(vc_target).squeeze()
368-
vc_wave = vc_wave[None, :]
367+
vc_wave = vocoder_fn(vc_target.float()).squeeze()
368+
vc_wave = vc_wave[None, :]
369369
if processed_frames == 0:
370370
if is_last_chunk:
371371
output_wave = vc_wave[0].cpu().numpy()

0 commit comments

Comments
 (0)