Skip to content

Commit 20d5683

Browse files
Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator.
PiperOrigin-RevId: 210648271
1 parent 069f808 commit 20d5683

File tree

4 files changed

+154
-47
lines changed

4 files changed

+154
-47
lines changed

tensorflow/compiler/tests/ftrl_test.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,49 @@ def testFtrlWithL1_L2_L2Shrinkage(self):
259259

260260
# Validate updated params
261261
self.assertAllCloseAccordingToType(
262-
np.array([-0.21931979, -0.40642974]), var0.eval(), rtol=1e-4)
262+
np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4)
263263
self.assertAllCloseAccordingToType(
264-
np.array([-0.0282721, -0.07188385]), var1.eval(), rtol=1e-4)
264+
np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4)
265+
266+
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
267+
"""Verifies that l2 shrinkage in FTRL does not change lr schedule."""
268+
for dtype in self.float_types:
269+
with self.test_session(), self.test_scope():
270+
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
271+
var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
272+
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
273+
grads1 = constant_op.constant([0.1, 0.2], dtype=dtype)
274+
275+
opt0 = ftrl.FtrlOptimizer(
276+
3.0,
277+
initial_accumulator_value=0.1,
278+
l1_regularization_strength=0.001,
279+
l2_regularization_strength=2.0,
280+
l2_shrinkage_regularization_strength=0.1)
281+
opt1 = ftrl.FtrlOptimizer(
282+
3.0,
283+
initial_accumulator_value=0.1,
284+
l1_regularization_strength=0.001,
285+
l2_regularization_strength=2.0)
286+
update0 = opt0.apply_gradients([(grads0, var0)])
287+
update1 = opt1.apply_gradients([(grads1, var1)])
288+
variables.global_variables_initializer().run()
289+
290+
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
291+
self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval())
292+
293+
# Run 10 steps FTRL
294+
for _ in range(10):
295+
update0.run()
296+
update1.run()
297+
298+
# var0 is experiencing L2 shrinkage so it should be smaller than var1
299+
# in magnitude.
300+
self.assertTrue((var0.eval()**2 < var1.eval()**2).all())
301+
accum0 = list(opt0._slots["accum"].values())[0].eval()
302+
accum1 = list(opt1._slots["accum"].values())[0].eval()
303+
# L2 shrinkage should not change how we update grad accumulator.
304+
self.assertAllCloseAccordingToType(accum0, accum1)
265305

266306
# When variables are initialized with Zero, FTRL-Proximal has two properties:
267307
# 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical

tensorflow/compiler/tf2xla/kernels/training_ops.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
688688
}
689689

690690
// grad_to_use = grad + 2 * l2_shrinkage * var
691-
// new_accum = accum + grad_to_use * grad_to_use
691+
// new_accum = accum + grad * grad
692692
// linear += grad_to_use -
693693
// (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var
694694
// quadratic = (new_accum^(-lr_power) / lr) + 2 * l2
@@ -704,7 +704,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
704704
grad_to_use = grad;
705705
}
706706

707-
xla::XlaOp new_accum = accum + xla::Square(grad_to_use);
707+
xla::XlaOp new_accum = accum + xla::Square(grad);
708708
xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power);
709709
xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power);
710710
linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var;

tensorflow/core/kernels/training_ops.cc

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#define EIGEN_USE_THREADS
17-
1817
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
1918

2019
#include <algorithm>
@@ -201,7 +200,7 @@ struct ApplyFtrlV2<CPUDevice, T> {
201200
typename TTypes<T>::ConstScalar l2_shrinkage,
202201
typename TTypes<T>::ConstScalar lr_power) {
203202
auto grad_with_shrinkage = grad + static_cast<T>(2) * l2_shrinkage() * var;
204-
auto new_accum = accum + grad_with_shrinkage.square();
203+
auto new_accum = accum + grad * grad;
205204
// special case for which lr_power=-0.5.
206205
if (lr_power() == static_cast<T>(-0.5)) {
207206
linear.device(d) +=
@@ -226,7 +225,7 @@ struct ApplyFtrlV2<CPUDevice, T> {
226225
var.device(d) = (linear.abs() > linear.constant(l1()))
227226
.select(pre_shrink, var.constant(static_cast<T>(0)));
228227
}
229-
accum.device(d) += grad_with_shrinkage.square();
228+
accum.device(d) += grad * grad;
230229
}
231230
};
232231

@@ -2167,15 +2166,15 @@ class SparseApplyFtrlOp : public OpKernel {
21672166

21682167
// Use a macro to implement the computation here due to the templating of the
21692168
// eigen tensor library.
2170-
#define COMPUTE_FTRL(grad_to_use) \
2171-
auto new_accum = accum + grad_to_use.square(); \
2169+
#define COMPUTE_FTRL(grad, grad_maybe_with_shrinkage) \
2170+
auto new_accum = accum + grad.square(); \
21722171
if (lr_power_scalar == static_cast<T>(-0.5)) { \
2173-
linear += \
2174-
grad_to_use - (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \
2172+
linear += grad_maybe_with_shrinkage - \
2173+
(new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \
21752174
} else { \
2176-
linear += grad_to_use - (new_accum.pow(-lr_power_scalar) - \
2177-
accum.pow(-lr_power_scalar)) / \
2178-
lr_scalar * var; \
2175+
linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \
2176+
accum.pow(-lr_power_scalar)) / \
2177+
lr_scalar * var; \
21792178
} \
21802179
auto l1_reg_adjust = linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar); \
21812180
auto x = l1_reg_adjust - linear; \
@@ -2188,14 +2187,14 @@ class SparseApplyFtrlOp : public OpKernel {
21882187
linear.constant(static_cast<T>(2) * l2_scalar); \
21892188
var = x / y; \
21902189
} \
2191-
accum += grad_to_use.square();
2190+
accum += grad.square();
21922191

21932192
if (has_l2_shrinkage) {
21942193
auto grad_with_shrinkage =
21952194
grad + static_cast<T>(2) * l2_shrinkage_scalar * var;
2196-
COMPUTE_FTRL(grad_with_shrinkage);
2195+
COMPUTE_FTRL(grad, grad_with_shrinkage);
21972196
} else {
2198-
COMPUTE_FTRL(grad);
2197+
COMPUTE_FTRL(grad, grad);
21992198
}
22002199
}
22012200
#undef COMPUTE_FTRL
@@ -2228,12 +2227,12 @@ class SparseApplyFtrlOp : public OpKernel {
22282227
T g;
22292228
if (has_l2_shrinkage) {
22302229
g = grad_flat(i) +
2231-
(static_cast<T>(2) * l2_shrinkage_scalar * var_flat(i));
2230+
(static_cast<T>(2) * l2_shrinkage_scalar * var_flat(index));
22322231
} else {
22332232
g = grad_flat(i);
22342233
}
22352234

2236-
T updated_a = a + g * g;
2235+
T updated_a = a + grad_flat(i) * grad_flat(i);
22372236
using Eigen::numext::pow;
22382237
T sigma = pow(updated_a, -lr_power_scalar) - pow(a, -lr_power_scalar);
22392238
sigma /= lr_scalar;
@@ -2856,9 +2855,8 @@ class ApplyAdaMaxOp : public OpKernel {
28562855
const Device& device = ctx->template eigen_device<Device>();
28572856
functor::ApplyAdaMax<Device, T>()(
28582857
device, var.flat<T>(), m.flat<T>(), v.flat<T>(),
2859-
beta1_power.scalar<T>(), lr.scalar<T>(),
2860-
beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(),
2861-
grad.flat<T>());
2858+
beta1_power.scalar<T>(), lr.scalar<T>(), beta1.scalar<T>(),
2859+
beta2.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>());
28622860

28632861
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
28642862
}
@@ -2867,16 +2865,16 @@ class ApplyAdaMaxOp : public OpKernel {
28672865
bool use_exclusive_lock_;
28682866
};
28692867

2870-
#define REGISTER_KERNELS(D, T) \
2871-
REGISTER_KERNEL_BUILDER( \
2868+
#define REGISTER_KERNELS(D, T) \
2869+
REGISTER_KERNEL_BUILDER( \
28722870
Name("ApplyAdaMax").Device(DEVICE_##D).TypeConstraint<T>("T"), \
28732871
ApplyAdaMaxOp<D##Device, T>); \
28742872
REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdaMax") \
2875-
.HostMemory("var") \
2876-
.HostMemory("m") \
2877-
.HostMemory("v") \
2878-
.Device(DEVICE_##D) \
2879-
.TypeConstraint<T>("T"), \
2873+
.HostMemory("var") \
2874+
.HostMemory("m") \
2875+
.HostMemory("v") \
2876+
.Device(DEVICE_##D) \
2877+
.TypeConstraint<T>("T"), \
28802878
ApplyAdaMaxOp<D##Device, T>);
28812879
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
28822880

@@ -2889,15 +2887,15 @@ TF_CALL_double(REGISTER_CPU_KERNELS);
28892887
namespace functor {
28902888
#define DECLARE_GPU_SPEC(T) \
28912889
template <> \
2892-
void ApplyAdaMax<GPUDevice, T>::operator()( \
2890+
void ApplyAdaMax<GPUDevice, T>::operator()( \
28932891
const GPUDevice& d, typename TTypes<T>::Flat var, \
28942892
typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \
28952893
typename TTypes<T>::ConstScalar beta1_power, \
28962894
typename TTypes<T>::ConstScalar lr, \
28972895
typename TTypes<T>::ConstScalar beta1, \
28982896
typename TTypes<T>::ConstScalar beta2, \
28992897
typename TTypes<T>::ConstScalar epsilon, \
2900-
typename TTypes<T>::ConstFlat grad); \
2898+
typename TTypes<T>::ConstFlat grad); \
29012899
extern template struct ApplyAdaMax<GPUDevice, T>;
29022900
DECLARE_GPU_SPEC(Eigen::half);
29032901
DECLARE_GPU_SPEC(float);

tensorflow/python/training/ftrl_test.py

Lines changed: 85 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ def testMinimizeSparseResourceVariable(self):
117117
# Run 1 step of sgd
118118
sgd_op.run()
119119
# Validate updated params
120-
self.assertAllCloseAccordingToType(
121-
[[0, 1]], var0.eval(), atol=0.01)
120+
self.assertAllCloseAccordingToType([[0, 1]], var0.eval(), atol=0.01)
122121

123122
def testFtrlWithL1(self):
124123
for dtype in [dtypes.half, dtypes.float32]:
@@ -212,24 +211,96 @@ def testFtrlWithL1_L2_L2Shrinkage(self):
212211

213212
v0_val, v1_val = sess.run([var0, var1])
214213
self.assertAllCloseAccordingToType(
215-
np.array([-0.22078767, -0.41378114]), v0_val)
214+
np.array([-0.22578995, -0.44345796]), v0_val)
216215
self.assertAllCloseAccordingToType(
217-
np.array([-0.02919818, -0.07343706]), v1_val)
216+
np.array([-0.14378493, -0.13229476]), v1_val)
217+
218+
def testFtrlWithL1_L2_L2ShrinkageSparse(self):
219+
"""Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
220+
for dtype in [dtypes.half, dtypes.float32]:
221+
with self.test_session() as sess:
222+
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
223+
var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
224+
grads0 = ops.IndexedSlices(
225+
constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
226+
constant_op.constant([0]), constant_op.constant([2, 1]))
227+
grads1 = ops.IndexedSlices(
228+
constant_op.constant([0.02], shape=[1, 1], dtype=dtype),
229+
constant_op.constant([1]), constant_op.constant([2, 1]))
230+
231+
opt = ftrl.FtrlOptimizer(
232+
3.0,
233+
initial_accumulator_value=0.1,
234+
l1_regularization_strength=0.001,
235+
l2_regularization_strength=2.0,
236+
l2_shrinkage_regularization_strength=0.1)
237+
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
238+
variables.global_variables_initializer().run()
239+
240+
v0_val, v1_val = sess.run([var0, var1])
241+
self.assertAllCloseAccordingToType([[1.0], [2.0]], v0_val)
242+
self.assertAllCloseAccordingToType([[4.0], [3.0]], v1_val)
243+
244+
# Run 10 steps FTRL
245+
for _ in range(10):
246+
update.run()
247+
248+
v0_val, v1_val = sess.run([var0, var1])
249+
self.assertAllCloseAccordingToType([[-0.22578995], [2.]], v0_val)
250+
self.assertAllCloseAccordingToType([[4.], [-0.13229476]], v1_val)
251+
252+
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
253+
"""Verifies that l2 shrinkage in FTRL does not change lr schedule."""
254+
for dtype in [dtypes.half, dtypes.float32]:
255+
with self.test_session() as sess:
256+
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
257+
var1 = variables.Variable([1.0, 2.0], dtype=dtype)
258+
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
259+
grads1 = constant_op.constant([0.1, 0.2], dtype=dtype)
260+
261+
opt0 = ftrl.FtrlOptimizer(
262+
3.0,
263+
initial_accumulator_value=0.1,
264+
l1_regularization_strength=0.001,
265+
l2_regularization_strength=2.0,
266+
l2_shrinkage_regularization_strength=0.1)
267+
opt1 = ftrl.FtrlOptimizer(
268+
3.0,
269+
initial_accumulator_value=0.1,
270+
l1_regularization_strength=0.001,
271+
l2_regularization_strength=2.0)
272+
update0 = opt0.apply_gradients([(grads0, var0)])
273+
update1 = opt1.apply_gradients([(grads1, var1)])
274+
variables.global_variables_initializer().run()
275+
276+
v0_val, v1_val = sess.run([var0, var1])
277+
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
278+
self.assertAllCloseAccordingToType([1.0, 2.0], v1_val)
279+
280+
# Run 10 steps FTRL
281+
for _ in range(10):
282+
update0.run()
283+
update1.run()
284+
285+
v0_val, v1_val = sess.run([var0, var1])
286+
# var0 is experiencing L2 shrinkage so it should be smaller than var1
287+
# in magnitude.
288+
self.assertTrue((v0_val**2 < v1_val**2).all())
289+
accum0 = list(sess.run(opt0._slots)["accum"].values())[0]
290+
accum1 = list(sess.run(opt1._slots)["accum"].values())[0]
291+
# L2 shrinkage should not change how we update grad accumulator.
292+
self.assertAllCloseAccordingToType(accum0, accum1)
218293

219294
def applyOptimizer(self, opt, dtype, steps=5, is_sparse=False):
220295
if is_sparse:
221296
var0 = variables.Variable([[0.0], [0.0]], dtype=dtype)
222297
var1 = variables.Variable([[0.0], [0.0]], dtype=dtype)
223298
grads0 = ops.IndexedSlices(
224-
constant_op.constant(
225-
[0.1], shape=[1, 1], dtype=dtype),
226-
constant_op.constant([0]),
227-
constant_op.constant([2, 1]))
299+
constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
300+
constant_op.constant([0]), constant_op.constant([2, 1]))
228301
grads1 = ops.IndexedSlices(
229-
constant_op.constant(
230-
[0.02], shape=[1, 1], dtype=dtype),
231-
constant_op.constant([1]),
232-
constant_op.constant([2, 1]))
302+
constant_op.constant([0.02], shape=[1, 1], dtype=dtype),
303+
constant_op.constant([1]), constant_op.constant([2, 1]))
233304
else:
234305
var0 = variables.Variable([0.0, 0.0], dtype=dtype)
235306
var1 = variables.Variable([0.0, 0.0], dtype=dtype)
@@ -277,8 +348,7 @@ def testEquivAdagradwithoutRegularization(self):
277348

278349
with self.test_session():
279350
val2, val3 = self.applyOptimizer(
280-
adagrad.AdagradOptimizer(
281-
3.0, initial_accumulator_value=0.1), dtype)
351+
adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1), dtype)
282352

283353
self.assertAllCloseAccordingToType(val0, val2)
284354
self.assertAllCloseAccordingToType(val1, val3)
@@ -299,8 +369,7 @@ def testEquivSparseAdagradwithoutRegularization(self):
299369

300370
with self.test_session():
301371
val2, val3 = self.applyOptimizer(
302-
adagrad.AdagradOptimizer(
303-
3.0, initial_accumulator_value=0.1),
372+
adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1),
304373
dtype,
305374
is_sparse=True)
306375

0 commit comments

Comments
 (0)