Skip to content

Commit f20aca4

Browse files
aloctavodiaspringcoil
authored andcommitted
add jitter to fast_kde (#1629)
* add jitter to fast_kde, prevents errors when input values are all the same * remove unnecessary print
1 parent ca7f68d commit f20aca4

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

pymc3/plots.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22
from scipy.stats import kde, mode
3-
from numpy.linalg import LinAlgError
43
import matplotlib.pyplot as plt
54
import pymc3 as pm
65
from .stats import quantiles, hpd
@@ -122,26 +121,18 @@ def histplot_op(ax, data, alpha=.35):
122121

123122

124123
def kdeplot_op(ax, data, prior=None, prior_alpha=1, prior_style='--'):
125-
errored = []
126124
for i in range(data.shape[1]):
127125
d = data[:, i]
128-
try:
129-
density, l, u = fast_kde(d)
130-
x = np.linspace(l, u, len(density))
131-
132-
if prior is not None:
133-
p = prior.logp(x).eval()
134-
ax.plot(x, np.exp(p), alpha=prior_alpha, ls=prior_style)
126+
density, l, u = fast_kde(d)
127+
x = np.linspace(l, u, len(density))
135128

136-
ax.plot(x, density)
129+
if prior is not None:
130+
p = prior.logp(x).eval()
131+
ax.plot(x, np.exp(p), alpha=prior_alpha, ls=prior_style)
137132

138-
except LinAlgError:
139-
errored.append(i)
133+
ax.plot(x, density)
140134

141135
ax.set_ylim(ymin=0)
142-
if errored:
143-
ax.text(.27, .47, 'WARNING: KDE plot failed for: ' + str(errored), style='italic',
144-
bbox={'facecolor': 'red', 'alpha': 0.5, 'pad': 10})
145136

146137

147138
def make_2d(a):
@@ -793,6 +784,7 @@ def get_trace_dict(tr, varnames):
793784

794785
fig.tight_layout()
795786
return ax
787+
796788

797789
def fast_kde(x):
798790
"""
@@ -813,14 +805,16 @@ def fast_kde(x):
813805
xmax: maximum value of x
814806
815807
"""
808+
# add small jitter in case input values are the same
809+
x = np.random.normal(x, 1e-12)
816810

817811
xmin, xmax = x.min(), x.max()
818812

819813
n = len(x)
820814
nx = 256
821815

822816
# compute histogram
823-
bins = np.linspace(x.min(), x.max(), nx)
817+
bins = np.linspace(xmin, xmax, nx)
824818
xyi = np.digitize(x, bins)
825819
dx = (xmax - xmin) / (nx - 1)
826820
grid = np.histogram(x, bins=nx)[0]

0 commit comments

Comments
 (0)