-
Notifications
You must be signed in to change notification settings - Fork 812
Add BLEU score metric #627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
You can use flake8 to cover the lint error. |
a35e89f
to
a8ae4c8
Compare
torchtext/data/metrics.py
Outdated
assert max_n > 0 | ||
|
||
ngrams = [tuple(x.split(' ')) for x in ngrams_iterator(tokens, max_n)] | ||
ngrams_counter = collections.Counter(ngrams) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Counter accepts iterators, so you can use collections.Counter(tuple(x.split(' ')) for x in ngrams_iterator(tokens, max_n))
. The advantage is that you don't need to materialize the entire string, so you'll save on memory roundtrips.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a good idea. In general, we only materialize the string if really necessary. Otherwise, we just send them to next generator pipeline.
test/data/test_metrics.py
Outdated
# Partial match | ||
candidate = [['My', 'full', 'pytorch', 'test']] | ||
refs = [[['My', 'full', 'pytorch', 'test', '!'], ['Different']]] | ||
assert round(metrics.bleu_score(candidate, refs), 4) == 0.7788 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the source for these scores? Or how were they computed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I computed them "by hand" by applying the math mentioned in the paper. Should I make this more explicit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fine with me. Do you have more than 4 digits in your calculations?
By the way, you could use "close to 0.7788" instead of round and equal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or torch.testing.assert_allclose
torchtext/data/metrics.py
Outdated
if min(clipped_counts) == 0: | ||
return 0.0 | ||
else: | ||
pn = [clipped_counts[i] / total_counts[i] for i in range(max_n)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of the math here can be done using torch.Tensors - since this is part of the DAPI ecosystem it's ok to hard a hard dependency on that. Numpy on the other hand is something we can't have a hard dependency on, because pytorch/pytorch doesn't either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, I'll look into it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced the math part with tensor operations in 4fc19dc
Makes it more elegant without the list comprehensions!
torchtext/data/metrics.py
Outdated
|
||
# Get the length of the reference that's closest in length to the candidate | ||
refs_len_list = [float(len(ref)) for ref in refs] | ||
refs_len = min(refs_len_list, key=lambda x: abs(len(candidate) - x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm not missing something refs_len is being overwritten in each iteration. So in the end we only use the refs_len of the last (candidate, refs) pair?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great catch !! It should be +=
. My tests didn't catch it because I only tested for very small corpuses.
Related to this, I was wondering how I should test that this works for a large corpus. Should I check that I get the same result as the nltk
implementation for a random corpus hardcoded in the test_metrics.py file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, using a well-established reference implementation isn't a bad idea. However, the values to compare against are better stored as a static asset, so that even though the reference implement might change, you still have the correct test data (we don't want a bug on their end influence us).
So, use ntlk to write out some test data, verify it's what you want and then use that as a reference within the test.
Great job! Thanks for writing this :) Most of my comments are around computational efficiency. |
from torchtext.data.utils import ngrams_iterator | ||
|
||
|
||
def _compute_ngram_counter(tokens, max_n): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that this func has only one line. You may not need a func here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assumption I took to keep this as a function was:
- Right now I use this function in these 2 places:
text/torchtext/data/metrics.py
Lines 73 to 75 in c1fc7b0
reference_counters = _compute_ngram_counter(refs[0], max_n) for ref in refs[1:]: reference_counters = reference_counters | _compute_ngram_counter(ref, max_n) - We might want to use it in other metrics in the future (it's a long shot as of now though)
If you want you could, for fun (for now), benchmark this algorithm on some fixed dataset and see how some of the changes impact your runtime. You can then also compare this to known implementations such as https://github.com/mjpost/sacreBLEU Once something like this is merged we want to make sure that our implementations of these metrics is at least within 10x of other implementations. Then we might decide to bind against these or write C++ code etc. |
0b04716
to
c1fc7b0
Compare
Generally this looks great to me! I'd be happy to merge that in its current form. There's still plenty of room for performance optimizations, but the overall setup seems fine and we have tests :D - Great job! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, I second @cpuhrsch. This looks good to me too!
@cpuhrsch @zhangguanheng66 In terms of github etiquette, should I rebase all these commits into 3 (one for metrics, one for the tests and one for docs) or into 1 only? For now I kept all the changes compared to the original PR in separate commits to keep track of each change |
@sluks - when merging we're choosing to squash and merge this PR into one. So, no need to worry about that. |
All right the PR passed all the tests and I added a few comments to clarify the test, I'm comfortable merging this from my side. I also rebased on top of master. |
A quick review of MT quality scores:
BLEU (Bilingual Evaluation Understudy), also called N-Gram Co-Occurence, was developed by IBM in 2001. It is a measure of similarity between a candidate translation and a set of reference translations, where each reference translation is equally valid. The essence of it is the following: we check how many n-grams the candidate and reference translation have in common (up to 4-grams). The idea is to focus on precision: how much of my candidate translation has correct words? However, if we only relied on precision, the machine could only output one word and precision would be super high but the translation (and recall) very bad. That’s why BLEU adds a “brevity penalty” to penalize for too short translations and improve recall.
NIST is an extension of BLEU, where we weigh the penalties of mis-matched n-grams: if we mis-match “rare” n-grams, the penalty will be higher than mis-matching common n-grams. The idea is to not give weight to usual, “stop word”-type of n-grams.
ROUGE (Recall-Oriented Understudy for Gisting Evaluation) has the same idea as BLEU but focuses on recall instead of precision: we look at how many n-grams in the reference translations are in the candidate translation (instead of the reverse with BLEU). There are 5 different ROUGE variants: ROUGE-N, ROUGE-L, ROUGE-W, ROUGE-S, ROUGE-SU.
METEOR
Similar to BLEU, but considering synonyms and stemming. This is heavier to compute.