|
1 |
| - |
2 | 1 | import tensorflow as tf
|
3 |
| - |
4 |
| - |
| 2 | +from tensorflow.contrib.layers.python.layers.layers import layer_norm |
5 | 3 |
|
6 | 4 | class ConvRNNCell(object):
|
7 | 5 | """Abstract object representing an Convolutional RNN cell.
|
8 | 6 | """
|
| 7 | + |
9 | 8 | def __call__(self, inputs, state, scope=None):
|
10 | 9 | """Run this RNN cell on inputs, starting from the given state.
|
11 | 10 | """
|
@@ -37,10 +36,10 @@ def zero_state(self, batch_size):
|
37 | 36 | if self._state_is_tuple:
|
38 | 37 | if self._data_format == 'NHWC':
|
39 | 38 | zeros = (tf.zeros([batch_size, shape[0], shape[1], num_features]),
|
40 |
| - tf.zeros([batch_size, shape[0], shape[1], num_features])) |
| 39 | + tf.zeros([batch_size, shape[0], shape[1], num_features])) |
41 | 40 | else:
|
42 | 41 | zeros = (tf.zeros([batch_size, num_features, shape[0], shape[1]]),
|
43 |
| - tf.zeros([batch_size, num_features, shape[0], shape[1]])) |
| 42 | + tf.zeros([batch_size, num_features, shape[0], shape[1]])) |
44 | 43 | else:
|
45 | 44 | if self._data_format == 'NHWC':
|
46 | 45 | zeros = tf.zeros([batch_size, shape[0], shape[1], num_features * 2])
|
@@ -110,6 +109,7 @@ def __call__(self, inputs, state, scope=None, reuse=None, clip_cell=None):
|
110 | 109 | new_state = tf.concat(axis=channel_axis, values=[new_c, new_h])
|
111 | 110 | return new_h, new_state
|
112 | 111 |
|
| 112 | + |
113 | 113 | class LayerNormConvLSTMCell(ConvRNNCell):
|
114 | 114 | """Basic Conv LSTM recurrent network cell. The
|
115 | 115 | """
|
@@ -155,23 +155,36 @@ def __call__(self, inputs, state, scope=None, reuse=None):
|
155 | 155 | else:
|
156 | 156 | c, h = tf.split(axis=channel_axis, num_or_size_splits=2, value=state)
|
157 | 157 | concat_i = _conv_linear([inputs], self.filter_size, self.num_features * 4, False,
|
158 |
| - data_format=self._data_format, scope='Input') |
| 158 | + data_format=self._data_format, scope='Input') |
159 | 159 | concat_h = _conv_linear([inputs], self.filter_size, self.num_features * 4, False,
|
160 |
| - data_format=self._data_format, scope='H') |
| 160 | + data_format=self._data_format, scope='H') |
161 | 161 | bias_term = tf.get_variable("Bias", [self.num_features * 4], dtype=inputs.dtype,
|
162 | 162 | initializer=tf.constant_initializer(0.0, dtype=inputs.dtype))
|
163 |
| - concat = tf.nn.bias_add(tf.contrib.layers.layer_norm(concat_i) + tf.contrib.layers.layer_norm(concat_h), |
| 163 | + if self._data_format == 'NCHW': |
| 164 | + concat_i = tf.transpose(concat_i, (0, 2, 3, 1)) |
| 165 | + concat_h = tf.transpose(concat_h, (0, 2, 3, 1)) |
| 166 | + normed_i = layer_norm(concat_i) |
| 167 | + normed_h = layer_norm(concat_h) |
| 168 | + if self._data_format == 'NCHW': |
| 169 | + normed_i = tf.transpose(normed_i, (0, 3, 1, 2)) |
| 170 | + normed_h = tf.transpose(normed_h, (0, 3, 1, 2)) |
| 171 | + |
| 172 | + concat = tf.nn.bias_add(normed_i + normed_h, |
164 | 173 | bias_term, data_format=self._data_format)
|
165 | 174 | # i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
166 | 175 | i, j, f, o = tf.split(axis=channel_axis, num_or_size_splits=4, value=concat)
|
167 | 176 | new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * self._activation(j))
|
168 |
| - new_c_norm = tf.contrib.layers.layer_norm(new_c) |
| 177 | + if self._data_format == 'NCHW': |
| 178 | + new_c = tf.transpose(new_c, (0, 2, 3, 1)) |
| 179 | + new_c_norm = layer_norm(new_c) |
| 180 | + if self._data_format == 'NCHW': |
| 181 | + new_c_norm = tf.transpose(new_c_norm, (0, 3, 1, 2)) |
169 | 182 | new_h = self._activation(new_c_norm) * tf.nn.sigmoid(o)
|
170 | 183 |
|
171 | 184 | if self._state_is_tuple:
|
172 |
| - new_state = (new_c, new_h) |
| 185 | + new_state = (new_c_norm, new_h) |
173 | 186 | else:
|
174 |
| - new_state = tf.concat(axis=channel_axis, values=[new_c, new_h]) |
| 187 | + new_state = tf.concat(axis=channel_axis, values=[new_c_norm, new_h]) |
175 | 188 | return new_h, new_state
|
176 | 189 |
|
177 | 190 |
|
@@ -261,13 +274,11 @@ def __call__(self, inputs, h, scope=None, reuse=None):
|
261 | 274 | channel_axis = 3 if self._data_format == 'NHWC' else 1
|
262 | 275 |
|
263 | 276 | concat = tf.nn.sigmoid(_conv_linear([inputs, h], self.filter_size, self.num_features * 2, True,
|
264 |
| - data_format=self._data_format, scope='gates')) |
| 277 | + data_format=self._data_format, scope='gates')) |
265 | 278 |
|
266 | 279 | z, r = tf.split(axis=channel_axis, num_or_size_splits=2, value=concat)
|
267 | 280 | i = tf.nn.tanh(_conv_linear([inputs, tf.multiply(r, h)], self.filter_size, self.num_features, True,
|
268 | 281 | data_format=self._data_format, scope='input'))
|
269 |
| - new_h = tf.add(tf.multiply(z, h), tf.multiply(1-z, i)) |
| 282 | + new_h = tf.add(tf.multiply(z, h), tf.multiply(1 - z, i)) |
270 | 283 |
|
271 | 284 | return new_h
|
272 |
| - |
273 |
| - |
|
0 commit comments