Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions tests/entrypoints/openai/test_truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,11 @@ async def test_zero_truncation_size(client: openai.AsyncOpenAI):
"truncate_prompt_tokens": truncation_size
}

with pytest.raises(openai.BadRequestError) as err:
await client.post(path="embeddings", cast_to=object, body={**kwargs})

assert err.value.status_code == 400
error_details = err.value.response.json()["error"]
response = await client.post(path="embeddings",
cast_to=object,
body={**kwargs})

assert error_details["type"] == "BadRequestError"
assert "This model's maximum context length is" in error_details["message"]
assert "tokens in the input for embedding generation" in error_details[
"message"]
assert "Please reduce the length of the input" in error_details["message"]
assert response["usage"]["prompt_tokens"] == truncation_size


@pytest.mark.asyncio
Expand Down
17 changes: 17 additions & 0 deletions tests/entrypoints/test_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,23 @@ async def test_truncation_positive(self, renderer, mock_async_tokenizer):
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 50

@pytest.mark.asyncio
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
# Test that negative truncation uses model's max_model_len
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088]) # Truncated to max_model_len
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer

results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=200,
truncate_prompt_tokens=-1)

assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 100 # model's max_model_len

@pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements
Expand Down
13 changes: 5 additions & 8 deletions vllm/entrypoints/openai/serving_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,11 @@ async def _preprocess(
ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request)

(
ctx.request_prompts,
ctx.engine_prompts,
) = await self._preprocess_completion(
ctx.request,
ctx.tokenizer,
ctx.request.input,
)
renderer = self._get_renderer(ctx.tokenizer)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
max_length=self.max_model_len,
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens)

return None

Expand Down
45 changes: 19 additions & 26 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
OpenAIServing,
RequestPrompt,
ServeContext,
TextTokensPrompt)
# yapf: enable
Expand Down Expand Up @@ -79,11 +78,12 @@ async def _preprocess(

tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
)
renderer = self._get_renderer(tokenizer)

if isinstance(ctx.request, EmbeddingChatRequest):
(
_,
ctx.request_prompts,
_,
ctx.engine_prompts,
) = await self._preprocess_chat(
ctx.request,
Expand All @@ -98,13 +98,18 @@ async def _preprocess(
add_special_tokens=ctx.request.add_special_tokens,
)
else:
(ctx.request_prompts,
ctx.engine_prompts) = await self._preprocess_completion(
ctx.request,
tokenizer,
ctx.request.input,
add_special_tokens=ctx.request.add_special_tokens,
)
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(ctx.request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len

ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
max_length=max_length,
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
add_special_tokens=ctx.request.add_special_tokens,
)
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
Expand Down Expand Up @@ -286,7 +291,6 @@ async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
request_prompt: RequestPrompt,
pooling_params: PoolingParams,
trace_headers: Optional[Mapping[str, str]],
prompt_index: int,
Expand All @@ -295,7 +299,7 @@ async def _create_single_prompt_generator(
request_id_item = f"{ctx.request_id}-{prompt_index}"

self._log_inputs(request_id_item,
request_prompt,
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request)

Expand Down Expand Up @@ -353,20 +357,14 @@ async def _prepare_generators(
return self.create_error_response(
"Engine prompts not available")

if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")

max_pos_embeddings = self._get_max_position_embeddings()

for i, engine_prompt in enumerate(ctx.engine_prompts):
request_prompt = ctx.request_prompts[i]

# Check if this specific prompt needs chunked processing
if self._is_text_tokens_prompt(request_prompt):
if self._is_text_tokens_prompt(engine_prompt):
# Cast to TextTokensPrompt since we've verified
# prompt_token_ids
text_tokens_prompt = cast(TextTokensPrompt, request_prompt)
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
if (len(text_tokens_prompt["prompt_token_ids"])
> max_pos_embeddings):
# Use chunked processing for this prompt
Expand All @@ -382,8 +380,7 @@ async def _prepare_generators(
Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
generator = await self._create_single_prompt_generator(
ctx, engine_prompt_typed, request_prompt, pooling_params,
trace_headers, i)
ctx, engine_prompt_typed, pooling_params, trace_headers, i)
generators.append(generator)

from vllm.utils import merge_async_iterators
Expand Down Expand Up @@ -419,10 +416,6 @@ async def _collect_batch(
if not use_chunked:
return await super()._collect_batch(ctx=ctx)

if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")

if ctx.result_generator is None:
return self.create_error_response(
"Result generator not available")
Expand Down Expand Up @@ -538,7 +531,7 @@ async def _collect_batch(
data=final_embedding)

# Get original prompt token IDs for this prompt
original_prompt = ctx.request_prompts[prompt_idx]
original_prompt = ctx.engine_prompts[prompt_idx]
if not self._is_text_tokens_prompt(original_prompt):
return self.create_error_response(
f"Chunked prompt {prompt_idx} is not a "
Expand Down
17 changes: 7 additions & 10 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,23 +368,20 @@ async def _prepare_generators(
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}"

if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")
# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt = cast(
Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)

self._log_inputs(
request_id_item,
ctx.request_prompts[i],
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)

# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt = cast(
Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
Expand Down
9 changes: 7 additions & 2 deletions vllm/entrypoints/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,15 @@ async def render_prompt(
for detailed parameter documentation.
"""
if truncate_prompt_tokens is not None:
if max_length is not None:
assert 0 <= truncate_prompt_tokens <= max_length
if truncate_prompt_tokens == 0:
return []
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len
if max_length is not None and truncate_prompt_tokens > max_length:
raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). "
f"Please select a smaller truncation size.")

# Parse and batch the input prompts
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
Expand Down