Skip to content

Issue reproducing Float8 performance benchmark #1344

Open
@xmfan

Description

@xmfan

Bug description

I'm looking at https://github.com/pytorch/torchtitan/blob/main/benchmarks/llama3_h100_202412_torchtitan.md. Specifically, this table:

Image

I'm not certain what the repro command for this. From https://github.com/pytorch/torchtitan/blob/main/docs/float8.md, I went ahead with CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --training.compile.

Made the following changes to my llama3 toml: https://gist.github.com/xmfan/53fca4ed56cf7e713a282ce6e1922e9e

  • seq_len = 32768
  • data_parallel_shard_degree = 8 (for 8 gpu fsdp)
  • activation_checkpoint.mode = "full"
  • steps = 400 (just for a shorter run)

But my peak memory of the run seems way lower than the one quoted in the perf benchmarks, which makes me think I did something wrong. @tianyu-l tried these settings, and got a hang instead.

Are these the correct settings for this benchmark?

https://gist.github.com/xmfan/5a6b6daa0968aed7499ef364dae61420

Versions

latest torchao (USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git), pytorch 06/25 nightly, torchtitan main

Metadata

Metadata

Labels

documentationImprovements or additions to documentation

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions