Skip to content

Commit 1ed8274

Browse files
authored
hmm code.
1 parent 490aa61 commit 1ed8274

File tree

1 file changed

+236
-0
lines changed

1 file changed

+236
-0
lines changed

charpter23_hmm/charpter_23.ipynb

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 3,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import numpy as np\n",
10+
"\n",
11+
"class HMM(object):\n",
12+
" def __init__(self, N, M, pi=None, A=None, B=None):\n",
13+
" # 可能的状态数\n",
14+
" self.N = N\n",
15+
" # 可能的观测数\n",
16+
" self.M = M\n",
17+
" # 初始状态概率向量\n",
18+
" self.pi = pi\n",
19+
" # 状态转移概率矩阵\n",
20+
" self.A = A\n",
21+
" # 观测概率矩阵\n",
22+
" self.B = B\n",
23+
"\n",
24+
" # 根据给定的概率分布随机返回数据\n",
25+
" def rdistribution(self, dist): \n",
26+
" r = np.random.rand()\n",
27+
" for ix, p in enumerate(dist):\n",
28+
" if r < p: \n",
29+
" return ix\n",
30+
" r -= p\n",
31+
"\n",
32+
" # 生成HMM观测序列\n",
33+
" def generate(self, T):\n",
34+
" # 根据初始概率分布生成第一个状态\n",
35+
" i = self.rdistribution(self.pi) \n",
36+
" # 生成第一个观测数据\n",
37+
" o = self.rdistribution(self.B[i]) \n",
38+
" observed_data = [o]\n",
39+
" # 遍历生成剩下的状态和观测数据\n",
40+
" for _ in range(T-1): \n",
41+
" i = self.rdistribution(self.A[i])\n",
42+
" o = self.rdistribution(self.B[i])\n",
43+
" observed_data.append(o)\n",
44+
" return observed_data"
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": 4,
50+
"metadata": {},
51+
"outputs": [
52+
{
53+
"name": "stdout",
54+
"output_type": "stream",
55+
"text": [
56+
"[1, 0, 0, 1, 0]\n"
57+
]
58+
}
59+
],
60+
"source": [
61+
"pi = np.array([0.25, 0.25, 0.25, 0.25])\n",
62+
"A = np.array([\n",
63+
" [0, 1, 0, 0],\n",
64+
" [0.4, 0, 0.6, 0],\n",
65+
" [0, 0.4, 0, 0.6],\n",
66+
" [0, 0, 0.5, 0.5]])\n",
67+
"B = np.array([\n",
68+
" [0.5, 0.5],\n",
69+
" [0.6, 0.4],\n",
70+
" [0.2, 0.8],\n",
71+
" [0.3, 0.7]])\n",
72+
"hmm = HMM(4, 2, pi, A, B)\n",
73+
"print(hmm.generate(5))"
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": 5,
79+
"metadata": {},
80+
"outputs": [
81+
{
82+
"name": "stdout",
83+
"output_type": "stream",
84+
"text": [
85+
"0.01983169125\n"
86+
]
87+
}
88+
],
89+
"source": [
90+
"### 前向算法计算条件概率\n",
91+
"def prob_calc(O):\n",
92+
" '''\n",
93+
" 输入:\n",
94+
" O:观测序列\n",
95+
" 输出:\n",
96+
" alpha.sum():条件概率\n",
97+
" '''\n",
98+
" # 初值\n",
99+
" alpha = pi * B[:, O[0]]\n",
100+
" # 递推\n",
101+
" for o in O[1:]:\n",
102+
" alpha = np.sum(A * alpha.reshape(-1,1) * B[:,o].reshape(1,-1), axis=0)\n",
103+
" return alpha.sum()\n",
104+
"\n",
105+
"# 给定观测\n",
106+
"O = [1,0,1,0,0]\n",
107+
"print(prob_calc(O))"
108+
]
109+
},
110+
{
111+
"cell_type": "code",
112+
"execution_count": 6,
113+
"metadata": {},
114+
"outputs": [
115+
{
116+
"name": "stdout",
117+
"output_type": "stream",
118+
"text": [
119+
"0.01983169125\n"
120+
]
121+
}
122+
],
123+
"source": [
124+
"### 前向算法计算条件概率\n",
125+
"def prob_calc(O):\n",
126+
" '''\n",
127+
" 输入:\n",
128+
" O:观测序列\n",
129+
" 输出:\n",
130+
" alpha.sum():条件概率\n",
131+
" '''\n",
132+
" # 初值\n",
133+
" alpha = pi * B[:, O[0]]\n",
134+
" # 递推\n",
135+
" for o in O[1:]:\n",
136+
" alpha_next = np.empty(4)\n",
137+
" for j in range(4):\n",
138+
" alpha_next[j] = np.sum(A[:,j] * alpha * B[j,o])\n",
139+
" alpha = alpha_next\n",
140+
" return alpha.sum()\n",
141+
"\n",
142+
"# 给定观测\n",
143+
"O = [1,0,1,0,0]\n",
144+
"print(prob_calc(O))"
145+
]
146+
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": 7,
150+
"metadata": {},
151+
"outputs": [
152+
{
153+
"name": "stdout",
154+
"output_type": "stream",
155+
"text": [
156+
"[0, 1, 2, 3, 3]\n"
157+
]
158+
}
159+
],
160+
"source": [
161+
"### 序列标注问题和维特比算法\n",
162+
"def viterbi_decode(O):\n",
163+
" '''\n",
164+
" 输入:\n",
165+
" O:观测序列\n",
166+
" 输出:\n",
167+
" path:最优隐状态路径\n",
168+
" ''' \n",
169+
" # 序列长度和初始观测\n",
170+
" T, o = len(O), O[0]\n",
171+
" # 初始化delta变量\n",
172+
" delta = pi * B[:, o]\n",
173+
" # 初始化varphi变量\n",
174+
" varphi = np.zeros((T, 4), dtype=int)\n",
175+
" path = [0] * T\n",
176+
" # 递推\n",
177+
" for i in range(1, T):\n",
178+
" delta = delta.reshape(-1, 1) \n",
179+
" tmp = delta * A\n",
180+
" varphi[i, :] = np.argmax(tmp, axis=0)\n",
181+
" delta = np.max(tmp, axis=0) * B[:, O[i]]\n",
182+
" # 终止\n",
183+
" path[-1] = np.argmax(delta)\n",
184+
" # 回溯最优路径\n",
185+
" for i in range(T-1, 0, -1):\n",
186+
" path[i-1] = varphi[i, path[i]]\n",
187+
" return path\n",
188+
"\n",
189+
"# 给定观测序列\n",
190+
"O = [1,0,1,1,0]\n",
191+
"print(viterbi_decode(O))"
192+
]
193+
},
194+
{
195+
"cell_type": "code",
196+
"execution_count": null,
197+
"metadata": {},
198+
"outputs": [],
199+
"source": []
200+
}
201+
],
202+
"metadata": {
203+
"kernelspec": {
204+
"display_name": "Python 3",
205+
"language": "python",
206+
"name": "python3"
207+
},
208+
"language_info": {
209+
"codemirror_mode": {
210+
"name": "ipython",
211+
"version": 3
212+
},
213+
"file_extension": ".py",
214+
"mimetype": "text/x-python",
215+
"name": "python",
216+
"nbconvert_exporter": "python",
217+
"pygments_lexer": "ipython3",
218+
"version": "3.7.3"
219+
},
220+
"toc": {
221+
"base_numbering": 1,
222+
"nav_menu": {},
223+
"number_sections": true,
224+
"sideBar": true,
225+
"skip_h1_title": false,
226+
"title_cell": "Table of Contents",
227+
"title_sidebar": "Contents",
228+
"toc_cell": false,
229+
"toc_position": {},
230+
"toc_section_display": true,
231+
"toc_window_display": false
232+
}
233+
},
234+
"nbformat": 4,
235+
"nbformat_minor": 2
236+
}

0 commit comments

Comments
 (0)