Skip to content

Paged attention does not work #41501

@yuyijiong

Description

@yuyijiong

System Info

  • transformers version: 4.57.0
  • Platform: Linux-5.15.0-107-generic-x86_64-with-glibc2.35
  • Python version: 3.12.11
  • Huggingface_hub version: 0.35.3
  • Safetensors version: 0.6.2
  • Accelerate version: 1.10.1
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: MULTI_GPU
    - mixed_precision: bf16
    - use_cpu: False
    - debug: False
    - num_processes: 8
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: all
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: False
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA H20

Who can help?

@vasqu @ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When set attn_implementation to "paged_attention", I meet the error even with very simple inference code.
This error is similar to #39525

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import pandas as pd
import torch
from transformers import AutoTokenizer,Qwen3ForCausalLM,AutoConfig,AutoModelForCausalLM

if __name__ == '__main__':
    model_path="/share/models/Qwen3-4B-Thinking-2507"

    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, add_bos_token=False,
                                              add_eos_token=False)

    model = AutoModelForCausalLM.from_pretrained(model_path,
                                             dtype=torch.bfloat16,
                                             trust_remote_code=True,
                                             #config=config,
                                             attn_implementation="paged_attention",#"flash_attention_2",#"sdpa",#
                                             device_map="cuda"  
                                             ).eval()

    prompt="How are you today?"
    chat_prompt=tokenizer.apply_chat_template([{"role":"user","content":prompt}],tokenize=False,add_generation_prompt=True)

    chat_prompt_ids=tokenizer(chat_prompt,return_tensors="pt")["input_ids"].to(model.device)
    output=model.generate(input_ids=chat_prompt_ids,max_new_tokens=500,num_beams=1,do_sample=False,temperature=1.0,use_cache=True,return_dict_in_generate=True,output_logits=True)
    output_text=tokenizer.decode(output['sequences'][0][chat_prompt_ids.size(1):])
    print(output_text)

Error:

  File "/root/miniconda3/envs/yyj/lib/python3.12/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 260, in forward
    hidden_states, _ = self.self_attn(
                       ^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/yyj/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/yyj/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/yyj/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/yyj/lib/python3.12/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 216, in forward
    attn_output, attn_weights = attention_interface(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/yyj/lib/python3.12/site-packages/transformers/integrations/flash_paged.py", line 85, in paged_attention_forward
    cu_seq_lens_q.to(torch.int32),
    ^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'to'

Expected behavior

fix the bug

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions