Skip to content

Commit ce5ee0a

Browse files
authored
chore: update vllm implementation (camel-ai#2124)
1 parent 3f0ff7e commit ce5ee0a

File tree

1 file changed

+128
-57
lines changed

1 file changed

+128
-57
lines changed

camel/models/vllm_model.py

Lines changed: 128 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -119,106 +119,177 @@ def token_counter(self) -> BaseTokenCounter:
119119
self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
120120
return self._token_counter
121121

122-
def check_model_config(self):
123-
r"""Check whether the model configuration contains any
124-
unexpected arguments to vLLM API.
122+
def _run(
123+
self,
124+
messages: List[OpenAIMessage],
125+
response_format: Optional[Type[BaseModel]] = None,
126+
tools: Optional[List[Dict[str, Any]]] = None,
127+
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
128+
r"""Runs inference of OpenAI chat completion.
125129
126-
Raises:
127-
ValueError: If the model configuration dictionary contains any
128-
unexpected arguments to OpenAI API.
130+
Args:
131+
messages (List[OpenAIMessage]): Message list with the chat history
132+
in OpenAI API format.
133+
response_format (Optional[Type[BaseModel]]): The format of the
134+
response.
135+
tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
136+
use for the request.
137+
138+
Returns:
139+
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
140+
`ChatCompletion` in the non-stream mode, or
141+
`Stream[ChatCompletionChunk]` in the stream mode.
129142
"""
130-
for param in self.model_config_dict:
131-
if param not in VLLM_API_PARAMS:
132-
raise ValueError(
133-
f"Unexpected argument `{param}` is "
134-
"input into vLLM model backend."
135-
)
143+
response_format = response_format or self.model_config_dict.get(
144+
"response_format", None
145+
)
146+
if response_format:
147+
return self._request_parse(messages, response_format, tools)
148+
else:
149+
return self._request_chat_completion(messages, tools)
136150

137151
async def _arun(
138152
self,
139153
messages: List[OpenAIMessage],
140154
response_format: Optional[Type[BaseModel]] = None,
141155
tools: Optional[List[Dict[str, Any]]] = None,
142156
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
143-
r"""Runs inference of OpenAI chat completion.
157+
r"""Runs inference of OpenAI chat completion in async mode.
144158
145159
Args:
146160
messages (List[OpenAIMessage]): Message list with the chat history
147161
in OpenAI API format.
148-
response_format (Optional[Type[BaseModel]], optional): The format
149-
to return the response in.
150-
tools (Optional[List[Dict[str, Any]]], optional): List of tools
151-
the model may call.
162+
response_format (Optional[Type[BaseModel]]): The format of the
163+
response.
164+
tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
165+
use for the request.
152166
153167
Returns:
154168
Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
155169
`ChatCompletion` in the non-stream mode, or
156170
`AsyncStream[ChatCompletionChunk]` in the stream mode.
157171
"""
172+
response_format = response_format or self.model_config_dict.get(
173+
"response_format", None
174+
)
175+
if response_format:
176+
return await self._arequest_parse(messages, response_format, tools)
177+
else:
178+
return await self._arequest_chat_completion(messages, tools)
179+
180+
def _request_chat_completion(
181+
self,
182+
messages: List[OpenAIMessage],
183+
tools: Optional[List[Dict[str, Any]]] = None,
184+
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
185+
request_config = self.model_config_dict.copy()
158186

159-
kwargs = self.model_config_dict.copy()
160187
if tools:
161-
kwargs["tools"] = tools
162-
if response_format:
163-
kwargs["response_format"] = {"type": "json_object"}
188+
request_config["tools"] = tools
164189

165190
# Remove additionalProperties from each tool's function parameters
166-
if tools and "tools" in kwargs:
167-
for tool in kwargs["tools"]:
191+
if tools and "tools" in request_config:
192+
for tool in request_config["tools"]:
168193
if "function" in tool and "parameters" in tool["function"]:
169194
tool["function"]["parameters"].pop(
170195
"additionalProperties", None
171196
)
172197

173-
response = await self._async_client.chat.completions.create(
198+
return self._client.chat.completions.create(
174199
messages=messages,
175200
model=self.model_type,
176-
**kwargs,
201+
**request_config,
177202
)
178-
return response
179203

180-
def _run(
204+
async def _arequest_chat_completion(
181205
self,
182206
messages: List[OpenAIMessage],
183-
response_format: Optional[Type[BaseModel]] = None,
184207
tools: Optional[List[Dict[str, Any]]] = None,
185-
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
186-
r"""Runs inference of OpenAI chat completion.
208+
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
209+
request_config = self.model_config_dict.copy()
187210

188-
Args:
189-
messages (List[OpenAIMessage]): Message list with the chat history
190-
in OpenAI API format.
191-
response_format (Optional[Type[BaseModel]], optional): The format
192-
to return the response in.
193-
tools (Optional[List[Dict[str, Any]]], optional): List of tools
194-
the model may call.
211+
if tools:
212+
request_config["tools"] = tools
213+
# Remove additionalProperties from each tool's function parameters
214+
if "tools" in request_config:
215+
for tool in request_config["tools"]:
216+
if "function" in tool and "parameters" in tool["function"]:
217+
tool["function"]["parameters"].pop(
218+
"additionalProperties", None
219+
)
195220

196-
Returns:
197-
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
198-
`ChatCompletion` in the non-stream mode, or
199-
`Stream[ChatCompletionChunk]` in the stream mode.
200-
"""
221+
return await self._async_client.chat.completions.create(
222+
messages=messages,
223+
model=self.model_type,
224+
**request_config,
225+
)
201226

202-
kwargs = self.model_config_dict.copy()
203-
if tools:
204-
kwargs["tools"] = tools
205-
if response_format:
206-
kwargs["response_format"] = {"type": "json_object"}
227+
def _request_parse(
228+
self,
229+
messages: List[OpenAIMessage],
230+
response_format: Type[BaseModel],
231+
tools: Optional[List[Dict[str, Any]]] = None,
232+
) -> ChatCompletion:
233+
request_config = self.model_config_dict.copy()
207234

208-
# Remove additionalProperties from each tool's function parameters
209-
if tools and "tools" in kwargs:
210-
for tool in kwargs["tools"]:
211-
if "function" in tool and "parameters" in tool["function"]:
212-
tool["function"]["parameters"].pop(
213-
"additionalProperties", None
214-
)
235+
request_config["response_format"] = response_format
236+
request_config.pop("stream", None)
237+
if tools is not None:
238+
request_config["tools"] = tools
239+
# Remove additionalProperties from each tool's function parameters
240+
if "tools" in request_config:
241+
for tool in request_config["tools"]:
242+
if "function" in tool and "parameters" in tool["function"]:
243+
tool["function"]["parameters"].pop(
244+
"additionalProperties", None
245+
)
246+
247+
return self._client.beta.chat.completions.parse(
248+
messages=messages,
249+
model=self.model_type,
250+
**request_config,
251+
)
252+
253+
async def _arequest_parse(
254+
self,
255+
messages: List[OpenAIMessage],
256+
response_format: Type[BaseModel],
257+
tools: Optional[List[Dict[str, Any]]] = None,
258+
) -> ChatCompletion:
259+
request_config = self.model_config_dict.copy()
215260

216-
response = self._client.chat.completions.create(
261+
request_config["response_format"] = response_format
262+
request_config.pop("stream", None)
263+
if tools is not None:
264+
request_config["tools"] = tools
265+
# Remove additionalProperties from each tool's function parameters
266+
if "tools" in request_config:
267+
for tool in request_config["tools"]:
268+
if "function" in tool and "parameters" in tool["function"]:
269+
tool["function"]["parameters"].pop(
270+
"additionalProperties", None
271+
)
272+
273+
return await self._async_client.beta.chat.completions.parse(
217274
messages=messages,
218275
model=self.model_type,
219-
**kwargs,
276+
**request_config,
220277
)
221-
return response
278+
279+
def check_model_config(self):
280+
r"""Check whether the model configuration contains any
281+
unexpected arguments to vLLM API.
282+
283+
Raises:
284+
ValueError: If the model configuration dictionary contains any
285+
unexpected arguments to OpenAI API.
286+
"""
287+
for param in self.model_config_dict:
288+
if param not in VLLM_API_PARAMS:
289+
raise ValueError(
290+
f"Unexpected argument `{param}` is "
291+
"input into vLLM model backend."
292+
)
222293

223294
@property
224295
def stream(self) -> bool:

0 commit comments

Comments
 (0)