Skip to content

[WIP] added xnli dataset #613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 287 additions & 49 deletions test/nli.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,304 @@
from torchtext import data
from torchtext import datasets
import torch
from .common.torchtext_test_case import TorchtextTestCase

# Testing SNLI
print("Run test on SNLI...")
TEXT = datasets.nli.ParsedTextField()
LABEL = data.LabelField()
TREE = datasets.nli.ShiftReduceField()
from torchtext.datasets import SNLI, MultiNLI, XNLI
from torchtext.datasets.nli import ParsedTextField, ShiftReduceField
from torchtext.data import Field, LabelField, Iterator

train, val, test = datasets.SNLI.splits(TEXT, LABEL, TREE)
import shutil

print("Fields:", train.fields)
print("Number of examples:\n", len(train))
print("First Example instance:\n", vars(train[0]))

TEXT.build_vocab(train)
LABEL.build_vocab(train)
class TestNLI(TorchtextTestCase):

train_iter, val_iter, test_iter = data.Iterator.splits((train, val, test), batch_size=3)
def test_snli(self):
batch_size = 4

batch = next(iter(train_iter))
print("Numericalize premises:\n", batch.premise)
print("Numericalize hypotheses:\n", batch.hypothesis)
print("Entailment labels:\n", batch.label)
# create fields
TEXT = ParsedTextField()
TREE = ShiftReduceField()
LABEL = LabelField()

print("Test iters function")
train_iter, val_iter, test_iter = datasets.SNLI.iters(batch_size=4, trees=True)
# create train/val/test splits
train, val, test = SNLI.splits(TEXT, LABEL, TREE)

batch = next(iter(train_iter))
print("Numericalize premises:\n", batch.premise)
print("Numericalize hypotheses:\n", batch.hypothesis)
print("Entailment labels:\n", batch.label)
# check all are SNLI datasets
assert type(train) == type(val) == type(test) == SNLI

# check all have correct number of fields
assert len(train.fields) == len(val.fields) == len(test.fields) == 5

# Testing MultiNLI
print("Run test on MultiNLI...")
TEXT = datasets.nli.ParsedTextField()
LABEL = data.LabelField()
GENRE = data.LabelField()
TREE = datasets.nli.ShiftReduceField()
# check fields are the correct type
assert type(train.fields['premise']) == ParsedTextField
assert type(train.fields['premise_transitions']) == ShiftReduceField
assert type(train.fields['hypothesis']) == ParsedTextField
assert type(train.fields['hypothesis_transitions']) == ShiftReduceField
assert type(train.fields['label']) == LabelField

train, val, test = datasets.MultiNLI.splits(TEXT, LABEL, TREE, GENRE)
assert type(val.fields['premise']) == ParsedTextField
assert type(val.fields['premise_transitions']) == ShiftReduceField
assert type(val.fields['hypothesis']) == ParsedTextField
assert type(val.fields['hypothesis_transitions']) == ShiftReduceField
assert type(val.fields['label']) == LabelField

print("Fields:", train.fields)
print("Number of examples:\n", len(train))
print("First Example instance:\n", vars(train[0]))
assert type(test.fields['premise']) == ParsedTextField
assert type(test.fields['premise_transitions']) == ShiftReduceField
assert type(test.fields['hypothesis']) == ParsedTextField
assert type(test.fields['hypothesis_transitions']) == ShiftReduceField
assert type(test.fields['label']) == LabelField

TEXT.build_vocab(train)
LABEL.build_vocab(train)
GENRE.build_vocab(train, val, test)
# check each is the correct length
assert len(train) == 549367
assert len(val) == 9842
assert len(test) == 9824

train_iter, val_iter, test_iter = data.Iterator.splits((train, val, test), batch_size=3)
# build vocabulary
TEXT.build_vocab(train)
LABEL.build_vocab(train)

batch = next(iter(train_iter))
print("Numericalize premises:\n", batch.premise)
print("Numericalize hypotheses:\n", batch.hypothesis)
print("Entailment labels:\n", batch.label)
print("Genre categories:\n", batch.genre)
# ensure vocabulary has been created
assert hasattr(TEXT, 'vocab')
assert hasattr(TEXT.vocab, 'itos')
assert hasattr(TEXT.vocab, 'stoi')

print("Test iters function")
train_iter, val_iter, test_iter = datasets.MultiNLI.iters(batch_size=4, trees=True)
# create iterators
train_iter, val_iter, test_iter = Iterator.splits((train, val, test),
batch_size=batch_size)

batch = next(iter(train_iter))
print("Numericalize premises:\n", batch.premise)
print("Numericalize hypotheses:\n", batch.hypothesis)
print("Entailment labels:\n", batch.label)
# get a batch to test
batch = next(iter(train_iter))

# split premise and hypothesis from tuples to tensors
premise, premise_transitions = batch.premise
hypothesis, hypothesis_transitions = batch.hypothesis
label = batch.label

# check each is actually a tensor
assert type(premise) == torch.Tensor
assert type(premise_transitions) == torch.Tensor
assert type(hypothesis) == torch.Tensor
assert type(hypothesis_transitions) == torch.Tensor
assert type(label) == torch.Tensor

# check have the correct batch dimension
assert premise.shape[-1] == batch_size
assert premise_transitions.shape[-1] == batch_size
assert hypothesis.shape[-1] == batch_size
assert hypothesis_transitions.shape[-1] == batch_size
assert label.shape[-1] == batch_size

# repeat the same tests with iters instead of split
train_iter, val_iter, test_iter = SNLI.iters(batch_size=batch_size,
trees=True)

# split premise and hypothesis from tuples to tensors
premise, premise_transitions = batch.premise
hypothesis, hypothesis_transitions = batch.hypothesis
label = batch.label

# check each is actually a tensor
assert type(premise) == torch.Tensor
assert type(premise_transitions) == torch.Tensor
assert type(hypothesis) == torch.Tensor
assert type(hypothesis_transitions) == torch.Tensor
assert type(label) == torch.Tensor

# check have the correct batch dimension
assert premise.shape[-1] == batch_size
assert premise_transitions.shape[-1] == batch_size
assert hypothesis.shape[-1] == batch_size
assert hypothesis_transitions.shape[-1] == batch_size
assert label.shape[-1] == batch_size

# remove downloaded snli directory
shutil.rmtree('.data/snli')

def test_multinli(self):
batch_size = 4

# create fields
TEXT = ParsedTextField()
TREE = ShiftReduceField()
GENRE = LabelField()
LABEL = LabelField()

# create train/val/test splits
train, val, test = MultiNLI.splits(TEXT, LABEL, TREE, GENRE)

# check all are MultiNLI datasets
assert type(train) == type(val) == type(test) == MultiNLI

# check all have correct number of fields
assert len(train.fields) == len(val.fields) == len(test.fields) == 6

# check fields are the correct type
assert type(train.fields['premise']) == ParsedTextField
assert type(train.fields['premise_transitions']) == ShiftReduceField
assert type(train.fields['hypothesis']) == ParsedTextField
assert type(train.fields['hypothesis_transitions']) == ShiftReduceField
assert type(train.fields['label']) == LabelField
assert type(train.fields['genre']) == LabelField

assert type(val.fields['premise']) == ParsedTextField
assert type(val.fields['premise_transitions']) == ShiftReduceField
assert type(val.fields['hypothesis']) == ParsedTextField
assert type(val.fields['hypothesis_transitions']) == ShiftReduceField
assert type(val.fields['label']) == LabelField
assert type(val.fields['genre']) == LabelField

assert type(test.fields['premise']) == ParsedTextField
assert type(test.fields['premise_transitions']) == ShiftReduceField
assert type(test.fields['hypothesis']) == ParsedTextField
assert type(test.fields['hypothesis_transitions']) == ShiftReduceField
assert type(test.fields['label']) == LabelField
assert type(test.fields['genre']) == LabelField

# check each is the correct length
assert len(train) == 392702
assert len(val) == 9815
assert len(test) == 9832

# build vocabulary
TEXT.build_vocab(train)
LABEL.build_vocab(train)
GENRE.build_vocab(train)

# ensure vocabulary has been created
assert hasattr(TEXT, 'vocab')
assert hasattr(TEXT.vocab, 'itos')
assert hasattr(TEXT.vocab, 'stoi')

# create iterators
train_iter, val_iter, test_iter = Iterator.splits((train, val, test),
batch_size=batch_size)

# get a batch to test
batch = next(iter(train_iter))

# split premise and hypothesis from tuples to tensors
premise, premise_transitions = batch.premise
hypothesis, hypothesis_transitions = batch.hypothesis
label = batch.label
genre = batch.genre

# check each is actually a tensor
assert type(premise) == torch.Tensor
assert type(premise_transitions) == torch.Tensor
assert type(hypothesis) == torch.Tensor
assert type(hypothesis_transitions) == torch.Tensor
assert type(label) == torch.Tensor
assert type(genre) == torch.Tensor

# check have the correct batch dimension
assert premise.shape[-1] == batch_size
assert premise_transitions.shape[-1] == batch_size
assert hypothesis.shape[-1] == batch_size
assert hypothesis_transitions.shape[-1] == batch_size
assert label.shape[-1] == batch_size
assert genre.shape[-1] == batch_size

# repeat the same tests with iters instead of split
train_iter, val_iter, test_iter = MultiNLI.iters(batch_size=batch_size,
trees=True)

# split premise and hypothesis from tuples to tensors
premise, premise_transitions = batch.premise
hypothesis, hypothesis_transitions = batch.hypothesis
label = batch.label

# check each is actually a tensor
assert type(premise) == torch.Tensor
assert type(premise_transitions) == torch.Tensor
assert type(hypothesis) == torch.Tensor
assert type(hypothesis_transitions) == torch.Tensor
assert type(label) == torch.Tensor

# check have the correct batch dimension
assert premise.shape[-1] == batch_size
assert premise_transitions.shape[-1] == batch_size
assert hypothesis.shape[-1] == batch_size
assert hypothesis_transitions.shape[-1] == batch_size
assert label.shape[-1] == batch_size

# remove downloaded multinli directory
shutil.rmtree('.data/multinli')

def test_xnli(self):
batch_size = 4

# create fields
TEXT = Field()
GENRE = LabelField()
LABEL = LabelField()
LANGUAGE = LabelField()

# create val/test splits, XNLI does not have a test set
val, test = XNLI.splits(TEXT, LABEL, GENRE, LANGUAGE)

# check both are XNLI datasets
assert type(val) == type(test) == XNLI

# check all have the correct number of fields
assert len(val.fields) == len(test.fields) == 5

# check fields are the correct type
assert type(val.fields['premise']) == Field
assert type(val.fields['hypothesis']) == Field
assert type(val.fields['label']) == LabelField
assert type(val.fields['genre']) == LabelField
assert type(val.fields['language']) == LabelField

assert type(test.fields['premise']) == Field
assert type(test.fields['hypothesis']) == Field
assert type(test.fields['label']) == LabelField
assert type(test.fields['genre']) == LabelField
assert type(test.fields['language']) == LabelField

# check each is the correct length
assert len(val) == 37350
assert len(test) == 75150

# build vocabulary
TEXT.build_vocab(val)
LABEL.build_vocab(val)
GENRE.build_vocab(val)
LANGUAGE.build_vocab(val)

# ensure vocabulary has been created
assert hasattr(TEXT, 'vocab')
assert hasattr(TEXT.vocab, 'itos')
assert hasattr(TEXT.vocab, 'stoi')

# create iterators
val_iter, test_iter = Iterator.splits((val, test),
batch_size=batch_size)

# get a batch to test
batch = next(iter(val_iter))

# split premise and hypothesis from tuples to tensors
premise = batch.premise
hypothesis = batch.hypothesis
label = batch.label
genre = batch.genre
language = batch.language

# check each is actually a tensor
assert type(premise) == torch.Tensor
assert type(hypothesis) == torch.Tensor
assert type(label) == torch.Tensor
assert type(genre) == torch.Tensor
assert type(language) == torch.Tensor

# check have the correct batch dimension
assert premise.shape[-1] == batch_size
assert hypothesis.shape[-1] == batch_size
assert label.shape[-1] == batch_size
assert genre.shape[-1] == batch_size
assert language.shape[-1] == batch_size

# xnli cannot use the iters method, ensure raises error
with self.assertRaises(NotImplementedError):
val_iter, test_iter = XNLI.iters(batch_size=batch_size)

# remove downloaded xnli directory
shutil.rmtree('.data/xnli')
3 changes: 2 additions & 1 deletion torchtext/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .language_modeling import LanguageModelingDataset, WikiText2, WikiText103, PennTreebank # NOQA
from .nli import SNLI, MultiNLI
from .nli import SNLI, MultiNLI, XNLI
from .sst import SST
from .translation import TranslationDataset, Multi30k, IWSLT, WMT14 # NOQA
from .sequence_tagging import SequenceTaggingDataset, UDPOS, CoNLL2000Chunking # NOQA
Expand All @@ -15,6 +15,7 @@
__all__ = ['LanguageModelingDataset',
'SNLI',
'MultiNLI',
'XNLI',
'SST',
'TranslationDataset',
'Multi30k',
Expand Down
26 changes: 26 additions & 0 deletions torchtext/datasets/nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,29 @@ def splits(cls, text_field, label_field, parse_field=None, genre_field=None,
extra_fields=extra_fields,
root=root, train=train,
validation=validation, test=test)


class XNLI(NLIDataset):
urls = ['http://www.nyu.edu/projects/bowman/xnli/XNLI-1.0.zip']
dirname = 'XNLI-1.0'
name = 'xnli'

@classmethod
def splits(cls, text_field, label_field, genre_field=None, language_field=None,
root='.data',
validation='xnli.dev.jsonl',
test='xnli.test.jsonl'):
extra_fields = {}
if genre_field is not None:
extra_fields["genre"] = ("genre", genre_field)
if language_field is not None:
extra_fields["language"] = ("language", language_field)

return super(XNLI, cls).splits(text_field, label_field,
extra_fields=extra_fields,
root=root, train=None,
validation=validation, test=test)

@classmethod
def iters(cls, *args, **kwargs):
raise NotImplementedError('XNLI dataset does not support iters')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So why XNLI dataset doesn't support iters?

Copy link
Contributor Author

@bentrevett bentrevett Oct 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of this line onward. NLIDataset always assumes there is a training, validation and test set, which is not the case for the XNLI dataset - it only has a validation and test set. I can edit the NLIDataset class to check if train is None and act accordingly?