|
6 | 6 | import os |
7 | 7 | import pickle |
8 | 8 | import shutil |
| 9 | +from collections import OrderedDict |
9 | 10 |
|
10 | 11 | import pandas as pd |
11 | 12 |
|
@@ -80,6 +81,13 @@ def __getstate__(self): |
80 | 81 | state = {k: v for k, v in state.items() if k != 'memo'} |
81 | 82 | return state |
82 | 83 |
|
| 84 | + def to_df(self, include_params=False): |
| 85 | + out = OrderedDict(trial_no=self.trial_no, succeeded=self.succeeded, reward=self.reward, elapsed=self.elapsed) |
| 86 | + if include_params: |
| 87 | + for p in self.space_sample.get_assigned_params(): |
| 88 | + out[p.alias] = p.value |
| 89 | + return pd.DataFrame({k: [v] for k, v in out.items()}) |
| 90 | + |
83 | 91 |
|
84 | 92 | class TrialHistory(): |
85 | 93 | def __init__(self, optimize_direction): |
@@ -111,15 +119,23 @@ def get_best(self): |
111 | 119 | else: |
112 | 120 | return top1[0] |
113 | 121 |
|
114 | | - def get_top(self, n=10): |
| 122 | + def get_worst(self): |
| 123 | + topn = self.get_top() |
| 124 | + return topn[-1] if len(topn) > 0 else None |
| 125 | + |
| 126 | + def get_top(self, n=None): |
| 127 | + assert n is None or isinstance(n, int) |
| 128 | + |
115 | 129 | valid_trials = [t for t in self.trials if t.succeeded] |
116 | 130 | if len(valid_trials) <= 0: |
117 | 131 | return [] |
118 | 132 | sorted_trials = sorted(valid_trials, key=lambda t: t.reward, |
119 | 133 | reverse=self.optimize_direction in ['max', OptimizeDirection.Maximize]) |
120 | | - if n > len(sorted_trials): |
121 | | - n = len(sorted_trials) |
122 | | - return sorted_trials[:n] |
| 134 | + |
| 135 | + if isinstance(n, int) and n < len(sorted_trials): |
| 136 | + sorted_trials = sorted_trials[:n] |
| 137 | + |
| 138 | + return sorted_trials |
123 | 139 |
|
124 | 140 | def get_space_signatures(self): |
125 | 141 | signatures = set() |
@@ -200,6 +216,24 @@ def load_history(space_fn, filepath): |
200 | 216 | history.append(trial) |
201 | 217 | return history |
202 | 218 |
|
| 219 | + def __repr__(self): |
| 220 | + out = OrderedDict(direction=self.optimize_direction) |
| 221 | + if len(self.trials) > 0: |
| 222 | + tops = self.get_top() |
| 223 | + out['size'] = len(self.trials) |
| 224 | + out['succeeded'] = len(tops) |
| 225 | + if len(tops) > 0: |
| 226 | + out['best_reward'] = tops[0].reward |
| 227 | + out['worst_reward'] = tops[-1].reward |
| 228 | + |
| 229 | + repr_ = ', '.join('%s=%r' % (k, v) for k, v in out.items()) |
| 230 | + return f'{type(self).__name__}({repr_})' |
| 231 | + |
| 232 | + def to_df(self, include_params=False): |
| 233 | + df = pd.concat([t.to_df(include_params) for t in self.trials], axis=0) |
| 234 | + df.reset_index(drop=True, inplace=True) |
| 235 | + return df |
| 236 | + |
203 | 237 | def plot_hyperparams(self, destination='notebook', output='hyperparams.html'): |
204 | 238 | """Plot hyperparams in a parallel line chart |
205 | 239 |
|
|
0 commit comments