diff --git a/docs/source/index.rst b/docs/source/index.rst index 2072db4adf..4e812ee286 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -44,6 +44,7 @@ popular datasets for natural language. experimental_transforms experimental_vectors experimental_vocab + models_utils examples .. automodule:: torchtext diff --git a/docs/source/models_utils.rst b/docs/source/models_utils.rst new file mode 100644 index 0000000000..1bcb2f73e7 --- /dev/null +++ b/docs/source/models_utils.rst @@ -0,0 +1,13 @@ +.. role:: hidden + :class: hidden-section + +torchtext.experimental.models.utils +=================================== + +.. automodule:: torchtext.experimental.models.utils +.. currentmodule:: torchtext.experimental.models.utils + +:hidden:`count_model_param` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: count_model_param diff --git a/test/models/__init__.py b/test/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/models/test_utils.py b/test/models/test_utils.py new file mode 100644 index 0000000000..162ec932c4 --- /dev/null +++ b/test/models/test_utils.py @@ -0,0 +1,9 @@ +import torch +from torchtext.experimental.models.utils import count_model_param +from ..common.torchtext_test_case import TorchtextTestCase + + +class TestModelsUtils(TorchtextTestCase): + def test_count_model_parameters_func(self): + model = torch.nn.Embedding(100, 200) + self.assertEqual(count_model_param(model, unit=10**3), 20.0) diff --git a/torchtext/experimental/__init__.py b/torchtext/experimental/__init__.py index 196d0dd84b..18eba857a5 100644 --- a/torchtext/experimental/__init__.py +++ b/torchtext/experimental/__init__.py @@ -1,4 +1,5 @@ from . import datasets from . import transforms +from . import models -__all__ = ['datasets', 'transforms'] +__all__ = ['datasets', 'transforms', 'models'] diff --git a/torchtext/experimental/models/__init__.py b/torchtext/experimental/models/__init__.py new file mode 100644 index 0000000000..2065636d77 --- /dev/null +++ b/torchtext/experimental/models/__init__.py @@ -0,0 +1,3 @@ +from .utils import count_model_param + +__all__ = ["count_model_param"] diff --git a/torchtext/experimental/models/utils.py b/torchtext/experimental/models/utils.py new file mode 100644 index 0000000000..b24ddd0404 --- /dev/null +++ b/torchtext/experimental/models/utils.py @@ -0,0 +1,22 @@ +import torch + + +def count_model_param(nn_model, unit=10**6): + r""" + Count the parameters in a model + + Args: + model: the model (torch.nn.Module) + unit: the unit of the returned value. Default: 10**6 or M. + + Examples: + >>> import torch + >>> import torchtext + >>> from torchtext.experimental.models.utils import count_model_param + >>> model = torch.nn.Embedding(100, 200) + >>> count_model_param(model, unit=10**3) + >>> 20. + """ + model_parameters = filter(lambda p: p.requires_grad, nn_model.parameters()) + params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters]) + return params.item() / unit