Skip to content

Commit aea6ad6

Browse files
authored
add unit tests to for testing model training (#1449)
1 parent c7110b5 commit aea6ad6

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

test/models/test_models.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torchtext
22
import torch
3+
from torch.nn import functional as torch_F
4+
import copy
35
from ..common.torchtext_test_case import TorchtextTestCase
46
from ..common.assets import get_asset_path
57

@@ -126,3 +128,49 @@ def test_roberta_bundler_from_config(self):
126128
encoder_state_dict['encoder.' + k] = v
127129
model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
128130
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())
131+
132+
def test_roberta_bundler_train(self):
133+
from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle
134+
dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2)
135+
from torch.optim import SGD
136+
137+
def _train(model):
138+
optim = SGD(model.parameters(), lr=1)
139+
model_input = torch.tensor([[0, 1, 2, 3, 4, 5]])
140+
target = torch.tensor([0])
141+
logits = model(model_input)
142+
loss = torch_F.cross_entropy(logits, target)
143+
loss.backward()
144+
optim.step()
145+
146+
# does not freeze encoder
147+
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
148+
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
149+
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
150+
head=dummy_classifier_head,
151+
freeze_encoder=False,
152+
checkpoint=dummy_classifier.state_dict())
153+
154+
encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict())
155+
head_current_state_dict = copy.deepcopy(model.head.state_dict())
156+
157+
_train(model)
158+
159+
self.assertNotEqual(model.encoder.state_dict(), encoder_current_state_dict)
160+
self.assertNotEqual(model.head.state_dict(), head_current_state_dict)
161+
162+
# freeze encoder
163+
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
164+
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
165+
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
166+
head=dummy_classifier_head,
167+
freeze_encoder=True,
168+
checkpoint=dummy_classifier.state_dict())
169+
170+
encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict())
171+
head_current_state_dict = copy.deepcopy(model.head.state_dict())
172+
173+
_train(model)
174+
175+
self.assertEqual(model.encoder.state_dict(), encoder_current_state_dict)
176+
self.assertNotEqual(model.head.state_dict(), head_current_state_dict)

0 commit comments

Comments
 (0)