1- from typing import Any
1+ from typing import Any , cast
22
33import requests
44from requests .exceptions import JSONDecodeError , RequestException
55from rich .console import Console
66
7+ from crewai .cli .authentication .main import Oauth2Settings , ProviderFactory
78from crewai .cli .command import BaseCommand
89from crewai .cli .settings .main import SettingsCommand
910from crewai .cli .version import get_crewai_version
1314
1415
1516class EnterpriseConfigureCommand (BaseCommand ):
16- def __init__ (self ):
17+ def __init__ (self ) -> None :
1718 super ().__init__ ()
1819 self .settings_command = SettingsCommand ()
1920
@@ -54,25 +55,12 @@ def _fetch_oauth_config(self, enterprise_url: str) -> dict[str, Any]:
5455 except JSONDecodeError as e :
5556 raise ValueError (f"Invalid JSON response from { oauth_endpoint } " ) from e
5657
57- required_fields = [
58- "audience" ,
59- "domain" ,
60- "device_authorization_client_id" ,
61- "provider" ,
62- ]
63- missing_fields = [
64- field for field in required_fields if field not in oauth_config
65- ]
66-
67- if missing_fields :
68- raise ValueError (
69- f"Missing required fields in OAuth2 configuration: { ', ' .join (missing_fields )} "
70- )
58+ self ._validate_oauth_config (oauth_config )
7159
7260 console .print (
7361 "✅ Successfully retrieved OAuth2 configuration" , style = "green"
7462 )
75- return oauth_config
63+ return cast ( dict [ str , Any ], oauth_config )
7664
7765 except RequestException as e :
7866 raise ValueError (f"Failed to connect to enterprise URL: { e !s} " ) from e
@@ -89,6 +77,7 @@ def _update_oauth_settings(
8977 "oauth2_audience" : oauth_config ["audience" ],
9078 "oauth2_client_id" : oauth_config ["device_authorization_client_id" ],
9179 "oauth2_domain" : oauth_config ["domain" ],
80+ "oauth2_extra" : oauth_config ["extra" ],
9281 }
9382
9483 console .print ("🔄 Updating local OAuth2 configuration..." )
@@ -99,3 +88,38 @@ def _update_oauth_settings(
9988
10089 except Exception as e :
10190 raise ValueError (f"Failed to update OAuth2 settings: { e !s} " ) from e
91+
92+ def _validate_oauth_config (self , oauth_config : dict [str , Any ]) -> None :
93+ required_fields = [
94+ "audience" ,
95+ "domain" ,
96+ "device_authorization_client_id" ,
97+ "provider" ,
98+ "extra" ,
99+ ]
100+
101+ missing_basic_fields = [
102+ field for field in required_fields if field not in oauth_config
103+ ]
104+ missing_provider_specific_fields = [
105+ field
106+ for field in self ._get_provider_specific_fields (oauth_config ["provider" ])
107+ if field not in oauth_config .get ("extra" , {})
108+ ]
109+
110+ if missing_basic_fields :
111+ raise ValueError (
112+ f"Missing required fields in OAuth2 configuration: [{ ', ' .join (missing_basic_fields )} ]"
113+ )
114+
115+ if missing_provider_specific_fields :
116+ raise ValueError (
117+ f"Missing authentication provider required fields in OAuth2 configuration: [{ ', ' .join (missing_provider_specific_fields )} ] (Configured provider: '{ oauth_config ['provider' ]} ')"
118+ )
119+
120+ def _get_provider_specific_fields (self , provider_name : str ) -> list [str ]:
121+ provider = ProviderFactory .from_settings (
122+ Oauth2Settings (provider = provider_name , client_id = "dummy" , domain = "dummy" )
123+ )
124+
125+ return provider .get_required_fields ()
0 commit comments