Skip to content

fixed ParseTextField postprocessing, added genre field to MultiNLI #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 37 additions & 30 deletions test/nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,65 @@
from torchtext import datasets

# Testing SNLI
TEXT = data.Field()
LABEL = data.Field(sequential=False)
print("Run test on SNLI...")
TEXT = datasets.nli.ParsedTextField()
LABEL = data.LabelField()
TREE = datasets.nli.ShiftReduceField()

train, val, test = datasets.SNLI.splits(TEXT, LABEL)
train, val, test = datasets.SNLI.splits(TEXT, LABEL, TREE)

print(train.fields)
print(len(train))
print(vars(train[0]))
print("Fields:", train.fields)
print("Number of examples:\n", len(train))
print("First Example instance:\n", vars(train[0]))

TEXT.build_vocab(train)
LABEL.build_vocab(train)

train_iter, val_iter, test_iter = data.BucketIterator.splits(
(train, val, test), batch_size=3)
train_iter, val_iter, test_iter = data.Iterator.splits((train, val, test), batch_size=3)

batch = next(iter(train_iter))
print(batch.premise)
print(batch.hypothesis)
print(batch.label)
print("Numericalize premises:\n", batch.premise)
print("Numericalize hypotheses:\n", batch.hypothesis)
print("Entailment labels:\n", batch.label)

train_iter, val_iter, test_iter = datasets.SNLI.iters(batch_size=4)
print("Test iters function")
train_iter, val_iter, test_iter = datasets.SNLI.iters(batch_size=4, trees=True)

batch = next(iter(train_iter))
print(batch.premise)
print(batch.hypothesis)
print(batch.label)
print("Numericalize premises:\n", batch.premise)
print("Numericalize hypotheses:\n", batch.hypothesis)
print("Entailment labels:\n", batch.label)


# Testing MultiNLI
TEXT = data.Field()
LABEL = data.Field(sequential=False)
print("Run test on MultiNLI...")
TEXT = datasets.nli.ParsedTextField()
LABEL = data.LabelField()
GENRE = data.LabelField()
TREE = datasets.nli.ShiftReduceField()

train, val, test = datasets.MultiNLI.splits(TEXT, LABEL)
train, val, test = datasets.MultiNLI.splits(TEXT, LABEL, TREE, GENRE)

print(train.fields)
print(len(train))
print(vars(train[0]))
print("Fields:", train.fields)
print("Number of examples:\n", len(train))
print("First Example instance:\n", vars(train[0]))

TEXT.build_vocab(train)
LABEL.build_vocab(train)
GENRE.build_vocab(train, val, test)

train_iter, val_iter, test_iter = data.BucketIterator.splits(
(train, val, test), batch_size=3)
train_iter, val_iter, test_iter = data.Iterator.splits((train, val, test), batch_size=3)

batch = next(iter(train_iter))
print(batch.premise)
print(batch.hypothesis)
print(batch.label)
print("Numericalize premises:\n", batch.premise)
print("Numericalize hypotheses:\n", batch.hypothesis)
print("Entailment labels:\n", batch.label)
print("Genre categories:\n", batch.genre)

train_iter, val_iter, test_iter = datasets.MultiNLI.iters(batch_size=4)
print("Test iters function")
train_iter, val_iter, test_iter = datasets.MultiNLI.iters(batch_size=4, trees=True)

batch = next(iter(train_iter))
print(batch.premise)
print(batch.hypothesis)
print(batch.label)
print("Numericalize premises:\n", batch.premise)
print("Numericalize hypotheses:\n", batch.hypothesis)
print("Entailment labels:\n", batch.label)
78 changes: 50 additions & 28 deletions torchtext/datasets/nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,23 @@ def __init__(self):


class ParsedTextField(data.Field):

def __init__(self, eos_token='<pad>', lower=False):

super(ParsedTextField, self).__init__(
eos_token=eos_token, lower=lower, preprocessing=lambda parse: [
t for t in parse if t not in ('(', ')')],
postprocessing=lambda parse, _, __: [
list(reversed(p)) for p in parse])
"""
Field for parsed sentences data in NLI datasets.
Expensive tokenization could be omitted from the pipeline as
the parse tree annotations are already in tokenized form.
"""
def __init__(self, eos_token='<pad>', lower=False, reverse=False):
if reverse:
super(ParsedTextField, self).__init__(
eos_token=eos_token, lower=lower,
preprocessing=lambda parse: [t for t in parse if t not in ('(', ')')],
postprocessing=lambda parse, _: [list(reversed(p)) for p in parse],
include_lengths=True)
else:
super(ParsedTextField, self).__init__(
eos_token=eos_token, lower=lower,
preprocessing=lambda parse: [t for t in parse if t not in ('(', ')')],
include_lengths=True)


class NLIDataset(data.TabularDataset):
Expand All @@ -34,8 +43,9 @@ def sort_key(ex):
len(ex.premise), len(ex.hypothesis))

@classmethod
def splits(cls, text_field, label_field, parse_field=None, root='.data',
train='train.jsonl', validation='val.jsonl', test='test.jsonl'):
def splits(cls, text_field, label_field, parse_field=None,
extra_fields={}, root='.data', train='train.jsonl',
validation='val.jsonl', test='test.jsonl'):
"""Create dataset objects for splits of the SNLI dataset.

This is the most flexible way to use the dataset.
Expand All @@ -46,6 +56,7 @@ def splits(cls, text_field, label_field, parse_field=None, root='.data',
label_field: The field that will be used for label data.
parse_field: The field that will be used for shift-reduce parser
transitions, or None to not include them.
extra_field: A dict[json_key: Tuple(field_name, Field)]
root: The root directory that the dataset's zip archive will be
expanded into.
train: The filename of the train data. Default: 'train.jsonl'.
Expand All @@ -57,21 +68,23 @@ def splits(cls, text_field, label_field, parse_field=None, root='.data',
path = cls.download(root)

if parse_field is None:
return super(NLIDataset, cls).splits(
path, root, train, validation, test,
format='json', fields={'sentence1': ('premise', text_field),
'sentence2': ('hypothesis', text_field),
'gold_label': ('label', label_field)},
filter_pred=lambda ex: ex.label != '-')
fields = {'sentence1': ('premise', text_field),
'sentence2': ('hypothesis', text_field),
'gold_label': ('label', label_field)}
else:
fields = {'sentence1_binary_parse': [('premise', text_field),
('premise_transitions', parse_field)],
'sentence2_binary_parse': [('hypothesis', text_field),
('hypothesis_transitions', parse_field)],
'gold_label': ('label', label_field)}

for key in extra_fields:
if key not in fields.keys():
fields[key] = extra_fields[key]

return super(NLIDataset, cls).splits(
path, root, train, validation, test,
format='json', fields={'sentence1_binary_parse':
[('premise', text_field),
('premise_transitions', parse_field)],
'sentence2_binary_parse':
[('hypothesis', text_field),
('hypothesis_transitions', parse_field)],
'gold_label': ('label', label_field)},
format='json', fields=fields,
filter_pred=lambda ex: ex.label != '-')

@classmethod
Expand Down Expand Up @@ -122,8 +135,9 @@ class SNLI(NLIDataset):
def splits(cls, text_field, label_field, parse_field=None, root='.data',
train='snli_1.0_train.jsonl', validation='snli_1.0_dev.jsonl',
test='snli_1.0_test.jsonl'):
return super(SNLI, cls).splits(text_field, label_field, parse_field,
root, train, validation, test)
return super(SNLI, cls).splits(text_field, label_field, parse_field=parse_field,
root=root, train=train, validation=validation,
test=test)


class MultiNLI(NLIDataset):
Expand All @@ -132,9 +146,17 @@ class MultiNLI(NLIDataset):
name = 'multinli'

@classmethod
def splits(cls, text_field, label_field, parse_field=None, root='.data',
def splits(cls, text_field, label_field, parse_field=None, genre_field=None,
root='.data',
train='multinli_1.0_train.jsonl',
validation='multinli_1.0_dev_matched.jsonl',
test='multinli_1.0_dev_mismatched.jsonl'):
return super(MultiNLI, cls).splits(text_field, label_field, parse_field,
root, train, validation, test)
extra_fields = {}
if genre_field is not None:
extra_fields["genre"] = ("genre", genre_field)

return super(MultiNLI, cls).splits(text_field, label_field,
parse_field=parse_field,
extra_fields=extra_fields,
root=root, train=train,
validation=validation, test=test)