Skip to content

CUDA: add set rows for f32 and f16 #14551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

am17an
Copy link
Collaborator

@am17an am17an commented Jul 6, 2025

@JohannesGaessler - I'm still working on the refactor for this, for now set_rows might still be useful in this form for #14363, so putting up a PR

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jul 6, 2025
@am17an am17an requested a review from JohannesGaessler July 6, 2025 14:38
@ggerganov
Copy link
Member

You can also test it with #14482

@am17an am17an mentioned this pull request Jul 7, 2025
15 tasks
Comment on lines +63 to +67
const int max_threads_per_row = 256;
const int threads_per_row = std::min((int)ne00, max_threads_per_row);

const int max_threads_per_block = 256;
const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use 1024 here instead of 256?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a mistake, it should be 1024

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using fewer threads per block may achieve higher occupancy and result in better performance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A streaming multiprocessor can at most have 1024 concurrent threads, so for most CUDA kernels the number of threads should be <= 256 to make scheduling easier. I would say the only real use case for more threads is if you're using so much shared memory that occupancy is not limited by the number of threads anyways (and register pressure is also not an issue).

@am17an
Copy link
Collaborator Author

am17an commented Jul 7, 2025

Actually this looks like it's worse than just doing the cpy kernel, worse with threads_per_block 1024.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend you organize the threads differently. Instead of a loop over ne00 I think it will be better to just assign one copy operation to each thread; yes, that will result in some redundant work for the index calculation. But this kernel is going to be 100% I/O bound anyways so the compute pipelines will be severely underutilized and especially for small input sizes the achieved occupancy will be much better.

For the indices I would put dimensions 1 in blockIdx.x, dimension 0 in blockIdx.y (blockIdx.x has a higher limit) and dimensions 2 and 3 in blockIdx.z. For threadIdx I would just use a one-dimensional layout that maps thread indices to flattened tensor data indices. For this I would recommend you use code like this:

div_t q = div(threadIdx.z, ne02);
const int ne03 = q.quot;
const int ne02 = q.rem;

I recently noticed that the above instruction would be a better way to calculate indices than how I did previously. From what I can tell we're also not using it in CPU code. @ggerganov is there a reason for this?

Comment on lines +5 to +15
static __device__ void set_rows_1_f32_f32(const char * src, char * dst) {
const float * src_f = (const float *) src;
float * dst_f = (float *) dst;
*dst_f = *src_f;
}

static __device__ void set_rows_1_f32_f16(const char * src, char * dst) {
const float * src_f = (const float *) src;
half * dst_h = (half *) dst;
*dst_h = __float2half(*src_f);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now but long-term I think it will be simpler (for floating-point data types) to define __device__ functions that map from and to float rather than explicit mappings between 2 types.

Comment on lines +42 to +49
const char * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03;
char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;

for (int col = threadIdx.x; col < ne00; col += blockDim.x) {
const char * src_elem = src0_row + col * src_type_size;
char * dst_elem = dst_row_ptr + col * dst_type_size;
set_rows_1(src_elem, dst_elem);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't look at the generated PTX code for this specific CUDA code but in my experience the CPU code pattern with explicit byte offsets performs comparatively poorly on GPUs. My recommendation would be to have the input and output types as template parameters and to compute the strides in units of the types (e.g. nb01/ggml_element_size(src0)) in host code. This is of course assuming that this kernel has a non-negligible contribution to the end-to-end performance in the first place.

Comment on lines +63 to +67
const int max_threads_per_row = 256;
const int threads_per_row = std::min((int)ne00, max_threads_per_row);

const int max_threads_per_block = 256;
const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A streaming multiprocessor can at most have 1024 concurrent threads, so for most CUDA kernels the number of threads should be <= 256 to make scheduling easier. I would say the only real use case for more threads is if you're using so much shared memory that occupancy is not limited by the number of threads anyways (and register pressure is also not an issue).

cudaStream_t stream) {

const int max_threads_per_row = 256;
const int threads_per_row = std::min((int)ne00, max_threads_per_row);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't compare the two versions but I think it would be better to use a static block size of e.g. 256 here and to map the thread indices to tensor data indices as if the tensor was flattened. Running CUDA code with fractional warps will always result in wasted GPU resources.

@ggerganov
Copy link
Member

I recently noticed that the above instruction would be a better way to calculate indices than how I did previously. From what I can tell we're also not using it in CPU code. @ggerganov is there a reason for this?

I didn't know about std::div. We should start using it.

@slaren
Copy link
Member

slaren commented Jul 7, 2025

At least for the CPU, std::div is not likely to be faster because the compiler can merge the division and modulus to a single instruction. In fact, it seems that std::div is not inlined in some compilers, which results in worse performance.
https://godbolt.org/z/1hMK5vjvj
https://quick-bench.com/q/KUIm7TlEB-muDY49TVaUdiNkXnU

@JohannesGaessler
Copy link
Collaborator

@am17an sorry, my advice regarding division was also bad for CUDA. Use division and modulo operators as per usual.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants