Skip to content

Commit e3ed784

Browse files
committed
Added Mac M Series (Apple Silicon) support
1 parent 993feb7 commit e3ed784

File tree

10 files changed

+145
-32
lines changed

10 files changed

+145
-32
lines changed

app.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
from pydub import AudioSegment
1010

1111
# Load model and configuration
12-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12+
13+
if torch.cuda.is_available():
14+
device = torch.device("cuda")
15+
elif torch.backends.mps.is_available():
16+
device = torch.device("mps")
17+
else:
18+
device = torch.device("cpu")
1319

1420
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
1521
"DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
@@ -233,8 +239,12 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
233239
F0_ori = rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.03)
234240
F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.03)
235241

236-
F0_ori = torch.from_numpy(F0_ori).to(device)[None]
237-
F0_alt = torch.from_numpy(F0_alt).to(device)[None]
242+
if device == "mps":
243+
F0_ori = torch.from_numpy(F0_ori).float().to(device)[None]
244+
F0_alt = torch.from_numpy(F0_alt).float().to(device)[None]
245+
else:
246+
F0_ori = torch.from_numpy(F0_ori).to(device)[None]
247+
F0_alt = torch.from_numpy(F0_alt).to(device)[None]
238248

239249
voiced_F0_ori = F0_ori[F0_ori > 1]
240250
voiced_F0_alt = F0_alt[F0_alt > 1]

app_svc.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,12 @@ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_c
294294
F0_ori = f0_fn(ref_waves_16k[0], thred=0.03)
295295
F0_alt = f0_fn(converted_waves_16k[0], thred=0.03)
296296

297-
F0_ori = torch.from_numpy(F0_ori).to(device)[None]
298-
F0_alt = torch.from_numpy(F0_alt).to(device)[None]
297+
if device.type == "mps":
298+
F0_ori = torch.from_numpy(F0_ori).float().to(device)[None]
299+
F0_alt = torch.from_numpy(F0_alt).float().to(device)[None]
300+
else:
301+
F0_ori = torch.from_numpy(F0_ori).to(device)[None]
302+
F0_alt = torch.from_numpy(F0_alt).to(device)[None]
299303

300304
voiced_F0_ori = F0_ori[F0_ori > 1]
301305
voiced_F0_alt = F0_alt[F0_alt > 1]
@@ -436,5 +440,11 @@ def main(args):
436440
parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0)
437441
args = parser.parse_args()
438442
cuda_target = f"cuda:{args.gpu}" if args.gpu else "cuda"
439-
device = torch.device(cuda_target if torch.cuda.is_available() else "cpu")
443+
444+
if torch.cuda.is_available():
445+
device = torch.device(cuda_target)
446+
elif torch.backends.mps.is_available():
447+
device = torch.device("mps")
448+
else:
449+
device = torch.device("cpu")
440450
main(args)

app_vc.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,5 +389,11 @@ def main(args):
389389
parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0)
390390
args = parser.parse_args()
391391
cuda_target = f"cuda:{args.gpu}" if args.gpu else "cuda"
392-
device = torch.device(cuda_target if torch.cuda.is_available() else "cpu")
392+
393+
if torch.cuda.is_available():
394+
device = torch.device(cuda_target)
395+
elif torch.backends.mps.is_available():
396+
device = torch.device("mps")
397+
else:
398+
device = torch.device("cpu")
393399
main(args)

dac/utils/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def decode(
4343
model_bitrate: str
4444
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
4545
device : str, optional
46-
Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
46+
Device to use, by default "cuda". Use "mps" on Apple Silicon devices or if "cpu", the model will be loaded on the CPU.
4747
model_type : str, optional
4848
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
4949
"""

dac/utils/encode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def encode(
4747
n_quantizers : int, optional
4848
Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
4949
device : str, optional
50-
Device to use, by default "cuda"
50+
Device to use, by default "cuda". Use "mps" on Apple Silicon devices.
5151
model_type : str, optional
5252
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
5353
"""

eval.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@
2323
from resemblyzer import preprocess_wav, VoiceEncoder
2424

2525
# Load model and configuration
26-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26+
27+
if torch.cuda.is_available():
28+
device = torch.device("cuda")
29+
elif torch.backends.mps.is_available():
30+
device = torch.device("mps")
31+
else:
32+
device = torch.device("cpu")
2733

2834
from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector
2935
from transformers import Wav2Vec2Processor, HubertForCTC

inference.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,14 @@
2525

2626

2727
# Load model and configuration
28-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29+
if torch.cuda.is_available():
30+
device = torch.device("cuda")
31+
elif torch.backends.mps.is_available():
32+
device = torch.device("mps")
33+
else:
34+
device = torch.device("cpu")
35+
2936
fp16 = False
3037
def load_models(args):
3138
global fp16

modules/rmvpe.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,13 @@ def __init__(self, model_path: str, is_half, device=None, use_jit=False):
486486
self.resample_kernel = {}
487487
self.is_half = is_half
488488
if device is None:
489-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
489+
#device = "cuda:0" if torch.cuda.is_available() else "cpu"
490+
if torch.cuda.is_available():
491+
device = "cuda:0"
492+
elif torch.backends.mps.is_available():
493+
device = "mps"
494+
else:
495+
device = "cpu"
490496
self.device = device
491497
self.mel_extractor = MelSpectrogram(
492498
is_half, 128, 16000, 1024, 160, None, 30, 8000

real-time-gui.py

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,22 @@ def custom_infer(model_set,
9494
reference_wav_name = new_reference_wav_name
9595

9696
converted_waves_16k = input_wav_res
97-
start_event = torch.cuda.Event(enable_timing=True)
98-
end_event = torch.cuda.Event(enable_timing=True)
99-
torch.cuda.synchronize()
97+
if device.type == "mps":
98+
start_event = torch.mps.event.Event(enable_timing=True)
99+
end_event = torch.mps.event.Event(enable_timing=True)
100+
torch.mps.synchronize()
101+
else:
102+
start_event = torch.cuda.Event(enable_timing=True)
103+
end_event = torch.cuda.Event(enable_timing=True)
104+
torch.cuda.synchronize()
105+
100106
start_event.record()
101107
S_alt = semantic_fn(converted_waves_16k.unsqueeze(0))
102108
end_event.record()
103-
torch.cuda.synchronize() # Wait for the events to be recorded!
109+
if device.type == "mps":
110+
torch.mps.synchronize() # MPS - Wait for the events to be recorded!
111+
else:
112+
torch.cuda.synchronize() # Wait for the events to be recorded!
104113
elapsed_time_ms = start_event.elapsed_time(end_event)
105114
print(f"Time taken for semantic_fn: {elapsed_time_ms}ms")
106115

@@ -466,7 +475,14 @@ def launcher(self):
466475
initial_folder=os.path.join(
467476
os.getcwd(), "examples/reference"
468477
),
469-
file_types=((". wav"), (". mp3"), (". flac"), (". m4a"), (". ogg"), (". opus")),
478+
file_types=[
479+
("WAV Files", "*.wav"),
480+
("MP3 Files", "*.mp3"),
481+
("FLAC Files", "*.flac"),
482+
("M4A Files", "*.m4a"),
483+
("OGG Files", "*.ogg"),
484+
("Opus Files", "*.opus"),
485+
],
470486
),
471487
],
472488
],
@@ -786,7 +802,10 @@ def set_values(self, values):
786802
return True
787803

788804
def start_vc(self):
789-
torch.cuda.empty_cache()
805+
if device.type == "mps":
806+
torch.mps.empty_cache()
807+
else:
808+
torch.cuda.empty_cache()
790809
self.reference_wav, _ = librosa.load(
791810
self.gui_config.reference_audio_path, sr=self.model_set[-1]["sampling_rate"]
792811
)
@@ -942,9 +961,14 @@ def audio_callback(
942961
indata = librosa.to_mono(indata.T)
943962

944963
# VAD first
945-
start_event = torch.cuda.Event(enable_timing=True)
946-
end_event = torch.cuda.Event(enable_timing=True)
947-
torch.cuda.synchronize()
964+
if device.type == "mps":
965+
start_event = torch.mps.event.Event(enable_timing=True)
966+
end_event = torch.mps.event.Event(enable_timing=True)
967+
torch.mps.synchronize()
968+
else:
969+
start_event = torch.cuda.Event(enable_timing=True)
970+
end_event = torch.cuda.Event(enable_timing=True)
971+
torch.cuda.synchronize()
948972
start_event.record()
949973
indata_16k = librosa.resample(indata, orig_sr=self.gui_config.samplerate, target_sr=16000)
950974
res = self.vad_model.generate(input=indata_16k, cache=self.vad_cache, is_final=False, chunk_size=self.vad_chunk_size)
@@ -955,7 +979,10 @@ def audio_callback(
955979
elif len(res_value) % 2 == 1 and self.vad_speech_detected:
956980
self.set_speech_detected_false_at_end_flag = True
957981
end_event.record()
958-
torch.cuda.synchronize() # Wait for the events to be recorded!
982+
if device.type == "mps":
983+
torch.mps.synchronize() # MPS - Wait for the events to be recorded!
984+
else:
985+
torch.cuda.synchronize() # Wait for the events to be recorded!
959986
elapsed_time_ms = start_event.elapsed_time(end_event)
960987
print(f"Time taken for VAD: {elapsed_time_ms}ms")
961988

@@ -993,9 +1020,14 @@ def audio_callback(
9931020
if self.function == "vc":
9941021
if self.gui_config.extra_time_ce - self.gui_config.extra_time < 0:
9951022
raise ValueError("Content encoder extra context must be greater than DiT extra context!")
996-
start_event = torch.cuda.Event(enable_timing=True)
997-
end_event = torch.cuda.Event(enable_timing=True)
998-
torch.cuda.synchronize()
1023+
if device.type == "mps":
1024+
start_event = torch.mps.event.Event(enable_timing=True)
1025+
end_event = torch.mps.event.Event(enable_timing=True)
1026+
torch.mps.synchronize()
1027+
else:
1028+
start_event = torch.cuda.Event(enable_timing=True)
1029+
end_event = torch.cuda.Event(enable_timing=True)
1030+
torch.cuda.synchronize()
9991031
start_event.record()
10001032
infer_wav = custom_infer(
10011033
self.model_set,
@@ -1014,7 +1046,10 @@ def audio_callback(
10141046
if self.resampler2 is not None:
10151047
infer_wav = self.resampler2(infer_wav)
10161048
end_event.record()
1017-
torch.cuda.synchronize() # Wait for the events to be recorded!
1049+
if device.type == "mps":
1050+
torch.mps.synchronize() # MPS - Wait for the events to be recorded!
1051+
else:
1052+
torch.cuda.synchronize() # Wait for the events to be recorded!
10181053
elapsed_time_ms = start_event.elapsed_time(end_event)
10191054
print(f"Time taken for VC: {elapsed_time_ms}ms")
10201055
if not self.vad_speech_detected:
@@ -1037,12 +1072,16 @@ def audio_callback(
10371072
)
10381073
+ 1e-8
10391074
)
1040-
if sys.platform == "darwin":
1041-
_, sola_offset = torch.max(cor_nom[0, 0] / cor_den[0, 0])
1042-
sola_offset = sola_offset.item()
1043-
else:
10441075

1045-
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
1076+
tensor = cor_nom[0, 0] / cor_den[0, 0]
1077+
if tensor.numel() > 1: # If tensor has multiple elements
1078+
if sys.platform == "darwin":
1079+
_, sola_offset = torch.max(tensor, dim=0)
1080+
sola_offset = sola_offset.item()
1081+
else:
1082+
sola_offset = torch.argmax(tensor, dim=0).item()
1083+
else:
1084+
sola_offset = tensor.item()
10461085

10471086
print(f"sola_offset = {int(sola_offset)}")
10481087

@@ -1141,5 +1180,11 @@ def get_device_channels(self):
11411180
parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0)
11421181
args = parser.parse_args()
11431182
cuda_target = f"cuda:{args.gpu}" if args.gpu else "cuda"
1144-
device = torch.device(cuda_target if torch.cuda.is_available() else "cpu")
1183+
1184+
if torch.cuda.is_available():
1185+
device = torch.device(cuda_target)
1186+
elif torch.backends.mps.is_available():
1187+
device = torch.device("mps")
1188+
else:
1189+
device = torch.device("cpu")
11451190
gui = GUI(args)

requirements-mac.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
--extra-index-url https://download.pytorch.org/whl/cu121
2+
torch --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu
3+
torchvision --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu
4+
torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu
5+
scipy==1.13.1
6+
librosa==0.10.2
7+
huggingface-hub==0.23.4
8+
munch==4.0.0
9+
einops==0.8.0
10+
descript-audio-codec==1.0.0
11+
gradio==4.44.0
12+
pydub==0.25.1
13+
resemblyzer
14+
jiwer==3.0.3
15+
transformers==4.46.3
16+
FreeSimpleGUI==5.1.1
17+
soundfile==0.12.1
18+
sounddevice==0.5.0
19+
modelscope==1.18.1
20+
funasr==1.1.5
21+
numpy==1.26.4
22+
pyyaml
23+
python-dotenv

0 commit comments

Comments
 (0)