Skip to content

Commit ea6b360

Browse files
committed
add global_context_block
1 parent 28235da commit ea6b360

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

ops.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,44 @@ def convolution_block_attention(x, channels, ratio=16, use_bias=True, sn=False,
454454
return x
455455

456456

457+
def global_context_block(x, channels, use_bias=True, sn=False, scope='gc_block'):
458+
with tf.variable_scope(scope):
459+
with tf.variable_scope('context_modeling'):
460+
bs, h, w, c = x.get_shape().as_list()
461+
input_x = x
462+
input_x = hw_flatten(input_x) # [N, H*W, C]
463+
input_x = tf.transpose(input_x, perm=[0, 2, 1])
464+
input_x = tf.expand_dims(input_x, axis=1)
465+
466+
context_mask = conv(x, channels=1, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv')
467+
context_mask = hw_flatten(context_mask)
468+
context_mask = tf.nn.softmax(context_mask, axis=1) # [N, H*W, 1]
469+
context_mask = tf.transpose(context_mask, perm=[0, 2, 1])
470+
context_mask = tf.expand_dims(context_mask, axis=-1)
471+
472+
context = tf.matmul(input_x, context_mask)
473+
context = tf.reshape(context, shape=[bs, 1, 1, c])
474+
475+
with tf.variable_scope('transform_0'):
476+
context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0')
477+
context_transform = layer_norm(context_transform)
478+
context_transform = relu(context_transform)
479+
context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_1')
480+
context_transform = sigmoid(context_transform)
481+
482+
x = x * context_transform
483+
484+
with tf.variable_scope('transform_1'):
485+
context_transform = conv(context, channels, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_0')
486+
context_transform = layer_norm(context_transform)
487+
context_transform = relu(context_transform)
488+
context_transform = conv(context_transform, channels=c, kernel=1, stride=1, use_bias=use_bias, sn=sn, scope='conv_1')
489+
490+
x = x + context_transform
491+
492+
return x
493+
494+
457495
##################################################################################
458496
# Normalization
459497
##################################################################################
@@ -728,7 +766,6 @@ def histogram_loss(x, y):
728766

729767
return hist_loss
730768

731-
732769
def get_histogram(img, bin_size=0.2):
733770
hist_entries = []
734771

@@ -788,9 +825,9 @@ def dice_loss(n_classes, logits, labels):
788825
:param labels: [batch_size, m, n, 1] int32, class label
789826
:return:
790827
"""
791-
828+
792829
# https://github.com/keras-team/keras/issues/9395
793-
830+
794831
smooth = 1e-7
795832
dtype = tf.float32
796833

0 commit comments

Comments
 (0)