Skip to content

Commit c8b5ad6

Browse files
committed
Override __repr__ method
1 parent cbedd0c commit c8b5ad6

File tree

8 files changed

+91
-27
lines changed

8 files changed

+91
-27
lines changed

hypernets/core/searcher.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
"""
33
44
"""
5-
from .stateful import Stateful
65
import enum
76

7+
from hypernets.utils import to_repr
8+
from .stateful import Stateful
9+
810

911
class OptimizeDirection(enum.Enum):
1012
Minimize = 'min'
@@ -63,3 +65,6 @@ def reset(self):
6365

6466
def export(self):
6567
raise NotImplementedError
68+
69+
def __repr__(self):
70+
return to_repr(self)

hypernets/core/trial.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import pickle
88
import shutil
9+
from collections import OrderedDict
910

1011
import pandas as pd
1112

@@ -80,6 +81,13 @@ def __getstate__(self):
8081
state = {k: v for k, v in state.items() if k != 'memo'}
8182
return state
8283

84+
def to_df(self, include_params=False):
85+
out = OrderedDict(trial_no=self.trial_no, succeeded=self.succeeded, reward=self.reward, elapsed=self.elapsed)
86+
if include_params:
87+
for p in self.space_sample.get_assigned_params():
88+
out[p.alias] = p.value
89+
return pd.DataFrame({k: [v] for k, v in out.items()})
90+
8391

8492
class TrialHistory():
8593
def __init__(self, optimize_direction):
@@ -111,15 +119,23 @@ def get_best(self):
111119
else:
112120
return top1[0]
113121

114-
def get_top(self, n=10):
122+
def get_worst(self):
123+
topn = self.get_top()
124+
return topn[-1] if len(topn) > 0 else None
125+
126+
def get_top(self, n=None):
127+
assert n is None or isinstance(n, int)
128+
115129
valid_trials = [t for t in self.trials if t.succeeded]
116130
if len(valid_trials) <= 0:
117131
return []
118132
sorted_trials = sorted(valid_trials, key=lambda t: t.reward,
119133
reverse=self.optimize_direction in ['max', OptimizeDirection.Maximize])
120-
if n > len(sorted_trials):
121-
n = len(sorted_trials)
122-
return sorted_trials[:n]
134+
135+
if isinstance(n, int) and n < len(sorted_trials):
136+
sorted_trials = sorted_trials[:n]
137+
138+
return sorted_trials
123139

124140
def get_space_signatures(self):
125141
signatures = set()
@@ -200,6 +216,24 @@ def load_history(space_fn, filepath):
200216
history.append(trial)
201217
return history
202218

219+
def __repr__(self):
220+
out = OrderedDict(direction=self.optimize_direction)
221+
if len(self.trials) > 0:
222+
tops = self.get_top()
223+
out['size'] = len(self.trials)
224+
out['succeeded'] = len(tops)
225+
if len(tops) > 0:
226+
out['best_reward'] = tops[0].reward
227+
out['worst_reward'] = tops[-1].reward
228+
229+
repr_ = ', '.join('%s=%r' % (k, v) for k, v in out.items())
230+
return f'{type(self).__name__}({repr_})'
231+
232+
def to_df(self, include_params=False):
233+
df = pd.concat([t.to_df(include_params) for t in self.trials], axis=0)
234+
df.reset_index(drop=True, inplace=True)
235+
return df
236+
203237
def plot_hyperparams(self, destination='notebook', output='hyperparams.html'):
204238
"""Plot hyperparams in a parallel line chart
205239

hypernets/model/hyper_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..core.trial import *
1111
from ..discriminators import UnPromisingTrial
1212
from ..dispatchers import get_dispatcher
13-
from ..utils import logging, infer_task_type as _infer_task_type, hash_data, const
13+
from ..utils import logging, infer_task_type as _infer_task_type, hash_data, const, to_repr
1414

1515
logger = logging.get_logger(__name__)
1616

@@ -97,6 +97,11 @@ def _run_trial(self, space_sample, trial_no, X, y, X_eval, y_eval, cv=False, num
9797
elapsed = time.time() - start_time
9898
trial = Trial(space_sample, trial_no, 0, elapsed, succeeded=succeeded)
9999

100+
if self.history is not None:
101+
t = self.history.get_worst()
102+
if t is not None:
103+
self.searcher.update_result(space_sample, t.reward)
104+
100105
return trial
101106

102107
def _get_reward(self, value, key=None):
@@ -220,3 +225,6 @@ def infer_task_type(self, y):
220225

221226
def plot_hyperparams(self, destination='notebook', output='hyperparams.html'):
222227
return self.history.plot_hyperparams(destination, output)
228+
229+
def __repr__(self):
230+
return to_repr(self)

hypernets/searchers/evolution_searcher.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
"""
33
44
"""
5-
import numpy as np
65

7-
from ..core.searcher import Searcher, OptimizeDirection
86
from ..core import get_random_state
7+
from ..core.searcher import Searcher, OptimizeDirection
98
from ..utils import logging
109

1110
logger = logging.get_logger(__name__)
@@ -96,6 +95,7 @@ class EvolutionSearcher(Searcher):
9695
----------
9796
Real, Esteban, et al. "Regularized evolution for image classifier architecture search." Proceedings of the aaai conference on artificial intelligence. Vol. 33. 2019.
9897
"""
98+
9999
def __init__(self, space_fn, population_size, sample_size, regularized=False,
100100
candidates_size=10, optimize_direction=OptimizeDirection.Minimize, use_meta_learner=True,
101101
space_sample_validation_fn=None, random_state=None):
@@ -123,10 +123,15 @@ def __init__(self, space_fn, population_size, sample_size, regularized=False,
123123
Searcher.__init__(self, space_fn=space_fn, optimize_direction=optimize_direction,
124124
use_meta_learner=use_meta_learner, space_sample_validation_fn=space_sample_validation_fn)
125125
self.random_state = random_state if random_state is not None else get_random_state()
126-
self.population = Population(size=population_size, optimize_direction=optimize_direction, random_state=self.random_state)
126+
self.population = Population(size=population_size, optimize_direction=optimize_direction,
127+
random_state=self.random_state)
127128
self.sample_size = sample_size
128129
self.regularized = regularized
129-
self.candidate_size = candidates_size
130+
self.candidates_size = candidates_size
131+
132+
@property
133+
def population_size(self):
134+
return self.population.size
130135

131136
@property
132137
def parallelizable(self):
@@ -151,7 +156,7 @@ def _get_offspring(self, space_sample):
151156
candidates = []
152157
scores = []
153158
no = 0
154-
for i in range(self.candidate_size):
159+
for i in range(self.candidates_size):
155160
new_space = self.space_fn()
156161
try:
157162
candidate = self._sample_and_check(lambda: self.population.mutate(space_sample, new_space))

hypernets/searchers/mcts_searcher.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ def __init__(self, space_fn, policy=None, max_node_space=10, candidates_size=10,
3939
Searcher.__init__(self, space_fn, optimize_direction, use_meta_learner=use_meta_learner,
4040
space_sample_validation_fn=space_sample_validation_fn)
4141
self.nodes_map = {}
42-
self.candidate_size = candidates_size
42+
self.candidates_size = candidates_size
43+
44+
@property
45+
def max_node_space(self):
46+
return self.tree.max_node_space
4347

4448
def parallelizable(self):
4549
return self.use_meta_learner and self.meta_learner is not None
@@ -71,7 +75,7 @@ def sample():
7175
def _select_best_candidate(self, node):
7276
candidates = []
7377
scores = []
74-
for i in range(self.candidate_size):
78+
for i in range(self.candidates_size):
7579
candidate = self._roll_out(node)
7680
candidates.append(candidate)
7781
scores.append(self.meta_learner.predict(candidate, 0.5))

hypernets/searchers/playback_searcher.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88

99

1010
class PlaybackSearcher(Searcher):
11-
def __init__(self, trail_history: TrialHistory, top_n=None, reverse=False,
11+
def __init__(self, history: TrialHistory, top_n=None, reverse=False,
1212
optimize_direction=OptimizeDirection.Minimize):
13-
assert trail_history is not None
14-
assert len(trail_history.trials) > 0
13+
assert history is not None
14+
assert len(history.trials) > 0
1515

16-
self.history = trail_history
17-
self.top_n = top_n if top_n is not None else len(trail_history.trials)
16+
self.history = history
17+
self.top_n = top_n if top_n is not None else len(history.trials)
1818
self.samples = [t.space_sample for t in self.history.get_top(self.top_n)]
1919
self.index = 0
2020
self.reverse = reverse

hypernets/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
from ._doc_lens import DocLens
88
from ._fsutils import filesystem as fs
99
from ._tic_tok import tic_toc, report as tic_toc_report, report_as_dataframe as tic_toc_report_as_dataframe
10-
from .common import generate_id, combinations, isnotebook, Counter, to_repr
10+
from .common import generate_id, combinations, isnotebook, Counter, to_repr, get_params
1111
from .common import infer_task_type, hash_data, hash_dataframe, load_data, load_module

hypernets/utils/common.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import math
1010
import pickle
1111
import uuid
12+
from collections import OrderedDict
1213
from functools import partial
1314
from io import BytesIO
1415

@@ -27,9 +28,9 @@ def generate_id():
2728
return str(uuid.uuid1())
2829

2930

30-
def to_repr(obj):
31+
def get_params(obj, include_default=False):
3132
def _get_init_params(cls):
32-
init = cls.__init__ # getattr(cls.__init__, 'deprecated_original', cls.__init__)
33+
init = cls.__init__
3334
if init is object.__init__:
3435
return []
3536

@@ -38,15 +39,22 @@ def _get_init_params(cls):
3839
if p.name != 'self' and p.kind != p.VAR_KEYWORD]
3940
return parameters
4041

41-
out = []
42-
cls_ = type(obj)
43-
for p in _get_init_params(cls_):
42+
out = OrderedDict()
43+
for p in _get_init_params(type(obj)):
4444
name = p.name
4545
value = getattr(obj, name, None)
46-
if value is not p.default:
47-
out.append('%s=%r' % (name, value))
46+
if include_default or value is not p.default:
47+
out[name] = value
48+
49+
return out
50+
51+
52+
def to_repr(obj, excludes=None):
53+
if excludes is None:
54+
excludes = []
55+
out = ['%s=%r' % (k, v) for k, v in get_params(obj).items() if k not in excludes]
4856
repr_ = ', '.join(out)
49-
return f'{cls_.__name__}({repr_})'
57+
return f'{type(obj).__name__}({repr_})'
5058

5159

5260
def combinations(n, m_max, m_min=1):

0 commit comments

Comments
 (0)