Skip to content

Commit 836b56a

Browse files
committed
Add alternative implementation of tile_concat that uses much less weights and memory.
1 parent f6fe198 commit 836b56a

File tree

3 files changed

+75
-12
lines changed

3 files changed

+75
-12
lines changed

video_prediction/models/savp_model.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,10 @@ def _rnn_func(self, inputs, state, num_units):
357357
return rnn_cell(inputs, state)
358358

359359
def _conv_rnn_func(self, inputs, state, filters):
360-
inputs_shape = inputs.get_shape().as_list()
360+
if isinstance(inputs, (list, tuple)):
361+
inputs_shape = inputs[0].shape.as_list()
362+
else:
363+
inputs_shape = inputs.shape.as_list()
361364
input_shape = inputs_shape[1:]
362365
if self.hparams.conv_rnn_norm_layer == 'none':
363366
normalizer_fn = None
@@ -446,18 +449,26 @@ def concat(tensors, axis):
446449
h = layers[-1][-1]
447450
kernel_size = (3, 3)
448451
if self.hparams.where_add == 'all' or (self.hparams.where_add == 'input' and i == 0):
449-
h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
450-
h = downsample_layer(h, out_channels, kernel_size=kernel_size, strides=(2, 2))
452+
if self.hparams.use_tile_concat:
453+
h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
454+
else:
455+
h = [h, state_action_z]
456+
h = _maybe_tile_concat_layer(downsample_layer)(
457+
h, out_channels, kernel_size=kernel_size, strides=(2, 2))
451458
h = norm_layer(h)
452459
h = activation_layer(h)
453460
if use_conv_rnn:
454461
with tf.variable_scope('%s_h%d' % ('conv' if self.hparams.ablation_rnn else self.hparams.conv_rnn, i)):
455462
if self.hparams.where_add == 'all':
456-
conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
463+
if self.hparams.use_tile_concat:
464+
conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
465+
else:
466+
conv_rnn_h = [h, state_action_z]
457467
else:
458468
conv_rnn_h = h
459469
if self.hparams.ablation_rnn:
460-
conv_rnn_h = conv2d(conv_rnn_h, out_channels, kernel_size=(5, 5))
470+
conv_rnn_h = _maybe_tile_concat_layer(conv2d)(
471+
conv_rnn_h, out_channels, kernel_size=(5, 5))
461472
conv_rnn_h = norm_layer(conv_rnn_h)
462473
conv_rnn_h = activation_layer(conv_rnn_h)
463474
else:
@@ -474,18 +485,25 @@ def concat(tensors, axis):
474485
else:
475486
h = tf.concat([layers[-1][-1], layers[num_encoder_layers - i - 1][-1]], axis=-1)
476487
if self.hparams.where_add == 'all' or (self.hparams.where_add == 'middle' and i == 0):
477-
h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
478-
h = upsample_layer(h, out_channels, kernel_size=(3, 3), strides=(2, 2))
488+
if self.hparams.use_tile_concat:
489+
h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
490+
else:
491+
h = [h, state_action_z]
492+
h = _maybe_tile_concat_layer(upsample_layer)(
493+
h, out_channels, kernel_size=(3, 3), strides=(2, 2))
479494
h = norm_layer(h)
480495
h = activation_layer(h)
481496
if use_conv_rnn:
482497
with tf.variable_scope('%s_h%d' % ('conv' if self.hparams.ablation_rnn else self.hparams.conv_rnn, len(layers))):
483498
if self.hparams.where_add == 'all':
484-
conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
499+
if self.hparams.use_tile_concat:
500+
conv_rnn_h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1)
501+
else:
502+
conv_rnn_h = [h, state_action_z]
485503
else:
486504
conv_rnn_h = h
487505
if self.hparams.ablation_rnn:
488-
conv_rnn_h = conv2d(conv_rnn_h, out_channels, kernel_size=(5, 5))
506+
conv_rnn_h = _maybe_tile_concat_layer(conv2d)(conv_rnn_h, out_channels, kernel_size=(5, 5))
489507
conv_rnn_h = norm_layer(conv_rnn_h)
490508
conv_rnn_h = activation_layer(conv_rnn_h)
491509
else:
@@ -770,6 +788,7 @@ def get_default_hparams_dict(self):
770788
kernel_size=(5, 5),
771789
dilation_rate=(1, 1),
772790
where_add='all',
791+
use_tile_concat=True,
773792
learn_initial_state=False,
774793
rnn='lstm',
775794
conv_rnn='lstm',
@@ -950,3 +969,16 @@ def center_slice(k):
950969
kernel[center_slice(kh), center_slice(kw)] = 1.0
951970
kernel /= np.sum(kernel)
952971
return kernel
972+
973+
974+
def _maybe_tile_concat_layer(conv2d_layer):
975+
def layer(inputs, out_channels, *args, **kwargs):
976+
if isinstance(inputs, (list, tuple)):
977+
inputs_spatial, inputs_non_spatial = inputs
978+
outputs = (conv2d_layer(inputs_spatial, out_channels, *args, **kwargs) +
979+
dense(inputs_non_spatial, out_channels, use_bias=False)[:, None, None, :])
980+
else:
981+
outputs = conv2d_layer(inputs, out_channels, *args, **kwargs)
982+
return outputs
983+
984+
return layer

video_prediction/ops.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
import tensorflow as tf
33

44

5-
def dense(inputs, units, use_spectral_norm=False):
5+
def dense(inputs, units, use_spectral_norm=False, use_bias=True):
66
with tf.variable_scope('dense'):
77
input_shape = inputs.get_shape().as_list()
88
kernel_shape = [input_shape[1], units]
99
kernel = tf.get_variable('kernel', kernel_shape, dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
1010
if use_spectral_norm:
1111
kernel = spectral_normed_weight(kernel)
12-
bias = tf.get_variable('bias', [units], dtype=tf.float32, initializer=tf.zeros_initializer())
13-
outputs = tf.matmul(inputs, kernel) + bias
12+
outputs = tf.matmul(inputs, kernel)
13+
if use_bias:
14+
bias = tf.get_variable('bias', [units], dtype=tf.float32, initializer=tf.zeros_initializer())
15+
outputs = tf.nn.bias_add(outputs, bias)
1416
return outputs
1517

1618

video_prediction/rnn_ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,25 @@ def _conv2d(self, inputs):
125125
outputs = nn_ops.bias_add(outputs, bias)
126126
return outputs
127127

128+
def _dense(self, inputs):
129+
num_units = 4 * self._filters
130+
input_shape = inputs.shape.as_list()
131+
kernel_shape = [input_shape[-1], num_units]
132+
kernel = vs.get_variable("weights", kernel_shape, dtype=dtypes.float32,
133+
initializer=init_ops.truncated_normal_initializer(stddev=0.02))
134+
outputs = tf.matmul(inputs, kernel)
135+
return outputs
136+
128137
def call(self, inputs, state):
129138
"""2D Convolutional LSTM cell with (optional) normalization and recurrent dropout."""
130139
c, h = state
140+
tile_concat = isinstance(inputs, (list, tuple))
141+
if tile_concat:
142+
inputs, inputs_non_spatial = inputs
131143
args = array_ops.concat([inputs, h], -1)
132144
concat = self._conv2d(args)
145+
if tile_concat:
146+
concat = concat + self._dense(inputs_non_spatial)[:, None, None, :]
133147

134148
if self._normalizer_fn and not self._separate_norms:
135149
concat = self._norm(concat, "input_transform_forget_output")
@@ -209,13 +223,26 @@ def _conv2d(self, inputs, output_filters, bias_initializer):
209223
outputs = nn_ops.bias_add(outputs, bias)
210224
return outputs
211225

226+
def _dense(self, inputs, num_units):
227+
input_shape = inputs.shape.as_list()
228+
kernel_shape = [input_shape[-1], num_units]
229+
kernel = vs.get_variable("weights", kernel_shape, dtype=dtypes.float32,
230+
initializer=init_ops.truncated_normal_initializer(stddev=0.02))
231+
outputs = tf.matmul(inputs, kernel)
232+
return outputs
233+
212234
def call(self, inputs, state):
213235
bias_ones = self._bias_initializer
214236
if self._bias_initializer is None:
215237
bias_ones = init_ops.ones_initializer()
238+
tile_concat = isinstance(inputs, (list, tuple))
239+
if tile_concat:
240+
inputs, inputs_non_spatial = inputs
216241
with vs.variable_scope('gates'):
217242
inputs = array_ops.concat([inputs, state], axis=-1)
218243
concat = self._conv2d(inputs, 2 * self._filters, bias_ones)
244+
if tile_concat:
245+
concat = concat + self._dense(inputs_non_spatial, concat.shape[-1].value)[:, None, None, :]
219246
if self._normalizer_fn and not self._separate_norms:
220247
concat = self._norm(concat, "reset_update", bias_ones)
221248
r, u = array_ops.split(concat, 2, axis=-1)
@@ -230,6 +257,8 @@ def call(self, inputs, state):
230257
with vs.variable_scope('candidate'):
231258
inputs = array_ops.concat([inputs, r * state], axis=-1)
232259
candidate = self._conv2d(inputs, self._filters, bias_zeros)
260+
if tile_concat:
261+
candidate = candidate + self._dense(inputs_non_spatial, candidate.shape[-1].value)[:, None, None, :]
233262
if self._normalizer_fn:
234263
candidate = self._norm(candidate, "state", bias_zeros)
235264

0 commit comments

Comments
 (0)