Skip to content

Commit 2675766

Browse files
committed
stream : speed-up real-time streaming at cost of some accuracy
1 parent a728be9 commit 2675766

File tree

3 files changed

+27
-21
lines changed

3 files changed

+27
-21
lines changed

examples/stream/stream.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ int main(int argc, char ** argv) {
217217
const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE;
218218
const int n_samples_len = (params.length_ms/1000.0)*WHISPER_SAMPLE_RATE;
219219
const int n_samples_30s = 30*WHISPER_SAMPLE_RATE;
220+
const int n_samples_keep = 0.2*WHISPER_SAMPLE_RATE;
220221

221222
std::vector<float> pcmf32(n_samples_30s, 0.0f);
222223
std::vector<float> pcmf32_old;
@@ -299,7 +300,7 @@ int main(int argc, char ** argv) {
299300
//const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new));
300301

301302
// take up to params.length_ms audio from previous iteration
302-
const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_len - n_samples_new));
303+
const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new));
303304

304305
//printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
305306

@@ -373,7 +374,8 @@ int main(int argc, char ** argv) {
373374
if ((n_iter % n_new_line) == 0) {
374375
printf("\n");
375376

376-
pcmf32_old.clear();
377+
// keep part of the audio for next iteration to try to mitigate word boundary issues
378+
pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
377379
}
378380
}
379381
}

whisper.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
613613
const int n_audio_state = hparams.n_audio_state;
614614
const int n_audio_layer = hparams.n_audio_layer;
615615

616-
const int n_text_ctx = hparams.n_text_ctx;
616+
const int n_text_ctx = hparams.n_text_ctx;
617617
const int n_text_state = hparams.n_text_state;
618618
const int n_text_layer = hparams.n_text_layer;
619619

@@ -748,7 +748,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
748748
const int n_audio_state = hparams.n_audio_state;
749749
const int n_audio_layer = hparams.n_audio_layer;
750750

751-
const int n_text_ctx = hparams.n_text_ctx;
751+
const int n_text_ctx = hparams.n_text_ctx;
752752
const int n_text_state = hparams.n_text_state;
753753
const int n_text_layer = hparams.n_text_layer;
754754

@@ -967,7 +967,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
967967

968968
// key/value memory for the cross-attention layer
969969
{
970-
const int n_audio_ctx = hparams.n_audio_ctx;
970+
const int n_audio_ctx = hparams.n_audio_ctx;
971971

972972
const int n_mem = n_text_layer*n_audio_ctx;
973973
const int n_elements = n_text_state*n_mem;
@@ -1054,6 +1054,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
10541054
}
10551055
}
10561056

1057+
model.e_pe->ne[1] = WHISPER_EXPERIMENT_AUDIO_CTX;
1058+
10571059
fin.close();
10581060

10591061
return true;
@@ -1076,13 +1078,11 @@ static bool whisper_encode(
10761078
const auto & mel_inp = wctx.mel;
10771079
const auto & hparams = model.hparams;
10781080

1079-
const int n_ctx = hparams.n_audio_ctx;
1081+
const int n_ctx = WHISPER_EXPERIMENT_AUDIO_CTX;
10801082
const int n_state = hparams.n_audio_state;
10811083
const int n_head = hparams.n_audio_head;
10821084
const int n_layer = hparams.n_audio_layer;
10831085

1084-
const int N = n_ctx;
1085-
10861086
const int n_mels = hparams.n_mels;
10871087
assert(mel_inp.n_mel == n_mels);
10881088

@@ -1198,24 +1198,24 @@ static bool whisper_encode(
11981198
ggml_permute(ctxL,
11991199
ggml_cpy(ctxL,
12001200
Qcur,
1201-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1201+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
12021202
0, 2, 1, 3);
12031203

12041204
struct ggml_tensor * K =
12051205
ggml_permute(ctxL,
12061206
ggml_cpy(ctxL,
12071207
Kcur,
1208-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1208+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
12091209
0, 2, 1, 3);
12101210

12111211
struct ggml_tensor * V =
12121212
ggml_cpy(ctxL,
12131213
ggml_permute(ctxL,
12141214
ggml_reshape_3d(ctxL,
12151215
Vcur,
1216-
n_state/n_head, n_head, N),
1216+
n_state/n_head, n_head, n_ctx),
12171217
1, 2, 0, 3),
1218-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
1218+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
12191219
);
12201220

12211221
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
@@ -1224,14 +1224,14 @@ static bool whisper_encode(
12241224
ggml_permute(ctxL,
12251225
ggml_cpy(ctxL,
12261226
Qcur,
1227-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
1227+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
12281228
0, 2, 1, 3);
12291229

12301230
struct ggml_tensor * K =
12311231
ggml_permute(ctxL,
12321232
ggml_cpy(ctxL,
12331233
Kcur,
1234-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1234+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
12351235
0, 2, 1, 3);
12361236

12371237
// K * Q
@@ -1249,7 +1249,7 @@ static bool whisper_encode(
12491249
// ggml_permute(ctxL,
12501250
// ggml_cpy(ctxL,
12511251
// Vcur,
1252-
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1252+
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
12531253
// 1, 2, 0, 3);
12541254

12551255
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@@ -1259,9 +1259,9 @@ static bool whisper_encode(
12591259
ggml_permute(ctxL,
12601260
ggml_reshape_3d(ctxL,
12611261
Vcur,
1262-
n_state/n_head, n_head, N),
1262+
n_state/n_head, n_head, n_ctx),
12631263
0, 2, 1, 3),
1264-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
1264+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
12651265
);
12661266

12671267
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
@@ -1271,7 +1271,7 @@ static bool whisper_encode(
12711271

12721272
cur = ggml_cpy(ctxL,
12731273
KQV_merged,
1274-
ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
1274+
ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx));
12751275
}
12761276

12771277
// projection
@@ -1474,7 +1474,7 @@ static bool whisper_decode(
14741474
const int n_layer = hparams.n_text_layer;
14751475

14761476
const int N = n_tokens;
1477-
const int M = hparams.n_audio_ctx;
1477+
const int M = WHISPER_EXPERIMENT_AUDIO_CTX;
14781478

14791479
struct ggml_init_params params = {
14801480
.mem_size = wctx.buf_compute.size(),
@@ -2609,6 +2609,7 @@ int whisper_full(
26092609

26102610
// timestamp token - update sliding window
26112611
if (token.id > whisper_token_beg(ctx)) {
2612+
token.id = std::min(50*WHISPER_CHUNK_SIZE, token.id - whisper_token_beg(ctx)) + whisper_token_beg(ctx);
26122613
seek_delta = 2*(token.id - whisper_token_beg(ctx));
26132614
result_len = i + 1;
26142615
}
@@ -2623,7 +2624,7 @@ int whisper_full(
26232624
//}
26242625

26252626
// end of text token
2626-
if (token.id == whisper_token_eot(ctx)) {
2627+
if (token.id == whisper_token_eot(ctx) || (i > WHISPER_EXPERIMENT_MAX_TOKENS_PER_SEGMENT)) {
26272628
if (result_len == 0) {
26282629
if (seek + seek_delta + 100 >= seek_end) {
26292630
result_len = i + 1;
@@ -2805,7 +2806,7 @@ int whisper_full_parallel(
28052806

28062807
// key/value memory for the cross-attention layer
28072808
{
2808-
const int n_audio_ctx = hparams.n_audio_ctx;
2809+
const int n_audio_ctx = hparams.n_audio_ctx;
28092810

28102811
const int n_mem = n_text_layer*n_audio_ctx;
28112812
const int n_elements = n_text_state*n_mem;

whisper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
#define WHISPER_HOP_LENGTH 160
2525
#define WHISPER_CHUNK_SIZE 30
2626

27+
#define WHISPER_EXPERIMENT_AUDIO_CTX 384
28+
#define WHISPER_EXPERIMENT_MAX_TOKENS_PER_SEGMENT 32
29+
2730
#ifdef __cplusplus
2831
extern "C" {
2932
#endif

0 commit comments

Comments
 (0)