Skip to content

Commit 56a1249

Browse files
committed
fix save image bug, add checkpoint@150epoch
1 parent 7e6912f commit 56a1249

File tree

6 files changed

+38
-38
lines changed

6 files changed

+38
-38
lines changed
Binary file not shown.

cityscapes.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,8 @@ def trainId2color(dataset_root, id_map, name):
129129
color_map[id_map == label.trainId] = np.array(label.color)
130130
color_map = color_map.astype(np.uint8)
131131

132-
# plot
133-
plt.imshow(color_map)
134-
# plt.show()
135-
136132
# save
137-
plt.savefig(dataset_root + '/' + name)
133+
cv2.imwrite(dataset_root + '/' + name, color_map)
138134

139135
return color_map
140136

config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def __init__(self):
2929

3030
# train parameters
3131
self.num_epoch = 150
32-
self.base_lr = 0.00025
32+
self.base_lr = 0.0002
3333
self.power = 0.9
3434
self.momentum = 0.9
3535
self.weight_decay = 0.0005
3636
self.should_val = True
37-
self.val_every = 1
37+
self.val_every = 2
3838
self.display = 1 # show train result every display epoch
3939
self.should_split = True # should split training procedure into several parts
4040
self.split = 2 # number of split
352 KB
Loading

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def main():
5252
LOG('Model Built.\n')
5353

5454
# let's start to train!
55-
# net.Train()
55+
net.Train()
5656
net.Test()
5757

5858
if __name__ == '__main__':

network.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -206,38 +206,40 @@ def Train(self):
206206
Train network in n epochs, n is defined in params.num_epoch
207207
"""
208208
self.init_epoch = self.epoch
209-
assert self.epoch <= self.params.num_epoch, 'Current epoch should not be larger than num_epoch'
210-
for _ in range(self.epoch, self.params.num_epoch):
211-
self.epoch += 1
212-
print('-' * 20 + 'Epoch.' + str(self.epoch) + '-' * 20)
209+
if self.epoch >= self.params.num_epoch:
210+
WARNING('Num_epoch should be smaller than current epoch. Skip training......\n')
211+
else:
212+
for _ in range(self.epoch, self.params.num_epoch):
213+
self.epoch += 1
214+
print('-' * 20 + 'Epoch.' + str(self.epoch) + '-' * 20)
213215

214-
# train one epoch
215-
self.train_one_epoch()
216+
# train one epoch
217+
self.train_one_epoch()
216218

217-
# should display
218-
if self.epoch % self.params.display == 0:
219-
print('\tTrain loss: %.4f' % self.train_loss[-1])
219+
# should display
220+
if self.epoch % self.params.display == 0:
221+
print('\tTrain loss: %.4f' % self.train_loss[-1])
220222

221-
# should save
222-
if self.params.should_save:
223-
if self.epoch % self.params.save_every == 0:
224-
self.save_checkpoint()
223+
# should save
224+
if self.params.should_save:
225+
if self.epoch % self.params.save_every == 0:
226+
self.save_checkpoint()
225227

226-
# test every params.test_every epoch
227-
if self.params.should_val:
228-
if self.epoch % self.params.val_every == 0:
229-
self.val_one_epoch()
230-
print('\tVal loss: %.4f' % self.val_loss[-1])
228+
# test every params.test_every epoch
229+
if self.params.should_val:
230+
if self.epoch % self.params.val_every == 0:
231+
self.val_one_epoch()
232+
print('\tVal loss: %.4f' % self.val_loss[-1])
231233

232-
# adjust learning rate
233-
self.adjust_lr()
234+
# adjust learning rate
235+
self.adjust_lr()
234236

235-
# save the last network state
236-
if self.params.should_save:
237-
self.save_checkpoint()
237+
# save the last network state
238+
if self.params.should_save:
239+
self.save_checkpoint()
238240

239-
# train visualization
240-
self.plot_curve()
241+
# train visualization
242+
self.plot_curve()
241243

242244
def Test(self):
243245
"""
@@ -276,8 +278,8 @@ def Test(self):
276278
image_orig = image[i].numpy().transpose(1, 2, 0)
277279
image_orig = image_orig*255
278280
image_orig = image_orig.astype(np.uint8)
279-
self.summary_writer.add_image('img_%d/orig' % idx, image_orig, idx)
280-
self.summary_writer.add_image('img_%d/seg' % idx, color_map, idx)
281+
self.summary_writer.add_image('test/img_%d/orig' % idx, image_orig, idx)
282+
self.summary_writer.add_image('test/img_%d/seg' % idx, color_map, idx)
281283

282284
"""##########################"""
283285
"""# Model Save and Restore #"""
@@ -367,9 +369,11 @@ def plot_curve(self):
367369
"""
368370
Plot train/val loss curve
369371
"""
370-
x = np.arange(self.init_epoch, self.params.num_epoch+1, dtype=np.int).tolist()
371-
plt.plot(x, self.train_loss, label='train_loss')
372-
plt.plot(x, self.val_loss, label='val_loss')
372+
x1 = np.arange(self.init_epoch, self.params.num_epoch+1, dtype=np.int).tolist()
373+
x2 = np.linspace(self.init_epoch, self.epoch,
374+
num=(self.epoch-self.init_epoch)//self.params.val_every+1, dtype=np.int64)
375+
plt.plot(x1, self.train_loss, label='train_loss')
376+
plt.plot(x2, self.val_loss, label='val_loss')
373377
plt.legend(loc='best')
374378
plt.title('Train/Val loss')
375379
plt.grid()

0 commit comments

Comments
 (0)