Skip to content

vulkan: optimizations for deepseek prompt processing #14555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1735,7 +1735,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
// number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;

static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
if (hsv >= 512) {
return 2;
} else {
return 8;
}
}

// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
// 128 threads split into four subgroups, each subgroup does 1/4
Expand All @@ -1760,7 +1767,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if (small_rows) {
return {scalar_flash_attention_num_small_rows, 64};
} else {
return {scalar_flash_attention_num_large_rows, 32};
return {get_fa_scalar_num_large_rows(hsv), 32};
}
}

Expand All @@ -1779,7 +1786,11 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3

// small cols to reduce register count
if (ggml_is_quantized(type) || hsk >= 256) {
return {64, 32};
if (hsk >= 512) {
return {32, 32};
} else {
return {64, 32};
}
}
return {64, 64};
}
Expand Down Expand Up @@ -1821,7 +1832,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
const uint32_t warps = warptile[0] / warptile[10];

const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;

const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
Expand Down Expand Up @@ -1946,10 +1957,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms_k = { 32, 32, 1 };

// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
l_mmqid_wg_denoms = { 128, 64, 1 };
l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };

Expand Down Expand Up @@ -6048,7 +6059,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = scalar_flash_attention_num_large_rows;
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
const uint32_t Bc = scalar_flash_attention_Bc;

const uint32_t tmpsh = wg_size * sizeof(float);
Expand Down Expand Up @@ -6173,7 +6184,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = scalar_flash_attention_num_large_rows;
max_gqa = get_fa_scalar_num_large_rows(HSV);
break;
case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
Expand Down
48 changes: 47 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#endif

#ifdef MUL_MAT_ID
Expand Down Expand Up @@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];

#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
uint _ne1;
#ifdef COOPMAT
shared uint _ne1_sh;
#endif
#endif // MUL_MAT_ID

#define NUM_WARPS (BLOCK_SIZE / WARP)
Expand Down Expand Up @@ -172,7 +177,47 @@ void main() {
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;

#ifdef MUL_MAT_ID
uint _ne1 = 0;
#ifdef COOPMAT
// Spread the search across all elements in the first subgroup
if (gl_SubgroupID == 0) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;

uint ids[16];
uint iter = 0;

for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_SubgroupInvocationID;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
uint idx = subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
}
_ne1 += subgroupBallotBitCount(ballot);
iter &= 15;
}
_ne1_sh = _ne1;
}

barrier();

_ne1 = _ne1_sh;
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
Expand All @@ -183,6 +228,7 @@ void main() {
}

barrier();
#endif

// Workgroup has no work
if (ic * BN >= _ne1) return;
Expand Down
47 changes: 38 additions & 9 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,32 @@ void main() {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;

for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
uint ids[16];
uint iter = 0;

for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_SubgroupInvocationID;
bool in_range = i < num_elements;
uint ii0 = i % p.nei0;
uint ii1 = i / p.nei0;
uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
uint ii0 = i % p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
uint idx = subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
}
_ne1 += subgroupBallotBitCount(ballot);
iter &= 15;
}
_ne1_sh = _ne1;
}
Expand Down Expand Up @@ -414,17 +429,31 @@ void main() {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}

coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;

coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif

sum = coopMatMulAdd(mat_a, mat_b, sum);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;

coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif

sum = coopMatMulAdd(mat_a, mat_b, sum);
}
}

// Convert from ACC_TYPE to D_TYPE
Expand Down
Loading