Skip to content

Commit 908ede4

Browse files
committed
Merge branch 'main' into vo/feat/ai-mind-tool
2 parents da36e16 + fbe4aa4 commit 908ede4

24 files changed

+1920
-87
lines changed

lib/crewai-tools/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212

1313

1414
class QdrantToolSchema(BaseModel):
15-
query: str = Field(..., description="Query to search in Qdrant DB")
15+
query: str = Field(
16+
..., description="Query to search in Qdrant DB - always required."
17+
)
1618
filter_by: str | None = Field(
17-
default=None, description="Parameter to filter the search by."
19+
default=None,
20+
description="Parameter to filter the search by. When filtering, needs to be used in conjunction with filter_value.",
1821
)
1922
filter_value: Any | None = Field(
20-
default=None, description="Value to filter the search by."
23+
default=None,
24+
description="Value to filter the search by. When filtering, needs to be used in conjunction with filter_by.",
2125
)
2226

2327

lib/crewai/src/crewai/agents/crew_agent_executor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
)
3939
from crewai.utilities.constants import TRAINING_DATA_FILE
4040
from crewai.utilities.i18n import I18N, get_i18n
41+
from crewai.utilities.llm_call_hooks import (
42+
get_after_llm_call_hooks,
43+
get_before_llm_call_hooks,
44+
)
4145
from crewai.utilities.printer import Printer
4246
from crewai.utilities.tool_utils import execute_tool_and_check_finality
4347
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -130,6 +134,10 @@ def __init__(
130134
self.messages: list[LLMMessage] = []
131135
self.iterations = 0
132136
self.log_error_after = 3
137+
self.before_llm_call_hooks: list[Callable] = []
138+
self.after_llm_call_hooks: list[Callable] = []
139+
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
140+
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
133141
if self.llm:
134142
# This may be mutating the shared llm object and needs further evaluation
135143
existing_stop = getattr(self.llm, "stop", [])
@@ -226,6 +234,7 @@ def _invoke_loop(self) -> AgentFinish:
226234
from_task=self.task,
227235
from_agent=self.agent,
228236
response_model=self.response_model,
237+
executor_context=self,
229238
)
230239
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
231240

lib/crewai/src/crewai/cli/authentication/main.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Any
2+
from typing import TYPE_CHECKING, Any, TypeVar, cast
33
import webbrowser
44

55
from pydantic import BaseModel, Field
@@ -13,6 +13,8 @@
1313

1414
console = Console()
1515

16+
TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings")
17+
1618

1719
class Oauth2Settings(BaseModel):
1820
provider: str = Field(
@@ -28,22 +30,36 @@ class Oauth2Settings(BaseModel):
2830
description="OAuth2 audience value, typically used to identify the target API or resource.",
2931
default=None,
3032
)
33+
extra: dict[str, Any] = Field(
34+
description="Extra configuration for the OAuth2 provider.",
35+
default={},
36+
)
3137

3238
@classmethod
33-
def from_settings(cls):
39+
def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings:
40+
"""Create an Oauth2Settings instance from the CLI settings."""
41+
3442
settings = Settings()
3543

3644
return cls(
3745
provider=settings.oauth2_provider,
3846
domain=settings.oauth2_domain,
3947
client_id=settings.oauth2_client_id,
4048
audience=settings.oauth2_audience,
49+
extra=settings.oauth2_extra,
4150
)
4251

4352

53+
if TYPE_CHECKING:
54+
from crewai.cli.authentication.providers.base_provider import BaseProvider
55+
56+
4457
class ProviderFactory:
4558
@classmethod
46-
def from_settings(cls, settings: Oauth2Settings | None = None):
59+
def from_settings(
60+
cls: type["ProviderFactory"], # noqa: UP037
61+
settings: Oauth2Settings | None = None,
62+
) -> "BaseProvider": # noqa: UP037
4763
settings = settings or Oauth2Settings.from_settings()
4864

4965
import importlib
@@ -53,11 +69,11 @@ def from_settings(cls, settings: Oauth2Settings | None = None):
5369
)
5470
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
5571

56-
return provider(settings)
72+
return cast("BaseProvider", provider(settings))
5773

5874

5975
class AuthenticationCommand:
60-
def __init__(self):
76+
def __init__(self) -> None:
6177
self.token_manager = TokenManager()
6278
self.oauth2_provider = ProviderFactory.from_settings()
6379

@@ -84,7 +100,7 @@ def _get_device_code(self) -> dict[str, Any]:
84100
timeout=20,
85101
)
86102
response.raise_for_status()
87-
return response.json()
103+
return cast(dict[str, Any], response.json())
88104

89105
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
90106
"""Display the authentication instructions to the user."""

lib/crewai/src/crewai/cli/authentication/providers/base_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ def get_audience(self) -> str: ...
2424

2525
@abstractmethod
2626
def get_client_id(self) -> str: ...
27+
28+
def get_required_fields(self) -> list[str]:
29+
"""Returns which provider-specific fields inside the "extra" dict will be required"""
30+
return []

lib/crewai/src/crewai/cli/authentication/providers/okta.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33

44
class OktaProvider(BaseProvider):
55
def get_authorize_url(self) -> str:
6-
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
6+
return f"{self._oauth2_base_url()}/v1/device/authorize"
77

88
def get_token_url(self) -> str:
9-
return f"https://{self.settings.domain}/oauth2/default/v1/token"
9+
return f"{self._oauth2_base_url()}/v1/token"
1010

1111
def get_jwks_url(self) -> str:
12-
return f"https://{self.settings.domain}/oauth2/default/v1/keys"
12+
return f"{self._oauth2_base_url()}/v1/keys"
1313

1414
def get_issuer(self) -> str:
15-
return f"https://{self.settings.domain}/oauth2/default"
15+
return self._oauth2_base_url().removesuffix("/oauth2")
1616

1717
def get_audience(self) -> str:
1818
if self.settings.audience is None:
@@ -27,3 +27,16 @@ def get_client_id(self) -> str:
2727
"Client ID is required. Please set it in the configuration."
2828
)
2929
return self.settings.client_id
30+
31+
def get_required_fields(self) -> list[str]:
32+
return ["authorization_server_name", "using_org_auth_server"]
33+
34+
def _oauth2_base_url(self) -> str:
35+
using_org_auth_server = self.settings.extra.get("using_org_auth_server", False)
36+
37+
if using_org_auth_server:
38+
base_url = f"https://{self.settings.domain}/oauth2"
39+
else:
40+
base_url = f"https://{self.settings.domain}/oauth2/{self.settings.extra.get('authorization_server_name', 'default')}"
41+
42+
return f"{base_url}"

lib/crewai/src/crewai/cli/command.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@
1111

1212

1313
class BaseCommand:
14-
def __init__(self):
14+
def __init__(self) -> None:
1515
self._telemetry = Telemetry()
1616
self._telemetry.set_tracer()
1717

1818

1919
class PlusAPIMixin:
20-
def __init__(self, telemetry):
20+
def __init__(self, telemetry: Telemetry) -> None:
2121
try:
2222
telemetry.set_tracer()
2323
self.plus_api_client = PlusAPI(api_key=get_auth_token())
2424
except Exception:
25-
self._deploy_signup_error_span = telemetry.deploy_signup_error_span()
25+
telemetry.deploy_signup_error_span()
2626
console.print(
2727
"Please sign up/login to CrewAI+ before using the CLI.",
2828
style="bold red",

lib/crewai/src/crewai/cli/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from logging import getLogger
33
from pathlib import Path
44
import tempfile
5+
from typing import Any
56

67
from pydantic import BaseModel, Field
78

@@ -136,7 +137,12 @@ class Settings(BaseModel):
136137
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
137138
)
138139

139-
def __init__(self, config_path: Path | None = None, **data):
140+
oauth2_extra: dict[str, Any] = Field(
141+
description="Extra configuration for the OAuth2 provider.",
142+
default={},
143+
)
144+
145+
def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None:
140146
"""Load Settings from config path with fallback support"""
141147
if config_path is None:
142148
config_path = get_writable_config_path()

lib/crewai/src/crewai/cli/enterprise/main.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Any
1+
from typing import Any, cast
22

33
import requests
44
from requests.exceptions import JSONDecodeError, RequestException
55
from rich.console import Console
66

7+
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
78
from crewai.cli.command import BaseCommand
89
from crewai.cli.settings.main import SettingsCommand
910
from crewai.cli.version import get_crewai_version
@@ -13,7 +14,7 @@
1314

1415

1516
class EnterpriseConfigureCommand(BaseCommand):
16-
def __init__(self):
17+
def __init__(self) -> None:
1718
super().__init__()
1819
self.settings_command = SettingsCommand()
1920

@@ -54,25 +55,12 @@ def _fetch_oauth_config(self, enterprise_url: str) -> dict[str, Any]:
5455
except JSONDecodeError as e:
5556
raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e
5657

57-
required_fields = [
58-
"audience",
59-
"domain",
60-
"device_authorization_client_id",
61-
"provider",
62-
]
63-
missing_fields = [
64-
field for field in required_fields if field not in oauth_config
65-
]
66-
67-
if missing_fields:
68-
raise ValueError(
69-
f"Missing required fields in OAuth2 configuration: {', '.join(missing_fields)}"
70-
)
58+
self._validate_oauth_config(oauth_config)
7159

7260
console.print(
7361
"✅ Successfully retrieved OAuth2 configuration", style="green"
7462
)
75-
return oauth_config
63+
return cast(dict[str, Any], oauth_config)
7664

7765
except RequestException as e:
7866
raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e
@@ -89,6 +77,7 @@ def _update_oauth_settings(
8977
"oauth2_audience": oauth_config["audience"],
9078
"oauth2_client_id": oauth_config["device_authorization_client_id"],
9179
"oauth2_domain": oauth_config["domain"],
80+
"oauth2_extra": oauth_config["extra"],
9281
}
9382

9483
console.print("🔄 Updating local OAuth2 configuration...")
@@ -99,3 +88,38 @@ def _update_oauth_settings(
9988

10089
except Exception as e:
10190
raise ValueError(f"Failed to update OAuth2 settings: {e!s}") from e
91+
92+
def _validate_oauth_config(self, oauth_config: dict[str, Any]) -> None:
93+
required_fields = [
94+
"audience",
95+
"domain",
96+
"device_authorization_client_id",
97+
"provider",
98+
"extra",
99+
]
100+
101+
missing_basic_fields = [
102+
field for field in required_fields if field not in oauth_config
103+
]
104+
missing_provider_specific_fields = [
105+
field
106+
for field in self._get_provider_specific_fields(oauth_config["provider"])
107+
if field not in oauth_config.get("extra", {})
108+
]
109+
110+
if missing_basic_fields:
111+
raise ValueError(
112+
f"Missing required fields in OAuth2 configuration: [{', '.join(missing_basic_fields)}]"
113+
)
114+
115+
if missing_provider_specific_fields:
116+
raise ValueError(
117+
f"Missing authentication provider required fields in OAuth2 configuration: [{', '.join(missing_provider_specific_fields)}] (Configured provider: '{oauth_config['provider']}')"
118+
)
119+
120+
def _get_provider_specific_fields(self, provider_name: str) -> list[str]:
121+
provider = ProviderFactory.from_settings(
122+
Oauth2Settings(provider=provider_name, client_id="dummy", domain="dummy")
123+
)
124+
125+
return provider.get_required_fields()

lib/crewai/src/crewai/cli/git.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
class Repository:
6-
def __init__(self, path="."):
6+
def __init__(self, path: str = ".") -> None:
77
self.path = path
88

99
if not self.is_git_installed():

0 commit comments

Comments
 (0)