From b34e7d492fa71c396f111e2f78e410f2002c29eb Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 2 Jul 2025 12:50:54 -0600 Subject: [PATCH] Use contextvars for agent overriding, rather than a local attribute --- pydantic_ai_slim/pydantic_ai/agent.py | 28 +++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 0cbb12af0..5342daa1b 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -6,6 +6,7 @@ import warnings from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager +from contextvars import ContextVar from copy import deepcopy from types import FrameType from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload @@ -157,8 +158,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) _default_retries: int = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) - _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False) - _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False) @overload def __init__( @@ -367,6 +366,9 @@ def __init__( else: self._register_tool(Tool(tool)) + self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) + self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) + @staticmethod def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: """Set the instrumentation options for all agents where `instrument` is not set.""" @@ -1113,24 +1115,22 @@ def override( model: The model to use instead of the model passed to the agent run. """ if _utils.is_set(deps): - override_deps_before = self._override_deps - self._override_deps = _utils.Some(deps) + deps_token = self._override_deps.set(_utils.Some(deps)) else: - override_deps_before = _utils.UNSET + deps_token = None if _utils.is_set(model): - override_model_before = self._override_model - self._override_model = _utils.Some(models.infer_model(model)) + model_token = self._override_model.set(_utils.Some(models.infer_model(model))) else: - override_model_before = _utils.UNSET + model_token = None try: yield finally: - if _utils.is_set(override_deps_before): - self._override_deps = override_deps_before - if _utils.is_set(override_model_before): - self._override_model = override_model_before + if deps_token is not None: + self._override_deps.reset(deps_token) + if model_token is not None: + self._override_model.reset(model_token) @overload def instructions( @@ -1604,7 +1604,7 @@ def _get_model(self, model: models.Model | models.KnownModelName | str | None) - The model used """ model_: models.Model - if some_model := self._override_model: + if some_model := self._override_model.get(): # we don't want `override()` to cover up errors from the model not being defined, hence this check if model is None and self.model is None: raise exceptions.UserError( @@ -1633,7 +1633,7 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope. """ - if some_deps := self._override_deps: + if some_deps := self._override_deps.get(): return some_deps.value else: return deps