Skip to content

Replace TransformerEncoder in torchtext with better transformer #1700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 49 additions & 20 deletions test/integration_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,26 @@
import torch
from torchtext.models import ROBERTA_BASE_ENCODER, ROBERTA_LARGE_ENCODER, XLMR_BASE_ENCODER, XLMR_LARGE_ENCODER
from torchtext.models import (
ROBERTA_BASE_ENCODER,
ROBERTA_LARGE_ENCODER,
XLMR_BASE_ENCODER,
XLMR_LARGE_ENCODER,
)

from ..common.assets import get_asset_path
from ..common.parameterized_utils import nested_params
from ..common.torchtext_test_case import TorchtextTestCase


class TestModels(TorchtextTestCase):
@nested_params(
[
("xlmr.base.output.pt", "XLMR base Model Comparison", XLMR_BASE_ENCODER),
("xlmr.large.output.pt", "XLMR base Model Comparison", XLMR_LARGE_ENCODER),
(
"roberta.base.output.pt",
"Roberta base Model Comparison",
ROBERTA_BASE_ENCODER,
),
(
"roberta.large.output.pt",
"Roberta base Model Comparison",
ROBERTA_LARGE_ENCODER,
),
],
[True, False],
)
def test_model(self, model_args, is_jit):
def _xlmr_base_model(self, is_jit):
"""Verify pre-trained XLM-R and Roberta models in torchtext produce
the same output as the reference implementation within fairseq
"""
expected_asset_name, test_text, model_bundler = model_args
expected_asset_name, test_text, model_bundler = (
"xlmr.base.output.pt",
"XLMR base Model Comparison",
XLMR_BASE_ENCODER,
)

expected_asset_path = get_asset_path(expected_asset_name)

Expand All @@ -44,3 +36,40 @@ def test_model(self, model_args, is_jit):
actual = model(model_input)
expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected)

def test_xlmr_base_model(self):
self._xlmr_base_model(is_jit=False)

def test_xlmr_base_model_jit(self):
self._xlmr_base_model(is_jit=True)

def _xlmr_large_model(self, is_jit):
"""Verify pre-trained XLM-R and Roberta models in torchtext produce
the same output as the reference implementation within fairseq
"""
expected_asset_name, test_text, model_bundler = (
"xlmr.large.output.pt",
"XLMR base Model Comparison",
XLMR_LARGE_ENCODER,
)

expected_asset_path = get_asset_path(expected_asset_name)

transform = model_bundler.transform()
model = model_bundler.get_model()
model = model.eval()

if is_jit:
transform = torch.jit.script(transform)
model = torch.jit.script(model)

model_input = torch.tensor(transform([test_text]))
actual = model(model_input)
expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected)

def test_xlmr_large_model(self):
self._xlmr_large_model(is_jit=False)

def test_xlmr_large_model_jit(self):
self._xlmr_large_model(is_jit=True)
75 changes: 50 additions & 25 deletions torchtext/models/roberta/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,17 @@ def __init__(
super().__init__()
self.padding_idx = padding_idx
self.token_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx)
self.layers = nn.ModuleList(
[
TransformerEncoderLayer(
embedding_dim=embedding_dim,
num_attention_heads=num_attention_heads,
ffn_dimension=ffn_dimension,
ffn_dimension = ffn_dimension or 4 * embedding_dim
layer = torch.nn.TransformerEncoderLayer(
d_model=embedding_dim,
nhead=num_attention_heads,
dim_feedforward=ffn_dimension,
dropout=dropout,
normalize_before=normalize_before,
scaling=scaling,
activation="gelu",
batch_first=True,
norm_first=normalize_before,
)
for _ in range(num_encoder_layers)
]
)
self.layers = torch.nn.TransformerEncoder(encoder_layer=layer, num_layers=num_encoder_layers)
self.positional_embedding = PositionalEmbedding(max_seq_len, embedding_dim, padding_idx)
self.embedding_layer_norm = nn.LayerNorm(embedding_dim)
self.dropout = nn.Dropout(dropout)
Expand Down Expand Up @@ -153,27 +151,54 @@ def forward(

padded_embedded = embedded * (1 - padding_mask.unsqueeze(-1).type_as(embedded))

encoded = padded_embedded.transpose(0, 1)

if self.return_all_layers:
states = [encoded]

for layer in self.layers:
encoded = padded_embedded
# B x T x C
# Then transpose back to T x B x C
states = [encoded.transpose(1, 0)]
for layer in self.layers.layers:
encoded = layer(encoded, padding_mask, attn_mask)
states.append(encoded)

encoded_t = encoded.transpose(1, 0)
states.append(encoded_t)
if self.normalize_before:
for i, state in enumerate(states):
states[i] = self.embedding_layer_norm(state)

# states are returned as T x B x C
return states
else:
for layer in self.layers:
encoded = layer(encoded, padding_mask, attn_mask)

# B x T x C
# Then transpose back to T x B x C
encoded = self.layers(padded_embedded).transpose(1, 0)
if self.normalize_before:
encoded = self.embedding_layer_norm(encoded)

# states are returned as T x B x C
return encoded

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
better_to_old_names = {
"self_attn.in_proj_weight": "attention.input_projection.weight",
"self_attn.in_proj_bias": "attention.input_projection.bias",
"self_attn.out_proj.weight": "attention.output_projection.weight",
"self_attn.out_proj.bias": "attention.output_projection.bias",
"linear1.weight": "residual_mlp.mlp.0.weight",
"linear1.bias": "residual_mlp.mlp.0.bias",
"linear2.weight": "residual_mlp.mlp.3.weight",
"linear2.bias": "residual_mlp.mlp.3.bias",
"norm1.weight": "attention_layer_norm.weight",
"norm1.bias": "attention_layer_norm.bias",
"norm2.weight": "final_layer_norm.weight",
"norm2.bias": "final_layer_norm.bias",
}
for i in range(self.layers.num_layers):
for better, old in better_to_old_names.items():
better_name = prefix + "layers.layers.{}.".format(i) + better
old_name = prefix + "layers.{}.".format(i) + old
if old_name in state_dict:
state_dict[better_name] = state_dict[old_name]
state_dict.pop(old_name)
elif better_name in state_dict:
# Do nothing
pass
elif strict:
missing_keys.append(better_name)

super(TransformerEncoder,
self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)