Skip to content

Commit 618edbe

Browse files
taehoonleefchollet
authored andcommitted
Add cumprod for CNTK (keras-team#12536)
* Add `cumprod` for CNTK * Fix PEP8 * Remove warnings
1 parent f6a30b0 commit 618edbe

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

keras/backend/cntk_backend.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2746,7 +2746,22 @@ def cumsum(x, axis=0):
27462746

27472747

27482748
def cumprod(x, axis=0):
2749-
raise NotImplementedError
2749+
shape = x.shape
2750+
out = x
2751+
for rep in range(shape[axis] - 1):
2752+
sliced_shape = list(shape)
2753+
sliced_shape[axis] = rep + 1
2754+
if axis == 0:
2755+
_x = x[rep:(rep + 1)]
2756+
elif axis == 1:
2757+
_x = x[:, rep:(rep + 1)]
2758+
elif axis == 2:
2759+
_x = x[:, :, rep:(rep + 1)]
2760+
y = concatenate([ones(sliced_shape, dtype=x.dtype),
2761+
repeat_elements(_x, rep=shape[axis] - 1 - rep, axis=axis)],
2762+
axis=axis)
2763+
out = C.element_times(out, y)
2764+
return out
27502765

27512766

27522767
def arange(start, stop=None, step=1, dtype='int32'):

tests/keras/backend/backend_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,6 @@ def test_cumsum(self):
554554
check_single_tensor_operation('cumsum', (4, 2), WITH_NP)
555555
check_single_tensor_operation('cumsum', (4, 2), WITH_NP, axis=1)
556556

557-
@pytest.mark.skipif(K.backend() == 'cntk', reason='cntk does not support '
558-
'cumprod yet')
559557
def test_cumprod(self):
560558
check_single_tensor_operation('cumprod', (4, 2), WITH_NP)
561559
check_single_tensor_operation('cumprod', (4, 2), WITH_NP, axis=1)

0 commit comments

Comments
 (0)