Skip to content

Add pad transform, string to int transform #1683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ add_token
---------

.. autofunction:: add_token

str_to_int
----------

.. autofunction:: str_to_int
14 changes: 14 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,17 @@ Sequential
.. autoclass:: Sequential

.. automethod:: forward

PadTransform
------------

.. autoclass:: PadTransform

.. automethod:: forward

StrToIntTransform
-----------------

.. autoclass:: StrToIntTransform

.. automethod:: forward
80 changes: 80 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,86 @@ def test_add_token(self):
def test_add_token_jit(self):
self._add_token(test_scripting=True)

def _pad_transform(self, test_scripting):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a brief description of what this test does?
In the form of "(under some condition), an API produces a result that satisfies this property"

"""
Test padding transform on 1D and 2D tensors.
When max_length < tensor length at dim -1, this should be a no-op.
Otherwise the tensor should be padded to max_length in dim -1.
"""

input_1d_tensor = torch.ones(5)
input_2d_tensor = torch.ones((8, 5))
pad_long = transforms.PadTransform(max_length=7, pad_value=0)
if test_scripting:
pad_long = torch.jit.script(pad_long)
padded_1d_tensor_actual = pad_long(input_1d_tensor)
padded_1d_tensor_expected = torch.cat([torch.ones(5), torch.zeros(2)])
torch.testing.assert_close(
padded_1d_tensor_actual,
padded_1d_tensor_expected,
msg=f"actual: {padded_1d_tensor_actual}, expected: {padded_1d_tensor_expected}",
)

padded_2d_tensor_actual = pad_long(input_2d_tensor)
padded_2d_tensor_expected = torch.cat([torch.ones(8, 5), torch.zeros(8, 2)], axis=-1)
torch.testing.assert_close(
padded_2d_tensor_actual,
padded_2d_tensor_expected,
msg=f"actual: {padded_2d_tensor_actual}, expected: {padded_2d_tensor_expected}",
)

pad_short = transforms.PadTransform(max_length=3, pad_value=0)
if test_scripting:
pad_short = torch.jit.script(pad_short)
padded_1d_tensor_actual = pad_short(input_1d_tensor)
padded_1d_tensor_expected = input_1d_tensor
torch.testing.assert_close(
padded_1d_tensor_actual,
padded_1d_tensor_expected,
msg=f"actual: {padded_1d_tensor_actual}, expected: {padded_1d_tensor_expected}",
)

padded_2d_tensor_actual = pad_short(input_2d_tensor)
padded_2d_tensor_expected = input_2d_tensor
torch.testing.assert_close(
padded_2d_tensor_actual,
padded_2d_tensor_expected,
msg=f"actual: {padded_2d_tensor_actual}, expected: {padded_2d_tensor_expected}",
)

def test_pad_transform(self):
self._pad_transform(test_scripting=False)

def test_pad_transform_jit(self):
self._pad_transform(test_scripting=True)

def _str_to_int_transform(self, test_scripting):
"""
Test StrToIntTransform on list and list of lists.
The result should be the same shape as the input but with all strings converted to ints.
"""
input_1d_string_list = ["1", "2", "3", "4", "5"]
input_2d_string_list = [["1", "2", "3"], ["4", "5", "6"]]

str_to_int = transforms.StrToIntTransform()
if test_scripting:
str_to_int = torch.jit.script(str_to_int)

expected_1d_int_list = [1, 2, 3, 4, 5]
actual_1d_int_list = str_to_int(input_1d_string_list)
self.assertListEqual(expected_1d_int_list, actual_1d_int_list)

expected_2d_int_list = [[1, 2, 3], [4, 5, 6]]
actual_2d_int_list = str_to_int(input_2d_string_list)
for i in range(len(expected_2d_int_list)):
self.assertListEqual(expected_2d_int_list[i], actual_2d_int_list[i])

def test_str_to_int_transform(self):
self._str_to_int_transform(test_scripting=False)

def test_str_to_int_transform_jit(self):
self._str_to_int_transform(test_scripting=True)


class TestSequential(TorchtextTestCase):
def _sequential(self, test_scripting):
Expand Down
26 changes: 26 additions & 0 deletions torchtext/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"to_tensor",
"truncate",
"add_token",
"str_to_int",
]


Expand Down Expand Up @@ -110,3 +111,28 @@ def add_token(input: Any, token_id: Any, begin: bool = True) -> Any:
return output
else:
raise TypeError("Input type not supported")


def str_to_int(input: Any) -> Any:
"""Convert string tokens to integers (either single sequence or batch).

:param input: Input sequence or batch
:type input: Union[List[str], List[List[str]]]
:return: Sequence or batch of string tokens converted to integers
:rtype: Union[List[int], List[List[int]]]
"""
if torch.jit.isinstance(input, List[str]):
output: List[int] = []
for element in input:
output.append(int(element))
return output
if torch.jit.isinstance(input, List[List[str]]):
output: List[List[int]] = []
for ids in input:
current: List[int] = []
for element in ids:
current.append(int(element))
output.append(current)
return output
else:
raise TypeError("Input type not supported")
46 changes: 46 additions & 0 deletions torchtext/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"LabelToIndex",
"Truncate",
"AddToken",
"PadTransform",
"StrToIntTransform",
"GPT2BPETokenizer",
"Sequential",
]
Expand Down Expand Up @@ -221,6 +223,50 @@ def forward(self, input: Any) -> Any:
return F.add_token(input, self.token, self.begin)


class PadTransform(Module):
"""Pad tensor to a fixed length with given padding value.

:param max_length: Maximum length to pad to
:type max_length: int
:param pad_value: Value to pad the tensor with
:type pad_value: bool
"""

def __init__(self, max_length: int, pad_value: int):
super().__init__()
self.max_length = max_length
self.pad_value = pad_value

def forward(self, x: Tensor) -> Tensor:
"""
:param x: The tensor to pad
:type x: Tensor
:return: Tensor padded up to max_length with pad_value
:rtype: Tensor
"""
max_encoded_length = x.size(-1)
if max_encoded_length < self.max_length:
pad_amount = self.max_length - max_encoded_length
x = torch.nn.functional.pad(x, (0, pad_amount), value=self.pad_value)
return x


class StrToIntTransform(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for doc-strings :)

"""Convert string tokens to integers (either single sequence or batch)."""

def __init__(self):
super().__init__()

def forward(self, input: Any) -> Any:
"""
:param input: sequence or batch of string tokens to convert
:type input: Union[List[str], List[List[str]]]
:return: sequence or batch converted into corresponding token ids
:rtype: Union[List[int], List[List[int]]]
"""
return F.str_to_int(input)


class GPT2BPETokenizer(Module):
__jit_unused_properties__ = ["is_jitable"]
"""
Expand Down