38
38
import torch
39
39
import torch .nn as nn
40
40
import torch .optim as optim
41
+ from torch .optim import lr_scheduler
41
42
from torch .autograd import Variable
42
43
import numpy as np
43
44
import torchvision
44
45
from torchvision import datasets , models , transforms
45
46
import matplotlib .pyplot as plt
46
47
import time
47
- import copy
48
48
import os
49
49
50
50
plt .ion () # interactive mode
64
64
# well.
65
65
#
66
66
# This dataset is a very small subset of imagenet.
67
- #
67
+ #
68
68
# .. Note ::
69
69
# Download the data from
70
70
# `here <https://download.pytorch.org/tutorial/hymenoptera_data.zip>`_
71
71
# and extract it to the current directory.
72
72
73
- # Data augmentation and normalization for training
73
+ # Data augmentation and normalization for training
74
74
# Just normalization for validation
75
75
data_transforms = {
76
76
'train' : transforms .Compose ([
88
88
}
89
89
90
90
data_dir = 'hymenoptera_data'
91
- dsets = {x : datasets .ImageFolder (os .path .join (data_dir , x ), data_transforms [x ])
92
- for x in ['train' , 'val' ]}
93
- dset_loaders = {x : torch .utils .data .DataLoader (dsets [x ], batch_size = 4 ,
94
- shuffle = True , num_workers = 4 )
95
- for x in ['train' , 'val' ]}
96
- dset_sizes = {x : len (dsets [x ]) for x in ['train' , 'val' ]}
97
- dset_classes = dsets ['train' ].classes
91
+ image_datasets = {x : datasets .ImageFolder (os .path .join (data_dir , x ),
92
+ data_transforms [x ])
93
+ for x in ['train' , 'val' ]}
94
+ dataloders = {x : torch .utils .data .DataLoader (image_datasets [x ], batch_size = 4 ,
95
+ shuffle = True , num_workers = 4 )
96
+ for x in ['train' , 'val' ]}
97
+ dataset_sizes = {x : len (image_datasets [x ]) for x in ['train' , 'val' ]}
98
+ class_names = image_datasets ['train' ].classes
98
99
99
100
use_gpu = torch .cuda .is_available ()
100
101
101
-
102
102
######################################################################
103
103
# Visualize a few images
104
104
# ^^^^^^^^^^^^^^^^^^^^^^
@@ -118,12 +118,12 @@ def imshow(inp, title=None):
118
118
119
119
120
120
# Get a batch of training data
121
- inputs , classes = next (iter (dset_loaders ['train' ]))
121
+ inputs , classes = next (iter (dataloders ['train' ]))
122
122
123
123
# Make a grid from batch
124
124
out = torchvision .utils .make_grid (inputs )
125
125
126
- imshow (out , title = [dset_classes [x ] for x in classes ])
126
+ imshow (out , title = [class_names [x ] for x in classes ])
127
127
128
128
129
129
######################################################################
@@ -134,16 +134,16 @@ def imshow(inp, title=None):
134
134
# illustrate:
135
135
#
136
136
# - Scheduling the learning rate
137
- # - Saving (deep copying) the best model
137
+ # - Saving the best model
138
138
#
139
- # In the following, parameter ``lr_scheduler(optimizer, epoch)``
140
- # is a function which modifies ``optimizer`` so that the learning
141
- # rate is changed according to desired schedule.
139
+ # In the following, parameter ``scheduler`` is an LR scheduler object from
140
+ # ``torch.optim.lr_scheduler``.
141
+
142
142
143
- def train_model (model , criterion , optimizer , lr_scheduler , num_epochs = 25 ):
143
+ def train_model (model , criterion , optimizer , scheduler , num_epochs = 25 ):
144
144
since = time .time ()
145
145
146
- best_model = model
146
+ best_model_wts = model . state_dict ()
147
147
best_acc = 0.0
148
148
149
149
for epoch in range (num_epochs ):
@@ -153,7 +153,7 @@ def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
153
153
# Each epoch has a training and validation phase
154
154
for phase in ['train' , 'val' ]:
155
155
if phase == 'train' :
156
- optimizer = lr_scheduler ( optimizer , epoch )
156
+ scheduler . step ( )
157
157
model .train (True ) # Set model to training mode
158
158
else :
159
159
model .train (False ) # Set model to evaluate mode
@@ -162,14 +162,14 @@ def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
162
162
running_corrects = 0
163
163
164
164
# Iterate over data.
165
- for data in dset_loaders [phase ]:
165
+ for data in dataloders [phase ]:
166
166
# get the inputs
167
167
inputs , labels = data
168
168
169
169
# wrap them in Variable
170
170
if use_gpu :
171
- inputs , labels = Variable (inputs .cuda ()), \
172
- Variable (labels .cuda ())
171
+ inputs = Variable (inputs .cuda ())
172
+ labels = Variable (labels .cuda ())
173
173
else :
174
174
inputs , labels = Variable (inputs ), Variable (labels )
175
175
@@ -190,42 +190,27 @@ def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
190
190
running_loss += loss .data [0 ]
191
191
running_corrects += torch .sum (preds == labels .data )
192
192
193
- epoch_loss = running_loss / dset_sizes [phase ]
194
- epoch_acc = running_corrects / dset_sizes [phase ]
193
+ epoch_loss = running_loss / dataset_sizes [phase ]
194
+ epoch_acc = running_corrects / dataset_sizes [phase ]
195
195
196
196
print ('{} Loss: {:.4f} Acc: {:.4f}' .format (
197
197
phase , epoch_loss , epoch_acc ))
198
198
199
199
# deep copy the model
200
200
if phase == 'val' and epoch_acc > best_acc :
201
201
best_acc = epoch_acc
202
- best_model = copy . deepcopy ( model )
202
+ best_model_wts = model . state_dict ( )
203
203
204
204
print ()
205
205
206
206
time_elapsed = time .time () - since
207
207
print ('Training complete in {:.0f}m {:.0f}s' .format (
208
208
time_elapsed // 60 , time_elapsed % 60 ))
209
209
print ('Best val Acc: {:4f}' .format (best_acc ))
210
- return best_model
211
-
212
- ######################################################################
213
- # Learning rate scheduler
214
- # ^^^^^^^^^^^^^^^^^^^^^^^
215
- # Let's create our learning rate scheduler. We will exponentially
216
- # decrease the learning rate once every few epochs.
217
210
218
- def exp_lr_scheduler (optimizer , epoch , init_lr = 0.001 , lr_decay_epoch = 7 ):
219
- """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
220
- lr = init_lr * (0.1 ** (epoch // lr_decay_epoch ))
221
-
222
- if epoch % lr_decay_epoch == 0 :
223
- print ('LR is set to {}' .format (lr ))
224
-
225
- for param_group in optimizer .param_groups :
226
- param_group ['lr' ] = lr
227
-
228
- return optimizer
211
+ # load best model weights
212
+ model .load_state_dict (best_model_wts )
213
+ return model
229
214
230
215
231
216
######################################################################
@@ -239,7 +224,7 @@ def visualize_model(model, num_images=6):
239
224
images_so_far = 0
240
225
fig = plt .figure ()
241
226
242
- for i , data in enumerate (dset_loaders ['val' ]):
227
+ for i , data in enumerate (dataloders ['val' ]):
243
228
inputs , labels = data
244
229
if use_gpu :
245
230
inputs , labels = Variable (inputs .cuda ()), Variable (labels .cuda ())
@@ -253,7 +238,7 @@ def visualize_model(model, num_images=6):
253
238
images_so_far += 1
254
239
ax = plt .subplot (num_images // 2 , 2 , images_so_far )
255
240
ax .axis ('off' )
256
- ax .set_title ('predicted: {}' .format (dset_classes [preds [j ]]))
241
+ ax .set_title ('predicted: {}' .format (class_names [preds [j ]]))
257
242
imshow (inputs .cpu ().data [j ])
258
243
259
244
if images_so_far == num_images :
@@ -278,6 +263,9 @@ def visualize_model(model, num_images=6):
278
263
# Observe that all parameters are being optimized
279
264
optimizer_ft = optim .SGD (model_ft .parameters (), lr = 0.001 , momentum = 0.9 )
280
265
266
+ # Decay LR by a factor of 0.1 every 7 epochs
267
+ exp_lr_scheduler = lr_scheduler .StepLR (optimizer_ft , step_size = 7 , gamma = 0.1 )
268
+
281
269
######################################################################
282
270
# Train and evaluate
283
271
# ^^^^^^^^^^^^^^^^^^
@@ -324,6 +312,9 @@ def visualize_model(model, num_images=6):
324
312
# opoosed to before.
325
313
optimizer_conv = optim .SGD (model_conv .fc .parameters (), lr = 0.001 , momentum = 0.9 )
326
314
315
+ # Decay LR by a factor of 0.1 every 7 epochs
316
+ exp_lr_scheduler = lr_scheduler .StepLR (optimizer_conv , step_size = 7 , gamma = 0.1 )
317
+
327
318
328
319
######################################################################
329
320
# Train and evaluate
0 commit comments