Skip to content

Commit 083fdaf

Browse files
mikekgfbJack-Khuu
andauthored
Enable sdpa backends for server export in export.py (#1478)
* Enable sdpa backends for server export in export.py FLASH worked for dso models, so try this with methodical tests * Update more-tests.yml Add tests for sdpa backends with server export (x86 cpu & cuda) * Update export.py Fix typo. * Update more-tests.yml Display test information to simplify debug * Update more-tests.yml fix typo * Update more-tests.yml Need to generate cmake-out * Update more-tests.yml Update MODEL_DIR definition so that aoti_run can find the tokenizer.model * Update more-tests.yml * Update more-tests.yml --------- Co-authored-by: Jack-Khuu <[email protected]>
1 parent d607ecc commit 083fdaf

File tree

2 files changed

+84
-18
lines changed

2 files changed

+84
-18
lines changed

.github/workflows/more-tests.yml

+67-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ jobs:
1919
gpu-arch-version: "12.4"
2020
timeout: 60
2121
script: |
22+
set -xeou pipefail
2223
echo "::group::Print machine info"
2324
uname -a
2425
echo "::endgroup::"
@@ -39,9 +40,10 @@ jobs:
3940
echo "::endgroup::"
4041
4142
echo "::group::Run inference"
42-
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
43+
export MODEL_DIR=checkpoints/stories15M/
44+
export MODEL_PATH=${MODEL_DIR}/stories15M.pt
4345
export MODEL_NAME=stories15M
44-
export MODEL_DIR=/tmp
46+
4547
4648
for DTYPE in bfloat16 float16 float32; do
4749
###################################################################
@@ -83,3 +85,66 @@ jobs:
8385
echo "tests complete"
8486
echo "******************************************"
8587
echo "::endgroup::"
88+
89+
90+
test-sdpa-backends-export:
91+
permissions:
92+
id-token: write
93+
contents: read
94+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
95+
with:
96+
runner: linux.g5.4xlarge.nvidia.gpu
97+
gpu-arch-type: cuda
98+
gpu-arch-version: "12.4"
99+
timeout: 60
100+
script: |
101+
set -xeou pipefail
102+
echo "::group::Print machine info"
103+
uname -a
104+
echo "::endgroup::"
105+
106+
echo "::group::Download checkpoints"
107+
# Install requirements
108+
./install/install_requirements.sh cuda
109+
pip3 list
110+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
111+
echo "::endgroup::"
112+
113+
echo "::group::Download checkpoints"
114+
mkdir -p checkpoints/stories15M
115+
pushd checkpoints/stories15M
116+
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
117+
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
118+
popd
119+
echo "::endgroup::"
120+
121+
echo "::group::Run inference"
122+
export MODEL_DIR=checkpoints/stories15M/
123+
export MODEL_PATH=${MODEL_DIR}/stories15M.pt
124+
export MODEL_NAME=stories15M
125+
126+
./torchchat/utils/scripts/build_native.sh aoti
127+
128+
for DEVICE in cpu cuda; do
129+
# depending on how the parameter passing works, may only be able to do bfloat16 for aoti_run, similar to runner-cuda-dtype.yml
130+
# (although the runner environment should not have an opinion what we us in the artifact, and we might suitably abstract that)
131+
for DTYPE in bfloat16 float16 float32; do
132+
for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do
133+
echo "***************************************************************"
134+
echo "*** $DEVICE $DTYPE $SDPA"
135+
###################################################################
136+
# Export DSO and run with Python
137+
python torchchat.py export --output-dso dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE}
138+
python torchchat.py generate --dso-path dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --prompt "Once upon a time"
139+
###################################################################
140+
# Export AOTI and run with aoti_run
141+
python torchchat.py export --output-aoti /tmp/model.pt2 --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE}
142+
./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "Once upon a time"
143+
###################################################################
144+
done
145+
done
146+
done
147+
148+
echo "tests complete"
149+
echo "******************************************"
150+
echo "::endgroup::"

torchchat/export.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -490,13 +490,14 @@ def main(args):
490490
print(
491491
"WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
492492
)
493-
export_for_server(
494-
model_to_dso,
495-
builder_args.device,
496-
output_dso_path,
497-
builder_args.dynamic_shapes,
498-
package=False,
499-
)
493+
with torch.nn.attention.sdpa_kernel([builder_args.attention_backend]):
494+
export_for_server(
495+
model_to_dso,
496+
builder_args.device,
497+
output_dso_path,
498+
builder_args.dynamic_shapes,
499+
package=False,
500+
)
500501

501502
if output_aoti_package_path:
502503
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
@@ -512,14 +513,15 @@ def main(args):
512513
print(
513514
"Exporting model using AOT Inductor to " f"{output_aoti_package_path}."
514515
)
515-
export_for_server(
516-
model_to_aoti_package,
517-
builder_args.device,
518-
output_aoti_package_path,
519-
builder_args.dynamic_shapes,
520-
package=True,
521-
metadata=metadata,
522-
)
516+
with torch.nn.attention.sdpa_kernel([builder_args.attention_backend]):
517+
export_for_server(
518+
model_to_aoti_package,
519+
builder_args.device,
520+
output_aoti_package_path,
521+
builder_args.dynamic_shapes,
522+
package=True,
523+
metadata=metadata,
524+
)
523525

524526
if output_snapshot_path:
525527
output_snapshot_path = str(os.path.abspath(output_snapshot_path))
@@ -529,4 +531,3 @@ def main(args):
529531
builder_args.device,
530532
output_snapshot_path,
531533
)
532-

0 commit comments

Comments
 (0)