8
8
import tensorflow as tf
9
9
from tensorflow .contrib .slim .python .slim .nets import vgg
10
10
from tensorflow .contrib .slim .python .slim .nets import inception
11
+ from tensorflow .contrib .slim .python .slim .nets import resnet_v1
11
12
from tensorflow .contrib .slim .python .slim .nets import resnet_v2
12
13
from mmdnn .conversion .examples .imagenet_test import TestKit
13
14
14
15
slim = tf .contrib .slim
15
16
16
17
input_layer_map = {
17
- 'vgg16' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 224 , 224 , 3 ]),
18
- 'vgg19' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 224 , 224 , 3 ]),
19
- 'inception_v1' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 224 , 224 , 3 ]),
20
- 'inception_v2' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
21
- 'inception_v3' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
22
- 'resnet50' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
23
- 'resnet101' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
24
- 'resnet152' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
25
- 'resnet200' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
18
+ 'vgg16' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 224 , 224 , 3 ]),
19
+ 'vgg19' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 224 , 224 , 3 ]),
20
+ 'inception_v1' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 224 , 224 , 3 ]),
21
+ 'inception_v2' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
22
+ 'inception_v3' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
23
+ 'resnet50' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
24
+ 'resnet_v1_101' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 224 , 224 , 3 ]),
25
+ 'resnet101' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
26
+ 'resnet152' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
27
+ 'resnet200' : lambda : tf .placeholder (name = 'input' , dtype = tf .float32 , shape = [None , 299 , 299 , 3 ]),
26
28
}
27
29
28
30
arg_scopes_map = {
32
34
'inception_v2' : inception .inception_v3_arg_scope ,
33
35
'inception_v3' : inception .inception_v3_arg_scope ,
34
36
'resnet50' : resnet_v2 .resnet_arg_scope ,
37
+ 'resnet_v1_101' : resnet_v2 .resnet_arg_scope ,
35
38
'resnet101' : resnet_v2 .resnet_arg_scope ,
36
39
'resnet152' : resnet_v2 .resnet_arg_scope ,
37
40
'resnet200' : resnet_v2 .resnet_arg_scope ,
44
47
'inception_v1' : lambda : inception .inception_v1 ,
45
48
'inception_v2' : lambda : inception .inception_v2 ,
46
49
'inception_v3' : lambda : inception .inception_v3 ,
50
+ 'resnet_v1_101' : lambda : resnet_v1 .resnet_v1_101 ,
47
51
'resnet50' : lambda : resnet_v2 .resnet_v2_50 ,
48
52
'resnet101' : lambda : resnet_v2 .resnet_v2_101 ,
49
53
'resnet152' : lambda : resnet_v2 .resnet_v2_152 ,
@@ -65,11 +69,11 @@ def _main():
65
69
66
70
args = parser .parse_args ()
67
71
68
- num_classes = 1000 if args .network in ('vgg16' , 'vgg19' ) else 1001
72
+ num_classes = 1000 if args .network in ('vgg16' , 'vgg19' , 'resnet_v1_101' ) else 1001
69
73
70
74
with slim .arg_scope (arg_scopes_map [args .network ]()):
71
75
data_input = input_layer_map [args .network ]()
72
- logits , endpoints = networks_map [args .network ]()(data_input , num_classes = num_classes , is_training = False )
76
+ logits , endpoints = networks_map [args .network ]()(data_input , num_classes = num_classes , is_training = False )
73
77
labels = tf .squeeze (logits )
74
78
75
79
init = tf .global_variables_initializer ()
0 commit comments