@@ -1543,6 +1543,36 @@ struct test_ssm_scan : public test_case {
1543
1543
}
1544
1544
};
1545
1545
1546
+ // GGML_OP_RWKV_WKV
1547
+ struct test_rwkv_wkv : public test_case {
1548
+ const ggml_type type;
1549
+
1550
+ const int64_t head_count;
1551
+ const int64_t head_size;
1552
+ const int64_t n_seq_tokens;
1553
+ const int64_t n_seqs;
1554
+
1555
+ std::string vars () override {
1556
+ return VARS_TO_STR5 (type, head_count, head_size, n_seq_tokens, n_seqs);
1557
+ }
1558
+
1559
+ test_rwkv_wkv (ggml_type type = GGML_TYPE_F32,
1560
+ int64_t head_count = 32 , int64_t head_size = 64 , int64_t n_seq_tokens = 32 , int64_t n_seqs = 32 )
1561
+ : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1562
+
1563
+ ggml_tensor * build_graph (ggml_context * ctx) override {
1564
+ const int64_t n_tokens = n_seq_tokens * n_seqs;
1565
+ ggml_tensor * r = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1566
+ ggml_tensor * k = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ head_size, 1 , head_count, n_tokens }.data ());
1567
+ ggml_tensor * v = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1568
+ ggml_tensor * tf = ggml_new_tensor (ctx, type, 2 , std::vector<int64_t >{ head_size, head_count }.data ());
1569
+ ggml_tensor * td = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1570
+ ggml_tensor * s = ggml_new_tensor (ctx, type, 2 , std::vector<int64_t >{ head_size * head_size * head_count, n_seqs }.data ());
1571
+ ggml_tensor * out = ggml_rwkv_wkv (ctx, k, v, r, tf, td, s);
1572
+ return out;
1573
+ }
1574
+ };
1575
+
1546
1576
// GGML_OP_MUL_MAT
1547
1577
struct test_mul_mat : public test_case {
1548
1578
const ggml_type type_a;
@@ -3337,6 +3367,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
3337
3367
3338
3368
test_cases.emplace_back (new test_ssm_scan (GGML_TYPE_F32, 16 , 1024 , 32 , 4 ));
3339
3369
3370
+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 1 , 1 ));
3371
+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 32 , 1 ));
3372
+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 32 , 4 ));
3373
+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 128 , 4 ));
3374
+
3340
3375
#if 1
3341
3376
for (ggml_type type_a : base_types) {
3342
3377
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
0 commit comments