|
9 | 9 | },
|
10 | 10 | {
|
11 | 11 | "cell_type": "code",
|
12 |
| - "execution_count": null, |
| 12 | + "execution_count": 37, |
13 | 13 | "metadata": {},
|
14 | 14 | "outputs": [],
|
15 | 15 | "source": [
|
16 | 16 | "import torch \n",
|
17 | 17 | "import torch.nn as nn\n",
|
| 18 | + "import torch.nn.functional as F\n", |
| 19 | + "import torch.optim as optim\n", |
18 | 20 | "import numpy as np\n",
|
19 | 21 | "from torchtext import data \n",
|
20 | 22 | "import torchtext\n",
|
|
30 | 32 | "## Loading & Data Cleaning"
|
31 | 33 | ]
|
32 | 34 | },
|
| 35 | + { |
| 36 | + "cell_type": "code", |
| 37 | + "execution_count": 8, |
| 38 | + "metadata": {}, |
| 39 | + "outputs": [], |
| 40 | + "source": [ |
| 41 | + "device = \"cuda\"\n" |
| 42 | + ] |
| 43 | + }, |
33 | 44 | {
|
34 | 45 | "cell_type": "code",
|
35 | 46 | "execution_count": null,
|
36 | 47 | "metadata": {},
|
37 | 48 | "outputs": [],
|
38 | 49 | "source": [
|
39 |
| - "device = \"cuda\"\n", |
40 | 50 | "# You'll probably need to use the 'python' engine to load the CSV\n",
|
41 | 51 | "# tweetsDF = pd.read_csv(\"training.1600000.processed.noemoticon.csv\", header=None)\n",
|
42 | 52 | "tweetsDF = pd.read_csv(\"training.1600000.processed.noemoticon.csv\", \n",
|
|
66 | 76 | },
|
67 | 77 | {
|
68 | 78 | "cell_type": "code",
|
69 |
| - "execution_count": null, |
| 79 | + "execution_count": 38, |
70 | 80 | "metadata": {},
|
71 | 81 | "outputs": [],
|
72 | 82 | "source": [
|
73 | 83 | "LABEL = data.LabelField()\n",
|
74 |
| - "TWEET = data.Field(tokenize='spacy', lower=true)\n", |
| 84 | + "TWEET = data.Field(tokenize='spacy', lower=True)\n", |
75 | 85 | "\n",
|
76 | 86 | "fields = [('score',None), ('id',None),('date',None),('query',None),\n",
|
77 | 87 | " ('name',None),\n",
|
|
87 | 97 | },
|
88 | 98 | {
|
89 | 99 | "cell_type": "code",
|
90 |
| - "execution_count": null, |
| 100 | + "execution_count": 39, |
91 | 101 | "metadata": {},
|
92 | 102 | "outputs": [],
|
93 | 103 | "source": [
|
|
100 | 110 | },
|
101 | 111 | {
|
102 | 112 | "cell_type": "code",
|
103 |
| - "execution_count": null, |
| 113 | + "execution_count": 40, |
104 | 114 | "metadata": {},
|
105 |
| - "outputs": [], |
| 115 | + "outputs": [ |
| 116 | + { |
| 117 | + "data": { |
| 118 | + "text/plain": [ |
| 119 | + "(6000, 2000, 2000)" |
| 120 | + ] |
| 121 | + }, |
| 122 | + "execution_count": 40, |
| 123 | + "metadata": {}, |
| 124 | + "output_type": "execute_result" |
| 125 | + } |
| 126 | + ], |
106 | 127 | "source": [
|
107 |
| - "(train, test, valid) = twitterDataset.split(split_ratio=[0.8,0.1,0.1])\n", |
| 128 | + "(train, test, valid)=twitterDataset.split(split_ratio=[0.6,0.2,0.2],stratified=True, strata_field='label')\n", |
108 | 129 | "\n",
|
109 | 130 | "(len(train),len(test),len(valid))"
|
110 | 131 | ]
|
111 | 132 | },
|
112 | 133 | {
|
113 | 134 | "cell_type": "code",
|
114 |
| - "execution_count": null, |
| 135 | + "execution_count": 41, |
115 | 136 | "metadata": {},
|
116 |
| - "outputs": [], |
| 137 | + "outputs": [ |
| 138 | + { |
| 139 | + "data": { |
| 140 | + "text/plain": [ |
| 141 | + "[('i', 3742),\n", |
| 142 | + " ('!', 3315),\n", |
| 143 | + " ('.', 3084),\n", |
| 144 | + " (' ', 2175),\n", |
| 145 | + " ('to', 2115),\n", |
| 146 | + " ('the', 2022),\n", |
| 147 | + " (',', 1823),\n", |
| 148 | + " ('a', 1461),\n", |
| 149 | + " ('my', 1205),\n", |
| 150 | + " ('it', 1197)]" |
| 151 | + ] |
| 152 | + }, |
| 153 | + "execution_count": 41, |
| 154 | + "metadata": {}, |
| 155 | + "output_type": "execute_result" |
| 156 | + } |
| 157 | + ], |
117 | 158 | "source": [
|
118 | 159 | "vocab_size = 20000\n",
|
119 | 160 | "TWEET.build_vocab(train, max_size = vocab_size)\n",
|
|
123 | 164 | },
|
124 | 165 | {
|
125 | 166 | "cell_type": "code",
|
126 |
| - "execution_count": null, |
| 167 | + "execution_count": 42, |
127 | 168 | "metadata": {},
|
128 | 169 | "outputs": [],
|
129 | 170 | "source": [
|
130 | 171 | "train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
|
131 | 172 | "(train, valid, test), \n",
|
132 | 173 | "batch_size = 32,\n",
|
133 |
| - "device = device)" |
| 174 | + "device = device,\n", |
| 175 | + "sort_key = lambda x: len(x.tweet),\n", |
| 176 | + "sort_within_batch = False)" |
134 | 177 | ]
|
135 | 178 | },
|
136 | 179 | {
|
|
142 | 185 | },
|
143 | 186 | {
|
144 | 187 | "cell_type": "code",
|
145 |
| - "execution_count": null, |
| 188 | + "execution_count": 43, |
146 | 189 | "metadata": {},
|
147 |
| - "outputs": [], |
| 190 | + "outputs": [ |
| 191 | + { |
| 192 | + "data": { |
| 193 | + "text/plain": [ |
| 194 | + "OurFirstLSTM(\n", |
| 195 | + " (embedding): Embedding(20002, 300)\n", |
| 196 | + " (encoder): LSTM(300, 100)\n", |
| 197 | + " (predictor): Linear(in_features=100, out_features=2, bias=True)\n", |
| 198 | + ")" |
| 199 | + ] |
| 200 | + }, |
| 201 | + "execution_count": 43, |
| 202 | + "metadata": {}, |
| 203 | + "output_type": "execute_result" |
| 204 | + } |
| 205 | + ], |
148 | 206 | "source": [
|
149 | 207 | "class OurFirstLSTM(nn.Module):\n",
|
150 | 208 | " def __init__(self, hidden_size, embedding_dim, vocab_size):\n",
|
|
173 | 231 | },
|
174 | 232 | {
|
175 | 233 | "cell_type": "code",
|
176 |
| - "execution_count": null, |
| 234 | + "execution_count": 44, |
177 | 235 | "metadata": {},
|
178 | 236 | "outputs": [],
|
179 | 237 | "source": [
|
|
187 | 245 | " valid_loss = 0.0\n",
|
188 | 246 | " model.train()\n",
|
189 | 247 | " for batch_idx, batch in enumerate(train_iterator):\n",
|
190 |
| - " opt.zero_grad()\n", |
| 248 | + " optimizer.zero_grad()\n", |
191 | 249 | " predict = model(batch.tweet)\n",
|
192 | 250 | " loss = criterion(predict,batch.label)\n",
|
193 | 251 | " loss.backward()\n",
|
|
203 | 261 | " valid_loss += loss.data.item() * batch.tweet.size(0)\n",
|
204 | 262 | " \n",
|
205 | 263 | " valid_loss /= len(valid_iterator)\n",
|
206 |
| - " print('Epoch: {}, Training Loss: {:.2f}, \n", |
207 |
| - " Validation Loss: {:.2f}'.format(epoch, training_loss, valid_loss))" |
| 264 | + " print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}'.format(epoch, training_loss, valid_loss))" |
| 265 | + ] |
| 266 | + }, |
| 267 | + { |
| 268 | + "cell_type": "code", |
| 269 | + "execution_count": 45, |
| 270 | + "metadata": {}, |
| 271 | + "outputs": [ |
| 272 | + { |
| 273 | + "name": "stdout", |
| 274 | + "output_type": "stream", |
| 275 | + "text": [ |
| 276 | + "Epoch: 1, Training Loss: 24.47, Validation Loss: 14.04\n", |
| 277 | + "Epoch: 2, Training Loss: 23.81, Validation Loss: 14.57\n", |
| 278 | + "Epoch: 3, Training Loss: 23.25, Validation Loss: 15.69\n", |
| 279 | + "Epoch: 4, Training Loss: 23.12, Validation Loss: 16.16\n", |
| 280 | + "Epoch: 5, Training Loss: 21.71, Validation Loss: 18.80\n" |
| 281 | + ] |
| 282 | + } |
| 283 | + ], |
| 284 | + "source": [ |
| 285 | + "train(5, model, optimizer, criterion, train_iterator, valid_iterator) " |
208 | 286 | ]
|
209 | 287 | },
|
210 | 288 | {
|
|
219 | 297 | "execution_count": null,
|
220 | 298 | "metadata": {},
|
221 | 299 | "outputs": [],
|
| 300 | + "source": [] |
| 301 | + }, |
| 302 | + { |
| 303 | + "cell_type": "code", |
| 304 | + "execution_count": 46, |
| 305 | + "metadata": {}, |
| 306 | + "outputs": [], |
222 | 307 | "source": [
|
223 | 308 | "def classify_tweet(tweet):\n",
|
224 | 309 | " categories = {0: \"Negative\", 1:\"Positive\"}\n",
|
|
0 commit comments