@@ -454,6 +454,44 @@ def convolution_block_attention(x, channels, ratio=16, use_bias=True, sn=False,
454454 return x
455455
456456
457+ def global_context_block (x , channels , use_bias = True , sn = False , scope = 'gc_block' ):
458+ with tf .variable_scope (scope ):
459+ with tf .variable_scope ('context_modeling' ):
460+ bs , h , w , c = x .get_shape ().as_list ()
461+ input_x = x
462+ input_x = hw_flatten (input_x ) # [N, H*W, C]
463+ input_x = tf .transpose (input_x , perm = [0 , 2 , 1 ])
464+ input_x = tf .expand_dims (input_x , axis = 1 )
465+
466+ context_mask = conv (x , channels = 1 , kernel = 1 , stride = 1 , use_bias = use_bias , sn = sn , scope = 'conv' )
467+ context_mask = hw_flatten (context_mask )
468+ context_mask = tf .nn .softmax (context_mask , axis = 1 ) # [N, H*W, 1]
469+ context_mask = tf .transpose (context_mask , perm = [0 , 2 , 1 ])
470+ context_mask = tf .expand_dims (context_mask , axis = - 1 )
471+
472+ context = tf .matmul (input_x , context_mask )
473+ context = tf .reshape (context , shape = [bs , 1 , 1 , c ])
474+
475+ with tf .variable_scope ('transform_0' ):
476+ context_transform = conv (context , channels , kernel = 1 , stride = 1 , use_bias = use_bias , sn = sn , scope = 'conv_0' )
477+ context_transform = layer_norm (context_transform )
478+ context_transform = relu (context_transform )
479+ context_transform = conv (context_transform , channels = c , kernel = 1 , stride = 1 , use_bias = use_bias , sn = sn , scope = 'conv_1' )
480+ context_transform = sigmoid (context_transform )
481+
482+ x = x * context_transform
483+
484+ with tf .variable_scope ('transform_1' ):
485+ context_transform = conv (context , channels , kernel = 1 , stride = 1 , use_bias = use_bias , sn = sn , scope = 'conv_0' )
486+ context_transform = layer_norm (context_transform )
487+ context_transform = relu (context_transform )
488+ context_transform = conv (context_transform , channels = c , kernel = 1 , stride = 1 , use_bias = use_bias , sn = sn , scope = 'conv_1' )
489+
490+ x = x + context_transform
491+
492+ return x
493+
494+
457495##################################################################################
458496# Normalization
459497##################################################################################
@@ -728,7 +766,6 @@ def histogram_loss(x, y):
728766
729767 return hist_loss
730768
731-
732769def get_histogram (img , bin_size = 0.2 ):
733770 hist_entries = []
734771
@@ -788,9 +825,9 @@ def dice_loss(n_classes, logits, labels):
788825 :param labels: [batch_size, m, n, 1] int32, class label
789826 :return:
790827 """
791-
828+
792829 # https://github.com/keras-team/keras/issues/9395
793-
830+
794831 smooth = 1e-7
795832 dtype = tf .float32
796833
0 commit comments