@@ -613,7 +613,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
613
613
const int n_audio_state = hparams.n_audio_state ;
614
614
const int n_audio_layer = hparams.n_audio_layer ;
615
615
616
- const int n_text_ctx = hparams.n_text_ctx ;
616
+ const int n_text_ctx = hparams.n_text_ctx ;
617
617
const int n_text_state = hparams.n_text_state ;
618
618
const int n_text_layer = hparams.n_text_layer ;
619
619
@@ -748,7 +748,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
748
748
const int n_audio_state = hparams.n_audio_state ;
749
749
const int n_audio_layer = hparams.n_audio_layer ;
750
750
751
- const int n_text_ctx = hparams.n_text_ctx ;
751
+ const int n_text_ctx = hparams.n_text_ctx ;
752
752
const int n_text_state = hparams.n_text_state ;
753
753
const int n_text_layer = hparams.n_text_layer ;
754
754
@@ -967,7 +967,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
967
967
968
968
// key/value memory for the cross-attention layer
969
969
{
970
- const int n_audio_ctx = hparams.n_audio_ctx ;
970
+ const int n_audio_ctx = hparams.n_audio_ctx ;
971
971
972
972
const int n_mem = n_text_layer*n_audio_ctx;
973
973
const int n_elements = n_text_state*n_mem;
@@ -1054,6 +1054,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
1054
1054
}
1055
1055
}
1056
1056
1057
+ model.e_pe ->ne [1 ] = WHISPER_EXPERIMENT_AUDIO_CTX;
1058
+
1057
1059
fin.close ();
1058
1060
1059
1061
return true ;
@@ -1076,13 +1078,11 @@ static bool whisper_encode(
1076
1078
const auto & mel_inp = wctx.mel ;
1077
1079
const auto & hparams = model.hparams ;
1078
1080
1079
- const int n_ctx = hparams. n_audio_ctx ;
1081
+ const int n_ctx = WHISPER_EXPERIMENT_AUDIO_CTX ;
1080
1082
const int n_state = hparams.n_audio_state ;
1081
1083
const int n_head = hparams.n_audio_head ;
1082
1084
const int n_layer = hparams.n_audio_layer ;
1083
1085
1084
- const int N = n_ctx;
1085
-
1086
1086
const int n_mels = hparams.n_mels ;
1087
1087
assert (mel_inp.n_mel == n_mels);
1088
1088
@@ -1198,24 +1198,24 @@ static bool whisper_encode(
1198
1198
ggml_permute (ctxL,
1199
1199
ggml_cpy (ctxL,
1200
1200
Qcur,
1201
- ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N )),
1201
+ ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx )),
1202
1202
0 , 2 , 1 , 3 );
1203
1203
1204
1204
struct ggml_tensor * K =
1205
1205
ggml_permute (ctxL,
1206
1206
ggml_cpy (ctxL,
1207
1207
Kcur,
1208
- ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N )),
1208
+ ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx )),
1209
1209
0 , 2 , 1 , 3 );
1210
1210
1211
1211
struct ggml_tensor * V =
1212
1212
ggml_cpy (ctxL,
1213
1213
ggml_permute (ctxL,
1214
1214
ggml_reshape_3d (ctxL,
1215
1215
Vcur,
1216
- n_state/n_head, n_head, N ),
1216
+ n_state/n_head, n_head, n_ctx ),
1217
1217
1 , 2 , 0 , 3 ),
1218
- ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, N , n_state/n_head, n_head)
1218
+ ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_ctx , n_state/n_head, n_head)
1219
1219
);
1220
1220
1221
1221
struct ggml_tensor * KQV = ggml_flash_attn (ctxL, Q, K, V, false );
@@ -1224,14 +1224,14 @@ static bool whisper_encode(
1224
1224
ggml_permute (ctxL,
1225
1225
ggml_cpy (ctxL,
1226
1226
Qcur,
1227
- ggml_new_tensor_3d (ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N )),
1227
+ ggml_new_tensor_3d (ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx )),
1228
1228
0 , 2 , 1 , 3 );
1229
1229
1230
1230
struct ggml_tensor * K =
1231
1231
ggml_permute (ctxL,
1232
1232
ggml_cpy (ctxL,
1233
1233
Kcur,
1234
- ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N )),
1234
+ ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx )),
1235
1235
0 , 2 , 1 , 3 );
1236
1236
1237
1237
// K * Q
@@ -1249,7 +1249,7 @@ static bool whisper_encode(
1249
1249
// ggml_permute(ctxL,
1250
1250
// ggml_cpy(ctxL,
1251
1251
// Vcur,
1252
- // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N )),
1252
+ // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx )),
1253
1253
// 1, 2, 0, 3);
1254
1254
1255
1255
// struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@@ -1259,9 +1259,9 @@ static bool whisper_encode(
1259
1259
ggml_permute (ctxL,
1260
1260
ggml_reshape_3d (ctxL,
1261
1261
Vcur,
1262
- n_state/n_head, n_head, N ),
1262
+ n_state/n_head, n_head, n_ctx ),
1263
1263
0 , 2 , 1 , 3 ),
1264
- ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_state/n_head, N , n_head)
1264
+ ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx , n_head)
1265
1265
);
1266
1266
1267
1267
struct ggml_tensor * KQV = ggml_mul_mat (ctxL, ggml_transpose (ctxL, V), KQ_soft_max);
@@ -1271,7 +1271,7 @@ static bool whisper_encode(
1271
1271
1272
1272
cur = ggml_cpy (ctxL,
1273
1273
KQV_merged,
1274
- ggml_new_tensor_2d (ctxL, GGML_TYPE_F32, n_state, N ));
1274
+ ggml_new_tensor_2d (ctxL, GGML_TYPE_F32, n_state, n_ctx ));
1275
1275
}
1276
1276
1277
1277
// projection
@@ -1474,7 +1474,7 @@ static bool whisper_decode(
1474
1474
const int n_layer = hparams.n_text_layer ;
1475
1475
1476
1476
const int N = n_tokens;
1477
- const int M = hparams. n_audio_ctx ;
1477
+ const int M = WHISPER_EXPERIMENT_AUDIO_CTX ;
1478
1478
1479
1479
struct ggml_init_params params = {
1480
1480
.mem_size = wctx.buf_compute .size (),
@@ -2656,7 +2656,7 @@ int whisper_full(
2656
2656
// }
2657
2657
2658
2658
// end of text token
2659
- if (token.id == whisper_token_eot (ctx)) {
2659
+ if (token.id == whisper_token_eot (ctx) || (i > WHISPER_EXPERIMENT_MAX_TOKENS_PER_SEGMENT) ) {
2660
2660
if (result_len == 0 ) {
2661
2661
if (seek + seek_delta + 100 >= seek_end) {
2662
2662
result_len = i + 1 ;
@@ -2844,7 +2844,7 @@ int whisper_full_parallel(
2844
2844
2845
2845
// key/value memory for the cross-attention layer
2846
2846
{
2847
- const int n_audio_ctx = hparams.n_audio_ctx ;
2847
+ const int n_audio_ctx = hparams.n_audio_ctx ;
2848
2848
2849
2849
const int n_mem = n_text_layer*n_audio_ctx;
2850
2850
const int n_elements = n_text_state*n_mem;
0 commit comments