Skip to content

Commit 63c7f29

Browse files
committed
Optimize Adam and AdaMax Optimizer
1 parent 52993b1 commit 63c7f29

File tree

5 files changed

+16
-7
lines changed

5 files changed

+16
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
- 2.3.3
2+
- Optimize Adam and AdaMax Optimizers
3+
14
- 2.3.2
25
- Update PHP Stemmer to version 3
36

src/NeuralNet/Optimizers/AdaMax.php

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ public function step(Parameter $param, Tensor $gradient) : Tensor
8282
{
8383
[$velocity, $norm] = $this->cache[$param->id()];
8484

85-
$velocity = $velocity->multiply(1.0 - $this->momentumDecay)
86-
->add($gradient->multiply($this->momentumDecay));
85+
$vHat = $gradient->subtract($velocity)
86+
->multiply($this->momentumDecay);
87+
88+
$velocity = $velocity->add($vHat);
8789

8890
$norm = $norm->multiply(1.0 - $this->normDecay);
8991

src/NeuralNet/Optimizers/Adam.php

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,15 @@ public function step(Parameter $param, Tensor $gradient) : Tensor
122122
{
123123
[$velocity, $norm] = $this->cache[$param->id()];
124124

125-
$velocity = $velocity->multiply(1.0 - $this->momentumDecay)
126-
->add($gradient->multiply($this->momentumDecay));
125+
$vHat = $gradient->subtract($velocity)
126+
->multiply($this->momentumDecay);
127127

128-
$norm = $norm->multiply(1.0 - $this->normDecay)
129-
->add($gradient->square()->multiply($this->normDecay));
128+
$velocity = $velocity->add($vHat);
129+
130+
$nHat = $gradient->square()->subtract($norm)
131+
->multiply($this->normDecay);
132+
133+
$norm = $norm->add($nHat);
130134

131135
$this->cache[$param->id()] = [$velocity, $norm];
132136

src/NeuralNet/Snapshotter.php

Whitespace-only changes.

tests/NeuralNet/Optimizers/AdamTest.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public function step(Parameter $param, Tensor $gradient, array $expected) : void
5454

5555
$step = $this->optimizer->step($param, $gradient);
5656

57-
$this->assertEquals($expected, $step->asArray());
57+
$this->assertEqualsWithDelta($expected, $step->asArray(), 1e-8);
5858
}
5959

6060
/**

0 commit comments

Comments
 (0)