Description
Bug description
I'm looking at https://github.com/pytorch/torchtitan/blob/main/benchmarks/llama3_h100_202412_torchtitan.md. Specifically, this table:

I'm not certain what the repro command for this. From https://github.com/pytorch/torchtitan/blob/main/docs/float8.md, I went ahead with CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --training.compile
.
Made the following changes to my llama3 toml: https://gist.github.com/xmfan/53fca4ed56cf7e713a282ce6e1922e9e
- seq_len = 32768
- data_parallel_shard_degree = 8 (for 8 gpu fsdp)
- activation_checkpoint.mode = "full"
- steps = 400 (just for a shorter run)
But my peak memory of the run seems way lower than the one quoted in the perf benchmarks, which makes me think I did something wrong. @tianyu-l tried these settings, and got a hang instead.
Are these the correct settings for this benchmark?
https://gist.github.com/xmfan/5a6b6daa0968aed7499ef364dae61420
Versions
latest torchao (USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
), pytorch 06/25 nightly, torchtitan main