Skip to content

Commit 708daf0

Browse files
author
lijc08
committed
权重7-
1 parent 37ed2da commit 708daf0

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ MaysaFaceboxesDiscROC.txt
1313
weights_*
1414

1515
face-eval/*.pdf
16-
weight_730
16+
weight_730
17+
18+
models_*

data/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@
99
'aspect_ratios': [[1], [1], [1]],
1010
'variance': [0.1, 0.2],
1111
'clip': False,
12-
'loc_weight': 2.0
12+
# 'loc_weight': 2.0
13+
'loc_weight': 1.0
1314
}

train.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# coding=utf-8
22
from __future__ import print_function
3+
4+
import logging
35
import os
46

57
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
@@ -24,7 +26,7 @@
2426
parser.add_argument('--ngpu', default=1, type=int, help='gpus')
2527
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
2628
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
27-
parser.add_argument('--resume_net', default=None, help='resume net for retraining')
29+
parser.add_argument('--resume_net', default="./weights/FaceBoxes_epoch_295.pth", help='resume net for retraining')
2830
parser.add_argument('--resume_epoch', default=0, type=int, help='resume iter for retraining')
2931
parser.add_argument('-max', '--max_epoch', default=300, type=int, help='max epoch for retraining')
3032
parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD')
@@ -71,7 +73,7 @@
7173
cudnn.benchmark = True
7274

7375
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
74-
criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)
76+
criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 3, 0.35, False)
7577

7678
priorbox = PriorBox(cfg)
7779
with torch.no_grad():
@@ -81,16 +83,29 @@
8183

8284

8385
def train():
86+
prefix = time.strftime("%Y-%m-%d-%H:%M:%S")
87+
file_path = "models_{}".format(prefix)
88+
if not os.path.exists(file_path):
89+
os.mkdir(file_path)
90+
logging.basicConfig()
91+
logging.getLogger().setLevel(logging.INFO)
92+
fh = logging.FileHandler("{}/train.log".format(file_path))
93+
# create formatter#
94+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
95+
# add formatter to ch
96+
fh.setFormatter(formatter)
97+
logging.getLogger().addHandler(fh)
98+
8499
net.train()
85100
epoch = 0 + args.resume_epoch
86-
print('Loading Dataset...')
101+
logging.info('Loading Dataset...')
87102

88103
args.training_dataset = os.path.expanduser(args.training_dataset)
89104
dataset = VOCDetection(args.training_dataset, preproc(img_dim, rgb_means), AnnotationTransform())
90105

91-
print("len(dataset):", len(dataset))
106+
logging.info("len(dataset): %s", len(dataset))
92107
epoch_size = int(math.ceil(len(dataset) / args.batch_size))
93-
print("epoch_size:", epoch_size)
108+
logging.info("epoch_size: %s", epoch_size)
94109
max_iter = args.max_epoch * epoch_size
95110

96111
stepvalues = (200 * epoch_size, 250 * epoch_size)
@@ -106,7 +121,7 @@ def train():
106121
# create batch iterator
107122
batch_iterator = iter(data.DataLoader(dataset, batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=detection_collate))
108123
if (epoch % 10 == 0 and epoch > 0) or (epoch % 5 == 0 and epoch > 200):
109-
torch.save(net.state_dict(), args.save_folder + 'FaceBoxes_epoch_' + repr(epoch) + '.pth')
124+
torch.save(net.state_dict(), file_path + 'FaceBoxes_epoch_' + repr(epoch) + '.pth')
110125
epoch += 1
111126

112127
load_t0 = time.time()
@@ -133,10 +148,12 @@ def train():
133148
loss.backward()
134149
optimizer.step()
135150
load_t1 = time.time()
136-
print('Epoch:' + repr(epoch) + ' || epochiter: ' + repr(iteration % epoch_size) + '/' + repr(epoch_size) +
137-
'|| Totel iter ' + repr(iteration) + ' || L: %.4f C: %.4f||' % (cfg['loc_weight'] * loss_l.item(), loss_c.item()) +
138-
'Batch time: %.4f sec. ||' % (load_t1 - load_t0) + 'LR: %.8f' % (lr))
151+
# logging.info('Epoch:' + repr(epoch) + ' || epochiter: ' + repr(iteration % epoch_size) + '/' + repr(epoch_size) +
152+
# '|| Totel iter ' + repr(iteration) + ' || L: %.4f C: %.4f||' % (cfg['loc_weight'] * loss_l.item(), loss_c.item()) +
153+
# 'Batch time: %.4f sec. ||' % (load_t1 - load_t0) + 'LR: %.8f' % (lr))
139154

155+
logging.info("epoch %s epochiter %s epoch_size %s iteration %s loss %s l_loss %s c_loss %s",
156+
epoch, iteration % epoch_size, epoch_size, iteration, loss.item(), loss_l.item(), loss_c.item())
140157
torch.save(net.state_dict(), args.save_folder + 'Final_FaceBoxes.pth')
141158

142159

0 commit comments

Comments
 (0)