Skip to content

Commit 1dd44f3

Browse files
Fixes for shared queue read functions:
. fix typo in call . re-export the functions so that they be accessed from outside code . do not create a new graph for the dummy placement variable. The new graph doesn't have the device setter applied, so the variable was not actually placed on the PS. Change: 137745823
1 parent db74f71 commit 1dd44f3

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

tensorflow/contrib/learn/python/learn/learn_io/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from tensorflow.contrib.learn.python.learn.learn_io.dask_io import extract_dask_data
2323
from tensorflow.contrib.learn.python.learn.learn_io.dask_io import extract_dask_labels
2424
from tensorflow.contrib.learn.python.learn.learn_io.dask_io import HAS_DASK
25+
from tensorflow.contrib.learn.python.learn.learn_io.graph_io import _read_keyed_batch_examples_shared_queue
26+
from tensorflow.contrib.learn.python.learn.learn_io.graph_io import _read_keyed_batch_features_shared_queue
2527
from tensorflow.contrib.learn.python.learn.learn_io.graph_io import queue_parsed_features
2628
from tensorflow.contrib.learn.python.learn.learn_io.graph_io import read_batch_examples
2729
from tensorflow.contrib.learn.python.learn.learn_io.graph_io import read_batch_features

tensorflow/contrib/learn/python/learn/learn_io/graph_io.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,7 @@ def _get_shared_file_name_queue(file_names, shuffle, num_epochs, name):
239239
# Creating a dummy variable so we can put the shared queue in ps if there is
240240
# a PS and in the worker otherwise. TODO(rohanj): Figure out how to place an
241241
# op on PS without this hack
242-
with ops.Graph().as_default():
243-
dummy_var = var_ops.Variable(initial_value=0, name='dummy_var')
242+
dummy_var = var_ops.Variable(initial_value=0, name='queue_placement_var')
244243
with ops.device(dummy_var.device):
245244
shared_file_name_queue = input_ops.string_input_producer(
246245
constant_op.constant(
@@ -561,7 +560,7 @@ def _read_keyed_batch_features_shared_queue(file_pattern,
561560
"""
562561

563562
with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope:
564-
keys, examples = read_keyed_batch_examples_shared_queue(
563+
keys, examples = _read_keyed_batch_examples_shared_queue(
565564
file_pattern,
566565
batch_size,
567566
reader,

0 commit comments

Comments
 (0)