From 8194ed5ef45ffb3f783cdaac4881642b253e603f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Tue, 8 Jul 2025 10:56:20 +0200 Subject: [PATCH] feat: Restore convenience `FLASHINFER_ENABLE_AOT` option Restore the minimal support for `FLASHINFER_ENABLE_AOT` environment variable that was removed as part of the refactoring in #1075. This option was quite useful for downstreams like Red Hat, since it provided for convenient integration of AOT mode in regular PEP 517 workflows. This change aims to be absolutely minimal -- it does not change anything for the default workflow, merely readds `FLASHINFER_ENABLE_AOT` that automatically invokes `flashinfer.aot.main()`. This makes it possible to perform an AOT build in a single step such as: ``` TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" FLASHINFER_ENABLE_AOT=1 python -m build -w ``` or: ``` TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" FLASHINFER_ENABLE_AOT=1 pip install . ``` --- custom_backend.py | 6 ++++++ flashinfer/aot.py | 7 ++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/custom_backend.py b/custom_backend.py index 6c349ce0a..2002634a0 100644 --- a/custom_backend.py +++ b/custom_backend.py @@ -1,3 +1,4 @@ +import os import shutil from pathlib import Path @@ -109,5 +110,10 @@ def build_sdist(sdist_directory, config_settings=None): def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): + if os.environ.get("FLASHINFER_ENABLE_AOT", "0") == "1": + from flashinfer.aot import main as aot_main + + aot_main([]) + _prepare_for_wheel() return orig.build_wheel(wheel_directory, config_settings, metadata_directory) diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 395d94ae3..79c258d0e 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -1,6 +1,7 @@ import argparse import os import shutil +import sys from itertools import product from pathlib import Path from typing import List, Tuple @@ -375,7 +376,7 @@ def parse_head_dim(head_dim: str) -> Tuple[int, int]: return qo, kv -def main(): +def main(args: list[str]): parser = argparse.ArgumentParser( description="Ahead-of-Time (AOT) build all modules" ) @@ -426,7 +427,7 @@ def main(): type=parse_bool, help="Add kernels for Gemma Model (head_dim=256, use_sliding_window, use_logits_soft_cap)", ) - args = parser.parse_args() + args = parser.parse_args(args) # Default values project_root = Path(__file__).resolve().parents[1] @@ -561,4 +562,4 @@ def main(): if __name__ == "__main__": - main() + main(sys.argv[1:])