Skip to content

Commit c4aa3f3

Browse files
bsmith89twiecki
authored andcommitted
Catch keyboard interrupt during ADVI (pymc-devs#1357)
1 parent dc0bc29 commit c4aa3f3

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

pymc3/variational/advi.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,21 +136,28 @@ def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False,
136136

137137
# Optimization loop
138138
elbos = np.empty(n)
139-
for i in range(n):
140-
uw_i, e = f()
141-
elbos[i] = e
142-
if verbose and not i % (n // 10):
143-
if not i:
144-
print('Iteration {0} [{1}%]: ELBO = {2}'.format(
145-
i, 100 * i // n, e.round(2)))
146-
else:
147-
avg_elbo = elbos[i - n // 10:i].mean()
148-
print('Iteration {0} [{1}%]: Average ELBO = {2}'.format(
149-
i, 100 * i // n, avg_elbo.round(2)))
150-
151-
if verbose:
152-
avg_elbo = elbos[-n // 10:].mean()
153-
print('Finished [100%]: Average ELBO = {}'.format(avg_elbo.round(2)))
139+
try:
140+
for i in range(n):
141+
uw_i, e = f()
142+
elbos[i] = e
143+
if verbose and not i % (n // 10):
144+
if not i:
145+
print('Iteration {0} [{1}%]: ELBO = {2}'.format(
146+
i, 100 * i // n, e.round(2)))
147+
else:
148+
avg_elbo = elbos[i - n // 10:i].mean()
149+
print('Iteration {0} [{1}%]: Average ELBO = {2}'.format(
150+
i, 100 * i // n, avg_elbo.round(2)))
151+
except KeyboardInterrupt:
152+
if verbose:
153+
elbos = elbos[:i]
154+
avg_elbo = elbos[i - n // 10:].mean()
155+
print('Interrupted at {0} [{1}%]: Average ELBO = {2}'.format(
156+
i, 100 * i // n, avg_elbo.round(2)))
157+
else:
158+
if verbose:
159+
avg_elbo = elbos[-n // 10:].mean()
160+
print('Finished [100%]: Average ELBO = {}'.format(avg_elbo.round(2)))
154161

155162
# Estimated parameters
156163
l = int(uw_i.size / 2)

0 commit comments

Comments
 (0)