Skip to content

Commit b994f6e

Browse files
kundaMwizapytorchmergebot
authored andcommitted
[inductor] check block options after broadcasting and singleton dims have been removed (pytorch#161602)
This will allow for some more cases to use tensor descriptors e.g. before the following block params would not match because the innermost dimension does not have stride 1 ```python block_params=BlockParameters(shape=[64, 4, 1, 1], block_shape=[((XBLOCK + 3)//4), Min(4, XBLOCK), 1, 1], strides=[0, 1, 0, 0], offsets=[(xoffset//4), ModularIndexing(xoffset, 1, 4), 0, 0]) ``` After broadcasting dimensions and singleton dimensions are removed: ```python block_params=BlockParameters(shape=[4], block_shape=[Min(4, XBLOCK)], strides=[1], offsets=[ModularIndexing(xoffset, 1, 4)]) ``` Pull Request resolved: pytorch#161602 Approved by: https://github.com/jansel
1 parent f44ad54 commit b994f6e

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

test/inductor/test_torchinductor_strided_blocks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def xfail_if_use_tensor_descriptor(fn):
7878
"test_2d_reduction_odd_shapes_view_size1_num_block_pointers_3_num_triton_kernels_2_reduction_op1",
7979
"test_broadcast_prefer_nd_tiling_False_x_size0_y_size0",
8080
"test_broadcast_prefer_nd_tiling_False_x_size2_y_size2",
81-
"test_broadcast_prefer_nd_tiling_False_x_size3_y_size3",
8281
"test_broadcast_prefer_nd_tiling_True_x_size0_y_size0",
8382
"test_broadcast_prefer_nd_tiling_True_x_size2_y_size2",
8483
"test_broadcast_with_singleton_dims",

torch/_inductor/codegen/triton.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2298,27 +2298,31 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]:
22982298

22992299
# Form the block pointer or TMA descriptor.
23002300
self.filter_masks(mask_vars)
2301-
options_class: type[BlockDescriptorOptions]
2302-
if config.triton.use_block_ptr:
2303-
options_class = BlockPtrOptions
2304-
else:
2301+
2302+
options_class = (
2303+
BlockPtrOptions
2304+
if config.triton.use_block_ptr
2305+
else TensorDescriptorOptions
2306+
)
2307+
options = options_class.create(
2308+
params=block_params,
2309+
constant_offset=offset,
2310+
range_trees=range_trees,
2311+
mask_vars=mask_vars,
2312+
get_max_block=self.max_block,
2313+
)
2314+
2315+
if options_class == TensorDescriptorOptions:
23052316
nonlocal tma_compatibility_checker
23062317
tma_compatibility_checker = cast(
23072318
TMACompatibilityChecker, tma_compatibility_checker
23082319
)
23092320
if not tma_compatibility_checker.are_block_parameters_compatible(
2310-
block_params
2321+
options.params
23112322
):
23122323
return None
2313-
options_class = TensorDescriptorOptions
23142324

2315-
return options_class.create(
2316-
params=block_params,
2317-
constant_offset=offset,
2318-
range_trees=range_trees,
2319-
mask_vars=mask_vars,
2320-
get_max_block=self.max_block,
2321-
)
2325+
return options
23222326

23232327
# Return a block pointer, if indexing matches the pattern.
23242328
options = match_block_expr()

0 commit comments

Comments
 (0)