Skip to content

Commit f201ff6

Browse files
committed
finetune, data-augmentation
1 parent ed88f1f commit f201ff6

File tree

3 files changed

+239
-0
lines changed

3 files changed

+239
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ Thumbs.db
1313
.idea/
1414

1515
*.pth
16+
data/

pytorch-10/10.2.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
2+
import torch
3+
from torch import nn
4+
import torch.nn.functional as F
5+
import torchvision
6+
import torchvision.transforms as transforms
7+
from torchvision import models
8+
from torchvision.datasets import ImageFolder
9+
from datetime import datetime
10+
11+
12+
def get_acc(output, label):
13+
total = output.shape[0]
14+
_, pred_label = output.max(1)
15+
num_correct = (pred_label == label).sum().item()
16+
return num_correct / total
17+
18+
19+
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
20+
21+
prev_time = datetime.now()
22+
for epoch in range(num_epochs):
23+
train_loss = 0
24+
train_acc = 0
25+
net = net.train()
26+
for im, label in train_data:
27+
im = im.to(device) # (bs, 3, h, w)
28+
label = label.to(device) # (bs, h, w)
29+
# forward
30+
output = net(im)
31+
loss = criterion(output, label)
32+
# backward
33+
optimizer.zero_grad()
34+
loss.backward()
35+
optimizer.step()
36+
37+
train_loss += loss.item()
38+
train_acc += get_acc(output, label)
39+
40+
cur_time = datetime.now()
41+
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
42+
m, s = divmod(remainder, 60)
43+
time_str = "Time %02d:%02d:%02d" % (h, m, s)
44+
if valid_data is not None:
45+
valid_loss = 0
46+
valid_acc = 0
47+
net = net.eval()
48+
for im, label in valid_data:
49+
im = im.to(device) # (bs, 3, h, w)
50+
label = label.to(device) # (bs, h, w)
51+
output = net(im)
52+
loss = criterion(output, label)
53+
valid_loss += loss.item()
54+
valid_acc += get_acc(output, label)
55+
epoch_str = (
56+
"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
57+
% (epoch, train_loss / len(train_data),
58+
train_acc / len(train_data), valid_loss / len(valid_data),
59+
valid_acc / len(valid_data)))
60+
else:
61+
epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
62+
(epoch, train_loss / len(train_data),
63+
train_acc / len(train_data)))
64+
prev_time = cur_time
65+
print(epoch_str + time_str)
66+
67+
68+
trans_train = transforms.Compose(
69+
[transforms.RandomResizedCrop(224),
70+
transforms.RandomHorizontalFlip(),
71+
transforms.ToTensor(),
72+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
73+
std=[0.229, 0.224, 0.225])])
74+
75+
trans_valid = transforms.Compose(
76+
[transforms.Resize(256),
77+
transforms.CenterCrop(224),
78+
transforms.ToTensor(),
79+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
80+
std=[0.229, 0.224, 0.225])])
81+
82+
trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
83+
download=False, transform=trans_train)
84+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
85+
shuffle=True, num_workers=2)
86+
87+
testset = torchvision.datasets.CIFAR10(root='../data', train=False,
88+
download=False, transform=trans_valid)
89+
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
90+
shuffle=False, num_workers=2)
91+
92+
classes = ('plane', 'car', 'bird', 'cat',
93+
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
94+
95+
# # 随机获取部分训练数据
96+
# dataiter = iter(trainloader)
97+
# images, labels = dataiter.next()
98+
99+
# 使用预训练的模型
100+
net = models.resnet18(pretrained=True)
101+
102+
# Freeze model weights
103+
for param in net.parameters():
104+
param.requires_grad = False
105+
106+
# 将最后的全连接层改成十分类
107+
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
108+
net.fc = nn.Linear(512, 10)
109+
#net = torch.nn.DataParallel(net)
110+
111+
# 查看总参数及训练参数
112+
total_params = sum(p.numel() for p in net.parameters())
113+
print('总参数个数:{}'.format(total_params))
114+
total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
115+
print('需训练参数个数:{}'.format(total_trainable_params))
116+
117+
net=net.to(device)
118+
119+
120+
criterion = nn.CrossEntropyLoss()
121+
#只需要优化最后一层参数
122+
optimizer = torch.optim.SGD(net.fc.parameters(), lr=1e-3, weight_decay=1e-3,momentum=0.9)
123+
124+
125+
train(net, trainloader, testloader, 20, optimizer, criterion)

pytorch-10/10.4.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
2+
import torch
3+
from torch import nn
4+
import torch.nn.functional as F
5+
import torchvision
6+
import torchvision.transforms as transforms
7+
from torchvision import models
8+
from torchvision.datasets import ImageFolder
9+
from datetime import datetime
10+
11+
12+
def get_acc(output, label):
13+
total = output.shape[0]
14+
_, pred_label = output.max(1)
15+
num_correct = (pred_label == label).sum().item()
16+
return num_correct / total
17+
18+
19+
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
20+
21+
prev_time = datetime.now()
22+
for epoch in range(num_epochs):
23+
train_loss = 0
24+
train_acc = 0
25+
net = net.train()
26+
for im, label in train_data:
27+
im = im.to(device) # (bs, 3, h, w)
28+
label = label.to(device) # (bs, h, w)
29+
# forward
30+
output = net(im)
31+
loss = criterion(output, label)
32+
# backward
33+
optimizer.zero_grad()
34+
loss.backward()
35+
optimizer.step()
36+
37+
train_loss += loss.item()
38+
train_acc += get_acc(output, label)
39+
40+
cur_time = datetime.now()
41+
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
42+
m, s = divmod(remainder, 60)
43+
time_str = "Time %02d:%02d:%02d" % (h, m, s)
44+
if valid_data is not None:
45+
valid_loss = 0
46+
valid_acc = 0
47+
net = net.eval()
48+
for im, label in valid_data:
49+
im = im.to(device) # (bs, 3, h, w)
50+
label = label.to(device) # (bs, h, w)
51+
output = net(im)
52+
loss = criterion(output, label)
53+
valid_loss += loss.item()
54+
valid_acc += get_acc(output, label)
55+
epoch_str = (
56+
"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
57+
% (epoch, train_loss / len(train_data),
58+
train_acc / len(train_data), valid_loss / len(valid_data),
59+
valid_acc / len(valid_data)))
60+
else:
61+
epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
62+
(epoch, train_loss / len(train_data),
63+
train_acc / len(train_data)))
64+
prev_time = cur_time
65+
print(epoch_str + time_str)
66+
67+
68+
trans_train = transforms.Compose(
69+
[transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
70+
transforms.RandomRotation(degrees=15),
71+
transforms.ColorJitter(),
72+
transforms.RandomResizedCrop(224),
73+
transforms.RandomHorizontalFlip(),
74+
transforms.ToTensor(),
75+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
76+
std=[0.229, 0.224, 0.225])])
77+
78+
trans_valid = transforms.Compose(
79+
[transforms.Resize(256),
80+
transforms.CenterCrop(224),
81+
transforms.ToTensor(),
82+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
83+
std=[0.229, 0.224, 0.225])])
84+
85+
trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
86+
download=False, transform=trans_train)
87+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
88+
shuffle=True, num_workers=2)
89+
90+
testset = torchvision.datasets.CIFAR10(root='../data', train=False,
91+
download=False, transform=trans_valid)
92+
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
93+
shuffle=False, num_workers=2)
94+
95+
classes = ('plane', 'car', 'bird', 'cat',
96+
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
97+
98+
# 使用预训练的模型
99+
net = models.resnet18(pretrained=True)
100+
#print(net)
101+
102+
# 将最后的全连接层改成十分类
103+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
104+
net.fc = nn.Linear(512, 10)
105+
#net = torch.nn.DataParallel(net)
106+
net.to(device)
107+
108+
109+
criterion = nn.CrossEntropyLoss()
110+
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3, weight_decay=1e-3,momentum=0.9)
111+
112+
113+
train(net, trainloader, testloader, 20, optimizer, criterion)

0 commit comments

Comments
 (0)