Skip to content

Commit a0955ef

Browse files
authored
Use more descriptive variable names & lr schedulers (pytorch#141)
1 parent dda0b2f commit a0955ef

File tree

1 file changed

+37
-46
lines changed

1 file changed

+37
-46
lines changed

beginner_source/transfer_learning_tutorial.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@
3838
import torch
3939
import torch.nn as nn
4040
import torch.optim as optim
41+
from torch.optim import lr_scheduler
4142
from torch.autograd import Variable
4243
import numpy as np
4344
import torchvision
4445
from torchvision import datasets, models, transforms
4546
import matplotlib.pyplot as plt
4647
import time
47-
import copy
4848
import os
4949

5050
plt.ion() # interactive mode
@@ -64,13 +64,13 @@
6464
# well.
6565
#
6666
# This dataset is a very small subset of imagenet.
67-
#
67+
#
6868
# .. Note ::
6969
# Download the data from
7070
# `here <https://download.pytorch.org/tutorial/hymenoptera_data.zip>`_
7171
# and extract it to the current directory.
7272

73-
# Data augmentation and normalization for training
73+
# Data augmentation and normalization for training
7474
# Just normalization for validation
7575
data_transforms = {
7676
'train': transforms.Compose([
@@ -88,17 +88,17 @@
8888
}
8989

9090
data_dir = 'hymenoptera_data'
91-
dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
92-
for x in ['train', 'val']}
93-
dset_loaders = {x: torch.utils.data.DataLoader(dsets[x], batch_size=4,
94-
shuffle=True, num_workers=4)
95-
for x in ['train', 'val']}
96-
dset_sizes = {x: len(dsets[x]) for x in ['train', 'val']}
97-
dset_classes = dsets['train'].classes
91+
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
92+
data_transforms[x])
93+
for x in ['train', 'val']}
94+
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
95+
shuffle=True, num_workers=4)
96+
for x in ['train', 'val']}
97+
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
98+
class_names = image_datasets['train'].classes
9899

99100
use_gpu = torch.cuda.is_available()
100101

101-
102102
######################################################################
103103
# Visualize a few images
104104
# ^^^^^^^^^^^^^^^^^^^^^^
@@ -118,12 +118,12 @@ def imshow(inp, title=None):
118118

119119

120120
# Get a batch of training data
121-
inputs, classes = next(iter(dset_loaders['train']))
121+
inputs, classes = next(iter(dataloders['train']))
122122

123123
# Make a grid from batch
124124
out = torchvision.utils.make_grid(inputs)
125125

126-
imshow(out, title=[dset_classes[x] for x in classes])
126+
imshow(out, title=[class_names[x] for x in classes])
127127

128128

129129
######################################################################
@@ -134,16 +134,16 @@ def imshow(inp, title=None):
134134
# illustrate:
135135
#
136136
# - Scheduling the learning rate
137-
# - Saving (deep copying) the best model
137+
# - Saving the best model
138138
#
139-
# In the following, parameter ``lr_scheduler(optimizer, epoch)``
140-
# is a function which modifies ``optimizer`` so that the learning
141-
# rate is changed according to desired schedule.
139+
# In the following, parameter ``scheduler`` is an LR scheduler object from
140+
# ``torch.optim.lr_scheduler``.
141+
142142

143-
def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
143+
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
144144
since = time.time()
145145

146-
best_model = model
146+
best_model_wts = model.state_dict()
147147
best_acc = 0.0
148148

149149
for epoch in range(num_epochs):
@@ -153,7 +153,7 @@ def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
153153
# Each epoch has a training and validation phase
154154
for phase in ['train', 'val']:
155155
if phase == 'train':
156-
optimizer = lr_scheduler(optimizer, epoch)
156+
scheduler.step()
157157
model.train(True) # Set model to training mode
158158
else:
159159
model.train(False) # Set model to evaluate mode
@@ -162,14 +162,14 @@ def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
162162
running_corrects = 0
163163

164164
# Iterate over data.
165-
for data in dset_loaders[phase]:
165+
for data in dataloders[phase]:
166166
# get the inputs
167167
inputs, labels = data
168168

169169
# wrap them in Variable
170170
if use_gpu:
171-
inputs, labels = Variable(inputs.cuda()), \
172-
Variable(labels.cuda())
171+
inputs = Variable(inputs.cuda())
172+
labels = Variable(labels.cuda())
173173
else:
174174
inputs, labels = Variable(inputs), Variable(labels)
175175

@@ -190,42 +190,27 @@ def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
190190
running_loss += loss.data[0]
191191
running_corrects += torch.sum(preds == labels.data)
192192

193-
epoch_loss = running_loss / dset_sizes[phase]
194-
epoch_acc = running_corrects / dset_sizes[phase]
193+
epoch_loss = running_loss / dataset_sizes[phase]
194+
epoch_acc = running_corrects / dataset_sizes[phase]
195195

196196
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
197197
phase, epoch_loss, epoch_acc))
198198

199199
# deep copy the model
200200
if phase == 'val' and epoch_acc > best_acc:
201201
best_acc = epoch_acc
202-
best_model = copy.deepcopy(model)
202+
best_model_wts = model.state_dict()
203203

204204
print()
205205

206206
time_elapsed = time.time() - since
207207
print('Training complete in {:.0f}m {:.0f}s'.format(
208208
time_elapsed // 60, time_elapsed % 60))
209209
print('Best val Acc: {:4f}'.format(best_acc))
210-
return best_model
211-
212-
######################################################################
213-
# Learning rate scheduler
214-
# ^^^^^^^^^^^^^^^^^^^^^^^
215-
# Let's create our learning rate scheduler. We will exponentially
216-
# decrease the learning rate once every few epochs.
217210

218-
def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=7):
219-
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
220-
lr = init_lr * (0.1**(epoch // lr_decay_epoch))
221-
222-
if epoch % lr_decay_epoch == 0:
223-
print('LR is set to {}'.format(lr))
224-
225-
for param_group in optimizer.param_groups:
226-
param_group['lr'] = lr
227-
228-
return optimizer
211+
# load best model weights
212+
model.load_state_dict(best_model_wts)
213+
return model
229214

230215

231216
######################################################################
@@ -239,7 +224,7 @@ def visualize_model(model, num_images=6):
239224
images_so_far = 0
240225
fig = plt.figure()
241226

242-
for i, data in enumerate(dset_loaders['val']):
227+
for i, data in enumerate(dataloders['val']):
243228
inputs, labels = data
244229
if use_gpu:
245230
inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
@@ -253,7 +238,7 @@ def visualize_model(model, num_images=6):
253238
images_so_far += 1
254239
ax = plt.subplot(num_images//2, 2, images_so_far)
255240
ax.axis('off')
256-
ax.set_title('predicted: {}'.format(dset_classes[preds[j]]))
241+
ax.set_title('predicted: {}'.format(class_names[preds[j]]))
257242
imshow(inputs.cpu().data[j])
258243

259244
if images_so_far == num_images:
@@ -278,6 +263,9 @@ def visualize_model(model, num_images=6):
278263
# Observe that all parameters are being optimized
279264
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
280265

266+
# Decay LR by a factor of 0.1 every 7 epochs
267+
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
268+
281269
######################################################################
282270
# Train and evaluate
283271
# ^^^^^^^^^^^^^^^^^^
@@ -324,6 +312,9 @@ def visualize_model(model, num_images=6):
324312
# opoosed to before.
325313
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
326314

315+
# Decay LR by a factor of 0.1 every 7 epochs
316+
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
317+
327318

328319
######################################################################
329320
# Train and evaluate

0 commit comments

Comments
 (0)