Skip to content

Commit f590923

Browse files
authored
Add files via upload
1 parent 966d8d9 commit f590923

File tree

1 file changed

+219
-0
lines changed

1 file changed

+219
-0
lines changed

charpter23_HMM/hmm.ipynb

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"### HMM"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 3,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"import numpy as np\n",
17+
"\n",
18+
"### 定义HMM模型\n",
19+
"class HMM:\n",
20+
" def __init__(self, N, M, pi=None, A=None, B=None):\n",
21+
" # 可能的状态数\n",
22+
" self.N = N\n",
23+
" # 可能的观测数\n",
24+
" self.M = M\n",
25+
" # 初始状态概率向量\n",
26+
" self.pi = pi\n",
27+
" # 状态转移概率矩阵\n",
28+
" self.A = A\n",
29+
" # 观测概率矩阵\n",
30+
" self.B = B\n",
31+
"\n",
32+
" # 根据给定的概率分布随机返回数据\n",
33+
" def rdistribution(self, dist): \n",
34+
" r = np.random.rand()\n",
35+
" for ix, p in enumerate(dist):\n",
36+
" if r < p: \n",
37+
" return ix\n",
38+
" r -= p\n",
39+
"\n",
40+
" # 生成HMM观测序列\n",
41+
" def generate(self, T):\n",
42+
" # 根据初始概率分布生成第一个状态\n",
43+
" i = self.rdistribution(self.pi) \n",
44+
" # 生成第一个观测数据\n",
45+
" o = self.rdistribution(self.B[i]) \n",
46+
" observed_data = [o]\n",
47+
" # 遍历生成剩下的状态和观测数据\n",
48+
" for _ in range(T-1): \n",
49+
" i = self.rdistribution(self.A[i])\n",
50+
" o = self.rdistribution(self.B[i])\n",
51+
" observed_data.append(o)\n",
52+
" return observed_data"
53+
]
54+
},
55+
{
56+
"cell_type": "code",
57+
"execution_count": 4,
58+
"metadata": {},
59+
"outputs": [
60+
{
61+
"name": "stdout",
62+
"output_type": "stream",
63+
"text": [
64+
"[1, 0, 0, 1, 0]\n"
65+
]
66+
}
67+
],
68+
"source": [
69+
"# 初始状态概率分布\n",
70+
"pi = np.array([0.25, 0.25, 0.25, 0.25])\n",
71+
"# 状态转移概率矩阵\n",
72+
"A = np.array([\n",
73+
" [0, 1, 0, 0],\n",
74+
" [0.4, 0, 0.6, 0],\n",
75+
" [0, 0.4, 0, 0.6],\n",
76+
"[0, 0, 0.5, 0.5]])\n",
77+
"# 观测概率矩阵\n",
78+
"B = np.array([\n",
79+
" [0.5, 0.5],\n",
80+
" [0.6, 0.4],\n",
81+
" [0.2, 0.8],\n",
82+
" [0.3, 0.7]])\n",
83+
"# 可能的状态数和观测数\n",
84+
"N = 4\n",
85+
"M = 2\n",
86+
"# 创建HMM实例\n",
87+
"hmm = HMM(N, M, pi, A, B)\n",
88+
"# 生成观测序列\n",
89+
"print(hmm.generate(5))"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": 6,
95+
"metadata": {},
96+
"outputs": [
97+
{
98+
"name": "stdout",
99+
"output_type": "stream",
100+
"text": [
101+
"0.01983169125\n"
102+
]
103+
}
104+
],
105+
"source": [
106+
"### 前向算法计算条件概率\n",
107+
"def prob_calc(O):\n",
108+
" '''\n",
109+
" 输入:\n",
110+
" O:观测序列\n",
111+
" 输出:\n",
112+
" alpha.sum():条件概率\n",
113+
" '''\n",
114+
" # 初值\n",
115+
" alpha = pi * B[:, O[0]]\n",
116+
" # 递推\n",
117+
" for o in O[1:]:\n",
118+
" alpha_next = np.empty(4)\n",
119+
" for j in range(4):\n",
120+
" alpha_next[j] = np.sum(A[:,j] * alpha * B[j,o])\n",
121+
" alpha = alpha_next\n",
122+
" return alpha.sum()\n",
123+
"\n",
124+
"# 给定观测\n",
125+
"O = [1,0,1,0,0]\n",
126+
"print(prob_calc(O))"
127+
]
128+
},
129+
{
130+
"cell_type": "code",
131+
"execution_count": 7,
132+
"metadata": {},
133+
"outputs": [
134+
{
135+
"name": "stdout",
136+
"output_type": "stream",
137+
"text": [
138+
"[0, 1, 2, 3, 3]\n"
139+
]
140+
}
141+
],
142+
"source": [
143+
"### 序列标注问题和维特比算法\n",
144+
"def viterbi_decode(O):\n",
145+
" '''\n",
146+
" 输入:\n",
147+
" O:观测序列\n",
148+
" 输出:\n",
149+
" path:最优隐状态路径\n",
150+
" ''' \n",
151+
" # 序列长度和初始观测\n",
152+
" T, o = len(O), O[0]\n",
153+
" # 初始化delta变量\n",
154+
" delta = pi * B[:, o]\n",
155+
" # 初始化varphi变量\n",
156+
" varphi = np.zeros((T, 4), dtype=int)\n",
157+
" path = [0] * T\n",
158+
" # 递推\n",
159+
" for i in range(1, T):\n",
160+
" delta = delta.reshape(-1, 1) \n",
161+
" tmp = delta * A\n",
162+
" varphi[i, :] = np.argmax(tmp, axis=0)\n",
163+
" delta = np.max(tmp, axis=0) * B[:, O[i]]\n",
164+
" # 终止\n",
165+
" path[-1] = np.argmax(delta)\n",
166+
" # 回溯最优路径\n",
167+
" for i in range(T-1, 0, -1):\n",
168+
" path[i-1] = varphi[i, path[i]]\n",
169+
" return path\n",
170+
"\n",
171+
"# 给定观测序列\n",
172+
"O = [1,0,1,1,0]\n",
173+
"# 输出最可能的隐状态序列\n",
174+
"print(viterbi_decode(O))"
175+
]
176+
},
177+
{
178+
"cell_type": "code",
179+
"execution_count": null,
180+
"metadata": {},
181+
"outputs": [],
182+
"source": []
183+
}
184+
],
185+
"metadata": {
186+
"kernelspec": {
187+
"display_name": "Python 3",
188+
"language": "python",
189+
"name": "python3"
190+
},
191+
"language_info": {
192+
"codemirror_mode": {
193+
"name": "ipython",
194+
"version": 3
195+
},
196+
"file_extension": ".py",
197+
"mimetype": "text/x-python",
198+
"name": "python",
199+
"nbconvert_exporter": "python",
200+
"pygments_lexer": "ipython3",
201+
"version": "3.7.3"
202+
},
203+
"toc": {
204+
"base_numbering": 1,
205+
"nav_menu": {},
206+
"number_sections": true,
207+
"sideBar": true,
208+
"skip_h1_title": false,
209+
"title_cell": "Table of Contents",
210+
"title_sidebar": "Contents",
211+
"toc_cell": false,
212+
"toc_position": {},
213+
"toc_section_display": true,
214+
"toc_window_display": false
215+
}
216+
},
217+
"nbformat": 4,
218+
"nbformat_minor": 2
219+
}

0 commit comments

Comments
 (0)