@@ -62,88 +62,69 @@ def _standardize_input_data(data, names, shapes=None,
6262 return []
6363 if data is None :
6464 return [None for _ in range (len (names ))]
65+
6566 if isinstance (data , dict ):
6667 try :
67- arrays = [data [name ].values if data [name ].__class__ .__name__ == 'DataFrame' else data [name ]
68- for name in names ]
69-
68+ data = [data [x ].values if data [x ].__class__ .__name__ == 'DataFrame' else data [x ] for x in names ]
69+ data = [np .expand_dims (x , 1 ) if x .ndim == 1 else x for x in data ]
7070 except KeyError as e :
71- raise ValueError ('No data provided for "' +
72- e .args [0 ] + '". Need data for each key in: ' +
73- str (names ))
74-
71+ raise ValueError (
72+ 'No data provided for "' + e .args [0 ] + '". Need data '
73+ 'for each key in: ' + str (names ))
7574 elif isinstance (data , list ):
76- arrays = [x .values if x .__class__ .__name__ == 'DataFrame' else x for x in data ]
77- if len (arrays ) != len (names ):
78- if arrays and hasattr (arrays [0 ], 'shape' ):
79- raise ValueError ('Error when checking model ' +
80- exception_prefix +
81- ': the list of Numpy arrays '
82- 'that you are passing to your model '
83- 'is not the size the model expected. '
84- 'Expected to see ' + str (len (names )) +
85- ' array(s), but instead got '
86- 'the following list of ' + str (len (arrays )) +
87- ' arrays: ' + str (arrays )[:200 ] +
88- '...' )
89- else :
90- if len (names ) == 1 :
91- arrays = [np .asarray (arrays )]
92- else :
93- raise ValueError (
94- 'Error when checking model ' +
95- exception_prefix +
96- ': you are passing a list as '
97- 'input to your model, '
98- 'but the model expects '
99- 'a list of ' + str (len (names )) +
100- ' Numpy arrays instead. '
101- 'The list you passed was: ' +
102- str (arrays )[:200 ])
75+ data = [x .values if x .__class__ .__name__ == 'DataFrame' else x for x in data ]
76+ data = [np .expand_dims (x , 1 ) if x is not None and x .ndim == 1 else x for x in data ]
10377 else :
104- if data .__class__ .__name__ == 'DataFrame' :
105- # test if data is a DataFrame, without pandas installed
106- data = data .values
107- if not hasattr (data , 'shape' ):
108- raise TypeError ('Error when checking model ' +
109- exception_prefix +
110- ': data should be a Numpy array, '
111- 'or list/dict of Numpy arrays. '
112- 'Found: ' + str (data )[:200 ] + '...' )
113- if len (names ) > 1 :
114- # Case: model expects multiple inputs but only received
115- # a single Numpy array.
116- raise ValueError ('The model expects ' + str (len (names )) + ' ' +
117- exception_prefix +
118- ' arrays, but only received one array. '
119- 'Found: array with shape ' + str (data .shape ))
120- arrays = [data ]
121-
122- # Make arrays at least 2D.
123- arrays = [np .expand_dims (array , 1 ) if array .ndim == 1 else array for array in arrays ]
78+ data = data .values if data .__class__ .__name__ == 'DataFrame' else data
79+ data = [np .expand_dims (data , 1 )] if data .ndim == 1 else [data ]
80+
81+ if len (data ) != len (names ):
82+ if data and hasattr (data [0 ], 'shape' ):
83+ raise ValueError (
84+ 'Error when checking model ' + exception_prefix +
85+ ': the list of Numpy arrays that you are passing to '
86+ 'your model is not the size the model expected. '
87+ 'Expected to see ' + str (len (names )) + ' array(s), '
88+ 'but instead got the following list of ' +
89+ str (len (data )) + ' arrays: ' + str (data )[:200 ] + '...' )
90+ elif len (names ) > 1 :
91+ raise ValueError (
92+ 'Error when checking model ' + exception_prefix +
93+ ': you are passing a list as input to your model, '
94+ 'but the model expects a list of ' + str (len (names )) +
95+ ' Numpy arrays instead. The list you passed was: ' +
96+ str (data )[:200 ])
97+ elif len (data ) == 1 and not hasattr (data [0 ], 'shape' ):
98+ raise TypeError (
99+ 'Error when checking model ' + exception_prefix +
100+ ': data should be a Numpy array, or list/dict of '
101+ 'Numpy arrays. Found: ' + str (data )[:200 ] + '...' )
102+ elif len (names ) == 1 :
103+ data = [np .asarray (data )]
124104
125105 # Check shapes compatibility.
126106 if shapes :
127- start = 0 if check_batch_axis else 1
128107 for i in range (len (names )):
129108 if shapes [i ] is not None :
130- array_shape = arrays [i ].shape
131- if arrays [i ].ndim != len (shapes [i ]):
132- raise ValueError ('Error when checking ' + exception_prefix +
133- ': expected ' + names [i ] +
134- ' to have ' + str (len (shapes [i ])) +
135- ' dimensions, but got array with shape ' +
136- str (array_shape ))
137-
138- for dim , ref_dim in zip (array_shape [start :], shapes [i ][start :]):
109+ data_shape = data [i ].shape
110+ shape = shapes [i ]
111+ if data [i ].ndim != len (shape ):
112+ raise ValueError (
113+ 'Error when checking ' + exception_prefix +
114+ ': expected ' + names [i ] + ' to have ' +
115+ str (len (shape )) + ' dimensions, but got array '
116+ 'with shape ' + str (data_shape ))
117+ if not check_batch_axis :
118+ data_shape = data_shape [1 :]
119+ shape = shape [1 :]
120+ for dim , ref_dim in zip (data_shape , shape ):
139121 if ref_dim != dim and ref_dim :
140122 raise ValueError (
141123 'Error when checking ' + exception_prefix +
142- ': expected ' + names [i ] +
143- ' to have shape ' + str (shapes [i ]) +
144- ' but got array with shape ' +
145- str (array_shape ))
146- return arrays
124+ ': expected ' + names [i ] + ' to have shape ' +
125+ str (shape ) + ' but got array with shape ' +
126+ str (data_shape ))
127+ return data
147128
148129
149130def _standardize_sample_or_class_weights (x_weight , output_names , weight_type ):
0 commit comments