Skip to content
3 changes: 3 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,6 @@ class SchedulerOutput:

# KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None

# Number of steps to schedule
step_num: int = 1
11 changes: 7 additions & 4 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
step_num: int = 1,
) -> None:
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
Expand Down Expand Up @@ -159,7 +160,8 @@
self.num_spec_tokens = speculative_config.num_speculative_tokens
if speculative_config.use_eagle():
self.use_eagle = True
self.num_lookahead_tokens = self.num_spec_tokens
self.num_lookahead_tokens = self.num_spec_tokens + \
(self.step_num - 1) * (1 + self.num_spec_tokens)

Check failure on line 164 in vllm/v1/core/sched/scheduler.py

View workflow job for this annotation

GitHub Actions / pre-commit

"Scheduler" has no attribute "step_num" [attr-defined]

# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
Expand Down Expand Up @@ -913,6 +915,7 @@
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
kv_connector_output = model_runner_output.kv_connector_output
num_steps = model_runner_output.step_num

outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None
Expand Down Expand Up @@ -959,8 +962,8 @@
scheduler_output.scheduled_spec_decode_tokens.get(req_id)
)
if scheduled_spec_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1
num_draft_tokens = len(scheduled_spec_token_ids) * num_steps
num_accepted = len(generated_token_ids) - num_steps
num_rejected = num_draft_tokens - num_accepted
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
Expand Down Expand Up @@ -1265,7 +1268,7 @@
if not self.log_stats:
return None
if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens)
spec_decoding_stats = SpecDecodingStats.new(num_draft_tokens)
spec_decoding_stats.observe_draft(
num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens
)
Expand Down
16 changes: 16 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,20 @@
* vllm_config.parallel_config.decode_context_parallel_size
)

self.step_num = int(additional_config.get("multi_step", 1))

Check failure on line 153 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

vllm/v1/engine/core.py:153:29: F821 Undefined name `additional_config`
if self.step_num > 1:
logger.info(
("multi step is enabled. step num is %d"),
self.step_num,
)
self.scheduler: SchedulerInterface = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
log_stats=self.log_stats,
block_size=scheduler_block_size,
step_num=self.step_num,
)
self.use_spec_decode = vllm_config.speculative_config is not None
if self.scheduler.connector is not None: # type: ignore
Expand Down Expand Up @@ -198,6 +205,7 @@
self.step_fn = (
self.step if self.batch_queue is None else self.step_with_batch_queue
)
additional_config = vllm_config.additional_config

Check failure on line 208 in vllm/v1/engine/core.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/v1/engine/core.py:208:9: F841 Local variable `additional_config` is assigned to but never used

def _initialize_kv_caches(
self, vllm_config: VllmConfig
Expand Down Expand Up @@ -321,10 +329,18 @@
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
scheduler_output.step_num = self.step_num
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output,
)
for (
req_id,
num_scheduled_token,
) in scheduler_output.num_scheduled_tokens.items():
self.scheduler.requests[req_id].num_computed_tokens += (
num_scheduled_token * (model_output.step_num - 1)
)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ class ModelRunnerOutput:
# req_id -> num_nans_in_logits
num_nans_in_logits: dict[str, int] | None = None

# actual step num in model_runner
step_num: int = 1


# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):
Expand Down
Loading