|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.distributed as dist |
| 5 | +import torch.nn as nn |
| 6 | +import torch.nn.functional as F |
| 7 | +from mmcv.cnn import ConvModule |
| 8 | + |
| 9 | +from ..builder import HEADS |
| 10 | +from .decode_head import BaseDecodeHead |
| 11 | + |
| 12 | + |
| 13 | +def reduce_mean(tensor): |
| 14 | + """Reduce mean when distributed training.""" |
| 15 | + if not (dist.is_available() and dist.is_initialized()): |
| 16 | + return tensor |
| 17 | + tensor = tensor.clone() |
| 18 | + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) |
| 19 | + return tensor |
| 20 | + |
| 21 | + |
| 22 | +class EMAModule(nn.Module): |
| 23 | + """Expectation Maximization Attention Module used in EMANet. |
| 24 | +
|
| 25 | + Args: |
| 26 | + channels (int): Channels of the whole module. |
| 27 | + num_bases (int): Number of bases. |
| 28 | + num_stages (int): Number of the EM iterations. |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__(self, channels, num_bases, num_stages, momentum): |
| 32 | + super(EMAModule, self).__init__() |
| 33 | + assert num_stages >= 1, 'num_stages must be at least 1!' |
| 34 | + self.num_bases = num_bases |
| 35 | + self.num_stages = num_stages |
| 36 | + self.momentum = momentum |
| 37 | + |
| 38 | + bases = torch.zeros(1, channels, self.num_bases) |
| 39 | + bases.normal_(0, math.sqrt(2. / self.num_bases)) |
| 40 | + # [1, channels, num_bases] |
| 41 | + bases = F.normalize(bases, dim=1, p=2) |
| 42 | + self.register_buffer('bases', bases) |
| 43 | + |
| 44 | + def forward(self, feats): |
| 45 | + """Forward function.""" |
| 46 | + batch_size, channels, height, width = feats.size() |
| 47 | + # [batch_size, channels, height*width] |
| 48 | + feats = feats.view(batch_size, channels, height * width) |
| 49 | + # [batch_size, channels, num_bases] |
| 50 | + bases = self.bases.repeat(batch_size, 1, 1) |
| 51 | + |
| 52 | + with torch.no_grad(): |
| 53 | + for i in range(self.num_stages): |
| 54 | + # [batch_size, height*width, num_bases] |
| 55 | + attention = torch.einsum('bcn,bck->bnk', feats, bases) |
| 56 | + attention = F.softmax(attention, dim=2) |
| 57 | + # l1 norm |
| 58 | + attention_normed = F.normalize(attention, dim=1, p=1) |
| 59 | + # [batch_size, channels, num_bases] |
| 60 | + bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) |
| 61 | + # l2 norm |
| 62 | + bases = F.normalize(bases, dim=1, p=2) |
| 63 | + |
| 64 | + feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) |
| 65 | + feats_recon = feats_recon.view(batch_size, channels, height, width) |
| 66 | + |
| 67 | + if self.training: |
| 68 | + bases = bases.mean(dim=0, keepdim=True) |
| 69 | + bases = reduce_mean(bases) |
| 70 | + # l2 norm |
| 71 | + bases = F.normalize(bases, dim=1, p=2) |
| 72 | + self.bases = (1 - |
| 73 | + self.momentum) * self.bases + self.momentum * bases |
| 74 | + |
| 75 | + return feats_recon |
| 76 | + |
| 77 | + |
| 78 | +@HEADS.register_module() |
| 79 | +class EMAHead(BaseDecodeHead): |
| 80 | + """Expectation Maximization Attention Networks for Semantic Segmentation. |
| 81 | +
|
| 82 | + This head is the implementation of `EMANet |
| 83 | + <https://arxiv.org/abs/1907.13426>`_. |
| 84 | +
|
| 85 | + Args: |
| 86 | + ema_channels (int): EMA module channels |
| 87 | + num_bases (int): Number of bases. |
| 88 | + num_stages (int): Number of the EM iterations. |
| 89 | + concat_input (bool): Whether concat the input and output of convs |
| 90 | + before classification layer. Default: True |
| 91 | + momentum (float): Momentum to update the base. Default: 0.1. |
| 92 | + """ |
| 93 | + |
| 94 | + def __init__(self, |
| 95 | + ema_channels, |
| 96 | + num_bases, |
| 97 | + num_stages, |
| 98 | + concat_input=True, |
| 99 | + momentum=0.1, |
| 100 | + **kwargs): |
| 101 | + super(EMAHead, self).__init__(**kwargs) |
| 102 | + self.ema_channels = ema_channels |
| 103 | + self.num_bases = num_bases |
| 104 | + self.num_stages = num_stages |
| 105 | + self.concat_input = concat_input |
| 106 | + self.momentum = momentum |
| 107 | + self.ema_module = EMAModule(self.ema_channels, self.num_bases, |
| 108 | + self.num_stages, self.momentum) |
| 109 | + |
| 110 | + self.ema_in_conv = ConvModule( |
| 111 | + self.in_channels, |
| 112 | + self.ema_channels, |
| 113 | + 3, |
| 114 | + padding=1, |
| 115 | + conv_cfg=self.conv_cfg, |
| 116 | + norm_cfg=self.norm_cfg, |
| 117 | + act_cfg=self.act_cfg) |
| 118 | + # project (0, inf) -> (-inf, inf) |
| 119 | + self.ema_mid_conv = ConvModule( |
| 120 | + self.ema_channels, |
| 121 | + self.ema_channels, |
| 122 | + 1, |
| 123 | + conv_cfg=self.conv_cfg, |
| 124 | + norm_cfg=None, |
| 125 | + act_cfg=None) |
| 126 | + for param in self.ema_mid_conv.parameters(): |
| 127 | + param.requires_grad = False |
| 128 | + |
| 129 | + self.ema_out_conv = ConvModule( |
| 130 | + self.ema_channels, |
| 131 | + self.ema_channels, |
| 132 | + 1, |
| 133 | + conv_cfg=self.conv_cfg, |
| 134 | + norm_cfg=self.norm_cfg, |
| 135 | + act_cfg=None) |
| 136 | + self.bottleneck = ConvModule( |
| 137 | + self.ema_channels, |
| 138 | + self.channels, |
| 139 | + 3, |
| 140 | + padding=1, |
| 141 | + conv_cfg=self.conv_cfg, |
| 142 | + norm_cfg=self.norm_cfg, |
| 143 | + act_cfg=self.act_cfg) |
| 144 | + if self.concat_input: |
| 145 | + self.conv_cat = ConvModule( |
| 146 | + self.in_channels + self.channels, |
| 147 | + self.channels, |
| 148 | + kernel_size=3, |
| 149 | + padding=1, |
| 150 | + conv_cfg=self.conv_cfg, |
| 151 | + norm_cfg=self.norm_cfg, |
| 152 | + act_cfg=self.act_cfg) |
| 153 | + |
| 154 | + def forward(self, inputs): |
| 155 | + """Forward function.""" |
| 156 | + x = self._transform_inputs(inputs) |
| 157 | + feats = self.ema_in_conv(x) |
| 158 | + identity = feats |
| 159 | + feats = self.ema_mid_conv(feats) |
| 160 | + recon = self.ema_module(feats) |
| 161 | + recon = F.relu(recon, inplace=True) |
| 162 | + recon = self.ema_out_conv(recon) |
| 163 | + output = F.relu(identity + recon, inplace=True) |
| 164 | + output = self.bottleneck(output) |
| 165 | + if self.concat_input: |
| 166 | + output = self.conv_cat(torch.cat([x, output], dim=1)) |
| 167 | + output = self.cls_seg(output) |
| 168 | + return output |
0 commit comments