@@ -357,7 +357,10 @@ def _rnn_func(self, inputs, state, num_units):
357
357
return rnn_cell (inputs , state )
358
358
359
359
def _conv_rnn_func (self , inputs , state , filters ):
360
- inputs_shape = inputs .get_shape ().as_list ()
360
+ if isinstance (inputs , (list , tuple )):
361
+ inputs_shape = inputs [0 ].shape .as_list ()
362
+ else :
363
+ inputs_shape = inputs .shape .as_list ()
361
364
input_shape = inputs_shape [1 :]
362
365
if self .hparams .conv_rnn_norm_layer == 'none' :
363
366
normalizer_fn = None
@@ -446,18 +449,26 @@ def concat(tensors, axis):
446
449
h = layers [- 1 ][- 1 ]
447
450
kernel_size = (3 , 3 )
448
451
if self .hparams .where_add == 'all' or (self .hparams .where_add == 'input' and i == 0 ):
449
- h = tile_concat ([h , state_action_z [:, None , None , :]], axis = - 1 )
450
- h = downsample_layer (h , out_channels , kernel_size = kernel_size , strides = (2 , 2 ))
452
+ if self .hparams .use_tile_concat :
453
+ h = tile_concat ([h , state_action_z [:, None , None , :]], axis = - 1 )
454
+ else :
455
+ h = [h , state_action_z ]
456
+ h = _maybe_tile_concat_layer (downsample_layer )(
457
+ h , out_channels , kernel_size = kernel_size , strides = (2 , 2 ))
451
458
h = norm_layer (h )
452
459
h = activation_layer (h )
453
460
if use_conv_rnn :
454
461
with tf .variable_scope ('%s_h%d' % ('conv' if self .hparams .ablation_rnn else self .hparams .conv_rnn , i )):
455
462
if self .hparams .where_add == 'all' :
456
- conv_rnn_h = tile_concat ([h , state_action_z [:, None , None , :]], axis = - 1 )
463
+ if self .hparams .use_tile_concat :
464
+ conv_rnn_h = tile_concat ([h , state_action_z [:, None , None , :]], axis = - 1 )
465
+ else :
466
+ conv_rnn_h = [h , state_action_z ]
457
467
else :
458
468
conv_rnn_h = h
459
469
if self .hparams .ablation_rnn :
460
- conv_rnn_h = conv2d (conv_rnn_h , out_channels , kernel_size = (5 , 5 ))
470
+ conv_rnn_h = _maybe_tile_concat_layer (conv2d )(
471
+ conv_rnn_h , out_channels , kernel_size = (5 , 5 ))
461
472
conv_rnn_h = norm_layer (conv_rnn_h )
462
473
conv_rnn_h = activation_layer (conv_rnn_h )
463
474
else :
@@ -474,18 +485,25 @@ def concat(tensors, axis):
474
485
else :
475
486
h = tf .concat ([layers [- 1 ][- 1 ], layers [num_encoder_layers - i - 1 ][- 1 ]], axis = - 1 )
476
487
if self .hparams .where_add == 'all' or (self .hparams .where_add == 'middle' and i == 0 ):
477
- h = tile_concat ([h , state_action_z [:, None , None , :]], axis = - 1 )
478
- h = upsample_layer (h , out_channels , kernel_size = (3 , 3 ), strides = (2 , 2 ))
488
+ if self .hparams .use_tile_concat :
489
+ h = tile_concat ([h , state_action_z [:, None , None , :]], axis = - 1 )
490
+ else :
491
+ h = [h , state_action_z ]
492
+ h = _maybe_tile_concat_layer (upsample_layer )(
493
+ h , out_channels , kernel_size = (3 , 3 ), strides = (2 , 2 ))
479
494
h = norm_layer (h )
480
495
h = activation_layer (h )
481
496
if use_conv_rnn :
482
497
with tf .variable_scope ('%s_h%d' % ('conv' if self .hparams .ablation_rnn else self .hparams .conv_rnn , len (layers ))):
483
498
if self .hparams .where_add == 'all' :
484
- conv_rnn_h = tile_concat ([h , state_action_z [:, None , None , :]], axis = - 1 )
499
+ if self .hparams .use_tile_concat :
500
+ conv_rnn_h = tile_concat ([h , state_action_z [:, None , None , :]], axis = - 1 )
501
+ else :
502
+ conv_rnn_h = [h , state_action_z ]
485
503
else :
486
504
conv_rnn_h = h
487
505
if self .hparams .ablation_rnn :
488
- conv_rnn_h = conv2d (conv_rnn_h , out_channels , kernel_size = (5 , 5 ))
506
+ conv_rnn_h = _maybe_tile_concat_layer ( conv2d ) (conv_rnn_h , out_channels , kernel_size = (5 , 5 ))
489
507
conv_rnn_h = norm_layer (conv_rnn_h )
490
508
conv_rnn_h = activation_layer (conv_rnn_h )
491
509
else :
@@ -770,6 +788,7 @@ def get_default_hparams_dict(self):
770
788
kernel_size = (5 , 5 ),
771
789
dilation_rate = (1 , 1 ),
772
790
where_add = 'all' ,
791
+ use_tile_concat = True ,
773
792
learn_initial_state = False ,
774
793
rnn = 'lstm' ,
775
794
conv_rnn = 'lstm' ,
@@ -950,3 +969,16 @@ def center_slice(k):
950
969
kernel [center_slice (kh ), center_slice (kw )] = 1.0
951
970
kernel /= np .sum (kernel )
952
971
return kernel
972
+
973
+
974
+ def _maybe_tile_concat_layer (conv2d_layer ):
975
+ def layer (inputs , out_channels , * args , ** kwargs ):
976
+ if isinstance (inputs , (list , tuple )):
977
+ inputs_spatial , inputs_non_spatial = inputs
978
+ outputs = (conv2d_layer (inputs_spatial , out_channels , * args , ** kwargs ) +
979
+ dense (inputs_non_spatial , out_channels , use_bias = False )[:, None , None , :])
980
+ else :
981
+ outputs = conv2d_layer (inputs , out_channels , * args , ** kwargs )
982
+ return outputs
983
+
984
+ return layer
0 commit comments