Skip to content

Conversation

simondanielsson
Copy link
Contributor

@simondanielsson simondanielsson commented Oct 14, 2025

Purpose

Part of #26201.

Adds Automatic Prefix Caching for GDN. Tries to be similar to APC for Mamba2 as introduced in #25752.

TODOs:

  • Add better logic for making the kernel return intermediate states rather than using GDN_RECOMPUTE_SUPPRESS_LEVEL=4.
  • Make it work with fullgraph (decode)
  • Make it compatible with specdec (possibly in another PR if this requires substantial work)
  • Extend APC test suite to also run on qwen3-next (tiny random)
  • Benchmark on 80B-A3 (I will need help from someone here)
  • Overall cleanup of PR

Test Plan

Note: this runs only with the tiny tiny-random/qwen3-next-moe model, as I only have an L4 with 20GB VRAM. Would be great if someone could try also with Qwen3-Next-80B-A3B

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
import time

if __name__ == "__main__":
    # Note: should be tested with Qwen/Qwen3-Next-80B-A3B-Instruct
    MODEL = "tiny-random/qwen3-next-moe"
    PROMPT_MULTIPLE = 310
    sampling_params = SamplingParams(temperature=0.0)
    prefix = (  # examples/offline_inference/prefix_caching.py
        "You are an expert school principal, skilled in effectively managing "
        "faculty and staff. Draft 10-15 questions for a potential first grade "
        "Head Teacher for my K-12, all-girls', independent school that emphasizes "
        "community, joyful discovery, and life-long learning. The candidate is "
        "coming in for a first-round panel interview for a 8th grade Math "
        "teaching role. They have 5 years of previous teaching experience "
        "as an assistant teacher at a co-ed, public school with experience "
        "in middle school math teaching. "
    )
    prefix2 = "Based on these information, fulfill " "the following paragraph: "
    prompt = PROMPT_MULTIPLE * prefix + prefix2 + "Hello, my name is"
    print("Prompt length:", len(prompt))
    for APC in [True, False]:
        engine = LLM(
            model=MODEL,
            enable_prefix_caching=APC,
            gpu_memory_utilization=0.3,
            disable_log_stats=False,
        )
        for i in range(3):
            if i == 0:
                print("Warm-up")
            if i == 1:
                print("Measuring")
                start_time = time.time()
            outputs = engine.generate(prompt, sampling_params)
            print("APC:", APC, i, f"Generated text: {outputs[0].outputs[0].text!r}")
            for m in engine.llm_engine.get_metrics():
                if "vllm:prefix_cache_hits" in m.name:
                    print(m.name, m.value)
        print("APC:", APC, "loop took --- %s seconds ---" % (time.time() - start_time))
        del engine
        cleanup_dist_env_and_memory()

Test Result

Note: gibberish output due to random model.

No cudagraphs (enforce_eager=True):

Warm-up
APC: True 0 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
Measuring
APC: True 1 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 31680
APC: True 2 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 63360
APC: True loop took --- 0.7412824630737305 seconds ---

Warm-up
APC: False 0 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
Measuring
APC: False 1 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
APC: False 2 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
APC: False loop took --- 0.9228880405426025 seconds ---

With cudagraphs (enforce_eager=False):

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:02<00:00, 24.18it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:03<00:00,  8.98it/s]
INFO 10-14 13:44:50 [gpu_model_runner.py:3821] Graph capturing finished in 7 secs, took 0.34 GiB
INFO 10-14 13:44:50 [core.py:242] init engine (profile, create kv cache, warmup model) took 25.02 seconds
INFO 10-14 13:44:51 [loggers.py:191] Engine 000: vllm cache_config_info with initialization after num_gpu_blocks is: 10969
INFO 10-14 13:44:51 [llm.py:335] Supported tasks: ('generate',)
Warm-up
APC: True 0 Generated text: ' estado Bernie阿拉 remotelySr春晚 ứngibelENCYcancel scientificallyResidentsnah Stout__))荁'
vllm:prefix_cache_hits 0
Measuring
APC: True 1 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 31680
APC: True 2 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 63360
APC: True loop took --- 0.3312194347381592 seconds ---

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:00<00:00, 72.41it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 17.99it/s]
INFO 10-14 13:54:26 [gpu_model_runner.py:3821] Graph capturing finished in 3 secs, took 0.20 GiB
INFO 10-14 13:54:26 [core.py:242] init engine (profile, create kv cache, warmup model) took 8.07 seconds
INFO 10-14 13:54:27 [loggers.py:191] Engine 000: vllm cache_config_info with initialization after num_gpu_blocks is: 11615
INFO 10-14 13:54:27 [llm.py:335] Supported tasks: ('generate',)
Warm-up
APC: False 0 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
Measuring
APC: False 1 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
APC: False 2 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
APC: False loop took --- 0.5677089691162109 seconds ---

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added qwen Related to Qwen models v1 labels Oct 14, 2025
Copy link

mergify bot commented Oct 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @simondanielsson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 14, 2025

@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
# @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("model", ["tiny-random/qwen3-next-moe"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: revert

Signed-off-by: simondanielsson <[email protected]>
@mergify mergify bot removed the needs-rebase label Oct 14, 2025
Signed-off-by: simondanielsson <[email protected]>
# used by e.g. Mamba2, NemotronH, Zamba
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
return chunk_size
return chunk_size or 64
Copy link
Contributor Author

@simondanielsson simondanielsson Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: Need to find a better way to inject the chunk size. Currently this comes from the hardcoded chunk size in chunk_gated_delta_rule_fwd

Signed-off-by: simondanielsson <[email protected]>
slot_in_copy = slot_in_safe.clamp(min=0).to(
device=conv_state.device, dtype=torch.long
)
breakpoint()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: remove these

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants