|
61 | 61 | from tensorflow.contrib.tpu.python.tpu import tpu
|
62 | 62 | from tensorflow.contrib.tpu.python.tpu import tpu_function
|
63 | 63 | from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
|
| 64 | +from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib |
64 | 65 | from tensorflow.core.protobuf import config_pb2
|
65 | 66 | from tensorflow.python.client import session as tf_session
|
66 | 67 | from tensorflow.python.data.ops import dataset_ops
|
|
80 | 81 | from tensorflow.python.ops import random_ops
|
81 | 82 | from tensorflow.python.ops import variable_scope
|
82 | 83 | from tensorflow.python.platform import tf_logging as logging
|
83 |
| -from tensorflow.python.util import tf_inspect |
84 | 84 |
|
85 | 85 |
|
86 | 86 | _SESSIONS = {}
|
@@ -110,31 +110,52 @@ def reset_tpu_sessions():
|
110 | 110 | _SESSIONS.clear()
|
111 | 111 |
|
112 | 112 |
|
113 |
| -# Work-around dependency cycle between DistributionStrategy and TPU lib. |
114 |
| -def TPUDistributionStrategy(tpu_cluster_resolver=None, num_cores=None): # pylint: disable=invalid-name |
115 |
| - """Construct a TPUDistributionStrategy.""" |
116 |
| - from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top |
117 |
| - # TODO(b/112705069): Remove this when TPUStrategy API is consistent. |
118 |
| - # We are including this for (a) backwards compatibility for open sourced |
119 |
| - # releases of TensorFlow and (b) to work around a circular dependency |
120 |
| - # where keras_support and tpu_strategy depends on each other. Once we release |
121 |
| - # a final version and remove support for the old API, this will be deleted. |
122 |
| - # (See bug above for more details) |
123 |
| - if tpu_cluster_resolver is None: |
124 |
| - tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('') |
125 |
| - |
126 |
| - args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__) |
127 |
| - if len(args) == 4: |
128 |
| - logging.info('Detected new TPUStrategy API.') |
129 |
| - return tpu_strategy.TPUStrategy(tpu_cluster_resolver, |
130 |
| - steps_per_run=1, |
131 |
| - num_cores=num_cores) |
132 |
| - else: |
133 |
| - logging.info('Detected old TPUStrategy API.') |
134 |
| - strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8) |
135 |
| - strategy._tpu_cluster_resolver = tpu_cluster_resolver |
| 113 | +def get_tpu_system_metadata(tpu_cluster_resolver): |
| 114 | + """Retrieves TPU system metadata given a TPUClusterResolver.""" |
| 115 | + master = tpu_cluster_resolver.master() |
| 116 | + |
| 117 | + # pylint: disable=protected-access |
| 118 | + cluster_spec = tpu_cluster_resolver.cluster_spec() |
| 119 | + cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None |
| 120 | + tpu_system_metadata = ( |
| 121 | + tpu_system_metadata_lib._query_tpu_system_metadata( |
| 122 | + master, |
| 123 | + cluster_def=cluster_def, |
| 124 | + query_topology=False)) |
| 125 | + |
| 126 | + return tpu_system_metadata |
| 127 | + |
| 128 | + |
| 129 | +class TPUDistributionStrategy(object): |
| 130 | + """The strategy to run Keras model on TPU.""" |
| 131 | + |
| 132 | + def __init__(self, tpu_cluster_resolver=None, using_single_core=False): |
| 133 | + """Construct a TPUDistributionStrategy. |
| 134 | +
|
| 135 | + Args: |
| 136 | + tpu_cluster_resolver: Any instance of `TPUClusterResolver`. If None, will |
| 137 | + create one with '' as master address. |
| 138 | + using_single_core: Bool. This is the debugging option, which might be |
| 139 | + removed in future once the model replication functionality is mature |
| 140 | + enough. If `False` (default behavior), the system automatically finds |
| 141 | + the best configuration, in terms of number of TPU cores, for the model |
| 142 | + replication, typically using all avaiable TPU cores. If overwrites as |
| 143 | + `True`, force the model replication using single core, i.e., no |
| 144 | + replication. |
| 145 | + """ |
136 | 146 |
|
137 |
| - return strategy |
| 147 | + if tpu_cluster_resolver is None: |
| 148 | + tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('') |
| 149 | + |
| 150 | + num_cores = (1 if using_single_core else |
| 151 | + get_tpu_system_metadata(tpu_cluster_resolver).num_cores) |
| 152 | + |
| 153 | + self._tpu_cluster_resolver = tpu_cluster_resolver |
| 154 | + self._num_cores = num_cores |
| 155 | + |
| 156 | + @property |
| 157 | + def num_towers(self): |
| 158 | + return self._num_cores |
138 | 159 |
|
139 | 160 |
|
140 | 161 | class TPUEmbedding(embeddings.Embedding):
|
@@ -1212,5 +1233,10 @@ def tpu_model(model, strategy=None):
|
1212 | 1233 |
|
1213 | 1234 | if strategy is None:
|
1214 | 1235 | strategy = TPUDistributionStrategy()
|
| 1236 | + else: |
| 1237 | + if not isinstance(strategy, TPUDistributionStrategy): |
| 1238 | + raise TypeError( |
| 1239 | + '`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. ' |
| 1240 | + 'Got: {}'.format(type(strategy))) |
1215 | 1241 |
|
1216 | 1242 | return KerasTPUModel(cpu_model=model, strategy=strategy)
|
0 commit comments