From 961b821ccffbcef4766bee8920366fe7d31a9c7a Mon Sep 17 00:00:00 2001 From: mttk Date: Mon, 17 Sep 2018 12:05:53 +0200 Subject: [PATCH 1/2] Fix translation dataset splits errors --- test/translation.py | 2 +- torchtext/datasets/translation.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/translation.py b/test/translation.py index abd3180908..eb3c47a349 100644 --- a/test/translation.py +++ b/test/translation.py @@ -78,7 +78,7 @@ def tokenize_en(text): train, val = datasets.TranslationDataset.splits( path='.data/multi30k/', train='train', - validation='val', exts=('.de', '.en'), + validation='val', test=None, exts=('.de', '.en'), fields=(DE, EN)) print(train.fields) diff --git a/torchtext/datasets/translation.py b/torchtext/datasets/translation.py index 919c04d4d0..a1c335adb2 100644 --- a/torchtext/datasets/translation.py +++ b/torchtext/datasets/translation.py @@ -97,7 +97,8 @@ def splits(cls, exts, fields, root='.data', Remaining keyword arguments: Passed to the splits method of Dataset. """ - path = os.path.join('data', cls.name) + expected_folder = os.path.join(root, cls.name) + path = expected_folder if os.path.exists(expected_folder) else None return super(Multi30k, cls).splits( exts, fields, path, root, train, validation, test, **kwargs) @@ -206,6 +207,8 @@ def splits(cls, exts, fields, root='.data', Remaining keyword arguments: Passed to the splits method of Dataset. """ - path = os.path.join('data', cls.name) + expected_folder = os.path.join(root, cls.name) + path = expected_folder if os.path.exists(expected_folder) else None + return super(WMT14, cls).splits( exts, fields, path, root, train, validation, test, **kwargs) From 5fe4cd74b1473637311ae39ecdbae46d31b82c1a Mon Sep 17 00:00:00 2001 From: mttk Date: Mon, 17 Sep 2018 12:09:11 +0200 Subject: [PATCH 2/2] One blank line --- torchtext/datasets/translation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtext/datasets/translation.py b/torchtext/datasets/translation.py index a1c335adb2..768939bfb9 100644 --- a/torchtext/datasets/translation.py +++ b/torchtext/datasets/translation.py @@ -99,6 +99,7 @@ def splits(cls, exts, fields, root='.data', """ expected_folder = os.path.join(root, cls.name) path = expected_folder if os.path.exists(expected_folder) else None + return super(Multi30k, cls).splits( exts, fields, path, root, train, validation, test, **kwargs)