Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions vllm/entrypoints/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Union

from openai_harmony import Author, Message, Role, StreamState, TextContent
Expand Down Expand Up @@ -67,15 +68,27 @@ def __init__(

self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
# TODO(woosuk): Implement the following fields.
self.num_prompt_tokens = 0
self.num_cached_tokens = 0
self.num_output_tokens = 0
# TODO(woosuk): Implement the following fields.
self.num_cached_tokens = 0
self.num_reasoning_tokens = 0

def _update_num_prompt_tokens(self, output: RequestOutput):
if output.prompt_token_ids and len(output.prompt_token_ids) > 0:
# NOTE: with built-in tools, there might be multiple rounds in
# the conversation, with the full conversation being resent
# as new prompt each time. Hence the sum.
self.num_prompt_tokens += len(output.prompt_token_ids)

def _update_num_output_tokens(self, token_ids: Sequence[int]):
self.num_output_tokens += len(token_ids)

def append_output(self, output) -> None:
if isinstance(output, RequestOutput):
self._update_num_prompt_tokens(output)
output_token_ids = output.outputs[0].token_ids
self._update_num_output_tokens(output_token_ids)
self.parser = get_streamable_parser_for_assistant()
for token_id in output_token_ids:
self.parser.process(token_id)
Expand Down Expand Up @@ -158,15 +171,26 @@ def __init__(self, *args, **kwargs):
self.parser = get_streamable_parser_for_assistant()
self.encoding = get_encoding()
self.last_tok = None
self.first_tok_of_message = True

@property
def messages(self) -> list:
return self.parser.messages

def append_output(self, output) -> None:
if isinstance(output, RequestOutput):
# append_output is called for each output token in streaming case,
# so we only want to add the prompt tokens once for each message.
if self.first_tok_of_message:
self._update_num_prompt_tokens(output)
# Reset self.first_tok_of_message if needed:
# if the current token is the last one of the current message
# (finished=True), then the next token processed will mark the
# beginning of a new message
self.first_tok_of_message = output.finished
tok = output.outputs[0].token_ids[0]
self.parser.process(tok)
self._update_num_output_tokens(output.outputs[0].token_ids)
self.last_tok = tok
else:
# Handle the case of tool output in direct message format
Expand Down