Skip to content

Autoparallel support for DP-only, DP+TP, or TP-only #1349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: autoparallel
Choose a base branch
from

Conversation

wconstab
Copy link
Contributor

lets existing torchtitan knobs which govern DP/TP mesh creation and mesh size influence the sharding constraints of autoparallel, allowing it to support these different sharding configurations.

Examples:

CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 1 --training.dataset c4 https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpUf57BL/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

[rank0]:[titan] 2025-06-26 18:12:46,592 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]                                                              [rank0]:[titan] 2025-06-26 18:12:46,593 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-06-26 18:23:14,389 - root - INFO - step:  1  loss:  8.1996  memory:  2.65GiB(2.79%)  tps: 1,772  tflops: 0.13  mfu: 0.01%
[rank0]:[titan] 2025-06-26 18:23:14,486 - root - INFO - step:  2  loss:  8.1709  memory:  2.66GiB(2.80%)  tps: 168,877  tflops: 12.14  mfu: 1.23%
[rank0]:[titan] 2025-06-26 18:23:14,580 - root - INFO - step:  3  loss:  8.1121  memory:  2.66GiB(2.80%)  tps: 175,100  tflops: 12.59  mfu: 1.27%
[rank0]:[titan] 2025-06-26 18:23:14,677 - root - INFO - step:  4  loss:  8.0119  memory:  2.66GiB(2.80%)  tps: 170,227  tflops: 12.24  mfu: 1.24%
[rank0]:[titan] 2025-06-26 18:23:14,771 - root - INFO - step:  5  loss:  7.8920  memory:  2.66GiB(2.80%)  tps: 174,614  tflops: 12.56  mfu: 1.27%
[rank0]:[titan] 2025-06-26 18:23:14,867 - root - INFO - step:  6  loss:  7.7511  memory:  2.66GiB(2.80%)  tps: 170,863  tflops: 12.29  mfu: 1.24%
[rank0]:[titan] 2025-06-26 18:23:14,963 - root - INFO - step:  7  loss:  7.6531  memory:  2.66GiB(2.80%)  tps: 172,868  tflops: 12.43  mfu: 1.26%
[rank0]:[titan] 2025-06-26 18:23:15,060 - root - INFO - step:  8  loss:  7.5231  memory:  2.66GiB(2.80%)  tps: 168,378  tflops: 12.11  mfu: 1.22%
[rank0]:[titan] 2025-06-26 18:23:15,157 - root - INFO - step:  9  loss:  7.3795  memory:  2.66GiB(2.80%)  tps: 170,250  tflops: 12.24  mfu: 1.24%
[rank0]:[titan] 2025-06-26 18:23:15,251 - root - INFO - step: 10  loss:  7.3036  memory:  2.66GiB(2.80%)  tps: 175,755  tflops: 12.64  mfu: 1.28%

CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4 https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp981ifR/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

[rank0]:[titan] 2025-06-26 18:24:05,617 - root - INFO - Building 2-D device mesh with ['dp_shard', 'tp'], [2, 4]
[rank0]:[titan] 2025-06-26 18:27:44,952 - root - INFO - step:  1  loss:  8.2345  memory:  1.08GiB(1.14%)  tps: 74  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-06-26 18:27:45,003 - root - INFO - step:  2  loss:  8.2156  memory:  1.15GiB(1.21%)  tps: 80,543  tflops: 5.79  mfu: 0.59%
[rank0]:[titan] 2025-06-26 18:27:45,054 - root - INFO - step:  3  loss:  8.1867  memory:  1.15GiB(1.21%)  tps: 81,472  tflops: 5.86  mfu: 0.59%
[rank0]:[titan] 2025-06-26 18:27:45,099 - root - INFO - step:  4  loss:  8.1072  memory:  1.15GiB(1.21%)  tps: 90,961  tflops: 6.54  mfu: 0.66%
[rank0]:[titan] 2025-06-26 18:27:45,145 - root - INFO - step:  5  loss:  8.0360  memory:  1.15GiB(1.21%)  tps: 90,280  tflops: 6.49  mfu: 0.66%
[rank0]:[titan] 2025-06-26 18:27:45,193 - root - INFO - step:  6  loss:  7.9681  memory:  1.15GiB(1.21%)  tps: 84,915  tflops: 6.11  mfu: 0.62%
[rank0]:[titan] 2025-06-26 18:27:45,241 - root - INFO - step:  7  loss:  7.8870  memory:  1.15GiB(1.21%)  tps: 86,096  tflops: 6.19  mfu: 0.63%
[rank0]:[titan] 2025-06-26 18:27:45,292 - root - INFO - step:  8  loss:  7.8493  memory:  1.15GiB(1.21%)  tps: 81,182  tflops: 5.84  mfu: 0.59%
[rank0]:[titan] 2025-06-26 18:27:45,341 - root - INFO - step:  9  loss:  7.7431  memory:  1.15GiB(1.21%)  tps: 84,341  tflops: 6.06  mfu: 0.61%
[rank0]:[titan] 2025-06-26 18:27:45,396 - root - INFO - step: 10  loss:  7.7052  memory:  1.15GiB(1.21%)  tps: 74,973  tflops: 5.39  mfu: 0.55%

CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 8 --training.dataset c4 https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpgPuMRF/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

[rank0]:[titan] 2025-06-26 18:32:37,789 - root - INFO - Building 1-D device mesh with ['tp'], [8]
[rank0]:[titan] 2025-06-26 18:33:00,183 - root - INFO - step:  1  loss:  8.2190  memory:  0.81GiB(0.85%)  tps: 205  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-06-26 18:33:00,251 - root - INFO - step:  2  loss:  8.1733  memory:  0.87GiB(0.92%)  tps: 30,431  tflops: 2.19  mfu: 0.22%
[rank0]:[titan] 2025-06-26 18:33:00,297 - root - INFO - step:  3  loss:  8.1438  memory:  0.87GiB(0.92%)  tps: 44,284  tflops: 3.18  mfu: 0.32%
[rank0]:[titan] 2025-06-26 18:33:00,342 - root - INFO - step:  4  loss:  8.0361  memory:  0.87GiB(0.92%)  tps: 45,921  tflops: 3.30  mfu: 0.33%
[rank0]:[titan] 2025-06-26 18:33:00,384 - root - INFO - step:  5  loss:  7.9559  memory:  0.87GiB(0.92%)  tps: 49,178  tflops: 3.54  mfu: 0.36%
[rank0]:[titan] 2025-06-26 18:33:00,426 - root - INFO - step:  6  loss:  7.8346  memory:  0.87GiB(0.92%)  tps: 49,172  tflops: 3.54  mfu: 0.36%
[rank0]:[titan] 2025-06-26 18:33:00,462 - root - INFO - step:  7  loss:  7.7266  memory:  0.87GiB(0.92%)  tps: 58,273  tflops: 4.19  mfu: 0.42%
[rank0]:[titan] 2025-06-26 18:33:00,499 - root - INFO - step:  8  loss:  7.6807  memory:  0.87GiB(0.92%)  tps: 54,435  tflops: 3.91  mfu: 0.40%
[rank0]:[titan] 2025-06-26 18:33:00,537 - root - INFO - step:  9  loss:  7.5616  memory:  0.87GiB(0.92%)  tps: 55,232  tflops: 3.97  mfu: 0.40%
[rank0]:[titan] 2025-06-26 18:33:00,575 - root - INFO - step: 10  loss:  7.5090  memory:  0.87GiB(0.92%)  tps: 54,284  tflops: 3.90  mfu: 0.39%

wconstab added 5 commits June 16, 2025 12:32
TODO
- try converting model params into fake tensors
- figure out init fn
- integrate torchtitan configs for DP/TP to control autop
"""
[rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step:  1  loss:  8.1880  memory:  4.88GiB(6.16%)  tps: 28
[rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step:  2  loss:  8.1610  memory:  4.90GiB(6.20%)  tps: 13,785
[rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step:  3  loss:  8.0871  memory:  4.90GiB(6.20%)  tps: 14,006
[rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step:  4  loss:  7.9516  memory:  4.90GiB(6.20%)  tps: 13,770
[rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step:  5  loss:  7.8552  memory:  4.90GiB(6.20%)  tps: 13,959
[rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step:  6  loss:  7.7732  memory:  4.90GiB(6.20%)  tps: 13,859
[rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step:  7  loss:  7.6987  memory:  4.90GiB(6.20%)  tps: 13,664
[rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step:  8  loss:  7.6779  memory:  4.90GiB(6.20%)  tps: 13,985
[rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step:  9  loss:  7.6043  memory:  4.90GiB(6.20%)  tps: 13,962
[rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10  loss:  7.5778  memory:  4.90GiB(6.20%)  tps: 13,891
"""
Allows reverting a lot of the hacks in the original integration that
were caused by not creating a model obj in the train.py due to passing a
model_fn builder to autop.
lets existing torchtitan knobs which govern DP/TP mesh creation and mesh
size influence the sharding constraints of autoparallel, allowing it to
support these different sharding configurations.

Examples:

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 1 --training.dataset c4`
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpUf57BL/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

```
[rank0]:[titan] 2025-06-26 18:12:46,592 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]                                                              [rank0]:[titan] 2025-06-26 18:12:46,593 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-06-26 18:23:14,389 - root - INFO - step:  1  loss:  8.1996  memory:  2.65GiB(2.79%)  tps: 1,772  tflops: 0.13  mfu: 0.01%
[rank0]:[titan] 2025-06-26 18:23:14,486 - root - INFO - step:  2  loss:  8.1709  memory:  2.66GiB(2.80%)  tps: 168,877  tflops: 12.14  mfu: 1.23%
[rank0]:[titan] 2025-06-26 18:23:14,580 - root - INFO - step:  3  loss:  8.1121  memory:  2.66GiB(2.80%)  tps: 175,100  tflops: 12.59  mfu: 1.27%
[rank0]:[titan] 2025-06-26 18:23:14,677 - root - INFO - step:  4  loss:  8.0119  memory:  2.66GiB(2.80%)  tps: 170,227  tflops: 12.24  mfu: 1.24%
[rank0]:[titan] 2025-06-26 18:23:14,771 - root - INFO - step:  5  loss:  7.8920  memory:  2.66GiB(2.80%)  tps: 174,614  tflops: 12.56  mfu: 1.27%
[rank0]:[titan] 2025-06-26 18:23:14,867 - root - INFO - step:  6  loss:  7.7511  memory:  2.66GiB(2.80%)  tps: 170,863  tflops: 12.29  mfu: 1.24%
[rank0]:[titan] 2025-06-26 18:23:14,963 - root - INFO - step:  7  loss:  7.6531  memory:  2.66GiB(2.80%)  tps: 172,868  tflops: 12.43  mfu: 1.26%
[rank0]:[titan] 2025-06-26 18:23:15,060 - root - INFO - step:  8  loss:  7.5231  memory:  2.66GiB(2.80%)  tps: 168,378  tflops: 12.11  mfu: 1.22%
[rank0]:[titan] 2025-06-26 18:23:15,157 - root - INFO - step:  9  loss:  7.3795  memory:  2.66GiB(2.80%)  tps: 170,250  tflops: 12.24  mfu: 1.24%
[rank0]:[titan] 2025-06-26 18:23:15,251 - root - INFO - step: 10  loss:  7.3036  memory:  2.66GiB(2.80%)  tps: 175,755  tflops: 12.64  mfu: 1.28%
```

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4`
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp981ifR/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
```
[rank0]:[titan] 2025-06-26 18:24:05,617 - root - INFO - Building 2-D device mesh with ['dp_shard', 'tp'], [2, 4]
[rank0]:[titan] 2025-06-26 18:27:44,952 - root - INFO - step:  1  loss:  8.2345  memory:  1.08GiB(1.14%)  tps: 74  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-06-26 18:27:45,003 - root - INFO - step:  2  loss:  8.2156  memory:  1.15GiB(1.21%)  tps: 80,543  tflops: 5.79  mfu: 0.59%
[rank0]:[titan] 2025-06-26 18:27:45,054 - root - INFO - step:  3  loss:  8.1867  memory:  1.15GiB(1.21%)  tps: 81,472  tflops: 5.86  mfu: 0.59%
[rank0]:[titan] 2025-06-26 18:27:45,099 - root - INFO - step:  4  loss:  8.1072  memory:  1.15GiB(1.21%)  tps: 90,961  tflops: 6.54  mfu: 0.66%
[rank0]:[titan] 2025-06-26 18:27:45,145 - root - INFO - step:  5  loss:  8.0360  memory:  1.15GiB(1.21%)  tps: 90,280  tflops: 6.49  mfu: 0.66%
[rank0]:[titan] 2025-06-26 18:27:45,193 - root - INFO - step:  6  loss:  7.9681  memory:  1.15GiB(1.21%)  tps: 84,915  tflops: 6.11  mfu: 0.62%
[rank0]:[titan] 2025-06-26 18:27:45,241 - root - INFO - step:  7  loss:  7.8870  memory:  1.15GiB(1.21%)  tps: 86,096  tflops: 6.19  mfu: 0.63%
[rank0]:[titan] 2025-06-26 18:27:45,292 - root - INFO - step:  8  loss:  7.8493  memory:  1.15GiB(1.21%)  tps: 81,182  tflops: 5.84  mfu: 0.59%
[rank0]:[titan] 2025-06-26 18:27:45,341 - root - INFO - step:  9  loss:  7.7431  memory:  1.15GiB(1.21%)  tps: 84,341  tflops: 6.06  mfu: 0.61%
[rank0]:[titan] 2025-06-26 18:27:45,396 - root - INFO - step: 10  loss:  7.7052  memory:  1.15GiB(1.21%)  tps: 74,973  tflops: 5.39  mfu: 0.55%
```

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 8 --training.dataset c4`
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpgPuMRF/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
```
[rank0]:[titan] 2025-06-26 18:32:37,789 - root - INFO - Building 1-D device mesh with ['tp'], [8]
[rank0]:[titan] 2025-06-26 18:33:00,183 - root - INFO - step:  1  loss:  8.2190  memory:  0.81GiB(0.85%)  tps: 205  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-06-26 18:33:00,251 - root - INFO - step:  2  loss:  8.1733  memory:  0.87GiB(0.92%)  tps: 30,431  tflops: 2.19  mfu: 0.22%
[rank0]:[titan] 2025-06-26 18:33:00,297 - root - INFO - step:  3  loss:  8.1438  memory:  0.87GiB(0.92%)  tps: 44,284  tflops: 3.18  mfu: 0.32%
[rank0]:[titan] 2025-06-26 18:33:00,342 - root - INFO - step:  4  loss:  8.0361  memory:  0.87GiB(0.92%)  tps: 45,921  tflops: 3.30  mfu: 0.33%
[rank0]:[titan] 2025-06-26 18:33:00,384 - root - INFO - step:  5  loss:  7.9559  memory:  0.87GiB(0.92%)  tps: 49,178  tflops: 3.54  mfu: 0.36%
[rank0]:[titan] 2025-06-26 18:33:00,426 - root - INFO - step:  6  loss:  7.8346  memory:  0.87GiB(0.92%)  tps: 49,172  tflops: 3.54  mfu: 0.36%
[rank0]:[titan] 2025-06-26 18:33:00,462 - root - INFO - step:  7  loss:  7.7266  memory:  0.87GiB(0.92%)  tps: 58,273  tflops: 4.19  mfu: 0.42%
[rank0]:[titan] 2025-06-26 18:33:00,499 - root - INFO - step:  8  loss:  7.6807  memory:  0.87GiB(0.92%)  tps: 54,435  tflops: 3.91  mfu: 0.40%
[rank0]:[titan] 2025-06-26 18:33:00,537 - root - INFO - step:  9  loss:  7.5616  memory:  0.87GiB(0.92%)  tps: 55,232  tflops: 3.97  mfu: 0.40%
[rank0]:[titan] 2025-06-26 18:33:00,575 - root - INFO - step: 10  loss:  7.5090  memory:  0.87GiB(0.92%)  tps: 54,284  tflops: 3.90  mfu: 0.39%
```
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 27, 2025
@@ -37,30 +37,35 @@ def input_fn():
if global_batch_size < 0:
# This global batch size results in 1 gradient accumulation
# step.
dp_degree = world_mesh["dp"].size()
dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making dp_replicate and dp_shard configurable seem to be weird as 2 * 4 or 4 * 2 make no difference . Should we just stick to one and assert when another is assigned to be larger than 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean they make no difference? i want to use them to control the memory constraints so autop behaves more like ddp vs fsdp vs hsdp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to have more explicit arguments to tune memory constraints. But I understand now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, i think i will add explicit knobs to get the most out of autoparallel. i was first trying to slot autoparallel into the existing torchtitan as 'seamlessly' as possible.

one thing is that if users ignore the --dp_replicate_degree and similar cmdline args, and use other args for influencing autoparallel, we have the problem of how to decide which mesh dims to create. I will have to think about what to do for that.

assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet"
assert parallel_dims.cp_enabled is False, "CP not supported yet"
assert parallel_dims.pp_enabled is False, "PP not supported yet"


# bail out
# model = model_fn()
# return model

autop = AutoParallel(model, input_fn, world_mesh)
autop.add_parameter_memory_constraint(low=None, high=None)
Copy link
Member

@fmassa fmassa Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you might also need to tweak the memory constraint if you want to get DP behavior. For example, you can get DDP with low=1.0, high=1.0.

By default (i.e., low=None, high=None), we get low = 0, high=1 / mesh.size(), which says "let's shard the sum of all parameters so that each GPU has 1 / mesh.size().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, i will try that next.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants