Skip to content

[fbsync] Enable model testing in FBCode #1720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 55 additions & 25 deletions test/integration_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,25 @@
import torch
from torchtext.models import ROBERTA_BASE_ENCODER, ROBERTA_LARGE_ENCODER, XLMR_BASE_ENCODER, XLMR_LARGE_ENCODER
from parameterized import parameterized
from torchtext.models import (
ROBERTA_BASE_ENCODER,
ROBERTA_LARGE_ENCODER,
XLMR_BASE_ENCODER,
XLMR_LARGE_ENCODER,
)

from ..common.assets import get_asset_path
from ..common.parameterized_utils import nested_params
from ..common.torchtext_test_case import TorchtextTestCase


class TestModels(TorchtextTestCase):
@nested_params(
[
("xlmr.base.output.pt", "XLMR base Model Comparison", XLMR_BASE_ENCODER),
("xlmr.large.output.pt", "XLMR base Model Comparison", XLMR_LARGE_ENCODER),
(
"roberta.base.output.pt",
"Roberta base Model Comparison",
ROBERTA_BASE_ENCODER,
),
(
"roberta.large.output.pt",
"Roberta base Model Comparison",
ROBERTA_LARGE_ENCODER,
),
],
[True, False],
)
def test_model(self, model_args, is_jit):
class TestRobertaEncoders(TorchtextTestCase):
def _roberta_encoders(self, is_jit, encoder, expected_asset_name, test_text):
"""Verify pre-trained XLM-R and Roberta models in torchtext produce
the same output as the reference implementation within fairseq
"""
expected_asset_name, test_text, model_bundler = model_args

expected_asset_path = get_asset_path(expected_asset_name)

transform = model_bundler.transform()
model = model_bundler.get_model()
transform = encoder.transform()
model = encoder.get_model()
model = model.eval()

if is_jit:
Expand All @@ -44,3 +30,47 @@ def test_model(self, model_args, is_jit):
actual = model(model_input)
expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected)

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_xlmr_base_model(self, name, is_jit):
expected_asset_name = "xlmr.base.output.pt"
test_text = "XLMR base Model Comparison"
self._roberta_encoders(
is_jit=is_jit,
encoder=XLMR_BASE_ENCODER,
expected_asset_name=expected_asset_name,
test_text=test_text,
)

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_xlmr_large_model(self, name, is_jit):
expected_asset_name = "xlmr.large.output.pt"
test_text = "XLMR base Model Comparison"
self._roberta_encoders(
is_jit=is_jit,
encoder=XLMR_LARGE_ENCODER,
expected_asset_name=expected_asset_name,
test_text=test_text,
)

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_roberta_base_model(self, name, is_jit):
expected_asset_name = "roberta.base.output.pt"
test_text = "Roberta base Model Comparison"
self._roberta_encoders(
is_jit=is_jit,
encoder=ROBERTA_BASE_ENCODER,
expected_asset_name=expected_asset_name,
test_text=test_text,
)

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_robeta_large_model(self, name, is_jit):
expected_asset_name = "roberta.large.output.pt"
test_text = "Roberta base Model Comparison"
self._roberta_encoders(
is_jit=is_jit,
encoder=ROBERTA_LARGE_ENCODER,
expected_asset_name=expected_asset_name,
test_text=test_text,
)