|
| 1 | +import scipy.io.wavfile |
| 2 | +from os.path import expanduser |
| 3 | +import os |
| 4 | +import array |
| 5 | +from pylab import * |
| 6 | +import scipy.signal |
| 7 | +import scipy |
| 8 | +import wave |
| 9 | +import numpy as np |
| 10 | +import time |
| 11 | +import sys |
| 12 | +import math |
| 13 | +import matplotlib |
| 14 | +import subprocess |
| 15 | + |
| 16 | +# Author: Brian K. Vogel |
| 17 | + |
| 18 | + |
| 19 | +fft_size = 2048 |
| 20 | +iterations = 300 |
| 21 | +hopsamp = fft_size // 8 |
| 22 | + |
| 23 | + |
| 24 | +def ensure_audio(): |
| 25 | + if not os.path.exists("audio"): |
| 26 | + print("Downloading audio dataset...") |
| 27 | + subprocess.check_output( |
| 28 | + "curl -SL https://storage.googleapis.com/wandb/audio.tar.gz | tar xz", shell=True) |
| 29 | + |
| 30 | + |
| 31 | +def griffin_lim(stft, scale): |
| 32 | + # Undo the rescaling. |
| 33 | + stft_modified_scaled = stft / scale |
| 34 | + stft_modified_scaled = stft_modified_scaled**0.5 |
| 35 | + # Use the Griffin&Lim algorithm to reconstruct an audio signal from the |
| 36 | + # magnitude spectrogram. |
| 37 | + x_reconstruct = reconstruct_signal_griffin_lim(stft_modified_scaled, |
| 38 | + fft_size, hopsamp, |
| 39 | + iterations) |
| 40 | + # The output signal must be in the range [-1, 1], otherwise we need to clip or normalize. |
| 41 | + max_sample = np.max(abs(x_reconstruct)) |
| 42 | + if max_sample > 1.0: |
| 43 | + x_reconstruct = x_reconstruct / max_sample |
| 44 | + return x_reconstruct |
| 45 | + |
| 46 | + |
| 47 | +def hz_to_mel(f_hz): |
| 48 | + """Convert Hz to mel scale. |
| 49 | +
|
| 50 | + This uses the formula from O'Shaugnessy's book. |
| 51 | + Args: |
| 52 | + f_hz (float): The value in Hz. |
| 53 | +
|
| 54 | + Returns: |
| 55 | + The value in mels. |
| 56 | + """ |
| 57 | + return 2595*np.log10(1.0 + f_hz/700.0) |
| 58 | + |
| 59 | + |
| 60 | +def mel_to_hz(m_mel): |
| 61 | + """Convert mel scale to Hz. |
| 62 | +
|
| 63 | + This uses the formula from O'Shaugnessy's book. |
| 64 | + Args: |
| 65 | + m_mel (float): The value in mels |
| 66 | +
|
| 67 | + Returns: |
| 68 | + The value in Hz |
| 69 | + """ |
| 70 | + return 700*(10**(m_mel/2595) - 1.0) |
| 71 | + |
| 72 | + |
| 73 | +def fft_bin_to_hz(n_bin, sample_rate_hz, fft_size): |
| 74 | + """Convert FFT bin index to frequency in Hz. |
| 75 | +
|
| 76 | + Args: |
| 77 | + n_bin (int or float): The FFT bin index. |
| 78 | + sample_rate_hz (int or float): The sample rate in Hz. |
| 79 | + fft_size (int or float): The FFT size. |
| 80 | +
|
| 81 | + Returns: |
| 82 | + The value in Hz. |
| 83 | + """ |
| 84 | + n_bin = float(n_bin) |
| 85 | + sample_rate_hz = float(sample_rate_hz) |
| 86 | + fft_size = float(fft_size) |
| 87 | + return n_bin*sample_rate_hz/(2.0*fft_size) |
| 88 | + |
| 89 | + |
| 90 | +def hz_to_fft_bin(f_hz, sample_rate_hz, fft_size): |
| 91 | + """Convert frequency in Hz to FFT bin index. |
| 92 | +
|
| 93 | + Args: |
| 94 | + f_hz (int or float): The frequency in Hz. |
| 95 | + sample_rate_hz (int or float): The sample rate in Hz. |
| 96 | + fft_size (int or float): The FFT size. |
| 97 | +
|
| 98 | + Returns: |
| 99 | + The FFT bin index as an int. |
| 100 | + """ |
| 101 | + f_hz = float(f_hz) |
| 102 | + sample_rate_hz = float(sample_rate_hz) |
| 103 | + fft_size = float(fft_size) |
| 104 | + fft_bin = int(np.round((f_hz*2.0*fft_size/sample_rate_hz))) |
| 105 | + if fft_bin >= fft_size: |
| 106 | + fft_bin = fft_size-1 |
| 107 | + return fft_bin |
| 108 | + |
| 109 | + |
| 110 | +def make_mel_filterbank(min_freq_hz, max_freq_hz, mel_bin_count, |
| 111 | + linear_bin_count, sample_rate_hz): |
| 112 | + """Create a mel filterbank matrix. |
| 113 | +
|
| 114 | + Create and return a mel filterbank matrix `filterbank` of shape (`mel_bin_count`, |
| 115 | + `linear_bin_couont`). The `filterbank` matrix can be used to transform a |
| 116 | + (linear scale) spectrum or spectrogram into a mel scale spectrum or |
| 117 | + spectrogram as follows: |
| 118 | +
|
| 119 | + `mel_scale_spectrum` = `filterbank`*'linear_scale_spectrum' |
| 120 | +
|
| 121 | + where linear_scale_spectrum' is a shape (`linear_bin_count`, `m`) and |
| 122 | + `mel_scale_spectrum` is shape ('mel_bin_count', `m`) where `m` is the number |
| 123 | + of spectral time slices. |
| 124 | +
|
| 125 | + Likewise, the reverse-direction transform can be performed as: |
| 126 | +
|
| 127 | + 'linear_scale_spectrum' = filterbank.T`*`mel_scale_spectrum` |
| 128 | +
|
| 129 | + Note that the process of converting to mel scale and then back to linear |
| 130 | + scale is lossy. |
| 131 | +
|
| 132 | + This function computes the mel-spaced filters such that each filter is triangular |
| 133 | + (in linear frequency) with response 1 at the center frequency and decreases linearly |
| 134 | + to 0 upon reaching an adjacent filter's center frequency. Note that any two adjacent |
| 135 | + filters will overlap having a response of 0.5 at the mean frequency of their |
| 136 | + respective center frequencies. |
| 137 | +
|
| 138 | + Args: |
| 139 | + min_freq_hz (float): The frequency in Hz corresponding to the lowest |
| 140 | + mel scale bin. |
| 141 | + max_freq_hz (flloat): The frequency in Hz corresponding to the highest |
| 142 | + mel scale bin. |
| 143 | + mel_bin_count (int): The number of mel scale bins. |
| 144 | + linear_bin_count (int): The number of linear scale (fft) bins. |
| 145 | + sample_rate_hz (float): The sample rate in Hz. |
| 146 | +
|
| 147 | + Returns: |
| 148 | + The mel filterbank matrix as an 2-dim Numpy array. |
| 149 | + """ |
| 150 | + min_mels = hz_to_mel(min_freq_hz) |
| 151 | + max_mels = hz_to_mel(max_freq_hz) |
| 152 | + # Create mel_bin_count linearly spaced values between these extreme mel values. |
| 153 | + mel_lin_spaced = np.linspace(min_mels, max_mels, num=mel_bin_count) |
| 154 | + # Map each of these mel values back into linear frequency (Hz). |
| 155 | + center_frequencies_hz = np.array([mel_to_hz(n) for n in mel_lin_spaced]) |
| 156 | + mels_per_bin = float(max_mels - min_mels)/float(mel_bin_count - 1) |
| 157 | + mels_start = min_mels - mels_per_bin |
| 158 | + hz_start = mel_to_hz(mels_start) |
| 159 | + fft_bin_start = hz_to_fft_bin(hz_start, sample_rate_hz, linear_bin_count) |
| 160 | + #print('fft_bin_start: ', fft_bin_start) |
| 161 | + mels_end = max_mels + mels_per_bin |
| 162 | + hz_stop = mel_to_hz(mels_end) |
| 163 | + fft_bin_stop = hz_to_fft_bin(hz_stop, sample_rate_hz, linear_bin_count) |
| 164 | + #print('fft_bin_stop: ', fft_bin_stop) |
| 165 | + # Map each center frequency to the closest fft bin index. |
| 166 | + linear_bin_indices = np.array([hz_to_fft_bin( |
| 167 | + f_hz, sample_rate_hz, linear_bin_count) for f_hz in center_frequencies_hz]) |
| 168 | + # Create filterbank matrix. |
| 169 | + filterbank = np.zeros((mel_bin_count, linear_bin_count)) |
| 170 | + for mel_bin in range(mel_bin_count): |
| 171 | + center_freq_linear_bin = int(linear_bin_indices[mel_bin].item()) |
| 172 | + # Create a triangular filter having the current center freq. |
| 173 | + # The filter will start with 0 response at left_bin (if it exists) |
| 174 | + # and ramp up to 1.0 at center_freq_linear_bin, and then ramp |
| 175 | + # back down to 0 response at right_bin (if it exists). |
| 176 | + |
| 177 | + # Create the left side of the triangular filter that ramps up |
| 178 | + # from 0 to a response of 1 at the center frequency. |
| 179 | + if center_freq_linear_bin > 1: |
| 180 | + # It is possible to create the left triangular filter. |
| 181 | + if mel_bin == 0: |
| 182 | + # Since this is the first center frequency, the left side |
| 183 | + # must start ramping up from linear bin 0 or 1 mel bin before the center freq. |
| 184 | + left_bin = max(0, fft_bin_start) |
| 185 | + else: |
| 186 | + # Start ramping up from the previous center frequency bin. |
| 187 | + left_bin = int(linear_bin_indices[mel_bin - 1].item()) |
| 188 | + for f_bin in range(left_bin, center_freq_linear_bin+1): |
| 189 | + if (center_freq_linear_bin - left_bin) > 0: |
| 190 | + response = float(f_bin - left_bin) / \ |
| 191 | + float(center_freq_linear_bin - left_bin) |
| 192 | + filterbank[mel_bin, f_bin] = response |
| 193 | + # Create the right side of the triangular filter that ramps down |
| 194 | + # from 1 to 0. |
| 195 | + if center_freq_linear_bin < linear_bin_count-2: |
| 196 | + # It is possible to create the right triangular filter. |
| 197 | + if mel_bin == mel_bin_count - 1: |
| 198 | + # Since this is the last mel bin, we must ramp down to response of 0 |
| 199 | + # at the last linear freq bin. |
| 200 | + right_bin = min(linear_bin_count - 1, fft_bin_stop) |
| 201 | + else: |
| 202 | + right_bin = int(linear_bin_indices[mel_bin + 1].item()) |
| 203 | + for f_bin in range(center_freq_linear_bin, right_bin+1): |
| 204 | + if (right_bin - center_freq_linear_bin) > 0: |
| 205 | + response = float(right_bin - f_bin) / \ |
| 206 | + float(right_bin - center_freq_linear_bin) |
| 207 | + filterbank[mel_bin, f_bin] = response |
| 208 | + filterbank[mel_bin, center_freq_linear_bin] = 1.0 |
| 209 | + |
| 210 | + return filterbank |
| 211 | + |
| 212 | + |
| 213 | +def stft_for_reconstruction(x, fft_size, hopsamp): |
| 214 | + """Compute and return the STFT of the supplied time domain signal x. |
| 215 | +
|
| 216 | + Args: |
| 217 | + x (1-dim Numpy array): A time domain signal. |
| 218 | + fft_size (int): FFT size. Should be a power of 2, otherwise DFT will be used. |
| 219 | + hopsamp (int): |
| 220 | +
|
| 221 | + Returns: |
| 222 | + The STFT. The rows are the time slices and columns are the frequency bins. |
| 223 | + """ |
| 224 | + window = np.hanning(fft_size) |
| 225 | + fft_size = int(fft_size) |
| 226 | + hopsamp = int(hopsamp) |
| 227 | + return np.array([np.fft.rfft(window*x[i:i+fft_size]) |
| 228 | + for i in range(0, len(x)-fft_size, hopsamp)]) |
| 229 | + |
| 230 | + |
| 231 | +def istft_for_reconstruction(X, fft_size, hopsamp): |
| 232 | + """Invert a STFT into a time domain signal. |
| 233 | +
|
| 234 | + Args: |
| 235 | + X (2-dim Numpy array): Input spectrogram. The rows are the time slices and columns are the frequency bins. |
| 236 | + fft_size (int): |
| 237 | + hopsamp (int): The hop size, in samples. |
| 238 | +
|
| 239 | + Returns: |
| 240 | + The inverse STFT. |
| 241 | + """ |
| 242 | + fft_size = int(fft_size) |
| 243 | + hopsamp = int(hopsamp) |
| 244 | + window = np.hanning(fft_size) |
| 245 | + time_slices = X.shape[0] |
| 246 | + len_samples = int(time_slices*hopsamp + fft_size) |
| 247 | + x = np.zeros(len_samples) |
| 248 | + for n, i in enumerate(range(0, len(x)-fft_size, hopsamp)): |
| 249 | + x[i:i+fft_size] += window*np.real(np.fft.irfft(X[n])) |
| 250 | + return x |
| 251 | + |
| 252 | + |
| 253 | +def get_signal(in_file, expected_fs=44100): |
| 254 | + """Load a wav file. |
| 255 | +
|
| 256 | + If the file contains more than one channel, return a mono file by taking |
| 257 | + the mean of all channels. |
| 258 | +
|
| 259 | + If the sample rate differs from the expected sample rate (default is 44100 Hz), |
| 260 | + raise an exception. |
| 261 | +
|
| 262 | + Args: |
| 263 | + in_file: The input wav file, which should have a sample rate of `expected_fs`. |
| 264 | + expected_fs (int): The expected sample rate of the input wav file. |
| 265 | +
|
| 266 | + Returns: |
| 267 | + The audio siganl as a 1-dim Numpy array. The values will be in the range [-1.0, 1.0]. fixme ( not yet) |
| 268 | + """ |
| 269 | + fs, y = scipy.io.wavfile.read(in_file) |
| 270 | + num_type = y[0].dtype |
| 271 | + if num_type == 'int16': |
| 272 | + y = y*(1.0/32768) |
| 273 | + elif num_type == 'int32': |
| 274 | + y = y*(1.0/2147483648) |
| 275 | + elif num_type == 'float32': |
| 276 | + # Nothing to do |
| 277 | + pass |
| 278 | + elif num_type == 'uint8': |
| 279 | + raise Exception('8-bit PCM is not supported.') |
| 280 | + else: |
| 281 | + raise Exception('Unknown format.') |
| 282 | + if fs != expected_fs: |
| 283 | + raise Exception('Invalid sample rate.') |
| 284 | + if y.ndim == 1: |
| 285 | + return y |
| 286 | + else: |
| 287 | + return y.mean(axis=1) |
| 288 | + |
| 289 | + |
| 290 | +def reconstruct_signal_griffin_lim(magnitude_spectrogram, fft_size, hopsamp, iterations): |
| 291 | + """Reconstruct an audio signal from a magnitude spectrogram. |
| 292 | +
|
| 293 | + Given a magnitude spectrogram as input, reconstruct |
| 294 | + the audio signal and return it using the Griffin-Lim algorithm from the paper: |
| 295 | + "Signal estimation from modified short-time fourier transform" by Griffin and Lim, |
| 296 | + in IEEE transactions on Acoustics, Speech, and Signal Processing. Vol ASSP-32, No. 2, April 1984. |
| 297 | +
|
| 298 | + Args: |
| 299 | + magnitude_spectrogram (2-dim Numpy array): The magnitude spectrogram. The rows correspond to the time slices |
| 300 | + and the columns correspond to frequency bins. |
| 301 | + fft_size (int): The FFT size, which should be a power of 2. |
| 302 | + hopsamp (int): The hope size in samples. |
| 303 | + iterations (int): Number of iterations for the Griffin-Lim algorithm. Typically a few hundred |
| 304 | + is sufficient. |
| 305 | +
|
| 306 | + Returns: |
| 307 | + The reconstructed time domain signal as a 1-dim Numpy array. |
| 308 | + """ |
| 309 | + time_slices = magnitude_spectrogram.shape[0] |
| 310 | + len_samples = int(time_slices*hopsamp + fft_size) |
| 311 | + # Initialize the reconstructed signal to noise. |
| 312 | + x_reconstruct = np.random.randn(len_samples) |
| 313 | + n = iterations # number of iterations of Griffin-Lim algorithm. |
| 314 | + while n > 0: |
| 315 | + n -= 1 |
| 316 | + reconstruction_spectrogram = stft_for_reconstruction( |
| 317 | + x_reconstruct, fft_size, hopsamp) |
| 318 | + reconstruction_angle = np.angle(reconstruction_spectrogram) |
| 319 | + # Discard magnitude part of the reconstruction and use the supplied magnitude spectrogram instead. |
| 320 | + proposal_spectrogram = magnitude_spectrogram * \ |
| 321 | + np.exp(1.0j*reconstruction_angle) |
| 322 | + prev_x = x_reconstruct |
| 323 | + x_reconstruct = istft_for_reconstruction( |
| 324 | + proposal_spectrogram, fft_size, hopsamp) |
| 325 | + diff = sqrt(sum((x_reconstruct - prev_x)**2)/x_reconstruct.size) |
| 326 | + #print('Reconstruction iteration: {}/{} RMSE: {} '.format(iterations - n, iterations, diff)) |
| 327 | + return x_reconstruct |
| 328 | + |
| 329 | + |
| 330 | +def save_audio_to_file(x, sample_rate, outfile='out.wav'): |
| 331 | + """Save a mono signal to a file. |
| 332 | +
|
| 333 | + Args: |
| 334 | + x (1-dim Numpy array): The audio signal to save. The signal values should be in the range [-1.0, 1.0]. |
| 335 | + sample_rate (int): The sample rate of the signal, in Hz. |
| 336 | + outfile: Name of the file to save. |
| 337 | +
|
| 338 | + """ |
| 339 | + x_max = np.max(abs(x)) |
| 340 | + assert x_max <= 1.0, 'Input audio value is out of range. Should be in the range [-1.0, 1.0].' |
| 341 | + x = x*32767.0 |
| 342 | + data = array.array('h') |
| 343 | + for i in range(len(x)): |
| 344 | + cur_samp = int(round(x[i])) |
| 345 | + data.append(cur_samp) |
| 346 | + f = wave.open(outfile, 'w') |
| 347 | + f.setparams((1, 2, sample_rate, 0, "NONE", "Uncompressed")) |
| 348 | + f.writeframes(data.tostring()) |
| 349 | + f.close() |
0 commit comments