Skip to content

Commit 8ed57c1

Browse files
committed
Merge branch 'master' of github.com:fchollet/keras
2 parents f6eda66 + 35e10e9 commit 8ed57c1

File tree

3 files changed

+83
-69
lines changed

3 files changed

+83
-69
lines changed

keras/engine/training.py

Lines changed: 50 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -62,88 +62,69 @@ def _standardize_input_data(data, names, shapes=None,
6262
return []
6363
if data is None:
6464
return [None for _ in range(len(names))]
65+
6566
if isinstance(data, dict):
6667
try:
67-
arrays = [data[name].values if data[name].__class__.__name__ == 'DataFrame' else data[name]
68-
for name in names]
69-
68+
data = [data[x].values if data[x].__class__.__name__ == 'DataFrame' else data[x] for x in names]
69+
data = [np.expand_dims(x, 1) if x.ndim == 1 else x for x in data]
7070
except KeyError as e:
71-
raise ValueError('No data provided for "' +
72-
e.args[0] + '". Need data for each key in: ' +
73-
str(names))
74-
71+
raise ValueError(
72+
'No data provided for "' + e.args[0] + '". Need data '
73+
'for each key in: ' + str(names))
7574
elif isinstance(data, list):
76-
arrays = [x.values if x.__class__.__name__ == 'DataFrame' else x for x in data]
77-
if len(arrays) != len(names):
78-
if arrays and hasattr(arrays[0], 'shape'):
79-
raise ValueError('Error when checking model ' +
80-
exception_prefix +
81-
': the list of Numpy arrays '
82-
'that you are passing to your model '
83-
'is not the size the model expected. '
84-
'Expected to see ' + str(len(names)) +
85-
' array(s), but instead got '
86-
'the following list of ' + str(len(arrays)) +
87-
' arrays: ' + str(arrays)[:200] +
88-
'...')
89-
else:
90-
if len(names) == 1:
91-
arrays = [np.asarray(arrays)]
92-
else:
93-
raise ValueError(
94-
'Error when checking model ' +
95-
exception_prefix +
96-
': you are passing a list as '
97-
'input to your model, '
98-
'but the model expects '
99-
'a list of ' + str(len(names)) +
100-
' Numpy arrays instead. '
101-
'The list you passed was: ' +
102-
str(arrays)[:200])
75+
data = [x.values if x.__class__.__name__ == 'DataFrame' else x for x in data]
76+
data = [np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data]
10377
else:
104-
if data.__class__.__name__ == 'DataFrame':
105-
# test if data is a DataFrame, without pandas installed
106-
data = data.values
107-
if not hasattr(data, 'shape'):
108-
raise TypeError('Error when checking model ' +
109-
exception_prefix +
110-
': data should be a Numpy array, '
111-
'or list/dict of Numpy arrays. '
112-
'Found: ' + str(data)[:200] + '...')
113-
if len(names) > 1:
114-
# Case: model expects multiple inputs but only received
115-
# a single Numpy array.
116-
raise ValueError('The model expects ' + str(len(names)) + ' ' +
117-
exception_prefix +
118-
' arrays, but only received one array. '
119-
'Found: array with shape ' + str(data.shape))
120-
arrays = [data]
121-
122-
# Make arrays at least 2D.
123-
arrays = [np.expand_dims(array, 1) if array.ndim == 1 else array for array in arrays]
78+
data = data.values if data.__class__.__name__ == 'DataFrame' else data
79+
data = [np.expand_dims(data, 1)] if data.ndim == 1 else [data]
80+
81+
if len(data) != len(names):
82+
if data and hasattr(data[0], 'shape'):
83+
raise ValueError(
84+
'Error when checking model ' + exception_prefix +
85+
': the list of Numpy arrays that you are passing to '
86+
'your model is not the size the model expected. '
87+
'Expected to see ' + str(len(names)) + ' array(s), '
88+
'but instead got the following list of ' +
89+
str(len(data)) + ' arrays: ' + str(data)[:200] + '...')
90+
elif len(names) > 1:
91+
raise ValueError(
92+
'Error when checking model ' + exception_prefix +
93+
': you are passing a list as input to your model, '
94+
'but the model expects a list of ' + str(len(names)) +
95+
' Numpy arrays instead. The list you passed was: ' +
96+
str(data)[:200])
97+
elif len(data) == 1 and not hasattr(data[0], 'shape'):
98+
raise TypeError(
99+
'Error when checking model ' + exception_prefix +
100+
': data should be a Numpy array, or list/dict of '
101+
'Numpy arrays. Found: ' + str(data)[:200] + '...')
102+
elif len(names) == 1:
103+
data = [np.asarray(data)]
124104

125105
# Check shapes compatibility.
126106
if shapes:
127-
start = 0 if check_batch_axis else 1
128107
for i in range(len(names)):
129108
if shapes[i] is not None:
130-
array_shape = arrays[i].shape
131-
if arrays[i].ndim != len(shapes[i]):
132-
raise ValueError('Error when checking ' + exception_prefix +
133-
': expected ' + names[i] +
134-
' to have ' + str(len(shapes[i])) +
135-
' dimensions, but got array with shape ' +
136-
str(array_shape))
137-
138-
for dim, ref_dim in zip(array_shape[start:], shapes[i][start:]):
109+
data_shape = data[i].shape
110+
shape = shapes[i]
111+
if data[i].ndim != len(shape):
112+
raise ValueError(
113+
'Error when checking ' + exception_prefix +
114+
': expected ' + names[i] + ' to have ' +
115+
str(len(shape)) + ' dimensions, but got array '
116+
'with shape ' + str(data_shape))
117+
if not check_batch_axis:
118+
data_shape = data_shape[1:]
119+
shape = shape[1:]
120+
for dim, ref_dim in zip(data_shape, shape):
139121
if ref_dim != dim and ref_dim:
140122
raise ValueError(
141123
'Error when checking ' + exception_prefix +
142-
': expected ' + names[i] +
143-
' to have shape ' + str(shapes[i]) +
144-
' but got array with shape ' +
145-
str(array_shape))
146-
return arrays
124+
': expected ' + names[i] + ' to have shape ' +
125+
str(shape) + ' but got array with shape ' +
126+
str(data_shape))
127+
return data
147128

148129

149130
def _standardize_sample_or_class_weights(x_weight, output_names, weight_type):

keras/preprocessing/text.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ class Tokenizer(object):
124124
lower: boolean. Whether to convert the texts to lowercase.
125125
split: character or string to use for token splitting.
126126
char_level: if True, every character will be treated as a token.
127+
oov_token: if given, it will be added to word_index and used to
128+
replace out-of-vocabulary words during text_to_sequence calls
127129
128130
By default, all punctuation is removed, turning the texts into
129131
space-separated sequences of words
@@ -138,6 +140,7 @@ def __init__(self, num_words=None,
138140
lower=True,
139141
split=' ',
140142
char_level=False,
143+
oov_token=None,
141144
**kwargs):
142145
# Legacy support
143146
if 'nb_words' in kwargs:
@@ -155,6 +158,7 @@ def __init__(self, num_words=None,
155158
self.num_words = num_words
156159
self.document_count = 0
157160
self.char_level = char_level
161+
self.oov_token = oov_token
158162

159163
def fit_on_texts(self, texts):
160164
"""Updates internal vocabulary based on a list of texts.
@@ -189,6 +193,11 @@ def fit_on_texts(self, texts):
189193
# note that index 0 is reserved, never assigned to an existing word
190194
self.word_index = dict(list(zip(sorted_voc, list(range(1, len(sorted_voc) + 1)))))
191195

196+
if self.oov_token is not None:
197+
i = self.word_index.get(self.oov_token)
198+
if i is None:
199+
self.word_index[self.oov_token] = len(self.word_index) + 1
200+
192201
self.index_docs = {}
193202
for w, c in list(self.word_docs.items()):
194203
self.index_docs[self.word_index[w]] = c
@@ -256,6 +265,10 @@ def texts_to_sequences_generator(self, texts):
256265
continue
257266
else:
258267
vect.append(i)
268+
elif self.oov_token is not None:
269+
i = self.word_index.get(self.oov_token)
270+
if i is not None:
271+
vect.append(i)
259272
yield vect
260273

261274
def texts_to_matrix(self, texts, mode='binary'):

tests/keras/preprocessing/text_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,25 @@ def test_tokenizer_unicode():
6767
assert len(tokenizer.word_counts) == 5
6868

6969

70+
def test_tokenizer_oov_flag():
71+
"""
72+
Test of Out of Vocabulary (OOV) flag in Tokenizer
73+
"""
74+
x_train = ['This text has only known words']
75+
x_test = ['This text has some unknown words'] # 2 OOVs: some, unknown
76+
77+
# Defalut, without OOV flag
78+
tokenizer = Tokenizer()
79+
tokenizer.fit_on_texts(x_train)
80+
x_test_seq = tokenizer.texts_to_sequences(x_test)
81+
assert len(x_test_seq[0]) == 4 # discards 2 OOVs
82+
83+
# With OOV feature
84+
tokenizer = Tokenizer(oov_token='<unk>')
85+
tokenizer.fit_on_texts(x_train)
86+
x_test_seq = tokenizer.texts_to_sequences(x_test)
87+
assert len(x_test_seq[0]) == 6 # OOVs marked in place
88+
89+
7090
if __name__ == '__main__':
7191
pytest.main([__file__])

0 commit comments

Comments
 (0)