Skip to content

Commit 06102cd

Browse files
Replace load_variable from checkpoint in gmm and kmeans as it will be deprecated soon.
Change: 130957334
1 parent 8b54edf commit 06102cd

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

tensorflow/contrib/factorization/python/ops/gmm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from tensorflow.contrib.learn.python.learn.estimators import estimator
3131
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
3232
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
33-
from tensorflow.contrib.learn.python.learn.utils import checkpoints
3433
from tensorflow.python.ops.control_flow_ops import with_dependencies
3534

3635

@@ -157,13 +156,13 @@ def transform(self, x, batch_size=None):
157156

158157
def clusters(self):
159158
"""Returns cluster centers."""
160-
clusters = checkpoints.load_variable(self.model_dir,
161-
gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
159+
clusters = tf.contrib.framework.load_variable(
160+
self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
162161
return np.squeeze(clusters, 1)
163162

164163
def covariances(self):
165164
"""Returns the covariances."""
166-
return checkpoints.load_variable(
165+
return tf.contrib.framework.load_variable(
167166
self.model_dir,
168167
gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
169168

tensorflow/contrib/factorization/python/ops/kmeans.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
2929
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
3030
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
31-
from tensorflow.contrib.learn.python.learn.utils import checkpoints
3231
from tensorflow.python.ops.control_flow_ops import with_dependencies
3332

3433
SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE
@@ -221,7 +220,7 @@ def transform(self, x, batch_size=None):
221220

222221
def clusters(self):
223222
"""Returns cluster centers."""
224-
return checkpoints.load_variable(self.model_dir, self.CLUSTERS)
223+
return tf.contrib.framework.load_variable(self.model_dir, self.CLUSTERS)
225224

226225
def _get_train_ops(self, features, _):
227226
(_,

0 commit comments

Comments
 (0)