-
Notifications
You must be signed in to change notification settings - Fork 812
Add XLMR Base and Large pre-trained models and corresponding transformations #1406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
7ee3ee4
initial commit
parmeet 9f4f058
minor edits
parmeet 3f142ea
minor edits
parmeet 6737229
minor edit
parmeet d34d706
Merge branch 'main' of github.com:pytorch/text into xlmr_model
parmeet b3a3665
fix examples
parmeet 6e919e6
fix flake
parmeet 8573895
minor fix
parmeet acd1b67
add transforms and tests
parmeet 4c41ded
add test for xlmr large
parmeet 147994b
fix flake
parmeet 6287497
remove pretrained spm model
parmeet 0127f7e
fix doc
parmeet 8e7a467
minor fix
parmeet 9a17579
reverting default path in download root
parmeet 0915572
removing get_from_local_path function
parmeet bd79785
Update torchtext/models/roberta/transforms.py
parmeet 96f0c27
add __all__ in functionals and transforms
parmeet 2282d8d
Merge branch 'xlmr_model' of github.com:parmeet/text into xlmr_model
parmeet File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import torchtext | ||
import torch | ||
|
||
from ..common.torchtext_test_case import TorchtextTestCase | ||
from ..common.assets import get_asset_path | ||
|
||
|
||
class TestModels(TorchtextTestCase): | ||
def test_xlmr_base_output(self): | ||
asset_name = "xlmr.base.output.pt" | ||
asset_path = get_asset_path(asset_name) | ||
xlmr_base = torchtext.models.XLMR_BASE_ENCODER | ||
model = xlmr_base.get_model() | ||
model = model.eval() | ||
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) | ||
actual = model(model_input) | ||
expected = torch.load(asset_path) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_xlmr_base_jit_output(self): | ||
asset_name = "xlmr.base.output.pt" | ||
asset_path = get_asset_path(asset_name) | ||
xlmr_base = torchtext.models.XLMR_BASE_ENCODER | ||
model = xlmr_base.get_model() | ||
model = model.eval() | ||
model_jit = torch.jit.script(model) | ||
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) | ||
actual = model_jit(model_input) | ||
expected = torch.load(asset_path) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_xlmr_large_output(self): | ||
asset_name = "xlmr.large.output.pt" | ||
asset_path = get_asset_path(asset_name) | ||
xlmr_base = torchtext.models.XLMR_LARGE_ENCODER | ||
model = xlmr_base.get_model() | ||
model = model.eval() | ||
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) | ||
actual = model(model_input) | ||
expected = torch.load(asset_path) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_xlmr_large_jit_output(self): | ||
asset_name = "xlmr.large.output.pt" | ||
asset_path = get_asset_path(asset_name) | ||
xlmr_base = torchtext.models.XLMR_LARGE_ENCODER | ||
model = xlmr_base.get_model() | ||
model = model.eval() | ||
model_jit = torch.jit.script(model) | ||
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) | ||
actual = model_jit(model_input) | ||
expected = torch.load(asset_path) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_xlmr_transform(self): | ||
xlmr_base = torchtext.models.XLMR_BASE_ENCODER | ||
transform = xlmr_base.transform() | ||
test_text = "XLMR base Model Comparison" | ||
actual = transform([test_text]) | ||
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]] | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_xlmr_transform_jit(self): | ||
xlmr_base = torchtext.models.XLMR_BASE_ENCODER | ||
transform = xlmr_base.transform() | ||
transform_jit = torch.jit.script(transform) | ||
test_text = "XLMR base Model Comparison" | ||
actual = transform_jit([test_text]) | ||
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]] | ||
torch.testing.assert_close(actual, expected) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
from torchtext import functional | ||
from .common.torchtext_test_case import TorchtextTestCase | ||
|
||
|
||
class TestFunctional(TorchtextTestCase): | ||
def test_to_tensor(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
padding_value = 0 | ||
actual = functional.to_tensor(input, padding_value=padding_value) | ||
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_to_tensor_jit(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
padding_value = 0 | ||
to_tensor_jit = torch.jit.script(functional.to_tensor) | ||
actual = to_tensor_jit(input, padding_value=padding_value) | ||
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) | ||
torch.testing.assert_close(actual, expected) | ||
|
||
def test_truncate(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
max_seq_len = 2 | ||
actual = functional.truncate(input, max_seq_len=max_seq_len) | ||
expected = [[1, 2], [1, 2]] | ||
self.assertEqual(actual, expected) | ||
|
||
def test_truncate_jit(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
max_seq_len = 2 | ||
truncate_jit = torch.jit.script(functional.truncate) | ||
actual = truncate_jit(input, max_seq_len=max_seq_len) | ||
expected = [[1, 2], [1, 2]] | ||
self.assertEqual(actual, expected) | ||
|
||
def test_add_token(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
token_id = 0 | ||
actual = functional.add_token(input, token_id=token_id) | ||
expected = [[0, 1, 2], [0, 1, 2, 3]] | ||
self.assertEqual(actual, expected) | ||
|
||
actual = functional.add_token(input, token_id=token_id, begin=False) | ||
expected = [[1, 2, 0], [1, 2, 3, 0]] | ||
self.assertEqual(actual, expected) | ||
|
||
def test_add_token_jit(self): | ||
input = [[1, 2], [1, 2, 3]] | ||
token_id = 0 | ||
add_token_jit = torch.jit.script(functional.add_token) | ||
actual = add_token_jit(input, token_id=token_id) | ||
expected = [[0, 1, 2], [0, 1, 2, 3]] | ||
self.assertEqual(actual, expected) | ||
|
||
actual = add_token_jit(input, token_id=token_id, begin=False) | ||
expected = [[1, 2, 0], [1, 2, 3, 0]] | ||
self.assertEqual(actual, expected) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
from torchtext import transforms | ||
from torchtext.vocab import vocab | ||
from collections import OrderedDict | ||
|
||
from .common.torchtext_test_case import TorchtextTestCase | ||
from .common.assets import get_asset_path | ||
|
||
|
||
class TestTransforms(TorchtextTestCase): | ||
def test_spmtokenizer_transform(self): | ||
asset_name = "spm_example.model" | ||
asset_path = get_asset_path(asset_name) | ||
transform = transforms.SpmTokenizerTransform(asset_path) | ||
actual = transform(["Hello World!, how are you?"]) | ||
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']] | ||
self.assertEqual(actual, expected) | ||
|
||
def test_spmtokenizer_transform_jit(self): | ||
asset_name = "spm_example.model" | ||
asset_path = get_asset_path(asset_name) | ||
transform = transforms.SpmTokenizerTransform(asset_path) | ||
transform_jit = torch.jit.script(transform) | ||
actual = transform_jit(["Hello World!, how are you?"]) | ||
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']] | ||
self.assertEqual(actual, expected) | ||
|
||
def test_vocab_transform(self): | ||
vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)])) | ||
transform = transforms.VocabTransform(vocab_obj) | ||
actual = transform([['a', 'b', 'c']]) | ||
expected = [[0, 1, 2]] | ||
self.assertEqual(actual, expected) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
from torch import Tensor | ||
from torch.nn.utils.rnn import pad_sequence | ||
from typing import List, Optional | ||
|
||
__all__ = [ | ||
'to_tensor', | ||
'truncate', | ||
'add_token', | ||
] | ||
|
||
|
||
def to_tensor(input: List[List[int]], padding_value: Optional[int] = None) -> Tensor: | ||
if padding_value is None: | ||
output = torch.tensor(input, dtype=torch.long) | ||
return output | ||
else: | ||
output = pad_sequence( | ||
[torch.tensor(ids, dtype=torch.long) for ids in input], | ||
batch_first=True, | ||
padding_value=float(padding_value) | ||
) | ||
return output | ||
|
||
|
||
def truncate(input: List[List[int]], max_seq_len: int) -> List[List[int]]: | ||
output: List[List[int]] = [] | ||
|
||
for ids in input: | ||
output.append(ids[:max_seq_len]) | ||
|
||
return output | ||
|
||
|
||
def add_token(input: List[List[int]], token_id: int, begin: bool = True) -> List[List[int]]: | ||
output: List[List[int]] = [] | ||
|
||
if begin: | ||
for ids in input: | ||
output.append([token_id] + ids) | ||
else: | ||
for ids in input: | ||
output.append(ids + [token_id]) | ||
|
||
return output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .roberta import * # noqa: F401, F403 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from .model import ( | ||
RobertaEncoderParams, | ||
RobertaClassificationHead, | ||
) | ||
|
||
from .bundler import ( | ||
RobertaModelBundle, | ||
XLMR_BASE_ENCODER, | ||
XLMR_LARGE_ENCODER, | ||
) | ||
|
||
__all__ = [ | ||
"RobertaEncoderParams", | ||
"RobertaClassificationHead", | ||
"RobertaModelBundle", | ||
"XLMR_BASE_ENCODER", | ||
"XLMR_LARGE_ENCODER", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
|
||
import os | ||
from dataclasses import dataclass | ||
from functools import partial | ||
|
||
from typing import Optional, Callable | ||
from torch.hub import load_state_dict_from_url | ||
from torch.nn import Module | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
from .model import ( | ||
RobertaEncoderParams, | ||
RobertaModel, | ||
_get_model, | ||
) | ||
|
||
from .transforms import get_xlmr_transform | ||
|
||
from torchtext import _TEXT_BUCKET | ||
|
||
|
||
@dataclass | ||
class RobertaModelBundle: | ||
""" | ||
Example - Pretrained encoder | ||
>>> import torch, torchtext | ||
>>> xlmr_base = torchtext.models.XLMR_BASE_ENCODER | ||
>>> model = xlmr_base.get_model() | ||
>>> transform = xlmr_base.transform() | ||
>>> model_input = torch.tensor(transform(["Hello World"])) | ||
>>> output = model(model_input) | ||
>>> output.shape | ||
torch.Size([1, 4, 768]) | ||
>>> input_batch = ["Hello world", "How are you!"] | ||
>>> from torchtext.functional import to_tensor | ||
>>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx) | ||
>>> output = model(model_input) | ||
>>> output.shape | ||
torch.Size([2, 6, 768]) | ||
|
||
Example - Pretrained encoder attached to un-initialized classification head | ||
>>> import torch, torchtext | ||
>>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER | ||
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.params.embedding_dim) | ||
>>> classification_model = xlmr_large.get_model(head=classifier_head) | ||
>>> transform = xlmr_large.transform() | ||
>>> model_input = torch.tensor(transform(["Hello World"])) | ||
>>> output = classification_model(model_input) | ||
>>> output.shape | ||
torch.Size([1, 2]) | ||
""" | ||
_params: RobertaEncoderParams | ||
_path: Optional[str] = None | ||
_head: Optional[Module] = None | ||
transform: Optional[Callable] = None | ||
|
||
def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel: | ||
|
||
if head is not None: | ||
input_head = head | ||
if self._head is not None: | ||
logger.log("A custom head module was provided, discarding the default head module.") | ||
else: | ||
input_head = self._head | ||
|
||
model = _get_model(self._params, input_head) | ||
|
||
dl_kwargs = {} if dl_kwargs is None else dl_kwargs | ||
state_dict = load_state_dict_from_url(self._path, **dl_kwargs) | ||
if input_head is not None: | ||
model.load_state_dict(state_dict, strict=False) | ||
else: | ||
model.load_state_dict(state_dict, strict=True) | ||
return model | ||
|
||
@property | ||
def params(self) -> RobertaEncoderParams: | ||
return self._params | ||
|
||
|
||
XLMR_BASE_ENCODER = RobertaModelBundle( | ||
_path=os.path.join(_TEXT_BUCKET, "xlmr.base.encoder.pt"), | ||
_params=RobertaEncoderParams(vocab_size=250002), | ||
transform=partial(get_xlmr_transform, | ||
vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), | ||
spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), | ||
) | ||
) | ||
|
||
XLMR_LARGE_ENCODER = RobertaModelBundle( | ||
_path=os.path.join(_TEXT_BUCKET, "xlmr.large.encoder.pt"), | ||
_params=RobertaEncoderParams(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24), | ||
transform=partial(get_xlmr_transform, | ||
vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), | ||
spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), | ||
) | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.