Description
🐛 Describe the bug
Llama3 workloads from Intel AI Reference Models failed with PyTorch 2025-06-22 nightly wheel.
It seems that there is something wrong with the latest flex attention implementation.
The suspected guilty commit: ccc6279
Error logs:
[run_202506220900233]:Traceback (most recent call last):
[run_202506220900233]: File "/home/sdp/Jenkins_dir/workspace/pytorch-inductor-launch-benchmark/pytorch-models//models_v2/pytorch/llama/inference/cpu/inductor/run_llm_inductor_greedy.py", line 280, in <module>
[run_202506220900233]: model.generate(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[run_202506220900233]: return func(*args, **kwargs)
[run_202506220900233]: File "/home/sdp/Jenkins_dir/workspace/pytorch-inductor-launch-benchmark/pytorch-models/models_v2/pytorch/llama/inference/cpu/transformers/src/transformers/generation/utils.py", line 2237, in generate
[run_202506220900233]: result, latency_list = self._sample(
[run_202506220900233]: File "/home/sdp/Jenkins_dir/workspace/pytorch-inductor-launch-benchmark/pytorch-models/models_v2/pytorch/llama/inference/cpu/transformers/src/transformers/generation/utils.py", line 3228, in _sample
[run_202506220900233]: outputs = self(**model_inputs, return_dict=True)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[run_202506220900233]: return self._call_impl(*args, **kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[run_202506220900233]: return forward_call(*args, **kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 723, in compile_wrapper
[run_202506220900233]: raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1870, in _call_user_compiler
[run_202506220900233]: raise BackendCompilerFailed(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1845, in _call_user_compiler
[run_202506220900233]: compiled_fn = compiler_fn(gm, example_inputs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
[run_202506220900233]: compiled_gm = compiler_fn(gm, example_inputs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/__init__.py", line 2398, in __call__
[run_202506220900233]: return compile_fx(model_, inputs_, config_patches=self.config)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2059, in compile_fx
[run_202506220900233]: return compile_fx(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2418, in compile_fx
[run_202506220900233]: return aot_autograd(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 109, in __call__
[run_202506220900233]: cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1198, in aot_module_simplified
[run_202506220900233]: compiled_fn = AOTAutogradCache.load(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 1121, in load
[run_202506220900233]: compiled_fn = dispatch_and_compile()
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1183, in dispatch_and_compile
[run_202506220900233]: compiled_fn, _ = create_aot_dispatcher_function(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
[run_202506220900233]: return _create_aot_dispatcher_function(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 836, in _create_aot_dispatcher_function
[run_202506220900233]: compiled_fn, fw_metadata = compiler_fn(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 244, in aot_dispatch_base
[run_202506220900233]: compiled_fw = compiler(fw_module, updated_flat_args)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1859, in fw_compiler_freezing
[run_202506220900233]: opt_model, preserved_arg_indices = freeze(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_inductor/freezing.py", line 95, in freeze
[run_202506220900233]: return _freeze(dynamo_gm, aot_autograd_gm, example_inputs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_inductor/freezing.py", line 127, in _freeze
[run_202506220900233]: freezing_passes(aot_autograd_gm, aot_example_inputs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_inductor/fx_passes/freezing_patterns.py", line 49, in freezing_passes
[run_202506220900233]: fake_tensor_prop(gm, aot_example_inputs, True)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 658, in fake_tensor_prop
[run_202506220900233]: FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/fx/passes/fake_tensor_prop.py", line 109, in propagate_dont_convert_inputs
[run_202506220900233]: return super().run(*args)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/fx/interpreter.py", line 173, in run
[run_202506220900233]: self.env[node] = self.run_node(node)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/fx/passes/fake_tensor_prop.py", line 75, in run_node
[run_202506220900233]: result = super().run_node(n)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/fx/interpreter.py", line 242, in run_node
[run_202506220900233]: return getattr(self, n.op)(n.target, args, kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/fx/interpreter.py", line 322, in call_function
[run_202506220900233]: return target(*args, **kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 97, in __call__
[run_202506220900233]: return super().__call__(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_ops.py", line 513, in __call__
[run_202506220900233]: return wrapper()
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_ops.py", line 509, in wrapper
[run_202506220900233]: return self.dispatch(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_ops.py", line 369, in dispatch
[run_202506220900233]: return kernel(*args, **kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 781, in flex_attention_autograd
[run_202506220900233]: out, logsumexp = FlexAttentionAutogradOp.apply(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/autograd/function.py", line 576, in apply
[run_202506220900233]: return super().apply(*args, **kwargs) # type: ignore[misc]
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 646, in forward
[run_202506220900233]: out, logsumexp = flex_attention(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 97, in __call__
[run_202506220900233]: return super().__call__(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_ops.py", line 513, in __call__
[run_202506220900233]: return wrapper()
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_ops.py", line 504, in wrapper
[run_202506220900233]: return torch.overrides.handle_torch_function(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/overrides.py", line 1725, in handle_torch_function
[run_202506220900233]: result = mode.__torch_function__(public_api, types, args, kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 142, in __torch_function__
[run_202506220900233]: return func(*args, **(kwargs or {}))
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py", line 97, in __call__
[run_202506220900233]: return super().__call__(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_ops.py", line 513, in __call__
[run_202506220900233]: return wrapper()
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_ops.py", line 509, in wrapper
[run_202506220900233]: return self.dispatch(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_ops.py", line 405, in dispatch
[run_202506220900233]: result = handler(mode, *args, **kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_higher_order_ops/utils.py", line 550, in impl
[run_202506220900233]: return mode.__torch_dispatch__(hop, [], args, kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/utils/_stats.py", line 28, in wrapper
[run_202506220900233]: return fn(*args, **kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
[run_202506220900233]: return self.dispatch(func, types, args, kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2058, in dispatch
[run_202506220900233]: return self._cached_dispatch_impl(func, types, args, kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1457, in _cached_dispatch_impl
[run_202506220900233]: return self._dispatch_impl(func, types, args, kwargs)
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2352, in _dispatch_impl
[run_202506220900233]: (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors(
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2803, in validate_and_convert_non_fake_tensors
[run_202506220900233]: validated_args = [validate(a) for a in flat_args]
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2803, in <listcomp>
[run_202506220900233]: validated_args = [validate(a) for a in flat_args]
[run_202506220900233]: File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2777, in validate
[run_202506220900233]: if torch.Tag.inplace_view in func.tags:
[run_202506220900233]:torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[run_202506220900233]:AttributeError: 'FlexAttentionHOP' object has no attribute 'tags'
Versions
Collecting environment information...
PyTorch version: 2.8.0.dev20250622+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 4.0.3
Libc version: glibc-2.35
Python version: 3.10.18 | packaged by conda-forge | (main, Jun 4 2025, 14:45:41) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-140-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 384
On-line CPU(s) list: 0-383
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) 6972P
CPU family: 6
Model: 173
Thread(s) per core: 2
Core(s) per socket: 96
Socket(s): 2
Stepping: 1
CPU max MHz: 3900.0000
CPU min MHz: 800.0000
BogoMIPS: 4800.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 9 MiB (192 instances)
L1i cache: 12 MiB (192 instances)
L2 cache: 384 MiB (192 instances)
L3 cache: 960 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-95,192-287
NUMA node1 CPU(s): 96-191,288-383
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS Not affected; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.8.0.dev20250622+cpu
[pip3] torch-fidelity==0.3.0
[pip3] torchmetrics==1.7.3
[pip3] torchvision==0.23.0a0+d164893
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.8.0.dev20250622+cpu pypi_0 pypi
[conda] torch-fidelity 0.3.0 pypi_0 pypi
[conda] torchmetrics 1.7.3 pypi_0 pypi
[conda] torchvision 0.23.0a0+d164893 pypi_0 pypi
cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng @chuanqi129 @Valentine233