-
Notifications
You must be signed in to change notification settings - Fork 24.6k
WIP fast_autotune
: Add lookup table and ML model to filter triton matmul configs
#156683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This PR needs a
|
fb5c0bd
to
8281e63
Compare
fast_autotune
: Add lookup table and ML model to filter triton matmul configs
Predict as log(exp(inputs) + self.kernel_overhead). | ||
|
||
Works well for predicting log(runtime) when runtime contains a constant | ||
overhead of `kernel_overhead`. (The log specification means that this | ||
wouldn't be trivially modeled with a bias term.) | ||
|
||
Probably could have fit the overhead rather than hard-coding it by | ||
having `self.kernel_overhead` be a tunable parameter or by having exp | ||
and log layers. | ||
""" | ||
# TODO: test this | ||
log_base_pred = self.linear_relu_stack(x) | ||
log_overhead_tsr = torch.full_like( | ||
input=log_base_pred, fill_value=self.log_kernel_overhead | ||
) | ||
return torch.logsumexp( | ||
torch.stack([log_base_pred, log_overhead_tsr], dim=-1), dim=-1 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this can be simplified considerably if we don't care about the interpretability of the predictions since this is a monotonic transformation.
Predict as log(exp(inputs) + self.kernel_overhead). | |
Works well for predicting log(runtime) when runtime contains a constant | |
overhead of `kernel_overhead`. (The log specification means that this | |
wouldn't be trivially modeled with a bias term.) | |
Probably could have fit the overhead rather than hard-coding it by | |
having `self.kernel_overhead` be a tunable parameter or by having exp | |
and log layers. | |
""" | |
# TODO: test this | |
log_base_pred = self.linear_relu_stack(x) | |
log_overhead_tsr = torch.full_like( | |
input=log_base_pred, fill_value=self.log_kernel_overhead | |
) | |
return torch.logsumexp( | |
torch.stack([log_base_pred, log_overhead_tsr], dim=-1), dim=-1 | |
) | |
Produce predictions that are on a log scale and do | |
not account for constant overhead. | |
Exponentiating these predictions produces a runtime prediction in ms, | |
not accounting for overhead. | |
""" | |
return self.linear_relu_stack(x) |
self.model = NeuralNetwork( | ||
n_inputs=12, hidden_layer_widths=[2**8 for _ in range(6)] | ||
) | ||
self.model.load_state_dict(torch.load(MODEL_PATH)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside: In a future update I hope to save the coefficients on the model so that load_state_dict
is all that's needed, rather than hardcoding the numbers in L480-L509
if dtype == torch.bfloat16 or dtype == torch.float16: | ||
dsize = 16 | ||
elif dtype == torch.float32: | ||
dsize = 32 | ||
else: | ||
raise ValueError(f"Unsupported dtype: {dtype}. Add support for this dtype.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if dtype == torch.bfloat16 or dtype == torch.float16: | |
dsize = 16 | |
elif dtype == torch.float32: | |
dsize = 32 | |
else: | |
raise ValueError(f"Unsupported dtype: {dtype}. Add support for this dtype.") | |
dsize = dtype.itemsize * 8 |
(I'm not sure how this plays into the broader logic; it might still be good to check that it's a float and is 16- or 32-bit)
def decode(self, ret_tensor: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Decode the model output tensor. | ||
|
||
Args: | ||
ret_tensor: Output tensor from the model | ||
|
||
Returns: | ||
Decoded tensor representing runtime predictions | ||
""" | ||
return ret_tensor | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't seem to do anything
@@ -443,12 +443,61 @@ def prologue_fusion_enabled() -> bool: | |||
).upper() | |||
|
|||
|
|||
# Specify the size of the benchmarking space for GEMM autotuning with the neural network model. | |||
# SAME - There should be no functional difference between this and max_autotune_gemm_search_space | |||
# DEFAULT - Benchmark the same number of configs as max_autotune, but search over a larger space using the model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know what the number is? Or does it depend on something?
# DEFAULT - Benchmark the same number of configs as max_autotune, but search over a larger space using the model | |
# DEFAULT - Benchmark the same number of configs as max_autotune, but use the model to tailor those configs to the inputs, selecting them from a large space |
|
@@ -443,12 +443,61 @@ def prologue_fusion_enabled() -> bool: | |||
).upper() | |||
|
|||
|
|||
# Specify the size of the benchmarking space for GEMM autotuning with the neural network model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you put all this stuff under an autotune: class? we can migrate other things to it later but I'd like to have some logical consistency
@@ -1474,7 +1474,10 @@ def get_tma_workspace_arg( | |||
|
|||
def use_max_autotune() -> bool: | |||
return ( | |||
config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache | |||
config.max_autotune |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nmacchioni is this the thing we're deprecating?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, this will be gone very shortly. It was being used improperly before and causing a bit of a headache deprecating search_autotune_cache
. I'd suggest searching for occurences of max_autotune_gemm
and deciding what needs to be changed from there
@@ -0,0 +1,364 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couple things to make this easier to digest
_inductor/performance_prediction_models/
- add an abstract class interface that exposes the basic functions you have in here. You can just make it say estimate() and take in the m,n,k and the list of configs
- make the wrapper here etc all an implementation of that one layer deeper
- implement the get_model() in the init of the performance_prediction_model class
max_autotune_gemm_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get( | ||
"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT" | ||
).upper() # type: ignore[assignment] | ||
max_autotune_gemm_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would why away from subjective names like fast for now and rather do more sterile things until we have a better idea about usability, topk, etc, but that's more of a nit, not a blocker
|
||
# Default model path - can be overridden by environment variable | ||
DEFAULT_MODEL_PATH = "./triton_h100_from_arm_108.pkl" | ||
MODEL_PATH = os.environ.get("TRITON_KERNEL_SELECTION_MODEL_PATH", DEFAULT_MODEL_PATH) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idk if I would add this feature, as I think the user needs to provide more than just a pkl. This seems to me like you're saying "provide a model pkl and it'll work" and that's not true is it?. I think it's fine for local development etc internally to ask people to override this and not expose the env var for now
Cool! I see you're already planning to split this up later, can provide more thorough review then. In the meantime, I'd be really interested to see some perf numbers; mainly, I'm interested in knowing how the model holds up against standard max-autotune and max-autotune with exhaustive search space (for a variety of shapes, dtypes, etc.). I'd also like to hear more about your thoughts on model freshness/update cadence? Essentially, how many data points did it take to initially train the model and how often do you plan on updating it? Will it be re-trained for every Triton update? Is it possible that a stale model can cause perf drops? Thanks! |
Yes, very fair, evidence on the way! Currently generating a compiletime vs perf tradeoff graph of the model vs status quo.
I'd say Triton version updates, CUDA/ROCm version updates, or new triton configs like Paul's split-K. The model itself takes a few hours to train, so that's not the issue. It's mainly when the data gets stale.
I think it's ~100m rows right now. We want to be more efficient with sampling using bayesian methods next half. Elizabeth and the applied science team she's a part of are specialists in Bayesian methods so there's a lot of room for improvement here. This model is more kitchen sink as a proof of concept.
Yes, this is definitely a danger, ML Models are a pretty extreme form of tech debt. I think the solution is monitoring and keeping on top of updates, and also becoming more efficient with sampling so that we have enough data to train the updated version. |
fast_autotune
: Add lookup table and ML model to filter triton matmul configsfast_autotune
: Add lookup table and ML model to filter triton matmul configs
The goal of this PR is to deliver
max-autotune
performance at a much lower compile time cost. Currently,max-autotune
has low adoption in training workloads because of the compile time cost, and could have more adoption in inference workloads with lower compile time. We want to deliver better performance per unit compile time by benchmarking better configs, selected with the context of the input shapes taken into account. We provide two ways of doing this: manually filtering the configs with a LUT and automatically filtering the configs with a ML Model.Part 1:
fast-autotune
: Model Prediction of Triton Kernel Runtimes #156851Features:
fast_autotune
: Uses ML model and lookup table to evaluate and filter the config space before benchmarking.kernel_lut
: Uses a lookup table with a standardized format to filter configs. Users can create these tables by creating the python class and saving the.serialize()
string. These tables are serialized to json, and can be edited by hand or by external tools, in the future.matmul_gemm_autotune_benchmark_space
: disaggregates the space of what we benchmark from what we search."DEFAULT"
will match the number of configs in max-autotune defaults,"SAME"
means that the search space is the same as the the benchmarking space, so model filtering will be disabled, and<int>
will set the top-k to<int>
.fast_autotune
: New setting that setsmatmul_gemm_autotune_search_space
to"EXHAUSTIVE"
andmatmul_gemm_autotune_benchmark_space
to1
.kernel_lut_path
: Lookup path for the kernel LUT. The kernel LUT will be on when this is non-None, regardless of the other settings. So these features can be applied independently.We plan on merging with Autoheuristic eventually, and we're using it for data collection, but it's currently in a broken state for inference because its tests were disabled for ~6 months. We fixed some of the tests, but we also found there was a bit of a mismatch between it and what we're trying to do, given autoheuristic is predicting over all configs, we are currently focused on just Triton Configs, and the shared infrastructure for the LUT requires it being at the start of the mm lowering function. Once our model and LUT handles all configs, we will move the callsite lower in the mm lowering function until it's at the same spot as autoherustic, at which point they can be merged. The end goal is to create config "middleware" that users can apply independently to filter the triton configs.
This will work, but our approach has several advantages:
Just wanted to get feedback on the whole thing first since it has a shared design. We'll split it up before merge.
We really don't want string parsing to go wrong in production, so we use a combination of property based tests and unit tests to ensure that tables are parsed successfully, and if they're not successfully parsed, we fail gracefully and notify the user.
from_dict
andto_dict
are tricky functions so I think it's warranted. The tests run in ~5s since it's just string parsing.TODO
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov