@@ -96,70 +96,145 @@ def __init__(
9696 base_url = self ._url ,
9797 )
9898
99- async def _arun (
99+ @property
100+ def token_counter (self ) -> BaseTokenCounter :
101+ r"""Initialize the token counter for the model backend.
102+
103+ Returns:
104+ BaseTokenCounter: The token counter following the model's
105+ tokenization style.
106+ """
107+ if not self ._token_counter :
108+ self ._token_counter = OpenAITokenCounter (ModelType .GPT_4O_MINI )
109+ return self ._token_counter
110+
111+ def _run (
100112 self ,
101113 messages : List [OpenAIMessage ],
102114 response_format : Optional [Type [BaseModel ]] = None ,
103115 tools : Optional [List [Dict [str , Any ]]] = None ,
104- ) -> Union [ChatCompletion , AsyncStream [ChatCompletionChunk ]]:
116+ ) -> Union [ChatCompletion , Stream [ChatCompletionChunk ]]:
105117 r"""Runs inference of OpenAI chat completion.
106118
107119 Args:
108120 messages (List[OpenAIMessage]): Message list with the chat history
109121 in OpenAI API format.
122+ response_format (Optional[Type[BaseModel]]): The format of the
123+ response.
124+ tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
125+ use for the request.
110126
111127 Returns:
112- Union[ChatCompletion, AsyncStream [ChatCompletionChunk]]:
128+ Union[ChatCompletion, Stream [ChatCompletionChunk]]:
113129 `ChatCompletion` in the non-stream mode, or
114- `AsyncStream [ChatCompletionChunk]` in the stream mode.
130+ `Stream [ChatCompletionChunk]` in the stream mode.
115131 """
116- # Use OpenAI client as interface call Together AI
117- # Reference: https://docs.together.ai/docs/openai-api-compatibility
118- response = await self ._async_client .chat .completions .create (
119- messages = messages ,
120- model = self .model_type ,
121- ** self .model_config_dict ,
132+ response_format = response_format or self .model_config_dict .get (
133+ "response_format" , None
122134 )
123- return response
135+ if response_format :
136+ return self ._request_parse (messages , response_format , tools )
137+ else :
138+ return self ._request_chat_completion (messages , tools )
124139
125- def _run (
140+ async def _arun (
126141 self ,
127142 messages : List [OpenAIMessage ],
128143 response_format : Optional [Type [BaseModel ]] = None ,
129144 tools : Optional [List [Dict [str , Any ]]] = None ,
130- ) -> Union [ChatCompletion , Stream [ChatCompletionChunk ]]:
131- r"""Runs inference of OpenAI chat completion.
145+ ) -> Union [ChatCompletion , AsyncStream [ChatCompletionChunk ]]:
146+ r"""Runs inference of OpenAI chat completion in async mode .
132147
133148 Args:
134149 messages (List[OpenAIMessage]): Message list with the chat history
135150 in OpenAI API format.
151+ response_format (Optional[Type[BaseModel]]): The format of the
152+ response.
153+ tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
154+ use for the request.
136155
137156 Returns:
138- Union[ChatCompletion, Stream [ChatCompletionChunk]]:
157+ Union[ChatCompletion, AsyncStream [ChatCompletionChunk]]:
139158 `ChatCompletion` in the non-stream mode, or
140- `Stream [ChatCompletionChunk]` in the stream mode.
159+ `AsyncStream [ChatCompletionChunk]` in the stream mode.
141160 """
142- # Use OpenAI client as interface call Together AI
143- # Reference: https://docs.together.ai/docs/openai-api-compatibility
144- response = self ._client .chat .completions .create (
161+ response_format = response_format or self .model_config_dict .get (
162+ "response_format" , None
163+ )
164+ if response_format :
165+ return await self ._arequest_parse (messages , response_format , tools )
166+ else :
167+ return await self ._arequest_chat_completion (messages , tools )
168+
169+ def _request_chat_completion (
170+ self ,
171+ messages : List [OpenAIMessage ],
172+ tools : Optional [List [Dict [str , Any ]]] = None ,
173+ ) -> Union [ChatCompletion , Stream [ChatCompletionChunk ]]:
174+ request_config = self .model_config_dict .copy ()
175+
176+ if tools :
177+ request_config ["tools" ] = tools
178+
179+ return self ._client .chat .completions .create (
145180 messages = messages ,
146181 model = self .model_type ,
147- ** self . model_config_dict ,
182+ ** request_config ,
148183 )
149- return response
150184
151- @property
152- def token_counter (self ) -> BaseTokenCounter :
153- r"""Initialize the token counter for the model backend.
185+ async def _arequest_chat_completion (
186+ self ,
187+ messages : List [OpenAIMessage ],
188+ tools : Optional [List [Dict [str , Any ]]] = None ,
189+ ) -> Union [ChatCompletion , AsyncStream [ChatCompletionChunk ]]:
190+ request_config = self .model_config_dict .copy ()
154191
155- Returns:
156- OpenAITokenCounter: The token counter following the model's
157- tokenization style.
158- """
192+ if tools :
193+ request_config ["tools" ] = tools
159194
160- if not self ._token_counter :
161- self ._token_counter = OpenAITokenCounter (ModelType .GPT_4O_MINI )
162- return self ._token_counter
195+ return await self ._async_client .chat .completions .create (
196+ messages = messages ,
197+ model = self .model_type ,
198+ ** request_config ,
199+ )
200+
201+ def _request_parse (
202+ self ,
203+ messages : List [OpenAIMessage ],
204+ response_format : Type [BaseModel ],
205+ tools : Optional [List [Dict [str , Any ]]] = None ,
206+ ) -> ChatCompletion :
207+ request_config = self .model_config_dict .copy ()
208+
209+ request_config ["response_format" ] = response_format
210+
211+ if tools is not None :
212+ request_config ["tools" ] = tools
213+
214+ return self ._client .beta .chat .completions .parse (
215+ messages = messages ,
216+ model = self .model_type ,
217+ ** request_config ,
218+ )
219+
220+ async def _arequest_parse (
221+ self ,
222+ messages : List [OpenAIMessage ],
223+ response_format : Type [BaseModel ],
224+ tools : Optional [List [Dict [str , Any ]]] = None ,
225+ ) -> ChatCompletion :
226+ request_config = self .model_config_dict .copy ()
227+
228+ request_config ["response_format" ] = response_format
229+
230+ if tools is not None :
231+ request_config ["tools" ] = tools
232+
233+ return await self ._async_client .beta .chat .completions .parse (
234+ messages = messages ,
235+ model = self .model_type ,
236+ ** request_config ,
237+ )
163238
164239 def check_model_config (self ):
165240 r"""Check whether the model configuration contains any
0 commit comments