Skip to content

Commit ae10c0e

Browse files
authored
avg over multi checkpoints
1 parent 51b7b74 commit ae10c0e

File tree

3 files changed

+300
-8
lines changed

3 files changed

+300
-8
lines changed

params.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55
66
@author: wayne
77
"""
8+
import torch
9+
10+
TRAIN_ROOT = 'data/validation_folder/'#'data/videos_folder/'
11+
VALIDATION_ROOT = 'data/validation_folder/'
12+
813

914
arch = 'resnet18' # preact_resnet50, resnet152
1015
pretrained = 'imagenet' #imagenet
1116
evaluate = False
1217
checkpoint_filename = arch + '_' + pretrained
18+
save_freq = 2
1319
try_resume = False
1420
print_freq = 10
1521
if_debug = False
@@ -34,7 +40,7 @@
3440
# training parameters:
3541
BATCH_SIZE = 32
3642
INPUT_WORKERS = 8
37-
epochs = 29
43+
epochs = 3
3844
use_epoch_decay = True # 可以加每次调lr时load回来最好的checkpoint
3945
lr = 0.0001 #0.01 0.001
4046
lr_min = 1e-6
@@ -52,4 +58,6 @@
5258
confusion_weight = 0.5 #for pairwise loss is 0.1N to 0.2N (where N is the number of classes), and for entropic is 0.1-0.5. https://github.com/abhimanyudubey/confusion
5359
betas=(0.9, 0.999)
5460
eps=1e-08 # 0.1的话一开始都是prec3 4.几
55-
momentum = 0.9
61+
momentum = 0.9
62+
63+
use_gpu = torch.cuda.is_available()

test2_multi_check.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
'''
5+
*Epoch:[0] Prec@1 99.384 Prec@3 100.000 Loss 0.5274
6+
'''
7+
8+
import os
9+
import torch
10+
import torch.nn as nn
11+
from torch.autograd import Variable
12+
from PIL import Image
13+
from torch.utils.data import Dataset, DataLoader
14+
import time
15+
import json
16+
from model import load_model
17+
from config import data_transforms
18+
import pickle
19+
import csv
20+
from params import *
21+
import torchvision.datasets as td
22+
import numpy as np
23+
24+
phases = ['val']
25+
batch_size = BATCH_SIZE
26+
27+
if phases[0] == 'test_A':
28+
test_root = 'data/test_A'
29+
elif phases[0] == 'test_B':
30+
test_root = 'data/test_B'
31+
elif phases[0] == 'val':
32+
test_root = 'data/validation_folder_full'
33+
34+
checkpoint_filename = arch + '_' + pretrained
35+
multi_checks = []
36+
'''
37+
在这里指定使用哪几个epoch的checkpoint进行平均
38+
'''
39+
for epoch_check in ['1']: # epoch的列表,如['10', '20']
40+
multi_checks.append('checkpoint/' + checkpoint_filename + '_' + str(epoch_check)+'.pth.tar')
41+
42+
'''
43+
这是imagefolder的顺序
44+
'''
45+
aaa = ['1','10', '11','12','13','14', '15', '16', '17', '18','19', '2', '20', '21', '22','23',
46+
'24', '25', '26', '27', '28', '29', '3', '30', '4', '5', '6', '7', '8','9']
47+
48+
49+
50+
51+
52+
best_check = 'checkpoint/' + checkpoint_filename + '_best.pth.tar'
53+
model_conv = load_model(arch, pretrained, use_gpu=use_gpu, num_classes=30, AdaptiveAvgPool=AdaptiveAvgPool, SPP=SPP, num_levels=num_levels, pool_type=pool_type, bilinear=bilinear, stage=stage, SENet=SENet,se_stage=se_stage,se_layers=se_layers)
54+
for param in model_conv.parameters():
55+
param.requires_grad = False #节省显存
56+
57+
best_checkpoint = torch.load(best_check)
58+
if arch.lower().startswith('alexnet') or arch.lower().startswith('vgg'):
59+
model_conv.features = nn.DataParallel(model_conv.features)
60+
model_conv.cuda()
61+
model_conv.load_state_dict(best_checkpoint['state_dict'])
62+
else:
63+
model_conv = nn.DataParallel(model_conv).cuda()
64+
model_conv.load_state_dict(best_checkpoint['state_dict'])
65+
66+
67+
68+
with open(test_root+'/pig_test_annotations.json', 'r') as f: #label文件, 测试的是我自己生成的
69+
label_raw_test = json.load(f)
70+
71+
def write_to_csv(aug_softmax, epoch_i = None): #aug_softmax[img_name_raw[item]] = temp[item,:]
72+
73+
if epoch_i != None:
74+
file = 'result/'+ phases[0] +'_1_'+ epoch_i.split('.')[0].split('_')[-1] + '.csv'
75+
else:
76+
file = 'result/'+ phases[0] +'_1.csv'
77+
with open(file, 'w', encoding='utf-8') as csvfile:
78+
spamwriter = csv.writer(csvfile,dialect='excel')
79+
for item in aug_softmax.keys():
80+
the_sum = sum(aug_softmax[item])
81+
for c in range(0,30):
82+
if phases[0] != 'val':
83+
spamwriter.writerow([int(item.split('.')[0]), c+1, aug_softmax[item][aaa.index(str(c+1))]/the_sum])
84+
else:
85+
spamwriter.writerow([item, c+1, aug_softmax[item][aaa.index(str(c+1))]/the_sum])
86+
87+
88+
class SceneDataset(Dataset):
89+
90+
def __init__(self, json_labels, root_dir, transform=None):
91+
self.label_raw = json_labels
92+
self.root_dir = root_dir
93+
self.transform = transform
94+
95+
def __len__(self):
96+
return len(self.label_raw)
97+
98+
def __getitem__(self, idx):
99+
# if phases[0] == 'val':
100+
# img_name = self.root_dir+ '/' + str(self.label_raw[idx]['label_id']+1) + '/'+ self.label_raw[idx]['image_id']
101+
# else:
102+
img_name = os.path.join(self.root_dir, self.label_raw[idx]['image_id'])
103+
img_name_raw = self.label_raw[idx]['image_id']
104+
image = Image.open(img_name)
105+
label = self.label_raw[idx]['label_id']
106+
107+
if self.transform:
108+
image = self.transform(image)
109+
110+
return image, label, img_name_raw
111+
112+
113+
transformed_dataset_test = SceneDataset(json_labels=label_raw_test,
114+
root_dir=test_root,
115+
transform=data_transforms('test',input_size, train_scale, test_scale)
116+
)
117+
dataloader = {phases[0]:DataLoader(transformed_dataset_test, batch_size=batch_size,shuffle=False, num_workers=INPUT_WORKERS)
118+
}
119+
dataset_sizes = {phases[0]: len(label_raw_test)}
120+
121+
122+
class AverageMeter(object):
123+
def __init__(self):
124+
self.reset()
125+
126+
def reset(self):
127+
self.val = 0
128+
self.avg = 0
129+
self.sum = 0
130+
self.count = 0
131+
132+
def update(self, val, n=1):
133+
self.val = val
134+
self.sum += val * n
135+
self.count += n
136+
self.avg = self.sum / self.count
137+
138+
def accuracy(output, target, topk=(1,)):
139+
"""Computes the precision@k for the specified values of k
140+
output: logits
141+
target: labels
142+
"""
143+
maxk = max(topk)
144+
batch_size = target.size(0)
145+
146+
_, pred = output.topk(maxk, 1, True, True)
147+
pred = pred.t()
148+
correct = pred.eq(target.view(1, -1).expand_as(pred))
149+
150+
res = []
151+
for k in topk:
152+
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
153+
res.append(correct_k.mul_(100.0 / batch_size))
154+
155+
156+
pred_list = pred.tolist() #[[14, 13], [72, 15], [74, 11]]
157+
return res, pred_list
158+
159+
160+
def test_model (model, criterion):
161+
since = time.time()
162+
163+
mystep = 0
164+
165+
for phase in phases:
166+
167+
model.eval() # Set model to evaluate mode
168+
169+
top1 = AverageMeter()
170+
top3 = AverageMeter()
171+
loss1 = AverageMeter()
172+
aug_softmax = {}
173+
174+
# Iterate over data.
175+
for data in dataloader[phase]:
176+
# get the inputs
177+
mystep = mystep + 1
178+
# if(mystep%10 ==0):
179+
# duration = time.time() - since
180+
# print('step %d vs %d in %.0f s' % (mystep, total_steps, duration))
181+
182+
inputs, labels, img_name_raw= data
183+
184+
# wrap them in Variable
185+
if use_gpu:
186+
inputs = Variable(inputs.cuda())
187+
labels = Variable(labels.cuda())
188+
else:
189+
inputs, labels = Variable(inputs), Variable(labels)
190+
191+
# forward
192+
outputs = model(inputs)
193+
crop_softmax = nn.functional.softmax(outputs)
194+
temp = crop_softmax.cpu().data.numpy()
195+
for item in range(len(img_name_raw)):
196+
aug_softmax[img_name_raw[item]] = temp[item,:] #防止多线程啥的改变了图片顺序,还是按照id保存比较保险
197+
198+
_, preds = torch.max(outputs.data, 1)
199+
loss = criterion(outputs, labels)
200+
201+
202+
# # statistics
203+
res, pred_list = accuracy(outputs.data, labels.data, topk=(1, 3))
204+
prec1 = res[0]
205+
prec3 = res[1]
206+
top1.update(prec1[0], inputs.size(0))
207+
top3.update(prec3[0], inputs.size(0))
208+
loss1.update(loss.data[0], inputs.size(0))
209+
210+
211+
print(' * Prec@1 {top1.avg:.6f} Prec@3 {top3.avg:.6f} Loss@1 {loss1.avg:.6f}'.format(top1=top1, top3=top3, loss1=loss1))
212+
213+
return aug_softmax
214+
215+
216+
217+
criterion = nn.CrossEntropyLoss()
218+
219+
220+
######################################################################
221+
# val and test
222+
total_steps = 1.0 * len(label_raw_test) / batch_size * len(multi_checks)
223+
print(total_steps)
224+
225+
class Average_Softmax(object):
226+
"""for item in range(len(img_name_raw)):
227+
aug_softmax[img_name_raw[item]] = temp[item,:]
228+
"""
229+
def __init__(self, inits):
230+
self.reset(inits)
231+
def reset(self, inits):
232+
self.val = inits
233+
self.avg = inits
234+
self.sum = inits
235+
self.total_weight = 0
236+
def update(self, val, w=1):
237+
self.val = val
238+
self.sum_dict(w)
239+
self.total_weight += w
240+
self.average()
241+
def sum_dict(self, w):
242+
for item in self.val.keys():
243+
self.sum[item] += (self.val[item] * w)
244+
def average(self):
245+
for item in self.avg.keys():
246+
self.avg[item] = self.sum[item]/self.total_weight
247+
248+
image_names = [item['image_id'] for item in label_raw_test]
249+
inits = {}
250+
for name in image_names:
251+
inits[name] = np.zeros(30)
252+
aug_softmax_multi = Average_Softmax(inits)
253+
254+
255+
for i in multi_checks:
256+
i_checkpoint = torch.load(i)
257+
print(i)
258+
if arch.lower().startswith('alexnet') or arch.lower().startswith('vgg'):
259+
#model_conv.features = nn.DataParallel(model_conv.features)
260+
#model_conv.cuda()
261+
model_conv.load_state_dict(i_checkpoint['state_dict'])
262+
else:
263+
#model_conv = nn.DataParallel(model_conv).cuda()
264+
model_conv.load_state_dict(i_checkpoint['state_dict'])
265+
aug_softmax = test_model(model_conv, criterion)
266+
write_to_csv(aug_softmax, i)
267+
aug_softmax_multi.update(aug_softmax)
268+
269+
'''
270+
输出融合的结果,并计算融合后的loss和accuracy
271+
'''
272+
def cal_loss(aug_softmax, label_raw_test):
273+
loss1 = 0
274+
for row in label_raw_test:
275+
loss1 -= np.log(aug_softmax[row['image_id']][row['label_id']])
276+
loss1 /= len(label_raw_test)
277+
print('Loss@1 {loss1:.6f}'.format(loss1=loss1))
278+
write_to_csv(aug_softmax_multi.avg)
279+
cal_loss(aug_softmax_multi.avg, label_raw_test)

train_pig.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,6 @@
2626
from hyperboard import Agent
2727
from params import *
2828

29-
use_gpu = torch.cuda.is_available()
30-
31-
TRAIN_ROOT = 'data/validation_folder/'#'data/videos_folder/'
32-
VALIDATION_ROOT = 'data/validation_folder/'
33-
3429

3530
def effect(alist):
3631
temp = 0
@@ -209,6 +204,14 @@ def run():
209204
# remember best
210205
is_best = loss1 <= best_loss1
211206
best_loss1 = min(loss1, best_loss1)
207+
if epoch % save_freq == 0:
208+
save_checkpoint_epoch({
209+
'epoch': epoch + 1,
210+
'arch': arch,
211+
'state_dict': model.state_dict(),
212+
'best_prec3': best_prec3,
213+
'loss1': loss1
214+
}, epoch+1)
212215
save_checkpoint({
213216
'epoch': epoch + 1,
214217
'arch': arch,
@@ -413,6 +416,8 @@ def save_checkpoint(state, is_best):
413416
if is_best:
414417
shutil.copyfile(latest_check, best_check)
415418

416-
419+
def save_checkpoint_epoch(state, epoch):
420+
torch.save(state, 'checkpoint/' + checkpoint_filename + '_' + str(epoch)+'.pth.tar')
421+
417422
if __name__ == '__main__':
418423
run()

0 commit comments

Comments
 (0)