|
1 | 1 | import torchtext
|
2 | 2 | import torch
|
| 3 | +from torch.nn import functional as torch_F |
| 4 | +import copy |
3 | 5 | from ..common.torchtext_test_case import TorchtextTestCase
|
4 | 6 | from ..common.assets import get_asset_path
|
5 | 7 |
|
@@ -126,3 +128,49 @@ def test_roberta_bundler_from_config(self):
|
126 | 128 | encoder_state_dict['encoder.' + k] = v
|
127 | 129 | model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
|
128 | 130 | self.assertEqual(model.state_dict(), dummy_classifier.state_dict())
|
| 131 | + |
| 132 | + def test_roberta_bundler_train(self): |
| 133 | + from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle |
| 134 | + dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2) |
| 135 | + from torch.optim import SGD |
| 136 | + |
| 137 | + def _train(model): |
| 138 | + optim = SGD(model.parameters(), lr=1) |
| 139 | + model_input = torch.tensor([[0, 1, 2, 3, 4, 5]]) |
| 140 | + target = torch.tensor([0]) |
| 141 | + logits = model(model_input) |
| 142 | + loss = torch_F.cross_entropy(logits, target) |
| 143 | + loss.backward() |
| 144 | + optim.step() |
| 145 | + |
| 146 | + # does not freeze encoder |
| 147 | + dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) |
| 148 | + dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) |
| 149 | + model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, |
| 150 | + head=dummy_classifier_head, |
| 151 | + freeze_encoder=False, |
| 152 | + checkpoint=dummy_classifier.state_dict()) |
| 153 | + |
| 154 | + encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict()) |
| 155 | + head_current_state_dict = copy.deepcopy(model.head.state_dict()) |
| 156 | + |
| 157 | + _train(model) |
| 158 | + |
| 159 | + self.assertNotEqual(model.encoder.state_dict(), encoder_current_state_dict) |
| 160 | + self.assertNotEqual(model.head.state_dict(), head_current_state_dict) |
| 161 | + |
| 162 | + # freeze encoder |
| 163 | + dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) |
| 164 | + dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) |
| 165 | + model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, |
| 166 | + head=dummy_classifier_head, |
| 167 | + freeze_encoder=True, |
| 168 | + checkpoint=dummy_classifier.state_dict()) |
| 169 | + |
| 170 | + encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict()) |
| 171 | + head_current_state_dict = copy.deepcopy(model.head.state_dict()) |
| 172 | + |
| 173 | + _train(model) |
| 174 | + |
| 175 | + self.assertEqual(model.encoder.state_dict(), encoder_current_state_dict) |
| 176 | + self.assertNotEqual(model.head.state_dict(), head_current_state_dict) |
0 commit comments