Skip to content

[BC-breaking] rename Roberta Bundle #1635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ torchtext.models
.. automodule:: torchtext.models
.. currentmodule:: torchtext.models

RobertaModelBundle
------------------
RobertaBundle
-------------

.. autoclass:: RobertaModelBundle
.. autoclass:: RobertaBundle
:members: transform

.. automethod:: get_model
Expand Down
16 changes: 8 additions & 8 deletions test/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def test_self_attn_mask(self):

class TestModels(TorchtextTestCase):
def test_roberta_bundler_build_model(self):
from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaModelBundle
from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaBundle

dummy_encoder_conf = RobertaEncoderConf(
vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2
)

# case: user provide encoder checkpoint state dict
dummy_encoder = RobertaModel(dummy_encoder_conf)
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf, checkpoint=dummy_encoder.state_dict())
model = RobertaBundle.build_model(encoder_conf=dummy_encoder_conf, checkpoint=dummy_encoder.state_dict())
self.assertEqual(model.state_dict(), dummy_encoder.state_dict())

# case: user provide classifier checkpoint state dict when head is given and override_head is False (by default)
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.build_model(
model = RobertaBundle.build_model(
encoder_conf=dummy_encoder_conf,
head=another_dummy_classifier_head,
checkpoint=dummy_classifier.state_dict(),
Expand All @@ -64,7 +64,7 @@ def test_roberta_bundler_build_model(self):

# case: user provide classifier checkpoint state dict when head is given and override_head is set True
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
model = RobertaModelBundle.build_model(
model = RobertaBundle.build_model(
encoder_conf=dummy_encoder_conf,
head=another_dummy_classifier_head,
checkpoint=dummy_classifier.state_dict(),
Expand All @@ -78,13 +78,13 @@ def test_roberta_bundler_build_model(self):
encoder_state_dict = {}
for k, v in dummy_classifier.encoder.state_dict().items():
encoder_state_dict["encoder." + k] = v
model = torchtext.models.RobertaModelBundle.build_model(
model = torchtext.models.RobertaBundle.build_model(
encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict
)
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())

def test_roberta_bundler_train(self):
from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaModelBundle
from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaBundle

dummy_encoder_conf = RobertaEncoderConf(
vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2
Expand All @@ -103,7 +103,7 @@ def _train(model):
# does not freeze encoder
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.build_model(
model = RobertaBundle.build_model(
encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=False,
Expand All @@ -121,7 +121,7 @@ def _train(model):
# freeze encoder
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.build_model(
model = RobertaBundle.build_model(
encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=True,
Expand Down
4 changes: 2 additions & 2 deletions torchtext/models/roberta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .bundler import (
ROBERTA_BASE_ENCODER,
ROBERTA_LARGE_ENCODER,
RobertaModelBundle,
RobertaBundle,
XLMR_BASE_ENCODER,
XLMR_LARGE_ENCODER,
)
Expand All @@ -11,7 +11,7 @@
"RobertaEncoderConf",
"RobertaClassificationHead",
"RobertaModel",
"RobertaModelBundle",
"RobertaBundle",
"XLMR_BASE_ENCODER",
"XLMR_LARGE_ENCODER",
"ROBERTA_BASE_ENCODER",
Expand Down
26 changes: 13 additions & 13 deletions torchtext/models/roberta/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def _is_head_available_in_checkpoint(checkpoint, head_state_dict):


@dataclass
class RobertaModelBundle:
"""RobertaModelBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None)
class RobertaBundle:
"""RobertaBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None)

Example - Pretrained base xlmr encoder
>>> import torch, torchtext
Expand Down Expand Up @@ -52,11 +52,11 @@ class RobertaModelBundle:
torch.Size([1, 2])

Example - User-specified configuration and checkpoint
>>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead
>>> from torchtext.models import RobertaEncoderConf, RobertaBundle, RobertaClassificationHead
>>> model_weights_path = "https://download.pytorch.org/models/text/xlmr.base.encoder.pt"
>>> encoder_conf = RobertaEncoderConf(vocab_size=250002)
>>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
>>> model = RobertaModelBundle.build_model(encoder_conf=encoder_conf, head=classifier_head, checkpoint=model_weights_path)
>>> model = RobertaBundle.build_model(encoder_conf=encoder_conf, head=classifier_head, checkpoint=model_weights_path)
"""

_encoder_conf: RobertaEncoderConf
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_model(
else:
input_head = self._head

return RobertaModelBundle.build_model(
return RobertaBundle.build_model(
encoder_conf=self._encoder_conf,
head=input_head,
freeze_encoder=freeze_encoder,
Expand Down Expand Up @@ -160,7 +160,7 @@ def encoderConf(self) -> RobertaEncoderConf:
return self._encoder_conf


XLMR_BASE_ENCODER = RobertaModelBundle(
XLMR_BASE_ENCODER = RobertaBundle(
_path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"),
_encoder_conf=RobertaEncoderConf(vocab_size=250002),
transform=lambda: T.Sequential(
Expand All @@ -184,11 +184,11 @@ def encoderConf(self) -> RobertaEncoderConf:
[`License <https://github.com/pytorch/fairseq/blob/main/LICENSE>`__,
`Source <https://github.com/pytorch/fairseq/tree/main/examples/xlmr#pre-trained-models>`__]

Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage.
Please refer to :func:`torchtext.models.RobertaBundle` for the usage.
"""


XLMR_LARGE_ENCODER = RobertaModelBundle(
XLMR_LARGE_ENCODER = RobertaBundle(
_path=urljoin(_TEXT_BUCKET, "xlmr.large.encoder.pt"),
_encoder_conf=RobertaEncoderConf(
vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24
Expand All @@ -214,11 +214,11 @@ def encoderConf(self) -> RobertaEncoderConf:
[`License <https://github.com/pytorch/fairseq/blob/main/LICENSE>`__,
`Source <https://github.com/pytorch/fairseq/tree/main/examples/xlmr#pre-trained-models>`__]

Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage.
Please refer to :func:`torchtext.models.RobertaBundle` for the usage.
"""


ROBERTA_BASE_ENCODER = RobertaModelBundle(
ROBERTA_BASE_ENCODER = RobertaBundle(
_path=urljoin(_TEXT_BUCKET, "roberta.base.encoder.pt"),
_encoder_conf=RobertaEncoderConf(vocab_size=50265),
transform=lambda: T.Sequential(
Expand Down Expand Up @@ -250,11 +250,11 @@ def encoderConf(self) -> RobertaEncoderConf:
[`License <https://github.com/pytorch/fairseq/blob/main/LICENSE>`__,
`Source <https://github.com/pytorch/fairseq/tree/main/examples/roberta#pre-trained-models>`__]

Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage.
Please refer to :func:`torchtext.models.RobertaBundle` for the usage.
"""


ROBERTA_LARGE_ENCODER = RobertaModelBundle(
ROBERTA_LARGE_ENCODER = RobertaBundle(
_path=urljoin(_TEXT_BUCKET, "roberta.large.encoder.pt"),
_encoder_conf=RobertaEncoderConf(
vocab_size=50265,
Expand Down Expand Up @@ -292,5 +292,5 @@ def encoderConf(self) -> RobertaEncoderConf:
[`License <https://github.com/pytorch/fairseq/blob/main/LICENSE>`__,
`Source <https://github.com/pytorch/fairseq/tree/main/examples/roberta#pre-trained-models>`__]

Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage.
Please refer to :func:`torchtext.models.RobertaBundle` for the usage.
"""