Skip to content

Commit 4dcd000

Browse files
Jianwei Xietensorflower-gardener
Jianwei Xie
authored andcommitted
Step 1 toward automatically assigning # of TPU cores.
Also fix the API as num_cores is not a good API. Changed to a debugging oriented argument using_single_core. Ideally we should not need that in future. The updated lstm example cannot work with TF 1.10 anymore. PiperOrigin-RevId: 210632592
1 parent bb0f1e9 commit 4dcd000

File tree

1 file changed

+51
-25
lines changed

1 file changed

+51
-25
lines changed

tensorflow/contrib/tpu/python/tpu/keras_support.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from tensorflow.contrib.tpu.python.tpu import tpu
6262
from tensorflow.contrib.tpu.python.tpu import tpu_function
6363
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
64+
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
6465
from tensorflow.core.protobuf import config_pb2
6566
from tensorflow.python.client import session as tf_session
6667
from tensorflow.python.data.ops import dataset_ops
@@ -80,7 +81,6 @@
8081
from tensorflow.python.ops import random_ops
8182
from tensorflow.python.ops import variable_scope
8283
from tensorflow.python.platform import tf_logging as logging
83-
from tensorflow.python.util import tf_inspect
8484

8585

8686
_SESSIONS = {}
@@ -110,31 +110,52 @@ def reset_tpu_sessions():
110110
_SESSIONS.clear()
111111

112112

113-
# Work-around dependency cycle between DistributionStrategy and TPU lib.
114-
def TPUDistributionStrategy(tpu_cluster_resolver=None, num_cores=None): # pylint: disable=invalid-name
115-
"""Construct a TPUDistributionStrategy."""
116-
from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
117-
# TODO(b/112705069): Remove this when TPUStrategy API is consistent.
118-
# We are including this for (a) backwards compatibility for open sourced
119-
# releases of TensorFlow and (b) to work around a circular dependency
120-
# where keras_support and tpu_strategy depends on each other. Once we release
121-
# a final version and remove support for the old API, this will be deleted.
122-
# (See bug above for more details)
123-
if tpu_cluster_resolver is None:
124-
tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
125-
126-
args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__)
127-
if len(args) == 4:
128-
logging.info('Detected new TPUStrategy API.')
129-
return tpu_strategy.TPUStrategy(tpu_cluster_resolver,
130-
steps_per_run=1,
131-
num_cores=num_cores)
132-
else:
133-
logging.info('Detected old TPUStrategy API.')
134-
strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8)
135-
strategy._tpu_cluster_resolver = tpu_cluster_resolver
113+
def get_tpu_system_metadata(tpu_cluster_resolver):
114+
"""Retrieves TPU system metadata given a TPUClusterResolver."""
115+
master = tpu_cluster_resolver.master()
116+
117+
# pylint: disable=protected-access
118+
cluster_spec = tpu_cluster_resolver.cluster_spec()
119+
cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
120+
tpu_system_metadata = (
121+
tpu_system_metadata_lib._query_tpu_system_metadata(
122+
master,
123+
cluster_def=cluster_def,
124+
query_topology=False))
125+
126+
return tpu_system_metadata
127+
128+
129+
class TPUDistributionStrategy(object):
130+
"""The strategy to run Keras model on TPU."""
131+
132+
def __init__(self, tpu_cluster_resolver=None, using_single_core=False):
133+
"""Construct a TPUDistributionStrategy.
134+
135+
Args:
136+
tpu_cluster_resolver: Any instance of `TPUClusterResolver`. If None, will
137+
create one with '' as master address.
138+
using_single_core: Bool. This is the debugging option, which might be
139+
removed in future once the model replication functionality is mature
140+
enough. If `False` (default behavior), the system automatically finds
141+
the best configuration, in terms of number of TPU cores, for the model
142+
replication, typically using all avaiable TPU cores. If overwrites as
143+
`True`, force the model replication using single core, i.e., no
144+
replication.
145+
"""
136146

137-
return strategy
147+
if tpu_cluster_resolver is None:
148+
tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
149+
150+
num_cores = (1 if using_single_core else
151+
get_tpu_system_metadata(tpu_cluster_resolver).num_cores)
152+
153+
self._tpu_cluster_resolver = tpu_cluster_resolver
154+
self._num_cores = num_cores
155+
156+
@property
157+
def num_towers(self):
158+
return self._num_cores
138159

139160

140161
class TPUEmbedding(embeddings.Embedding):
@@ -1212,5 +1233,10 @@ def tpu_model(model, strategy=None):
12121233

12131234
if strategy is None:
12141235
strategy = TPUDistributionStrategy()
1236+
else:
1237+
if not isinstance(strategy, TPUDistributionStrategy):
1238+
raise TypeError(
1239+
'`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. '
1240+
'Got: {}'.format(type(strategy)))
12151241

12161242
return KerasTPUModel(cpu_model=model, strategy=strategy)

0 commit comments

Comments
 (0)