diff --git a/custom_backend.py b/custom_backend.py index 6c349ce0a..2002634a0 100644 --- a/custom_backend.py +++ b/custom_backend.py @@ -1,3 +1,4 @@ +import os import shutil from pathlib import Path @@ -109,5 +110,10 @@ def build_sdist(sdist_directory, config_settings=None): def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): + if os.environ.get("FLASHINFER_ENABLE_AOT", "0") == "1": + from flashinfer.aot import main as aot_main + + aot_main([]) + _prepare_for_wheel() return orig.build_wheel(wheel_directory, config_settings, metadata_directory) diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 395d94ae3..79c258d0e 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -1,6 +1,7 @@ import argparse import os import shutil +import sys from itertools import product from pathlib import Path from typing import List, Tuple @@ -375,7 +376,7 @@ def parse_head_dim(head_dim: str) -> Tuple[int, int]: return qo, kv -def main(): +def main(args: list[str]): parser = argparse.ArgumentParser( description="Ahead-of-Time (AOT) build all modules" ) @@ -426,7 +427,7 @@ def main(): type=parse_bool, help="Add kernels for Gemma Model (head_dim=256, use_sliding_window, use_logits_soft_cap)", ) - args = parser.parse_args() + args = parser.parse_args(args) # Default values project_root = Path(__file__).resolve().parents[1] @@ -561,4 +562,4 @@ def main(): if __name__ == "__main__": - main() + main(sys.argv[1:])