@@ -94,13 +94,22 @@ def custom_infer(model_set,
9494 reference_wav_name = new_reference_wav_name
9595
9696 converted_waves_16k = input_wav_res
97- start_event = torch .cuda .Event (enable_timing = True )
98- end_event = torch .cuda .Event (enable_timing = True )
99- torch .cuda .synchronize ()
97+ if device .type == "mps" :
98+ start_event = torch .mps .event .Event (enable_timing = True )
99+ end_event = torch .mps .event .Event (enable_timing = True )
100+ torch .mps .synchronize ()
101+ else :
102+ start_event = torch .cuda .Event (enable_timing = True )
103+ end_event = torch .cuda .Event (enable_timing = True )
104+ torch .cuda .synchronize ()
105+
100106 start_event .record ()
101107 S_alt = semantic_fn (converted_waves_16k .unsqueeze (0 ))
102108 end_event .record ()
103- torch .cuda .synchronize () # Wait for the events to be recorded!
109+ if device .type == "mps" :
110+ torch .mps .synchronize () # MPS - Wait for the events to be recorded!
111+ else :
112+ torch .cuda .synchronize () # Wait for the events to be recorded!
104113 elapsed_time_ms = start_event .elapsed_time (end_event )
105114 print (f"Time taken for semantic_fn: { elapsed_time_ms } ms" )
106115
@@ -466,7 +475,14 @@ def launcher(self):
466475 initial_folder = os .path .join (
467476 os .getcwd (), "examples/reference"
468477 ),
469- file_types = ((". wav" ), (". mp3" ), (". flac" ), (". m4a" ), (". ogg" ), (". opus" )),
478+ file_types = [
479+ ("WAV Files" , "*.wav" ),
480+ ("MP3 Files" , "*.mp3" ),
481+ ("FLAC Files" , "*.flac" ),
482+ ("M4A Files" , "*.m4a" ),
483+ ("OGG Files" , "*.ogg" ),
484+ ("Opus Files" , "*.opus" ),
485+ ],
470486 ),
471487 ],
472488 ],
@@ -786,7 +802,10 @@ def set_values(self, values):
786802 return True
787803
788804 def start_vc (self ):
789- torch .cuda .empty_cache ()
805+ if device .type == "mps" :
806+ torch .mps .empty_cache ()
807+ else :
808+ torch .cuda .empty_cache ()
790809 self .reference_wav , _ = librosa .load (
791810 self .gui_config .reference_audio_path , sr = self .model_set [- 1 ]["sampling_rate" ]
792811 )
@@ -942,9 +961,14 @@ def audio_callback(
942961 indata = librosa .to_mono (indata .T )
943962
944963 # VAD first
945- start_event = torch .cuda .Event (enable_timing = True )
946- end_event = torch .cuda .Event (enable_timing = True )
947- torch .cuda .synchronize ()
964+ if device .type == "mps" :
965+ start_event = torch .mps .event .Event (enable_timing = True )
966+ end_event = torch .mps .event .Event (enable_timing = True )
967+ torch .mps .synchronize ()
968+ else :
969+ start_event = torch .cuda .Event (enable_timing = True )
970+ end_event = torch .cuda .Event (enable_timing = True )
971+ torch .cuda .synchronize ()
948972 start_event .record ()
949973 indata_16k = librosa .resample (indata , orig_sr = self .gui_config .samplerate , target_sr = 16000 )
950974 res = self .vad_model .generate (input = indata_16k , cache = self .vad_cache , is_final = False , chunk_size = self .vad_chunk_size )
@@ -955,7 +979,10 @@ def audio_callback(
955979 elif len (res_value ) % 2 == 1 and self .vad_speech_detected :
956980 self .set_speech_detected_false_at_end_flag = True
957981 end_event .record ()
958- torch .cuda .synchronize () # Wait for the events to be recorded!
982+ if device .type == "mps" :
983+ torch .mps .synchronize () # MPS - Wait for the events to be recorded!
984+ else :
985+ torch .cuda .synchronize () # Wait for the events to be recorded!
959986 elapsed_time_ms = start_event .elapsed_time (end_event )
960987 print (f"Time taken for VAD: { elapsed_time_ms } ms" )
961988
@@ -993,9 +1020,14 @@ def audio_callback(
9931020 if self .function == "vc" :
9941021 if self .gui_config .extra_time_ce - self .gui_config .extra_time < 0 :
9951022 raise ValueError ("Content encoder extra context must be greater than DiT extra context!" )
996- start_event = torch .cuda .Event (enable_timing = True )
997- end_event = torch .cuda .Event (enable_timing = True )
998- torch .cuda .synchronize ()
1023+ if device .type == "mps" :
1024+ start_event = torch .mps .event .Event (enable_timing = True )
1025+ end_event = torch .mps .event .Event (enable_timing = True )
1026+ torch .mps .synchronize ()
1027+ else :
1028+ start_event = torch .cuda .Event (enable_timing = True )
1029+ end_event = torch .cuda .Event (enable_timing = True )
1030+ torch .cuda .synchronize ()
9991031 start_event .record ()
10001032 infer_wav = custom_infer (
10011033 self .model_set ,
@@ -1014,7 +1046,10 @@ def audio_callback(
10141046 if self .resampler2 is not None :
10151047 infer_wav = self .resampler2 (infer_wav )
10161048 end_event .record ()
1017- torch .cuda .synchronize () # Wait for the events to be recorded!
1049+ if device .type == "mps" :
1050+ torch .mps .synchronize () # MPS - Wait for the events to be recorded!
1051+ else :
1052+ torch .cuda .synchronize () # Wait for the events to be recorded!
10181053 elapsed_time_ms = start_event .elapsed_time (end_event )
10191054 print (f"Time taken for VC: { elapsed_time_ms } ms" )
10201055 if not self .vad_speech_detected :
@@ -1037,12 +1072,16 @@ def audio_callback(
10371072 )
10381073 + 1e-8
10391074 )
1040- if sys .platform == "darwin" :
1041- _ , sola_offset = torch .max (cor_nom [0 , 0 ] / cor_den [0 , 0 ])
1042- sola_offset = sola_offset .item ()
1043- else :
10441075
1045- sola_offset = torch .argmax (cor_nom [0 , 0 ] / cor_den [0 , 0 ])
1076+ tensor = cor_nom [0 , 0 ] / cor_den [0 , 0 ]
1077+ if tensor .numel () > 1 : # If tensor has multiple elements
1078+ if sys .platform == "darwin" :
1079+ _ , sola_offset = torch .max (tensor , dim = 0 )
1080+ sola_offset = sola_offset .item ()
1081+ else :
1082+ sola_offset = torch .argmax (tensor , dim = 0 ).item ()
1083+ else :
1084+ sola_offset = tensor .item ()
10461085
10471086 print (f"sola_offset = { int (sola_offset )} " )
10481087
@@ -1141,5 +1180,11 @@ def get_device_channels(self):
11411180 parser .add_argument ("--gpu" , type = int , help = "Which GPU id to use" , default = 0 )
11421181 args = parser .parse_args ()
11431182 cuda_target = f"cuda:{ args .gpu } " if args .gpu else "cuda"
1144- device = torch .device (cuda_target if torch .cuda .is_available () else "cpu" )
1183+
1184+ if torch .cuda .is_available ():
1185+ device = torch .device (cuda_target )
1186+ elif torch .backends .mps .is_available ():
1187+ device = torch .device ("mps" )
1188+ else :
1189+ device = torch .device ("cpu" )
11451190 gui = GUI (args )
0 commit comments