You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi! I was just wondering why the module in line 171 of model.py called ResidualAttentionBlock? It looks like standard attention block for me instead of residual attention. Am I misunderstanding something here? Thanks!
For reference: I thought residual attention refers to https://arxiv.org/abs/1704.06904
But this implementation doesn't look like the paper for me
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
The text was updated successfully, but these errors were encountered:
Hi! I was just wondering why the module in line 171 of model.py called ResidualAttentionBlock? It looks like standard attention block for me instead of residual attention. Am I misunderstanding something here? Thanks!
For reference: I thought residual attention refers to https://arxiv.org/abs/1704.06904
But this implementation doesn't look like the paper for me
The text was updated successfully, but these errors were encountered: