Skip to content

Commit f2b2645

Browse files
aseyboldtJunpeng Lao
authored and
Junpeng Lao
committed
Some minor changes (#2643)
* Improve sample docstring * Cast scaling in metropolis to double * Update advi+adapt_diag_grad * Update doc of nuts._Tree.extend * Add adapt_diag_grad to release notes
1 parent 63768f2 commit f2b2645

File tree

5 files changed

+18
-11
lines changed

5 files changed

+18
-11
lines changed

RELEASE-NOTES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Release Notes
22

3+
## PyMC 3.3. (Unreleased)
4+
5+
- Improve NUTS initialization `advi+adapt_diag_grad` and add `jitter+adapt_diag_grad` (#2643)
6+
7+
38
## PyMC3 3.2 (October 10, 2017)
49

510
### New features

pymc3/sampling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,9 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
236236
BLAS. In those cases it might be faster to set this to one.
237237
tune : int
238238
Number of iterations to tune, if applicable (defaults to 500).
239-
These samples will be drawn in addition to samples and discarded
240-
unless discard_tuned_samples is set to True.
239+
Samplers adjust the step sizes, scalings or similar during
240+
tuning. These samples will be drawn in addition to samples
241+
and discarded unless discard_tuned_samples is set to True.
241242
nuts_kwargs : dict
242243
Options for the NUTS sampler. See the docstring of NUTS
243244
for a complete list of options. Common options are

pymc3/step_methods/hmc/nuts.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,10 @@ def extend(self, direction):
285285
If direction is larger than 0, extend it to the right, otherwise
286286
extend it to the left.
287287
288-
Return a tuple `(diverging, turning)` of type (bool, bool).
289-
`diverging` indicates, that the tree extension was aborted because
290-
the energy change exceeded `self.Emax`. `turning` indicates that
288+
Return a tuple `(diverging, turning)`. `diverging` indicates if the
289+
tree extension was aborted because the energy change exceeded
290+
`self.Emax`. If so, it is a tuple containing details about the reason.
291+
Otherwise, it will be `False`. `turning` indicates that
291292
the tree extension was stopped because the termination criterior
292293
was reached (the trajectory is turning back).
293294
"""

pymc3/step_methods/hmc/quadpotential.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,21 +198,21 @@ def _update(self, var):
198198

199199
def adapt(self, sample, grad):
200200
"""Inform the potential about a new sample during tuning."""
201-
self._grads1[:] += grad ** 2
202-
self._grads2[:] += grad ** 2
201+
self._grads1[:] += np.abs(grad)
202+
self._grads2[:] += np.abs(grad)
203203
self._ngrads1 += 1
204204
self._ngrads2 += 1
205205

206206
if self._n_samples <= 150:
207207
super().adapt(sample, grad)
208208
else:
209-
self._update(self._ngrads1 / self._grads1)
209+
self._update((self._ngrads1 / self._grads1) ** 2)
210210

211211
if self._n_samples > 100 and self._n_samples % 100 == 50:
212212
self._ngrads1 = self._ngrads2
213-
self._ngrads2 = 0
213+
self._ngrads2 = 1
214214
self._grads1[:] = self._grads2
215-
self._grads2[:] = 0
215+
self._grads2[:] = 1
216216

217217

218218
class _WeightedVariance(object):

pymc3/step_methods/metropolis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, vars=None, S=None, proposal_dist=None, scaling=1.,
112112
else:
113113
raise ValueError("Invalid rank for variance: %s" % S.ndim)
114114

115-
self.scaling = np.atleast_1d(scaling)
115+
self.scaling = np.atleast_1d(scaling).astype('d')
116116
self.tune = tune
117117
self.tune_interval = tune_interval
118118
self.steps_until_tune = tune_interval

0 commit comments

Comments
 (0)