Skip to content

Commit 593bcda

Browse files
committed
Switched from pickle to dill
1 parent bf4bf83 commit 593bcda

File tree

3 files changed

+170
-3
lines changed

3 files changed

+170
-3
lines changed

edbo/bro.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pandas as pd
66
import numpy as np
77

8-
import pickle
8+
import dill
99

1010
from gpytorch.priors import GammaPrior
1111

@@ -401,7 +401,7 @@ def save(self, path='BO.pkl'):
401401
"""
402402

403403
file = open(path, 'wb')
404-
pickle.dump(self.__dict__, file)
404+
dill.dump(self.__dict__, file)
405405
file.close()
406406

407407
# Load BO instance
@@ -419,7 +419,7 @@ def load(self, path='BO.pkl'):
419419
"""
420420

421421
file = open(path, 'rb')
422-
tmp_dict = pickle.load(file)
422+
tmp_dict = dill.load(file)
423423
file.close()
424424

425425
self.__dict__.update(tmp_dict)

install.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ conda install -y pandas=0.25.3 numpy=1.17.4 xlrd
3232
conda install -y pytorch=1.3.1 cudatoolkit=10.1 torchvision -c pytorch
3333
conda install -y scikit-learn=0.22.1
3434
conda install -y matplotlib seaborn
35+
conda install -y dill
3536

3637
pip install gpytorch==1.0.0 pyclustering==0.9.3.1
3738
pip install pyro-ppl==1.1

tests/pickle_test.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
############################################################################## Setup
2+
"""
3+
1D Bayesian Optimization Test:
4+
(1) Gemerate 1D objective.
5+
(2) Initialize with data.
6+
(3) Test predictions, variance estimation, and sampling.
7+
(4) Run single iteration of each acquisition function.
8+
"""
9+
10+
# Imports
11+
12+
import numpy as np
13+
import pandas as pd
14+
from edbo.bro import BO_express
15+
from edbo.pd_utils import to_torch, torch_to_numpy
16+
import matplotlib.pyplot as plt
17+
import random
18+
19+
############################################################################## Test Functions
20+
21+
# Objective
22+
23+
def random_result(*kwargs):
24+
"""Random objective."""
25+
26+
return round(random.random(),3) * 100
27+
28+
# Test a precomputed objective
29+
30+
def BO_pred(acq_func, plot=False, return_='pred', append=False, init='rand'):
31+
32+
# Define reaction space and auto-encode
33+
n_ligands = random.sample([3,4,5,6,7,8], 1)[0]
34+
ligands = pd.read_csv('ligands.csv').sample(n_ligands).values.flatten()
35+
bases = ['DBU', 'MTBD', 'potassium carbonate', 'potassium phosphate', 'potassium tert-butoxide']
36+
reaction_components={'aryl_halide':['chlorobenzene','iodobenzene','bromobenzene'],
37+
'base':bases,
38+
'solvent':['THF', 'Toluene', 'DMSO', 'DMAc'],
39+
'ligand':ligands,
40+
'concentration':[0.1, 0.2, 0.3],
41+
'temperature': [20, 30, 40]
42+
}
43+
encoding={
44+
'aryl_halide':'resolve',
45+
'base':'resolve',
46+
'solvent':'resolve',
47+
'ligand':'mordred',
48+
'concentration':'numeric',
49+
'temperature':'numeric'}
50+
51+
# Instatiate BO class
52+
bo = BO_express(reaction_components=reaction_components,
53+
encoding=encoding,
54+
acquisition_function=acq_func,
55+
init_method=init,
56+
batch_size=random.sample(range(30),1)[0],
57+
computational_objective=random_result,
58+
target='yield')
59+
60+
bo.init_sample(append=True)
61+
bo.run(append=append)
62+
bo.save()
63+
bo = BO_express()
64+
bo.load()
65+
66+
# Check prediction
67+
if return_ == 'pred':
68+
69+
try:
70+
bo.model.predict(to_torch(bo.obj.domain)) # torch.tensor
71+
bo.model.predict(bo.obj.domain.values) # numpy.array
72+
bo.model.predict(list(bo.obj.domain.values)) # list
73+
bo.model.predict(bo.obj.domain) # pandas.DataFrame
74+
except:
75+
return False
76+
77+
return True
78+
79+
# Check predictive postrior variance
80+
elif return_ == 'var':
81+
82+
try:
83+
bo.model.predict(to_torch(bo.obj.domain)) # torch.tensor
84+
bo.model.predict(bo.obj.domain.values) # numpy.array
85+
bo.model.predict(list(bo.obj.domain.values)) # list
86+
bo.model.predict(bo.obj.domain) # pandas.DataFrame
87+
except:
88+
return False
89+
90+
return True
91+
92+
# Make sure sampling works with tensors, arrays, lists, and DataFrames
93+
elif return_ == 'sample':
94+
try:
95+
bo.model.sample_posterior(to_torch(bo.obj.domain)) # torch.tensor
96+
bo.model.sample_posterior(bo.obj.domain.values) # numpy.array
97+
bo.model.sample_posterior(list(bo.obj.domain.values)) # list
98+
bo.model.sample_posterior(bo.obj.domain) # pandas.DataFrame
99+
return True
100+
except:
101+
return False
102+
103+
# Plot model
104+
elif return_ == 'plot':
105+
mean = bo.obj.scaler.unstandardize(bo.model.predict(bo.obj.domain))
106+
std = np.sqrt(bo.model.variance(bo.obj.domain)) * bo.obj.scaler.std * 2
107+
samples = bo.obj.scaler.unstandardize(bo.model.sample_posterior(bo.obj.domain, batch_size=3))
108+
109+
plt.figure(1, figsize=(6,6))
110+
111+
# Model mean and standard deviation
112+
plt.subplot(211)
113+
plt.plot(range(len(mean)), mean, label='GP')
114+
plt.fill_between(range(len(mean)), mean-std, mean+std, alpha=0.4)
115+
# Known results and next selected point
116+
plt.scatter(bo.obj.results_input().index.values, bo.obj.results_input()['yield'], color='black', label='known')
117+
plt.ylabel('f(x)')
118+
# Samples
119+
plt.subplot(212)
120+
for sample in samples:
121+
plt.plot(range(len(mean)), torch_to_numpy(sample))
122+
plt.xlabel('x')
123+
plt.ylabel('Posterior Samples')
124+
plt.show()
125+
126+
return True
127+
128+
elif return_ == 'simulate':
129+
130+
if init != 'external':
131+
bo.init_seq.batch_size = random.sample([2,3,4,5,6,7,8,9,10],1)[0]
132+
133+
bo.simulate(iterations=3)
134+
bo.plot_convergence()
135+
bo.model.regression()
136+
137+
return True
138+
139+
140+
############################################################################## Tests
141+
142+
# Test predicted mean and variance, sampling, and ploting
143+
144+
def test_BO_pred_mean_TS():
145+
assert BO_pred('TS', return_='pred')
146+
147+
def test_BO_var():
148+
assert BO_pred('TS', return_='var')
149+
150+
def test_BO_sample():
151+
assert BO_pred('TS', return_='sample')
152+
153+
def test_BO_plot():
154+
assert BO_pred('TS', return_='plot')
155+
156+
# Test simulations
157+
158+
def test_BO_simulate_TS():
159+
assert BO_pred('TS', return_='simulate')
160+
161+
def test_BO_simulate_EI():
162+
assert BO_pred('EI', return_='simulate')
163+
164+
165+
166+

0 commit comments

Comments
 (0)