Skip to content

Use contextvars for agent overriding, rather than a local attribute #2118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
from contextvars import ContextVar
from copy import deepcopy
from types import FrameType
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
Expand Down Expand Up @@ -157,8 +158,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
_mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
_default_retries: int = dataclasses.field(repr=False)
_max_result_retries: int = dataclasses.field(repr=False)
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)

@overload
def __init__(
Expand Down Expand Up @@ -367,6 +366,9 @@ def __init__(
else:
self._register_tool(Tool(tool))

self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None)
self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None)

@staticmethod
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
"""Set the instrumentation options for all agents where `instrument` is not set."""
Expand Down Expand Up @@ -1113,24 +1115,22 @@ def override(
model: The model to use instead of the model passed to the agent run.
"""
if _utils.is_set(deps):
override_deps_before = self._override_deps
self._override_deps = _utils.Some(deps)
deps_token = self._override_deps.set(_utils.Some(deps))
else:
override_deps_before = _utils.UNSET
deps_token = None

if _utils.is_set(model):
override_model_before = self._override_model
self._override_model = _utils.Some(models.infer_model(model))
model_token = self._override_model.set(_utils.Some(models.infer_model(model)))
else:
override_model_before = _utils.UNSET
model_token = None

try:
yield
finally:
if _utils.is_set(override_deps_before):
self._override_deps = override_deps_before
if _utils.is_set(override_model_before):
self._override_model = override_model_before
if deps_token is not None:
self._override_deps.reset(deps_token)
if model_token is not None:
self._override_model.reset(model_token)

@overload
def instructions(
Expand Down Expand Up @@ -1604,7 +1604,7 @@ def _get_model(self, model: models.Model | models.KnownModelName | str | None) -
The model used
"""
model_: models.Model
if some_model := self._override_model:
if some_model := self._override_model.get():
# we don't want `override()` to cover up errors from the model not being defined, hence this check
if model is None and self.model is None:
raise exceptions.UserError(
Expand Down Expand Up @@ -1633,7 +1633,7 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:

We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
"""
if some_deps := self._override_deps:
if some_deps := self._override_deps.get():
return some_deps.value
else:
return deps
Expand Down