32
32
_BLEU_NAME_DEFAULT = 'BLEU'
33
33
34
34
35
- # TODO: b/287700355) - Add __slots__ to this dataclass.
35
+ # TODO: b/287700355 - Add __slots__ to _Accumulator
36
+ @dataclasses .dataclass
37
+ class _Accumulator :
38
+ """Accumulator for _BleuCombiner.
39
+
40
+ Attributes:
41
+ matching_ngrams: A list containing the number of matching n-grams between
42
+ the hypothesis and the reference for each n. This should be initialized as
43
+ np.zeros(max_ngram_order).
44
+ total_ngrams: A list containing the total number of n-grams for each n. Like
45
+ 'matching_ngrams', this should be initialized as an
46
+ np.zeros(max_ngram_order).
47
+ hyp_len: The number of unigrams (words) in the hypothesis.
48
+ ref_len: The number of unigrams (words) in the reference.
49
+
50
+ matching_ngrams[n - 1] = number of matching n-grams for n > 0
51
+ matching_ngrams[0] = number of matching unigrams
52
+ matching_ngrams[1] = number of matching bigrams
53
+ ...
54
+
55
+ total_ngrams[n - 1] = (
56
+ max(number of n-grams in hyp, number of n-grams in ref) for n > 0
57
+ )
58
+ total_ngrams[] follows same pattern as matching_ngrams[]
59
+
60
+ For hypotheses and references, ending punctuation (periods, exclamation
61
+ points, etc.) count as their own unigram.
62
+ For example, 'Google.' has 2 unigrams: 'Google' and '.'.
63
+ """
64
+
65
+ matching_ngrams : np .ndarray
66
+ total_ngrams : np .ndarray
67
+ hyp_len : int = 0
68
+ ref_len : int = 0
69
+
70
+ def __eq__ (self , other ):
71
+ return (
72
+ np .array_equal (self .matching_ngrams , other .matching_ngrams )
73
+ and np .array_equal (self .total_ngrams , other .total_ngrams )
74
+ and self .hyp_len == other .hyp_len
75
+ and self .ref_len == other .ref_len
76
+ )
77
+
78
+
79
+ # TODO: b/287700355 - Add __slots__ to this dataclass.
36
80
@dataclasses .dataclass
37
81
class _RefInfo :
38
82
ngrams : collections .Counter [dict [tuple [str ], int ]] # n-grams and counts
@@ -92,6 +136,36 @@ def __init__(
92
136
self .key = key
93
137
self .bleu_metric = sacrebleu .BLEU (** bleu_kwargs )
94
138
139
+ def _extract_statistics_for_empty_reference (
140
+ self , hypotheses : Sequence [str ]
141
+ ) -> list [_Accumulator ]:
142
+ """Returns sentence-level statistics when there are no references.
143
+
144
+ Args:
145
+ hypotheses: A sequence of hypothesis strings.
146
+
147
+ Returns:
148
+ A list of _Accumulators of segment statistics.
149
+ """
150
+ sum_hyp_len = 0
151
+ for hypothesis in hypotheses :
152
+ _ , hyp_len = sacrebleu .helpers .extract_all_word_ngrams (
153
+ hypothesis , 1 , self .bleu_metric .max_ngram_order
154
+ )
155
+ sum_hyp_len += hyp_len
156
+
157
+ # No n-grams.
158
+ matching_ngrams = np .zeros (self .bleu_metric .max_ngram_order , dtype = int )
159
+ total_ngrams = np .zeros (self .bleu_metric .max_ngram_order , dtype = int )
160
+
161
+ return [
162
+ _Accumulator (
163
+ matching_ngrams = matching_ngrams ,
164
+ total_ngrams = total_ngrams ,
165
+ hyp_len = sum_hyp_len ,
166
+ )
167
+ ]
168
+
95
169
def _preprocess_segment (self , sentence : str ) -> str :
96
170
"""Given a sentence, lowercases (optionally) and tokenizes it."""
97
171
if self .bleu_metric .lowercase :
@@ -153,7 +227,7 @@ def _compute_segment_statistics(
153
227
self ,
154
228
hypothesis : str ,
155
229
ref_info : _RefInfo ,
156
- ) -> list [ int ] :
230
+ ) -> _Accumulator :
157
231
"""Given a (pre-processed) hypothesis sentence and already computed reference n-grams & lengths, returns the best match statistics across the references.
158
232
159
233
Args:
@@ -162,7 +236,7 @@ def _compute_segment_statistics(
162
236
the list of reference lengths.
163
237
164
238
Returns:
165
- A list of integers with match statistics.
239
+ An _Accumulator with match statistics.
166
240
"""
167
241
# Extract n-grams for the hypothesis.
168
242
hyp_ngrams , hyp_len = sacrebleu .helpers .extract_all_word_ngrams (
@@ -173,8 +247,8 @@ def _compute_segment_statistics(
173
247
174
248
# Count the stats.
175
249
# Although counter has its internal & and | operators, this is faster.
176
- matching_ngrams = [ 0 ] * self .bleu_metric .max_ngram_order
177
- total_ngrams = matching_ngrams [:]
250
+ matching_ngrams = np . zeros ( self .bleu_metric .max_ngram_order , dtype = int )
251
+ total_ngrams = np . zeros ( self . bleu_metric . max_ngram_order , dtype = int )
178
252
179
253
for hyp_ngram , hyp_count in hyp_ngrams .items ():
180
254
# n-gram order.
@@ -188,14 +262,18 @@ def _compute_segment_statistics(
188
262
if hyp_ngram in ref_ngrams :
189
263
matching_ngrams [n ] += min (hyp_count , ref_ngrams [hyp_ngram ])
190
264
191
- # Return a flattened list as per 'stats' semantics.
192
- return [hyp_len , ref_len ] + matching_ngrams + total_ngrams
265
+ return _Accumulator (
266
+ matching_ngrams = matching_ngrams ,
267
+ total_ngrams = total_ngrams ,
268
+ hyp_len = hyp_len ,
269
+ ref_len = ref_len ,
270
+ )
193
271
194
272
def _extract_corpus_statistics (
195
273
self ,
196
274
hypotheses : Sequence [str ],
197
- references : Optional [ Sequence [Sequence [str ] ]],
198
- ) -> list [list [ int ] ]:
275
+ references : Sequence [Sequence [str ]],
276
+ ) -> list [_Accumulator ]:
199
277
"""Reads the corpus and returns sentence-level match statistics for faster re-computations esp during statistical tests.
200
278
201
279
Args:
@@ -205,8 +283,12 @@ def _extract_corpus_statistics(
205
283
batch_size_of_hypotheses).
206
284
207
285
Returns:
208
- A list where each sublist corresponds to segment statistics.
286
+ A list of _Accumulators of segment statistics.
209
287
"""
288
+ if np .all ((np .array (references ) == ['' ])):
289
+ # Empty Reference.
290
+ return self ._extract_statistics_for_empty_reference (hypotheses )
291
+
210
292
stats = []
211
293
tok_count = 0
212
294
@@ -238,73 +320,48 @@ def _extract_corpus_statistics(
238
320
239
321
return stats
240
322
241
- def _compute_score_from_stats (self , stats : list [int ]) -> sacrebleu .BLEUScore :
323
+ def _compute_score_from_accumulator (
324
+ self , accumulator : _Accumulator
325
+ ) -> sacrebleu .BLEUScore :
242
326
"""Computes the final score from already aggregated statistics.
243
327
244
- 'stats' semantics are preserved here from the wrapped implementation.
245
- stats = [hyp_len, ref_len, matching_ngrams, total_ngrams] where
246
- hyp_len = number of unigrams (words) in the hypothesis
247
- ref_len = number of unigrams (words) in the reference
248
- Note, ending punctuation (periods, exclamation points, etc.) count as
249
- their own unigram.
250
- For example, 'Google.' has 2 unigrams: 'Google' and '.'
251
- matching_ngrams[n - 1] = number of matching n-grams for n > 0
252
- matching_ngrams[0] = number of matching unigrams
253
- matching_ngrams[1] = number of matching bigrams
254
- ...
255
- total_ngrams[n - 1] = number of n-grams in hyp for n > 0
256
- total_ngrams[] follows same pattern as matching_ngrams[]
257
-
258
328
Args:
259
- stats: A list of segment-level statistics.
329
+ accumulator: An accumulator containing segment-level statistics.
260
330
261
331
Returns:
262
332
A 'BLEUScore' object.
263
333
"""
264
334
bleu_metric = self .bleu_metric
265
335
266
- # matching_ngrams[n - 1] = number of matching n-grams for n > 0
267
- matching_ngrams = stats [2 : 2 + bleu_metric .max_ngram_order ]
268
-
269
- # total_ngrams[n - 1] = number of n-grams in hyp for n > 0
270
- total_ngrams = stats [2 + bleu_metric .max_ngram_order :]
271
-
272
- # hyp_len = number of unigrams (words) in the hypothesis
273
- hyp_len = int (stats [0 ])
274
-
275
- # ref_len = number of unigrams (words) in the reference
276
- ref_len = int (stats [1 ])
277
-
336
+ # TODO: b/319702245 - Resolve the issue below in compute_bleu().
337
+ # We need to convert the np.ndarray's to a lists here.
338
+ # If we leave it as a np.ndarray of ints, then sacrebleu will not be able to
339
+ # add decimal smooth values to the stats list within compute_bleu().
340
+ # If we convert it to an np.ndarray of floats, then sacrebleu will not be
341
+ # able to propely set BLEUScore._verbose because there is no format code 'd'
342
+ # for floats.
278
343
return self .bleu_metric .compute_bleu (
279
- correct = matching_ngrams ,
280
- total = total_ngrams ,
281
- sys_len = hyp_len ,
282
- ref_len = ref_len ,
344
+ correct = accumulator . matching_ngrams . tolist () ,
345
+ total = accumulator . total_ngrams . tolist () ,
346
+ sys_len = accumulator . hyp_len ,
347
+ ref_len = accumulator . ref_len ,
283
348
smooth_method = bleu_metric .smooth_method ,
284
349
smooth_value = bleu_metric .smooth_value ,
285
350
effective_order = bleu_metric .effective_order ,
286
351
max_ngram_order = bleu_metric .max_ngram_order ,
287
352
)
288
353
289
354
def create_accumulator (self ):
290
- """Accumulator is the running total of 'stats' of type np.ndarray.
291
-
292
- Args: None.
293
-
294
- Returns:
295
- 'stats' list of all zeros.
296
- """
297
- # TODO: b/321082946 - Replace 'stats' semantics with a dataclass.
298
- # len(stats)
299
- # = len(hyp_len) + len(ref_len) + len(matching_ngrams) + len(total_ngrams)
300
- # = 1 + 1 + max_ngram_order + max_ngram_order = 2 + 2 * max_ngram_order
301
- return np .zeros (2 + 2 * self .bleu_metric .max_ngram_order , dtype = int )
355
+ return _Accumulator (
356
+ matching_ngrams = np .zeros (self .bleu_metric .max_ngram_order , dtype = int ),
357
+ total_ngrams = np .zeros (self .bleu_metric .max_ngram_order , dtype = int ),
358
+ )
302
359
303
360
def add_input (
304
361
self ,
305
- accumulator : np . ndarray ,
362
+ accumulator : _Accumulator ,
306
363
element : metric_types .StandardMetricInputs ,
307
- ) -> np . ndarray :
364
+ ) -> _Accumulator :
308
365
# references = labels, hypotheses = predictions
309
366
references , hypotheses , _ = next (
310
367
metric_util .to_label_prediction_example_weight (
@@ -318,28 +375,31 @@ def add_input(
318
375
)
319
376
)
320
377
321
- # Sum accumulator and new stats
322
- return accumulator + np . sum (
323
- self . _extract_corpus_statistics ( hypotheses , references ), axis = 0
324
- )
378
+ corpus_stats = self . _extract_corpus_statistics ( hypotheses , references )
379
+ corpus_stats . append ( accumulator )
380
+
381
+ return self . merge_accumulators ( corpus_stats )
325
382
326
383
def merge_accumulators (
327
- self , list_of_stats : Iterable [np .ndarray ]
328
- ) -> np .ndarray :
329
- """Sum of list of stats."""
330
- return np .sum (list_of_stats , axis = 0 )
384
+ self , accumulators : Iterable [_Accumulator ]
385
+ ) -> _Accumulator :
386
+ accumulators = iter (accumulators )
387
+ result = next (accumulators )
388
+ for accumulator in accumulators :
389
+ result .hyp_len += accumulator .hyp_len
390
+ result .ref_len += accumulator .ref_len
391
+ result .matching_ngrams = np .sum (
392
+ [result .matching_ngrams , accumulator .matching_ngrams ], axis = 0
393
+ )
394
+ result .total_ngrams = np .sum (
395
+ [result .total_ngrams , accumulator .total_ngrams ], axis = 0
396
+ )
397
+ return result
331
398
332
399
def extract_output (
333
- self , accumulator : np . ndarray
400
+ self , accumulator : _Accumulator
334
401
) -> dict [metric_types .MetricKey , sacrebleu .BLEUScore ]:
335
- # TODO: b/319702245 - Resolve the issue below in compute_bleu().
336
- # We need to convert the accumulator to a list here.
337
- # If we leave it as a np.ndarray of ints, then sacrebleu will not be able to
338
- # add decimal smooth values to the stats list within compute_bleu().
339
- # If we convert it to an np.ndarray of floats, then sacrebleu will not be
340
- # able to propely set BLEUScore._verbose because there is no format code 'd'
341
- # for floats.
342
- return {self .key : self ._compute_score_from_stats (accumulator .tolist ())}
402
+ return {self .key : self ._compute_score_from_accumulator (accumulator )}
343
403
344
404
345
405
def _bleu (
0 commit comments