Skip to content

Commit 204777a

Browse files
Michael Gschwindfacebook-github-bot
authored andcommitted
Remove unneeded modules after using nn.Module for BetterTransformer (#1693)
Summary: Pull Request resolved: #1693 Remove unneeded modules after using nn.Module for BetterTransformer Differential Revision: D36038830 fbshipit-source-id: 5663217ca4989dc07962f239b3365bc48edad490
1 parent 6e8b4d6 commit 204777a

File tree

1 file changed

+0
-153
lines changed

1 file changed

+0
-153
lines changed

torchtext/models/roberta/modules.py

Lines changed: 0 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -30,159 +30,6 @@ def _make_positions(self, tensor, pad_index: int):
3030
return torch.cumsum(masked, dim=1) * masked + pad_index
3131

3232

33-
class ResidualMLP(Module):
34-
def __init__(
35-
self,
36-
input_dim: int,
37-
hidden_dims: List[int],
38-
dropout: float = 0.1,
39-
activation=nn.GELU,
40-
add_residual=True,
41-
):
42-
super().__init__()
43-
modules = []
44-
for last_dim, dim in zip([input_dim] + hidden_dims, hidden_dims):
45-
modules.extend([nn.Linear(last_dim, dim), activation(), nn.Dropout(dropout)])
46-
47-
last_dim = hidden_dims[-1] if hidden_dims else input_dim
48-
modules.extend([nn.Linear(last_dim, input_dim), nn.Dropout(dropout)])
49-
50-
self.mlp = nn.Sequential(*modules)
51-
self.add_residual = add_residual
52-
self.hidden_dim = hidden_dims[0] if hidden_dims else input_dim
53-
54-
def forward(self, input):
55-
bias = self.mlp(input)
56-
if not hasattr(self, "add_residual"):
57-
self.add_residual = True
58-
if self.add_residual:
59-
return input + bias
60-
else:
61-
return bias
62-
63-
64-
class MultiheadSelfAttention(Module):
65-
def __init__(
66-
self,
67-
embed_dim: int,
68-
num_heads: int,
69-
scaling: Optional[float] = None,
70-
dropout: float = 0.1,
71-
):
72-
super().__init__()
73-
self.embed_dim = embed_dim
74-
self.num_heads = num_heads
75-
self.head_dim = embed_dim // num_heads
76-
77-
expected_scaling = float(1 / math.sqrt(self.head_dim))
78-
79-
assert embed_dim % num_heads == 0, f"embed_dim={embed_dim} should be a multiple of num_heads={num_heads}"
80-
81-
if not scaling:
82-
logger.warn(
83-
f"""
84-
Scaling not set. Please manually set scaling for transformers.
85-
In this case the suggested value {expected_scaling} will be inferred,
86-
or float(1 / math.sqrt(head_dim))
87-
where head_dim = embed_dim // num_heads = {self.head_dim}
88-
and embed_dim = {embed_dim} and num_heads = {num_heads}.
89-
"""
90-
)
91-
scaling = expected_scaling
92-
93-
self.scaling = scaling
94-
self.dropout = nn.Dropout(dropout)
95-
self.input_projection = nn.Linear(embed_dim, 3 * embed_dim)
96-
self.output_projection = nn.Linear(embed_dim, embed_dim)
97-
98-
def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
99-
target_length, batch_size, embed_dim = query.size()
100-
mask_batch_size, source_length = key_padding_mask.size()
101-
102-
torch._assert(embed_dim == self.embed_dim, "query embed dim doesn't match")
103-
torch._assert(
104-
batch_size == mask_batch_size,
105-
"query and key_padding_mask batch sizes differed",
106-
)
107-
108-
projection = self.input_projection(query)
109-
q, k, v = projection.chunk(3, dim=-1)
110-
q = self.scaling * q
111-
112-
batch_heads = batch_size * self.num_heads
113-
114-
q = q.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
115-
k = k.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
116-
v = v.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
117-
118-
torch._assert(k.size(1) == source_length, "key size should be equal to source length")
119-
120-
attn_weights = torch.bmm(q, k.transpose(1, 2))
121-
if attn_mask is not None:
122-
torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim()))
123-
torch._assert(
124-
attn_mask.size(0) == target_length,
125-
"attn_mask shape didn't match for target length {}".format(target_length),
126-
)
127-
torch._assert(
128-
attn_mask.size(1) == source_length,
129-
"attn_mask shape didn't match for source length {}".format(source_length),
130-
)
131-
torch._assert(
132-
attn_mask.is_floating_point() or attn_mask.dtype == torch.bool,
133-
f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}",
134-
)
135-
if attn_mask.dtype == torch.bool:
136-
new_attn_mask = torch.zeros_like(attn_mask, dtype=query.dtype)
137-
new_attn_mask.masked_fill_(attn_mask, -1e8 if query.dtype == torch.float32 else -1e4)
138-
attn_mask = new_attn_mask
139-
attn_mask = attn_mask.unsqueeze(0)
140-
attn_weights += attn_mask
141-
142-
torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim")
143-
torch._assert(
144-
attn_weights.size(0) == batch_heads,
145-
"attn_weights shape didn't match for batch heads",
146-
)
147-
torch._assert(
148-
attn_weights.size(1) == target_length,
149-
"attn_weights shape didn't match for target length",
150-
)
151-
torch._assert(
152-
attn_weights.size(2) == source_length,
153-
"attn_weights shape didn't match for source length",
154-
)
155-
156-
attn_weights = attn_weights.view(batch_size, self.num_heads, target_length, source_length)
157-
attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
158-
attn_weights = attn_weights.view(batch_heads, target_length, source_length)
159-
160-
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
161-
attn_weights = self.dropout(attn_weights)
162-
163-
attn = torch.bmm(attn_weights, v)
164-
165-
torch._assert(
166-
attn.dim() == 3,
167-
"unexpected attn dim size",
168-
)
169-
torch._assert(
170-
attn.size(0) == batch_heads,
171-
"attn shape didn't match for batch heads",
172-
)
173-
torch._assert(
174-
attn.size(1) == target_length,
175-
"attn shape didn't match for target length",
176-
)
177-
torch._assert(
178-
attn.size(2) == self.head_dim,
179-
"attn shape didn't match for head dim",
180-
)
181-
attn = attn.transpose(0, 1).contiguous().view(target_length, batch_size, self.head_dim * self.num_heads)
182-
attn = self.output_projection(attn)
183-
184-
return attn
185-
18633

18734
class TransformerEncoderLayer(Module):
18835
def __init__(

0 commit comments

Comments
 (0)