Skip to content

Commit 5327ae6

Browse files
committed
SQISTEP: Fix the multi-dimensional usage case
1 parent 9efcc37 commit 5327ae6

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

sqistep.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def begin(self, bounds, point0=None, axis=None):
8585

8686
# We make the qxmin, qfmin arrays as big as the points arrays
8787
# even though in fact we use only N-2 elements instead of N.
88-
self.qxmin = np.array([None, None, None])
88+
self.qxmin = np.array([self.mdpoint(None), self.mdpoint(None), self.mdpoint(None)])
8989
self.qfmin = np.array([np.Inf, np.Inf, np.Inf])
9090
self._update_qmins(0)
9191
self.itercnt = 0
@@ -107,6 +107,13 @@ def easiest_sqi_interval(self):
107107
else:
108108
return None
109109

110+
def mdpoint(self, x):
111+
if self.axis is None:
112+
return x
113+
mdx = np.array(self.points[0])
114+
mdx[self.axis] = x
115+
return mdx
116+
110117
def one_step(self):
111118
"""
112119
Perform one iteration of the SQISTEP algorithm, which amounts to
@@ -129,7 +136,8 @@ def one_step(self):
129136
newpoint = self.qxmin[npi_i]
130137
newvalue = self.qfmin[npi_i]
131138
# Convert NPI index to interval index
132-
if newpoint > self.points[npi_i+1]:
139+
if ((self.axis is None and newpoint > self.points[npi_i+1]) or
140+
(self.axis is not None and newpoint[self.axis] > self.points[npi_i+1][self.axis])):
133141
i = npi_i + 1
134142
if self.disp:
135143
print('SQI chose interval %s: x=[%s %s {%s} %s] y=[%s %s {%s} %s]' %
@@ -161,12 +169,14 @@ def one_step(self):
161169
self.difficulty = np.insert(self.difficulty, i+1, np.nan, axis=0)
162170
self.easiest_i_cache = None # we touched .difficulty[]
163171

164-
self.qxmin = np.insert(self.qxmin, i+1, np.nan, axis=0)
172+
self.qxmin = np.insert(self.qxmin, i+1, self.mdpoint(np.nan), axis=0)
165173
self.qfmin = np.insert(self.qfmin, i+1, np.Inf, axis=0)
166174
if i > 0:
167175
self._update_qmins(i-1)
168-
self._update_qmins(i)
169-
self._update_qmins(i+1)
176+
if i < np.size(self.points, axis=0) - 2:
177+
self._update_qmins(i)
178+
if i+1 < np.size(self.points, axis=0) - 2:
179+
self._update_qmins(i+1)
170180

171181
if newvalue < self.fmin:
172182
# New fmin, recompute difficulties of all intervals
@@ -211,15 +221,22 @@ def _nip_qinterp(self, points, values):
211221
# than the sampled points
212222
return (None, np.Inf)
213223

214-
xr = points[1] - points[0]
224+
if self.axis is None:
225+
x0 = points[0]
226+
xr = points[1] - points[0]
227+
xs = points[2] - points[0]
228+
else:
229+
x0 = points[0][self.axis]
230+
xr = points[1][self.axis] - points[0][self.axis]
231+
xs = points[2][self.axis] - points[0][self.axis]
215232
yr = values[1] - values[0]
216-
xs = points[2] - points[0]
217233
ys = values[2] - values[0]
218234
a = (xr * ys - xs * yr) / (xr * xs * (xs - xr))
219235
b = (yr / xr) - (xr * ys - xs * yr) / (xs * (xs - xr))
220-
xm = points[0] - b / (2*a)
236+
xm = x0 - b / (2*a)
221237
ym = values[0] - b**2 / (4*a)
222-
return (xm, ym)
238+
239+
return (self.mdpoint(xm), ym)
223240

224241

225242
def sqistep_minimize(fun, bounds, args=(), maxiter=100, callback=None, axis=None, point0=None, logf=None, staglimit=None, **options):

0 commit comments

Comments
 (0)