Skip to content

Commit ed88f1f

Browse files
committed
测试三个模型
1 parent 08a4c39 commit ed88f1f

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

pytorch-06/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11

22
## 1 使用net_cnn.py
3+
```bash
4+
python cifar.py --model cnn --gpu_id 1
5+
cnn : val_acc : 0.6813, val_loss : 0.3075759708881378
6+
```
7+
8+
39
```bash
410
2021-08-07 03:17:48 [10, 2000] train acc: 0.889 loss: 0.308 ; val acc: 0.694 loss: 0.299 lr: 0.001
511
2021-08-07 03:18:07 [10, 4000] train acc: 0.869 loss: 0.356 ; val acc: 0.689 loss: 0.295 lr: 0.001
@@ -21,6 +27,11 @@ Accuracy of truck : 68 %
2127
```
2228

2329
## 2 使用net_gap.py
30+
```bash
31+
python cifar.py --model gap --gpu_id 1
32+
gap : val_acc : 0.6217, val_loss : 0.27244704961776733
33+
```
34+
2435
```bash
2536
2021-08-07 03:19:08 [10, 2000] train acc: 0.651 loss: 1.004 ; val acc: 0.616 loss: 0.272 lr: 0.001
2637
2021-08-07 03:19:27 [10, 4000] train acc: 0.655 loss: 0.982 ; val acc: 0.650 loss: 0.253 lr: 0.001
@@ -43,6 +54,12 @@ Accuracy of truck : 90 %
4354

4455

4556
## 3 net_vgg.py
57+
```bash
58+
python cifar.py --model vgg --gpu_id 1
59+
vgg : val_acc : 0.8284, val_loss : 0.1539972424507141
60+
```
61+
62+
4663
```bash
4764
2021-08-07 04:01:42 [10, 2000] train acc: 0.948 loss: 0.156 ; val acc: 0.828 loss: 0.159 lr: 0.001
4865
2021-08-07 04:02:41 [10, 4000] train acc: 0.946 loss: 0.158 ; val acc: 0.828 loss: 0.158 lr: 0.001

pytorch-06/cifar.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ def val(net, testloader, criterion):
5151
return correct / total, loss / total
5252

5353

54-
def train(net, trainloader, testloader, num_epoch):
55-
criterion = nn.CrossEntropyLoss()
56-
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
57-
54+
def train(net, trainloader, testloader, num_epoch, optimizer, criterion):
5855
# Print optimizer's state_dict
5956
print("Optimizer's state_dict:")
6057
for var_name in optimizer.state_dict():
@@ -126,15 +123,20 @@ def main(args):
126123
# print(device)
127124
net.to(device)
128125

126+
criterion = nn.CrossEntropyLoss()
127+
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
128+
129129
trainloader, testloader = load_data_cifar()
130130
weights_file = './cifar_net_{}.pth'.format(args.model)
131131
if not os.path.isfile(weights_file):
132-
train(net, trainloader, testloader, args.num_epoch)
132+
train(net, trainloader, testloader, args.num_epoch, optimizer, criterion)
133133
# PATH = './cifar_net.pth'
134134
torch.save(net.state_dict(), weights_file)
135135
else:
136136
net.load_state_dict(torch.load(weights_file))
137137
net.to(device)
138+
val_acc, val_loss = val(net, testloader, criterion)
139+
print('{} : val_acc : {}, val_loss : {}'.format(args.model, val_acc, val_loss))
138140
# 打印每个类别的准确率
139141
print_accuracy_of_classes(net, testloader)
140142

0 commit comments

Comments
 (0)