Skip to content

Commit d7adb3f

Browse files
tf-model-analysis-teamtfx-copybara
tf-model-analysis-team
authored andcommitted
Refactor BLEU Metric.
PiperOrigin-RevId: 601823244
1 parent 80b2a02 commit d7adb3f

File tree

2 files changed

+284
-85
lines changed

2 files changed

+284
-85
lines changed

tensorflow_model_analysis/metrics/bleu.py

Lines changed: 133 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,51 @@
3232
_BLEU_NAME_DEFAULT = 'BLEU'
3333

3434

35-
# TODO: b/287700355) - Add __slots__ to this dataclass.
35+
# TODO: b/287700355 - Add __slots__ to _Accumulator
36+
@dataclasses.dataclass
37+
class _Accumulator:
38+
"""Accumulator for _BleuCombiner.
39+
40+
Attributes:
41+
matching_ngrams: A list containing the number of matching n-grams between
42+
the hypothesis and the reference for each n. This should be initialized as
43+
np.zeros(max_ngram_order).
44+
total_ngrams: A list containing the total number of n-grams for each n. Like
45+
'matching_ngrams', this should be initialized as an
46+
np.zeros(max_ngram_order).
47+
hyp_len: The number of unigrams (words) in the hypothesis.
48+
ref_len: The number of unigrams (words) in the reference.
49+
50+
matching_ngrams[n - 1] = number of matching n-grams for n > 0
51+
matching_ngrams[0] = number of matching unigrams
52+
matching_ngrams[1] = number of matching bigrams
53+
...
54+
55+
total_ngrams[n - 1] = (
56+
max(number of n-grams in hyp, number of n-grams in ref) for n > 0
57+
)
58+
total_ngrams[] follows same pattern as matching_ngrams[]
59+
60+
For hypotheses and references, ending punctuation (periods, exclamation
61+
points, etc.) count as their own unigram.
62+
For example, 'Google.' has 2 unigrams: 'Google' and '.'.
63+
"""
64+
65+
matching_ngrams: np.ndarray
66+
total_ngrams: np.ndarray
67+
hyp_len: int = 0
68+
ref_len: int = 0
69+
70+
def __eq__(self, other):
71+
return (
72+
np.array_equal(self.matching_ngrams, other.matching_ngrams)
73+
and np.array_equal(self.total_ngrams, other.total_ngrams)
74+
and self.hyp_len == other.hyp_len
75+
and self.ref_len == other.ref_len
76+
)
77+
78+
79+
# TODO: b/287700355 - Add __slots__ to this dataclass.
3680
@dataclasses.dataclass
3781
class _RefInfo:
3882
ngrams: collections.Counter[dict[tuple[str], int]] # n-grams and counts
@@ -92,6 +136,36 @@ def __init__(
92136
self.key = key
93137
self.bleu_metric = sacrebleu.BLEU(**bleu_kwargs)
94138

139+
def _extract_statistics_for_empty_reference(
140+
self, hypotheses: Sequence[str]
141+
) -> list[_Accumulator]:
142+
"""Returns sentence-level statistics when there are no references.
143+
144+
Args:
145+
hypotheses: A sequence of hypothesis strings.
146+
147+
Returns:
148+
A list of _Accumulators of segment statistics.
149+
"""
150+
sum_hyp_len = 0
151+
for hypothesis in hypotheses:
152+
_, hyp_len = sacrebleu.helpers.extract_all_word_ngrams(
153+
hypothesis, 1, self.bleu_metric.max_ngram_order
154+
)
155+
sum_hyp_len += hyp_len
156+
157+
# No n-grams.
158+
matching_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int)
159+
total_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int)
160+
161+
return [
162+
_Accumulator(
163+
matching_ngrams=matching_ngrams,
164+
total_ngrams=total_ngrams,
165+
hyp_len=sum_hyp_len,
166+
)
167+
]
168+
95169
def _preprocess_segment(self, sentence: str) -> str:
96170
"""Given a sentence, lowercases (optionally) and tokenizes it."""
97171
if self.bleu_metric.lowercase:
@@ -153,7 +227,7 @@ def _compute_segment_statistics(
153227
self,
154228
hypothesis: str,
155229
ref_info: _RefInfo,
156-
) -> list[int]:
230+
) -> _Accumulator:
157231
"""Given a (pre-processed) hypothesis sentence and already computed reference n-grams & lengths, returns the best match statistics across the references.
158232
159233
Args:
@@ -162,7 +236,7 @@ def _compute_segment_statistics(
162236
the list of reference lengths.
163237
164238
Returns:
165-
A list of integers with match statistics.
239+
An _Accumulator with match statistics.
166240
"""
167241
# Extract n-grams for the hypothesis.
168242
hyp_ngrams, hyp_len = sacrebleu.helpers.extract_all_word_ngrams(
@@ -173,8 +247,8 @@ def _compute_segment_statistics(
173247

174248
# Count the stats.
175249
# Although counter has its internal & and | operators, this is faster.
176-
matching_ngrams = [0] * self.bleu_metric.max_ngram_order
177-
total_ngrams = matching_ngrams[:]
250+
matching_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int)
251+
total_ngrams = np.zeros(self.bleu_metric.max_ngram_order, dtype=int)
178252

179253
for hyp_ngram, hyp_count in hyp_ngrams.items():
180254
# n-gram order.
@@ -188,14 +262,18 @@ def _compute_segment_statistics(
188262
if hyp_ngram in ref_ngrams:
189263
matching_ngrams[n] += min(hyp_count, ref_ngrams[hyp_ngram])
190264

191-
# Return a flattened list as per 'stats' semantics.
192-
return [hyp_len, ref_len] + matching_ngrams + total_ngrams
265+
return _Accumulator(
266+
matching_ngrams=matching_ngrams,
267+
total_ngrams=total_ngrams,
268+
hyp_len=hyp_len,
269+
ref_len=ref_len,
270+
)
193271

194272
def _extract_corpus_statistics(
195273
self,
196274
hypotheses: Sequence[str],
197-
references: Optional[Sequence[Sequence[str]]],
198-
) -> list[list[int]]:
275+
references: Sequence[Sequence[str]],
276+
) -> list[_Accumulator]:
199277
"""Reads the corpus and returns sentence-level match statistics for faster re-computations esp during statistical tests.
200278
201279
Args:
@@ -205,8 +283,12 @@ def _extract_corpus_statistics(
205283
batch_size_of_hypotheses).
206284
207285
Returns:
208-
A list where each sublist corresponds to segment statistics.
286+
A list of _Accumulators of segment statistics.
209287
"""
288+
if np.all((np.array(references) == [''])):
289+
# Empty Reference.
290+
return self._extract_statistics_for_empty_reference(hypotheses)
291+
210292
stats = []
211293
tok_count = 0
212294

@@ -238,73 +320,48 @@ def _extract_corpus_statistics(
238320

239321
return stats
240322

241-
def _compute_score_from_stats(self, stats: list[int]) -> sacrebleu.BLEUScore:
323+
def _compute_score_from_accumulator(
324+
self, accumulator: _Accumulator
325+
) -> sacrebleu.BLEUScore:
242326
"""Computes the final score from already aggregated statistics.
243327
244-
'stats' semantics are preserved here from the wrapped implementation.
245-
stats = [hyp_len, ref_len, matching_ngrams, total_ngrams] where
246-
hyp_len = number of unigrams (words) in the hypothesis
247-
ref_len = number of unigrams (words) in the reference
248-
Note, ending punctuation (periods, exclamation points, etc.) count as
249-
their own unigram.
250-
For example, 'Google.' has 2 unigrams: 'Google' and '.'
251-
matching_ngrams[n - 1] = number of matching n-grams for n > 0
252-
matching_ngrams[0] = number of matching unigrams
253-
matching_ngrams[1] = number of matching bigrams
254-
...
255-
total_ngrams[n - 1] = number of n-grams in hyp for n > 0
256-
total_ngrams[] follows same pattern as matching_ngrams[]
257-
258328
Args:
259-
stats: A list of segment-level statistics.
329+
accumulator: An accumulator containing segment-level statistics.
260330
261331
Returns:
262332
A 'BLEUScore' object.
263333
"""
264334
bleu_metric = self.bleu_metric
265335

266-
# matching_ngrams[n - 1] = number of matching n-grams for n > 0
267-
matching_ngrams = stats[2 : 2 + bleu_metric.max_ngram_order]
268-
269-
# total_ngrams[n - 1] = number of n-grams in hyp for n > 0
270-
total_ngrams = stats[2 + bleu_metric.max_ngram_order :]
271-
272-
# hyp_len = number of unigrams (words) in the hypothesis
273-
hyp_len = int(stats[0])
274-
275-
# ref_len = number of unigrams (words) in the reference
276-
ref_len = int(stats[1])
277-
336+
# TODO: b/319702245 - Resolve the issue below in compute_bleu().
337+
# We need to convert the np.ndarray's to a lists here.
338+
# If we leave it as a np.ndarray of ints, then sacrebleu will not be able to
339+
# add decimal smooth values to the stats list within compute_bleu().
340+
# If we convert it to an np.ndarray of floats, then sacrebleu will not be
341+
# able to propely set BLEUScore._verbose because there is no format code 'd'
342+
# for floats.
278343
return self.bleu_metric.compute_bleu(
279-
correct=matching_ngrams,
280-
total=total_ngrams,
281-
sys_len=hyp_len,
282-
ref_len=ref_len,
344+
correct=accumulator.matching_ngrams.tolist(),
345+
total=accumulator.total_ngrams.tolist(),
346+
sys_len=accumulator.hyp_len,
347+
ref_len=accumulator.ref_len,
283348
smooth_method=bleu_metric.smooth_method,
284349
smooth_value=bleu_metric.smooth_value,
285350
effective_order=bleu_metric.effective_order,
286351
max_ngram_order=bleu_metric.max_ngram_order,
287352
)
288353

289354
def create_accumulator(self):
290-
"""Accumulator is the running total of 'stats' of type np.ndarray.
291-
292-
Args: None.
293-
294-
Returns:
295-
'stats' list of all zeros.
296-
"""
297-
# TODO: b/321082946 - Replace 'stats' semantics with a dataclass.
298-
# len(stats)
299-
# = len(hyp_len) + len(ref_len) + len(matching_ngrams) + len(total_ngrams)
300-
# = 1 + 1 + max_ngram_order + max_ngram_order = 2 + 2 * max_ngram_order
301-
return np.zeros(2 + 2 * self.bleu_metric.max_ngram_order, dtype=int)
355+
return _Accumulator(
356+
matching_ngrams=np.zeros(self.bleu_metric.max_ngram_order, dtype=int),
357+
total_ngrams=np.zeros(self.bleu_metric.max_ngram_order, dtype=int),
358+
)
302359

303360
def add_input(
304361
self,
305-
accumulator: np.ndarray,
362+
accumulator: _Accumulator,
306363
element: metric_types.StandardMetricInputs,
307-
) -> np.ndarray:
364+
) -> _Accumulator:
308365
# references = labels, hypotheses = predictions
309366
references, hypotheses, _ = next(
310367
metric_util.to_label_prediction_example_weight(
@@ -318,28 +375,31 @@ def add_input(
318375
)
319376
)
320377

321-
# Sum accumulator and new stats
322-
return accumulator + np.sum(
323-
self._extract_corpus_statistics(hypotheses, references), axis=0
324-
)
378+
corpus_stats = self._extract_corpus_statistics(hypotheses, references)
379+
corpus_stats.append(accumulator)
380+
381+
return self.merge_accumulators(corpus_stats)
325382

326383
def merge_accumulators(
327-
self, list_of_stats: Iterable[np.ndarray]
328-
) -> np.ndarray:
329-
"""Sum of list of stats."""
330-
return np.sum(list_of_stats, axis=0)
384+
self, accumulators: Iterable[_Accumulator]
385+
) -> _Accumulator:
386+
accumulators = iter(accumulators)
387+
result = next(accumulators)
388+
for accumulator in accumulators:
389+
result.hyp_len += accumulator.hyp_len
390+
result.ref_len += accumulator.ref_len
391+
result.matching_ngrams = np.sum(
392+
[result.matching_ngrams, accumulator.matching_ngrams], axis=0
393+
)
394+
result.total_ngrams = np.sum(
395+
[result.total_ngrams, accumulator.total_ngrams], axis=0
396+
)
397+
return result
331398

332399
def extract_output(
333-
self, accumulator: np.ndarray
400+
self, accumulator: _Accumulator
334401
) -> dict[metric_types.MetricKey, sacrebleu.BLEUScore]:
335-
# TODO: b/319702245 - Resolve the issue below in compute_bleu().
336-
# We need to convert the accumulator to a list here.
337-
# If we leave it as a np.ndarray of ints, then sacrebleu will not be able to
338-
# add decimal smooth values to the stats list within compute_bleu().
339-
# If we convert it to an np.ndarray of floats, then sacrebleu will not be
340-
# able to propely set BLEUScore._verbose because there is no format code 'd'
341-
# for floats.
342-
return {self.key: self._compute_score_from_stats(accumulator.tolist())}
402+
return {self.key: self._compute_score_from_accumulator(accumulator)}
343403

344404

345405
def _bleu(

0 commit comments

Comments
 (0)