-
Notifications
You must be signed in to change notification settings - Fork 71
Open
Description
The following program is a matrix multiplication of two matrices in FP32 on an H100. We achieve a rough throughput of 400 TFLOP/s (which is impossible without using TensorCores). XLA calls cuBLAS, which in turn converts the matrices to TF32, thus making the kernel TensorCore-eligible. While the GPU kernel stats page of the JAX profiler correctly identifies TensorCore eligibility as well as TensorCore usage, the framework op stats page incorrectly identifies the op as not being TensorCore eligible.
import datetime
import jax.numpy as jnp
import jax
MATRIX_DIM = 32768
STEPS = 10
A = jnp.ones((MATRIX_DIM, MATRIX_DIM))
B = jnp.ones((MATRIX_DIM, MATRIX_DIM))
num_bytes = A.size * 4
total_num_bytes_crossing_to_hbm = num_bytes * 3
total_num_flops = 2 * MATRIX_DIM * MATRIX_DIM**2
def matmul(A, B):
return A @ B
matmul(A, B) # warmup
with jax.profiler.trace("tensorboard"):
start_time = datetime.datetime.now()
for i in range(STEPS):
C = A @ B
C.block_until_ready()
end_time = datetime.datetime.now()
average_time_per_step = (end_time - start_time).total_seconds() / STEPS
print(f"{average_time_per_step}, teraflops per second: {total_num_flops / average_time_per_step / 1e12}, gigabytes per second: {total_num_bytes_crossing_to_hbm / average_time_per_step / 1e9}")
%custom-call.1 =
(f32[32768,32768]{1,0},
s8[33554432]{0}) custom-
call(f32[32768,32768]{1,0}
%Arg_0.1, f32[32768,32768]
{1,0} %Arg_1.2),
custom_call_target="__cublas$
gemm"
Kernel name: sm90_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize128x256x32_warpgroupsize2x1x1_execute_segment_k_off_kernel__5x_cublas
Metadata
Metadata
Assignees
Labels
No labels