Skip to content

Commit bd95866

Browse files
authored
Add files via upload
1 parent 9d4199a commit bd95866

File tree

1 file changed

+164
-0
lines changed

1 file changed

+164
-0
lines changed

charpter22_EM/em.ipynb

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"### EM"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 1,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"# 导入numpy库 \n",
17+
"import numpy as np\n",
18+
"\n",
19+
"### EM算法过程函数定义\n",
20+
"def em(data, thetas, max_iter=30, eps=1e-3):\n",
21+
" '''\n",
22+
" 输入:\n",
23+
" data:观测数据\n",
24+
" thetas:初始化的估计参数值\n",
25+
" max_iter:最大迭代次数\n",
26+
" eps:收敛阈值\n",
27+
" 输出:\n",
28+
" thetas:估计参数\n",
29+
" '''\n",
30+
" # 初始化似然函数值\n",
31+
" ll_old = -np.infty\n",
32+
" for i in range(max_iter):\n",
33+
" ### E步:求隐变量分布\n",
34+
" # 对数似然\n",
35+
" log_like = np.array([np.sum(data * np.log(theta), axis=1) for theta in thetas])\n",
36+
" # 似然\n",
37+
" like = np.exp(log_like)\n",
38+
" # 求隐变量分布\n",
39+
" ws = like/like.sum(0)\n",
40+
" # 概率加权\n",
41+
" vs = np.array([w[:, None] * data for w in ws])\n",
42+
" ### M步:更新参数值\n",
43+
" thetas = np.array([v.sum(0)/v.sum() for v in vs])\n",
44+
" # 更新似然函数\n",
45+
" ll_new = np.sum([w*l for w, l in zip(ws, log_like)])\n",
46+
" print(\"Iteration: %d\" % (i+1))\n",
47+
" print(\"theta_B = %.2f, theta_C = %.2f, ll = %.2f\" \n",
48+
" % (thetas[0,0], thetas[1,0], ll_new))\n",
49+
" # 满足迭代条件即退出迭代\n",
50+
" if np.abs(ll_new - ll_old) < eps:\n",
51+
" break\n",
52+
" ll_old = ll_new\n",
53+
" return thetas"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": 2,
59+
"metadata": {},
60+
"outputs": [
61+
{
62+
"name": "stdout",
63+
"output_type": "stream",
64+
"text": [
65+
"Iteration: 1\n",
66+
"theta_B = 0.71, theta_C = 0.58, ll = -32.69\n",
67+
"Iteration: 2\n",
68+
"theta_B = 0.75, theta_C = 0.57, ll = -31.26\n",
69+
"Iteration: 3\n",
70+
"theta_B = 0.77, theta_C = 0.55, ll = -30.76\n",
71+
"Iteration: 4\n",
72+
"theta_B = 0.78, theta_C = 0.53, ll = -30.33\n",
73+
"Iteration: 5\n",
74+
"theta_B = 0.79, theta_C = 0.53, ll = -30.07\n",
75+
"Iteration: 6\n",
76+
"theta_B = 0.79, theta_C = 0.52, ll = -29.95\n",
77+
"Iteration: 7\n",
78+
"theta_B = 0.80, theta_C = 0.52, ll = -29.90\n",
79+
"Iteration: 8\n",
80+
"theta_B = 0.80, theta_C = 0.52, ll = -29.88\n",
81+
"Iteration: 9\n",
82+
"theta_B = 0.80, theta_C = 0.52, ll = -29.87\n",
83+
"Iteration: 10\n",
84+
"theta_B = 0.80, theta_C = 0.52, ll = -29.87\n",
85+
"Iteration: 11\n",
86+
"theta_B = 0.80, theta_C = 0.52, ll = -29.87\n",
87+
"Iteration: 12\n",
88+
"theta_B = 0.80, theta_C = 0.52, ll = -29.87\n"
89+
]
90+
}
91+
],
92+
"source": [
93+
"# 观测数据,5次独立试验,每次试验10次抛掷的正反次数\n",
94+
"# 比如第一次试验为5次正面5次反面\n",
95+
"observed_data = np.array([(5,5), (9,1), (8,2), (4,6), (7,3)])\n",
96+
"# 初始化参数值,即硬币B的正面概率为0.6,硬币C的正面概率为0.5\n",
97+
"thetas = np.array([[0.6, 0.4], [0.5, 0.5]])\n",
98+
"thetas = em(observed_data, thetas, max_iter=30, eps=1e-3)"
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"execution_count": 3,
104+
"metadata": {},
105+
"outputs": [
106+
{
107+
"data": {
108+
"text/plain": [
109+
"array([[0.7967829 , 0.2032171 ],\n",
110+
" [0.51959543, 0.48040457]])"
111+
]
112+
},
113+
"execution_count": 3,
114+
"metadata": {},
115+
"output_type": "execute_result"
116+
}
117+
],
118+
"source": [
119+
"thetas"
120+
]
121+
},
122+
{
123+
"cell_type": "code",
124+
"execution_count": null,
125+
"metadata": {},
126+
"outputs": [],
127+
"source": []
128+
}
129+
],
130+
"metadata": {
131+
"kernelspec": {
132+
"display_name": "Python 3",
133+
"language": "python",
134+
"name": "python3"
135+
},
136+
"language_info": {
137+
"codemirror_mode": {
138+
"name": "ipython",
139+
"version": 3
140+
},
141+
"file_extension": ".py",
142+
"mimetype": "text/x-python",
143+
"name": "python",
144+
"nbconvert_exporter": "python",
145+
"pygments_lexer": "ipython3",
146+
"version": "3.7.3"
147+
},
148+
"toc": {
149+
"base_numbering": 1,
150+
"nav_menu": {},
151+
"number_sections": true,
152+
"sideBar": true,
153+
"skip_h1_title": false,
154+
"title_cell": "Table of Contents",
155+
"title_sidebar": "Contents",
156+
"toc_cell": false,
157+
"toc_position": {},
158+
"toc_section_display": true,
159+
"toc_window_display": false
160+
}
161+
},
162+
"nbformat": 4,
163+
"nbformat_minor": 4
164+
}

0 commit comments

Comments
 (0)