Skip to content

PrefixTuning fails with DynamicCache: 'DynamicCache' object has no attribute 'key_cache' #2821

@dnanper

Description

@dnanper

🧩 System Info

  • Platform: Kaggle Notebook
  • Transformers: 4.57.0
  • PEFT: 0.17.1
  • Model: Qwen/Qwen2.5-Coder-0.5B-Instruct

🧪 Here is my code

### Reproduction code
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PrefixTuningConfig, get_peft_model
import torch

MODEL_NAME = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto"
)

peft_config = PrefixTuningConfig(
    task_type="CAUSAL_LM",
    num_virtual_tokens=20
)
model = get_peft_model(model, peft_config)

inputs = tokenizer("Hello", return_tensors="pt").to(model.device)
loss = model(**inputs, labels=inputs["input_ids"]).loss

Here is the error

AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_37/3537529634.py in <cell line: 0>()
     20 
     21 inputs = tokenizer("Hello", return_tensors="pt").to(model.device)
---> 22 loss = model(**inputs, labels=inputs["input_ids"]).loss

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1737             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738         else:
-> 1739             return self._call_impl(*args, **kwargs)
   1740 
   1741     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1748                 or _global_backward_pre_hooks or _global_backward_hooks
   1749                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750             return forward_call(*args, **kwargs)
   1751 
   1752         result = None

/usr/local/lib/python3.11/dist-packages/peft/peft_model.py in forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1888             else:
   1889                 max_cache_len = inputs_embeds.shape[1] + peft_config.num_virtual_tokens
-> 1890             kwargs["past_key_values"] = self.get_prompt(batch_size, max_cache_len=max_cache_len)
   1891             return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
   1892         elif peft_config.peft_type == PeftType.CPT:

/usr/local/lib/python3.11/dist-packages/peft/peft_model.py in get_prompt(self, batch_size, task_ids, max_cache_len)
    788                     layer_idx: False for layer_idx in range(len(past_key_values.cross_attention_cache.key_cache))
    789                 }
--> 790             map_cache_to_layer_device_map(self.get_base_model(), past_key_values)  # no-op if not a Cache instance
    791             return past_key_values
    792         else:

/usr/local/lib/python3.11/dist-packages/peft/utils/integrations.py in map_cache_to_layer_device_map(model, cache)
    180     for idx in range(model.config.num_hidden_layers):
    181         layer_device = layer_device_map[idx]
--> 182         cache.key_cache[idx] = cache.key_cache[idx].to(layer_device)
    183         cache.value_cache[idx] = cache.value_cache[idx].to(layer_device)
    184 

AttributeError: 'DynamicCache' object has no attribute 'key_cache'

🙏 Thank You

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions