Skip to content

Remove unneeded modules after using nn.Module for BetterTransformer #1693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 1 addition & 155 deletions torchtext/models/roberta/modules.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import math
from typing import List, Optional, Union

import torch
from torch import nn
from torch.nn import functional as F, Module
from torch.nn import Module

logger = logging.getLogger(__name__)

Expand All @@ -30,159 +29,6 @@ def _make_positions(self, tensor, pad_index: int):
return torch.cumsum(masked, dim=1) * masked + pad_index


class ResidualMLP(Module):
def __init__(
self,
input_dim: int,
hidden_dims: List[int],
dropout: float = 0.1,
activation=nn.GELU,
add_residual=True,
):
super().__init__()
modules = []
for last_dim, dim in zip([input_dim] + hidden_dims, hidden_dims):
modules.extend([nn.Linear(last_dim, dim), activation(), nn.Dropout(dropout)])

last_dim = hidden_dims[-1] if hidden_dims else input_dim
modules.extend([nn.Linear(last_dim, input_dim), nn.Dropout(dropout)])

self.mlp = nn.Sequential(*modules)
self.add_residual = add_residual
self.hidden_dim = hidden_dims[0] if hidden_dims else input_dim

def forward(self, input):
bias = self.mlp(input)
if not hasattr(self, "add_residual"):
self.add_residual = True
if self.add_residual:
return input + bias
else:
return bias


class MultiheadSelfAttention(Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
scaling: Optional[float] = None,
dropout: float = 0.1,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads

expected_scaling = float(1 / math.sqrt(self.head_dim))

assert embed_dim % num_heads == 0, f"embed_dim={embed_dim} should be a multiple of num_heads={num_heads}"

if not scaling:
logger.warn(
f"""
Scaling not set. Please manually set scaling for transformers.
In this case the suggested value {expected_scaling} will be inferred,
or float(1 / math.sqrt(head_dim))
where head_dim = embed_dim // num_heads = {self.head_dim}
and embed_dim = {embed_dim} and num_heads = {num_heads}.
"""
)
scaling = expected_scaling

self.scaling = scaling
self.dropout = nn.Dropout(dropout)
self.input_projection = nn.Linear(embed_dim, 3 * embed_dim)
self.output_projection = nn.Linear(embed_dim, embed_dim)

def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
target_length, batch_size, embed_dim = query.size()
mask_batch_size, source_length = key_padding_mask.size()

torch._assert(embed_dim == self.embed_dim, "query embed dim doesn't match")
torch._assert(
batch_size == mask_batch_size,
"query and key_padding_mask batch sizes differed",
)

projection = self.input_projection(query)
q, k, v = projection.chunk(3, dim=-1)
q = self.scaling * q

batch_heads = batch_size * self.num_heads

q = q.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)

torch._assert(k.size(1) == source_length, "key size should be equal to source length")

attn_weights = torch.bmm(q, k.transpose(1, 2))
if attn_mask is not None:
torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim()))
torch._assert(
attn_mask.size(0) == target_length,
"attn_mask shape didn't match for target length {}".format(target_length),
)
torch._assert(
attn_mask.size(1) == source_length,
"attn_mask shape didn't match for source length {}".format(source_length),
)
torch._assert(
attn_mask.is_floating_point() or attn_mask.dtype == torch.bool,
f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}",
)
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=query.dtype)
new_attn_mask.masked_fill_(attn_mask, -1e8 if query.dtype == torch.float32 else -1e4)
attn_mask = new_attn_mask
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask

torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim")
torch._assert(
attn_weights.size(0) == batch_heads,
"attn_weights shape didn't match for batch heads",
)
torch._assert(
attn_weights.size(1) == target_length,
"attn_weights shape didn't match for target length",
)
torch._assert(
attn_weights.size(2) == source_length,
"attn_weights shape didn't match for source length",
)

attn_weights = attn_weights.view(batch_size, self.num_heads, target_length, source_length)
attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
attn_weights = attn_weights.view(batch_heads, target_length, source_length)

attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
attn_weights = self.dropout(attn_weights)

attn = torch.bmm(attn_weights, v)

torch._assert(
attn.dim() == 3,
"unexpected attn dim size",
)
torch._assert(
attn.size(0) == batch_heads,
"attn shape didn't match for batch heads",
)
torch._assert(
attn.size(1) == target_length,
"attn shape didn't match for target length",
)
torch._assert(
attn.size(2) == self.head_dim,
"attn shape didn't match for head dim",
)
attn = attn.transpose(0, 1).contiguous().view(target_length, batch_size, self.head_dim * self.num_heads)
attn = self.output_projection(attn)

return attn


class TransformerEncoderLayer(Module):
def __init__(
Expand Down