Skip to content

Commit b818c85

Browse files
Merge pull request falloutdurham#50 from MarcusFra/ch9_FastBERT
Add FastBERT --> chapter9/Fast_bert_.ipynb
2 parents 528157c + 8698185 commit b818c85

File tree

1 file changed

+224
-0
lines changed

1 file changed

+224
-0
lines changed

chapter9/Fast_bert_.ipynb

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"collapsed": true,
7+
"pycharm": {
8+
"name": "#%% md\n"
9+
}
10+
},
11+
"source": [
12+
"## FastBERT"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": null,
18+
"outputs": [],
19+
"source": [
20+
"!pip install fast-bert"
21+
],
22+
"metadata": {
23+
"collapsed": false,
24+
"pycharm": {
25+
"name": "#%%\n"
26+
}
27+
}
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": null,
32+
"outputs": [],
33+
"source": [
34+
"import logging\n",
35+
"import numpy as np\n",
36+
"import pandas as pd\n",
37+
"import torch\n",
38+
"\n",
39+
"from transformers import BertTokenizer\n",
40+
"from fast_bert.data_cls import BertDataBunch\n",
41+
"from fast_bert.learner_cls import BertLearner\n",
42+
"from fast_bert.metrics import accuracy"
43+
],
44+
"metadata": {
45+
"collapsed": false,
46+
"pycharm": {
47+
"name": "#%%\n"
48+
}
49+
}
50+
},
51+
{
52+
"cell_type": "markdown",
53+
"source": [
54+
"Before execution create directories with `mkdir twitterdata labels`\n",
55+
"\n",
56+
"Then set paths:"
57+
],
58+
"metadata": {
59+
"collapsed": false
60+
}
61+
},
62+
{
63+
"cell_type": "code",
64+
"execution_count": null,
65+
"outputs": [],
66+
"source": [
67+
"PATH_TO_DATA = \"./twitterdata/\"\n",
68+
"PATH_TO_LABELS = \"./labels/\"\n",
69+
"OUTPUT_DIR = \"./\""
70+
],
71+
"metadata": {
72+
"collapsed": false,
73+
"pycharm": {
74+
"name": "#%%\n"
75+
}
76+
}
77+
},
78+
{
79+
"cell_type": "markdown",
80+
"source": [
81+
"Read relevant data from Chapter 5, split data set (60/20/20) and save data sets as csv"
82+
],
83+
"metadata": {
84+
"collapsed": false
85+
}
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": null,
90+
"outputs": [],
91+
"source": [
92+
"df = pd.read_csv('../chapter5/train-processed.csv', encoding='latin-1')\n",
93+
"df = df.drop(df.columns[[0, 1, 2, 3, 4, 6]], axis=1)\n",
94+
"df.columns = ['text', 'label']\n",
95+
"\n",
96+
"# https://stackoverflow.com/questions/38250710/\n",
97+
"# how-to-split-data-into-3-sets-train-validation-and-test/38251213#38251213\n",
98+
"np.random.seed(0)\n",
99+
"train, valid, test = \\\n",
100+
" np.split(df.sample(frac=1), [int(.6*len(df)), int(.8*len(df))])\n",
101+
"\n",
102+
"train.to_csv('./twitterdata/train.csv', index=False)\n",
103+
"valid.to_csv('./twitterdata/valid.csv', index=False)\n",
104+
"test.to_csv('./twitterdata/test.csv', index=False)"
105+
],
106+
"metadata": {
107+
"collapsed": false,
108+
"pycharm": {
109+
"name": "#%%\n"
110+
}
111+
}
112+
},
113+
{
114+
"cell_type": "markdown",
115+
"source": [
116+
"Get labels and save them in separate directory `labels`/`PATH_TO_LABELS` as csv"
117+
],
118+
"metadata": {
119+
"collapsed": false
120+
}
121+
},
122+
{
123+
"cell_type": "code",
124+
"source": [
125+
"labels = pd.DataFrame(df.label.unique())\n",
126+
"labels.to_csv(\"./labels/labels.csv\", header=False, index=False)"
127+
],
128+
"metadata": {
129+
"collapsed": false,
130+
"pycharm": {
131+
"name": "#%%\n"
132+
}
133+
},
134+
"execution_count": null,
135+
"outputs": []
136+
},
137+
{
138+
"cell_type": "markdown",
139+
"source": [
140+
"Define and train model"
141+
],
142+
"metadata": {
143+
"collapsed": false,
144+
"pycharm": {
145+
"name": "#%% md\n"
146+
}
147+
}
148+
},
149+
{
150+
"cell_type": "code",
151+
"source": [
152+
"device = torch.device('cuda')\n",
153+
"logger = logging.getLogger()\n",
154+
"metrics = [{'name': 'accuracy', 'function': accuracy}]\n",
155+
"\n",
156+
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',\n",
157+
" do_lower_case=True)\n",
158+
"\n",
159+
"databunch = BertDataBunch(PATH_TO_DATA,\n",
160+
" PATH_TO_LABELS,\n",
161+
" tokenizer,\n",
162+
" train_file=\"train.csv\",\n",
163+
" val_file=\"valid.csv\",\n",
164+
" test_data=\"test.csv\",\n",
165+
" text_col=0, label_col=1,\n",
166+
" batch_size_per_gpu=32,\n",
167+
" max_seq_length=140,\n",
168+
" multi_gpu=False,\n",
169+
" multi_label=False,\n",
170+
" model_type=\"bert\")\n",
171+
"\n",
172+
"learner = BertLearner.from_pretrained_model(databunch,\n",
173+
" 'bert-base-uncased',\n",
174+
" metrics=metrics,\n",
175+
" device=device,\n",
176+
" logger=logger,\n",
177+
" output_dir=OUTPUT_DIR,\n",
178+
" is_fp16=False,\n",
179+
" multi_gpu=False,\n",
180+
" multi_label=False)\n",
181+
"\n",
182+
"learner.fit(3, lr=1e-2)\n"
183+
],
184+
"metadata": {
185+
"collapsed": false,
186+
"pycharm": {
187+
"name": "#%%\n"
188+
}
189+
},
190+
"execution_count": null,
191+
"outputs": []
192+
},
193+
{
194+
"cell_type": "markdown",
195+
"source": [
196+
"\n"
197+
],
198+
"metadata": {
199+
"collapsed": false
200+
}
201+
}
202+
],
203+
"metadata": {
204+
"kernelspec": {
205+
"display_name": "Python 3",
206+
"language": "python",
207+
"name": "python3"
208+
},
209+
"language_info": {
210+
"codemirror_mode": {
211+
"name": "ipython",
212+
"version": 2
213+
},
214+
"file_extension": ".py",
215+
"mimetype": "text/x-python",
216+
"name": "python",
217+
"nbconvert_exporter": "python",
218+
"pygments_lexer": "ipython2",
219+
"version": "2.7.6"
220+
}
221+
},
222+
"nbformat": 4,
223+
"nbformat_minor": 0
224+
}

0 commit comments

Comments
 (0)