diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7b8faf17879d6..8e1da64fb82b6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1735,7 +1735,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) { // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; + +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { + if (hsv >= 512) { + return 2; + } else { + return 8; + } +} // The FA coopmat1 shader assumes 16x16x16 matrix multiply support. // 128 threads split into four subgroups, each subgroup does 1/4 @@ -1760,7 +1767,7 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if (small_rows) { return {scalar_flash_attention_num_small_rows, 64}; } else { - return {scalar_flash_attention_num_large_rows, 32}; + return {get_fa_scalar_num_large_rows(hsv), 32}; } } @@ -1779,7 +1786,11 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 // small cols to reduce register count if (ggml_is_quantized(type) || hsk >= 256) { - return {64, 32}; + if (hsk >= 512) { + return {32, 32}; + } else { + return {64, 32}; + } } return {64, 64}; } @@ -1821,7 +1832,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec const uint32_t warps = warptile[0] / warptile[10]; const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; - const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0; + const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; @@ -1946,10 +1957,10 @@ static void ggml_vk_load_shaders(vk_device& device) { s_mmq_wg_denoms_k = { 32, 32, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 64, 16, 0 }; + l_warptile_mmqid = { 256, 128, 128, 16, 0 }; m_warptile_mmqid = { 256, 128, 64, 16, 0 }; s_warptile_mmqid = { 256, 128, 64, 16, 0 }; - l_mmqid_wg_denoms = { 128, 64, 1 }; + l_mmqid_wg_denoms = { 128, 128, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; @@ -6048,7 +6059,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = scalar_flash_attention_num_large_rows; + const uint32_t Br = get_fa_scalar_num_large_rows(hsv); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -6173,7 +6184,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = scalar_flash_attention_num_large_rows; + max_gqa = get_fa_scalar_num_large_rows(HSV); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 888ce79f6ec11..f481549911b92 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -18,6 +18,7 @@ #extension GL_KHR_cooperative_matrix : enable #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable #endif #ifdef MUL_MAT_ID @@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; #ifdef MUL_MAT_ID shared u16vec2 row_ids[4096]; +uint _ne1; +#ifdef COOPMAT +shared uint _ne1_sh; +#endif #endif // MUL_MAT_ID #define NUM_WARPS (BLOCK_SIZE / WARP) @@ -172,7 +177,47 @@ void main() { const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; #ifdef MUL_MAT_ID - uint _ne1 = 0; +#ifdef COOPMAT + // Spread the search across all elements in the first subgroup + if (gl_SubgroupID == 0) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_SubgroupInvocationID; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + uint idx = subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx) { + row_ids[_ne1 + idx] = u16vec2(ii0, ii1); + } + _ne1 += subgroupBallotBitCount(ballot); + iter &= 15; + } + _ne1_sh = _ne1; + } + + barrier(); + + _ne1 = _ne1_sh; +#else + _ne1 = 0; for (uint ii1 = 0; ii1 < p.nei1; ii1++) { for (uint ii0 = 0; ii0 < p.nei0; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { @@ -183,6 +228,7 @@ void main() { } barrier(); +#endif // Workgroup has no work if (ic * BN >= _ne1) return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 9184657573281..29e4b5c9ce2d4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -162,17 +162,32 @@ void main() { _ne1 = 0; uint num_elements = p.nei1 * p.nei0; - for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_SubgroupInvocationID; bool in_range = i < num_elements; - uint ii0 = i % p.nei0; uint ii1 = i / p.nei0; - uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + uint ii0 = i % p.nei0; + uint id = ids[iter++]; uvec4 ballot = subgroupBallot(in_range && id == expert_idx); uint idx = subgroupBallotExclusiveBitCount(ballot); if (in_range && id == expert_idx) { row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); } _ne1 += subgroupBallotBitCount(ballot); + iter &= 15; } _ne1_sh = _ne1; } @@ -414,17 +429,31 @@ void main() { fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); } - coopmat mat_a; - coopmat mat_b; + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); #else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif - sum = coopMatMulAdd(mat_a, mat_b, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); +#endif + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } } // Convert from ACC_TYPE to D_TYPE