Skip to content

Commit 78c431d

Browse files
authored
cleanup whisper a little (ml-explore#639)
1 parent f6283ef commit 78c431d

File tree

6 files changed

+252
-236
lines changed

6 files changed

+252
-236
lines changed

llms/mlx_lm/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,13 @@ def generate(
239239
),
240240
range(max_tokens),
241241
):
242-
if token == tokenizer.eos_token_id:
243-
break
242+
token = token.item()
244243
if n == 0:
245244
prompt_time = time.perf_counter() - tic
246245
tic = time.perf_counter()
247-
tokens.append(token.item())
246+
if token == tokenizer.eos_token_id:
247+
break
248+
tokens.append(token)
248249

249250
if verbose:
250251
s = tokenizer.decode(tokens)

whisper/convert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def _download(url: str, root: str) -> str:
9191
output.write(buffer)
9292
loop.update(len(buffer))
9393

94-
model_bytes = open(download_target, "rb").read()
94+
with open(download_target, "rb") as fid:
95+
model_bytes = fid.read()
9596
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
9697
raise RuntimeError(
9798
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."

whisper/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def test_transcribe_alice(self):
297297
"temperature": 0.0,
298298
"avg_logprob": -0.1350895343440594,
299299
"compression_ratio": 1.6208333333333333,
300-
"no_speech_prob": 0.002246702555567026,
300+
"no_speech_prob": 0.009053784422576427,
301301
}
302302

303303
def check_segment(seg, expected):

whisper/whisper/audio.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
5858
except CalledProcessError as e:
5959
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
6060

61-
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
61+
return mx.array(np.frombuffer(out, np.int16)).flatten().astype(mx.float32) / 32768.0
6262

6363

6464
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@@ -73,8 +73,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
7373
if array.shape[axis] < length:
7474
pad_widths = [(0, 0)] * array.ndim
7575
pad_widths[axis] = (0, length - array.shape[axis])
76-
pad_fn = mx.pad if isinstance(array, mx.array) else np.pad
77-
array = pad_fn(array, pad_widths)
76+
array = mx.pad(array, pad_widths)
7877

7978
return array
8079

@@ -154,9 +153,9 @@ def log_mel_spectrogram(
154153
"""
155154
device = mx.default_device()
156155
mx.set_default_device(mx.cpu)
157-
if not isinstance(audio, mx.array):
158-
if isinstance(audio, str):
159-
audio = load_audio(audio)
156+
if isinstance(audio, str):
157+
audio = load_audio(audio)
158+
elif not isinstance(audio, mx.array):
160159
audio = mx.array(audio)
161160

162161
if padding > 0:

0 commit comments

Comments
 (0)