11import logging
22import types
3- from typing import TYPE_CHECKING , Any , Callable , List , Optional
3+ from typing import TYPE_CHECKING , Any , Callable , List , Optional , Union , Tuple
44
55if TYPE_CHECKING :
66 import pandas as pd
@@ -54,7 +54,7 @@ def __init__(
5454 metric : Optional [Callable ] = None ,
5555 num_threads : int = 1 ,
5656 display_progress : bool = False ,
57- display_table : bool = False ,
57+ display_table : Union [ bool , int ] = False ,
5858 max_errors : int = 5 ,
5959 return_all_scores : bool = False ,
6060 return_outputs : bool = False ,
@@ -68,7 +68,8 @@ def __init__(
6868 metric (Callable): The metric function to use for evaluation.
6969 num_threads (int): The number of threads to use for parallel evaluation.
7070 display_progress (bool): Whether to display progress during evaluation.
71- display_table (bool): Whether to display the evaluation results in a table.
71+ display_table (Union[bool, int]): Whether to display the evaluation results in a table.
72+ If a number is passed, the evaluation results will be truncated to that number before displayed.
7273 max_errors (int): The maximum number of errors to allow before stopping evaluation.
7374 return_all_scores (bool): Whether to return scores for every data record in `devset`.
7475 return_outputs (bool): Whether to return the dspy program's outputs for every data in `devset`.
@@ -94,7 +95,7 @@ def __call__(
9495 devset : Optional [List ["dspy.Example" ]] = None ,
9596 num_threads : Optional [int ] = None ,
9697 display_progress : Optional [bool ] = None ,
97- display_table : Optional [bool ] = None ,
98+ display_table : Optional [Union [ bool , int ] ] = None ,
9899 return_all_scores : Optional [bool ] = None ,
99100 return_outputs : Optional [bool ] = None ,
100101 callback_metadata : Optional [dict [str , Any ]] = None ,
@@ -108,8 +109,8 @@ def __call__(
108109 `self.num_threads`.
109110 display_progress (bool): Whether to display progress during evaluation. if not provided, use
110111 `self.display_progress`.
111- display_table (bool): Whether to display the evaluation results in a table. if not provided, use
112- `self.display_table`.
112+ display_table (Union[ bool, int] ): Whether to display the evaluation results in a table. if not provided, use
113+ `self.display_table`. If a number is passed, the evaluation results will be truncated to that number before displayed.
113114 return_all_scores (bool): Whether to return scores for every data record in `devset`. if not provided,
114115 use `self.return_all_scores`.
115116 return_outputs (bool): Whether to return the dspy program's outputs for every data in `devset`. if not
@@ -174,12 +175,38 @@ def process_item(example):
174175 ncorrect , ntotal = sum (score for * _ , score in results ), len (devset )
175176
176177 logger .info (f"Average Metric: { ncorrect } / { ntotal } ({ round (100 * ncorrect / ntotal , 1 )} %)" )
177-
178- def prediction_is_dictlike (prediction ):
179- # Downstream logic for displaying dictionary-like predictions depends solely on the predictions
180- # having a method called `items()` for iterating through key/value pairs
181- return hasattr (prediction , "items" ) and callable (getattr (prediction , "items" ))
182178
179+ # Rename the 'correct' column to the name of the metric object
180+ metric_name = metric .__name__ if isinstance (metric , types .FunctionType ) else metric .__class__ .__name__
181+ # Construct a pandas DataFrame from the results
182+ result_df = self ._construct_result_table (results , metric_name )
183+
184+ if display_table :
185+ self ._display_result_table (result_df , display_table , metric_name )
186+
187+ if return_all_scores and return_outputs :
188+ return round (100 * ncorrect / ntotal , 2 ), results , [score for * _ , score in results ]
189+ if return_all_scores :
190+ return round (100 * ncorrect / ntotal , 2 ), [score for * _ , score in results ]
191+ if return_outputs :
192+ return round (100 * ncorrect / ntotal , 2 ), results
193+
194+ return round (100 * ncorrect / ntotal , 2 )
195+
196+
197+ def _construct_result_table (self , results : list [Tuple [dspy .Example , dspy .Example , Any ]], metric_name : str ) -> "pd.DataFrame" :
198+ """
199+ Construct a pandas DataFrame from the specified result list.
200+ Let's not try to change the name of this method as it may be patched by external tracing tools.
201+
202+ Args:
203+ results: The list of results to construct the result DataFrame from.
204+ metric_name: The name of the metric used for evaluation.
205+
206+ Returns:
207+ The constructed pandas DataFrame.
208+ """
209+ import pandas as pd
183210 data = [
184211 (
185212 merge_dicts (example , prediction ) | {"correct" : score }
@@ -189,50 +216,53 @@ def prediction_is_dictlike(prediction):
189216 for example , prediction , score in results
190217 ]
191218
192-
193- import pandas as pd
194219 # Truncate every cell in the DataFrame (DataFrame.applymap was renamed to DataFrame.map in Pandas 2.1.0)
195220 result_df = pd .DataFrame (data )
196221 result_df = result_df .map (truncate_cell ) if hasattr (result_df , "map" ) else result_df .applymap (truncate_cell )
197222
198- # Rename the 'correct' column to the name of the metric object
199- metric_name = metric .__name__ if isinstance (metric , types .FunctionType ) else metric .__class__ .__name__
200- result_df = result_df .rename (columns = {"correct" : metric_name })
223+ return result_df .rename (columns = {"correct" : metric_name })
201224
202- if display_table :
203- if isinstance (display_table , bool ):
204- df_to_display = result_df .copy ()
205- truncated_rows = 0
206- else :
207- df_to_display = result_df .head (display_table ).copy ()
208- truncated_rows = len (result_df ) - display_table
209-
210- df_to_display = stylize_metric_name (df_to_display , metric_name )
211-
212- display_dataframe (df_to_display )
213-
214- if truncated_rows > 0 :
215- # Simplified message about the truncated rows
216- message = f"""
217- <div style='
218- text-align: center;
219- font-size: 16px;
220- font-weight: bold;
221- color: #555;
222- margin: 10px 0;'>
223- ... { truncated_rows } more rows not displayed ...
224- </div>
225- """
226- display (HTML (message ))
227225
228- if return_all_scores and return_outputs :
229- return round (100 * ncorrect / ntotal , 2 ), results , [score for * _ , score in results ]
230- if return_all_scores :
231- return round (100 * ncorrect / ntotal , 2 ), [score for * _ , score in results ]
232- if return_outputs :
233- return round (100 * ncorrect / ntotal , 2 ), results
226+ def _display_result_table (self , result_df : "pd.DataFrame" , display_table : Union [bool , int ], metric_name : str ):
227+ """
228+ Display the specified result DataFrame in a table format.
234229
235- return round (100 * ncorrect / ntotal , 2 )
230+ Args:
231+ result_df: The result DataFrame to display.
232+ display_table: Whether to display the evaluation results in a table.
233+ If a number is passed, the evaluation results will be truncated to that number before displayed.
234+ metric_name: The name of the metric used for evaluation.
235+ """
236+ if isinstance (display_table , bool ):
237+ df_to_display = result_df .copy ()
238+ truncated_rows = 0
239+ else :
240+ df_to_display = result_df .head (display_table ).copy ()
241+ truncated_rows = len (result_df ) - display_table
242+
243+ df_to_display = stylize_metric_name (df_to_display , metric_name )
244+
245+ display_dataframe (df_to_display )
246+
247+ if truncated_rows > 0 :
248+ # Simplified message about the truncated rows
249+ message = f"""
250+ <div style='
251+ text-align: center;
252+ font-size: 16px;
253+ font-weight: bold;
254+ color: #555;
255+ margin: 10px 0;'>
256+ ... { truncated_rows } more rows not displayed ...
257+ </div>
258+ """
259+ display (HTML (message ))
260+
261+
262+ def prediction_is_dictlike (prediction ):
263+ # Downstream logic for displaying dictionary-like predictions depends solely on the predictions
264+ # having a method called `items()` for iterating through key/value pairs
265+ return hasattr (prediction , "items" ) and callable (getattr (prediction , "items" ))
236266
237267
238268def merge_dicts (d1 , d2 ) -> dict :
0 commit comments