@@ -51,10 +51,7 @@ def val(net, testloader, criterion):
51
51
return correct / total , loss / total
52
52
53
53
54
- def train (net , trainloader , testloader , num_epoch ):
55
- criterion = nn .CrossEntropyLoss ()
56
- optimizer = optim .SGD (net .parameters (), lr = 0.001 , momentum = 0.9 )
57
-
54
+ def train (net , trainloader , testloader , num_epoch , optimizer , criterion ):
58
55
# Print optimizer's state_dict
59
56
print ("Optimizer's state_dict:" )
60
57
for var_name in optimizer .state_dict ():
@@ -126,15 +123,20 @@ def main(args):
126
123
# print(device)
127
124
net .to (device )
128
125
126
+ criterion = nn .CrossEntropyLoss ()
127
+ optimizer = optim .SGD (net .parameters (), lr = 0.001 , momentum = 0.9 )
128
+
129
129
trainloader , testloader = load_data_cifar ()
130
130
weights_file = './cifar_net_{}.pth' .format (args .model )
131
131
if not os .path .isfile (weights_file ):
132
- train (net , trainloader , testloader , args .num_epoch )
132
+ train (net , trainloader , testloader , args .num_epoch , optimizer , criterion )
133
133
# PATH = './cifar_net.pth'
134
134
torch .save (net .state_dict (), weights_file )
135
135
else :
136
136
net .load_state_dict (torch .load (weights_file ))
137
137
net .to (device )
138
+ val_acc , val_loss = val (net , testloader , criterion )
139
+ print ('{} : val_acc : {}, val_loss : {}' .format (args .model , val_acc , val_loss ))
138
140
# 打印每个类别的准确率
139
141
print_accuracy_of_classes (net , testloader )
140
142
0 commit comments