Skip to content

[FEA] Improve FP8xFP8 GEMM performance, especially for FP8_E4M3 #394

@sanchitintel

Description

@sanchitintel

Problem

  1. Currently, FP8xFP8 GEMM has quite poor performance, especially for FP8_E4M3,
    since FP8_E4M3 -> FP16 conversion has quite a huge overhead.
  2. Explore performance bottleneck for E5M2 -> FP16 conversion (which is just a shift op).

Tentative solutions

  • We should cache converted FP8_E4M3 -> FP16 elements of A & B in shared local memory & reuse them.
    oneDNN has also been doing that. Based on the writeup in the first comment, we should also do that for E5M2.

  • Also, the E4M3 GEMM implementation uses unnecessary copies, as it creates several new vectors. Copy elision (Return Value Optimization) for some of these vectors doesn't seem to be happening with nightly DPCPP dated March 24 (since I was able to get a least ~86% better performance on DeepSeek input shapes for FP8_E5M2 GEMMs after eliminating all unnecessary copies). However, the compiler optimization level was O2. (So, this task is for removing all unnecessary copies by not creating new vectors if their creation is not necessary).

PRs

  1. Reduce E5M2 -> FP16 conversion overhead with fewer shift instructions - Reduce FP8_E5M2 -> FP16 conversion overhead #396

  2. Adding Fp8 input support for flash attention prefill #419 removed all unnecessary copies from E4M3 -> FP16 conversion.

  3. @jiyang1011 will submit a PR for [FEA] Improve FP8xFP8 GEMM performance, especially for FP8_E4M3 #394 (comment) - converting FP8 -> FP16 elements, caching them in SLM & then reusing them. It'd help both E4M3 & E5M2.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions