diff --git a/docs/source/models.rst b/docs/source/models.rst index b5425c8da4..d34d9dd845 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -7,10 +7,10 @@ torchtext.models .. automodule:: torchtext.models .. currentmodule:: torchtext.models -RobertaModelBundle ------------------- +RobertaBundle +------------- -.. autoclass:: RobertaModelBundle +.. autoclass:: RobertaBundle :members: transform .. automethod:: get_model diff --git a/test/models/test_models.py b/test/models/test_models.py index 8c2c61c195..2001e799ed 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -40,7 +40,7 @@ def test_self_attn_mask(self): class TestModels(TorchtextTestCase): def test_roberta_bundler_build_model(self): - from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaModelBundle + from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaBundle dummy_encoder_conf = RobertaEncoderConf( vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2 @@ -48,14 +48,14 @@ def test_roberta_bundler_build_model(self): # case: user provide encoder checkpoint state dict dummy_encoder = RobertaModel(dummy_encoder_conf) - model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf, checkpoint=dummy_encoder.state_dict()) + model = RobertaBundle.build_model(encoder_conf=dummy_encoder_conf, checkpoint=dummy_encoder.state_dict()) self.assertEqual(model.state_dict(), dummy_encoder.state_dict()) # case: user provide classifier checkpoint state dict when head is given and override_head is False (by default) dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) - model = RobertaModelBundle.build_model( + model = RobertaBundle.build_model( encoder_conf=dummy_encoder_conf, head=another_dummy_classifier_head, checkpoint=dummy_classifier.state_dict(), @@ -64,7 +64,7 @@ def test_roberta_bundler_build_model(self): # case: user provide classifier checkpoint state dict when head is given and override_head is set True another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) - model = RobertaModelBundle.build_model( + model = RobertaBundle.build_model( encoder_conf=dummy_encoder_conf, head=another_dummy_classifier_head, checkpoint=dummy_classifier.state_dict(), @@ -78,13 +78,13 @@ def test_roberta_bundler_build_model(self): encoder_state_dict = {} for k, v in dummy_classifier.encoder.state_dict().items(): encoder_state_dict["encoder." + k] = v - model = torchtext.models.RobertaModelBundle.build_model( + model = torchtext.models.RobertaBundle.build_model( encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict ) self.assertEqual(model.state_dict(), dummy_classifier.state_dict()) def test_roberta_bundler_train(self): - from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaModelBundle + from torchtext.models import RobertaClassificationHead, RobertaEncoderConf, RobertaModel, RobertaBundle dummy_encoder_conf = RobertaEncoderConf( vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2 @@ -103,7 +103,7 @@ def _train(model): # does not freeze encoder dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) - model = RobertaModelBundle.build_model( + model = RobertaBundle.build_model( encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, freeze_encoder=False, @@ -121,7 +121,7 @@ def _train(model): # freeze encoder dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) - model = RobertaModelBundle.build_model( + model = RobertaBundle.build_model( encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, freeze_encoder=True, diff --git a/torchtext/models/roberta/__init__.py b/torchtext/models/roberta/__init__.py index 1db7a6a29a..57265aeb3d 100644 --- a/torchtext/models/roberta/__init__.py +++ b/torchtext/models/roberta/__init__.py @@ -1,7 +1,7 @@ from .bundler import ( ROBERTA_BASE_ENCODER, ROBERTA_LARGE_ENCODER, - RobertaModelBundle, + RobertaBundle, XLMR_BASE_ENCODER, XLMR_LARGE_ENCODER, ) @@ -11,7 +11,7 @@ "RobertaEncoderConf", "RobertaClassificationHead", "RobertaModel", - "RobertaModelBundle", + "RobertaBundle", "XLMR_BASE_ENCODER", "XLMR_LARGE_ENCODER", "ROBERTA_BASE_ENCODER", diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 3805f1a12b..be4752e1e5 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -22,8 +22,8 @@ def _is_head_available_in_checkpoint(checkpoint, head_state_dict): @dataclass -class RobertaModelBundle: - """RobertaModelBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None) +class RobertaBundle: + """RobertaBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None) Example - Pretrained base xlmr encoder >>> import torch, torchtext @@ -52,11 +52,11 @@ class RobertaModelBundle: torch.Size([1, 2]) Example - User-specified configuration and checkpoint - >>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead + >>> from torchtext.models import RobertaEncoderConf, RobertaBundle, RobertaClassificationHead >>> model_weights_path = "https://download.pytorch.org/models/text/xlmr.base.encoder.pt" >>> encoder_conf = RobertaEncoderConf(vocab_size=250002) >>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) - >>> model = RobertaModelBundle.build_model(encoder_conf=encoder_conf, head=classifier_head, checkpoint=model_weights_path) + >>> model = RobertaBundle.build_model(encoder_conf=encoder_conf, head=classifier_head, checkpoint=model_weights_path) """ _encoder_conf: RobertaEncoderConf @@ -99,7 +99,7 @@ def get_model( else: input_head = self._head - return RobertaModelBundle.build_model( + return RobertaBundle.build_model( encoder_conf=self._encoder_conf, head=input_head, freeze_encoder=freeze_encoder, @@ -160,7 +160,7 @@ def encoderConf(self) -> RobertaEncoderConf: return self._encoder_conf -XLMR_BASE_ENCODER = RobertaModelBundle( +XLMR_BASE_ENCODER = RobertaBundle( _path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"), _encoder_conf=RobertaEncoderConf(vocab_size=250002), transform=lambda: T.Sequential( @@ -184,11 +184,11 @@ def encoderConf(self) -> RobertaEncoderConf: [`License `__, `Source `__] - Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage. + Please refer to :func:`torchtext.models.RobertaBundle` for the usage. """ -XLMR_LARGE_ENCODER = RobertaModelBundle( +XLMR_LARGE_ENCODER = RobertaBundle( _path=urljoin(_TEXT_BUCKET, "xlmr.large.encoder.pt"), _encoder_conf=RobertaEncoderConf( vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24 @@ -214,11 +214,11 @@ def encoderConf(self) -> RobertaEncoderConf: [`License `__, `Source `__] - Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage. + Please refer to :func:`torchtext.models.RobertaBundle` for the usage. """ -ROBERTA_BASE_ENCODER = RobertaModelBundle( +ROBERTA_BASE_ENCODER = RobertaBundle( _path=urljoin(_TEXT_BUCKET, "roberta.base.encoder.pt"), _encoder_conf=RobertaEncoderConf(vocab_size=50265), transform=lambda: T.Sequential( @@ -250,11 +250,11 @@ def encoderConf(self) -> RobertaEncoderConf: [`License `__, `Source `__] - Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage. + Please refer to :func:`torchtext.models.RobertaBundle` for the usage. """ -ROBERTA_LARGE_ENCODER = RobertaModelBundle( +ROBERTA_LARGE_ENCODER = RobertaBundle( _path=urljoin(_TEXT_BUCKET, "roberta.large.encoder.pt"), _encoder_conf=RobertaEncoderConf( vocab_size=50265, @@ -292,5 +292,5 @@ def encoderConf(self) -> RobertaEncoderConf: [`License `__, `Source `__] - Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage. + Please refer to :func:`torchtext.models.RobertaBundle` for the usage. """