Skip to content

Commit 08b0fd9

Browse files
authored
Increase coverage in optimizers (#21337)
* Update .gitignore * Add tests for invalid parameters in optimizers: AdamW, Ftrl, Lion, Nadam, and SGD * Ran formatting * Revert "Update .gitignore" This reverts commit ae56fe7.
1 parent 802299e commit 08b0fd9

File tree

5 files changed

+85
-0
lines changed

5 files changed

+85
-0
lines changed

keras/src/optimizers/adamw_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ def test_weight_decay(self):
5252
self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)
5353
self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)
5454

55+
def test_weight_decay_is_none(self):
56+
with self.assertRaisesRegex(
57+
ValueError,
58+
"Argument `weight_decay` must be a float. "
59+
"Received: weight_decay=None",
60+
):
61+
AdamW(learning_rate=1.0, weight_decay=None)
62+
5563
def test_correctness_with_golden(self):
5664
optimizer = AdamW(learning_rate=1.0, weight_decay=0.5, epsilon=2)
5765

keras/src/optimizers/ftrl_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
import numpy as np
5+
from unittest import mock
56

67
from keras.src import backend
78
from keras.src import testing
@@ -71,3 +72,43 @@ def test_clip_value(self):
7172
grad = [np.array([100.0, 100.0])]
7273
clipped_grad = optimizer._clip_gradients(grad)
7374
self.assertAllClose(clipped_grad[0], [1.0, 1.0])
75+
76+
def test_invalid_initial_accumulator_value(self):
77+
invalid_value = -0.1
78+
with self.assertRaisesRegex(
79+
ValueError,
80+
f"^`initial_accumulator_value` needs to be positive or zero. Received: initial_accumulator_value={invalid_value}.$",
81+
):
82+
Ftrl(initial_accumulator_value=invalid_value)
83+
84+
def test_invalid_learning_rate_power(self):
85+
invalid_value = 0.1
86+
with self.assertRaisesRegex(
87+
ValueError,
88+
f"^`learning_rate_power` needs to be negative or zero. Received: learning_rate_power={invalid_value}.$",
89+
):
90+
Ftrl(learning_rate_power=invalid_value)
91+
92+
def test_invalid_l1_regularization_strength(self):
93+
invalid_value = -0.1
94+
with self.assertRaisesRegex(
95+
ValueError,
96+
f"^`l1_regularization_strength` needs to be positive or zero. Received: l1_regularization_strength={invalid_value}.$",
97+
):
98+
Ftrl(l1_regularization_strength=invalid_value)
99+
100+
def test_invalid_l2_regularization_strength(self):
101+
invalid_value = -0.1
102+
with self.assertRaisesRegex(
103+
ValueError,
104+
f"^`l2_regularization_strength` needs to be positive or zero. Received: l2_regularization_strength={invalid_value}.$",
105+
):
106+
Ftrl(l2_regularization_strength=invalid_value)
107+
108+
def test_invalid_l2_shrinkage_regularization_strength(self):
109+
invalid_value = -0.1
110+
with self.assertRaisesRegex(
111+
ValueError,
112+
f"^`l2_shrinkage_regularization_strength` needs to be positive or zero. Received: l2_shrinkage_regularization_strength={invalid_value}.$",
113+
):
114+
Ftrl(l2_shrinkage_regularization_strength=invalid_value)

keras/src/optimizers/lion_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,26 @@
99

1010

1111
class LionTest(testing.TestCase):
12+
def test_invalid_beta_1(self):
13+
with self.assertRaisesRegex(
14+
ValueError,
15+
"Argument `beta_1` must be in the \\[0, 1\\] range. Otherwise, the "
16+
"optimizer degenerates to SignSGD. Received: beta_1=-0.1.",
17+
):
18+
Lion(beta_1=-0.1)
19+
with self.assertRaisesRegex(
20+
ValueError,
21+
"Argument `beta_1` must be in the \\[0, 1\\] range. Otherwise, the "
22+
"optimizer degenerates to SignSGD. Received: beta_1=0.0.",
23+
):
24+
Lion(beta_1=0.0)
25+
with self.assertRaisesRegex(
26+
ValueError,
27+
"Argument `beta_1` must be in the \\[0, 1\\] range. Otherwise, the "
28+
"optimizer degenerates to SignSGD. Received: beta_1=1.1.",
29+
):
30+
Lion(beta_1=1.1)
31+
1232
def test_config(self):
1333
optimizer = Lion(
1434
learning_rate=0.5,

keras/src/optimizers/nadam_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ def test_config(self):
1919
)
2020
self.run_class_serialization_test(optimizer)
2121

22+
def test_build_with_empty_var_list(self):
23+
optimizer = Nadam()
24+
optimizer.build([])
25+
self.assertEqual(optimizer._u_product.dtype, backend.floatx())
26+
2227
def test_single_step(self):
2328
optimizer = Nadam(learning_rate=0.5)
2429
grads = ops.array([1.0, 6.0, 7.0, 2.0])

keras/src/optimizers/sgd_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ def test_single_step(self):
3030
self.assertEqual(optimizer.variables[0], 1)
3131
self.assertEqual(optimizer.variables[1], 0.5)
3232

33+
def test_invalid_momentum(self):
34+
with self.assertRaisesRegex(
35+
ValueError, "`momentum` must be a float between \\[0, 1\\]."
36+
):
37+
SGD(momentum=-1.0)
38+
39+
with self.assertRaisesRegex(
40+
ValueError, "`momentum` must be a float between \\[0, 1\\]."
41+
):
42+
SGD(momentum=2.0)
43+
3344
def test_weight_decay(self):
3445
grads, var1, var2, var3 = (
3546
ops.zeros(()),

0 commit comments

Comments
 (0)