Skip to content

Commit d89028f

Browse files
Utility that converts async stream to sync stream (stanfordnlp#8162)
* init sync streaming * increment * fix tests * fix tests
1 parent bb5f0d1 commit d89028f

File tree

3 files changed

+109
-2
lines changed

3 files changed

+109
-2
lines changed

dspy/streaming/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dspy.streaming.messages import StatusMessage, StatusMessageProvider, StreamResponse
2-
from dspy.streaming.streamify import streamify, streaming_response
2+
from dspy.streaming.streamify import apply_sync_streaming, streamify, streaming_response
33
from dspy.streaming.streaming_listener import StreamListener
44

55
__all__ = [
@@ -9,4 +9,5 @@
99
"StreamListener",
1010
"StreamResponse",
1111
"streaming_response",
12+
"apply_sync_streaming",
1213
]

dspy/streaming/streamify.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import asyncio
2+
import contextvars
13
import logging
4+
import threading
25
from asyncio import iscoroutinefunction
3-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, List, Optional
6+
from queue import Queue
7+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, List, Optional
48

59
import litellm
610
import ujson
@@ -200,6 +204,46 @@ async def streamer(*args, **kwargs):
200204
return streamer
201205

202206

207+
def apply_sync_streaming(async_generator: AsyncGenerator) -> Generator:
208+
"""Convert the async streaming generator to a sync generator."""
209+
queue = Queue() # Queue to hold items from the async generator
210+
stop_sentinel = object() # Sentinel to signal the generator is complete
211+
212+
# To propagate prediction request ID context to the child thread
213+
context = contextvars.copy_context()
214+
from dspy.dsp.utils.settings import thread_local_overrides
215+
216+
parent_overrides = thread_local_overrides.overrides.copy()
217+
218+
def producer():
219+
"""Runs in a background thread to fetch items asynchronously."""
220+
221+
original_overrides = thread_local_overrides.overrides
222+
thread_local_overrides.overrides = parent_overrides.copy()
223+
224+
async def runner():
225+
try:
226+
async for item in async_generator:
227+
queue.put(item)
228+
finally:
229+
# Signal completion
230+
queue.put(stop_sentinel)
231+
232+
context.run(asyncio.run, runner())
233+
thread_local_overrides.overrides = original_overrides
234+
235+
# Start the producer in a background thread
236+
thread = threading.Thread(target=producer, daemon=True)
237+
thread.start()
238+
239+
# Consume items from the queue
240+
while True:
241+
item = queue.get() # Block until an item is available
242+
if item is stop_sentinel:
243+
break
244+
yield item
245+
246+
203247
async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator:
204248
"""
205249
Convert a DSPy program output stream to an OpenAI-compatible output stream that can be

tests/streaming/test_streaming.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,65 @@ def __call__(self, x: str, **kwargs):
225225

226226
assert all_chunks[-1].predict_name == "predict2"
227227
assert all_chunks[-1].signature_field_name == "judgement"
228+
229+
230+
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not found in environment variables")
231+
def test_sync_streaming():
232+
class MyProgram(dspy.Module):
233+
def __init__(self):
234+
self.predict1 = dspy.Predict("question->answer")
235+
self.predict2 = dspy.Predict("question, answer->judgement")
236+
237+
def __call__(self, x: str, **kwargs):
238+
answer = self.predict1(question=x, **kwargs)
239+
judgement = self.predict2(question=x, answer=answer, **kwargs)
240+
return judgement
241+
242+
# Turn off the cache to ensure the stream is produced.
243+
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False))
244+
my_program = MyProgram()
245+
program = dspy.streamify(
246+
my_program,
247+
stream_listeners=[
248+
dspy.streaming.StreamListener(signature_field_name="answer"),
249+
dspy.streaming.StreamListener(signature_field_name="judgement"),
250+
],
251+
include_final_prediction_in_output_stream=False,
252+
)
253+
output = program(x="why did a chicken cross the kitchen?")
254+
sync_output = dspy.streaming.apply_sync_streaming(output)
255+
all_chunks = []
256+
for value in sync_output:
257+
if isinstance(value, dspy.streaming.StreamResponse):
258+
all_chunks.append(value)
259+
260+
assert all_chunks[0].predict_name == "predict1"
261+
assert all_chunks[0].signature_field_name == "answer"
262+
263+
assert all_chunks[-1].predict_name == "predict2"
264+
assert all_chunks[-1].signature_field_name == "judgement"
265+
266+
267+
def test_sync_status_streaming():
268+
class MyProgram(dspy.Module):
269+
def __init__(self):
270+
self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
271+
self.predict = dspy.Predict("question->answer")
272+
273+
def __call__(self, x: str):
274+
question = self.generate_question(x=x)
275+
return self.predict(question=question)
276+
277+
lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
278+
with dspy.context(lm=lm):
279+
program = dspy.streamify(MyProgram())
280+
output = program("sky")
281+
sync_output = dspy.streaming.apply_sync_streaming(output)
282+
status_messages = []
283+
for value in sync_output:
284+
if isinstance(value, StatusMessage):
285+
status_messages.append(value)
286+
287+
assert len(status_messages) == 2
288+
assert status_messages[0].message == "Calling tool generate_question..."
289+
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."

0 commit comments

Comments
 (0)