Skip to content

Commit d5eb677

Browse files
qlzh727tensorflower-gardener
authored andcommitted
Cleanup keras distribution tests.
PiperOrigin-RevId: 345066746 Change-Id: Ibd5adc02e9b1ae7c9f4d640f6fa7ad8140de89ac
1 parent 0d4b7e2 commit d5eb677

File tree

4 files changed

+126
-115
lines changed

4 files changed

+126
-115
lines changed

tensorflow/python/keras/distribute/BUILD

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ py_library(
6969
":model_combinations",
7070
":multi_worker_testing_utils",
7171
":saved_model_test_base",
72+
":test_example",
7273
],
7374
)
7475

@@ -202,7 +203,6 @@ cuda_py_test(
202203
"//tensorflow/python/distribute:multi_worker_test_base",
203204
"//tensorflow/python/distribute:multi_worker_util",
204205
"//tensorflow/python/distribute:strategy_combinations",
205-
"//tensorflow/python/distribute:strategy_test_lib",
206206
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
207207
"//tensorflow/python/eager:context",
208208
"//tensorflow/python/keras:testing_utils",
@@ -601,10 +601,10 @@ distribute_py_test(
601601
"notsan",
602602
],
603603
deps = [
604+
":distribute_strategy_test_lib",
604605
":keras_test_lib",
605606
":optimizer_combinations",
606607
"//tensorflow/python/distribute:parameter_server_strategy",
607-
"//tensorflow/python/keras/distribute:distribute_strategy_test_lib",
608608
],
609609
)
610610

@@ -654,6 +654,7 @@ distribute_py_test(
654654
],
655655
deps = [
656656
":optimizer_combinations",
657+
":test_example",
657658
"//tensorflow/python:control_flow_ops",
658659
"//tensorflow/python:control_flow_v2_toggles",
659660
"//tensorflow/python:framework_ops",
@@ -663,9 +664,7 @@ distribute_py_test(
663664
"//tensorflow/python/data/ops:dataset_ops",
664665
"//tensorflow/python/distribute:combinations",
665666
"//tensorflow/python/distribute:mirrored_strategy",
666-
"//tensorflow/python/distribute:single_loss_example",
667667
"//tensorflow/python/distribute:strategy_combinations",
668-
"//tensorflow/python/distribute:strategy_test_lib",
669668
"//tensorflow/python/eager:context",
670669
"//tensorflow/python/eager:test",
671670
"//tensorflow/python/ops/losses",
@@ -842,27 +841,6 @@ distribute_py_test(
842841
],
843842
)
844843

845-
distribute_py_test(
846-
name = "step_fn_test",
847-
srcs = ["step_fn_test.py"],
848-
main = "step_fn_test.py",
849-
tags = [
850-
"multi_and_single_gpu",
851-
],
852-
deps = [
853-
":optimizer_combinations",
854-
"//tensorflow/python:framework_test_lib",
855-
"//tensorflow/python:variables",
856-
"//tensorflow/python/distribute:combinations",
857-
"//tensorflow/python/distribute:single_loss_example",
858-
"//tensorflow/python/distribute:strategy_combinations",
859-
"//tensorflow/python/eager:context",
860-
"//tensorflow/python/eager:test",
861-
"//third_party/py/numpy",
862-
"@absl_py//absl/testing:parameterized",
863-
],
864-
)
865-
866844
tf_py_test(
867845
name = "parameter_server_training_test",
868846
srcs = ["parameter_server_training_test.py"],
@@ -962,3 +940,16 @@ py_library(
962940
"//tensorflow/python/distribute:strategy_combinations",
963941
],
964942
)
943+
944+
py_library(
945+
name = "test_example",
946+
srcs = ["test_example.py"],
947+
deps = [
948+
"//tensorflow/python:array_ops",
949+
"//tensorflow/python:constant_op",
950+
"//tensorflow/python:framework_ops",
951+
"//tensorflow/python:layers",
952+
"//tensorflow/python:math_ops",
953+
"//tensorflow/python/data/ops:dataset_ops",
954+
],
955+
)

tensorflow/python/keras/distribute/minimize_loss_test.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,19 @@
2525
from tensorflow.python.distribute import combinations as ds_combinations
2626
from tensorflow.python.distribute import reduce_util
2727
from tensorflow.python.distribute import strategy_combinations
28-
from tensorflow.python.distribute import strategy_test_lib
29-
from tensorflow.python.distribute.single_loss_example import batchnorm_example
30-
from tensorflow.python.distribute.single_loss_example import minimize_loss_example
3128
from tensorflow.python.eager import context
3229
from tensorflow.python.framework import constant_op
33-
from tensorflow.python.framework import dtypes
3430
from tensorflow.python.framework import ops
3531
from tensorflow.python.framework import test_combinations as combinations
3632
from tensorflow.python.keras.distribute import optimizer_combinations
33+
from tensorflow.python.keras.distribute.test_example import batchnorm_example
34+
from tensorflow.python.keras.distribute.test_example import minimize_loss_example
35+
from tensorflow.python.keras.layers import core
36+
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
3737
from tensorflow.python.ops import array_ops
3838
from tensorflow.python.ops import control_flow_ops
3939
from tensorflow.python.ops import control_flow_v2_toggles
4040
from tensorflow.python.ops import math_ops
41-
from tensorflow.python.ops import nn_ops
4241
from tensorflow.python.ops import variable_scope
4342
from tensorflow.python.ops import variables as variables_lib
4443
from tensorflow.python.ops.losses import losses_impl
@@ -212,7 +211,7 @@ def run_step():
212211
def get_expected_variables(num_parameter_devices):
213212
name = optimizer._name
214213

215-
if strategy_test_lib.is_optimizer_v2_instance(optimizer):
214+
if isinstance(optimizer, optimizer_v2.OptimizerV2):
216215
variables = VAR_MAP_V2[name]
217216
else:
218217
variables = VAR_MAP_V1[name]
@@ -353,7 +352,7 @@ def loss_fn():
353352

354353
optimizer = optimizer_fn() # GradientDescent with 0.2 learning rate
355354

356-
if strategy_test_lib.is_optimizer_v2_instance(optimizer):
355+
if isinstance(optimizer, optimizer_v2.OptimizerV2):
357356
return optimizer.minimize(loss_fn, [w])
358357
else:
359358
if use_callable_loss:
@@ -430,25 +429,20 @@ def dataset_fn():
430429
return dataset.batch(batch_size=1, drop_remainder=True)
431430

432431
optimizer = optimizer_fn()
433-
kernel = strategy_test_lib.create_variable_like_keras_layer(
434-
"kernel", (1, 1), dtypes.float32)
435-
bias = strategy_test_lib.create_variable_like_keras_layer(
436-
"bias", (1,), dtypes.float32)
437-
# layer = core.Dense(1, use_bias=True)
432+
layer = core.Dense(1, use_bias=True)
438433

439434
key1 = "foo"
440435
value1 = "bar"
441436

442437
def model_fn(output_context, x):
443438
"""A very simple model written by the user."""
444439
def loss_fn():
445-
y = array_ops.reshape(nn_ops.bias_add(
446-
math_ops.matmul(x, kernel), bias), []) - constant_op.constant(1.)
440+
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
447441
return y * y
448442

449-
if strategy_test_lib.is_optimizer_v2_instance(optimizer):
443+
if isinstance(optimizer, optimizer_v2.OptimizerV2):
450444
train_op = optimizer.minimize(
451-
loss_fn, lambda: [kernel, bias])
445+
loss_fn, lambda: layer.trainable_variables)
452446
else:
453447
train_op = optimizer.minimize(loss_fn)
454448
loss = loss_fn()
@@ -517,8 +511,8 @@ def run_step():
517511
for _ in range(5):
518512
_, loss = run_step()
519513
losses.append(loss)
520-
weights.append(self.evaluate(kernel))
521-
biases.append(self.evaluate(bias))
514+
weights.append(self.evaluate(layer.kernel))
515+
biases.append(self.evaluate(layer.bias))
522516

523517
loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:]))
524518
self.assertTrue(loss_is_not_increasing)

tensorflow/python/keras/distribute/step_fn_test.py

Lines changed: 0 additions & 72 deletions
This file was deleted.
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""A simple network to use in tests and examples."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from tensorflow.python.data.ops import dataset_ops
22+
from tensorflow.python.framework import constant_op
23+
from tensorflow.python.framework import ops
24+
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
25+
from tensorflow.python.layers import core
26+
from tensorflow.python.layers import normalization
27+
from tensorflow.python.ops import array_ops
28+
from tensorflow.python.ops import math_ops
29+
30+
31+
def minimize_loss_example(optimizer, use_bias=False, use_callable_loss=True):
32+
"""Example of non-distribution-aware legacy code."""
33+
34+
def dataset_fn():
35+
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
36+
# TODO(isaprykin): batch with drop_remainder causes shapes to be
37+
# fully defined for TPU. Remove this when XLA supports dynamic shapes.
38+
return dataset.batch(1, drop_remainder=True)
39+
40+
layer = core.Dense(1, use_bias=use_bias)
41+
42+
def model_fn(x):
43+
"""A very simple model written by the user."""
44+
45+
def loss_fn():
46+
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
47+
return y * y
48+
49+
if isinstance(optimizer, optimizer_v2.OptimizerV2):
50+
return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
51+
elif use_callable_loss:
52+
return optimizer.minimize(loss_fn)
53+
else:
54+
return optimizer.minimize(loss_fn())
55+
56+
return model_fn, dataset_fn, layer
57+
58+
59+
def batchnorm_example(optimizer_fn,
60+
batch_per_epoch=1,
61+
momentum=0.9,
62+
renorm=False,
63+
update_ops_in_replica_mode=False):
64+
"""Example of non-distribution-aware legacy code with batch normalization."""
65+
66+
def dataset_fn():
67+
# input shape is [16, 8], input values are increasing in both dimensions.
68+
return dataset_ops.Dataset.from_tensor_slices(
69+
[[[float(x * 8 + y + z * 100)
70+
for y in range(8)]
71+
for x in range(16)]
72+
for z in range(batch_per_epoch)]).repeat()
73+
74+
optimizer = optimizer_fn()
75+
batchnorm = normalization.BatchNormalization(
76+
renorm=renorm, momentum=momentum, fused=False)
77+
layer = core.Dense(1, use_bias=False)
78+
79+
def model_fn(x):
80+
"""A model that uses batchnorm."""
81+
82+
def loss_fn():
83+
y = batchnorm(x, training=True)
84+
with ops.control_dependencies(
85+
ops.get_collection(ops.GraphKeys.UPDATE_OPS)
86+
if update_ops_in_replica_mode else []):
87+
loss = math_ops.reduce_mean(
88+
math_ops.reduce_sum(layer(y)) - constant_op.constant(1.))
89+
# `x` and `y` will be fetched by the gradient computation, but not `loss`.
90+
return loss
91+
92+
if isinstance(optimizer, optimizer_v2.OptimizerV2):
93+
return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
94+
95+
# Callable loss.
96+
return optimizer.minimize(loss_fn)
97+
98+
return model_fn, dataset_fn, batchnorm

0 commit comments

Comments
 (0)