Skip to content

add out_f32x4_shared_bcf_merge_write_row2col(2d) #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions kernels/mat-transpose/mat_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ __global__ void mat_transpose_f32x4_shared_row2col2d_kernel(float *x, float *y,
}
}


__global__ void mat_transpose_f32x4_shared_bcf_col2row2d_kernel(float *x,
float *y,
const int row,
Expand Down Expand Up @@ -296,6 +297,44 @@ __global__ void mat_transpose_f32x4_shared_bcf_row2col2d_kernel(float *x,
y[(out_y + 3) * row + out_x] = smem_val.w;
}
}


__global__ void mat_transpose_f32x4_shared_bcf_merge_write_row2col2d_kernel(float *x,
float *y,
const int row,
const int col) {
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
const int local_x = threadIdx.x;
const int local_y = threadIdx.y;
__shared__ float tile[WARP_SIZE_S * 4][WARP_SIZE_S + PAD];
if (global_y * 4 < row && global_x < col) {
// load value from x to shared memory
float4 x_val;
x_val.x = x[(global_y * 4) * col + global_x];
x_val.y = x[(global_y * 4 + 1) * col + global_x];
x_val.z = x[(global_y * 4 + 2) * col + global_x];
x_val.w = x[(global_y * 4 + 3) * col + global_x];
tile[local_y * 4][local_x] = x_val.x;
tile[local_y * 4 + 1][local_x] = x_val.y;
tile[local_y * 4 + 2][local_x] = x_val.z;
tile[local_y * 4 + 3][local_x] = x_val.w;
__syncthreads();
float4 smem_val;
// load value from shared memory to y.
smem_val.x = tile[local_x * 4][local_y];
smem_val.y = tile[local_x * 4 + 1][local_y];
smem_val.z = tile[local_x * 4 + 2][local_y];
smem_val.w = tile[local_x * 4 + 3][local_y];

const int gid_x = blockIdx.x * blockDim.x;
const int gid_y = blockIdx.y * blockDim.y * 4;
const int out_y = gid_y + local_x * 4;
const int out_x = gid_x + local_y;
reinterpret_cast<float4 *>(y)[(out_x * row + out_y) / 4] = FLOAT4(smem_val);
}
}

// TODO: may support double buffer pipeline mat transpose ?
// TODO: may support fp16 mat transpose ?

Expand Down Expand Up @@ -361,6 +400,9 @@ TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_col2row, torch::kFloat32, float,
1, 4)
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_row2col, torch::kFloat32, float,
4, 1)
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_merge_write_row2col, torch::kFloat32, float,
4, 1)

// TODO: may support double buffer pipeline mat transpose ?
// TODO: may support fp16 mat transpose ?

Expand Down Expand Up @@ -400,6 +442,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// shared memory optimize with bcf
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_col2row2d)
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_row2col2d)
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_merge_write_row2col2d)
// CuTe implentations
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_cute_col2row_reg)
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_cute_row2col_reg)
Expand Down
6 changes: 6 additions & 0 deletions kernels/mat-transpose/mat_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ def transpose_copy_compiled(input: torch.Tensor, out: torch.Tensor):
"f32x4_shared_bcf_row2col(2d)",
y,
)
run_benchmark(
lib.mat_transpose_f32x4_shared_bcf_merge_write_row2col2d,
x,
"f32x4_shared_bcf_merge_write_row2col(2d)",
y,
)
run_benchmark(
lib.mat_transpose_cute_col2row_reg,
x,
Expand Down