|
26 | 26 |
|
27 | 27 | from tensorflow.python.framework import ops |
28 | 28 | from tensorflow.python.framework import tensor_shape |
| 29 | +from tensorflow.python.framework import tensor_util |
29 | 30 | from tensorflow.python.layers import base as base_layer |
30 | 31 | from tensorflow.python.ops import array_ops |
31 | 32 | from tensorflow.python.ops import variable_scope as vs |
32 | 33 | from tensorflow.python.ops import variables as tf_variables |
33 | 34 | from tensorflow.python.util import nest |
34 | 35 |
|
35 | 36 |
|
36 | | -def _state_size_with_prefix(state_size, prefix=None): |
37 | | - """Helper function that enables int or TensorShape shape specification. |
| 37 | +def _concat(prefix, suffix, static=False): |
| 38 | + """Concat that enables int, Tensor, or TensorShape values. |
38 | 39 |
|
39 | | - This function takes a size specification, which can be an integer or a |
40 | | - TensorShape, and converts it into a list of integers. One may specify any |
41 | | - additional dimensions that precede the final state size specification. |
| 40 | + This function takes a size specification, which can be an integer, a |
| 41 | + TensorShape, or a Tensor, and converts it into a concatenated Tensor |
| 42 | + (if static = False) or a list of integers (if static = True). |
42 | 43 |
|
43 | 44 | Args: |
44 | | - state_size: TensorShape or int that specifies the size of a tensor. |
45 | | - prefix: optional additional list of dimensions to prepend. |
| 45 | + prefix: The prefix; usually the batch size (and/or time step size). |
| 46 | + (TensorShape, int, or Tensor.) |
| 47 | + suffix: TensorShape, int, or Tensor. |
| 48 | + static: If `True`, return a python list with possibly unknown dimensions. |
| 49 | + Otherwise return a `Tensor`. |
46 | 50 |
|
47 | 51 | Returns: |
48 | | - result_state_size: list of dimensions the resulting tensor size. |
| 52 | + shape: the concatenation of prefix and suffix. |
| 53 | +
|
| 54 | + Raises: |
| 55 | + ValueError: if `suffix` is not a scalar or vector (or TensorShape). |
| 56 | + ValueError: if prefix or suffix was `None` and asked for dynamic |
| 57 | + Tensors out. |
49 | 58 | """ |
50 | | - result_state_size = tensor_shape.as_shape(state_size).as_list() |
51 | | - if prefix is not None: |
52 | | - if not isinstance(prefix, list): |
53 | | - raise TypeError("prefix of _state_size_with_prefix should be a list.") |
54 | | - result_state_size = prefix + result_state_size |
55 | | - return result_state_size |
| 59 | + if isinstance(prefix, ops.Tensor): |
| 60 | + p = prefix |
| 61 | + p_static = tensor_util.constant_value(prefix) |
| 62 | + if p.shape.ndims == 0: |
| 63 | + p = array_ops.expand_dims(p, 0) |
| 64 | + elif p.shape.ndims != 1: |
| 65 | + raise ValueError("prefix tensor must be either a scalar or vector, " |
| 66 | + "but saw tensor: %s" % p) |
| 67 | + else: |
| 68 | + p = tensor_shape.as_shape(prefix) |
| 69 | + p = p.as_list() if p.ndims is not None else None |
| 70 | + p_static = p |
| 71 | + if isinstance(suffix, ops.Tensor): |
| 72 | + s = suffix |
| 73 | + s_static = tensor_util.constant_value(suffix) |
| 74 | + if s.shape.ndims == 0: |
| 75 | + s = array_ops.expand_dims(s, 0) |
| 76 | + elif s.shape.ndims != 1: |
| 77 | + raise ValueError("suffix tensor must be either a scalar or vector, " |
| 78 | + "but saw tensor: %s" % s) |
| 79 | + else: |
| 80 | + s = tensor_shape.as_shape(suffix) |
| 81 | + s = s.as_list() if s.ndims is not None else None |
| 82 | + s_static = s |
| 83 | + |
| 84 | + if static: |
| 85 | + shape = tensor_shape.as_shape(p_static).concatenate(s_static) |
| 86 | + shape = shape.as_list() if shape.ndims is not None else None |
| 87 | + else: |
| 88 | + if p is None or s is None: |
| 89 | + raise ValueError("Provided a prefix or suffix of None: %s and %s" |
| 90 | + % (prefix, suffix)) |
| 91 | + shape = array_ops.concat((p, s), 0) |
| 92 | + return shape |
56 | 93 |
|
57 | 94 |
|
58 | 95 | def _zero_state_tensors(state_size, batch_size, dtype): |
59 | 96 | """Create tensors of zeros based on state_size, batch_size, and dtype.""" |
60 | | - if nest.is_sequence(state_size): |
61 | | - state_size_flat = nest.flatten(state_size) |
62 | | - zeros_flat = [ |
63 | | - array_ops.zeros( |
64 | | - array_ops.stack(_state_size_with_prefix( |
65 | | - s, prefix=[batch_size])), |
66 | | - dtype=dtype) for s in state_size_flat |
67 | | - ] |
68 | | - for s, z in zip(state_size_flat, zeros_flat): |
69 | | - z.set_shape(_state_size_with_prefix(s, prefix=[None])) |
70 | | - zeros = nest.pack_sequence_as(structure=state_size, |
71 | | - flat_sequence=zeros_flat) |
72 | | - else: |
73 | | - zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size]) |
74 | | - zeros = array_ops.zeros(array_ops.stack(zeros_size), dtype=dtype) |
75 | | - zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None])) |
76 | | - |
77 | | - return zeros |
| 97 | + def get_state_shape(s): |
| 98 | + """Combine s with batch_size to get a proper tensor shape.""" |
| 99 | + c = _concat(batch_size, s) |
| 100 | + c_static = _concat(batch_size, s, static=True) |
| 101 | + size = array_ops.zeros(c, dtype=dtype) |
| 102 | + size.set_shape(c_static) |
| 103 | + return size |
| 104 | + return nest.map_structure(get_state_shape, state_size) |
78 | 105 |
|
79 | 106 |
|
80 | 107 | class _RNNCell(base_layer.Layer): |
|
0 commit comments