Skip to content

Commit 54efd63

Browse files
ebrevdotensorflower-gardener
authored andcommitted
Update RNN helpers to be able to handle dynamic state sizes.
This fixes a bug I introduced previously by adding the alignment into the AttentionWrapper's state (since the alignment's size may have to be a Tensor - the encoder's max_time is not usually static). PiperOrigin-RevId: 156077314
1 parent eef5ba5 commit 54efd63

File tree

6 files changed

+71
-47
lines changed

6 files changed

+71
-47
lines changed

tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def testLearnShiftByOne(self):
525525
num_classes = 2
526526
num_unroll = 32
527527
sequence_length = 32
528-
train_steps = 200
528+
train_steps = 300
529529
eval_steps = 20
530530
num_units = [4]
531531
learning_rate = 0.5

tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -942,8 +942,8 @@ def SampledLoss(labels, logits):
942942
perplexities[bucket].append(math.exp(float(res[1])))
943943
for bucket in range(len(buckets)):
944944
if len(perplexities[bucket]) > 1: # Assert that perplexity went down.
945-
self.assertLess(perplexities[bucket][-1], # 10% margin of error.
946-
1.1 * perplexities[bucket][0])
945+
self.assertLess(perplexities[bucket][-1], # 20% margin of error.
946+
1.2 * perplexities[bucket][0])
947947

948948
def testModelWithBooleanFeedPrevious(self):
949949
"""Test the model behavior when feed_previous is True.

tensorflow/contrib/rnn/python/ops/core_rnn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
# pylint: disable=protected-access
34-
_state_size_with_prefix = rnn_cell_impl._state_size_with_prefix
34+
_concat = rnn_cell_impl._concat
3535
_infer_state_dtype = rnn._infer_state_dtype
3636
_reverse_seq = rnn._reverse_seq
3737
_rnn_step = rnn._rnn_step
@@ -159,11 +159,10 @@ def static_rnn(cell, inputs, initial_state=None, dtype=None,
159159
"sequence_length must be a vector of length batch_size")
160160
def _create_zero_output(output_size):
161161
# convert int to TensorShape if necessary
162-
size = _state_size_with_prefix(output_size, prefix=[batch_size])
162+
size = _concat(batch_size, output_size)
163163
output = array_ops.zeros(
164164
array_ops.stack(size), _infer_state_dtype(dtype, state))
165-
shape = _state_size_with_prefix(
166-
output_size, prefix=[fixed_batch_size.value])
165+
shape = _concat(fixed_batch_size.value, output_size, static=True)
167166
output.set_shape(tensor_shape.TensorShape(shape))
168167
return output
169168

tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def testStepWithGreedyEmbeddingHelper(self):
124124
vocabulary_size = 7
125125
cell_depth = vocabulary_size # cell's logits must match vocabulary size
126126
input_depth = 10
127-
start_tokens = [0] * batch_size
127+
start_tokens = np.random.randint(0, vocabulary_size, size=batch_size)
128128
end_token = 1
129129

130130
with self.test_session(use_gpu=True) as sess:

tensorflow/python/ops/rnn.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
# pylint: disable=protected-access
36-
_state_size_with_prefix = rnn_cell_impl._state_size_with_prefix
36+
_concat = rnn_cell_impl._concat
3737
# pylint: enable=protected-access
3838

3939

@@ -660,7 +660,7 @@ def _dynamic_rnn_loop(cell,
660660

661661
# Prepare dynamic conditional copying of state & output
662662
def _create_zero_arrays(size):
663-
size = _state_size_with_prefix(size, prefix=[batch_size])
663+
size = _concat(batch_size, size)
664664
return array_ops.zeros(
665665
array_ops.stack(size), _infer_state_dtype(dtype, state))
666666

@@ -746,8 +746,8 @@ def _time_step(time, output_ta_t, state):
746746

747747
# Restore some shape information
748748
for output, output_size in zip(final_outputs, flat_output_size):
749-
shape = _state_size_with_prefix(
750-
output_size, prefix=[const_time_steps, const_batch_size])
749+
shape = _concat(
750+
[const_time_steps, const_batch_size], output_size, static=True)
751751
output.set_shape(shape)
752752

753753
final_outputs = nest.pack_sequence_as(
@@ -981,9 +981,7 @@ def loop_fn(time, cell_output, cell_state, loop_state):
981981
emit_ta = nest.pack_sequence_as(structure=emit_structure,
982982
flat_sequence=flat_emit_ta)
983983
flat_zero_emit = [
984-
array_ops.zeros(
985-
_state_size_with_prefix(size_i, prefix=[batch_size]),
986-
dtype_i)
984+
array_ops.zeros(_concat(batch_size, size_i), dtype_i)
987985
for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)]
988986
zero_emit = nest.pack_sequence_as(structure=emit_structure,
989987
flat_sequence=flat_zero_emit)

tensorflow/python/ops/rnn_cell_impl.py

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,55 +26,82 @@
2626

2727
from tensorflow.python.framework import ops
2828
from tensorflow.python.framework import tensor_shape
29+
from tensorflow.python.framework import tensor_util
2930
from tensorflow.python.layers import base as base_layer
3031
from tensorflow.python.ops import array_ops
3132
from tensorflow.python.ops import variable_scope as vs
3233
from tensorflow.python.ops import variables as tf_variables
3334
from tensorflow.python.util import nest
3435

3536

36-
def _state_size_with_prefix(state_size, prefix=None):
37-
"""Helper function that enables int or TensorShape shape specification.
37+
def _concat(prefix, suffix, static=False):
38+
"""Concat that enables int, Tensor, or TensorShape values.
3839
39-
This function takes a size specification, which can be an integer or a
40-
TensorShape, and converts it into a list of integers. One may specify any
41-
additional dimensions that precede the final state size specification.
40+
This function takes a size specification, which can be an integer, a
41+
TensorShape, or a Tensor, and converts it into a concatenated Tensor
42+
(if static = False) or a list of integers (if static = True).
4243
4344
Args:
44-
state_size: TensorShape or int that specifies the size of a tensor.
45-
prefix: optional additional list of dimensions to prepend.
45+
prefix: The prefix; usually the batch size (and/or time step size).
46+
(TensorShape, int, or Tensor.)
47+
suffix: TensorShape, int, or Tensor.
48+
static: If `True`, return a python list with possibly unknown dimensions.
49+
Otherwise return a `Tensor`.
4650
4751
Returns:
48-
result_state_size: list of dimensions the resulting tensor size.
52+
shape: the concatenation of prefix and suffix.
53+
54+
Raises:
55+
ValueError: if `suffix` is not a scalar or vector (or TensorShape).
56+
ValueError: if prefix or suffix was `None` and asked for dynamic
57+
Tensors out.
4958
"""
50-
result_state_size = tensor_shape.as_shape(state_size).as_list()
51-
if prefix is not None:
52-
if not isinstance(prefix, list):
53-
raise TypeError("prefix of _state_size_with_prefix should be a list.")
54-
result_state_size = prefix + result_state_size
55-
return result_state_size
59+
if isinstance(prefix, ops.Tensor):
60+
p = prefix
61+
p_static = tensor_util.constant_value(prefix)
62+
if p.shape.ndims == 0:
63+
p = array_ops.expand_dims(p, 0)
64+
elif p.shape.ndims != 1:
65+
raise ValueError("prefix tensor must be either a scalar or vector, "
66+
"but saw tensor: %s" % p)
67+
else:
68+
p = tensor_shape.as_shape(prefix)
69+
p = p.as_list() if p.ndims is not None else None
70+
p_static = p
71+
if isinstance(suffix, ops.Tensor):
72+
s = suffix
73+
s_static = tensor_util.constant_value(suffix)
74+
if s.shape.ndims == 0:
75+
s = array_ops.expand_dims(s, 0)
76+
elif s.shape.ndims != 1:
77+
raise ValueError("suffix tensor must be either a scalar or vector, "
78+
"but saw tensor: %s" % s)
79+
else:
80+
s = tensor_shape.as_shape(suffix)
81+
s = s.as_list() if s.ndims is not None else None
82+
s_static = s
83+
84+
if static:
85+
shape = tensor_shape.as_shape(p_static).concatenate(s_static)
86+
shape = shape.as_list() if shape.ndims is not None else None
87+
else:
88+
if p is None or s is None:
89+
raise ValueError("Provided a prefix or suffix of None: %s and %s"
90+
% (prefix, suffix))
91+
shape = array_ops.concat((p, s), 0)
92+
return shape
5693

5794

5895
def _zero_state_tensors(state_size, batch_size, dtype):
5996
"""Create tensors of zeros based on state_size, batch_size, and dtype."""
60-
if nest.is_sequence(state_size):
61-
state_size_flat = nest.flatten(state_size)
62-
zeros_flat = [
63-
array_ops.zeros(
64-
array_ops.stack(_state_size_with_prefix(
65-
s, prefix=[batch_size])),
66-
dtype=dtype) for s in state_size_flat
67-
]
68-
for s, z in zip(state_size_flat, zeros_flat):
69-
z.set_shape(_state_size_with_prefix(s, prefix=[None]))
70-
zeros = nest.pack_sequence_as(structure=state_size,
71-
flat_sequence=zeros_flat)
72-
else:
73-
zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
74-
zeros = array_ops.zeros(array_ops.stack(zeros_size), dtype=dtype)
75-
zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))
76-
77-
return zeros
97+
def get_state_shape(s):
98+
"""Combine s with batch_size to get a proper tensor shape."""
99+
c = _concat(batch_size, s)
100+
c_static = _concat(batch_size, s, static=True)
101+
size = array_ops.zeros(c, dtype=dtype)
102+
size.set_shape(c_static)
103+
return size
104+
return nest.map_structure(get_state_shape, state_size)
78105

79106

80107
class _RNNCell(base_layer.Layer):

0 commit comments

Comments
 (0)