|
1 | 1 | '''Tests for functions in io_utils.py. |
2 | 2 | ''' |
3 | 3 | import os |
| 4 | +import sys |
4 | 5 | import pytest |
5 | 6 | from keras.models import Sequential |
6 | 7 | from keras.layers import Dense |
7 | 8 | from keras.utils.io_utils import HDF5Matrix |
| 9 | +from keras.utils.io_utils import ask_to_proceed_with_overwrite |
8 | 10 | import numpy as np |
9 | 11 | import warnings |
10 | 12 | import h5py |
| 13 | +try: |
| 14 | + from unittest.mock import patch |
| 15 | +except: |
| 16 | + from mock import patch |
11 | 17 |
|
12 | 18 |
|
13 | 19 | @pytest.fixture |
@@ -76,8 +82,45 @@ def test_io_utils(in_tmpdir): |
76 | 82 | # test slicing for shortened array |
77 | 83 | assert len(X_train[0:]) == len(X_train), 'Incorrect shape for sliced data' |
78 | 84 |
|
| 85 | + # test __getitem__ |
| 86 | + with pytest.raises(IndexError): |
| 87 | + X_train[1000] |
| 88 | + with pytest.raises(IndexError): |
| 89 | + X_train[1000:1001] |
| 90 | + with pytest.raises(IndexError): |
| 91 | + X_train[[1000, 1001]] |
| 92 | + with pytest.raises(IndexError): |
| 93 | + X_train[np.array([1000])] |
| 94 | + with pytest.raises(IndexError): |
| 95 | + X_train[None] |
| 96 | + assert (X_train[0] == X_train[:1][0]).all() |
| 97 | + assert (X_train[[0, 1]] == X_train[:2]).all() |
| 98 | + assert (X_train[np.array([0, 1])] == X_train[:2]).all() |
| 99 | + |
| 100 | + # test normalizer |
| 101 | + normalizer = lambda x: x + 1 |
| 102 | + normalized_X_train = HDF5Matrix(h5_path, 'my_data', start=0, end=150, normalizer=normalizer) |
| 103 | + assert np.isclose(normalized_X_train[0][0], X_train[0][0] + 1) |
| 104 | + |
79 | 105 | os.remove(h5_path) |
80 | 106 |
|
81 | 107 |
|
| 108 | +def test_ask_to_proceed_with_overwrite(): |
| 109 | + if sys.version_info[:2] <= (2, 7): |
| 110 | + with patch('__builtin__.raw_input') as mock: |
| 111 | + mock.return_value = 'y' |
| 112 | + assert ask_to_proceed_with_overwrite('/tmp/not_exists') |
| 113 | + |
| 114 | + mock.return_value = 'n' |
| 115 | + assert not ask_to_proceed_with_overwrite('/tmp/not_exists') |
| 116 | + else: |
| 117 | + with patch('builtins.input') as mock: |
| 118 | + mock.return_value = 'y' |
| 119 | + assert ask_to_proceed_with_overwrite('/tmp/not_exists') |
| 120 | + |
| 121 | + mock.return_value = 'n' |
| 122 | + assert not ask_to_proceed_with_overwrite('/tmp/not_exists') |
| 123 | + |
| 124 | + |
82 | 125 | if __name__ == '__main__': |
83 | 126 | pytest.main([__file__]) |
0 commit comments