Skip to content

[fbsync] Remove unneeded modules after using nn.Module for BetterTransformer #1696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions test/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,6 @@
from ..common.torchtext_test_case import TorchtextTestCase


class TestModules(TorchtextTestCase):
def test_self_attn_mask(self):
from torchtext.models.roberta.modules import MultiheadSelfAttention

embed_dim, batch_size, num_heads, source_len = 4, 1, 2, 2
mha = MultiheadSelfAttention(embed_dim=embed_dim, num_heads=num_heads)
query = torch.ones((source_len, batch_size, embed_dim))
query[0, ...] = 0
key_padding_mask = torch.zeros((batch_size, source_len))
float_attn_mask = torch.zeros((source_len, source_len))
float_attn_mask[0][1] = -1e8
bool_attn_mask = float_attn_mask.to(dtype=bool)
with torch.no_grad():
mha.input_projection.weight.fill_(1.0 / embed_dim)
mha.input_projection.bias.fill_(0.0)
mha.output_projection.weight.fill_(1.0 / embed_dim)
mha.output_projection.bias.fill_(0.0)

# with float attention mask
output = mha(query, key_padding_mask, float_attn_mask)
actual = output[0].flatten()
expected = torch.tensor([0.0, 0.0, 0.0, 0])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

# with bool attention mask
output = mha(query, key_padding_mask, bool_attn_mask)
actual = output[0].flatten()
expected = torch.tensor([0.0, 0.0, 0.0, 0])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)


class TestModels(TorchtextTestCase):
def test_roberta_bundler_build_model(self):
from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaBundle
Expand Down
157 changes: 1 addition & 156 deletions torchtext/models/roberta/modules.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import math
from typing import List, Optional, Union

import torch
from torch import nn
from torch.nn import functional as F, Module
from torch.nn import Module

logger = logging.getLogger(__name__)

Expand All @@ -30,160 +29,6 @@ def _make_positions(self, tensor, pad_index: int):
return torch.cumsum(masked, dim=1) * masked + pad_index


class ResidualMLP(Module):
def __init__(
self,
input_dim: int,
hidden_dims: List[int],
dropout: float = 0.1,
activation=nn.GELU,
add_residual=True,
):
super().__init__()
modules = []
for last_dim, dim in zip([input_dim] + hidden_dims, hidden_dims):
modules.extend([nn.Linear(last_dim, dim), activation(), nn.Dropout(dropout)])

last_dim = hidden_dims[-1] if hidden_dims else input_dim
modules.extend([nn.Linear(last_dim, input_dim), nn.Dropout(dropout)])

self.mlp = nn.Sequential(*modules)
self.add_residual = add_residual
self.hidden_dim = hidden_dims[0] if hidden_dims else input_dim

def forward(self, input):
bias = self.mlp(input)
if not hasattr(self, "add_residual"):
self.add_residual = True
if self.add_residual:
return input + bias
else:
return bias


class MultiheadSelfAttention(Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
scaling: Optional[float] = None,
dropout: float = 0.1,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads

expected_scaling = float(1 / math.sqrt(self.head_dim))

assert embed_dim % num_heads == 0, f"embed_dim={embed_dim} should be a multiple of num_heads={num_heads}"

if not scaling:
logger.warn(
f"""
Scaling not set. Please manually set scaling for transformers.
In this case the suggested value {expected_scaling} will be inferred,
or float(1 / math.sqrt(head_dim))
where head_dim = embed_dim // num_heads = {self.head_dim}
and embed_dim = {embed_dim} and num_heads = {num_heads}.
"""
)
scaling = expected_scaling

self.scaling = scaling
self.dropout = nn.Dropout(dropout)
self.input_projection = nn.Linear(embed_dim, 3 * embed_dim)
self.output_projection = nn.Linear(embed_dim, embed_dim)

def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
target_length, batch_size, embed_dim = query.size()
mask_batch_size, source_length = key_padding_mask.size()

torch._assert(embed_dim == self.embed_dim, "query embed dim doesn't match")
torch._assert(
batch_size == mask_batch_size,
"query and key_padding_mask batch sizes differed",
)

projection = self.input_projection(query)
q, k, v = projection.chunk(3, dim=-1)
q = self.scaling * q

batch_heads = batch_size * self.num_heads

q = q.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)

torch._assert(k.size(1) == source_length, "key size should be equal to source length")

attn_weights = torch.bmm(q, k.transpose(1, 2))
if attn_mask is not None:
torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim()))
torch._assert(
attn_mask.size(0) == target_length,
"attn_mask shape didn't match for target length {}".format(target_length),
)
torch._assert(
attn_mask.size(1) == source_length,
"attn_mask shape didn't match for source length {}".format(source_length),
)
torch._assert(
attn_mask.is_floating_point() or attn_mask.dtype == torch.bool,
f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}",
)
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=query.dtype)
new_attn_mask.masked_fill_(attn_mask, -1e8 if query.dtype == torch.float32 else -1e4)
attn_mask = new_attn_mask
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask

torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim")
torch._assert(
attn_weights.size(0) == batch_heads,
"attn_weights shape didn't match for batch heads",
)
torch._assert(
attn_weights.size(1) == target_length,
"attn_weights shape didn't match for target length",
)
torch._assert(
attn_weights.size(2) == source_length,
"attn_weights shape didn't match for source length",
)

attn_weights = attn_weights.view(batch_size, self.num_heads, target_length, source_length)
attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
attn_weights = attn_weights.view(batch_heads, target_length, source_length)

attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
attn_weights = self.dropout(attn_weights)

attn = torch.bmm(attn_weights, v)

torch._assert(
attn.dim() == 3,
"unexpected attn dim size",
)
torch._assert(
attn.size(0) == batch_heads,
"attn shape didn't match for batch heads",
)
torch._assert(
attn.size(1) == target_length,
"attn shape didn't match for target length",
)
torch._assert(
attn.size(2) == self.head_dim,
"attn shape didn't match for head dim",
)
attn = attn.transpose(0, 1).contiguous().view(target_length, batch_size, self.head_dim * self.num_heads)
attn = self.output_projection(attn)

return attn


class TransformerEncoderLayer(Module):
def __init__(
self,
Expand Down