Skip to content

WIP fast_autotune: Add lookup table and ML model to filter triton matmul configs #156683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

exclamaforte
Copy link
Contributor

@exclamaforte exclamaforte commented Jun 24, 2025

The goal of this PR is to deliver max-autotune performance at a much lower compile time cost. Currently, max-autotune has low adoption in training workloads because of the compile time cost, and could have more adoption in inference workloads with lower compile time. We want to deliver better performance per unit compile time by benchmarking better configs, selected with the context of the input shapes taken into account. We provide two ways of doing this: manually filtering the configs with a LUT and automatically filtering the configs with a ML Model.

Part 1:

Features:

  • fast_autotune: Uses ML model and lookup table to evaluate and filter the config space before benchmarking.
  • kernel_lut: Uses a lookup table with a standardized format to filter configs. Users can create these tables by creating the python class and saving the .serialize() string. These tables are serialized to json, and can be edited by hand or by external tools, in the future.
  • New Configs:
    • matmul_gemm_autotune_benchmark_space: disaggregates the space of what we benchmark from what we search. "DEFAULT" will match the number of configs in max-autotune defaults, "SAME" means that the search space is the same as the the benchmarking space, so model filtering will be disabled, and <int> will set the top-k to <int>.
    • fast_autotune: New setting that sets matmul_gemm_autotune_search_space to "EXHAUSTIVE" and matmul_gemm_autotune_benchmark_space to 1.
    • kernel_lut_path: Lookup path for the kernel LUT. The kernel LUT will be on when this is non-None, regardless of the other settings. So these features can be applied independently.
  • Fixes autoheuristic tests, which have been broken for awhile.

Why not integrate it into Autoheuristic?

We plan on merging with Autoheuristic eventually, and we're using it for data collection, but it's currently in a broken state for inference because its tests were disabled for ~6 months. We fixed some of the tests, but we also found there was a bit of a mismatch between it and what we're trying to do, given autoheuristic is predicting over all configs, we are currently focused on just Triton Configs, and the shared infrastructure for the LUT requires it being at the start of the mm lowering function. Once our model and LUT handles all configs, we will move the callsite lower in the mm lowering function until it's at the same spot as autoherustic, at which point they can be merged. The end goal is to create config "middleware" that users can apply independently to filter the triton configs.

For the LUT, why not just find/replace the configs in template_heuristic.py?

This will work, but our approach has several advantages:

  • json files are easier to manage than code patches and they can be shared and generated externally.
  • Configs are changed from template_heuristic before appending. The configs found in the LUT are guaranteed to get in.

Why is PR so huge?

Just wanted to get feedback on the whole thing first since it has a shared design. We'll split it up before merge.

Why so many tests for the lookup table?

We really don't want string parsing to go wrong in production, so we use a combination of property based tests and unit tests to ensure that tables are parsed successfully, and if they're not successfully parsed, we fail gracefully and notify the user. from_dict and to_dict are tricky functions so I think it's warranted. The tests run in ~5s since it's just string parsing.

TODO

  • Graph of Performance vs Runtime tradeoff, status quo vs this change.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link

pytorch-bot bot commented Jun 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156683

Note: Links to docs will display an error until the docs builds have been completed.

❌ 29 New Failures, 2 Unrelated Failures

As of commit 5c3c489 with merge base 2625c70 (image):

NEW FAILURES - The following jobs have failed:

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@exclamaforte exclamaforte added the topic: new features topic category label Jun 24, 2025
@exclamaforte exclamaforte force-pushed the exclamforte/gemm-model-final branch from fb5c0bd to 8281e63 Compare June 24, 2025 08:48
@exclamaforte exclamaforte changed the title Add lookup table and ML model to filter triton matmul configs fast_autotune: Add lookup table and ML model to filter triton matmul configs Jun 24, 2025
Comment on lines 75 to 92
Predict as log(exp(inputs) + self.kernel_overhead).

Works well for predicting log(runtime) when runtime contains a constant
overhead of `kernel_overhead`. (The log specification means that this
wouldn't be trivially modeled with a bias term.)

Probably could have fit the overhead rather than hard-coding it by
having `self.kernel_overhead` be a tunable parameter or by having exp
and log layers.
"""
# TODO: test this
log_base_pred = self.linear_relu_stack(x)
log_overhead_tsr = torch.full_like(
input=log_base_pred, fill_value=self.log_kernel_overhead
)
return torch.logsumexp(
torch.stack([log_base_pred, log_overhead_tsr], dim=-1), dim=-1
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this can be simplified considerably if we don't care about the interpretability of the predictions since this is a monotonic transformation.

Suggested change
Predict as log(exp(inputs) + self.kernel_overhead).
Works well for predicting log(runtime) when runtime contains a constant
overhead of `kernel_overhead`. (The log specification means that this
wouldn't be trivially modeled with a bias term.)
Probably could have fit the overhead rather than hard-coding it by
having `self.kernel_overhead` be a tunable parameter or by having exp
and log layers.
"""
# TODO: test this
log_base_pred = self.linear_relu_stack(x)
log_overhead_tsr = torch.full_like(
input=log_base_pred, fill_value=self.log_kernel_overhead
)
return torch.logsumexp(
torch.stack([log_base_pred, log_overhead_tsr], dim=-1), dim=-1
)
Produce predictions that are on a log scale and do
not account for constant overhead.
Exponentiating these predictions produces a runtime prediction in ms,
not accounting for overhead.
"""
return self.linear_relu_stack(x)

self.model = NeuralNetwork(
n_inputs=12, hidden_layer_widths=[2**8 for _ in range(6)]
)
self.model.load_state_dict(torch.load(MODEL_PATH))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: In a future update I hope to save the coefficients on the model so that load_state_dict is all that's needed, rather than hardcoding the numbers in L480-L509

Comment on lines +593 to +598
if dtype == torch.bfloat16 or dtype == torch.float16:
dsize = 16
elif dtype == torch.float32:
dsize = 32
else:
raise ValueError(f"Unsupported dtype: {dtype}. Add support for this dtype.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if dtype == torch.bfloat16 or dtype == torch.float16:
dsize = 16
elif dtype == torch.float32:
dsize = 32
else:
raise ValueError(f"Unsupported dtype: {dtype}. Add support for this dtype.")
dsize = dtype.itemsize * 8

(I'm not sure how this plays into the broader logic; it might still be good to check that it's a float and is 16- or 32-bit)

Comment on lines +659 to +670
def decode(self, ret_tensor: torch.Tensor) -> torch.Tensor:
"""
Decode the model output tensor.

Args:
ret_tensor: Output tensor from the model

Returns:
Decoded tensor representing runtime predictions
"""
return ret_tensor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't seem to do anything

@@ -443,12 +443,61 @@ def prologue_fusion_enabled() -> bool:
).upper()


# Specify the size of the benchmarking space for GEMM autotuning with the neural network model.
# SAME - There should be no functional difference between this and max_autotune_gemm_search_space
# DEFAULT - Benchmark the same number of configs as max_autotune, but search over a larger space using the model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know what the number is? Or does it depend on something?

Suggested change
# DEFAULT - Benchmark the same number of configs as max_autotune, but search over a larger space using the model
# DEFAULT - Benchmark the same number of configs as max_autotune, but use the model to tailor those configs to the inputs, selecting them from a large space

Copy link

CLA Not Signed

@@ -443,12 +443,61 @@ def prologue_fusion_enabled() -> bool:
).upper()


# Specify the size of the benchmarking space for GEMM autotuning with the neural network model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put all this stuff under an autotune: class? we can migrate other things to it later but I'd like to have some logical consistency

@@ -1474,7 +1474,10 @@ def get_tma_workspace_arg(

def use_max_autotune() -> bool:
return (
config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache
config.max_autotune
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nmacchioni is this the thing we're deprecating?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this will be gone very shortly. It was being used improperly before and causing a bit of a headache deprecating search_autotune_cache. I'd suggest searching for occurences of max_autotune_gemm and deciding what needs to be changed from there

@@ -0,0 +1,364 @@
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couple things to make this easier to digest

  • _inductor/performance_prediction_models/
  • add an abstract class interface that exposes the basic functions you have in here. You can just make it say estimate() and take in the m,n,k and the list of configs
  • make the wrapper here etc all an implementation of that one layer deeper
  • implement the get_model() in the init of the performance_prediction_model class

max_autotune_gemm_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get(
"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]
max_autotune_gemm_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would why away from subjective names like fast for now and rather do more sterile things until we have a better idea about usability, topk, etc, but that's more of a nit, not a blocker


# Default model path - can be overridden by environment variable
DEFAULT_MODEL_PATH = "./triton_h100_from_arm_108.pkl"
MODEL_PATH = os.environ.get("TRITON_KERNEL_SELECTION_MODEL_PATH", DEFAULT_MODEL_PATH)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk if I would add this feature, as I think the user needs to provide more than just a pkl. This seems to me like you're saying "provide a model pkl and it'll work" and that's not true is it?. I think it's fine for local development etc internally to ask people to override this and not expose the env var for now

@nmacchioni
Copy link
Contributor

Cool! I see you're already planning to split this up later, can provide more thorough review then. In the meantime, I'd be really interested to see some perf numbers; mainly, I'm interested in knowing how the model holds up against standard max-autotune and max-autotune with exhaustive search space (for a variety of shapes, dtypes, etc.).

I'd also like to hear more about your thoughts on model freshness/update cadence? Essentially, how many data points did it take to initially train the model and how often do you plan on updating it? Will it be re-trained for every Triton update? Is it possible that a stale model can cause perf drops?

Thanks!

@exclamaforte
Copy link
Contributor Author

exclamaforte commented Jun 25, 2025

@nmacchioni

In the meantime, I'd be really interested to see some perf numbers; mainly, I'm interested in knowing how the model holds up against standard max-autotune and max-autotune with exhaustive search space (for a variety of shapes, dtypes, etc.).

Yes, very fair, evidence on the way! Currently generating a compiletime vs perf tradeoff graph of the model vs status quo.

I'd also like to hear more about your thoughts on model freshness/update cadence?

I'd say Triton version updates, CUDA/ROCm version updates, or new triton configs like Paul's split-K. The model itself takes a few hours to train, so that's not the issue. It's mainly when the data gets stale.

Essentially, how many data points did it take to initially train the model and how often do you plan on updating it?

I think it's ~100m rows right now. We want to be more efficient with sampling using bayesian methods next half. Elizabeth and the applied science team she's a part of are specialists in Bayesian methods so there's a lot of room for improvement here. This model is more kitchen sink as a proof of concept.

Is it possible that a stale model can cause perf drops?

Yes, this is definitely a danger, ML Models are a pretty extreme form of tech debt. I think the solution is monitoring and keeping on top of updates, and also becoming more efficient with sampling so that we have enough data to train the updated version.

@exclamaforte exclamaforte changed the title fast_autotune: Add lookup table and ML model to filter triton matmul configs WIP fast_autotune: Add lookup table and ML model to filter triton matmul configs Jul 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants