-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Open
Labels
Description
Initial Checks
- I confirm that I'm using the latest version of Pydantic AI
- I confirm that I searched for my issue in https://github.com/pydantic/pydantic-ai/issues before opening this issue
Description
The following exception was encountered when trying sequential streaming of the pydantic-graph beta API as shown in the example code (in jupyter notebook). The problem seems if the node_id was not provided to the @g.stream decorator then wrapper is used as the id. From my local test, it seems adding a line node_id = node_id or get_callable_name(call) between the wrapper function return and step function return (on line 286) would fix it.
pydantic-ai/pydantic_graph/pydantic_graph/beta/graph_builder.py
Lines 283 to 287 in ffafffe
| # We need to wrap the call so that we can call `await` even though the result is an async iterator | |
| async def wrapper(ctx: StepContext[StateT, DepsT, InputT]): | |
| return call(ctx) | |
| return self.step(call=wrapper, node_id=node_id, label=label) |
---------------------------------------------------------------------------
GraphBuildingError Traceback (most recent call last)
Cell In[1], [line 33](vscode-notebook-cell:?execution_count=1&line=33)
28 return ctx.inputs + 1.6
31 collect = g.join(reduce_list_append, initial_factory=list[int])
---> [33](vscode-notebook-cell:?execution_count=1&line=33) g.add(
34 g.edge_from(g.start_node).to(generate_stream),
35 # The stream output is an AsyncIterable, so we can map over it
36 g.edge_from(generate_stream).map().to(square),
37 g.edge_from(square).map().to(plus_one),
38 g.edge_from(plus_one).to(collect),
39 g.edge_from(collect).to(g.end_node),
40 )
42 graph = g.build()
44 async def main():
File ~/code/.venv/lib/python3.12/site-packages/pydantic_graph/beta/graph_builder.py:380, in GraphBuilder.add(self, *edges)
378 self._edges_by_source[source_node.id].append(edge.path)
379 for destination_node in edge.destinations:
--> 380 _handle_destination_node(destination_node)
381 _handle_path(edge.path)
383 # Automatically create edges from step function return hints including `BaseNode`s
File ~/code/.venv/lib/python3.12/site-packages/pydantic_graph/beta/graph_builder.py:366, in GraphBuilder.add.<locals>._handle_destination_node(d)
364 destination_ids.add(id(d))
365 destinations.append(d)
--> 366 self._insert_node(d)
367 if isinstance(d, Decision):
368 for branch in d.branches:
File ~/code/.venv/lib/python3.12/site-packages/pydantic_graph/beta/graph_builder.py:562, in GraphBuilder._insert_node(self, node)
560 pass
561 elif existing is not node:
--> 562 raise GraphBuildingError(
563 f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}'
564 )
GraphBuildingError: All nodes must have unique node IDs. 'wrapper' was the ID for Step(id='wrapper', _call=<function GraphBuilder.stream.<locals>.wrapper at 0x107b50f40>, label=None) and Step(id='wrapper', _call=<function GraphBuilder.stream.<locals>.wrapper at 0x10eeb40e0>, label=None)
Example Code
from dataclasses import dataclass
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import reduce_list_append
@dataclass
class SimpleState:
pass
g = GraphBuilder(state_type=SimpleState, output_type=list[int])
@g.stream#(node_id="my first stream")
async def generate_stream(ctx: StepContext[SimpleState, None, None]):
"""Stream numbers from 1 to 5."""
for i in range(1, 6):
print(f"generate stream {i}")
yield i
@g.stream#(node_id="my second stream")
async def square(ctx: StepContext[SimpleState, None, int]):
print(f"squaring {ctx.inputs * ctx.inputs}")
yield ctx.inputs * ctx.inputs
@g.step
async def plus_one(ctx: StepContext[SimpleState, None, int]) -> int:
print(f"plus_one {ctx.inputs + 1.6 }")
return ctx.inputs + 1.6
collect = g.join(reduce_list_append, initial_factory=list[int])
g.add(
g.edge_from(g.start_node).to(generate_stream),
# The stream output is an AsyncIterable, so we can map over it
g.edge_from(generate_stream).map().to(square),
g.edge_from(square).map().to(plus_one),
g.edge_from(plus_one).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
async def main():
result = await graph.run(state=SimpleState())
print(sorted(result))Python, Pydantic AI & LLM client version
Python==3.12.11
pydantic-graph==1.25.0
No LLM was used