Skip to content

Commit 6e50100

Browse files
Allow to pass both session and input list (#1298)
1 parent daa9695 commit 6e50100

File tree

4 files changed

+110
-18
lines changed

4 files changed

+110
-18
lines changed

src/agents/memory/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from .openai_conversations_session import OpenAIConversationsSession
22
from .session import Session, SessionABC
33
from .sqlite_session import SQLiteSession
4+
from .util import SessionInputCallback
45

56
__all__ = [
67
"Session",
78
"SessionABC",
9+
"SessionInputCallback",
810
"SQLiteSession",
911
"OpenAIConversationsSession",
1012
]

src/agents/memory/util.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
5+
from ..items import TResponseInputItem
6+
from ..util._types import MaybeAwaitable
7+
8+
SessionInputCallback = Callable[
9+
[list[TResponseInputItem], list[TResponseInputItem]],
10+
MaybeAwaitable[list[TResponseInputItem]],
11+
]
12+
"""A function that combines session history with new input items.
13+
14+
Args:
15+
history_items: The list of items from the session history.
16+
new_items: The list of new input items for the current turn.
17+
18+
Returns:
19+
A list of combined items to be used as input for the agent. Can be sync or async.
20+
"""

src/agents/run.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
)
5555
from .lifecycle import RunHooks
5656
from .logger import logger
57-
from .memory import Session
57+
from .memory import Session, SessionInputCallback
5858
from .model_settings import ModelSettings
5959
from .models.interface import Model, ModelProvider
6060
from .models.multi_provider import MultiProvider
@@ -179,6 +179,13 @@ class RunConfig:
179179
An optional dictionary of additional metadata to include with the trace.
180180
"""
181181

182+
session_input_callback: SessionInputCallback | None = None
183+
"""Defines how to handle session history when new input is provided.
184+
- `None` (default): The new input is appended to the session history.
185+
- `SessionInputCallback`: A custom function that receives the history and new input, and
186+
returns the desired combined list of items.
187+
"""
188+
182189
call_model_input_filter: CallModelInputFilter | None = None
183190
"""
184191
Optional callback that is invoked immediately before calling the model. It receives the current
@@ -413,7 +420,9 @@ async def run(
413420

414421
# Keep original user input separate from session-prepared input
415422
original_user_input = input
416-
prepared_input = await self._prepare_input_with_session(input, session)
423+
prepared_input = await self._prepare_input_with_session(
424+
input, session, run_config.session_input_callback
425+
)
417426

418427
tool_use_tracker = AgentToolUseTracker()
419428

@@ -781,7 +790,9 @@ async def _start_streaming(
781790

782791
try:
783792
# Prepare input with session if enabled
784-
prepared_input = await AgentRunner._prepare_input_with_session(starting_input, session)
793+
prepared_input = await AgentRunner._prepare_input_with_session(
794+
starting_input, session, run_config.session_input_callback
795+
)
785796

786797
# Update the streamed result with the prepared input
787798
streamed_result.input = prepared_input
@@ -1474,19 +1485,20 @@ async def _prepare_input_with_session(
14741485
cls,
14751486
input: str | list[TResponseInputItem],
14761487
session: Session | None,
1488+
session_input_callback: SessionInputCallback | None,
14771489
) -> str | list[TResponseInputItem]:
14781490
"""Prepare input by combining it with session history if enabled."""
14791491
if session is None:
14801492
return input
14811493

1482-
# Validate that we don't have both a session and a list input, as this creates
1483-
# ambiguity about whether the list should append to or replace existing session history
1484-
if isinstance(input, list):
1494+
# If the user doesn't specify an input callback and pass a list as input
1495+
if isinstance(input, list) and not session_input_callback:
14851496
raise UserError(
1486-
"Cannot provide both a session and a list of input items. "
1487-
"When using session memory, provide only a string input to append to the "
1488-
"conversation, or use session=None and provide a list to manually manage "
1489-
"conversation history."
1497+
"When using session memory, list inputs require a "
1498+
"`RunConfig.session_input_callback` to define how they should be merged "
1499+
"with the conversation history. If you don't want to use a callback, "
1500+
"provide your input as a string instead, or disable session memory "
1501+
"(session=None) and pass a list to manage the history manually."
14901502
)
14911503

14921504
# Get previous conversation history
@@ -1495,10 +1507,18 @@ async def _prepare_input_with_session(
14951507
# Convert input to list format
14961508
new_input_list = ItemHelpers.input_to_new_input_list(input)
14971509

1498-
# Combine history with new input
1499-
combined_input = history + new_input_list
1500-
1501-
return combined_input
1510+
if session_input_callback is None:
1511+
return history + new_input_list
1512+
elif callable(session_input_callback):
1513+
res = session_input_callback(history, new_input_list)
1514+
if inspect.isawaitable(res):
1515+
return await res
1516+
return res
1517+
else:
1518+
raise UserError(
1519+
f"Invalid `session_input_callback` value: {session_input_callback}. "
1520+
"Choose between `None` or a custom callable function."
1521+
)
15021522

15031523
@classmethod
15041524
async def _save_result_to_session(
@@ -1507,7 +1527,11 @@ async def _save_result_to_session(
15071527
original_input: str | list[TResponseInputItem],
15081528
new_items: list[RunItem],
15091529
) -> None:
1510-
"""Save the conversation turn to session."""
1530+
"""
1531+
Save the conversation turn to session.
1532+
It does not account for any filtering or modification performed by
1533+
`RunConfig.session_input_callback`.
1534+
"""
15111535
if session is None:
15121536
return
15131537

tests/test_session.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from agents import Agent, Runner, SQLiteSession, TResponseInputItem
9+
from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem
1010
from agents.exceptions import UserError
1111

1212
from .fake_model import FakeModel
@@ -394,11 +394,57 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method)
394394
await run_agent_async(runner_method, agent, list_input, session=session)
395395

396396
# Verify the error message explains the issue
397-
assert "Cannot provide both a session and a list of input items" in str(exc_info.value)
398-
assert "manually manage conversation history" in str(exc_info.value)
397+
assert "list inputs require a `RunConfig.session_input_callback" in str(exc_info.value)
398+
assert "to manage the history manually" in str(exc_info.value)
399399

400400
session.close()
401401

402+
403+
@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
404+
@pytest.mark.asyncio
405+
async def test_session_callback_prepared_input(runner_method):
406+
"""Test if the user passes a list of items and want to append them."""
407+
with tempfile.TemporaryDirectory() as temp_dir:
408+
db_path = Path(temp_dir) / "test_memory.db"
409+
410+
model = FakeModel()
411+
agent = Agent(name="test", model=model)
412+
413+
# Session
414+
session_id = "session_1"
415+
session = SQLiteSession(session_id, db_path)
416+
417+
# Add first messages manually
418+
initial_history: list[TResponseInputItem] = [
419+
{"role": "user", "content": "Hello there."},
420+
{"role": "assistant", "content": "Hi, I'm here to assist you."},
421+
]
422+
await session.add_items(initial_history)
423+
424+
def filter_assistant_messages(history, new_input):
425+
# Only include user messages from history
426+
return [item for item in history if item["role"] == "user"] + new_input
427+
428+
new_turn_input = [{"role": "user", "content": "What your name?"}]
429+
model.set_next_output([get_text_message("I'm gpt-4o")])
430+
431+
# Run the agent with the callable
432+
await run_agent_async(
433+
runner_method,
434+
agent,
435+
new_turn_input,
436+
session=session,
437+
run_config=RunConfig(session_input_callback=filter_assistant_messages),
438+
)
439+
440+
expected_model_input = [
441+
initial_history[0], # From history
442+
new_turn_input[0], # New input
443+
]
444+
445+
assert len(model.last_turn_args["input"]) == 2
446+
assert model.last_turn_args["input"] == expected_model_input
447+
402448
@pytest.mark.asyncio
403449
async def test_sqlite_session_unicode_content():
404450
"""Test that session correctly stores and retrieves unicode/non-ASCII content."""

0 commit comments

Comments
 (0)