diff --git a/test/models/test_models.py b/test/models/test_models.py index 3303f7999b..e95f9cf090 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -8,37 +8,6 @@ from ..common.torchtext_test_case import TorchtextTestCase -class TestModules(TorchtextTestCase): - def test_self_attn_mask(self): - from torchtext.models.roberta.modules import MultiheadSelfAttention - - embed_dim, batch_size, num_heads, source_len = 4, 1, 2, 2 - mha = MultiheadSelfAttention(embed_dim=embed_dim, num_heads=num_heads) - query = torch.ones((source_len, batch_size, embed_dim)) - query[0, ...] = 0 - key_padding_mask = torch.zeros((batch_size, source_len)) - float_attn_mask = torch.zeros((source_len, source_len)) - float_attn_mask[0][1] = -1e8 - bool_attn_mask = float_attn_mask.to(dtype=bool) - with torch.no_grad(): - mha.input_projection.weight.fill_(1.0 / embed_dim) - mha.input_projection.bias.fill_(0.0) - mha.output_projection.weight.fill_(1.0 / embed_dim) - mha.output_projection.bias.fill_(0.0) - - # with float attention mask - output = mha(query, key_padding_mask, float_attn_mask) - actual = output[0].flatten() - expected = torch.tensor([0.0, 0.0, 0.0, 0]) - torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) - - # with bool attention mask - output = mha(query, key_padding_mask, bool_attn_mask) - actual = output[0].flatten() - expected = torch.tensor([0.0, 0.0, 0.0, 0]) - torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) - - class TestModels(TorchtextTestCase): def test_roberta_bundler_build_model(self): from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaBundle diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index 2a9a866267..85127c9645 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -1,10 +1,9 @@ import logging -import math from typing import List, Optional, Union import torch from torch import nn -from torch.nn import functional as F, Module +from torch.nn import Module logger = logging.getLogger(__name__) @@ -30,160 +29,6 @@ def _make_positions(self, tensor, pad_index: int): return torch.cumsum(masked, dim=1) * masked + pad_index -class ResidualMLP(Module): - def __init__( - self, - input_dim: int, - hidden_dims: List[int], - dropout: float = 0.1, - activation=nn.GELU, - add_residual=True, - ): - super().__init__() - modules = [] - for last_dim, dim in zip([input_dim] + hidden_dims, hidden_dims): - modules.extend([nn.Linear(last_dim, dim), activation(), nn.Dropout(dropout)]) - - last_dim = hidden_dims[-1] if hidden_dims else input_dim - modules.extend([nn.Linear(last_dim, input_dim), nn.Dropout(dropout)]) - - self.mlp = nn.Sequential(*modules) - self.add_residual = add_residual - self.hidden_dim = hidden_dims[0] if hidden_dims else input_dim - - def forward(self, input): - bias = self.mlp(input) - if not hasattr(self, "add_residual"): - self.add_residual = True - if self.add_residual: - return input + bias - else: - return bias - - -class MultiheadSelfAttention(Module): - def __init__( - self, - embed_dim: int, - num_heads: int, - scaling: Optional[float] = None, - dropout: float = 0.1, - ): - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.head_dim = embed_dim // num_heads - - expected_scaling = float(1 / math.sqrt(self.head_dim)) - - assert embed_dim % num_heads == 0, f"embed_dim={embed_dim} should be a multiple of num_heads={num_heads}" - - if not scaling: - logger.warn( - f""" - Scaling not set. Please manually set scaling for transformers. - In this case the suggested value {expected_scaling} will be inferred, - or float(1 / math.sqrt(head_dim)) - where head_dim = embed_dim // num_heads = {self.head_dim} - and embed_dim = {embed_dim} and num_heads = {num_heads}. - """ - ) - scaling = expected_scaling - - self.scaling = scaling - self.dropout = nn.Dropout(dropout) - self.input_projection = nn.Linear(embed_dim, 3 * embed_dim) - self.output_projection = nn.Linear(embed_dim, embed_dim) - - def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - target_length, batch_size, embed_dim = query.size() - mask_batch_size, source_length = key_padding_mask.size() - - torch._assert(embed_dim == self.embed_dim, "query embed dim doesn't match") - torch._assert( - batch_size == mask_batch_size, - "query and key_padding_mask batch sizes differed", - ) - - projection = self.input_projection(query) - q, k, v = projection.chunk(3, dim=-1) - q = self.scaling * q - - batch_heads = batch_size * self.num_heads - - q = q.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1) - k = k.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1) - v = v.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1) - - torch._assert(k.size(1) == source_length, "key size should be equal to source length") - - attn_weights = torch.bmm(q, k.transpose(1, 2)) - if attn_mask is not None: - torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim())) - torch._assert( - attn_mask.size(0) == target_length, - "attn_mask shape didn't match for target length {}".format(target_length), - ) - torch._assert( - attn_mask.size(1) == source_length, - "attn_mask shape didn't match for source length {}".format(source_length), - ) - torch._assert( - attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, - f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}", - ) - if attn_mask.dtype == torch.bool: - new_attn_mask = torch.zeros_like(attn_mask, dtype=query.dtype) - new_attn_mask.masked_fill_(attn_mask, -1e8 if query.dtype == torch.float32 else -1e4) - attn_mask = new_attn_mask - attn_mask = attn_mask.unsqueeze(0) - attn_weights += attn_mask - - torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim") - torch._assert( - attn_weights.size(0) == batch_heads, - "attn_weights shape didn't match for batch heads", - ) - torch._assert( - attn_weights.size(1) == target_length, - "attn_weights shape didn't match for target length", - ) - torch._assert( - attn_weights.size(2) == source_length, - "attn_weights shape didn't match for source length", - ) - - attn_weights = attn_weights.view(batch_size, self.num_heads, target_length, source_length) - attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")) - attn_weights = attn_weights.view(batch_heads, target_length, source_length) - - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) - attn_weights = self.dropout(attn_weights) - - attn = torch.bmm(attn_weights, v) - - torch._assert( - attn.dim() == 3, - "unexpected attn dim size", - ) - torch._assert( - attn.size(0) == batch_heads, - "attn shape didn't match for batch heads", - ) - torch._assert( - attn.size(1) == target_length, - "attn shape didn't match for target length", - ) - torch._assert( - attn.size(2) == self.head_dim, - "attn shape didn't match for head dim", - ) - attn = attn.transpose(0, 1).contiguous().view(target_length, batch_size, self.head_dim * self.num_heads) - attn = self.output_projection(attn) - - return attn - - class TransformerEncoderLayer(Module): def __init__( self,