Skip to content

Commit 7d7060f

Browse files
icybladefchollet
authored andcommitted
Patch to increase test coverage (keras-team#8902)
* no need to disable vis test in python 3.x this unit test works in my environment * add test for RemoteMonitor * install pydot and graphvis for python 3.x * monkey patch requests.post remove unnecessary code * travis fix * pep8 fix * add test for _remove_long_seq * add test for HDF5Matrix and ask_to_proceed_with_overwrite * typo
1 parent 7591740 commit 7d7060f

File tree

6 files changed

+86
-5
lines changed

6 files changed

+86
-5
lines changed

.travis.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ install:
6666
fi
6767

6868
# install pydot for visualization tests
69-
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
70-
conda install pydot graphviz;
71-
fi
69+
- conda install pydot graphviz
7270

7371
# exclude different backends to measure a coverage for the designated backend only
7472
- if [[ "$KERAS_BACKEND" != "tensorflow" ]]; then

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
'pytest-pep8',
2222
'pytest-xdist',
2323
'pytest-cov',
24-
'pandas'],
24+
'pandas',
25+
'requests'],
2526
},
2627
classifiers=[
2728
'Development Status :: 5 - Production/Stable',

tests/keras/preprocessing/sequence_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from keras.preprocessing.sequence import pad_sequences
77
from keras.preprocessing.sequence import make_sampling_table
88
from keras.preprocessing.sequence import skipgrams
9+
from keras.preprocessing.sequence import _remove_long_seq
910

1011

1112
def test_pad_sequences():
@@ -82,5 +83,17 @@ def test_skipgrams():
8283
assert len(l) == 2
8384

8485

86+
def test_remove_long_seq():
87+
maxlen = 5
88+
seq = [
89+
[1, 2, 3],
90+
[1, 2, 3, 4, 5, 6],
91+
]
92+
label = ['a', 'b']
93+
new_seq, new_label = _remove_long_seq(maxlen, seq, label)
94+
assert new_seq == [[1, 2, 3]]
95+
assert new_label == ['a']
96+
97+
8598
if __name__ == '__main__':
8699
pytest.main([__file__])

tests/keras/test_callbacks.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
from keras.utils.test_utils import keras_test
1818
from keras import backend as K
1919
from keras.utils import np_utils
20+
try:
21+
from unittest.mock import patch
22+
except:
23+
from mock import patch
24+
2025

2126
input_dim = 2
2227
num_hidden = 4
@@ -787,5 +792,27 @@ def test_TensorBoard_with_ReduceLROnPlateau(tmpdir):
787792
assert not tmpdir.listdir()
788793

789794

795+
@keras_test
796+
def tests_RemoteMonitor():
797+
(X_train, y_train), (X_test, y_test) = get_test_data(num_train=train_samples,
798+
num_test=test_samples,
799+
input_shape=(input_dim,),
800+
classification=True,
801+
num_classes=num_classes)
802+
y_test = np_utils.to_categorical(y_test)
803+
y_train = np_utils.to_categorical(y_train)
804+
model = Sequential()
805+
model.add(Dense(num_hidden, input_dim=input_dim, activation='relu'))
806+
model.add(Dense(num_classes, activation='softmax'))
807+
model.compile(loss='categorical_crossentropy',
808+
optimizer='rmsprop',
809+
metrics=['accuracy'])
810+
cbks = [callbacks.RemoteMonitor()]
811+
812+
with patch('requests.post'):
813+
model.fit(X_train, y_train, batch_size=batch_size,
814+
validation_data=(X_test, y_test), callbacks=cbks, epochs=1)
815+
816+
790817
if __name__ == '__main__':
791818
pytest.main([__file__])

tests/keras/utils/io_utils_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
'''Tests for functions in io_utils.py.
22
'''
33
import os
4+
import sys
45
import pytest
56
from keras.models import Sequential
67
from keras.layers import Dense
78
from keras.utils.io_utils import HDF5Matrix
9+
from keras.utils.io_utils import ask_to_proceed_with_overwrite
810
import numpy as np
911
import warnings
1012
import h5py
13+
try:
14+
from unittest.mock import patch
15+
except:
16+
from mock import patch
1117

1218

1319
@pytest.fixture
@@ -76,8 +82,45 @@ def test_io_utils(in_tmpdir):
7682
# test slicing for shortened array
7783
assert len(X_train[0:]) == len(X_train), 'Incorrect shape for sliced data'
7884

85+
# test __getitem__
86+
with pytest.raises(IndexError):
87+
X_train[1000]
88+
with pytest.raises(IndexError):
89+
X_train[1000:1001]
90+
with pytest.raises(IndexError):
91+
X_train[[1000, 1001]]
92+
with pytest.raises(IndexError):
93+
X_train[np.array([1000])]
94+
with pytest.raises(IndexError):
95+
X_train[None]
96+
assert (X_train[0] == X_train[:1][0]).all()
97+
assert (X_train[[0, 1]] == X_train[:2]).all()
98+
assert (X_train[np.array([0, 1])] == X_train[:2]).all()
99+
100+
# test normalizer
101+
normalizer = lambda x: x + 1
102+
normalized_X_train = HDF5Matrix(h5_path, 'my_data', start=0, end=150, normalizer=normalizer)
103+
assert np.isclose(normalized_X_train[0][0], X_train[0][0] + 1)
104+
79105
os.remove(h5_path)
80106

81107

108+
def test_ask_to_proceed_with_overwrite():
109+
if sys.version_info[:2] <= (2, 7):
110+
with patch('__builtin__.raw_input') as mock:
111+
mock.return_value = 'y'
112+
assert ask_to_proceed_with_overwrite('/tmp/not_exists')
113+
114+
mock.return_value = 'n'
115+
assert not ask_to_proceed_with_overwrite('/tmp/not_exists')
116+
else:
117+
with patch('builtins.input') as mock:
118+
mock.return_value = 'y'
119+
assert ask_to_proceed_with_overwrite('/tmp/not_exists')
120+
121+
mock.return_value = 'n'
122+
assert not ask_to_proceed_with_overwrite('/tmp/not_exists')
123+
124+
82125
if __name__ == '__main__':
83126
pytest.main([__file__])

tests/keras/utils/vis_utils_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from keras.utils import vis_utils
1212

1313

14-
@pytest.mark.skipif(sys.version_info > (3, 0), reason='pydot-ng currently supports python 3.4')
1514
def test_plot_model():
1615
model = Sequential()
1716
model.add(Conv2D(filters=2, kernel_size=(2, 3), input_shape=(3, 5, 5), name='conv'))

0 commit comments

Comments
 (0)