Skip to content

Commit ebfee60

Browse files
eellisonpytorchmergebot
authored andcommitted
[WIP] more aggressive persistent reduction (pytorch#161055)
Gives 18% speedup on rms norm (2048, 32768). And we have seen other instances where inductor is not aggressive enough about codegening persistent reductions - e.g. 39% on [this kernel from torch ao](pytorch#159769 (comment)). Codegen-ing persistent reductions can be risky if you run out of registers. Here, I'm effectively making persistent reductions an option of looped reductions by setting RBLOCK == rnumel, so that we can still fallback to looped reductions as needed. As criteria: - there needs to be significant memory savings from doing a persistent reduction (by keeping memory in register and avoiding another iteration over input) - we should not be coalescing on x dimension, otherwise large rblock will inhibit coalescing - we should not be especially register or arithmetic intensive (this last part uses mem_ops_per_thread, but could be improved). Still need to do dashboard run, although I'm not sure we get a lot of large rblock in our benchmarks. Pull Request resolved: pytorch#161055 Approved by: https://github.com/jansel
1 parent 6db872f commit ebfee60

File tree

3 files changed

+101
-13
lines changed

3 files changed

+101
-13
lines changed

torch/_inductor/codegen/simd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ def __init__(
408408
else self.should_use_cooperative_reduction()
409409
)
410410
self.tiling_scores: Optional[dict[str, sympy.Expr]] = tiling_scores
411+
self.tiling: dict[str, sympy.Expr] = tiling
411412
self.persistent_reduction: bool = (
412413
override_persistent_reduction
413414
if override_persistent_reduction is not None

torch/_inductor/codegen/triton.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4022,7 +4022,12 @@ def inductor_meta_common():
40224022
)
40234023
return inductor_meta
40244024

4025-
def codegen_kernel(self, name=None):
4025+
def codegen_kernel(self, name=None) -> str:
4026+
"""
4027+
Convert the TritonKernel from Inductor SIMD IR to triton code, including inductor triton heuristics, imports,
4028+
metadata, and benchmarking infra.
4029+
"""
4030+
40264031
code = IndentedBuffer()
40274032

40284033
size_hints = {}
@@ -4163,6 +4168,53 @@ def add_constexpr_arg(arg_name):
41634168
"num_reduction": self.num_reduction,
41644169
**self.inductor_meta_common(),
41654170
}
4171+
4172+
# Bail on 3d tiling, which has more complicated coalesce patterns
4173+
looped_red = V.kernel.features.is_reduction() and not self.persistent_reduction
4174+
tiling_scores = self.tiling_scores
4175+
two_d_red = (
4176+
len(self.tiling) == 2 and tiling_scores is not None and "x" in tiling_scores
4177+
)
4178+
if looped_red and two_d_red:
4179+
assert tiling_scores is not None
4180+
memory_stats = self.features.memory_stats(self.tiling)
4181+
dim_stats = memory_stats.persistent.memory.dim[0]
4182+
mem_ops_per_thread = dim_stats.count_per_thread
4183+
4184+
# check if majority of reads are coalesced by the rblock
4185+
r_coalesce_ratio = tiling_scores["r0_"] / max(tiling_scores["x"], 1)
4186+
4187+
looped_mem = memory_stats.looped.memory.bytes
4188+
persistent_mem = memory_stats.persistent.memory.bytes
4189+
# check that we save significant memory by doing persistent
4190+
saved_bytes_ratio = V.graph.sizevars.size_hint(
4191+
looped_mem, fallback=config.unbacked_symint_fallback
4192+
) / max(
4193+
V.graph.sizevars.size_hint(
4194+
persistent_mem, fallback=config.unbacked_symint_fallback
4195+
),
4196+
1,
4197+
)
4198+
4199+
# TODO - rnumel should be reasonably close to power of 2
4200+
if (
4201+
# significant memory bandwidth savings
4202+
saved_bytes_ratio >= 1.3
4203+
# large rblock inhibits xblock size, dont attempt if there is a decent amount of
4204+
# reads coalesced by xblock
4205+
and r_coalesce_ratio >= 8.0
4206+
# TODO - need more detailed register analysis
4207+
and V.graph.sizevars.statically_known_leq(
4208+
self.features.reduction_numel, 32768
4209+
)
4210+
# We will already generate a persistent config in this case
4211+
and V.graph.sizevars.statically_known_gt(
4212+
self.features.reduction_numel, 2048
4213+
)
4214+
and mem_ops_per_thread <= 10
4215+
):
4216+
inductor_meta["add_persistent_rblock"] = True
4217+
41664218
if self.tiling_scores:
41674219
inductor_meta["tiling_scores"] = self.tiling_scores
41684220

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@
8282
)
8383

8484

85+
class InductorConfig(Config):
86+
"""Inductor-specific Triton config with additional control flags"""
87+
88+
def __init__(self, *args, dynamic_scale_rblock=True, **kwargs):
89+
super().__init__(*args, **kwargs)
90+
self.dynamic_scale_rblock = dynamic_scale_rblock
91+
92+
8593
class NoTritonConfigsError(RuntimeError):
8694
pass
8795

@@ -2249,6 +2257,7 @@ def triton_config_reduction(
22492257
num_stages=1,
22502258
num_warps=None,
22512259
register_intensive=False,
2260+
dynamic_scale_rblock=True,
22522261
) -> Config:
22532262
"""
22542263
Construct a reduction triton config with some adjustment heuristics
@@ -2292,7 +2301,12 @@ def total_numel() -> int:
22922301
cfg = _get_config({"x": x, **rnumels})
22932302
check_max_block(cfg)
22942303
check_config(cfg, xnumel=size_hints["x"])
2295-
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
2304+
return InductorConfig(
2305+
cfg,
2306+
num_warps=num_warps,
2307+
num_stages=num_stages,
2308+
dynamic_scale_rblock=dynamic_scale_rblock,
2309+
)
22962310

22972311

22982312
def _get_config(numels: dict[str, int]) -> dict[str, int]:
@@ -2490,11 +2504,10 @@ def _reduction_configs(
24902504

24912505
register_intensive = False
24922506
MAX_R0_BLOCK = 2048
2493-
if (
2494-
size_hints["x"] >= 1024
2495-
and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0)
2496-
>= 10
2497-
):
2507+
loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get(
2508+
"num_reduction", 0
2509+
)
2510+
if size_hints["x"] >= 1024 and loads_and_red >= 10:
24982511
# A heuristics to reduce R0_BLOCK if a kernel potentially need many registers.
24992512
# Consider load and reduction since load need move data into registers and
25002513
# reduction needs an accumulator.
@@ -2510,7 +2523,14 @@ def _reduction_configs(
25102523
MAX_R0_BLOCK = 1024
25112524
register_intensive = True
25122525

2513-
def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
2526+
def make_config(
2527+
x,
2528+
r,
2529+
num_warps=None,
2530+
num_stages=1,
2531+
register_intensive=False,
2532+
dynamic_scale_rblock=True,
2533+
):
25142534
# For 3D case with tiling scores, create an adapted version
25152535
if "y" in size_hints:
25162536
assert "tiling_scores" in inductor_meta
@@ -2532,6 +2552,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
25322552
num_warps=num_warps,
25332553
num_stages=num_stages,
25342554
register_intensive=register_intensive,
2555+
dynamic_scale_rblock=dynamic_scale_rblock,
25352556
)
25362557

25372558
def outer_config_opt():
@@ -2598,6 +2619,19 @@ def outer_config_opt():
25982619
# for correctness
25992620
if not torch.version.hip and not is_fbcode():
26002621
outer_config = outer_config_opt()
2622+
2623+
configs = []
2624+
2625+
if inductor_meta.get("add_persistent_rblock") and loads_and_red <= 8:
2626+
xnumel = max(4096 // rnumel, 1)
2627+
c = make_config(
2628+
xnumel,
2629+
rnumel,
2630+
register_intensive=register_intensive,
2631+
dynamic_scale_rblock=False,
2632+
)
2633+
configs.append(c)
2634+
26012635
# For 3d tiling, default to more autotuning initially
26022636
if "y" in size_hints:
26032637
pass
@@ -2606,14 +2640,15 @@ def outer_config_opt():
26062640
):
26072641
pass # skip all these cases
26082642
elif reduction_hint == ReductionHint.INNER:
2609-
return [contiguous_config]
2643+
return configs + [contiguous_config]
26102644
elif reduction_hint == ReductionHint.OUTER:
2611-
return [outer_config]
2645+
return configs + [outer_config]
26122646
elif reduction_hint == ReductionHint.OUTER_TINY:
2613-
return [tiny_config]
2647+
return configs + [tiny_config]
26142648
if disable_pointwise_autotuning(inductor_meta):
2615-
return [make_config(32, 128)]
2616-
return [
2649+
return configs + [make_config(32, 128)]
2650+
2651+
return configs + [
26172652
contiguous_config,
26182653
outer_config,
26192654
tiny_config,

0 commit comments

Comments
 (0)