Skip to content

Add XLMR Base and Large pre-trained models and corresponding transformations #1406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test/asset/xlmr.base.output.pt
Binary file not shown.
Binary file added test/asset/xlmr.large.output.pt
Binary file not shown.
70 changes: 70 additions & 0 deletions test/models/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torchtext
import torch

from ..common.torchtext_test_case import TorchtextTestCase
from ..common.assets import get_asset_path


class TestModels(TorchtextTestCase):
def test_xlmr_base_output(self):
asset_name = "xlmr.base.output.pt"
asset_path = get_asset_path(asset_name)
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
model = xlmr_base.get_model()
model = model.eval()
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
actual = model(model_input)
expected = torch.load(asset_path)
torch.testing.assert_close(actual, expected)

def test_xlmr_base_jit_output(self):
asset_name = "xlmr.base.output.pt"
asset_path = get_asset_path(asset_name)
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
model = xlmr_base.get_model()
model = model.eval()
model_jit = torch.jit.script(model)
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
actual = model_jit(model_input)
expected = torch.load(asset_path)
torch.testing.assert_close(actual, expected)

def test_xlmr_large_output(self):
asset_name = "xlmr.large.output.pt"
asset_path = get_asset_path(asset_name)
xlmr_base = torchtext.models.XLMR_LARGE_ENCODER
model = xlmr_base.get_model()
model = model.eval()
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
actual = model(model_input)
expected = torch.load(asset_path)
torch.testing.assert_close(actual, expected)

def test_xlmr_large_jit_output(self):
asset_name = "xlmr.large.output.pt"
asset_path = get_asset_path(asset_name)
xlmr_base = torchtext.models.XLMR_LARGE_ENCODER
model = xlmr_base.get_model()
model = model.eval()
model_jit = torch.jit.script(model)
model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
actual = model_jit(model_input)
expected = torch.load(asset_path)
torch.testing.assert_close(actual, expected)

def test_xlmr_transform(self):
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
transform = xlmr_base.transform()
test_text = "XLMR base Model Comparison"
actual = transform([test_text])
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
torch.testing.assert_close(actual, expected)

def test_xlmr_transform_jit(self):
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
transform = xlmr_base.transform()
transform_jit = torch.jit.script(transform)
test_text = "XLMR base Model Comparison"
actual = transform_jit([test_text])
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
torch.testing.assert_close(actual, expected)
58 changes: 58 additions & 0 deletions test/test_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
from torchtext import functional
from .common.torchtext_test_case import TorchtextTestCase


class TestFunctional(TorchtextTestCase):
def test_to_tensor(self):
input = [[1, 2], [1, 2, 3]]
padding_value = 0
actual = functional.to_tensor(input, padding_value=padding_value)
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
torch.testing.assert_close(actual, expected)

def test_to_tensor_jit(self):
input = [[1, 2], [1, 2, 3]]
padding_value = 0
to_tensor_jit = torch.jit.script(functional.to_tensor)
actual = to_tensor_jit(input, padding_value=padding_value)
expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
torch.testing.assert_close(actual, expected)

def test_truncate(self):
input = [[1, 2], [1, 2, 3]]
max_seq_len = 2
actual = functional.truncate(input, max_seq_len=max_seq_len)
expected = [[1, 2], [1, 2]]
self.assertEqual(actual, expected)

def test_truncate_jit(self):
input = [[1, 2], [1, 2, 3]]
max_seq_len = 2
truncate_jit = torch.jit.script(functional.truncate)
actual = truncate_jit(input, max_seq_len=max_seq_len)
expected = [[1, 2], [1, 2]]
self.assertEqual(actual, expected)

def test_add_token(self):
input = [[1, 2], [1, 2, 3]]
token_id = 0
actual = functional.add_token(input, token_id=token_id)
expected = [[0, 1, 2], [0, 1, 2, 3]]
self.assertEqual(actual, expected)

actual = functional.add_token(input, token_id=token_id, begin=False)
expected = [[1, 2, 0], [1, 2, 3, 0]]
self.assertEqual(actual, expected)

def test_add_token_jit(self):
input = [[1, 2], [1, 2, 3]]
token_id = 0
add_token_jit = torch.jit.script(functional.add_token)
actual = add_token_jit(input, token_id=token_id)
expected = [[0, 1, 2], [0, 1, 2, 3]]
self.assertEqual(actual, expected)

actual = add_token_jit(input, token_id=token_id, begin=False)
expected = [[1, 2, 0], [1, 2, 3, 0]]
self.assertEqual(actual, expected)
33 changes: 33 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torchtext import transforms
from torchtext.vocab import vocab
from collections import OrderedDict

from .common.torchtext_test_case import TorchtextTestCase
from .common.assets import get_asset_path


class TestTransforms(TorchtextTestCase):
def test_spmtokenizer_transform(self):
asset_name = "spm_example.model"
asset_path = get_asset_path(asset_name)
transform = transforms.SpmTokenizerTransform(asset_path)
actual = transform(["Hello World!, how are you?"])
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']]
self.assertEqual(actual, expected)

def test_spmtokenizer_transform_jit(self):
asset_name = "spm_example.model"
asset_path = get_asset_path(asset_name)
transform = transforms.SpmTokenizerTransform(asset_path)
transform_jit = torch.jit.script(transform)
actual = transform_jit(["Hello World!, how are you?"])
expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']]
self.assertEqual(actual, expected)

def test_vocab_transform(self):
vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)]))
transform = transforms.VocabTransform(vocab_obj)
actual = transform([['a', 'b', 'c']])
expected = [[0, 1, 2]]
self.assertEqual(actual, expected)
10 changes: 10 additions & 0 deletions torchtext/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import os
_TEXT_BUCKET = 'https://download.pytorch.org/models/text'
_CACHE_DIR = os.path.expanduser('~/.torchtext/cache')

from . import data
from . import nn
from . import datasets
from . import utils
from . import vocab
from . import transforms
from . import functional
from . import models
from . import experimental
from . import legacy
from ._extension import _init_extension
Expand All @@ -18,6 +25,9 @@
'datasets',
'utils',
'vocab',
'transforms',
'functional',
'models',
'experimental',
'legacy']

Expand Down
6 changes: 4 additions & 2 deletions torchtext/data/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import defusedxml.ElementTree as ET
except ImportError:
import xml.etree.ElementTree as ET

from torchtext import _CACHE_DIR
"""
These functions and classes are meant solely for use in torchtext.datasets and not
for public consumption yet.
Expand Down Expand Up @@ -213,7 +215,7 @@ def _wrap_split_argument_with_fn(fn, splits):
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))

@functools.wraps(fn)
def new_fn(root=os.path.expanduser('~/.torchtext/cache'), split=splits, **kwargs):
def new_fn(root=_CACHE_DIR, split=splits, **kwargs):
result = []
for item in _check_default_set(split, splits, fn.__name__):
result.append(fn(root, item, **kwargs))
Expand Down Expand Up @@ -250,7 +252,7 @@ def decorator(func):
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))

@functools.wraps(func)
def wrapper(root=os.path.expanduser('~/.torchtext/cache'), *args, **kwargs):
def wrapper(root=_CACHE_DIR, *args, **kwargs):
new_root = os.path.join(root, dataset_name)
if not os.path.exists(new_root):
os.makedirs(new_root)
Expand Down
45 changes: 45 additions & 0 deletions torchtext/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from typing import List, Optional

__all__ = [
'to_tensor',
'truncate',
'add_token',
]


def to_tensor(input: List[List[int]], padding_value: Optional[int] = None) -> Tensor:
if padding_value is None:
output = torch.tensor(input, dtype=torch.long)
return output
else:
output = pad_sequence(
[torch.tensor(ids, dtype=torch.long) for ids in input],
batch_first=True,
padding_value=float(padding_value)
)
return output


def truncate(input: List[List[int]], max_seq_len: int) -> List[List[int]]:
output: List[List[int]] = []

for ids in input:
output.append(ids[:max_seq_len])

return output


def add_token(input: List[List[int]], token_id: int, begin: bool = True) -> List[List[int]]:
output: List[List[int]] = []

if begin:
for ids in input:
output.append([token_id] + ids)
else:
for ids in input:
output.append(ids + [token_id])

return output
1 change: 1 addition & 0 deletions torchtext/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .roberta import * # noqa: F401, F403
18 changes: 18 additions & 0 deletions torchtext/models/roberta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .model import (
RobertaEncoderParams,
RobertaClassificationHead,
)

from .bundler import (
RobertaModelBundle,
XLMR_BASE_ENCODER,
XLMR_LARGE_ENCODER,
)

__all__ = [
"RobertaEncoderParams",
"RobertaClassificationHead",
"RobertaModelBundle",
"XLMR_BASE_ENCODER",
"XLMR_LARGE_ENCODER",
]
99 changes: 99 additions & 0 deletions torchtext/models/roberta/bundler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@

import os
from dataclasses import dataclass
from functools import partial

from typing import Optional, Callable
from torch.hub import load_state_dict_from_url
from torch.nn import Module
import logging

logger = logging.getLogger(__name__)

from .model import (
RobertaEncoderParams,
RobertaModel,
_get_model,
)

from .transforms import get_xlmr_transform

from torchtext import _TEXT_BUCKET


@dataclass
class RobertaModelBundle:
"""
Example - Pretrained encoder
>>> import torch, torchtext
>>> xlmr_base = torchtext.models.XLMR_BASE_ENCODER
>>> model = xlmr_base.get_model()
>>> transform = xlmr_base.transform()
>>> model_input = torch.tensor(transform(["Hello World"]))
>>> output = model(model_input)
>>> output.shape
torch.Size([1, 4, 768])
>>> input_batch = ["Hello world", "How are you!"]
>>> from torchtext.functional import to_tensor
>>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx)
>>> output = model(model_input)
>>> output.shape
torch.Size([2, 6, 768])

Example - Pretrained encoder attached to un-initialized classification head
>>> import torch, torchtext
>>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.params.embedding_dim)
>>> classification_model = xlmr_large.get_model(head=classifier_head)
>>> transform = xlmr_large.transform()
>>> model_input = torch.tensor(transform(["Hello World"]))
>>> output = classification_model(model_input)
>>> output.shape
torch.Size([1, 2])
"""
_params: RobertaEncoderParams
_path: Optional[str] = None
_head: Optional[Module] = None
transform: Optional[Callable] = None

def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel:

if head is not None:
input_head = head
if self._head is not None:
logger.log("A custom head module was provided, discarding the default head module.")
else:
input_head = self._head

model = _get_model(self._params, input_head)

dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(self._path, **dl_kwargs)
if input_head is not None:
model.load_state_dict(state_dict, strict=False)
else:
model.load_state_dict(state_dict, strict=True)
return model

@property
def params(self) -> RobertaEncoderParams:
return self._params


XLMR_BASE_ENCODER = RobertaModelBundle(
_path=os.path.join(_TEXT_BUCKET, "xlmr.base.encoder.pt"),
_params=RobertaEncoderParams(vocab_size=250002),
transform=partial(get_xlmr_transform,
vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"),
spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"),
)
)

XLMR_LARGE_ENCODER = RobertaModelBundle(
_path=os.path.join(_TEXT_BUCKET, "xlmr.large.encoder.pt"),
_params=RobertaEncoderParams(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24),
transform=partial(get_xlmr_transform,
vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"),
spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"),
)
)
Loading