Skip to content

Refactor text_classification, improve extract_archive and add a small example #565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Jul 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
c30ec38
add new APIs to build dataset.
Jul 9, 2019
ffd1f49
Add new datasets for text classification.
Jul 11, 2019
3b7b0e2
Add docs and examples.
Jul 9, 2019
5a31dd3
Split text_normalize out of preprocess function.
Jul 11, 2019
5efa58e
Add docs and test case.
Jul 11, 2019
844242a
Update README file.
Jul 12, 2019
b373de9
revise generate_iters() function.
Jul 22, 2019
6d5cb03
Remove TextDataset class.
Jul 22, 2019
3f0c523
Remove generate_iterators() API
Jul 22, 2019
2f20914
remove unnecessary library loading
Jul 22, 2019
57d0d03
Re-name build_vocab to build_dictionary
Jul 22, 2019
4cf4099
change build_vocab to build_dictionary.
Jul 22, 2019
c8ec403
convert two functions to the interanls.
Jul 22, 2019
0568a04
Change the API of _load_text_classification_data() function.
Jul 22, 2019
78673a5
use a static list for url.
Jul 22, 2019
58e3bac
use logging.info as print.
Jul 22, 2019
81e5a31
combine download and extract_archive
Jul 22, 2019
e05d7fe
Merge branch 'master' into new_pattern
cpuhrsch Jul 23, 2019
7ffb267
Merge branch 'new_supervised_learning_dataset' into new_pattern
Jul 23, 2019
e138fa8
examples
Jul 23, 2019
1e9f0e1
remove more
Jul 23, 2019
c746d86
less
Jul 23, 2019
fea3bad
split
Jul 24, 2019
5c90fbc
ordered dict
Jul 24, 2019
ba23ae1
Merge remote-tracking branch 'upstream/master' into tutorial
Jul 24, 2019
3df4dc1
rename
Jul 24, 2019
ea639c2
Simplifications
Jul 24, 2019
193a670
clean more
Jul 24, 2019
285a515
more efficient dictionary building
Jul 24, 2019
fc1fcc1
Merge branch 'master' into tutorial
Jul 24, 2019
3e27dcd
Reduce code
Jul 24, 2019
4678478
tar and extraction
Jul 24, 2019
2a18586
Merge branch 'additionalstuff' into tutorial
Jul 24, 2019
ee9894f
rebase
Jul 24, 2019
197c70d
remove legacy
Jul 24, 2019
bc2369f
more logging and args
Jul 25, 2019
0e81889
more
Jul 25, 2019
75fd515
small changes
Jul 25, 2019
e7ea6c2
More small changes
Jul 25, 2019
accf587
Update docs
Jul 25, 2019
5506a2e
bring back examples
Jul 25, 2019
2c8c4bf
bring back examples
Jul 25, 2019
28b0976
small fix
Jul 25, 2019
5520b51
Small test fix
Jul 25, 2019
ff9132d
Use io.open
Jul 25, 2019
331bf79
flake8
Jul 25, 2019
73772e2
flake8
Jul 25, 2019
fdbeb26
remove print
Jul 26, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,7 @@ venv.bak/
.mypy_cache/

# End of https://www.gitignore.io/api/python

# vim
*.swp
*.swo
19 changes: 19 additions & 0 deletions examples/text_classification/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch.nn as nn


class TextSentiment(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super().__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
self.fc = nn.Linear(embed_dim, num_class)
self.init_weights()

def init_weights(self):
initrange = 0.5
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()

def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
26 changes: 26 additions & 0 deletions examples/text_classification/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import sys
import argparse

from torchtext.datasets.text_classification import text_normalize


def predict(text, model, dictionary):
with torch.no_grad():
text = torch.tensor([dictionary.get(token, dictionary['<unk>'])
for token in text_normalize(text)])
output = model(text, torch.tensor([0]))
return output.argmax(1).item() + 1


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Predict text from stdin given model and dictionary')
parser.add_argument('model')
parser.add_argument('dictionary')
args = parser.parse_args()

model = torch.load(args.model)
dictionary = torch.load(args.dictionary)
for line in sys.stdin:
print(predict(line, model, dictionary))
105 changes: 105 additions & 0 deletions examples/text_classification/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import logging
import random
import argparse

import torch

from torchtext.datasets.text_classification import AG_NEWS

from model import TextSentiment


def generate_offsets(data_batch):
offsets = [0]
for entry in data_batch:
offsets.append(offsets[-1] + len(entry))
offsets = torch.tensor(offsets[:-1])
return offsets


def generate_batch(data, labels, i, batch_size):
data_batch = data[i:i + batch_size]
text = torch.cat(data_batch)
offsets = generate_offsets(data_batch)
cls = torch.tensor(labels[i:i + batch_size])
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
return text, offsets, cls


def train(lr_, num_epoch, data, labels):
num_lines = num_epochs * len(data)
for epoch in range(num_epochs):
perm = list(range(len(data)))
random.shuffle(perm)
data = [data[i] for i in perm]
labels = [labels[i] for i in perm]

for i in range(0, len(data), batch_size):
text, offsets, cls = generate_batch(data, labels, i, batch_size)
output = model(text, offsets)
loss = criterion(output, cls)
loss.backward()
progress = (i + len(data) * epoch) / float(num_lines)
lr = lr_ * (1 - progress)
# SGD
for p in model.parameters():
p.data.add_(p.grad.data * -lr)
p.grad.detach_()
p.grad.zero_()
print("")


def test(data, labels):
total_accuracy = []
for i in range(0, len(data), batch_size):
with torch.no_grad():
text, offsets, cls = generate_batch(data, labels, i, batch_size)
output = model(text, offsets)
accuracy = (output.argmax(1) == cls).float().mean().item()
total_accuracy.append(accuracy)
print("Test - Accuracy: {}".format(sum(total_accuracy) / len(total_accuracy)))


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Train a text classification model on AG_NEWS')
parser.add_argument('--num-epochs', type=int, default=3)
parser.add_argument('--embed-dim', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--lr', type=float, default=64.0)
parser.add_argument('--ngrams', type=int, default=2)
parser.add_argument('--device', default='cpu')
parser.add_argument('--data', default='.data')
parser.add_argument('--save-model-path')
parser.add_argument('--save-dictionary-path')
parser.add_argument('--logging-level', default='WARNING')
args = parser.parse_args()

num_epochs = args.num_epochs
embed_dim = args.embed_dim
batch_size = args.batch_size
lr = args.lr
device = args.device
data = args.data

logging.basicConfig(level=getattr(logging, args.logging_level))

if not os.path.exists(data):
print("Creating directory {}".format(data))
os.mkdir(data)

dataset = AG_NEWS(root=data, ngrams=args.ngrams)
model = TextSentiment(len(dataset.dictionary), embed_dim,
len(set(dataset.labels))).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)

train(lr, num_epochs, dataset.train_data, dataset.train_labels)
test(dataset.test_data, dataset.test_labels)

if args.save_model_path:
print("Saving model to {}".format(args.save_model_path))
torch.save(model.to('cpu'), args.save_model_path)
if args.save_dictionary_path:
print("Saving dictionary to {}".format(args.save_dictionary_path))
torch.save(dataset.dictionary, args.save_dictionary_path)
2 changes: 2 additions & 0 deletions test/data/test_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def test_text_classification(self):
# smoke test to ensure ag_news dataset works properly

datadir = os.path.join(self.project_root, ".data")
if not os.path.exists(datadir):
os.mkdir(datadir)
ag_news_cls = AG_NEWS(root=datadir, ngrams=3)
self.assertEqual(len(ag_news_cls.train_examples), 120000)
self.assertEqual(len(ag_news_cls.test_examples), 7600)
Expand Down
6 changes: 3 additions & 3 deletions test/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_vocab_extend(self):
self.assertGreater(len(v), n_vocab)

self.assertEqual(v.itos[:6], ['<unk>', '<pad>', '<bos>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'])
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'])
vectors = v.vectors.numpy()

# The first 5 entries in each vector.
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_vocab_download_glove_vectors(self):
conditional_remove(zip_file)
for dim in ["25", "50", "100", "200"]:
conditional_remove(os.path.join(self.project_root, ".vector_cache",
"glove.twitter.27B.{}d.txt".format(dim)))
"glove.twitter.27B.{}d.txt".format(dim)))

@slow
def test_vocab_download_charngram_vectors(self):
Expand Down Expand Up @@ -355,4 +355,4 @@ def test_vectors_get_vecs(self):
conditional_remove(zip_file)
for dim in ["50", "100", "200", "300"]:
conditional_remove(os.path.join(self.project_root, ".vector_cache",
"glove.6B.{}d.txt".format(dim)))
"glove.6B.{}d.txt".format(dim)))
2 changes: 1 addition & 1 deletion torchtext/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, dataset, batch_size, sort_key=None, device=None,
else:
self.sort_key = sort_key

if type(device) == int:
if isinstance(device, int):
logger.warning("The `device` argument should be set by using `torch.device`"
+ " or passing a string as an argument. This behavior will be"
+ " deprecated soon and currently defaults to cpu.")
Expand Down
1 change: 1 addition & 0 deletions torchtext/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Pipeline(object):
pipes: The Pipelines that will be applied to input sequence
data in order.
"""

def __init__(self, convert_token=None):
"""Create a pipeline.

Expand Down
1 change: 1 addition & 0 deletions torchtext/datasets/nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ParsedTextField(data.Field):
Expensive tokenization could be omitted from the pipeline as
the parse tree annotations are already in tokenized form.
"""

def __init__(self, eos_token='<pad>', lower=False, reverse=False):
if reverse:
super(ParsedTextField, self).__init__(
Expand Down
Loading