Skip to content

Commit c25124d

Browse files
authored
add Resnet ce (PaddlePaddle#2502)
* add ce for dygraph mnist * add ce for dygraph mnist * del mnist_dygraph.py * change mnist_dygraph to train * fix print style * add resnet * fix ce bug * fix ce decsription
1 parent 4bb42e2 commit c25124d

File tree

3 files changed

+100
-6
lines changed

3 files changed

+100
-6
lines changed

dygraph/resnet/.run_ce.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
3+
# This file is only used for continuous evaluation.
4+
# dygraph single card
5+
export FLAGS_cudnn_deterministic=True
6+
export CUDA_VISIBLE_DEVICES=0
7+
python train.py --ce --epoch 1 --batch_size 128 | python _ce.py
8+

dygraph/resnet/_ce.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
####this file is only used for continuous evaluation test!
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
import os
6+
import sys
7+
sys.path.append(os.environ['ceroot'])
8+
from kpi import CostKpi, DurationKpi, AccKpi
9+
10+
#### NOTE kpi.py should shared in models in some way!!!!
11+
12+
train_acc1 = AccKpi('train_acc1', 0.01, 0, actived=True, desc="train acc1")
13+
train_acc5 = AccKpi('train_acc5', 0.01, 0, actived=True, desc="train acc5")
14+
train_loss = CostKpi('train_loss', 0.01, 0, actived=True, desc="train loss")
15+
test_acc1 = AccKpi('test_acc1', 0.01, 0, actived=True, desc='test acc1')
16+
test_acc5 = AccKpi('test_acc5', 0.01, 0, actived=True, desc='test acc5')
17+
test_loss = CostKpi('test_loss', 0.01, 0, actived=True, desc='test loss')
18+
#train_speed_kpi = DurationKpi(
19+
# 'train_speed',
20+
# 0.05,
21+
# 0,
22+
# actived=True,
23+
# unit_repr='seconds/image',
24+
# desc='train speed in one GPU card')
25+
tracking_kpis = [train_acc1, train_acc5, train_loss,
26+
test_acc1, test_acc5, test_loss]
27+
28+
def parse_log(log):
29+
'''
30+
This method should be implemented by model developers.
31+
32+
The suggestion:
33+
34+
each line in the log should be key, value, for example:
35+
36+
"
37+
train_cost\t1.0
38+
test_cost\t1.0
39+
train_cost\t1.0
40+
train_cost\t1.0
41+
train_acc\t1.2
42+
"
43+
'''
44+
for line in log.split('\n'):
45+
fs = line.strip().split('\t')
46+
print(fs)
47+
if len(fs) == 3 and fs[0] == 'kpis':
48+
print("-----%s" % fs)
49+
kpi_name = fs[1]
50+
kpi_value = float(fs[2])
51+
yield kpi_name, kpi_value
52+
53+
54+
def log_to_ce(log):
55+
kpi_tracker = {}
56+
for kpi in tracking_kpis:
57+
kpi_tracker[kpi.name] = kpi
58+
59+
for (kpi_name, kpi_value) in parse_log(log):
60+
print(kpi_name, kpi_value)
61+
kpi_tracker[kpi_name].add_record(kpi_value)
62+
kpi_tracker[kpi_name].persist()
63+
64+
65+
if __name__ == '__main__':
66+
log = sys.stdin.read()
67+
print("*****")
68+
print(log)
69+
print("****")
70+
log_to_ce(log)

dygraph/resnet/train.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,28 @@
2626
import math
2727
import sys
2828

29-
batch_size = 32
30-
epoch = 120
3129
IMAGENET1000 = 1281167
3230
base_lr = 0.1
3331
momentum_rate = 0.9
3432
l2_decay = 1e-4
3533

3634

3735
def parse_args():
38-
parser = argparse.ArgumentParser("Training for Mnist.")
36+
parser = argparse.ArgumentParser("Training for Resnet.")
3937
parser.add_argument(
4038
"--use_data_parallel",
4139
type=ast.literal_eval,
4240
default=False,
4341
help="The flag indicating whether to shuffle instances in each pass.")
42+
parser.add_argument("-e", "--epoch", default=120, type=int, help="set epoch")
43+
parser.add_argument("-b", "--batch_size", default=32, type=int, help="set epoch")
44+
parser.add_argument("--ce", action="store_true", help="run ce")
4445
args = parser.parse_args()
4546
return args
4647

4748

4849
args = parse_args()
49-
50+
batch_size = args.batch_size
5051

5152
def optimizer_setting():
5253

@@ -263,16 +264,28 @@ def eval(model, data):
263264
print("test | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \
264265
( batch_id, total_loss / total_sample, \
265266
total_acc1 / total_sample, total_acc5 / total_sample))
267+
if args.ce:
268+
print("kpis\ttest_acc1\t%0.3f" % (total_acc1 / total_sample))
269+
print("kpis\ttest_acc5\t%0.3f" % (total_acc5 / total_sample))
270+
print("kpis\ttest_loss\t%0.3f" % (total_loss / total_sample))
266271
print("final eval loss %0.3f acc1 %0.3f acc5 %0.3f" % \
267272
(total_loss / total_sample, \
268273
total_acc1 / total_sample, total_acc5 / total_sample))
269274

270275

271276
def train_resnet():
277+
epoch = args.epoch
272278
trainer_count = fluid.dygraph.parallel.Env().nranks
273279
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
274280
if args.use_data_parallel else fluid.CUDAPlace(0)
275281
with fluid.dygraph.guard(place):
282+
if args.ce:
283+
print("ce mode")
284+
seed = 33
285+
np.random.seed(seed)
286+
fluid.default_startup_program().random_seed = seed
287+
fluid.default_main_program().random_seed = seed
288+
276289
if args.use_data_parallel:
277290
strategy = fluid.dygraph.parallel.prepare_context()
278291

@@ -340,24 +353,27 @@ def train_resnet():
340353
optimizer.minimize(avg_loss)
341354
resnet.clear_gradients()
342355

343-
framework._dygraph_tracer_._clear_ops()
344356

345357
total_loss += dy_out
346358
total_acc1 += acc_top1.numpy()
347359
total_acc5 += acc_top5.numpy()
348360
total_sample += 1
349-
350361
#print("epoch id: %d, batch step: %d, loss: %f" % (eop, batch_id, dy_out))
351362
if batch_id % 10 == 0:
352363
print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \
353364
( eop, batch_id, total_loss / total_sample, \
354365
total_acc1 / total_sample, total_acc5 / total_sample))
355366

367+
if args.ce:
368+
print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample))
369+
print("kpis\ttrain_acc5\t%0.3f" % (total_acc5 / total_sample))
370+
print("kpis\ttrain_loss\t%0.3f" % (total_loss / total_sample))
356371
print("epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \
357372
(eop, batch_id, total_loss / total_sample, \
358373
total_acc1 / total_sample, total_acc5 / total_sample))
359374
resnet.eval()
360375
eval(resnet, test_reader)
376+
fluid.dygraph.save_persistables(resnet.state_dict(), 'resnet_params')
361377

362378

363379
if __name__ == '__main__':

0 commit comments

Comments
 (0)