Skip to content

False negative for TensorCore utilization in framework op stats page #1704

@emergenz

Description

@emergenz

The following program is a matrix multiplication of two matrices in FP32 on an H100. We achieve a rough throughput of 400 TFLOP/s (which is impossible without using TensorCores). XLA calls cuBLAS, which in turn converts the matrices to TF32, thus making the kernel TensorCore-eligible. While the GPU kernel stats page of the JAX profiler correctly identifies TensorCore eligibility as well as TensorCore usage, the framework op stats page incorrectly identifies the op as not being TensorCore eligible.

import datetime
import jax.numpy as jnp
import jax

MATRIX_DIM = 32768
STEPS = 10

A = jnp.ones((MATRIX_DIM, MATRIX_DIM))
B = jnp.ones((MATRIX_DIM, MATRIX_DIM))

num_bytes = A.size * 4
total_num_bytes_crossing_to_hbm = num_bytes * 3

total_num_flops = 2 * MATRIX_DIM * MATRIX_DIM**2

def matmul(A, B):
  return A @ B

matmul(A, B) # warmup

with jax.profiler.trace("tensorboard"):
  start_time = datetime.datetime.now()
  for i in range(STEPS):
    C = A @ B
    C.block_until_ready()
  end_time = datetime.datetime.now()

average_time_per_step = (end_time - start_time).total_seconds() / STEPS

print(f"{average_time_per_step}, teraflops per second: {total_num_flops / average_time_per_step / 1e12}, gigabytes per second: {total_num_bytes_crossing_to_hbm / average_time_per_step / 1e9}")
%custom-call.1 =
  (f32[32768,32768]{1,0},
   s8[33554432]{0}) custom-
   call(f32[32768,32768]{1,0}
   %Arg_0.1, f32[32768,32768]
   {1,0} %Arg_1.2),
   custom_call_target="__cublas$
gemm"

Kernel name: sm90_xmma_gemm_f32f32_tf32f32_f32_nn_n_tilesize128x256x32_warpgroupsize2x1x1_execute_segment_k_off_kernel__5x_cublas

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions