Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions libs/langchain_v1/langchain/agents/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@
from .types import (
AgentMiddleware,
AgentState,
ModelCallHandler,
ModelCallResult,
ModelCallWrapper,
ModelRequest,
ModelResponse,
after_agent,
after_model,
before_agent,
before_model,
dynamic_prompt,
hook_config,
wrap_model_call,
wrap_tool_call,
)

__all__ = [
Expand All @@ -41,9 +46,13 @@
"InterruptOnConfig",
"LLMToolEmulator",
"LLMToolSelectorMiddleware",
"ModelCallHandler",
"ModelCallLimitMiddleware",
"ModelCallResult",
"ModelCallWrapper",
"ModelFallbackMiddleware",
"ModelRequest",
"ModelResponse",
"PIIDetectionError",
"PIIMiddleware",
"PlanningMiddleware",
Expand All @@ -56,4 +65,5 @@
"dynamic_prompt",
"hook_config",
"wrap_model_call",
"wrap_tool_call",
]
114 changes: 98 additions & 16 deletions libs/langchain_v1/langchain/agents/middleware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
)

if TYPE_CHECKING:
from collections.abc import Awaitable

from langchain.tools.tool_node import ToolCallRequest
from langchain.tools.tool_node import (
AsyncToolCallHandler,
ToolCallHandler,
ToolCallRequest,
)

# Needed as top level import for Pydantic schema generation on AgentState
from typing import TypeAlias
Expand All @@ -43,6 +45,9 @@
"AgentMiddleware",
"AgentState",
"ContextT",
"ModelCallHandler",
"ModelCallResult",
"ModelCallWrapper",
"ModelRequest",
"ModelResponse",
"OmitFromSchema",
Expand All @@ -53,6 +58,7 @@
"before_model",
"dynamic_prompt",
"hook_config",
"wrap_model_call",
"wrap_tool_call",
]

Expand Down Expand Up @@ -102,6 +108,82 @@ class ModelResponse:
"""


ModelCallHandler = Callable[[ModelRequest], ModelResponse]
"""Type alias for the handler callback passed to wrap_model_call hooks.

The handler executes the model request and returns a ModelResponse. It can be called
multiple times for retry logic or skipped entirely to short-circuit execution.

Examples:
Simple passthrough:
```python
def my_wrapper(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
return handler(request)
```

Retry logic:
```python
def retry_wrapper(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
for attempt in range(3):
try:
return handler(request)
except Exception:
if attempt == 2:
raise
```
"""

AsyncModelCallHandler = Callable[[ModelRequest], Awaitable[ModelResponse]]
"""Type alias for the async handler callback passed to wrap_model_call hooks.

The async handler executes the model request and returns a ModelResponse. It can be
called multiple times for retry logic or skipped entirely to short-circuit execution.
"""

ModelCallWrapper = Callable[[ModelRequest, ModelCallHandler], ModelCallResult]
"""Type alias for synchronous model call wrapper functions.

A wrapper receives a ModelRequest and a handler callback. It can modify the request,
call the handler (potentially multiple times), modify the response, or short-circuit
entirely.

Args:
request: Model request containing state, runtime, messages, tools, etc.
handler: Callback to execute the model. Can be called multiple times.

Returns:
ModelCallResult (either ModelResponse or AIMessage)

Examples:
Basic retry pattern:
```python
def retry_on_error(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
for attempt in range(3):
try:
return handler(request)
except Exception:
if attempt == 2:
raise
```

Access runtime context:
```python
def use_runtime(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
user_id = request.runtime.context.get("user_id")
# Modify request based on context
return handler(request)
```
"""

AsyncModelCallWrapper = Callable[[ModelRequest, AsyncModelCallHandler], Awaitable[ModelCallResult]]
"""Type alias for asynchronous model call wrapper functions.

A wrapper receives a ModelRequest and an async handler callback. It can modify the
request, call the handler (potentially multiple times), modify the response, or
short-circuit entirely.
"""


@dataclass
class OmitFromSchema:
"""Annotation used to mark state attributes as omitted from input or output schemas."""
Expand Down Expand Up @@ -195,7 +277,7 @@ async def aafter_model(
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
handler: ModelCallHandler,
) -> ModelCallResult:
"""Intercept and control model execution via handler callback.

Expand Down Expand Up @@ -278,7 +360,7 @@ def wrap_model_call(self, request, handler):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
handler: AsyncModelCallHandler,
) -> ModelCallResult:
"""Intercept and control async model execution via handler callback.

Expand Down Expand Up @@ -331,7 +413,7 @@ async def aafter_agent(
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: ToolCallHandler,
) -> ToolMessage | Command:
"""Intercept tool execution for retries, monitoring, or modification.

Expand Down Expand Up @@ -395,7 +477,7 @@ def wrap_tool_call(self, request, handler):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
handler: AsyncToolCallHandler,
) -> ToolMessage | Command:
"""Intercept and control async tool execution via handler callback.

Expand Down Expand Up @@ -480,7 +562,7 @@ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # typ
def __call__(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
handler: ModelCallHandler,
) -> ModelCallResult:
"""Intercept model execution via handler callback."""
...
Expand All @@ -495,7 +577,7 @@ class _CallableReturningToolResponse(Protocol):
def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: ToolCallHandler,
) -> ToolMessage | Command:
"""Intercept tool execution via handler callback."""
...
Expand Down Expand Up @@ -1174,7 +1256,7 @@ def decorator(
async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
handler: AsyncModelCallHandler,
) -> ModelCallResult:
prompt = await func(request) # type: ignore[misc]
request.system_prompt = prompt
Expand All @@ -1195,7 +1277,7 @@ async def async_wrapped(
def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
handler: ModelCallHandler,
) -> ModelCallResult:
prompt = cast("str", func(request))
request.system_prompt = prompt
Expand All @@ -1204,7 +1286,7 @@ def wrapped(
async def async_wrapped_from_sync(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
handler: AsyncModelCallHandler,
) -> ModelCallResult:
# Delegate to sync function
prompt = cast("str", func(request))
Expand Down Expand Up @@ -1337,7 +1419,7 @@ def decorator(
async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
handler: AsyncModelCallHandler,
) -> ModelCallResult:
return await func(request, handler) # type: ignore[misc, arg-type]

Expand All @@ -1358,7 +1440,7 @@ async def async_wrapped(
def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
handler: ModelCallHandler,
) -> ModelCallResult:
return func(request, handler)

Expand Down Expand Up @@ -1480,7 +1562,7 @@ def decorator(
async def async_wrapped(
self: AgentMiddleware, # noqa: ARG001
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
handler: AsyncToolCallHandler,
) -> ToolMessage | Command:
return await func(request, handler) # type: ignore[arg-type,misc]

Expand All @@ -1501,7 +1583,7 @@ async def async_wrapped(
def wrapped(
self: AgentMiddleware, # noqa: ARG001
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: ToolCallHandler,
) -> ToolMessage | Command:
return func(request, handler)

Expand Down
16 changes: 15 additions & 1 deletion libs/langchain_v1/langchain/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,28 @@
tool,
)

from langchain.tools.tool_node import InjectedState, InjectedStore, ToolInvocationError
from langchain.tools.tool_node import (
AsyncToolCallHandler,
AsyncToolCallWrapper,
InjectedState,
InjectedStore,
ToolCallHandler,
ToolCallRequest,
ToolCallWrapper,
ToolInvocationError,
)

__all__ = [
"AsyncToolCallHandler",
"AsyncToolCallWrapper",
"BaseTool",
"InjectedState",
"InjectedStore",
"InjectedToolArg",
"InjectedToolCallId",
"ToolCallHandler",
"ToolCallRequest",
"ToolCallWrapper",
"ToolException",
"ToolInvocationError",
"tool",
Expand Down
Loading
Loading