@@ -225,3 +225,65 @@ def __call__(self, x: str, **kwargs):
225225
226226 assert all_chunks [- 1 ].predict_name == "predict2"
227227 assert all_chunks [- 1 ].signature_field_name == "judgement"
228+
229+
230+ @pytest .mark .skipif (not os .getenv ("OPENAI_API_KEY" ), reason = "OpenAI API key not found in environment variables" )
231+ def test_sync_streaming ():
232+ class MyProgram (dspy .Module ):
233+ def __init__ (self ):
234+ self .predict1 = dspy .Predict ("question->answer" )
235+ self .predict2 = dspy .Predict ("question, answer->judgement" )
236+
237+ def __call__ (self , x : str , ** kwargs ):
238+ answer = self .predict1 (question = x , ** kwargs )
239+ judgement = self .predict2 (question = x , answer = answer , ** kwargs )
240+ return judgement
241+
242+ # Turn off the cache to ensure the stream is produced.
243+ dspy .settings .configure (lm = dspy .LM ("openai/gpt-4o-mini" , cache = False ))
244+ my_program = MyProgram ()
245+ program = dspy .streamify (
246+ my_program ,
247+ stream_listeners = [
248+ dspy .streaming .StreamListener (signature_field_name = "answer" ),
249+ dspy .streaming .StreamListener (signature_field_name = "judgement" ),
250+ ],
251+ include_final_prediction_in_output_stream = False ,
252+ )
253+ output = program (x = "why did a chicken cross the kitchen?" )
254+ sync_output = dspy .streaming .apply_sync_streaming (output )
255+ all_chunks = []
256+ for value in sync_output :
257+ if isinstance (value , dspy .streaming .StreamResponse ):
258+ all_chunks .append (value )
259+
260+ assert all_chunks [0 ].predict_name == "predict1"
261+ assert all_chunks [0 ].signature_field_name == "answer"
262+
263+ assert all_chunks [- 1 ].predict_name == "predict2"
264+ assert all_chunks [- 1 ].signature_field_name == "judgement"
265+
266+
267+ def test_sync_status_streaming ():
268+ class MyProgram (dspy .Module ):
269+ def __init__ (self ):
270+ self .generate_question = dspy .Tool (lambda x : f"What color is the { x } ?" , name = "generate_question" )
271+ self .predict = dspy .Predict ("question->answer" )
272+
273+ def __call__ (self , x : str ):
274+ question = self .generate_question (x = x )
275+ return self .predict (question = question )
276+
277+ lm = dspy .utils .DummyLM ([{"answer" : "red" }, {"answer" : "blue" }])
278+ with dspy .context (lm = lm ):
279+ program = dspy .streamify (MyProgram ())
280+ output = program ("sky" )
281+ sync_output = dspy .streaming .apply_sync_streaming (output )
282+ status_messages = []
283+ for value in sync_output :
284+ if isinstance (value , StatusMessage ):
285+ status_messages .append (value )
286+
287+ assert len (status_messages ) == 2
288+ assert status_messages [0 ].message == "Calling tool generate_question..."
289+ assert status_messages [1 ].message == "Tool calling finished! Querying the LLM with tool calling results..."
0 commit comments