-
Notifications
You must be signed in to change notification settings - Fork 424
Autoparallel support for DP-only, DP+TP, or TP-only #1349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: autoparallel
Are you sure you want to change the base?
Conversation
TODO - try converting model params into fake tensors - figure out init fn - integrate torchtitan configs for DP/TP to control autop
""" [rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28 [rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785 [rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006 [rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770 [rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959 [rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859 [rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664 [rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985 [rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962 [rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891 """
Allows reverting a lot of the hacks in the original integration that were caused by not creating a model obj in the train.py due to passing a model_fn builder to autop.
lets existing torchtitan knobs which govern DP/TP mesh creation and mesh size influence the sharding constraints of autoparallel, allowing it to support these different sharding configurations. Examples: `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 1 --training.dataset c4` https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpUf57BL/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 ``` [rank0]:[titan] 2025-06-26 18:12:46,592 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-06-26 18:12:46,593 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:[titan] 2025-06-26 18:23:14,389 - root - INFO - step: 1 loss: 8.1996 memory: 2.65GiB(2.79%) tps: 1,772 tflops: 0.13 mfu: 0.01% [rank0]:[titan] 2025-06-26 18:23:14,486 - root - INFO - step: 2 loss: 8.1709 memory: 2.66GiB(2.80%) tps: 168,877 tflops: 12.14 mfu: 1.23% [rank0]:[titan] 2025-06-26 18:23:14,580 - root - INFO - step: 3 loss: 8.1121 memory: 2.66GiB(2.80%) tps: 175,100 tflops: 12.59 mfu: 1.27% [rank0]:[titan] 2025-06-26 18:23:14,677 - root - INFO - step: 4 loss: 8.0119 memory: 2.66GiB(2.80%) tps: 170,227 tflops: 12.24 mfu: 1.24% [rank0]:[titan] 2025-06-26 18:23:14,771 - root - INFO - step: 5 loss: 7.8920 memory: 2.66GiB(2.80%) tps: 174,614 tflops: 12.56 mfu: 1.27% [rank0]:[titan] 2025-06-26 18:23:14,867 - root - INFO - step: 6 loss: 7.7511 memory: 2.66GiB(2.80%) tps: 170,863 tflops: 12.29 mfu: 1.24% [rank0]:[titan] 2025-06-26 18:23:14,963 - root - INFO - step: 7 loss: 7.6531 memory: 2.66GiB(2.80%) tps: 172,868 tflops: 12.43 mfu: 1.26% [rank0]:[titan] 2025-06-26 18:23:15,060 - root - INFO - step: 8 loss: 7.5231 memory: 2.66GiB(2.80%) tps: 168,378 tflops: 12.11 mfu: 1.22% [rank0]:[titan] 2025-06-26 18:23:15,157 - root - INFO - step: 9 loss: 7.3795 memory: 2.66GiB(2.80%) tps: 170,250 tflops: 12.24 mfu: 1.24% [rank0]:[titan] 2025-06-26 18:23:15,251 - root - INFO - step: 10 loss: 7.3036 memory: 2.66GiB(2.80%) tps: 175,755 tflops: 12.64 mfu: 1.28% ``` `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4` https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp981ifR/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 ``` [rank0]:[titan] 2025-06-26 18:24:05,617 - root - INFO - Building 2-D device mesh with ['dp_shard', 'tp'], [2, 4] [rank0]:[titan] 2025-06-26 18:27:44,952 - root - INFO - step: 1 loss: 8.2345 memory: 1.08GiB(1.14%) tps: 74 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-06-26 18:27:45,003 - root - INFO - step: 2 loss: 8.2156 memory: 1.15GiB(1.21%) tps: 80,543 tflops: 5.79 mfu: 0.59% [rank0]:[titan] 2025-06-26 18:27:45,054 - root - INFO - step: 3 loss: 8.1867 memory: 1.15GiB(1.21%) tps: 81,472 tflops: 5.86 mfu: 0.59% [rank0]:[titan] 2025-06-26 18:27:45,099 - root - INFO - step: 4 loss: 8.1072 memory: 1.15GiB(1.21%) tps: 90,961 tflops: 6.54 mfu: 0.66% [rank0]:[titan] 2025-06-26 18:27:45,145 - root - INFO - step: 5 loss: 8.0360 memory: 1.15GiB(1.21%) tps: 90,280 tflops: 6.49 mfu: 0.66% [rank0]:[titan] 2025-06-26 18:27:45,193 - root - INFO - step: 6 loss: 7.9681 memory: 1.15GiB(1.21%) tps: 84,915 tflops: 6.11 mfu: 0.62% [rank0]:[titan] 2025-06-26 18:27:45,241 - root - INFO - step: 7 loss: 7.8870 memory: 1.15GiB(1.21%) tps: 86,096 tflops: 6.19 mfu: 0.63% [rank0]:[titan] 2025-06-26 18:27:45,292 - root - INFO - step: 8 loss: 7.8493 memory: 1.15GiB(1.21%) tps: 81,182 tflops: 5.84 mfu: 0.59% [rank0]:[titan] 2025-06-26 18:27:45,341 - root - INFO - step: 9 loss: 7.7431 memory: 1.15GiB(1.21%) tps: 84,341 tflops: 6.06 mfu: 0.61% [rank0]:[titan] 2025-06-26 18:27:45,396 - root - INFO - step: 10 loss: 7.7052 memory: 1.15GiB(1.21%) tps: 74,973 tflops: 5.39 mfu: 0.55% ``` `CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 8 --training.dataset c4` https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpgPuMRF/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 ``` [rank0]:[titan] 2025-06-26 18:32:37,789 - root - INFO - Building 1-D device mesh with ['tp'], [8] [rank0]:[titan] 2025-06-26 18:33:00,183 - root - INFO - step: 1 loss: 8.2190 memory: 0.81GiB(0.85%) tps: 205 tflops: 0.01 mfu: 0.00% [rank0]:[titan] 2025-06-26 18:33:00,251 - root - INFO - step: 2 loss: 8.1733 memory: 0.87GiB(0.92%) tps: 30,431 tflops: 2.19 mfu: 0.22% [rank0]:[titan] 2025-06-26 18:33:00,297 - root - INFO - step: 3 loss: 8.1438 memory: 0.87GiB(0.92%) tps: 44,284 tflops: 3.18 mfu: 0.32% [rank0]:[titan] 2025-06-26 18:33:00,342 - root - INFO - step: 4 loss: 8.0361 memory: 0.87GiB(0.92%) tps: 45,921 tflops: 3.30 mfu: 0.33% [rank0]:[titan] 2025-06-26 18:33:00,384 - root - INFO - step: 5 loss: 7.9559 memory: 0.87GiB(0.92%) tps: 49,178 tflops: 3.54 mfu: 0.36% [rank0]:[titan] 2025-06-26 18:33:00,426 - root - INFO - step: 6 loss: 7.8346 memory: 0.87GiB(0.92%) tps: 49,172 tflops: 3.54 mfu: 0.36% [rank0]:[titan] 2025-06-26 18:33:00,462 - root - INFO - step: 7 loss: 7.7266 memory: 0.87GiB(0.92%) tps: 58,273 tflops: 4.19 mfu: 0.42% [rank0]:[titan] 2025-06-26 18:33:00,499 - root - INFO - step: 8 loss: 7.6807 memory: 0.87GiB(0.92%) tps: 54,435 tflops: 3.91 mfu: 0.40% [rank0]:[titan] 2025-06-26 18:33:00,537 - root - INFO - step: 9 loss: 7.5616 memory: 0.87GiB(0.92%) tps: 55,232 tflops: 3.97 mfu: 0.40% [rank0]:[titan] 2025-06-26 18:33:00,575 - root - INFO - step: 10 loss: 7.5090 memory: 0.87GiB(0.92%) tps: 54,284 tflops: 3.90 mfu: 0.39% ```
@@ -37,30 +37,35 @@ def input_fn(): | |||
if global_batch_size < 0: | |||
# This global batch size results in 1 gradient accumulation | |||
# step. | |||
dp_degree = world_mesh["dp"].size() | |||
dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making dp_replicate
and dp_shard
configurable seem to be weird as 2 * 4 or 4 * 2 make no difference . Should we just stick to one and assert when another is assigned to be larger than 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean they make no difference? i want to use them to control the memory constraints so autop behaves more like ddp vs fsdp vs hsdp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking to have more explicit arguments to tune memory constraints. But I understand now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, i think i will add explicit knobs to get the most out of autoparallel. i was first trying to slot autoparallel into the existing torchtitan as 'seamlessly' as possible.
one thing is that if users ignore the --dp_replicate_degree and similar cmdline args, and use other args for influencing autoparallel, we have the problem of how to decide which mesh dims to create. I will have to think about what to do for that.
assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" | ||
assert parallel_dims.cp_enabled is False, "CP not supported yet" | ||
assert parallel_dims.pp_enabled is False, "PP not supported yet" | ||
|
||
|
||
# bail out | ||
# model = model_fn() | ||
# return model | ||
|
||
autop = AutoParallel(model, input_fn, world_mesh) | ||
autop.add_parameter_memory_constraint(low=None, high=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you might also need to tweak the memory constraint if you want to get DP
behavior. For example, you can get DDP with low=1.0, high=1.0
.
By default (i.e., low=None, high=None
), we get low = 0, high=1 / mesh.size()
, which says "let's shard the sum of all parameters so that each GPU has 1 / mesh.size()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, i will try that next.
lets existing torchtitan knobs which govern DP/TP mesh creation and mesh size influence the sharding constraints of autoparallel, allowing it to support these different sharding configurations.
Examples:
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 1 --training.dataset c4
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpUf57BL/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp981ifR/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 8 --training.dataset c4
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpgPuMRF/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000