Skip to content

Commit 624c00b

Browse files
authored
feat: add Categorical Generalized Cross Entropy (GCE) loss (keras-team#21024)
* feat: add Categorical Generalized Cross Entropy (GCE) loss * run api generation * docs: Align docstrings with Keras style guide * docs: more docstring changes
1 parent 33d97b0 commit 624c00b

File tree

4 files changed

+340
-0
lines changed

4 files changed

+340
-0
lines changed

keras/api/_tf_keras/keras/losses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from keras.src.losses.losses import BinaryFocalCrossentropy
1515
from keras.src.losses.losses import CategoricalCrossentropy
1616
from keras.src.losses.losses import CategoricalFocalCrossentropy
17+
from keras.src.losses.losses import CategoricalGeneralizedCrossEntropy
1718
from keras.src.losses.losses import CategoricalHinge
1819
from keras.src.losses.losses import Circle
1920
from keras.src.losses.losses import CosineSimilarity
@@ -34,6 +35,7 @@
3435
from keras.src.losses.losses import binary_focal_crossentropy
3536
from keras.src.losses.losses import categorical_crossentropy
3637
from keras.src.losses.losses import categorical_focal_crossentropy
38+
from keras.src.losses.losses import categorical_generalized_cross_entropy
3739
from keras.src.losses.losses import categorical_hinge
3840
from keras.src.losses.losses import circle
3941
from keras.src.losses.losses import cosine_similarity

keras/api/losses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from keras.src.losses.losses import BinaryFocalCrossentropy
1414
from keras.src.losses.losses import CategoricalCrossentropy
1515
from keras.src.losses.losses import CategoricalFocalCrossentropy
16+
from keras.src.losses.losses import CategoricalGeneralizedCrossEntropy
1617
from keras.src.losses.losses import CategoricalHinge
1718
from keras.src.losses.losses import Circle
1819
from keras.src.losses.losses import CosineSimilarity
@@ -33,6 +34,7 @@
3334
from keras.src.losses.losses import binary_focal_crossentropy
3435
from keras.src.losses.losses import categorical_crossentropy
3536
from keras.src.losses.losses import categorical_focal_crossentropy
37+
from keras.src.losses.losses import categorical_generalized_cross_entropy
3638
from keras.src.losses.losses import categorical_hinge
3739
from keras.src.losses.losses import circle
3840
from keras.src.losses.losses import cosine_similarity

keras/src/losses/losses.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,86 @@ def get_config(self):
15041504
return config
15051505

15061506

1507+
@keras_export("keras.losses.CategoricalGeneralizedCrossEntropy")
1508+
class CategoricalGeneralizedCrossEntropy(LossFunctionWrapper):
1509+
"""Computes the Generalized Cross Entropy loss between `y_true` & `y_pred`.
1510+
1511+
Generalized Cross Entropy (GCE) is a noise-robust loss function
1512+
that provides better robustness against noisy labels than
1513+
standard cross entropy.
1514+
It generalizes both cross entropy and mean absolute error through
1515+
the parameter q, where values closer to 1 make the loss more robust
1516+
to noisy labels.
1517+
1518+
Formula:
1519+
```python
1520+
loss = (1 - p**q) / q
1521+
```
1522+
where `p` is the predicted probability for the true class and `q`
1523+
is the noise parameter.
1524+
1525+
Args:
1526+
q: Float in range `(0, 1)`. It is the noise parameter.
1527+
Controls the behavior of the loss:
1528+
- As `q` approaches 0: Behaves more like cross entropy
1529+
- As `q` approaches 1: Behaves more like mean absolute error
1530+
Defaults to `0.5`
1531+
reduction: Type of reduction to apply to the loss. In almost all cases
1532+
this should be `"sum_over_batch_size"`. Supported options are
1533+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
1534+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
1535+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
1536+
sample size, and `"mean_with_sample_weight"` sums the loss and
1537+
divides by the sum of the sample weights. `"none"` and `None`
1538+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
1539+
name: Optional name for the loss instance.
1540+
dtype: The dtype of the loss's computations. Defaults to `None`, which
1541+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
1542+
`"float32"` unless set to different value
1543+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
1544+
provided, then the `compute_dtype` will be utilized.
1545+
1546+
Example:
1547+
```python
1548+
y_true = np.array([0, 1, 0, 1])
1549+
y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.4, 0.6]])
1550+
keras.losses.CategoricalGeneralizedCrossEntropy()(y_true, y_pred)
1551+
```
1552+
1553+
References:
1554+
- [Zhang, Sabuncu, 2018](https://arxiv.org/abs/1805.07836)
1555+
("Generalized Cross Entropy Loss for Training
1556+
Deep Neural Networks with Noisy Labels")
1557+
"""
1558+
1559+
def __init__(
1560+
self,
1561+
q=0.5,
1562+
reduction="sum_over_batch_size",
1563+
name="categorical_generalized_cross_entropy",
1564+
dtype=None,
1565+
):
1566+
if not 0 < q < 1:
1567+
raise ValueError("q must be in the interval (0, 1)")
1568+
super().__init__(
1569+
categorical_generalized_cross_entropy,
1570+
name=name,
1571+
reduction=reduction,
1572+
dtype=dtype,
1573+
q=q,
1574+
)
1575+
self.q = q
1576+
1577+
def get_config(self):
1578+
config = Loss.get_config(self)
1579+
config.update(
1580+
{
1581+
"q": self.q,
1582+
}
1583+
)
1584+
return config
1585+
1586+
15071587
def convert_binary_labels_to_hinge(y_true):
15081588
"""Converts binary labels into -1/1 for hinge loss/metric calculation."""
15091589
are_zeros = ops.equal(y_true, 0)
@@ -2609,3 +2689,54 @@ def circle(
26092689
circle_loss = ops.softplus(p_loss + n_loss)
26102690
backend.set_keras_mask(circle_loss, circle_loss > 0)
26112691
return circle_loss
2692+
2693+
2694+
@keras_export("keras.losses.categorical_generalized_cross_entropy")
2695+
def categorical_generalized_cross_entropy(y_true, y_pred, q):
2696+
"""Computes the Generalized Cross Entropy loss.
2697+
2698+
Generalized Cross Entropy (GCE) is a noise-robust loss function that
2699+
provides better robustness against noisy labels than standard cross entropy.
2700+
It generalizes both cross entropy and mean absolute error through
2701+
the parameter q, where values closer to 1 make the loss more robust
2702+
to noisy labels.
2703+
2704+
Formula:
2705+
```python
2706+
loss = (1 - p**q) / q
2707+
```
2708+
where `p` is the predicted probability for the true class and `q`
2709+
is the noise parameter.
2710+
2711+
Args:
2712+
y_true: Ground truth labels. Expected to contain *integer class indices*
2713+
with shape `[batch_size]` or `[batch_size, 1]`.
2714+
y_pred: The predicted class probabilities, with shape
2715+
`[batch_size, num_classes]`.
2716+
q: Float in range `(0, 1)`. It is the noise parameter.
2717+
Controls the behavior of the loss:
2718+
- As `q` approaches 0: Behaves more like cross entropy
2719+
- As `q` approaches 1: Behaves more like mean absolute error
2720+
2721+
Returns:
2722+
GCE loss values with shape `[batch_size]`.
2723+
```
2724+
2725+
References:
2726+
- [Zhang, Sabuncu, 2018](https://arxiv.org/abs/1805.07836)
2727+
("Generalized Cross Entropy Loss for Training
2728+
Deep Neural Networks with Noisy Labels")
2729+
"""
2730+
2731+
# Convert y_true to integer type and one-hot encode
2732+
y_true_one_hot = ops.one_hot(
2733+
ops.cast(y_true, "int"), num_classes=ops.shape(y_pred)[-1]
2734+
)
2735+
y_true_one_hot = ops.cast(y_true_one_hot, y_pred.dtype)
2736+
# Calculate the probability of the true class
2737+
p = ops.sum(y_pred * y_true_one_hot, axis=-1)
2738+
2739+
# Compute the GCE loss for q in (0,1)
2740+
gce_loss = (1 - ops.power(p, q)) / q
2741+
2742+
return gce_loss

keras/src/losses/losses_test.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,3 +1763,208 @@ def test_dtype_arg(self):
17631763
circle_loss = losses.Circle(dtype="bfloat16")
17641764
loss = circle_loss(self.y_true, self.y_pred)
17651765
self.assertDType(loss, "bfloat16")
1766+
1767+
1768+
class CategoricalGeneralizedCrossEntropyTest(testing.TestCase):
1769+
def test_config(self):
1770+
self.run_class_serialization_test(
1771+
losses.CategoricalGeneralizedCrossEntropy(name="gce")
1772+
)
1773+
self.run_class_serialization_test(
1774+
losses.CategoricalGeneralizedCrossEntropy(q=0.1, name="gce")
1775+
)
1776+
1777+
def test_basic_correctness_for_binary(self):
1778+
y_true = np.array([0, 1, 0, 1])
1779+
y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.4, 0.6]])
1780+
# Calculate expected GCE loss manually
1781+
# For q=0.5:
1782+
# First sample (class 0): gce = (1 - 0.7^0.5) / 0.5
1783+
# Second sample (class 1): gce = (1 - 0.8^0.5) / 0.5
1784+
# Third sample (class 0): gce = (1 - 0.6^0.5) / 0.5
1785+
# Fourth sample (class 1): gce = (1 - 0.6^0.5) / 0.5
1786+
expected = np.array(
1787+
[
1788+
(1 - np.power(0.7, 0.5)) / 0.5,
1789+
(1 - np.power(0.8, 0.5)) / 0.5,
1790+
(1 - np.power(0.6, 0.5)) / 0.5,
1791+
(1 - np.power(0.6, 0.5)) / 0.5,
1792+
]
1793+
)
1794+
output = losses.CategoricalGeneralizedCrossEntropy()(y_true, y_pred)
1795+
self.assertAllClose(output, expected.sum() / len(expected))
1796+
1797+
expected_q_08 = np.array(
1798+
[
1799+
(1 - np.power(0.7, 0.8)) / 0.8,
1800+
(1 - np.power(0.8, 0.8)) / 0.8,
1801+
(1 - np.power(0.6, 0.8)) / 0.8,
1802+
(1 - np.power(0.6, 0.8)) / 0.8,
1803+
]
1804+
)
1805+
output = losses.CategoricalGeneralizedCrossEntropy(q=0.8)(
1806+
y_true, y_pred
1807+
)
1808+
self.assertAllClose(output, expected_q_08.sum() / len(expected_q_08))
1809+
1810+
def test_basic_correctness_for_multi_class(self):
1811+
y_true = np.array([0, 1, 0, 1])
1812+
y_pred = np.array(
1813+
[[0.7, 0.3, 0.0], [0.2, 0.2, 0.6], [0.6, 0.4, 0.0], [0.2, 0.2, 0.6]]
1814+
)
1815+
# Calculate expected GCE loss manually
1816+
# For q=0.5:
1817+
# First sample (class 0): gce = (1 - 0.7^0.5) / 0.5
1818+
# Second sample (class 1): gce = (1 - 0^0.5) / 0.5
1819+
# Third sample (class 0): gce = (1 - 0.6^0.5) / 0.5
1820+
# Fourth sample (class 1): gce = (1 - 0.0^0.5) / 0.5
1821+
expected = np.array(
1822+
[
1823+
(1 - np.power(0.7, 0.5)) / 0.5,
1824+
(1 - np.power(0.2, 0.5)) / 0.5,
1825+
(1 - np.power(0.6, 0.5)) / 0.5,
1826+
(1 - np.power(0.2, 0.5)) / 0.5,
1827+
]
1828+
)
1829+
output = losses.CategoricalGeneralizedCrossEntropy()(y_true, y_pred)
1830+
self.assertAllClose(output, expected.sum() / len(expected))
1831+
1832+
expected_q_08 = np.array(
1833+
[
1834+
(1 - np.power(0.7, 0.8)) / 0.8,
1835+
(1 - np.power(0.2, 0.8)) / 0.8,
1836+
(1 - np.power(0.6, 0.8)) / 0.8,
1837+
(1 - np.power(0.2, 0.8)) / 0.8,
1838+
]
1839+
)
1840+
output = losses.CategoricalGeneralizedCrossEntropy(q=0.8)(
1841+
y_true, y_pred
1842+
)
1843+
self.assertAllClose(output, expected_q_08.sum() / len(expected_q_08))
1844+
1845+
def test_binary_segmentation(self):
1846+
y_true = np.array(
1847+
[[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]
1848+
)
1849+
y_pred = np.array(
1850+
[
1851+
[[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]],
1852+
[[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]],
1853+
[[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],
1854+
[[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]],
1855+
]
1856+
)
1857+
output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)(
1858+
y_true, y_pred
1859+
)
1860+
self.assertAllClose(output, 0.0)
1861+
1862+
y_true = np.array(
1863+
[[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]
1864+
)
1865+
y_pred = np.array(
1866+
[
1867+
[[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.2, 0.8]],
1868+
[[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]],
1869+
[[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],
1870+
[[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.6, 0.4]],
1871+
]
1872+
)
1873+
expected = np.array(
1874+
[
1875+
(1 - np.power(0.2, 0.5)) / 0.5,
1876+
(1 - np.power(0.4, 0.5)) / 0.5,
1877+
]
1878+
)
1879+
output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)(
1880+
y_true, y_pred
1881+
)
1882+
self.assertAllClose(output, expected.sum() / 16.0) # 16 pixels
1883+
1884+
def test_multi_class_segmentation(self):
1885+
y_true = np.array(
1886+
[[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]
1887+
)
1888+
y_pred = np.array(
1889+
[
1890+
[
1891+
[1.0, 0.0, 0.0],
1892+
[0.0, 1.0, 0.0],
1893+
[0.0, 0.0, 1.0],
1894+
[1.0, 0.0, 0.0],
1895+
],
1896+
[
1897+
[0.0, 1.0, 0.0],
1898+
[1.0, 0.0, 0.0],
1899+
[0.0, 1.0, 0.0],
1900+
[1.0, 0.0, 0.0],
1901+
],
1902+
[
1903+
[1.0, 0.0, 0.0],
1904+
[1.0, 0.0, 0.0],
1905+
[0.0, 1.0, 0.0],
1906+
[0.0, 1.0, 0.0],
1907+
],
1908+
[
1909+
[0.0, 1.0, 0.0],
1910+
[0.0, 1.0, 0.0],
1911+
[1.0, 0.0, 0.0],
1912+
[0.0, 1.0, 0.0],
1913+
],
1914+
]
1915+
)
1916+
output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)(
1917+
y_true, y_pred
1918+
)
1919+
self.assertAllClose(output, 0.0)
1920+
1921+
y_true = np.array(
1922+
[[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]
1923+
)
1924+
y_pred = np.array(
1925+
[
1926+
[
1927+
[1.0, 0.0, 0.0],
1928+
[0.0, 1.0, 0.0],
1929+
[0.0, 0.0, 1.0],
1930+
[0.2, 0.0, 0.8],
1931+
],
1932+
[
1933+
[1.0, 0.0, 0.0],
1934+
[1.0, 0.0, 0.0],
1935+
[0.0, 1.0, 0.0],
1936+
[1.0, 0.0, 0.0],
1937+
],
1938+
[
1939+
[1.0, 0.0, 0.0],
1940+
[1.0, 0.0, 0.0],
1941+
[0.0, 1.0, 0.0],
1942+
[0.0, 1.0, 0.0],
1943+
],
1944+
[
1945+
[0.0, 1.0, 0.0],
1946+
[0.0, 1.0, 0.0],
1947+
[0.5, 0.5, 0.0],
1948+
[0.0, 1.0, 0.0],
1949+
],
1950+
]
1951+
)
1952+
expected = np.array(
1953+
[
1954+
(1 - np.power(0.2, 0.5)) / 0.5,
1955+
(1 - np.power(0.0, 0.5)) / 0.5,
1956+
(1 - np.power(0.5, 0.5)) / 0.5,
1957+
]
1958+
)
1959+
output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)(
1960+
y_true, y_pred
1961+
)
1962+
self.assertAllClose(output, expected.sum() / 16.0) # 16 pixels
1963+
1964+
def test_dtype_arg(self):
1965+
y_true = np.array([0, 1, 0, 1])
1966+
y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.4, 0.6]])
1967+
output = losses.CategoricalGeneralizedCrossEntropy(dtype="bfloat16")(
1968+
y_true, y_pred
1969+
)
1970+
self.assertDType(output, "bfloat16")

0 commit comments

Comments
 (0)