Skip to content

GraphBuildingError when sequential streaming #3587

@jb2197

Description

@jb2197

Initial Checks

Description

The following exception was encountered when trying sequential streaming of the pydantic-graph beta API as shown in the example code (in jupyter notebook). The problem seems if the node_id was not provided to the @g.stream decorator then wrapper is used as the id. From my local test, it seems adding a line node_id = node_id or get_callable_name(call) between the wrapper function return and step function return (on line 286) would fix it.

# We need to wrap the call so that we can call `await` even though the result is an async iterator
async def wrapper(ctx: StepContext[StateT, DepsT, InputT]):
return call(ctx)
return self.step(call=wrapper, node_id=node_id, label=label)

---------------------------------------------------------------------------
GraphBuildingError                        Traceback (most recent call last)
Cell In[1], [line 33](vscode-notebook-cell:?execution_count=1&line=33)
     28     return ctx.inputs + 1.6
     31 collect = g.join(reduce_list_append, initial_factory=list[int])
---> [33](vscode-notebook-cell:?execution_count=1&line=33) g.add(
     34     g.edge_from(g.start_node).to(generate_stream),
     35     # The stream output is an AsyncIterable, so we can map over it
     36     g.edge_from(generate_stream).map().to(square),
     37     g.edge_from(square).map().to(plus_one),
     38     g.edge_from(plus_one).to(collect),
     39     g.edge_from(collect).to(g.end_node),
     40 )
     42 graph = g.build()
     44 async def main():

File ~/code/.venv/lib/python3.12/site-packages/pydantic_graph/beta/graph_builder.py:380, in GraphBuilder.add(self, *edges)
    378         self._edges_by_source[source_node.id].append(edge.path)
    379     for destination_node in edge.destinations:
--> 380         _handle_destination_node(destination_node)
    381     _handle_path(edge.path)
    383 # Automatically create edges from step function return hints including `BaseNode`s

File ~/code/.venv/lib/python3.12/site-packages/pydantic_graph/beta/graph_builder.py:366, in GraphBuilder.add.<locals>._handle_destination_node(d)
    364 destination_ids.add(id(d))
    365 destinations.append(d)
--> 366 self._insert_node(d)
    367 if isinstance(d, Decision):
    368     for branch in d.branches:

File ~/code/.venv/lib/python3.12/site-packages/pydantic_graph/beta/graph_builder.py:562, in GraphBuilder._insert_node(self, node)
    560     pass
    561 elif existing is not node:
--> 562     raise GraphBuildingError(
    563         f'All nodes must have unique node IDs. {node.id!r} was the ID for {existing} and {node}'
    564     )

GraphBuildingError: All nodes must have unique node IDs. 'wrapper' was the ID for Step(id='wrapper', _call=<function GraphBuilder.stream.<locals>.wrapper at 0x107b50f40>, label=None) and Step(id='wrapper', _call=<function GraphBuilder.stream.<locals>.wrapper at 0x10eeb40e0>, label=None)

Example Code

from dataclasses import dataclass

from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import reduce_list_append


@dataclass
class SimpleState:
    pass

g = GraphBuilder(state_type=SimpleState, output_type=list[int])

@g.stream#(node_id="my first stream")
async def generate_stream(ctx: StepContext[SimpleState, None, None]):
    """Stream numbers from 1 to 5."""
    for i in range(1, 6):
        print(f"generate stream {i}")
        yield i

@g.stream#(node_id="my second stream")
async def square(ctx: StepContext[SimpleState, None, int]):
    print(f"squaring {ctx.inputs * ctx.inputs}")
    yield ctx.inputs * ctx.inputs

@g.step
async def plus_one(ctx: StepContext[SimpleState, None, int]) -> int:
    print(f"plus_one {ctx.inputs + 1.6 }")
    return ctx.inputs + 1.6


collect = g.join(reduce_list_append, initial_factory=list[int])

g.add(
    g.edge_from(g.start_node).to(generate_stream),
    # The stream output is an AsyncIterable, so we can map over it
    g.edge_from(generate_stream).map().to(square),
    g.edge_from(square).map().to(plus_one),
    g.edge_from(plus_one).to(collect),
    g.edge_from(collect).to(g.end_node),
)

graph = g.build()

async def main():
    result = await graph.run(state=SimpleState())
    print(sorted(result))

Python, Pydantic AI & LLM client version

Python==3.12.11
pydantic-graph==1.25.0
No LLM was used

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions