Skip to content

Commit 28c4934

Browse files
committed
Add sort_key lambda to make training loop work!
1 parent 64c6d3b commit 28c4934

File tree

1 file changed

+103
-18
lines changed

1 file changed

+103
-18
lines changed

chapter5/Chapter 5.ipynb

Lines changed: 103 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
},
1010
{
1111
"cell_type": "code",
12-
"execution_count": null,
12+
"execution_count": 37,
1313
"metadata": {},
1414
"outputs": [],
1515
"source": [
1616
"import torch \n",
1717
"import torch.nn as nn\n",
18+
"import torch.nn.functional as F\n",
19+
"import torch.optim as optim\n",
1820
"import numpy as np\n",
1921
"from torchtext import data \n",
2022
"import torchtext\n",
@@ -30,13 +32,21 @@
3032
"## Loading & Data Cleaning"
3133
]
3234
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": 8,
38+
"metadata": {},
39+
"outputs": [],
40+
"source": [
41+
"device = \"cuda\"\n"
42+
]
43+
},
3344
{
3445
"cell_type": "code",
3546
"execution_count": null,
3647
"metadata": {},
3748
"outputs": [],
3849
"source": [
39-
"device = \"cuda\"\n",
4050
"# You'll probably need to use the 'python' engine to load the CSV\n",
4151
"# tweetsDF = pd.read_csv(\"training.1600000.processed.noemoticon.csv\", header=None)\n",
4252
"tweetsDF = pd.read_csv(\"training.1600000.processed.noemoticon.csv\", \n",
@@ -66,12 +76,12 @@
6676
},
6777
{
6878
"cell_type": "code",
69-
"execution_count": null,
79+
"execution_count": 38,
7080
"metadata": {},
7181
"outputs": [],
7282
"source": [
7383
"LABEL = data.LabelField()\n",
74-
"TWEET = data.Field(tokenize='spacy', lower=true)\n",
84+
"TWEET = data.Field(tokenize='spacy', lower=True)\n",
7585
"\n",
7686
"fields = [('score',None), ('id',None),('date',None),('query',None),\n",
7787
" ('name',None),\n",
@@ -87,7 +97,7 @@
8797
},
8898
{
8999
"cell_type": "code",
90-
"execution_count": null,
100+
"execution_count": 39,
91101
"metadata": {},
92102
"outputs": [],
93103
"source": [
@@ -100,20 +110,51 @@
100110
},
101111
{
102112
"cell_type": "code",
103-
"execution_count": null,
113+
"execution_count": 40,
104114
"metadata": {},
105-
"outputs": [],
115+
"outputs": [
116+
{
117+
"data": {
118+
"text/plain": [
119+
"(6000, 2000, 2000)"
120+
]
121+
},
122+
"execution_count": 40,
123+
"metadata": {},
124+
"output_type": "execute_result"
125+
}
126+
],
106127
"source": [
107-
"(train, test, valid) = twitterDataset.split(split_ratio=[0.8,0.1,0.1])\n",
128+
"(train, test, valid)=twitterDataset.split(split_ratio=[0.6,0.2,0.2],stratified=True, strata_field='label')\n",
108129
"\n",
109130
"(len(train),len(test),len(valid))"
110131
]
111132
},
112133
{
113134
"cell_type": "code",
114-
"execution_count": null,
135+
"execution_count": 41,
115136
"metadata": {},
116-
"outputs": [],
137+
"outputs": [
138+
{
139+
"data": {
140+
"text/plain": [
141+
"[('i', 3742),\n",
142+
" ('!', 3315),\n",
143+
" ('.', 3084),\n",
144+
" (' ', 2175),\n",
145+
" ('to', 2115),\n",
146+
" ('the', 2022),\n",
147+
" (',', 1823),\n",
148+
" ('a', 1461),\n",
149+
" ('my', 1205),\n",
150+
" ('it', 1197)]"
151+
]
152+
},
153+
"execution_count": 41,
154+
"metadata": {},
155+
"output_type": "execute_result"
156+
}
157+
],
117158
"source": [
118159
"vocab_size = 20000\n",
119160
"TWEET.build_vocab(train, max_size = vocab_size)\n",
@@ -123,14 +164,16 @@
123164
},
124165
{
125166
"cell_type": "code",
126-
"execution_count": null,
167+
"execution_count": 42,
127168
"metadata": {},
128169
"outputs": [],
129170
"source": [
130171
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
131172
"(train, valid, test), \n",
132173
"batch_size = 32,\n",
133-
"device = device)"
174+
"device = device,\n",
175+
"sort_key = lambda x: len(x.tweet),\n",
176+
"sort_within_batch = False)"
134177
]
135178
},
136179
{
@@ -142,9 +185,24 @@
142185
},
143186
{
144187
"cell_type": "code",
145-
"execution_count": null,
188+
"execution_count": 43,
146189
"metadata": {},
147-
"outputs": [],
190+
"outputs": [
191+
{
192+
"data": {
193+
"text/plain": [
194+
"OurFirstLSTM(\n",
195+
" (embedding): Embedding(20002, 300)\n",
196+
" (encoder): LSTM(300, 100)\n",
197+
" (predictor): Linear(in_features=100, out_features=2, bias=True)\n",
198+
")"
199+
]
200+
},
201+
"execution_count": 43,
202+
"metadata": {},
203+
"output_type": "execute_result"
204+
}
205+
],
148206
"source": [
149207
"class OurFirstLSTM(nn.Module):\n",
150208
" def __init__(self, hidden_size, embedding_dim, vocab_size):\n",
@@ -173,7 +231,7 @@
173231
},
174232
{
175233
"cell_type": "code",
176-
"execution_count": null,
234+
"execution_count": 44,
177235
"metadata": {},
178236
"outputs": [],
179237
"source": [
@@ -187,7 +245,7 @@
187245
" valid_loss = 0.0\n",
188246
" model.train()\n",
189247
" for batch_idx, batch in enumerate(train_iterator):\n",
190-
" opt.zero_grad()\n",
248+
" optimizer.zero_grad()\n",
191249
" predict = model(batch.tweet)\n",
192250
" loss = criterion(predict,batch.label)\n",
193251
" loss.backward()\n",
@@ -203,8 +261,28 @@
203261
" valid_loss += loss.data.item() * batch.tweet.size(0)\n",
204262
" \n",
205263
" valid_loss /= len(valid_iterator)\n",
206-
" print('Epoch: {}, Training Loss: {:.2f}, \n",
207-
" Validation Loss: {:.2f}'.format(epoch, training_loss, valid_loss))"
264+
" print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}'.format(epoch, training_loss, valid_loss))"
265+
]
266+
},
267+
{
268+
"cell_type": "code",
269+
"execution_count": 45,
270+
"metadata": {},
271+
"outputs": [
272+
{
273+
"name": "stdout",
274+
"output_type": "stream",
275+
"text": [
276+
"Epoch: 1, Training Loss: 24.47, Validation Loss: 14.04\n",
277+
"Epoch: 2, Training Loss: 23.81, Validation Loss: 14.57\n",
278+
"Epoch: 3, Training Loss: 23.25, Validation Loss: 15.69\n",
279+
"Epoch: 4, Training Loss: 23.12, Validation Loss: 16.16\n",
280+
"Epoch: 5, Training Loss: 21.71, Validation Loss: 18.80\n"
281+
]
282+
}
283+
],
284+
"source": [
285+
"train(5, model, optimizer, criterion, train_iterator, valid_iterator) "
208286
]
209287
},
210288
{
@@ -219,6 +297,13 @@
219297
"execution_count": null,
220298
"metadata": {},
221299
"outputs": [],
300+
"source": []
301+
},
302+
{
303+
"cell_type": "code",
304+
"execution_count": 46,
305+
"metadata": {},
306+
"outputs": [],
222307
"source": [
223308
"def classify_tweet(tweet):\n",
224309
" categories = {0: \"Negative\", 1:\"Positive\"}\n",

0 commit comments

Comments
 (0)