| 
 | 1 | +{  | 
 | 2 | + "cells": [  | 
 | 3 | +  {  | 
 | 4 | +   "cell_type": "markdown",  | 
 | 5 | +   "metadata": {},  | 
 | 6 | +   "source": [  | 
 | 7 | +    "### CRF"  | 
 | 8 | +   ]  | 
 | 9 | +  },  | 
 | 10 | +  {  | 
 | 11 | +   "cell_type": "markdown",  | 
 | 12 | +   "metadata": {},  | 
 | 13 | +   "source": [  | 
 | 14 | +    "基于sklearn_crfsuite NER系统搭建,本例来自于sklearn_crfsuite官方tutorial"  | 
 | 15 | +   ]  | 
 | 16 | +  },  | 
 | 17 | +  {  | 
 | 18 | +   "cell_type": "code",  | 
 | 19 | +   "execution_count": 3,  | 
 | 20 | +   "metadata": {},  | 
 | 21 | +   "outputs": [],  | 
 | 22 | +   "source": [  | 
 | 23 | +    "# 导入相关库\n",  | 
 | 24 | +    "import nltk\n",  | 
 | 25 | +    "import sklearn\n",  | 
 | 26 | +    "import scipy.stats\n",  | 
 | 27 | +    "from sklearn.metrics import make_scorer\n",  | 
 | 28 | +    "from sklearn.model_selection import cross_val_score\n",  | 
 | 29 | +    "from sklearn.model_selection import RandomizedSearchCV\n",  | 
 | 30 | +    "\n",  | 
 | 31 | +    "import sklearn_crfsuite\n",  | 
 | 32 | +    "from sklearn_crfsuite import scorers\n",  | 
 | 33 | +    "from sklearn_crfsuite import metrics"  | 
 | 34 | +   ]  | 
 | 35 | +  },  | 
 | 36 | +  {  | 
 | 37 | +   "cell_type": "code",  | 
 | 38 | +   "execution_count": 5,  | 
 | 39 | +   "metadata": {},  | 
 | 40 | +   "outputs": [  | 
 | 41 | +    {  | 
 | 42 | +     "name": "stderr",  | 
 | 43 | +     "output_type": "stream",  | 
 | 44 | +     "text": [  | 
 | 45 | +      "[nltk_data] Downloading package conll2002 to\n",  | 
 | 46 | +      "[nltk_data]     C:\\Users\\92070\\AppData\\Roaming\\nltk_data...\n",  | 
 | 47 | +      "[nltk_data]   Unzipping corpora\\conll2002.zip.\n"  | 
 | 48 | +     ]  | 
 | 49 | +    },  | 
 | 50 | +    {  | 
 | 51 | +     "data": {  | 
 | 52 | +      "text/plain": [  | 
 | 53 | +       "True"  | 
 | 54 | +      ]  | 
 | 55 | +     },  | 
 | 56 | +     "execution_count": 5,  | 
 | 57 | +     "metadata": {},  | 
 | 58 | +     "output_type": "execute_result"  | 
 | 59 | +    }  | 
 | 60 | +   ],  | 
 | 61 | +   "source": [  | 
 | 62 | +    "# 基于NLTK下载示例数据集\n",  | 
 | 63 | +    "nltk.download('conll2002')"  | 
 | 64 | +   ]  | 
 | 65 | +  },  | 
 | 66 | +  {  | 
 | 67 | +   "cell_type": "code",  | 
 | 68 | +   "execution_count": 6,  | 
 | 69 | +   "metadata": {},  | 
 | 70 | +   "outputs": [],  | 
 | 71 | +   "source": [  | 
 | 72 | +    "# 设置训练和测试样本\n",  | 
 | 73 | +    "train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))\n",  | 
 | 74 | +    "test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))"  | 
 | 75 | +   ]  | 
 | 76 | +  },  | 
 | 77 | +  {  | 
 | 78 | +   "cell_type": "code",  | 
 | 79 | +   "execution_count": 7,  | 
 | 80 | +   "metadata": {},  | 
 | 81 | +   "outputs": [  | 
 | 82 | +    {  | 
 | 83 | +     "data": {  | 
 | 84 | +      "text/plain": [  | 
 | 85 | +       "[('Melbourne', 'NP', 'B-LOC'),\n",  | 
 | 86 | +       " ('(', 'Fpa', 'O'),\n",  | 
 | 87 | +       " ('Australia', 'NP', 'B-LOC'),\n",  | 
 | 88 | +       " (')', 'Fpt', 'O'),\n",  | 
 | 89 | +       " (',', 'Fc', 'O'),\n",  | 
 | 90 | +       " ('25', 'Z', 'O'),\n",  | 
 | 91 | +       " ('may', 'NC', 'O'),\n",  | 
 | 92 | +       " ('(', 'Fpa', 'O'),\n",  | 
 | 93 | +       " ('EFE', 'NC', 'B-ORG'),\n",  | 
 | 94 | +       " (')', 'Fpt', 'O'),\n",  | 
 | 95 | +       " ('.', 'Fp', 'O')]"  | 
 | 96 | +      ]  | 
 | 97 | +     },  | 
 | 98 | +     "execution_count": 7,  | 
 | 99 | +     "metadata": {},  | 
 | 100 | +     "output_type": "execute_result"  | 
 | 101 | +    }  | 
 | 102 | +   ],  | 
 | 103 | +   "source": [  | 
 | 104 | +    "train_sents[0]"  | 
 | 105 | +   ]  | 
 | 106 | +  },  | 
 | 107 | +  {  | 
 | 108 | +   "cell_type": "code",  | 
 | 109 | +   "execution_count": 8,  | 
 | 110 | +   "metadata": {},  | 
 | 111 | +   "outputs": [],  | 
 | 112 | +   "source": [  | 
 | 113 | +    "# 单词转化为数值特征\n",  | 
 | 114 | +    "def word2features(sent, i):\n",  | 
 | 115 | +    "    word = sent[i][0]\n",  | 
 | 116 | +    "    postag = sent[i][1]\n",  | 
 | 117 | +    "\n",  | 
 | 118 | +    "    features = {\n",  | 
 | 119 | +    "        'bias': 1.0,\n",  | 
 | 120 | +    "        'word.lower()': word.lower(),\n",  | 
 | 121 | +    "        'word[-3:]': word[-3:],\n",  | 
 | 122 | +    "        'word[-2:]': word[-2:],\n",  | 
 | 123 | +    "        'word.isupper()': word.isupper(),\n",  | 
 | 124 | +    "        'word.istitle()': word.istitle(),\n",  | 
 | 125 | +    "        'word.isdigit()': word.isdigit(),\n",  | 
 | 126 | +    "        'postag': postag,\n",  | 
 | 127 | +    "        'postag[:2]': postag[:2],\n",  | 
 | 128 | +    "    }\n",  | 
 | 129 | +    "    if i > 0:\n",  | 
 | 130 | +    "        word1 = sent[i-1][0]\n",  | 
 | 131 | +    "        postag1 = sent[i-1][1]\n",  | 
 | 132 | +    "        features.update({\n",  | 
 | 133 | +    "            '-1:word.lower()': word1.lower(),\n",  | 
 | 134 | +    "            '-1:word.istitle()': word1.istitle(),\n",  | 
 | 135 | +    "            '-1:word.isupper()': word1.isupper(),\n",  | 
 | 136 | +    "            '-1:postag': postag1,\n",  | 
 | 137 | +    "            '-1:postag[:2]': postag1[:2],\n",  | 
 | 138 | +    "        })\n",  | 
 | 139 | +    "    else:\n",  | 
 | 140 | +    "        features['BOS'] = True\n",  | 
 | 141 | +    "\n",  | 
 | 142 | +    "    if i < len(sent)-1:\n",  | 
 | 143 | +    "        word1 = sent[i+1][0]\n",  | 
 | 144 | +    "        postag1 = sent[i+1][1]\n",  | 
 | 145 | +    "        features.update({\n",  | 
 | 146 | +    "            '+1:word.lower()': word1.lower(),\n",  | 
 | 147 | +    "            '+1:word.istitle()': word1.istitle(),\n",  | 
 | 148 | +    "            '+1:word.isupper()': word1.isupper(),\n",  | 
 | 149 | +    "            '+1:postag': postag1,\n",  | 
 | 150 | +    "            '+1:postag[:2]': postag1[:2],\n",  | 
 | 151 | +    "        })\n",  | 
 | 152 | +    "    else:\n",  | 
 | 153 | +    "        features['EOS'] = True\n",  | 
 | 154 | +    "\n",  | 
 | 155 | +    "    return features\n",  | 
 | 156 | +    "\n",  | 
 | 157 | +    "\n",  | 
 | 158 | +    "def sent2features(sent):\n",  | 
 | 159 | +    "    return [word2features(sent, i) for i in range(len(sent))]\n",  | 
 | 160 | +    "\n",  | 
 | 161 | +    "def sent2labels(sent):\n",  | 
 | 162 | +    "    return [label for token, postag, label in sent]\n",  | 
 | 163 | +    "\n",  | 
 | 164 | +    "def sent2tokens(sent):\n",  | 
 | 165 | +    "    return [token for token, postag, label in sent]"  | 
 | 166 | +   ]  | 
 | 167 | +  },  | 
 | 168 | +  {  | 
 | 169 | +   "cell_type": "code",  | 
 | 170 | +   "execution_count": 9,  | 
 | 171 | +   "metadata": {},  | 
 | 172 | +   "outputs": [  | 
 | 173 | +    {  | 
 | 174 | +     "data": {  | 
 | 175 | +      "text/plain": [  | 
 | 176 | +       "{'bias': 1.0,\n",  | 
 | 177 | +       " 'word.lower()': 'melbourne',\n",  | 
 | 178 | +       " 'word[-3:]': 'rne',\n",  | 
 | 179 | +       " 'word[-2:]': 'ne',\n",  | 
 | 180 | +       " 'word.isupper()': False,\n",  | 
 | 181 | +       " 'word.istitle()': True,\n",  | 
 | 182 | +       " 'word.isdigit()': False,\n",  | 
 | 183 | +       " 'postag': 'NP',\n",  | 
 | 184 | +       " 'postag[:2]': 'NP',\n",  | 
 | 185 | +       " 'BOS': True,\n",  | 
 | 186 | +       " '+1:word.lower()': '(',\n",  | 
 | 187 | +       " '+1:word.istitle()': False,\n",  | 
 | 188 | +       " '+1:word.isupper()': False,\n",  | 
 | 189 | +       " '+1:postag': 'Fpa',\n",  | 
 | 190 | +       " '+1:postag[:2]': 'Fp'}"  | 
 | 191 | +      ]  | 
 | 192 | +     },  | 
 | 193 | +     "execution_count": 9,  | 
 | 194 | +     "metadata": {},  | 
 | 195 | +     "output_type": "execute_result"  | 
 | 196 | +    }  | 
 | 197 | +   ],  | 
 | 198 | +   "source": [  | 
 | 199 | +    "sent2features(train_sents[0])[0]"  | 
 | 200 | +   ]  | 
 | 201 | +  },  | 
 | 202 | +  {  | 
 | 203 | +   "cell_type": "code",  | 
 | 204 | +   "execution_count": 10,  | 
 | 205 | +   "metadata": {},  | 
 | 206 | +   "outputs": [],  | 
 | 207 | +   "source": [  | 
 | 208 | +    "# 构造训练集和测试集\n",  | 
 | 209 | +    "X_train = [sent2features(s) for s in train_sents]\n",  | 
 | 210 | +    "y_train = [sent2labels(s) for s in train_sents]\n",  | 
 | 211 | +    "\n",  | 
 | 212 | +    "X_test = [sent2features(s) for s in test_sents]\n",  | 
 | 213 | +    "y_test = [sent2labels(s) for s in test_sents]"  | 
 | 214 | +   ]  | 
 | 215 | +  },  | 
 | 216 | +  {  | 
 | 217 | +   "cell_type": "code",  | 
 | 218 | +   "execution_count": 11,  | 
 | 219 | +   "metadata": {},  | 
 | 220 | +   "outputs": [  | 
 | 221 | +    {  | 
 | 222 | +     "name": "stdout",  | 
 | 223 | +     "output_type": "stream",  | 
 | 224 | +     "text": [  | 
 | 225 | +      "8323 1517\n"  | 
 | 226 | +     ]  | 
 | 227 | +    }  | 
 | 228 | +   ],  | 
 | 229 | +   "source": [  | 
 | 230 | +    "print(len(X_train), len(X_test))"  | 
 | 231 | +   ]  | 
 | 232 | +  },  | 
 | 233 | +  {  | 
 | 234 | +   "cell_type": "code",  | 
 | 235 | +   "execution_count": 18,  | 
 | 236 | +   "metadata": {},  | 
 | 237 | +   "outputs": [  | 
 | 238 | +    {  | 
 | 239 | +     "data": {  | 
 | 240 | +      "text/plain": [  | 
 | 241 | +       "0.7964686316443963"  | 
 | 242 | +      ]  | 
 | 243 | +     },  | 
 | 244 | +     "execution_count": 18,  | 
 | 245 | +     "metadata": {},  | 
 | 246 | +     "output_type": "execute_result"  | 
 | 247 | +    }  | 
 | 248 | +   ],  | 
 | 249 | +   "source": [  | 
 | 250 | +    "# 创建CRF模型实例\n",  | 
 | 251 | +    "crf = sklearn_crfsuite.CRF(\n",  | 
 | 252 | +    "    algorithm='lbfgs',\n",  | 
 | 253 | +    "    c1=0.1,\n",  | 
 | 254 | +    "    c2=0.1,\n",  | 
 | 255 | +    "    max_iterations=100,\n",  | 
 | 256 | +    "    all_possible_transitions=True\n",  | 
 | 257 | +    ")\n",  | 
 | 258 | +    "# 模型训练\n",  | 
 | 259 | +    "crf.fit(X_train, y_train)\n",  | 
 | 260 | +    "# 类别标签\n",  | 
 | 261 | +    "labels = list(crf.classes_)\n",  | 
 | 262 | +    "labels.remove('O')\n",  | 
 | 263 | +    "# 模型预测\n",  | 
 | 264 | +    "y_pred = crf.predict(X_test)\n",  | 
 | 265 | +    "# 计算F1得分\n",  | 
 | 266 | +    "metrics.flat_f1_score(y_test, y_pred,\n",  | 
 | 267 | +    "                      average='weighted', labels=labels)"  | 
 | 268 | +   ]  | 
 | 269 | +  },  | 
 | 270 | +  {  | 
 | 271 | +   "cell_type": "code",  | 
 | 272 | +   "execution_count": 19,  | 
 | 273 | +   "metadata": {},  | 
 | 274 | +   "outputs": [  | 
 | 275 | +    {  | 
 | 276 | +     "name": "stdout",  | 
 | 277 | +     "output_type": "stream",  | 
 | 278 | +     "text": [  | 
 | 279 | +      "              precision    recall  f1-score   support\n",  | 
 | 280 | +      "\n",  | 
 | 281 | +      "       B-LOC      0.810     0.784     0.797      1084\n",  | 
 | 282 | +      "       I-LOC      0.690     0.637     0.662       325\n",  | 
 | 283 | +      "      B-MISC      0.731     0.569     0.640       339\n",  | 
 | 284 | +      "      I-MISC      0.699     0.589     0.639       557\n",  | 
 | 285 | +      "       B-ORG      0.807     0.832     0.820      1400\n",  | 
 | 286 | +      "       I-ORG      0.852     0.786     0.818      1104\n",  | 
 | 287 | +      "       B-PER      0.850     0.884     0.867       735\n",  | 
 | 288 | +      "       I-PER      0.893     0.943     0.917       634\n",  | 
 | 289 | +      "\n",  | 
 | 290 | +      "   micro avg      0.813     0.787     0.799      6178\n",  | 
 | 291 | +      "   macro avg      0.791     0.753     0.770      6178\n",  | 
 | 292 | +      "weighted avg      0.809     0.787     0.796      6178\n",  | 
 | 293 | +      "\n"  | 
 | 294 | +     ]  | 
 | 295 | +    }  | 
 | 296 | +   ],  | 
 | 297 | +   "source": [  | 
 | 298 | +    "# 打印B和I组的模型结果\n",  | 
 | 299 | +    "sorted_labels = sorted(\n",  | 
 | 300 | +    "    labels,\n",  | 
 | 301 | +    "    key=lambda name: (name[1:], name[0])\n",  | 
 | 302 | +    ")\n",  | 
 | 303 | +    "print(metrics.flat_classification_report(\n",  | 
 | 304 | +    "    y_test, y_pred, labels=sorted_labels, digits=3\n",  | 
 | 305 | +    "))"  | 
 | 306 | +   ]  | 
 | 307 | +  },  | 
 | 308 | +  {  | 
 | 309 | +   "cell_type": "code",  | 
 | 310 | +   "execution_count": null,  | 
 | 311 | +   "metadata": {},  | 
 | 312 | +   "outputs": [],  | 
 | 313 | +   "source": []  | 
 | 314 | +  }  | 
 | 315 | + ],  | 
 | 316 | + "metadata": {  | 
 | 317 | +  "kernelspec": {  | 
 | 318 | +   "display_name": "Python 3",  | 
 | 319 | +   "language": "python",  | 
 | 320 | +   "name": "python3"  | 
 | 321 | +  },  | 
 | 322 | +  "language_info": {  | 
 | 323 | +   "codemirror_mode": {  | 
 | 324 | +    "name": "ipython",  | 
 | 325 | +    "version": 3  | 
 | 326 | +   },  | 
 | 327 | +   "file_extension": ".py",  | 
 | 328 | +   "mimetype": "text/x-python",  | 
 | 329 | +   "name": "python",  | 
 | 330 | +   "nbconvert_exporter": "python",  | 
 | 331 | +   "pygments_lexer": "ipython3",  | 
 | 332 | +   "version": "3.7.3"  | 
 | 333 | +  },  | 
 | 334 | +  "toc": {  | 
 | 335 | +   "base_numbering": 1,  | 
 | 336 | +   "nav_menu": {},  | 
 | 337 | +   "number_sections": true,  | 
 | 338 | +   "sideBar": true,  | 
 | 339 | +   "skip_h1_title": false,  | 
 | 340 | +   "title_cell": "Table of Contents",  | 
 | 341 | +   "title_sidebar": "Contents",  | 
 | 342 | +   "toc_cell": false,  | 
 | 343 | +   "toc_position": {},  | 
 | 344 | +   "toc_section_display": true,  | 
 | 345 | +   "toc_window_display": false  | 
 | 346 | +  }  | 
 | 347 | + },  | 
 | 348 | + "nbformat": 4,  | 
 | 349 | + "nbformat_minor": 2  | 
 | 350 | +}  | 
0 commit comments