Skip to content

Commit 09a32f3

Browse files
derekjchowsguada
authored andcommitted
Update slim/ (tensorflow#2307)
1 parent 42f507f commit 09a32f3

File tree

5 files changed

+39
-33
lines changed

5 files changed

+39
-33
lines changed

slim/export_inference_graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def with the variables inlined as constants using:
6262
from datasets import dataset_factory
6363
from nets import nets_factory
6464

65+
6566
slim = tf.contrib.slim
6667

6768
tf.app.flags.DEFINE_string(

slim/nets/nets_factory_test.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,19 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
2223
import tensorflow as tf
2324

2425
from nets import nets_factory
2526

26-
slim = tf.contrib.slim
27-
2827

2928
class NetworksTest(tf.test.TestCase):
3029

31-
def testGetNetworkFn(self):
30+
def testGetNetworkFnFirstHalf(self):
3231
batch_size = 5
3332
num_classes = 1000
34-
for net in nets_factory.networks_map:
35-
with self.test_session():
33+
for net in nets_factory.networks_map.keys()[:10]:
34+
with tf.Graph().as_default() as g, self.test_session(g):
3635
net_fn = nets_factory.get_network_fn(net, num_classes)
3736
# Most networks use 224 as their default_image_size
3837
image_size = getattr(net_fn, 'default_image_size', 224)
@@ -43,19 +42,20 @@ def testGetNetworkFn(self):
4342
self.assertEqual(logits.get_shape().as_list()[0], batch_size)
4443
self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
4544

46-
def testGetNetworkFnArgScope(self):
45+
def testGetNetworkFnSecondHalf(self):
4746
batch_size = 5
48-
num_classes = 10
49-
net = 'cifarnet'
50-
with self.test_session(use_gpu=True):
51-
net_fn = nets_factory.get_network_fn(net, num_classes)
52-
image_size = getattr(net_fn, 'default_image_size', 224)
53-
with slim.arg_scope([slim.model_variable, slim.variable],
54-
device='/CPU:0'):
47+
num_classes = 1000
48+
for net in nets_factory.networks_map.keys()[10:]:
49+
with tf.Graph().as_default() as g, self.test_session(g):
50+
net_fn = nets_factory.get_network_fn(net, num_classes)
51+
# Most networks use 224 as their default_image_size
52+
image_size = getattr(net_fn, 'default_image_size', 224)
5553
inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
56-
net_fn(inputs)
57-
weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'CifarNet/conv1')[0]
58-
self.assertDeviceEqual('/CPU:0', weights.device)
54+
logits, end_points = net_fn(inputs)
55+
self.assertTrue(isinstance(logits, tf.Tensor))
56+
self.assertTrue(isinstance(end_points, dict))
57+
self.assertEqual(logits.get_shape().as_list()[0], batch_size)
58+
self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
5959

6060
if __name__ == '__main__':
6161
tf.test.main()

slim/preprocessing/inception_preprocessing.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def preprocess_for_train(image, height, width, bbox,
212212
num_resize_cases = 1 if fast_mode else 4
213213
distorted_image = apply_with_random_selector(
214214
distorted_image,
215-
lambda x, method: tf.image.resize_images(x, [height, width], method=method),
215+
lambda x, method: tf.image.resize_images(x, [height, width], method),
216216
num_cases=num_resize_cases)
217217

218218
tf.summary.image('cropped_resized_image',
@@ -248,7 +248,7 @@ def preprocess_for_eval(image, height, width,
248248
image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
249249
[0, 1], otherwise it would converted to tf.float32 assuming that the range
250250
is [0, MAX], where MAX is largest positive representable number for
251-
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details)
251+
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
252252
height: integer
253253
width: integer
254254
central_fraction: Optional Float, fraction of the image to crop.
@@ -282,7 +282,11 @@ def preprocess_image(image, height, width,
282282
"""Pre-process one image for training or evaluation.
283283
284284
Args:
285-
image: 3-D Tensor [height, width, channels] with the image.
285+
image: 3-D Tensor [height, width, channels] with the image. If dtype is
286+
tf.float32 then the range should be [0, 1], otherwise it would converted
287+
to tf.float32 assuming that the range is [0, MAX], where MAX is largest
288+
positive representable number for int(8/16/32) data type (see
289+
`tf.image.convert_image_dtype` for details).
286290
height: integer, image expected height.
287291
width: integer, image expected width.
288292
is_training: Boolean. If true it would transform an image for train,

slim/scripts/export_mobilenet.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
# of the model, and the input image size, which can be 224, 192, 160, or 128
1313
# pixels, and affects the amount of computation needed, and the latency.
1414
# Here's an example generating a frozen model from pretrained weights:
15-
#
15+
#
1616

1717
set -e
1818

1919
print_usage () {
2020
echo "Creates a frozen mobilenet model suitable for mobile use"
2121
echo "Usage:"
2222
echo "$0 <mobilenet version> <input size> [checkpoint path]"
23-
}
23+
}
2424

2525
MOBILENET_VERSION=$1
2626
IMAGE_SIZE=$2

slim/train_image_classifier.py

100755100644
Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@
117117
'momentum', 0.9,
118118
'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
119119

120+
tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
121+
120122
tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
121123

122124
#######################
@@ -301,14 +303,15 @@ def _configure_optimizer(learning_rate):
301303
optimizer = tf.train.RMSPropOptimizer(
302304
learning_rate,
303305
decay=FLAGS.rmsprop_decay,
304-
momentum=FLAGS.momentum,
306+
momentum=FLAGS.rmsprop_momentum,
305307
epsilon=FLAGS.opt_epsilon)
306308
elif FLAGS.optimizer == 'sgd':
307309
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
308310
else:
309311
raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
310312
return optimizer
311313

314+
312315
def _get_init_fn():
313316
"""Returns a function run by the chief worker to warm-start the training.
314317
@@ -450,20 +453,19 @@ def main(_):
450453
####################
451454
def clone_fn(batch_queue):
452455
"""Allows data parallelism by creating multiple clones of network_fn."""
453-
with tf.device(deploy_config.inputs_device()):
454-
images, labels = batch_queue.dequeue()
456+
images, labels = batch_queue.dequeue()
455457
logits, end_points = network_fn(images)
456458

457459
#############################
458460
# Specify the loss function #
459461
#############################
460462
if 'AuxLogits' in end_points:
461-
tf.losses.softmax_cross_entropy(
462-
logits=end_points['AuxLogits'], onehot_labels=labels,
463-
label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
464-
tf.losses.softmax_cross_entropy(
465-
logits=logits, onehot_labels=labels,
466-
label_smoothing=FLAGS.label_smoothing, weights=1.0)
463+
slim.losses.softmax_cross_entropy(
464+
end_points['AuxLogits'], labels,
465+
label_smoothing=FLAGS.label_smoothing, weights=0.4,
466+
scope='aux_loss')
467+
slim.losses.softmax_cross_entropy(
468+
logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
467469
return end_points
468470

469471
# Gather initial summaries.
@@ -515,10 +517,9 @@ def clone_fn(batch_queue):
515517
optimizer = tf.train.SyncReplicasOptimizer(
516518
opt=optimizer,
517519
replicas_to_aggregate=FLAGS.replicas_to_aggregate,
520+
total_num_replicas=FLAGS.worker_replicas,
518521
variable_averages=variable_averages,
519-
variables_to_average=moving_average_variables,
520-
replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
521-
total_num_replicas=FLAGS.worker_replicas)
522+
variables_to_average=moving_average_variables)
522523
elif FLAGS.moving_average_decay:
523524
# Update ops executed locally by trainer.
524525
update_ops.append(variable_averages.apply(moving_average_variables))

0 commit comments

Comments
 (0)