Skip to content

Commit 0f4fec3

Browse files
moustakifchollet
authored andcommitted
Handling ndim == 2 in TF batch_dot (keras-team#5280)
* Handling ndim == 2 in TF batch_dot * Adding support for swapped axes when ndim==2
1 parent ff1f796 commit 0f4fec3

File tree

3 files changed

+35
-27
lines changed

3 files changed

+35
-27
lines changed

keras/backend/tensorflow_backend.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,6 @@ def batch_dot(x, y, axes=None):
831831
832832
# Arguments
833833
x, y: Keras tensors or variables with `ndim >= 2`
834-
(With TensorFlow backend, `batch_dot()` only supports `ndim >= 3`)
835834
axes: list of (or single) int with target dimensions.
836835
The lengths of `axes[0]` and `axes[1]` should be the same.
837836
@@ -870,24 +869,28 @@ def batch_dot(x, y, axes=None):
870869
(32, 1, 30)
871870
```
872871
"""
873-
if ndim(x) < 3 or ndim(y) < 3:
874-
raise ValueError('Invalid dimensions for batch_dot: ', ndim(x), ndim(y))
875872
if isinstance(axes, int):
876873
axes = (axes, axes)
877-
if axes is not None:
878-
adj_x = None if axes[0] == ndim(x) - 1 else True
879-
adj_y = True if axes[1] == ndim(y) - 1 else None
880-
else:
881-
adj_x = None
882-
adj_y = None
883-
# TODO: remove later.
884-
if hasattr(tf, 'batch_matmul'):
885-
try:
886-
out = tf.batch_matmul(x, y, adj_a=adj_x, adj_b=adj_y)
887-
except TypeError:
888-
out = tf.batch_matmul(x, y, adj_x=adj_x, adj_y=adj_y)
874+
if ndim(x) == 2 and ndim(y) == 2:
875+
if axes[0] == axes[1]:
876+
out = tf.reduce_sum(tf.mul(x, y), axes[0])
877+
else:
878+
out = tf.reduce_sum(tf.mul(tf.transpose(x, [1, 0]), y), axes[1])
889879
else:
890-
out = tf.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
880+
if axes is not None:
881+
adj_x = None if axes[0] == ndim(x) - 1 else True
882+
adj_y = True if axes[1] == ndim(y) - 1 else None
883+
else:
884+
adj_x = None
885+
adj_y = None
886+
# TODO: remove later.
887+
if hasattr(tf, 'batch_matmul'):
888+
try:
889+
out = tf.batch_matmul(x, y, adj_a=adj_x, adj_b=adj_y)
890+
except TypeError:
891+
out = tf.batch_matmul(x, y, adj_x=adj_x, adj_y=adj_y)
892+
else:
893+
out = tf.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
891894
if ndim(out) == 1:
892895
out = expand_dims(out, 1)
893896
return out

tests/keras/backend/test_backends.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,26 @@ def test_linear_operations(self):
7777

7878
check_two_tensor_operation('batch_dot', (4, 2, 3), (4, 5, 3),
7979
axes=(2, 2))
80+
check_two_tensor_operation('batch_dot', (32, 20), (32, 20), axes=1)
81+
check_two_tensor_operation('batch_dot', (32, 20), (32, 20), axes=(1, 1))
8082
check_single_tensor_operation('transpose', (4, 2))
8183
check_single_tensor_operation('reverse', (4, 3, 2), axes=1)
8284
check_single_tensor_operation('reverse', (4, 3, 2), axes=(1, 2))
8385

8486
def test_batch_dot_shape(self):
85-
with pytest.raises(ValueError):
86-
x_batch = KTF.ones(shape=(32, 20))
87-
y_batch = KTF.ones(shape=(32, 20))
88-
xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=1)
87+
x_batch = KTF.ones(shape=(32, 20))
88+
y_batch = KTF.ones(shape=(32, 20))
89+
xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=1)
90+
assert_allclose(KTF.eval(xy_batch_dot), np.ones((32, 1)) * 20, atol=1e-05)
91+
xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=0)
92+
assert_allclose(KTF.eval(xy_batch_dot), np.ones((20, 1)) * 32, atol=1e-05)
93+
# making sure swapping axes when ndim == 2 works
94+
x_batch = KTF.ones(shape=(32, 20))
95+
y_batch = KTF.ones(shape=(20, 32))
96+
xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=(0, 1))
97+
assert_allclose(KTF.eval(xy_batch_dot), np.ones((20, 1)) * 32, atol=1e-05)
98+
xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=(1, 0))
99+
assert_allclose(KTF.eval(xy_batch_dot), np.ones((32, 1)) * 20, atol=1e-05)
89100

90101
def test_shape_operations(self):
91102
# concatenate

tests/keras/test_sequential_model.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from keras import backend as K
99
from keras.models import Sequential
10-
from keras.layers.core import Dense, Activation, Merge, Lambda, Reshape
10+
from keras.layers.core import Dense, Activation, Merge, Lambda
1111
from keras.utils import np_utils
1212
from keras.utils.test_utils import get_test_data, keras_test
1313
from keras.models import model_from_json, model_from_yaml
@@ -287,35 +287,29 @@ def test_merge_dot():
287287

288288
left = Sequential()
289289
left.add(Dense(input_dim=input_dim, output_dim=nb_hidden))
290-
left.add(Reshape((nb_hidden, 1)))
291290
left.add(Activation('relu'))
292291

293292
right = Sequential()
294293
right.add(Dense(input_dim=input_dim, output_dim=nb_hidden))
295-
right.add(Reshape((nb_hidden, 1)))
296294
right.add(Activation('relu'))
297295

298296
model = Sequential()
299297
model.add(Merge([left, right], mode='dot', dot_axes=1))
300-
model.add(Reshape((1,)))
301298
model.add(Dense(nb_class))
302299
model.add(Activation('softmax'))
303300

304301
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
305302

306303
left = Sequential()
307304
left.add(Dense(input_dim=input_dim, output_dim=nb_hidden))
308-
left.add(Reshape((nb_hidden, 1)))
309305
left.add(Activation('relu'))
310306

311307
right = Sequential()
312308
right.add(Dense(input_dim=input_dim, output_dim=nb_hidden))
313-
right.add(Reshape((nb_hidden, 1)))
314309
right.add(Activation('relu'))
315310

316311
model = Sequential()
317312
model.add(Merge([left, right], mode='dot', dot_axes=[1, 1]))
318-
model.add(Reshape((1,)))
319313
model.add(Dense(nb_class))
320314
model.add(Activation('softmax'))
321315

0 commit comments

Comments
 (0)