Skip to content

Commit c03d5cc

Browse files
Ali Siddiquirmlarsen
authored andcommitted
NADAM Optimizer (tensorflow#9889)
* Initial commit for NADAM * Add GPU and sparse implementations and add missing arguments * Add tester * Revert changes to files made for testing * Add nadam optimizer in a class of its own * Reverse changes to adam_test.py * Reverse changes to adam.py * Actually reverse adam_test.py * Actually reverse adam.py * Delete nadam_optimizer_test.py * Create nadam_optimizer_test.py * Fix BUILD * Run buildifier on BUILD
1 parent 15d9f00 commit c03d5cc

File tree

8 files changed

+318
-15
lines changed

8 files changed

+318
-15
lines changed

tensorflow/contrib/opt/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ py_library(
1818
"python/training/external_optimizer.py",
1919
"python/training/lazy_adam_optimizer.py",
2020
"python/training/moving_average_optimizer.py",
21+
"python/training/nadam_optimizer.py",
2122
"python/training/variable_clipping_optimizer.py",
2223
],
2324
srcs_version = "PY2AND3",
@@ -106,6 +107,23 @@ py_test(
106107
],
107108
)
108109

110+
py_test(
111+
name = "nadam_optimizer_test",
112+
srcs = ["python/training/nadam_optimizer_test.py"],
113+
srcs_version = "PY2AND3",
114+
deps = [
115+
":opt_py",
116+
"//tensorflow/python:array_ops",
117+
"//tensorflow/python:client_testlib",
118+
"//tensorflow/python:constant_op",
119+
"//tensorflow/python:dtypes",
120+
"//tensorflow/python:framework_ops",
121+
"//tensorflow/python:math_ops",
122+
"//tensorflow/python:variables",
123+
"//third_party/py/numpy",
124+
],
125+
)
126+
109127
tf_py_test(
110128
name = "drop_stale_gradient_optimizer_test",
111129
srcs = ["python/training/drop_stale_gradient_optimizer_test.py"],

tensorflow/contrib/opt/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import *
2323
from tensorflow.contrib.opt.python.training.external_optimizer import *
2424
from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
25+
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
2526
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
2627
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
2728
# pylint: enable=wildcard-import
@@ -31,6 +32,7 @@
3132
_allowed_symbols = ['DropStaleGradientOptimizer',
3233
'ExternalOptimizerInterface',
3334
'LazyAdamOptimizer',
35+
'NadamOptimizer',
3436
'MovingAverageOptimizer',
3537
'ScipyOptimizerInterface',
3638
'VariableClippingOptimizer']
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Nadam for TensorFlow."""
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from tensorflow.python.framework import ops
22+
from tensorflow.python.ops import control_flow_ops
23+
from tensorflow.python.ops import math_ops
24+
from tensorflow.python.ops import state_ops
25+
from tensorflow.python.training import training_ops
26+
from tensorflow.python.training import adam
27+
28+
29+
class NadamOptimizer(adam.AdamOptimizer):
30+
"""Optimizer that implements the Nadam algorithm.
31+
32+
See [Dozat, T., 2015](http://cs229.stanford.edu/proj2015/054_report.pdf).
33+
"""
34+
35+
def _apply_dense(self, grad, var):
36+
m = self.get_slot(var, "m")
37+
v = self.get_slot(var, "v")
38+
return training_ops.apply_adam(
39+
var, m, v,
40+
math_ops.cast(self._beta1_power, var.dtype.base_dtype),
41+
math_ops.cast(self._beta2_power, var.dtype.base_dtype),
42+
math_ops.cast(self._lr_t, var.dtype.base_dtype),
43+
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
44+
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
45+
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
46+
grad, use_locking=self._use_locking,
47+
use_nesterov=True).op
48+
49+
def _resource_apply_dense(self, grad, var):
50+
m = self.get_slot(var, "m")
51+
v = self.get_slot(var, "v")
52+
return training_ops.resource_apply_adam(
53+
var.handle, m.handle, v.handle,
54+
math_ops.cast(self._beta1_power, grad.dtype.base_dtype),
55+
math_ops.cast(self._beta2_power, grad.dtype.base_dtype),
56+
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
57+
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
58+
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
59+
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
60+
grad, use_locking=self._use_locking,
61+
use_nesterov=True)
62+
63+
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
64+
beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
65+
beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
66+
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
67+
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
68+
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
69+
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
70+
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
71+
# m_t = beta1 * m + (1 - beta1) * g_t
72+
m = self.get_slot(var, "m")
73+
m_scaled_g_values = grad * (1 - beta1_t)
74+
m_t = state_ops.assign(m, m * beta1_t,
75+
use_locking=self._use_locking)
76+
with ops.control_dependencies([m_t]):
77+
m_t = scatter_add(m, indices, m_scaled_g_values)
78+
# m_bar = (1 - beta1) * g_t + beta1 * m_t
79+
m_bar = m_scaled_g_values + beta1_t * m_t
80+
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
81+
v = self.get_slot(var, "v")
82+
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
83+
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
84+
with ops.control_dependencies([v_t]):
85+
v_t = scatter_add(v, indices, v_scaled_g_values)
86+
v_sqrt = math_ops.sqrt(v_t)
87+
var_update = state_ops.assign_sub(var,
88+
lr * m_bar / (v_sqrt + epsilon_t),
89+
use_locking=self._use_locking)
90+
return control_flow_ops.group(*[var_update, m_bar, v_t])
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for Nadam."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
23+
from tensorflow.python.client import session
24+
from tensorflow.python.framework import constant_op
25+
from tensorflow.python.framework import dtypes
26+
from tensorflow.python.framework import ops
27+
from tensorflow.python.ops import array_ops
28+
from tensorflow.python.ops import math_ops
29+
from tensorflow.python.ops import resource_variable_ops
30+
from tensorflow.python.ops import variables
31+
from tensorflow.python.platform import test
32+
from tensorflow.contrib.opt.python.training import nadam_optimizer
33+
34+
35+
def nadam_update_numpy(param,
36+
g_t,
37+
t,
38+
m,
39+
v,
40+
alpha=0.001,
41+
beta1=0.9,
42+
beta2=0.999,
43+
epsilon=1e-8):
44+
alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
45+
46+
m_t = beta1 * m + (1 - beta1) * g_t
47+
v_t = beta2 * v + (1 - beta2) * g_t * g_t
48+
49+
m_bar = (1 - beta1) * g_t + beta1 * m_t
50+
51+
param_t = param - alpha_t * m_bar / (np.sqrt(v_t) + epsilon)
52+
return param_t, m_t, v_t
53+
54+
55+
class NadamOptimizerTest(test.TestCase):
56+
57+
def doTestSparse(self, use_resource=False):
58+
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
59+
with self.test_session():
60+
# Initialize variables for numpy implementation.
61+
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
62+
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
63+
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
64+
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
65+
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
66+
67+
if use_resource:
68+
var0 = resource_variable_ops.ResourceVariable(var0_np)
69+
var1 = resource_variable_ops.ResourceVariable(var1_np)
70+
else:
71+
var0 = variables.Variable(var0_np)
72+
var1 = variables.Variable(var1_np)
73+
grads0_np_indices = np.array([0, 1], dtype=np.int32)
74+
grads0 = ops.IndexedSlices(
75+
constant_op.constant(grads0_np),
76+
constant_op.constant(grads0_np_indices), constant_op.constant([2]))
77+
grads1_np_indices = np.array([0, 1], dtype=np.int32)
78+
grads1 = ops.IndexedSlices(
79+
constant_op.constant(grads1_np),
80+
constant_op.constant(grads1_np_indices), constant_op.constant([2]))
81+
opt = nadam_optimizer.NadamOptimizer()
82+
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
83+
variables.global_variables_initializer().run()
84+
85+
# Fetch params to validate initial values
86+
self.assertAllClose([1.0, 2.0], var0.eval())
87+
self.assertAllClose([3.0, 4.0], var1.eval())
88+
89+
beta1_power, beta2_power = opt._get_beta_accumulators()
90+
91+
# Run 3 steps of Nadam
92+
for t in range(1, 4):
93+
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
94+
self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
95+
update.run()
96+
97+
var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0)
98+
var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1)
99+
100+
# Validate updated params
101+
self.assertAllCloseAccordingToType(var0_np, var0.eval())
102+
self.assertAllCloseAccordingToType(var1_np, var1.eval())
103+
104+
def testSparse(self):
105+
self.doTestSparse(use_resource=False)
106+
107+
def testResourceSparse(self):
108+
self.doTestSparse(use_resource=True)
109+
110+
def doTestBasic(self, use_resource=False):
111+
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
112+
with self.test_session():
113+
# Initialize variables for numpy implementation.
114+
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
115+
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
116+
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
117+
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
118+
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
119+
120+
if use_resource:
121+
var0 = resource_variable_ops.ResourceVariable(var0_np)
122+
var1 = resource_variable_ops.ResourceVariable(var1_np)
123+
else:
124+
var0 = variables.Variable(var0_np)
125+
var1 = variables.Variable(var1_np)
126+
grads0 = constant_op.constant(grads0_np)
127+
grads1 = constant_op.constant(grads1_np)
128+
opt = nadam_optimizer.NadamOptimizer()
129+
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
130+
variables.global_variables_initializer().run()
131+
132+
# Fetch params to validate initial values
133+
self.assertAllClose([1.0, 2.0], var0.eval())
134+
self.assertAllClose([3.0, 4.0], var1.eval())
135+
136+
beta1_power, beta2_power = opt._get_beta_accumulators()
137+
138+
# Run 3 steps of Nadam
139+
for t in range(1, 4):
140+
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
141+
self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
142+
update.run()
143+
144+
var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0)
145+
var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1)
146+
147+
# Validate updated params
148+
self.assertAllCloseAccordingToType(var0_np, var0.eval())
149+
self.assertAllCloseAccordingToType(var1_np, var1.eval())
150+
151+
def testBasic(self):
152+
self.doTestBasic(use_resource=False)
153+
154+
def testResourceBasic(self):
155+
self.doTestBasic(use_resource=True)
156+
157+
if __name__ == "__main__":
158+
test.main()

tensorflow/core/kernels/training_ops.cc

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,22 @@ struct ApplyAdamNonCuda {
245245
typename TTypes<T>::ConstScalar beta1,
246246
typename TTypes<T>::ConstScalar beta2,
247247
typename TTypes<T>::ConstScalar epsilon,
248-
typename TTypes<T>::ConstFlat grad) {
248+
typename TTypes<T>::ConstFlat grad,
249+
bool use_nesterov) {
249250
const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) /
250251
(T(1) - beta1_power());
252+
// beta1 == μ
253+
// beta2 == ν
254+
// v == n
255+
// var == θ
256+
251257
m.device(d) += (grad - m) * (T(1) - beta1());
252258
v.device(d) += (grad.square() - v) * (T(1) - beta2());
253-
var.device(d) -= (m * alpha) / (v.sqrt() + epsilon());
259+
if (use_nesterov) {
260+
var.device(d) -= ((grad * (T(1) - beta1()) + beta1() * m) * alpha) / (v.sqrt() + epsilon());
261+
} else {
262+
var.device(d) -= (m * alpha) / (v.sqrt() + epsilon());
263+
}
254264
}
255265
};
256266

@@ -2248,6 +2258,7 @@ class ApplyAdamOp : public OpKernel {
22482258
public:
22492259
explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
22502260
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
2261+
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
22512262
}
22522263

22532264
void Compute(OpKernelContext* ctx) override {
@@ -2322,13 +2333,15 @@ class ApplyAdamOp : public OpKernel {
23222333
v.flat<T>(), beta1_power.scalar<T>(),
23232334
beta2_power.scalar<T>(), lr.scalar<T>(),
23242335
beta1.scalar<T>(), beta2.scalar<T>(),
2325-
epsilon.scalar<T>(), grad.flat<T>());
2336+
epsilon.scalar<T>(), grad.flat<T>(),
2337+
use_nesterov_);
23262338

23272339
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
23282340
}
23292341

23302342
private:
23312343
bool use_exclusive_lock_;
2344+
bool use_nesterov_;
23322345
};
23332346

23342347
using CPUDevice = Eigen::ThreadPoolDevice;
@@ -2372,7 +2385,8 @@ namespace functor {
23722385
typename TTypes<T>::ConstScalar beta1, \
23732386
typename TTypes<T>::ConstScalar beta2, \
23742387
typename TTypes<T>::ConstScalar epsilon, \
2375-
typename TTypes<T>::ConstFlat grad); \
2388+
typename TTypes<T>::ConstFlat grad, \
2389+
bool use_nesterov); \
23762390
extern template struct ApplyAdam<GPUDevice, T>;
23772391
DECLARE_GPU_SPEC(Eigen::half);
23782392
DECLARE_GPU_SPEC(float);

tensorflow/core/kernels/training_ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ struct ApplyAdam {
123123
typename TTypes<T>::ConstScalar beta1,
124124
typename TTypes<T>::ConstScalar beta2,
125125
typename TTypes<T>::ConstScalar epsilon,
126-
typename TTypes<T>::ConstFlat grad);
126+
typename TTypes<T>::ConstFlat grad,
127+
bool use_nesterov);
127128
};
128129

129130
template <typename Device, typename T>

0 commit comments

Comments
 (0)