Skip to content

Commit a76fd0f

Browse files
authored
CRF based sklearn_crfsuite.
1 parent 25323b4 commit a76fd0f

File tree

1 file changed

+350
-0
lines changed

1 file changed

+350
-0
lines changed

charpter24_CRF/crf.ipynb

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"### CRF"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"基于sklearn_crfsuite NER系统搭建,本例来自于sklearn_crfsuite官方tutorial"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": 3,
20+
"metadata": {},
21+
"outputs": [],
22+
"source": [
23+
"# 导入相关库\n",
24+
"import nltk\n",
25+
"import sklearn\n",
26+
"import scipy.stats\n",
27+
"from sklearn.metrics import make_scorer\n",
28+
"from sklearn.model_selection import cross_val_score\n",
29+
"from sklearn.model_selection import RandomizedSearchCV\n",
30+
"\n",
31+
"import sklearn_crfsuite\n",
32+
"from sklearn_crfsuite import scorers\n",
33+
"from sklearn_crfsuite import metrics"
34+
]
35+
},
36+
{
37+
"cell_type": "code",
38+
"execution_count": 5,
39+
"metadata": {},
40+
"outputs": [
41+
{
42+
"name": "stderr",
43+
"output_type": "stream",
44+
"text": [
45+
"[nltk_data] Downloading package conll2002 to\n",
46+
"[nltk_data] C:\\Users\\92070\\AppData\\Roaming\\nltk_data...\n",
47+
"[nltk_data] Unzipping corpora\\conll2002.zip.\n"
48+
]
49+
},
50+
{
51+
"data": {
52+
"text/plain": [
53+
"True"
54+
]
55+
},
56+
"execution_count": 5,
57+
"metadata": {},
58+
"output_type": "execute_result"
59+
}
60+
],
61+
"source": [
62+
"# 基于NLTK下载示例数据集\n",
63+
"nltk.download('conll2002')"
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": 6,
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"# 设置训练和测试样本\n",
73+
"train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))\n",
74+
"test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": 7,
80+
"metadata": {},
81+
"outputs": [
82+
{
83+
"data": {
84+
"text/plain": [
85+
"[('Melbourne', 'NP', 'B-LOC'),\n",
86+
" ('(', 'Fpa', 'O'),\n",
87+
" ('Australia', 'NP', 'B-LOC'),\n",
88+
" (')', 'Fpt', 'O'),\n",
89+
" (',', 'Fc', 'O'),\n",
90+
" ('25', 'Z', 'O'),\n",
91+
" ('may', 'NC', 'O'),\n",
92+
" ('(', 'Fpa', 'O'),\n",
93+
" ('EFE', 'NC', 'B-ORG'),\n",
94+
" (')', 'Fpt', 'O'),\n",
95+
" ('.', 'Fp', 'O')]"
96+
]
97+
},
98+
"execution_count": 7,
99+
"metadata": {},
100+
"output_type": "execute_result"
101+
}
102+
],
103+
"source": [
104+
"train_sents[0]"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 8,
110+
"metadata": {},
111+
"outputs": [],
112+
"source": [
113+
"# 单词转化为数值特征\n",
114+
"def word2features(sent, i):\n",
115+
" word = sent[i][0]\n",
116+
" postag = sent[i][1]\n",
117+
"\n",
118+
" features = {\n",
119+
" 'bias': 1.0,\n",
120+
" 'word.lower()': word.lower(),\n",
121+
" 'word[-3:]': word[-3:],\n",
122+
" 'word[-2:]': word[-2:],\n",
123+
" 'word.isupper()': word.isupper(),\n",
124+
" 'word.istitle()': word.istitle(),\n",
125+
" 'word.isdigit()': word.isdigit(),\n",
126+
" 'postag': postag,\n",
127+
" 'postag[:2]': postag[:2],\n",
128+
" }\n",
129+
" if i > 0:\n",
130+
" word1 = sent[i-1][0]\n",
131+
" postag1 = sent[i-1][1]\n",
132+
" features.update({\n",
133+
" '-1:word.lower()': word1.lower(),\n",
134+
" '-1:word.istitle()': word1.istitle(),\n",
135+
" '-1:word.isupper()': word1.isupper(),\n",
136+
" '-1:postag': postag1,\n",
137+
" '-1:postag[:2]': postag1[:2],\n",
138+
" })\n",
139+
" else:\n",
140+
" features['BOS'] = True\n",
141+
"\n",
142+
" if i < len(sent)-1:\n",
143+
" word1 = sent[i+1][0]\n",
144+
" postag1 = sent[i+1][1]\n",
145+
" features.update({\n",
146+
" '+1:word.lower()': word1.lower(),\n",
147+
" '+1:word.istitle()': word1.istitle(),\n",
148+
" '+1:word.isupper()': word1.isupper(),\n",
149+
" '+1:postag': postag1,\n",
150+
" '+1:postag[:2]': postag1[:2],\n",
151+
" })\n",
152+
" else:\n",
153+
" features['EOS'] = True\n",
154+
"\n",
155+
" return features\n",
156+
"\n",
157+
"\n",
158+
"def sent2features(sent):\n",
159+
" return [word2features(sent, i) for i in range(len(sent))]\n",
160+
"\n",
161+
"def sent2labels(sent):\n",
162+
" return [label for token, postag, label in sent]\n",
163+
"\n",
164+
"def sent2tokens(sent):\n",
165+
" return [token for token, postag, label in sent]"
166+
]
167+
},
168+
{
169+
"cell_type": "code",
170+
"execution_count": 9,
171+
"metadata": {},
172+
"outputs": [
173+
{
174+
"data": {
175+
"text/plain": [
176+
"{'bias': 1.0,\n",
177+
" 'word.lower()': 'melbourne',\n",
178+
" 'word[-3:]': 'rne',\n",
179+
" 'word[-2:]': 'ne',\n",
180+
" 'word.isupper()': False,\n",
181+
" 'word.istitle()': True,\n",
182+
" 'word.isdigit()': False,\n",
183+
" 'postag': 'NP',\n",
184+
" 'postag[:2]': 'NP',\n",
185+
" 'BOS': True,\n",
186+
" '+1:word.lower()': '(',\n",
187+
" '+1:word.istitle()': False,\n",
188+
" '+1:word.isupper()': False,\n",
189+
" '+1:postag': 'Fpa',\n",
190+
" '+1:postag[:2]': 'Fp'}"
191+
]
192+
},
193+
"execution_count": 9,
194+
"metadata": {},
195+
"output_type": "execute_result"
196+
}
197+
],
198+
"source": [
199+
"sent2features(train_sents[0])[0]"
200+
]
201+
},
202+
{
203+
"cell_type": "code",
204+
"execution_count": 10,
205+
"metadata": {},
206+
"outputs": [],
207+
"source": [
208+
"# 构造训练集和测试集\n",
209+
"X_train = [sent2features(s) for s in train_sents]\n",
210+
"y_train = [sent2labels(s) for s in train_sents]\n",
211+
"\n",
212+
"X_test = [sent2features(s) for s in test_sents]\n",
213+
"y_test = [sent2labels(s) for s in test_sents]"
214+
]
215+
},
216+
{
217+
"cell_type": "code",
218+
"execution_count": 11,
219+
"metadata": {},
220+
"outputs": [
221+
{
222+
"name": "stdout",
223+
"output_type": "stream",
224+
"text": [
225+
"8323 1517\n"
226+
]
227+
}
228+
],
229+
"source": [
230+
"print(len(X_train), len(X_test))"
231+
]
232+
},
233+
{
234+
"cell_type": "code",
235+
"execution_count": 18,
236+
"metadata": {},
237+
"outputs": [
238+
{
239+
"data": {
240+
"text/plain": [
241+
"0.7964686316443963"
242+
]
243+
},
244+
"execution_count": 18,
245+
"metadata": {},
246+
"output_type": "execute_result"
247+
}
248+
],
249+
"source": [
250+
"# 创建CRF模型实例\n",
251+
"crf = sklearn_crfsuite.CRF(\n",
252+
" algorithm='lbfgs',\n",
253+
" c1=0.1,\n",
254+
" c2=0.1,\n",
255+
" max_iterations=100,\n",
256+
" all_possible_transitions=True\n",
257+
")\n",
258+
"# 模型训练\n",
259+
"crf.fit(X_train, y_train)\n",
260+
"# 类别标签\n",
261+
"labels = list(crf.classes_)\n",
262+
"labels.remove('O')\n",
263+
"# 模型预测\n",
264+
"y_pred = crf.predict(X_test)\n",
265+
"# 计算F1得分\n",
266+
"metrics.flat_f1_score(y_test, y_pred,\n",
267+
" average='weighted', labels=labels)"
268+
]
269+
},
270+
{
271+
"cell_type": "code",
272+
"execution_count": 19,
273+
"metadata": {},
274+
"outputs": [
275+
{
276+
"name": "stdout",
277+
"output_type": "stream",
278+
"text": [
279+
" precision recall f1-score support\n",
280+
"\n",
281+
" B-LOC 0.810 0.784 0.797 1084\n",
282+
" I-LOC 0.690 0.637 0.662 325\n",
283+
" B-MISC 0.731 0.569 0.640 339\n",
284+
" I-MISC 0.699 0.589 0.639 557\n",
285+
" B-ORG 0.807 0.832 0.820 1400\n",
286+
" I-ORG 0.852 0.786 0.818 1104\n",
287+
" B-PER 0.850 0.884 0.867 735\n",
288+
" I-PER 0.893 0.943 0.917 634\n",
289+
"\n",
290+
" micro avg 0.813 0.787 0.799 6178\n",
291+
" macro avg 0.791 0.753 0.770 6178\n",
292+
"weighted avg 0.809 0.787 0.796 6178\n",
293+
"\n"
294+
]
295+
}
296+
],
297+
"source": [
298+
"# 打印B和I组的模型结果\n",
299+
"sorted_labels = sorted(\n",
300+
" labels,\n",
301+
" key=lambda name: (name[1:], name[0])\n",
302+
")\n",
303+
"print(metrics.flat_classification_report(\n",
304+
" y_test, y_pred, labels=sorted_labels, digits=3\n",
305+
"))"
306+
]
307+
},
308+
{
309+
"cell_type": "code",
310+
"execution_count": null,
311+
"metadata": {},
312+
"outputs": [],
313+
"source": []
314+
}
315+
],
316+
"metadata": {
317+
"kernelspec": {
318+
"display_name": "Python 3",
319+
"language": "python",
320+
"name": "python3"
321+
},
322+
"language_info": {
323+
"codemirror_mode": {
324+
"name": "ipython",
325+
"version": 3
326+
},
327+
"file_extension": ".py",
328+
"mimetype": "text/x-python",
329+
"name": "python",
330+
"nbconvert_exporter": "python",
331+
"pygments_lexer": "ipython3",
332+
"version": "3.7.3"
333+
},
334+
"toc": {
335+
"base_numbering": 1,
336+
"nav_menu": {},
337+
"number_sections": true,
338+
"sideBar": true,
339+
"skip_h1_title": false,
340+
"title_cell": "Table of Contents",
341+
"title_sidebar": "Contents",
342+
"toc_cell": false,
343+
"toc_position": {},
344+
"toc_section_display": true,
345+
"toc_window_display": false
346+
}
347+
},
348+
"nbformat": 4,
349+
"nbformat_minor": 2
350+
}

0 commit comments

Comments
 (0)