Skip to content

fix: providing a tensor to cache_position in model.generate kwargs always crashes because of boolean test #39261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 0 commits into from

Conversation

guicho271828
Copy link

Currently, giving cache_position to model.generate is broken because providing any tensor value to it results in performing a boolean test, which causes a runtime error. This PR fixes it.

This is basically an oversight in the change commited in #37986 --- the code path is untested.

Traceback (most recent call last):
  File "/home/masataro/test-kv.py", line 77, in <module>
    outputs = model.generate(**inputs_trimmed, do_sample=False, max_new_tokens=256, past_key_values=past_key_values, cache_position=cache_position)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/masataro/miniforge3/envs/test/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/masataro/miniforge3/envs/test/lib/python3.12/site-packages/transformers/generation/utils.py", line 2623, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/home/masataro/miniforge3/envs/test/lib/python3.12/site-packages/transformers/generation/utils.py", line 3568, in _sample
    model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/masataro/miniforge3/envs/test/lib/python3.12/site-packages/transformers/generation/utils.py", line 1799, in _get_initial_cache_position
    if "cache_position" in model_kwargs and model_kwargs["cache_position"]:
                                            ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
RuntimeError: Boolean value of Tensor with more than one value is ambiguous

sample code (test-kv.py):

import torch
from transformers import AutoTokenizer,AutoModelForCausalLM
from transformers.cache_utils import (
    DynamicCache,
)

model_id = "ibm-granite/granite-3.2-8b-instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)

user_prompts = ["Hello, what's your name?",
                "Btw, yesterday I was on a rock concert.",
                "The song was 'we are the world'",
                ]

# with kv cache for all past context
past_key_values = DynamicCache()
last_output_length=0
messages = []
for prompt in user_prompts:
    #                      outputs[0].shape[1]
    #                  /~~~~~~~~~~~~~~~~~~~~~~~~~~\ output contains this string
    # IIIIIIIIOOOOOOOOOOIIIIIIIIIOOOOOOOOOOOOOOOOOO
    #                  |        +-- input length  |
    #                  +--last output length      + next output length
    #                           \-----------------/
    #                               completion
    print(prompt)
    messages.append({"role": "user", "content": prompt})
    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
    input_length = inputs["input_ids"].shape[1]
    inputs_trimmed={
        'input_ids':inputs["input_ids"][:, last_output_length:],
        'attention_mask':inputs["attention_mask"],
    }
    cache_position = torch.arange(last_output_length, input_length, dtype=torch.int64, device=model.device)
    assert inputs_trimmed["input_ids"].shape[1] == cache_position.shape[0]
    outputs = model.generate(**inputs_trimmed,
                             do_sample=False,
                             max_new_tokens=256,
                             use_cache=True,
                             past_key_values=past_key_values,
                             cache_position=cache_position)
    completion = tokenizer.decode(outputs[0, input_length - last_output_length: ], skip_special_tokens=True)
    print(completion)
    messages.append({"role": "assistant", "content": completion})
    last_output_length += outputs.shape[1]

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). --- does not apply
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case. --- could not find any similar issue
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings. --- does not apply
  • Did you write any new necessary tests? --- help needed.

Who can review?

@FremyCompany @ArthurZucker @zucchini-nlp @gante

@guicho271828 guicho271828 marked this pull request as draft July 7, 2025 21:19
@guicho271828 guicho271828 marked this pull request as ready for review July 7, 2025 21:20
@github-actions github-actions bot requested a review from gante July 7, 2025 21:20
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for locating the bug and opening a PR with the fix!

LGTM (and +1 to @zucchini-nlp 's comment, let's add a test to prevent regressions 💛 )

@guicho271828
Copy link
Author

Done. I wonder if this test is correct.

@gante
Copy link
Member

gante commented Jul 8, 2025

@guicho271828 ideally, the test would be an integration test (runs once per CI run) rather than a mixin test (runs once per model per CI run 😬 ) i.e. keep it in the same file, but in GenerationIntegrationTests instead of in GenerationTesterMixin, using a dummy model checkpoint (e.g. hf-internal-testing/tiny-random-MistralForCausalLM)

if it's too much hassle for you, lmk and I'll push the changes :)

@guicho271828
Copy link
Author

if it's too much hassle for you, lmk and I'll push the changes :)

Yeah I don't know much about the testing mechanism here; Thanks a lot.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante
Copy link
Member

gante commented Jul 9, 2025

@guicho271828 sorry, I did something wrong with your branch and it closed this PR 👀 I've opened a copy here #39300

You'll still be attributed as an author, since it has your commits 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants