Skip to content

Commit 1938f66

Browse files
committed
Exp std estimate from advi.
1 parent 3da3e6c commit 1938f66

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

pymc3/examples/advi.ipynb

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@
8080
{
8181
"data": {
8282
"text/plain": [
83-
"({'mu': array(-0.04257040855843567), 'sd_log': array(0.04718829740487737)},\n",
84-
" {'mu': array(-2.1597198585453112), 'sd_log': array(-2.5448566614228585)})"
83+
"({'mu': array(-0.00185745477891906), 'sd_log': array(0.08483499873770008)},\n",
84+
" {'mu': 0.11968501081096175, 'sd_log': 0.078377195440148151})"
8585
]
8686
},
8787
"execution_count": 4,
@@ -95,7 +95,7 @@
9595
},
9696
{
9797
"cell_type": "code",
98-
"execution_count": 9,
98+
"execution_count": 5,
9999
"metadata": {
100100
"collapsed": false
101101
},
@@ -107,6 +107,14 @@
107107
"\r",
108108
" [-----------------100%-----------------] 500 of 500 complete in 0.2 sec"
109109
]
110+
},
111+
{
112+
"name": "stderr",
113+
"output_type": "stream",
114+
"text": [
115+
"/home/wiecki/envs/pymc3/local/lib/python2.7/site-packages/theano/scan_module/scan_perform_ext.py:133: RuntimeWarning: numpy.ndarray size changed, may indicate binary incompatibility\n",
116+
" from scan_perform.scan_perform import *\n"
117+
]
110118
}
111119
],
112120
"source": [
@@ -117,7 +125,7 @@
117125
},
118126
{
119127
"cell_type": "code",
120-
"execution_count": 16,
128+
"execution_count": 7,
121129
"metadata": {
122130
"collapsed": false
123131
},
@@ -126,19 +134,28 @@
126134
"name": "stdout",
127135
"output_type": "stream",
128136
"text": [
129-
"-0.0670183517618\n",
130-
"0.0498948596914\n",
131-
"-2.17277976734\n",
132-
"-2.68756925176\n"
137+
"-0.0234951954338\n",
138+
"0.084690267036\n",
139+
"0.117815331104\n",
140+
"0.0747069952536\n"
133141
]
134142
}
135143
],
136144
"source": [
137145
"print trace['mu'].mean()\n",
138146
"print trace['sd_log'].mean()\n",
139-
"print np.log(trace['mu'].std())\n",
140-
"print np.log(trace['sd_log'].std())"
147+
"print trace['mu'].std()\n",
148+
"print trace['sd_log'].std()"
141149
]
150+
},
151+
{
152+
"cell_type": "code",
153+
"execution_count": null,
154+
"metadata": {
155+
"collapsed": true
156+
},
157+
"outputs": [],
158+
"source": []
142159
}
143160
],
144161
"metadata": {

pymc3/variational/ADVI.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def advi(vars=None, start=None, model=None, n=5000):
4242

4343
u = bij.rmap(result[:l])
4444
w = bij.rmap(result[l:])
45+
# w is in log space
46+
for var in w.keys():
47+
w[var] = np.exp(w[var])
4548
return u, w
4649

4750
def run_adagrad(uw, grad, inarray, n):

0 commit comments

Comments
 (0)