@@ -15,7 +15,7 @@ pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
15
15
16
16
/// [crate::Client] relies on this for every API call on OpenAI
17
17
/// or Azure OpenAI service
18
- pub trait Config : Clone {
18
+ pub trait Config : Send + Sync {
19
19
fn headers ( & self ) -> HeaderMap ;
20
20
fn url ( & self , path : & str ) -> String ;
21
21
fn query ( & self ) -> Vec < ( & str , & str ) > ;
@@ -25,6 +25,32 @@ pub trait Config: Clone {
25
25
fn api_key ( & self ) -> & SecretString ;
26
26
}
27
27
28
+ /// Macro to implement Config trait for pointer types with dyn objects
29
+ macro_rules! impl_config_for_ptr {
30
+ ( $t: ty) => {
31
+ impl Config for $t {
32
+ fn headers( & self ) -> HeaderMap {
33
+ self . as_ref( ) . headers( )
34
+ }
35
+ fn url( & self , path: & str ) -> String {
36
+ self . as_ref( ) . url( path)
37
+ }
38
+ fn query( & self ) -> Vec <( & str , & str ) > {
39
+ self . as_ref( ) . query( )
40
+ }
41
+ fn api_base( & self ) -> & str {
42
+ self . as_ref( ) . api_base( )
43
+ }
44
+ fn api_key( & self ) -> & SecretString {
45
+ self . as_ref( ) . api_key( )
46
+ }
47
+ }
48
+ } ;
49
+ }
50
+
51
+ impl_config_for_ptr ! ( Box <dyn Config >) ;
52
+ impl_config_for_ptr ! ( std:: sync:: Arc <dyn Config >) ;
53
+
28
54
/// Configuration for OpenAI API
29
55
#[ derive( Clone , Debug , Deserialize ) ]
30
56
#[ serde( default ) ]
@@ -211,3 +237,55 @@ impl Config for AzureConfig {
211
237
vec ! [ ( "api-version" , & self . api_version) ]
212
238
}
213
239
}
240
+
241
+ #[ cfg( test) ]
242
+ mod test {
243
+ use super :: * ;
244
+ use crate :: types:: {
245
+ ChatCompletionRequestMessage , ChatCompletionRequestUserMessage , CreateChatCompletionRequest ,
246
+ } ;
247
+ use crate :: Client ;
248
+ use std:: sync:: Arc ;
249
+ #[ test]
250
+ fn test_client_creation ( ) {
251
+ unsafe { std:: env:: set_var ( "OPENAI_API_KEY" , "test" ) }
252
+ let openai_config = OpenAIConfig :: default ( ) ;
253
+ let config = Box :: new ( openai_config. clone ( ) ) as Box < dyn Config > ;
254
+ let client = Client :: with_config ( config) ;
255
+ assert ! ( client. config( ) . url( "" ) . ends_with( "/v1" ) ) ;
256
+
257
+ let config = Arc :: new ( openai_config) as Arc < dyn Config > ;
258
+ let client = Client :: with_config ( config) ;
259
+ assert ! ( client. config( ) . url( "" ) . ends_with( "/v1" ) ) ;
260
+ let cloned_client = client. clone ( ) ;
261
+ assert ! ( cloned_client. config( ) . url( "" ) . ends_with( "/v1" ) ) ;
262
+ }
263
+
264
+ async fn dynamic_dispatch_compiles ( client : & Client < Box < dyn Config > > ) {
265
+ let _ = client. chat ( ) . create ( CreateChatCompletionRequest {
266
+ model : "gpt-4o" . to_string ( ) ,
267
+ messages : vec ! [ ChatCompletionRequestMessage :: User (
268
+ ChatCompletionRequestUserMessage {
269
+ content: "Hello, world!" . into( ) ,
270
+ ..Default :: default ( )
271
+ } ,
272
+ ) ] ,
273
+ ..Default :: default ( )
274
+ } ) ;
275
+ }
276
+
277
+ #[ tokio:: test]
278
+ async fn test_dynamic_dispatch ( ) {
279
+ let openai_config = OpenAIConfig :: default ( ) ;
280
+ let azure_config = AzureConfig :: default ( ) ;
281
+
282
+ let azure_client = Client :: with_config ( Box :: new ( azure_config. clone ( ) ) as Box < dyn Config > ) ;
283
+ let oai_client = Client :: with_config ( Box :: new ( openai_config. clone ( ) ) as Box < dyn Config > ) ;
284
+
285
+ let _ = dynamic_dispatch_compiles ( & azure_client) . await ;
286
+ let _ = dynamic_dispatch_compiles ( & oai_client) . await ;
287
+
288
+ let _ = tokio:: spawn ( async move { dynamic_dispatch_compiles ( & azure_client) . await } ) ;
289
+ let _ = tokio:: spawn ( async move { dynamic_dispatch_compiles ( & oai_client) . await } ) ;
290
+ }
291
+ }
0 commit comments