1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ '''
5
+ *Epoch:[0] Prec@1 99.384 Prec@3 100.000 Loss 0.5274
6
+ '''
7
+
8
+ import os
9
+ import torch
10
+ import torch .nn as nn
11
+ from torch .autograd import Variable
12
+ from PIL import Image
13
+ from torch .utils .data import Dataset , DataLoader
14
+ import time
15
+ import json
16
+ from model import load_model
17
+ from config import data_transforms
18
+ import pickle
19
+ import csv
20
+ from params import *
21
+ import torchvision .datasets as td
22
+ import numpy as np
23
+
24
+ phases = ['val' ]
25
+ batch_size = BATCH_SIZE
26
+
27
+ if phases [0 ] == 'test_A' :
28
+ test_root = 'data/test_A'
29
+ elif phases [0 ] == 'test_B' :
30
+ test_root = 'data/test_B'
31
+ elif phases [0 ] == 'val' :
32
+ test_root = 'data/validation_folder_full'
33
+
34
+ checkpoint_filename = arch + '_' + pretrained
35
+ multi_checks = []
36
+ '''
37
+ 在这里指定使用哪几个epoch的checkpoint进行平均
38
+ '''
39
+ for epoch_check in ['1' ]: # epoch的列表,如['10', '20']
40
+ multi_checks .append ('checkpoint/' + checkpoint_filename + '_' + str (epoch_check )+ '.pth.tar' )
41
+
42
+ '''
43
+ 这是imagefolder的顺序
44
+ '''
45
+ aaa = ['1' ,'10' , '11' ,'12' ,'13' ,'14' , '15' , '16' , '17' , '18' ,'19' , '2' , '20' , '21' , '22' ,'23' ,
46
+ '24' , '25' , '26' , '27' , '28' , '29' , '3' , '30' , '4' , '5' , '6' , '7' , '8' ,'9' ]
47
+
48
+
49
+
50
+
51
+
52
+ best_check = 'checkpoint/' + checkpoint_filename + '_best.pth.tar'
53
+ model_conv = load_model (arch , pretrained , use_gpu = use_gpu , num_classes = 30 , AdaptiveAvgPool = AdaptiveAvgPool , SPP = SPP , num_levels = num_levels , pool_type = pool_type , bilinear = bilinear , stage = stage , SENet = SENet ,se_stage = se_stage ,se_layers = se_layers )
54
+ for param in model_conv .parameters ():
55
+ param .requires_grad = False #节省显存
56
+
57
+ best_checkpoint = torch .load (best_check )
58
+ if arch .lower ().startswith ('alexnet' ) or arch .lower ().startswith ('vgg' ):
59
+ model_conv .features = nn .DataParallel (model_conv .features )
60
+ model_conv .cuda ()
61
+ model_conv .load_state_dict (best_checkpoint ['state_dict' ])
62
+ else :
63
+ model_conv = nn .DataParallel (model_conv ).cuda ()
64
+ model_conv .load_state_dict (best_checkpoint ['state_dict' ])
65
+
66
+
67
+
68
+ with open (test_root + '/pig_test_annotations.json' , 'r' ) as f : #label文件, 测试的是我自己生成的
69
+ label_raw_test = json .load (f )
70
+
71
+ def write_to_csv (aug_softmax , epoch_i = None ): #aug_softmax[img_name_raw[item]] = temp[item,:]
72
+
73
+ if epoch_i != None :
74
+ file = 'result/' + phases [0 ] + '_1_' + epoch_i .split ('.' )[0 ].split ('_' )[- 1 ] + '.csv'
75
+ else :
76
+ file = 'result/' + phases [0 ] + '_1.csv'
77
+ with open (file , 'w' , encoding = 'utf-8' ) as csvfile :
78
+ spamwriter = csv .writer (csvfile ,dialect = 'excel' )
79
+ for item in aug_softmax .keys ():
80
+ the_sum = sum (aug_softmax [item ])
81
+ for c in range (0 ,30 ):
82
+ if phases [0 ] != 'val' :
83
+ spamwriter .writerow ([int (item .split ('.' )[0 ]), c + 1 , aug_softmax [item ][aaa .index (str (c + 1 ))]/ the_sum ])
84
+ else :
85
+ spamwriter .writerow ([item , c + 1 , aug_softmax [item ][aaa .index (str (c + 1 ))]/ the_sum ])
86
+
87
+
88
+ class SceneDataset (Dataset ):
89
+
90
+ def __init__ (self , json_labels , root_dir , transform = None ):
91
+ self .label_raw = json_labels
92
+ self .root_dir = root_dir
93
+ self .transform = transform
94
+
95
+ def __len__ (self ):
96
+ return len (self .label_raw )
97
+
98
+ def __getitem__ (self , idx ):
99
+ # if phases[0] == 'val':
100
+ # img_name = self.root_dir+ '/' + str(self.label_raw[idx]['label_id']+1) + '/'+ self.label_raw[idx]['image_id']
101
+ # else:
102
+ img_name = os .path .join (self .root_dir , self .label_raw [idx ]['image_id' ])
103
+ img_name_raw = self .label_raw [idx ]['image_id' ]
104
+ image = Image .open (img_name )
105
+ label = self .label_raw [idx ]['label_id' ]
106
+
107
+ if self .transform :
108
+ image = self .transform (image )
109
+
110
+ return image , label , img_name_raw
111
+
112
+
113
+ transformed_dataset_test = SceneDataset (json_labels = label_raw_test ,
114
+ root_dir = test_root ,
115
+ transform = data_transforms ('test' ,input_size , train_scale , test_scale )
116
+ )
117
+ dataloader = {phases [0 ]:DataLoader (transformed_dataset_test , batch_size = batch_size ,shuffle = False , num_workers = INPUT_WORKERS )
118
+ }
119
+ dataset_sizes = {phases [0 ]: len (label_raw_test )}
120
+
121
+
122
+ class AverageMeter (object ):
123
+ def __init__ (self ):
124
+ self .reset ()
125
+
126
+ def reset (self ):
127
+ self .val = 0
128
+ self .avg = 0
129
+ self .sum = 0
130
+ self .count = 0
131
+
132
+ def update (self , val , n = 1 ):
133
+ self .val = val
134
+ self .sum += val * n
135
+ self .count += n
136
+ self .avg = self .sum / self .count
137
+
138
+ def accuracy (output , target , topk = (1 ,)):
139
+ """Computes the precision@k for the specified values of k
140
+ output: logits
141
+ target: labels
142
+ """
143
+ maxk = max (topk )
144
+ batch_size = target .size (0 )
145
+
146
+ _ , pred = output .topk (maxk , 1 , True , True )
147
+ pred = pred .t ()
148
+ correct = pred .eq (target .view (1 , - 1 ).expand_as (pred ))
149
+
150
+ res = []
151
+ for k in topk :
152
+ correct_k = correct [:k ].view (- 1 ).float ().sum (0 , keepdim = True )
153
+ res .append (correct_k .mul_ (100.0 / batch_size ))
154
+
155
+
156
+ pred_list = pred .tolist () #[[14, 13], [72, 15], [74, 11]]
157
+ return res , pred_list
158
+
159
+
160
+ def test_model (model , criterion ):
161
+ since = time .time ()
162
+
163
+ mystep = 0
164
+
165
+ for phase in phases :
166
+
167
+ model .eval () # Set model to evaluate mode
168
+
169
+ top1 = AverageMeter ()
170
+ top3 = AverageMeter ()
171
+ loss1 = AverageMeter ()
172
+ aug_softmax = {}
173
+
174
+ # Iterate over data.
175
+ for data in dataloader [phase ]:
176
+ # get the inputs
177
+ mystep = mystep + 1
178
+ # if(mystep%10 ==0):
179
+ # duration = time.time() - since
180
+ # print('step %d vs %d in %.0f s' % (mystep, total_steps, duration))
181
+
182
+ inputs , labels , img_name_raw = data
183
+
184
+ # wrap them in Variable
185
+ if use_gpu :
186
+ inputs = Variable (inputs .cuda ())
187
+ labels = Variable (labels .cuda ())
188
+ else :
189
+ inputs , labels = Variable (inputs ), Variable (labels )
190
+
191
+ # forward
192
+ outputs = model (inputs )
193
+ crop_softmax = nn .functional .softmax (outputs )
194
+ temp = crop_softmax .cpu ().data .numpy ()
195
+ for item in range (len (img_name_raw )):
196
+ aug_softmax [img_name_raw [item ]] = temp [item ,:] #防止多线程啥的改变了图片顺序,还是按照id保存比较保险
197
+
198
+ _ , preds = torch .max (outputs .data , 1 )
199
+ loss = criterion (outputs , labels )
200
+
201
+
202
+ # # statistics
203
+ res , pred_list = accuracy (outputs .data , labels .data , topk = (1 , 3 ))
204
+ prec1 = res [0 ]
205
+ prec3 = res [1 ]
206
+ top1 .update (prec1 [0 ], inputs .size (0 ))
207
+ top3 .update (prec3 [0 ], inputs .size (0 ))
208
+ loss1 .update (loss .data [0 ], inputs .size (0 ))
209
+
210
+
211
+ print (' * Prec@1 {top1.avg:.6f} Prec@3 {top3.avg:.6f} Loss@1 {loss1.avg:.6f}' .format (top1 = top1 , top3 = top3 , loss1 = loss1 ))
212
+
213
+ return aug_softmax
214
+
215
+
216
+
217
+ criterion = nn .CrossEntropyLoss ()
218
+
219
+
220
+ ######################################################################
221
+ # val and test
222
+ total_steps = 1.0 * len (label_raw_test ) / batch_size * len (multi_checks )
223
+ print (total_steps )
224
+
225
+ class Average_Softmax (object ):
226
+ """for item in range(len(img_name_raw)):
227
+ aug_softmax[img_name_raw[item]] = temp[item,:]
228
+ """
229
+ def __init__ (self , inits ):
230
+ self .reset (inits )
231
+ def reset (self , inits ):
232
+ self .val = inits
233
+ self .avg = inits
234
+ self .sum = inits
235
+ self .total_weight = 0
236
+ def update (self , val , w = 1 ):
237
+ self .val = val
238
+ self .sum_dict (w )
239
+ self .total_weight += w
240
+ self .average ()
241
+ def sum_dict (self , w ):
242
+ for item in self .val .keys ():
243
+ self .sum [item ] += (self .val [item ] * w )
244
+ def average (self ):
245
+ for item in self .avg .keys ():
246
+ self .avg [item ] = self .sum [item ]/ self .total_weight
247
+
248
+ image_names = [item ['image_id' ] for item in label_raw_test ]
249
+ inits = {}
250
+ for name in image_names :
251
+ inits [name ] = np .zeros (30 )
252
+ aug_softmax_multi = Average_Softmax (inits )
253
+
254
+
255
+ for i in multi_checks :
256
+ i_checkpoint = torch .load (i )
257
+ print (i )
258
+ if arch .lower ().startswith ('alexnet' ) or arch .lower ().startswith ('vgg' ):
259
+ #model_conv.features = nn.DataParallel(model_conv.features)
260
+ #model_conv.cuda()
261
+ model_conv .load_state_dict (i_checkpoint ['state_dict' ])
262
+ else :
263
+ #model_conv = nn.DataParallel(model_conv).cuda()
264
+ model_conv .load_state_dict (i_checkpoint ['state_dict' ])
265
+ aug_softmax = test_model (model_conv , criterion )
266
+ write_to_csv (aug_softmax , i )
267
+ aug_softmax_multi .update (aug_softmax )
268
+
269
+ '''
270
+ 输出融合的结果,并计算融合后的loss和accuracy
271
+ '''
272
+ def cal_loss (aug_softmax , label_raw_test ):
273
+ loss1 = 0
274
+ for row in label_raw_test :
275
+ loss1 -= np .log (aug_softmax [row ['image_id' ]][row ['label_id' ]])
276
+ loss1 /= len (label_raw_test )
277
+ print ('Loss@1 {loss1:.6f}' .format (loss1 = loss1 ))
278
+ write_to_csv (aug_softmax_multi .avg )
279
+ cal_loss (aug_softmax_multi .avg , label_raw_test )
0 commit comments