1212 "max" : 30.0 ,
1313}
1414# assume single speaker
15+ def to_mel_fn (wave , mel_fn_args ):
16+ return mel_spectrogram (wave , ** mel_fn_args )
17+
1518class FT_Dataset (torch .utils .data .Dataset ):
16- def __init__ (self ,
17- data_path ,
18- spect_params ,
19- sr = 22050 ,
20- batch_size = 1 ,
21- ):
19+ def __init__ (
20+ self ,
21+ data_path ,
22+ spect_params ,
23+ sr = 22050 ,
24+ batch_size = 1 ,
25+ ):
2226 self .data_path = data_path
23- # recursively find all files in data_path
2427 self .data = []
2528 for root , _ , files in os .walk (data_path ):
2629 for file in files :
27- if (file .endswith (".wav" ) or
28- file .endswith (".mp3" ) or
29- file .endswith (".flac" ) or
30- file .endswith (".ogg" ) or
31- file .endswith (".m4a" ) or
32- file .endswith (".opus" )):
30+ if file .endswith ((".wav" , ".mp3" , ".flac" , ".ogg" , ".m4a" , ".opus" )):
3331 self .data .append (os .path .join (root , file ))
3432
35- mel_fn_args = {
33+ self .sr = sr
34+ self .mel_fn_args = {
3635 "n_fft" : spect_params ['n_fft' ],
3736 "win_size" : spect_params ['win_length' ],
3837 "hop_size" : spect_params ['hop_length' ],
@@ -42,11 +41,8 @@ def __init__(self,
4241 "fmax" : None if spect_params ['fmax' ] == "None" else spect_params ['fmax' ],
4342 "center" : False
4443 }
45- self .to_mel = lambda x : mel_spectrogram (x , ** mel_fn_args )
46- self .sr = sr
4744
4845 assert len (self .data ) != 0
49- # if dataset length is less than batch size, repeat the dataset
5046 while len (self .data ) < batch_size :
5147 self .data += self .data
5248
@@ -64,17 +60,14 @@ def __getitem__(self, idx):
6460 if len (speech ) < self .sr * duration_setting ["min" ] or len (speech ) > self .sr * duration_setting ["max" ]:
6561 print (f"Audio { wav_path } is too short or too long, skipping" )
6662 return self .__getitem__ (random .randint (0 , len (self )))
67- return_dict = {
68- 'audio' : speech ,
69- 'sr' : orig_sr
70- }
71- wave , orig_sr = return_dict ['audio' ], return_dict ['sr' ]
7263 if orig_sr != self .sr :
73- wave = librosa .resample (wave , orig_sr , self .sr )
74- wave = torch .from_numpy (wave ).float ()
75- mel = self .to_mel (wave .unsqueeze (0 )).squeeze (0 )
64+ speech = librosa .resample (speech , orig_sr , self .sr )
65+
66+ wave = torch .from_numpy (speech ).float ().unsqueeze (0 )
67+ mel = to_mel_fn (wave , self .mel_fn_args ).squeeze (0 )
68+
69+ return wave .squeeze (0 ), mel
7670
77- return wave , mel
7871
7972def build_ft_dataloader (data_path , spect_params , sr , batch_size = 1 , num_workers = 0 ):
8073 dataset = FT_Dataset (data_path , spect_params , sr , batch_size )
@@ -130,4 +123,4 @@ def collate(batch):
130123 wave , mel , wave_lengths , mel_lengths = batch
131124 print (wave .shape , mel .shape )
132125 if idx == 10 :
133- break
126+ break
0 commit comments