Skip to content

Commit 9595e34

Browse files
authored
Add files via upload
1 parent c1977fd commit 9595e34

File tree

1 file changed

+277
-0
lines changed

1 file changed

+277
-0
lines changed
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import pandas as pd\n",
10+
"import numpy as np\n",
11+
"from collections import defaultdict\n",
12+
"\n",
13+
"class MaxEnt:\n",
14+
" def __init__(self, max_iter=100):\n",
15+
" # 训练输入\n",
16+
" self.X_ = None\n",
17+
" # 训练标签\n",
18+
" self.y_ = None\n",
19+
" # 标签类别数量\n",
20+
" self.m = None \n",
21+
" # 特征数量\n",
22+
" self.n = None \n",
23+
" # 训练样本量\n",
24+
" self.N = None \n",
25+
" # 常数特征取值\n",
26+
" self.M = None\n",
27+
" # 权重系数\n",
28+
" self.w = None\n",
29+
" # 标签名称\n",
30+
" self.labels = defaultdict(int)\n",
31+
" # 特征名称\n",
32+
" self.features = defaultdict(int)\n",
33+
" # 最大迭代次数\n",
34+
" self.max_iter = max_iter\n",
35+
"\n",
36+
" ### 计算特征函数关于经验联合分布P(X,Y)的期望\n",
37+
" def _EP_hat_f(self, x, y):\n",
38+
" self.Pxy = np.zeros((self.m, self.n))\n",
39+
" self.Px = np.zeros(self.n)\n",
40+
" for x_, y_ in zip(x, y):\n",
41+
" # 遍历每个样本\n",
42+
" for x__ in set(x_):\n",
43+
" self.Pxy[self.labels[y_], self.features[x__]] += 1\n",
44+
" self.Px[self.features[x__]] += 1 \n",
45+
" self.EP_hat_f = self.Pxy/self.N\n",
46+
" \n",
47+
" ### 计算特征函数关于模型P(Y|X)与经验分布P(X)的期望\n",
48+
" def _EP_f(self):\n",
49+
" self.EPf = np.zeros((self.m, self.n))\n",
50+
" for X in self.X_:\n",
51+
" pw = self._pw(X)\n",
52+
" pw = pw.reshape(self.m, 1)\n",
53+
" px = self.Px.reshape(1, self.n)\n",
54+
" self.EP_f += pw*px / self.N\n",
55+
" \n",
56+
" ### 最大熵模型P(y|x)\n",
57+
" def _pw(self, x):\n",
58+
" mask = np.zeros(self.n+1)\n",
59+
" for ix in x:\n",
60+
" mask[self.features[ix]] = 1\n",
61+
" tmp = self.w * mask[1:]\n",
62+
" pw = np.exp(np.sum(tmp, axis=1))\n",
63+
" Z = np.sum(pw)\n",
64+
" pw = pw/Z\n",
65+
" return pw\n",
66+
"\n",
67+
" ### 熵模型拟合\n",
68+
" ### 基于改进的迭代尺度方法IIS\n",
69+
" def fit(self, x, y):\n",
70+
" # 训练输入\n",
71+
" self.X_ = x\n",
72+
" # 训练输出\n",
73+
" self.y_ = list(set(y))\n",
74+
" # 输入数据展平后集合\n",
75+
" tmp = set(self.X_.flatten())\n",
76+
" # 特征命名\n",
77+
" self.features = defaultdict(int, zip(tmp, range(1, len(tmp)+1))) \n",
78+
" # 标签命名\n",
79+
" self.labels = dict(zip(self.y_, range(len(self.y_))))\n",
80+
" # 特征数\n",
81+
" self.n = len(self.features)+1 \n",
82+
" # 标签类别数量\n",
83+
" self.m = len(self.labels)\n",
84+
" # 训练样本量\n",
85+
" self.N = len(x) \n",
86+
" # 计算EP_hat_f\n",
87+
" self._EP_hat_f(x, y)\n",
88+
" # 初始化系数矩阵\n",
89+
" self.w = np.zeros((self.m, self.n))\n",
90+
" # 循环迭代\n",
91+
" i = 0\n",
92+
" while i <= self.max_iter:\n",
93+
" # 计算EPf\n",
94+
" self._EP_f()\n",
95+
" # 令常数特征函数为M\n",
96+
" self.M = 100\n",
97+
" # IIS算法步骤(3)\n",
98+
" tmp = np.true_divide(self.EP_hat_f, self.EP_f)\n",
99+
" tmp[tmp == np.inf] = 0\n",
100+
" tmp = np.nan_to_num(tmp)\n",
101+
" sigma = np.where(tmp != 0, 1/self.M*np.log(tmp), 0) \n",
102+
" # 更新系数:IIS步骤(4)\n",
103+
" self.w = self.w + sigma\n",
104+
" i += 1\n",
105+
" print('training done.')\n",
106+
" return self\n",
107+
"\n",
108+
" # 定义最大熵模型预测函数\n",
109+
" def predict(self, x):\n",
110+
" res = np.zeros(len(x), dtype=np.int64)\n",
111+
" for ix, x_ in enumerate(x):\n",
112+
" tmp = self._pw(x_)\n",
113+
" print(tmp, np.argmax(tmp), self.labels)\n",
114+
" res[ix] = self.labels[self.y_[np.argmax(tmp)]]\n",
115+
" return np.array([self.y_[ix] for ix in res])"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": 2,
121+
"metadata": {},
122+
"outputs": [
123+
{
124+
"name": "stdout",
125+
"output_type": "stream",
126+
"text": [
127+
"(105, 4) (105,)\n"
128+
]
129+
}
130+
],
131+
"source": [
132+
"from sklearn.datasets import load_iris\n",
133+
"from sklearn.model_selection import train_test_split\n",
134+
"raw_data = load_iris()\n",
135+
"X, labels = raw_data.data, raw_data.target\n",
136+
"X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.3, random_state=43)\n",
137+
"print(X_train.shape, y_train.shape)"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": 3,
143+
"metadata": {},
144+
"outputs": [
145+
{
146+
"data": {
147+
"text/plain": [
148+
"array([2, 2, 2, 2, 2])"
149+
]
150+
},
151+
"execution_count": 3,
152+
"metadata": {},
153+
"output_type": "execute_result"
154+
}
155+
],
156+
"source": [
157+
"labels[-5:]"
158+
]
159+
},
160+
{
161+
"cell_type": "code",
162+
"execution_count": 4,
163+
"metadata": {},
164+
"outputs": [
165+
{
166+
"name": "stderr",
167+
"output_type": "stream",
168+
"text": [
169+
"D:\\Installation\\anaconda\\install\\lib\\site-packages\\ipykernel_launcher.py:90: RuntimeWarning: invalid value encountered in true_divide\n",
170+
"D:\\Installation\\anaconda\\install\\lib\\site-packages\\ipykernel_launcher.py:93: RuntimeWarning: divide by zero encountered in log\n"
171+
]
172+
},
173+
{
174+
"name": "stdout",
175+
"output_type": "stream",
176+
"text": [
177+
"training done.\n",
178+
"[0.87116843 0.04683368 0.08199789] 0 {0: 0, 1: 1, 2: 2}\n",
179+
"[0.00261138 0.49573305 0.50165557] 2 {0: 0, 1: 1, 2: 2}\n",
180+
"[0.12626693 0.017157 0.85657607] 2 {0: 0, 1: 1, 2: 2}\n",
181+
"[1.55221378e-04 4.45985560e-05 9.99800180e-01] 2 {0: 0, 1: 1, 2: 2}\n",
182+
"[7.29970746e-03 9.92687370e-01 1.29226740e-05] 1 {0: 0, 1: 1, 2: 2}\n",
183+
"[0.01343943 0.01247887 0.9740817 ] 2 {0: 0, 1: 1, 2: 2}\n",
184+
"[0.85166079 0.05241898 0.09592023] 0 {0: 0, 1: 1, 2: 2}\n",
185+
"[0.00371481 0.00896982 0.98731537] 2 {0: 0, 1: 1, 2: 2}\n",
186+
"[2.69340079e-04 9.78392776e-01 2.13378835e-02] 1 {0: 0, 1: 1, 2: 2}\n",
187+
"[0.01224702 0.02294254 0.96481044] 2 {0: 0, 1: 1, 2: 2}\n",
188+
"[0.00323508 0.98724246 0.00952246] 1 {0: 0, 1: 1, 2: 2}\n",
189+
"[0.00196548 0.01681989 0.98121463] 2 {0: 0, 1: 1, 2: 2}\n",
190+
"[0.00480966 0.00345107 0.99173927] 2 {0: 0, 1: 1, 2: 2}\n",
191+
"[0.00221101 0.01888735 0.97890163] 2 {0: 0, 1: 1, 2: 2}\n",
192+
"[9.87528545e-01 3.25313387e-04 1.21461416e-02] 0 {0: 0, 1: 1, 2: 2}\n",
193+
"[3.84153917e-05 5.25603786e-01 4.74357798e-01] 1 {0: 0, 1: 1, 2: 2}\n",
194+
"[0.91969448 0.00730851 0.07299701] 0 {0: 0, 1: 1, 2: 2}\n",
195+
"[3.48493252e-03 9.96377722e-01 1.37345863e-04] 1 {0: 0, 1: 1, 2: 2}\n",
196+
"[0.00597935 0.02540794 0.96861271] 2 {0: 0, 1: 1, 2: 2}\n",
197+
"[0.96593729 0.01606867 0.01799404] 0 {0: 0, 1: 1, 2: 2}\n",
198+
"[7.07324443e-01 2.92672257e-01 3.29961259e-06] 0 {0: 0, 1: 1, 2: 2}\n",
199+
"[0.96122092 0.03604362 0.00273547] 0 {0: 0, 1: 1, 2: 2}\n",
200+
"[9.92671813e-01 7.31265179e-03 1.55352641e-05] 0 {0: 0, 1: 1, 2: 2}\n",
201+
"[9.99997290e-01 2.58555077e-06 1.24081335e-07] 0 {0: 0, 1: 1, 2: 2}\n",
202+
"[1.77991802e-05 4.62006560e-04 9.99520194e-01] 2 {0: 0, 1: 1, 2: 2}\n",
203+
"[9.99995176e-01 3.85240188e-06 9.72067357e-07] 0 {0: 0, 1: 1, 2: 2}\n",
204+
"[0.15306343 0.21405142 0.63288515] 2 {0: 0, 1: 1, 2: 2}\n",
205+
"[0.25817329 0.28818997 0.45363674] 2 {0: 0, 1: 1, 2: 2}\n",
206+
"[2.43530473e-04 4.07929999e-01 5.91826471e-01] 2 {0: 0, 1: 1, 2: 2}\n",
207+
"[0.71160155 0.27290911 0.01548934] 0 {0: 0, 1: 1, 2: 2}\n",
208+
"[2.94976826e-06 2.51510534e-02 9.74845997e-01] 2 {0: 0, 1: 1, 2: 2}\n",
209+
"[0.97629163 0.00331591 0.02039245] 0 {0: 0, 1: 1, 2: 2}\n",
210+
"[0.04513811 0.01484173 0.94002015] 2 {0: 0, 1: 1, 2: 2}\n",
211+
"[0.61382753 0.38321073 0.00296174] 0 {0: 0, 1: 1, 2: 2}\n",
212+
"[9.65538451e-01 3.86322918e-06 3.44576854e-02] 0 {0: 0, 1: 1, 2: 2}\n",
213+
"[0.00924088 0.01731108 0.97344804] 2 {0: 0, 1: 1, 2: 2}\n",
214+
"[0.02511142 0.93818613 0.03670245] 1 {0: 0, 1: 1, 2: 2}\n",
215+
"[9.99127831e-01 3.29723254e-04 5.42445518e-04] 0 {0: 0, 1: 1, 2: 2}\n",
216+
"[0.05081665 0.0038204 0.94536295] 2 {0: 0, 1: 1, 2: 2}\n",
217+
"[9.99985376e-01 6.85280694e-06 7.77081022e-06] 0 {0: 0, 1: 1, 2: 2}\n",
218+
"[9.99791732e-01 2.06536005e-04 1.73191035e-06] 0 {0: 0, 1: 1, 2: 2}\n",
219+
"[2.72323181e-04 2.99692548e-03 9.96730751e-01] 2 {0: 0, 1: 1, 2: 2}\n",
220+
"[0.02005139 0.97151852 0.00843009] 1 {0: 0, 1: 1, 2: 2}\n",
221+
"[0.95642409 0.02485912 0.01871679] 0 {0: 0, 1: 1, 2: 2}\n",
222+
"[0.00297317 0.01261126 0.98441558] 2 {0: 0, 1: 1, 2: 2}\n",
223+
"0.37777777777777777\n"
224+
]
225+
}
226+
],
227+
"source": [
228+
"from sklearn.metrics import accuracy_score\n",
229+
"maxent = MaxEnt()\n",
230+
"maxent.fit(X_train, y_train)\n",
231+
"y_pred = maxent.predict(X_test)\n",
232+
"print(accuracy_score(y_test, y_pred))"
233+
]
234+
},
235+
{
236+
"cell_type": "code",
237+
"execution_count": null,
238+
"metadata": {},
239+
"outputs": [],
240+
"source": []
241+
}
242+
],
243+
"metadata": {
244+
"kernelspec": {
245+
"display_name": "Python 3",
246+
"language": "python",
247+
"name": "python3"
248+
},
249+
"language_info": {
250+
"codemirror_mode": {
251+
"name": "ipython",
252+
"version": 3
253+
},
254+
"file_extension": ".py",
255+
"mimetype": "text/x-python",
256+
"name": "python",
257+
"nbconvert_exporter": "python",
258+
"pygments_lexer": "ipython3",
259+
"version": "3.7.3"
260+
},
261+
"toc": {
262+
"base_numbering": 1,
263+
"nav_menu": {},
264+
"number_sections": true,
265+
"sideBar": true,
266+
"skip_h1_title": false,
267+
"title_cell": "Table of Contents",
268+
"title_sidebar": "Contents",
269+
"toc_cell": false,
270+
"toc_position": {},
271+
"toc_section_display": true,
272+
"toc_window_display": false
273+
}
274+
},
275+
"nbformat": 4,
276+
"nbformat_minor": 2
277+
}

0 commit comments

Comments
 (0)