Skip to content

Commit c3ac575

Browse files
authored
Refactor evaluate and introduce construct_result_table method (stanfordnlp#7991)
* set last_result_df when evaluate is called * refactor evaluate logic to add construct_result_df * refactor _display_result_table * fix annotation * Fix result signature * Fix doc
1 parent a8df4a7 commit c3ac575

File tree

2 files changed

+103
-48
lines changed

2 files changed

+103
-48
lines changed

dspy/evaluate/evaluate.py

Lines changed: 78 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import types
3-
from typing import TYPE_CHECKING, Any, Callable, List, Optional
3+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, Tuple
44

55
if TYPE_CHECKING:
66
import pandas as pd
@@ -54,7 +54,7 @@ def __init__(
5454
metric: Optional[Callable] = None,
5555
num_threads: int = 1,
5656
display_progress: bool = False,
57-
display_table: bool = False,
57+
display_table: Union[bool, int] = False,
5858
max_errors: int = 5,
5959
return_all_scores: bool = False,
6060
return_outputs: bool = False,
@@ -68,7 +68,8 @@ def __init__(
6868
metric (Callable): The metric function to use for evaluation.
6969
num_threads (int): The number of threads to use for parallel evaluation.
7070
display_progress (bool): Whether to display progress during evaluation.
71-
display_table (bool): Whether to display the evaluation results in a table.
71+
display_table (Union[bool, int]): Whether to display the evaluation results in a table.
72+
If a number is passed, the evaluation results will be truncated to that number before displayed.
7273
max_errors (int): The maximum number of errors to allow before stopping evaluation.
7374
return_all_scores (bool): Whether to return scores for every data record in `devset`.
7475
return_outputs (bool): Whether to return the dspy program's outputs for every data in `devset`.
@@ -94,7 +95,7 @@ def __call__(
9495
devset: Optional[List["dspy.Example"]] = None,
9596
num_threads: Optional[int] = None,
9697
display_progress: Optional[bool] = None,
97-
display_table: Optional[bool] = None,
98+
display_table: Optional[Union[bool, int]] = None,
9899
return_all_scores: Optional[bool] = None,
99100
return_outputs: Optional[bool] = None,
100101
callback_metadata: Optional[dict[str, Any]] = None,
@@ -108,8 +109,8 @@ def __call__(
108109
`self.num_threads`.
109110
display_progress (bool): Whether to display progress during evaluation. if not provided, use
110111
`self.display_progress`.
111-
display_table (bool): Whether to display the evaluation results in a table. if not provided, use
112-
`self.display_table`.
112+
display_table (Union[bool, int]): Whether to display the evaluation results in a table. if not provided, use
113+
`self.display_table`. If a number is passed, the evaluation results will be truncated to that number before displayed.
113114
return_all_scores (bool): Whether to return scores for every data record in `devset`. if not provided,
114115
use `self.return_all_scores`.
115116
return_outputs (bool): Whether to return the dspy program's outputs for every data in `devset`. if not
@@ -174,12 +175,38 @@ def process_item(example):
174175
ncorrect, ntotal = sum(score for *_, score in results), len(devset)
175176

176177
logger.info(f"Average Metric: {ncorrect} / {ntotal} ({round(100 * ncorrect / ntotal, 1)}%)")
177-
178-
def prediction_is_dictlike(prediction):
179-
# Downstream logic for displaying dictionary-like predictions depends solely on the predictions
180-
# having a method called `items()` for iterating through key/value pairs
181-
return hasattr(prediction, "items") and callable(getattr(prediction, "items"))
182178

179+
# Rename the 'correct' column to the name of the metric object
180+
metric_name = metric.__name__ if isinstance(metric, types.FunctionType) else metric.__class__.__name__
181+
# Construct a pandas DataFrame from the results
182+
result_df = self._construct_result_table(results, metric_name)
183+
184+
if display_table:
185+
self._display_result_table(result_df, display_table, metric_name)
186+
187+
if return_all_scores and return_outputs:
188+
return round(100 * ncorrect / ntotal, 2), results, [score for *_, score in results]
189+
if return_all_scores:
190+
return round(100 * ncorrect / ntotal, 2), [score for *_, score in results]
191+
if return_outputs:
192+
return round(100 * ncorrect / ntotal, 2), results
193+
194+
return round(100 * ncorrect / ntotal, 2)
195+
196+
197+
def _construct_result_table(self, results: list[Tuple[dspy.Example, dspy.Example, Any]], metric_name: str) -> "pd.DataFrame":
198+
"""
199+
Construct a pandas DataFrame from the specified result list.
200+
Let's not try to change the name of this method as it may be patched by external tracing tools.
201+
202+
Args:
203+
results: The list of results to construct the result DataFrame from.
204+
metric_name: The name of the metric used for evaluation.
205+
206+
Returns:
207+
The constructed pandas DataFrame.
208+
"""
209+
import pandas as pd
183210
data = [
184211
(
185212
merge_dicts(example, prediction) | {"correct": score}
@@ -189,50 +216,53 @@ def prediction_is_dictlike(prediction):
189216
for example, prediction, score in results
190217
]
191218

192-
193-
import pandas as pd
194219
# Truncate every cell in the DataFrame (DataFrame.applymap was renamed to DataFrame.map in Pandas 2.1.0)
195220
result_df = pd.DataFrame(data)
196221
result_df = result_df.map(truncate_cell) if hasattr(result_df, "map") else result_df.applymap(truncate_cell)
197222

198-
# Rename the 'correct' column to the name of the metric object
199-
metric_name = metric.__name__ if isinstance(metric, types.FunctionType) else metric.__class__.__name__
200-
result_df = result_df.rename(columns={"correct": metric_name})
223+
return result_df.rename(columns={"correct": metric_name})
201224

202-
if display_table:
203-
if isinstance(display_table, bool):
204-
df_to_display = result_df.copy()
205-
truncated_rows = 0
206-
else:
207-
df_to_display = result_df.head(display_table).copy()
208-
truncated_rows = len(result_df) - display_table
209-
210-
df_to_display = stylize_metric_name(df_to_display, metric_name)
211-
212-
display_dataframe(df_to_display)
213-
214-
if truncated_rows > 0:
215-
# Simplified message about the truncated rows
216-
message = f"""
217-
<div style='
218-
text-align: center;
219-
font-size: 16px;
220-
font-weight: bold;
221-
color: #555;
222-
margin: 10px 0;'>
223-
... {truncated_rows} more rows not displayed ...
224-
</div>
225-
"""
226-
display(HTML(message))
227225

228-
if return_all_scores and return_outputs:
229-
return round(100 * ncorrect / ntotal, 2), results, [score for *_, score in results]
230-
if return_all_scores:
231-
return round(100 * ncorrect / ntotal, 2), [score for *_, score in results]
232-
if return_outputs:
233-
return round(100 * ncorrect / ntotal, 2), results
226+
def _display_result_table(self, result_df: "pd.DataFrame", display_table: Union[bool, int], metric_name: str):
227+
"""
228+
Display the specified result DataFrame in a table format.
234229
235-
return round(100 * ncorrect / ntotal, 2)
230+
Args:
231+
result_df: The result DataFrame to display.
232+
display_table: Whether to display the evaluation results in a table.
233+
If a number is passed, the evaluation results will be truncated to that number before displayed.
234+
metric_name: The name of the metric used for evaluation.
235+
"""
236+
if isinstance(display_table, bool):
237+
df_to_display = result_df.copy()
238+
truncated_rows = 0
239+
else:
240+
df_to_display = result_df.head(display_table).copy()
241+
truncated_rows = len(result_df) - display_table
242+
243+
df_to_display = stylize_metric_name(df_to_display, metric_name)
244+
245+
display_dataframe(df_to_display)
246+
247+
if truncated_rows > 0:
248+
# Simplified message about the truncated rows
249+
message = f"""
250+
<div style='
251+
text-align: center;
252+
font-size: 16px;
253+
font-weight: bold;
254+
color: #555;
255+
margin: 10px 0;'>
256+
... {truncated_rows} more rows not displayed ...
257+
</div>
258+
"""
259+
display(HTML(message))
260+
261+
262+
def prediction_is_dictlike(prediction):
263+
# Downstream logic for displaying dictionary-like predictions depends solely on the predictions
264+
# having a method called `items()` for iterating through key/value pairs
265+
return hasattr(prediction, "items") and callable(getattr(prediction, "items"))
236266

237267

238268
def merge_dicts(d1, d2) -> dict:

tests/evaluate/test_evaluate.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import signal
22
import threading
33
from unittest.mock import patch
4+
import pandas as pd
45

56
import pytest
67

@@ -54,6 +55,30 @@ def test_evaluate_call():
5455
assert score == 100.0
5556

5657

58+
def test_construct_result_df():
59+
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
60+
ev = Evaluate(
61+
devset=devset,
62+
metric=answer_exact_match,
63+
)
64+
results = [
65+
(devset[0], {"answer": "2"}, 100.0),
66+
(devset[1], {"answer": "4"}, 100.0),
67+
]
68+
result_df = ev._construct_result_table(results, answer_exact_match.__name__)
69+
pd.testing.assert_frame_equal(
70+
result_df,
71+
pd.DataFrame(
72+
{
73+
"question": ["What is 1+1?", "What is 2+2?"],
74+
"example_answer": ["2", "4"],
75+
"pred_answer": ["2", "4"],
76+
"answer_exact_match": [100.0, 100.0],
77+
}
78+
)
79+
)
80+
81+
5782
def test_multithread_evaluate_call():
5883
dspy.settings.configure(lm=DummyLM({"What is 1+1?": {"answer": "2"}, "What is 2+2?": {"answer": "4"}}))
5984
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]

0 commit comments

Comments
 (0)