From 31c12c63f984a182ef4b5effabde7d7f6e1970c7 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 4 May 2022 22:06:37 -0700 Subject: [PATCH 1/2] Enable model testing in FBCode Differential Revision: D35973306 fbshipit-source-id: 4a2a9f90bb0d3be32689541ad4c4b8e7b1a9cfdf --- test/integration_tests/test_models.py | 69 +++++++++++++++++++-------- 1 file changed, 49 insertions(+), 20 deletions(-) diff --git a/test/integration_tests/test_models.py b/test/integration_tests/test_models.py index 7c1dd60d8a..2004a65b8d 100644 --- a/test/integration_tests/test_models.py +++ b/test/integration_tests/test_models.py @@ -1,5 +1,10 @@ import torch -from torchtext.models import ROBERTA_BASE_ENCODER, ROBERTA_LARGE_ENCODER, XLMR_BASE_ENCODER, XLMR_LARGE_ENCODER +from torchtext.models import ( + ROBERTA_BASE_ENCODER, + ROBERTA_LARGE_ENCODER, + XLMR_BASE_ENCODER, + XLMR_LARGE_ENCODER, +) from ..common.assets import get_asset_path from ..common.parameterized_utils import nested_params @@ -7,28 +12,15 @@ class TestModels(TorchtextTestCase): - @nested_params( - [ - ("xlmr.base.output.pt", "XLMR base Model Comparison", XLMR_BASE_ENCODER), - ("xlmr.large.output.pt", "XLMR base Model Comparison", XLMR_LARGE_ENCODER), - ( - "roberta.base.output.pt", - "Roberta base Model Comparison", - ROBERTA_BASE_ENCODER, - ), - ( - "roberta.large.output.pt", - "Roberta base Model Comparison", - ROBERTA_LARGE_ENCODER, - ), - ], - [True, False], - ) - def test_model(self, model_args, is_jit): + def _xlmr_base_model(self, is_jit): """Verify pre-trained XLM-R and Roberta models in torchtext produce the same output as the reference implementation within fairseq """ - expected_asset_name, test_text, model_bundler = model_args + expected_asset_name, test_text, model_bundler = ( + "xlmr.base.output.pt", + "XLMR base Model Comparison", + XLMR_BASE_ENCODER, + ) expected_asset_path = get_asset_path(expected_asset_name) @@ -44,3 +36,40 @@ def test_model(self, model_args, is_jit): actual = model(model_input) expected = torch.load(expected_asset_path) torch.testing.assert_close(actual, expected) + + def test_xlmr_base_model(self): + self._xlmr_base_model(is_jit=False) + + def test_xlmr_base_model_jit(self): + self._xlmr_base_model(is_jit=True) + + def _xlmr_large_model(self, is_jit): + """Verify pre-trained XLM-R and Roberta models in torchtext produce + the same output as the reference implementation within fairseq + """ + expected_asset_name, test_text, model_bundler = ( + "xlmr.large.output.pt", + "XLMR base Model Comparison", + XLMR_LARGE_ENCODER, + ) + + expected_asset_path = get_asset_path(expected_asset_name) + + transform = model_bundler.transform() + model = model_bundler.get_model() + model = model.eval() + + if is_jit: + transform = torch.jit.script(transform) + model = torch.jit.script(model) + + model_input = torch.tensor(transform([test_text])) + actual = model(model_input) + expected = torch.load(expected_asset_path) + torch.testing.assert_close(actual, expected) + + def test_xlmr_large_model(self): + self._xlmr_large_model(is_jit=False) + + def test_xlmr_large_model_jit(self): + self._xlmr_large_model(is_jit=True) From a43f362967e153e97f85046442bf500d7d6f6445 Mon Sep 17 00:00:00 2001 From: Rui Zhu Date: Wed, 4 May 2022 22:06:53 -0700 Subject: [PATCH 2/2] Replace TransformerEncoder in torchtext with better transformer Summary: Replace the usage of TransformerEncoder by BetterTransformerEncoder In theory we should be able to remove torchtext.TransformerEncoderLayer after this diff. Reviewed By: parmeet Differential Revision: D36084653 fbshipit-source-id: 2578243f67678fa6a6437d7e275f34cccee14ab2 --- torchtext/models/roberta/modules.py | 75 +++++++++++++++++++---------- 1 file changed, 50 insertions(+), 25 deletions(-) diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index 85127c9645..f7b911a6fd 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -110,19 +110,17 @@ def __init__( super().__init__() self.padding_idx = padding_idx self.token_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx) - self.layers = nn.ModuleList( - [ - TransformerEncoderLayer( - embedding_dim=embedding_dim, - num_attention_heads=num_attention_heads, - ffn_dimension=ffn_dimension, + ffn_dimension = ffn_dimension or 4 * embedding_dim + layer = torch.nn.TransformerEncoderLayer( + d_model=embedding_dim, + nhead=num_attention_heads, + dim_feedforward=ffn_dimension, dropout=dropout, - normalize_before=normalize_before, - scaling=scaling, + activation="gelu", + batch_first=True, + norm_first=normalize_before, ) - for _ in range(num_encoder_layers) - ] - ) + self.layers = torch.nn.TransformerEncoder(encoder_layer=layer, num_layers=num_encoder_layers) self.positional_embedding = PositionalEmbedding(max_seq_len, embedding_dim, padding_idx) self.embedding_layer_norm = nn.LayerNorm(embedding_dim) self.dropout = nn.Dropout(dropout) @@ -153,27 +151,54 @@ def forward( padded_embedded = embedded * (1 - padding_mask.unsqueeze(-1).type_as(embedded)) - encoded = padded_embedded.transpose(0, 1) - if self.return_all_layers: - states = [encoded] - - for layer in self.layers: + encoded = padded_embedded + # B x T x C + # Then transpose back to T x B x C + states = [encoded.transpose(1, 0)] + for layer in self.layers.layers: encoded = layer(encoded, padding_mask, attn_mask) - states.append(encoded) - + encoded_t = encoded.transpose(1, 0) + states.append(encoded_t) if self.normalize_before: for i, state in enumerate(states): states[i] = self.embedding_layer_norm(state) - - # states are returned as T x B x C return states else: - for layer in self.layers: - encoded = layer(encoded, padding_mask, attn_mask) - + # B x T x C + # Then transpose back to T x B x C + encoded = self.layers(padded_embedded).transpose(1, 0) if self.normalize_before: encoded = self.embedding_layer_norm(encoded) - - # states are returned as T x B x C return encoded + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + better_to_old_names = { + "self_attn.in_proj_weight": "attention.input_projection.weight", + "self_attn.in_proj_bias": "attention.input_projection.bias", + "self_attn.out_proj.weight": "attention.output_projection.weight", + "self_attn.out_proj.bias": "attention.output_projection.bias", + "linear1.weight": "residual_mlp.mlp.0.weight", + "linear1.bias": "residual_mlp.mlp.0.bias", + "linear2.weight": "residual_mlp.mlp.3.weight", + "linear2.bias": "residual_mlp.mlp.3.bias", + "norm1.weight": "attention_layer_norm.weight", + "norm1.bias": "attention_layer_norm.bias", + "norm2.weight": "final_layer_norm.weight", + "norm2.bias": "final_layer_norm.bias", + } + for i in range(self.layers.num_layers): + for better, old in better_to_old_names.items(): + better_name = prefix + "layers.layers.{}.".format(i) + better + old_name = prefix + "layers.{}.".format(i) + old + if old_name in state_dict: + state_dict[better_name] = state_dict[old_name] + state_dict.pop(old_name) + elif better_name in state_dict: + # Do nothing + pass + elif strict: + missing_keys.append(better_name) + + super(TransformerEncoder, + self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)