From 56bbef66b8f157949b54029d79106708cbc8df51 Mon Sep 17 00:00:00 2001 From: Ajie Utama Date: Mon, 17 Sep 2018 12:18:25 +0200 Subject: [PATCH 1/3] fixed ParseTextField postprocessing function, added option to whether reverse ParseTextField, added genre field multinli --- test/nli.py | 67 +++++++++++++++++++++------------------ torchtext/datasets/nli.py | 64 ++++++++++++++++++++++--------------- 2 files changed, 75 insertions(+), 56 deletions(-) diff --git a/test/nli.py b/test/nli.py index 86e7c840a8..cb88170fe5 100644 --- a/test/nli.py +++ b/test/nli.py @@ -2,58 +2,65 @@ from torchtext import datasets # Testing SNLI -TEXT = data.Field() -LABEL = data.Field(sequential=False) +print("Run test on SNLI...") +TEXT = datasets.nli.ParsedTextField() +LABEL = data.LabelField() +TREE = datasets.nli.ShiftReduceField() -train, val, test = datasets.SNLI.splits(TEXT, LABEL) +train, val, test = datasets.SNLI.splits(TEXT, LABEL, TREE) -print(train.fields) -print(len(train)) -print(vars(train[0])) +print("Fields:", train.fields) +print("Number of examples:\n", len(train)) +print("First Example instance:\n", vars(train[0])) TEXT.build_vocab(train) LABEL.build_vocab(train) -train_iter, val_iter, test_iter = data.BucketIterator.splits( - (train, val, test), batch_size=3) +train_iter, val_iter, test_iter = data.Iterator.splits((train, val, test), batch_size=3) batch = next(iter(train_iter)) -print(batch.premise) -print(batch.hypothesis) -print(batch.label) +print("Numericalize premises:\n", batch.premise) +print("Numericalize hypotheses:\n", batch.hypothesis) +print("Entailment labels:\n", batch.label) -train_iter, val_iter, test_iter = datasets.SNLI.iters(batch_size=4) +print("Test iters function") +train_iter, val_iter, test_iter = datasets.SNLI.iters(batch_size=4, trees=True) batch = next(iter(train_iter)) -print(batch.premise) -print(batch.hypothesis) -print(batch.label) +print("Numericalize premises:\n", batch.premise) +print("Numericalize hypotheses:\n", batch.hypothesis) +print("Entailment labels:\n", batch.label) # Testing MultiNLI -TEXT = data.Field() -LABEL = data.Field(sequential=False) +print("Run test on MultiNLI...") +TEXT = datasets.nli.ParsedTextField() +LABEL = data.LabelField() +GENRE = data.LabelField() +TREE = datasets.nli.ShiftReduceField() -train, val, test = datasets.MultiNLI.splits(TEXT, LABEL) +train, val, test = datasets.MultiNLI.splits(TEXT, LABEL, TREE, GENRE) -print(train.fields) -print(len(train)) -print(vars(train[0])) +print("Fields:", train.fields) +print("Number of examples:\n", len(train)) +print("First Example instance:\n", vars(train[0])) TEXT.build_vocab(train) LABEL.build_vocab(train) +GENRE.build_vocab(train, val, test) -train_iter, val_iter, test_iter = data.BucketIterator.splits( - (train, val, test), batch_size=3) +train_iter, val_iter, test_iter = data.Iterator.splits((train, val, test), batch_size=3) batch = next(iter(train_iter)) -print(batch.premise) -print(batch.hypothesis) -print(batch.label) +print("Numericalize premises:\n", batch.premise) +print("Numericalize hypotheses:\n", batch.hypothesis) +print("Entailment labels:\n", batch.label) +print("Genre categories:\n", batch.genre) -train_iter, val_iter, test_iter = datasets.MultiNLI.iters(batch_size=4) +print("Test iters function") +train_iter, val_iter, test_iter = datasets.MultiNLI.iters(batch_size=4, trees=True) batch = next(iter(train_iter)) -print(batch.premise) -print(batch.hypothesis) -print(batch.label) +print("Numericalize premises:\n", batch.premise) +print("Numericalize hypotheses:\n", batch.hypothesis) +print("Entailment labels:\n", batch.label) diff --git a/torchtext/datasets/nli.py b/torchtext/datasets/nli.py index 0eb4d28b81..403f50aebf 100644 --- a/torchtext/datasets/nli.py +++ b/torchtext/datasets/nli.py @@ -12,14 +12,21 @@ def __init__(self): class ParsedTextField(data.Field): - - def __init__(self, eos_token='', lower=False): - + """ + Field for parsed sentences data in NLI datasets. + Expensive tokenization could be omitted from the pipeline as + the parse tree annotations are already in tokenized form. + """ + def __init__(self, eos_token='', lower=False, reverse=False): + """ remove parentheses to recover the original sentences """ + preprocessing = lambda parse: [t for t in parse if t not in ('(', ')')] + if reverse: + postprocessing = lambda parse, _: [list(reversed(p)) for p in parse] + else: + postprocessing = None super(ParsedTextField, self).__init__( - eos_token=eos_token, lower=lower, preprocessing=lambda parse: [ - t for t in parse if t not in ('(', ')')], - postprocessing=lambda parse, _, __: [ - list(reversed(p)) for p in parse]) + eos_token=eos_token, lower=lower, preprocessing=preprocessing, + postprocessing=postprocessing, include_lengths=True) class NLIDataset(data.TabularDataset): @@ -34,7 +41,7 @@ def sort_key(ex): len(ex.premise), len(ex.hypothesis)) @classmethod - def splits(cls, text_field, label_field, parse_field=None, root='.data', + def splits(cls, text_field, label_field, parse_field=None, extra_fields={}, root='.data', train='train.jsonl', validation='val.jsonl', test='test.jsonl'): """Create dataset objects for splits of the SNLI dataset. @@ -46,6 +53,7 @@ def splits(cls, text_field, label_field, parse_field=None, root='.data', label_field: The field that will be used for label data. parse_field: The field that will be used for shift-reduce parser transitions, or None to not include them. + extra_field: A dict[json_key: Tuple(field_name, Field)] root: The root directory that the dataset's zip archive will be expanded into. train: The filename of the train data. Default: 'train.jsonl'. @@ -57,21 +65,21 @@ def splits(cls, text_field, label_field, parse_field=None, root='.data', path = cls.download(root) if parse_field is None: - return super(NLIDataset, cls).splits( - path, root, train, validation, test, - format='json', fields={'sentence1': ('premise', text_field), - 'sentence2': ('hypothesis', text_field), - 'gold_label': ('label', label_field)}, - filter_pred=lambda ex: ex.label != '-') + fields = {'sentence1': ('premise', text_field), + 'sentence2': ('hypothesis', text_field), + 'gold_label': ('label', label_field)} + else: + fields = {'sentence1_binary_parse': [('premise', text_field), ('premise_transitions', parse_field)], + 'sentence2_binary_parse': [('hypothesis', text_field), ('hypothesis_transitions', parse_field)], + 'gold_label': ('label', label_field)} + + for key in extra_fields: + if key not in fields.keys(): + fields[key] = extra_fields[key] + return super(NLIDataset, cls).splits( path, root, train, validation, test, - format='json', fields={'sentence1_binary_parse': - [('premise', text_field), - ('premise_transitions', parse_field)], - 'sentence2_binary_parse': - [('hypothesis', text_field), - ('hypothesis_transitions', parse_field)], - 'gold_label': ('label', label_field)}, + format='json', fields=fields, filter_pred=lambda ex: ex.label != '-') @classmethod @@ -122,8 +130,8 @@ class SNLI(NLIDataset): def splits(cls, text_field, label_field, parse_field=None, root='.data', train='snli_1.0_train.jsonl', validation='snli_1.0_dev.jsonl', test='snli_1.0_test.jsonl'): - return super(SNLI, cls).splits(text_field, label_field, parse_field, - root, train, validation, test) + return super(SNLI, cls).splits(text_field, label_field, parse_field=parse_field, + root=root, train=train, validation=validation, test=test) class MultiNLI(NLIDataset): @@ -132,9 +140,13 @@ class MultiNLI(NLIDataset): name = 'multinli' @classmethod - def splits(cls, text_field, label_field, parse_field=None, root='.data', + def splits(cls, text_field, label_field, parse_field=None, genre_field=None, root='.data', train='multinli_1.0_train.jsonl', validation='multinli_1.0_dev_matched.jsonl', test='multinli_1.0_dev_mismatched.jsonl'): - return super(MultiNLI, cls).splits(text_field, label_field, parse_field, - root, train, validation, test) + extra_fields = {} + if genre_field is not None: + extra_fields["genre"] = ("genre", genre_field) + + return super(MultiNLI, cls).splits(text_field, label_field, parse_field=parse_field, extra_fields=extra_fields, + root=root, train=train, validation=validation, test=test) From e6bae408eba647966ce90e8d2d896cb9d00a0f5f Mon Sep 17 00:00:00 2001 From: Ajie Utama Date: Mon, 17 Sep 2018 19:38:54 +0200 Subject: [PATCH 2/3] fix code lines length --- torchtext/datasets/nli.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/torchtext/datasets/nli.py b/torchtext/datasets/nli.py index 403f50aebf..b1d9608730 100644 --- a/torchtext/datasets/nli.py +++ b/torchtext/datasets/nli.py @@ -41,8 +41,9 @@ def sort_key(ex): len(ex.premise), len(ex.hypothesis)) @classmethod - def splits(cls, text_field, label_field, parse_field=None, extra_fields={}, root='.data', - train='train.jsonl', validation='val.jsonl', test='test.jsonl'): + def splits(cls, text_field, label_field, parse_field=None, + extra_fields={}, root='.data', train='train.jsonl', + validation='val.jsonl', test='test.jsonl'): """Create dataset objects for splits of the SNLI dataset. This is the most flexible way to use the dataset. @@ -69,8 +70,10 @@ def splits(cls, text_field, label_field, parse_field=None, extra_fields={}, root 'sentence2': ('hypothesis', text_field), 'gold_label': ('label', label_field)} else: - fields = {'sentence1_binary_parse': [('premise', text_field), ('premise_transitions', parse_field)], - 'sentence2_binary_parse': [('hypothesis', text_field), ('hypothesis_transitions', parse_field)], + fields = {'sentence1_binary_parse': [('premise', text_field), + ('premise_transitions', parse_field)], + 'sentence2_binary_parse': [('hypothesis', text_field), + ('hypothesis_transitions', parse_field)], 'gold_label': ('label', label_field)} for key in extra_fields: @@ -131,7 +134,8 @@ def splits(cls, text_field, label_field, parse_field=None, root='.data', train='snli_1.0_train.jsonl', validation='snli_1.0_dev.jsonl', test='snli_1.0_test.jsonl'): return super(SNLI, cls).splits(text_field, label_field, parse_field=parse_field, - root=root, train=train, validation=validation, test=test) + root=root, train=train, validation=validation, + test=test) class MultiNLI(NLIDataset): @@ -140,7 +144,8 @@ class MultiNLI(NLIDataset): name = 'multinli' @classmethod - def splits(cls, text_field, label_field, parse_field=None, genre_field=None, root='.data', + def splits(cls, text_field, label_field, parse_field=None, genre_field=None, + root='.data', train='multinli_1.0_train.jsonl', validation='multinli_1.0_dev_matched.jsonl', test='multinli_1.0_dev_mismatched.jsonl'): @@ -148,5 +153,8 @@ def splits(cls, text_field, label_field, parse_field=None, genre_field=None, roo if genre_field is not None: extra_fields["genre"] = ("genre", genre_field) - return super(MultiNLI, cls).splits(text_field, label_field, parse_field=parse_field, extra_fields=extra_fields, - root=root, train=train, validation=validation, test=test) + return super(MultiNLI, cls).splits(text_field, label_field, + parse_field=parse_field, + extra_fields=extra_fields, + root=root, train=train, + validation=validation, test=test) From cb8e946cc6a9344fb58cea3792bb5fa560a9c566 Mon Sep 17 00:00:00 2001 From: Ajie Utama Date: Tue, 18 Sep 2018 07:35:16 +0200 Subject: [PATCH 3/3] remove lambda expression --- torchtext/datasets/nli.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/torchtext/datasets/nli.py b/torchtext/datasets/nli.py index b1d9608730..745da48190 100644 --- a/torchtext/datasets/nli.py +++ b/torchtext/datasets/nli.py @@ -18,15 +18,17 @@ class ParsedTextField(data.Field): the parse tree annotations are already in tokenized form. """ def __init__(self, eos_token='', lower=False, reverse=False): - """ remove parentheses to recover the original sentences """ - preprocessing = lambda parse: [t for t in parse if t not in ('(', ')')] if reverse: - postprocessing = lambda parse, _: [list(reversed(p)) for p in parse] + super(ParsedTextField, self).__init__( + eos_token=eos_token, lower=lower, + preprocessing=lambda parse: [t for t in parse if t not in ('(', ')')], + postprocessing=lambda parse, _: [list(reversed(p)) for p in parse], + include_lengths=True) else: - postprocessing = None - super(ParsedTextField, self).__init__( - eos_token=eos_token, lower=lower, preprocessing=preprocessing, - postprocessing=postprocessing, include_lengths=True) + super(ParsedTextField, self).__init__( + eos_token=eos_token, lower=lower, + preprocessing=lambda parse: [t for t in parse if t not in ('(', ')')], + include_lengths=True) class NLIDataset(data.TabularDataset):