@@ -71,26 +71,18 @@ def forward(self, input):
7171 res2 = self .predictor2 .batch ([input ] * 5 )
7272
7373 return (res1 , res2 )
74-
75- result , reason_result = MyModule ()(dspy .Example (input = "test input" ).with_inputs ("input" ))
7674
77- assert result [0 ].output == "test output 1"
78- assert result [1 ].output == "test output 2"
79- assert result [2 ].output == "test output 3"
80- assert result [3 ].output == "test output 4"
81- assert result [4 ].output == "test output 5"
75+ result , reason_result = MyModule ()(dspy .Example (input = "test input" ).with_inputs ("input" ))
8276
83- assert reason_result [0 ].output == "test output 1"
84- assert reason_result [1 ].output == "test output 2"
85- assert reason_result [2 ].output == "test output 3"
86- assert reason_result [3 ].output == "test output 4"
87- assert reason_result [4 ].output == "test output 5"
77+ # Check that we got all expected outputs without caring about order
78+ expected_outputs = {f"test output { i } " for i in range (1 , 6 )}
79+ assert {r .output for r in result } == expected_outputs
80+ assert {r .output for r in reason_result } == expected_outputs
8881
89- assert reason_result [0 ].reasoning == "test reasoning 1"
90- assert reason_result [1 ].reasoning == "test reasoning 2"
91- assert reason_result [2 ].reasoning == "test reasoning 3"
92- assert reason_result [3 ].reasoning == "test reasoning 4"
93- assert reason_result [4 ].reasoning == "test reasoning 5"
82+ # Check that reasoning matches outputs for reason_result
83+ for r in reason_result :
84+ num = r .output .split ()[- 1 ] # get the number from "test output X"
85+ assert r .reasoning == f"test reasoning { num } "
9486
9587
9688def test_nested_parallel_module ():
@@ -120,7 +112,7 @@ def forward(self, input):
120112 (self .predictor , input ),
121113 ]),
122114 ])
123-
115+
124116 output = MyModule ()(dspy .Example (input = "test input" ).with_inputs ("input" ))
125117
126118 assert output [0 ].output == "test output 1"
@@ -148,7 +140,7 @@ def forward(self, input):
148140 res = self .predictor .batch ([dspy .Example (input = input ).with_inputs ("input" )]* 2 )
149141
150142 return res
151-
143+
152144 result = MyModule ().batch ([dspy .Example (input = "test input" ).with_inputs ("input" )]* 2 )
153145
154146 assert {result [0 ][0 ].output , result [0 ][1 ].output , result [1 ][0 ].output , result [1 ][1 ].output } \
0 commit comments