Skip to content

Commit fc26c2f

Browse files
committed
训练mnist
1 parent 1876937 commit fc26c2f

File tree

12 files changed

+321
-0
lines changed

12 files changed

+321
-0
lines changed

pytorch-03/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
logs/
2+
data/

pytorch-03/net.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from torch import nn
2+
import torch.nn.functional as F
3+
4+
5+
class Net(nn.Module):
6+
"""
7+
使用sequential构建网络,Sequential()函数的功能是将网络的层组合到一起
8+
"""
9+
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
10+
super(Net, self).__init__()
11+
self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.BatchNorm1d(n_hidden_1))
12+
self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2),nn.BatchNorm1d(n_hidden_2))
13+
self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
14+
15+
16+
def forward(self, x):
17+
x = F.relu(self.layer1(x))
18+
x = F.relu(self.layer2(x))
19+
x = self.layer3(x)
20+
return x

pytorch-03/net1.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
3+
4+
class Net1(torch.nn.Module):
5+
def __init__(self):
6+
super(Net4, self).__init__()
7+
self.conv = torch.nn.Sequential(
8+
OrderedDict(
9+
[
10+
("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
11+
("relu1", torch.nn.ReLU()),
12+
("pool", torch.nn.MaxPool2d(2))
13+
]
14+
))
15+
16+
self.dense = torch.nn.Sequential(
17+
OrderedDict([
18+
("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
19+
("relu2", torch.nn.ReLU()),
20+
("dense2", torch.nn.Linear(128, 10))
21+
])
22+
)
23+
24+
25+
26+
27+

pytorch-03/net2.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
import torch
3+
from torch import nn
4+
import torch.nn.functional as F
5+
6+
7+
class Net2(torch.nn.Module):
8+
# 初始化
9+
def __init__(self):
10+
super(Net2, self).__init__()
11+
self.hidden = torch.nn.Linear(1, 20)
12+
self.predict = torch.nn.Linear(20, 1)
13+
14+
# 前向传递
15+
def forward(self, x):
16+
x = F.relu(self.hidden(x))
17+
x = self.predict(x)
18+
return x
19+

pytorch-03/optimizers.png

82.2 KB
Loading

pytorch-03/plot.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
5+
6+
def plot_loss(losses, eval_losses, save_path):
7+
plt.clf() # clear figure
8+
plt.title('train-val-loss')
9+
plt.plot(np.arange(len(losses)), losses)
10+
plt.plot(np.arange(len(eval_losses)), eval_losses)
11+
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
12+
plt.legend(['Train Loss'], loc='upper right')
13+
plt.savefig(save_path)
14+
15+
16+
def plot_acc(acces, eval_acces, save_path):
17+
plt.clf() # clear figure
18+
plt.title('train-val-acc')
19+
plt.plot(np.arange(len(acces)), acces)
20+
plt.plot(np.arange(len(eval_acces)), eval_acces)
21+
plt.legend(['Train acc', 'Test acc'], loc='upper right')
22+
plt.legend(['Train acc'], loc='upper right')
23+
plt.savefig(save_path)
24+
25+

pytorch-03/train.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
2+
import numpy as np
3+
import torch
4+
# 导入 pytorch 内置的 mnist 数据
5+
from torchvision.datasets import mnist
6+
#import torchvision
7+
#导入预处理模块
8+
import torchvision.transforms as transforms
9+
from torch.utils.data import DataLoader
10+
#导入nn及优化器
11+
import torch.nn.functional as F
12+
import torch.optim as optim
13+
from torch import nn
14+
15+
from tensorboardX import SummaryWriter
16+
17+
from net import Net
18+
from plot import plot_acc, plot_loss
19+
20+
# 定义一些超参数
21+
train_batch_size = 64
22+
test_batch_size = 128
23+
learning_rate = 0.01
24+
num_epoches = 20
25+
26+
27+
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
28+
29+
#定义预处理函数
30+
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
31+
#下载数据,并对数据进行预处理
32+
train_dataset = mnist.MNIST('./data', train=True, transform=transform, download=True)
33+
test_dataset = mnist.MNIST('./data', train=False, transform=transform)
34+
#得到一个生成器
35+
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
36+
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
37+
38+
# examples = enumerate(test_loader)
39+
# batch_idx, (example_data, example_targets) = next(examples)
40+
41+
42+
lr = 0.01
43+
momentum = 0.9
44+
45+
#实例化模型
46+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
47+
#if torch.cuda.device_count() > 1:
48+
# print("Let's use", torch.cuda.device_count(), "GPUs")
49+
# # dim = 0 [20, xxx] -> [10, ...], [10, ...] on 2GPUs
50+
# model = nn.DataParallel(model)
51+
model = Net(28 * 28, 300, 100, 10)
52+
model.to(device)
53+
54+
# 定义损失函数和优化器
55+
criterion = nn.CrossEntropyLoss()
56+
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
57+
58+
59+
# 开始训练
60+
losses = []
61+
acces = []
62+
eval_losses = []
63+
eval_acces = []
64+
writer = SummaryWriter(log_dir='logs',comment='train-loss')
65+
66+
for epoch in range(num_epoches):
67+
train_loss = 0
68+
train_acc = 0
69+
model.train()
70+
#动态修改参数学习率
71+
if epoch%5==0:
72+
optimizer.param_groups[0]['lr']*=0.9
73+
print(optimizer.param_groups[0]['lr'])
74+
for img, label in train_loader:
75+
img = img.to(device)
76+
label = label.to(device)
77+
img = img.view(img.size(0), -1)
78+
# 前向传播
79+
out = model(img)
80+
loss = criterion(out, label)
81+
# 反向传播
82+
optimizer.zero_grad()
83+
loss.backward()
84+
optimizer.step()
85+
# 记录误差
86+
train_loss += loss.item()
87+
# 保存loss的数据与epoch数值
88+
writer.add_scalar('Train', train_loss/len(train_loader), epoch)
89+
# 计算分类的准确率
90+
_, pred = out.max(1)
91+
num_correct = (pred == label).sum().item()
92+
acc = num_correct / img.shape[0]
93+
train_acc += acc
94+
95+
losses.append(train_loss / len(train_loader))
96+
acces.append(train_acc / len(train_loader))
97+
# 在测试集上检验效果
98+
eval_loss = 0
99+
eval_acc = 0
100+
#net.eval() # 将模型改为预测模式
101+
model.eval()
102+
for img, label in test_loader:
103+
img = img.to(device)
104+
label = label.to(device)
105+
img = img.view(img.size(0), -1)
106+
out = model(img)
107+
loss = criterion(out, label)
108+
# 记录误差
109+
eval_loss += loss.item()
110+
# 记录准确率
111+
_, pred = out.max(1)
112+
num_correct = (pred == label).sum().item()
113+
acc = num_correct / img.shape[0]
114+
eval_acc += acc
115+
116+
eval_losses.append(eval_loss / len(test_loader))
117+
eval_acces.append(eval_acc / len(test_loader))
118+
print('epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'
119+
.format(epoch, train_loss / len(train_loader), train_acc / len(train_loader),
120+
eval_loss / len(test_loader), eval_acc / len(test_loader)))
121+
122+
123+
124+
plot_loss(losses, eval_losses, 'train_loss.png')
125+
plot_acc(acces, eval_acces, 'train_acc.png')
126+
127+
128+
129+
130+
131+
132+
133+
134+
135+
136+
137+
138+

pytorch-03/train1.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
import torch.utils.data as Data
3+
import torch.nn.functional as F
4+
import matplotlib.pyplot as plt
5+
# %matplotlib inline
6+
7+
8+
9+
10+
11+
12+

pytorch-03/train2.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
2+
import torch
3+
import torch.utils.data as Data
4+
import torch.nn.functional as F
5+
import matplotlib.pyplot as plt
6+
from net2 import Net2
7+
8+
9+
# 超参数
10+
LR = 0.01
11+
BATCH_SIZE = 32
12+
EPOCH = 12
13+
14+
# 生成训练数据
15+
# torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据
16+
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
17+
# 0.1 * torch.normal(x.size())增加噪点
18+
y = x.pow(2) + 0.1 * torch.normal(torch.zeros(*x.size()))
19+
20+
21+
torch_dataset = Data.TensorDataset(x,y)
22+
#得到一个代批量的生成器
23+
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True)
24+
25+
26+
net_SGD = Net2()
27+
net_Momentum = Net2()
28+
net_RMSProp = Net2()
29+
net_Adam = Net2()
30+
31+
nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]
32+
33+
opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
34+
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.9)
35+
opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)
36+
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
37+
optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]
38+
39+
loss_func = torch.nn.MSELoss()
40+
41+
loss_his = [[], [], [], []] # 记录损失
42+
43+
for epoch in range(EPOCH):
44+
for step, (batch_x, batch_y) in enumerate(loader):
45+
for net, opt,l_his in zip(nets, optimizers, loss_his):
46+
output = net(batch_x) # get output for every net
47+
loss = loss_func(output, batch_y) # compute loss for every net
48+
opt.zero_grad() # clear gradients for next train
49+
loss.backward() # backpropagation, compute gradients
50+
opt.step() # apply gradients
51+
l_his.append(loss.data.numpy()) # loss recoder
52+
labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
53+
for i, l_his in enumerate(loss_his):
54+
plt.plot(l_his, label=labels[i])
55+
plt.legend(loc='best')
56+
plt.xlabel('Steps')
57+
plt.ylabel('Loss')
58+
plt.ylim((0, 0.2))
59+
# plt.show()
60+
plt.savefig('optimizers.png')

pytorch-03/train_acc.png

22.9 KB
Loading

pytorch-03/train_loss.png

22 KB
Loading

pytorch-03/vis.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
2+
3+
import matplotlib.pyplot as plt
4+
# %matplotlib inline
5+
6+
def plot_img(test_loader, save_path):
7+
examples = enumerate(test_loader)
8+
batch_idx, (example_data, example_targets) = next(examples)
9+
10+
fig = plt.figure()
11+
for i in range(6):
12+
plt.subplot(2,3,i+1)
13+
plt.tight_layout()
14+
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
15+
plt.title("Ground Truth: {}".format(example_targets[i]))
16+
plt.xticks([])
17+
plt.yticks([])
18+
plt.savefig(save_path)

0 commit comments

Comments
 (0)