@@ -30,29 +30,26 @@ def test_mean_variance_axis0():
3030 X_lil = sp .lil_matrix (X )
3131 X_lil [1 , 0 ] = 0
3232 X [1 , 0 ] = 0
33- X_csr = sp .csr_matrix (X_lil )
3433
35- X_means , X_vars = mean_variance_axis (X_csr , axis = 0 )
36- assert_array_almost_equal (X_means , np .mean (X , axis = 0 ))
37- assert_array_almost_equal (X_vars , np .var (X , axis = 0 ))
34+ assert_raises (TypeError , mean_variance_axis , X_lil , axis = 0 )
3835
36+ X_csr = sp .csr_matrix (X_lil )
3937 X_csc = sp .csc_matrix (X_lil )
40- X_means , X_vars = mean_variance_axis (X_csc , axis = 0 )
4138
42- assert_array_almost_equal (X_means , np .mean (X , axis = 0 ))
43- assert_array_almost_equal (X_vars , np .var (X , axis = 0 ))
44- assert_raises (TypeError , mean_variance_axis , X_lil , axis = 0 )
39+ expected_dtypes = [(np .float32 , np .float32 ),
40+ (np .float64 , np .float64 ),
41+ (np .int32 , np .float64 ),
42+ (np .int64 , np .float64 )]
4543
46- X = X .astype (np .float32 )
47- X_csr = X_csr .astype (np .float32 )
48- X_csc = X_csr .astype (np .float32 )
49- X_means , X_vars = mean_variance_axis (X_csr , axis = 0 )
50- assert_array_almost_equal (X_means , np .mean (X , axis = 0 ))
51- assert_array_almost_equal (X_vars , np .var (X , axis = 0 ))
52- X_means , X_vars = mean_variance_axis (X_csc , axis = 0 )
53- assert_array_almost_equal (X_means , np .mean (X , axis = 0 ))
54- assert_array_almost_equal (X_vars , np .var (X , axis = 0 ))
55- assert_raises (TypeError , mean_variance_axis , X_lil , axis = 0 )
44+ for input_dtype , output_dtype in expected_dtypes :
45+ X_test = X .astype (input_dtype )
46+ for X_sparse in (X_csr , X_csc ):
47+ X_sparse = X_sparse .astype (input_dtype )
48+ X_means , X_vars = mean_variance_axis (X_sparse , axis = 0 )
49+ assert_equal (X_means .dtype , output_dtype )
50+ assert_equal (X_vars .dtype , output_dtype )
51+ assert_array_almost_equal (X_means , np .mean (X_test , axis = 0 ))
52+ assert_array_almost_equal (X_vars , np .var (X_test , axis = 0 ))
5653
5754
5855def test_mean_variance_axis1 ():
@@ -64,29 +61,26 @@ def test_mean_variance_axis1():
6461 X_lil = sp .lil_matrix (X )
6562 X_lil [1 , 0 ] = 0
6663 X [1 , 0 ] = 0
67- X_csr = sp .csr_matrix (X_lil )
6864
69- X_means , X_vars = mean_variance_axis (X_csr , axis = 1 )
70- assert_array_almost_equal (X_means , np .mean (X , axis = 1 ))
71- assert_array_almost_equal (X_vars , np .var (X , axis = 1 ))
65+ assert_raises (TypeError , mean_variance_axis , X_lil , axis = 1 )
7266
67+ X_csr = sp .csr_matrix (X_lil )
7368 X_csc = sp .csc_matrix (X_lil )
74- X_means , X_vars = mean_variance_axis (X_csc , axis = 1 )
7569
76- assert_array_almost_equal (X_means , np .mean (X , axis = 1 ))
77- assert_array_almost_equal (X_vars , np .var (X , axis = 1 ))
78- assert_raises (TypeError , mean_variance_axis , X_lil , axis = 1 )
70+ expected_dtypes = [(np .float32 , np .float32 ),
71+ (np .float64 , np .float64 ),
72+ (np .int32 , np .float64 ),
73+ (np .int64 , np .float64 )]
7974
80- X = X .astype (np .float32 )
81- X_csr = X_csr .astype (np .float32 )
82- X_csc = X_csr .astype (np .float32 )
83- X_means , X_vars = mean_variance_axis (X_csr , axis = 1 )
84- assert_array_almost_equal (X_means , np .mean (X , axis = 1 ))
85- assert_array_almost_equal (X_vars , np .var (X , axis = 1 ))
86- X_means , X_vars = mean_variance_axis (X_csc , axis = 1 )
87- assert_array_almost_equal (X_means , np .mean (X , axis = 1 ))
88- assert_array_almost_equal (X_vars , np .var (X , axis = 1 ))
89- assert_raises (TypeError , mean_variance_axis , X_lil , axis = 1 )
75+ for input_dtype , output_dtype in expected_dtypes :
76+ X_test = X .astype (input_dtype )
77+ for X_sparse in (X_csr , X_csc ):
78+ X_sparse = X_sparse .astype (input_dtype )
79+ X_means , X_vars = mean_variance_axis (X_sparse , axis = 0 )
80+ assert_equal (X_means .dtype , output_dtype )
81+ assert_equal (X_vars .dtype , output_dtype )
82+ assert_array_almost_equal (X_means , np .mean (X_test , axis = 0 ))
83+ assert_array_almost_equal (X_vars , np .var (X_test , axis = 0 ))
9084
9185
9286def test_incr_mean_variance_axis ():
@@ -132,34 +126,25 @@ def test_incr_mean_variance_axis():
132126 X = np .vstack (data_chunks )
133127 X_lil = sp .lil_matrix (X )
134128 X_csr = sp .csr_matrix (X_lil )
135- X_means , X_vars = mean_variance_axis (X_csr , axis )
136- X_means_incr , X_vars_incr , n_incr = \
137- incr_mean_variance_axis (X_csr , axis , last_mean , last_var , last_n )
138- assert_array_almost_equal (X_means , X_means_incr )
139- assert_array_almost_equal (X_vars , X_vars_incr )
140- assert_equal (X .shape [axis ], n_incr )
141-
142129 X_csc = sp .csc_matrix (X_lil )
143- X_means , X_vars = mean_variance_axis (X_csc , axis )
144- assert_array_almost_equal (X_means , X_means_incr )
145- assert_array_almost_equal (X_vars , X_vars_incr )
146- assert_equal (X .shape [axis ], n_incr )
147130
148- # All data but as float
149- X = X .astype (np .float32 )
150- X_csr = X_csr .astype (np .float32 )
151- X_means , X_vars = mean_variance_axis (X_csr , axis )
152- X_means_incr , X_vars_incr , n_incr = \
153- incr_mean_variance_axis (X_csr , axis , last_mean , last_var , last_n )
154- assert_array_almost_equal (X_means , X_means_incr )
155- assert_array_almost_equal (X_vars , X_vars_incr )
156- assert_equal (X .shape [axis ], n_incr )
157-
158- X_csc = X_csr .astype (np .float32 )
159- X_means , X_vars = mean_variance_axis (X_csc , axis )
160- assert_array_almost_equal (X_means , X_means_incr )
161- assert_array_almost_equal (X_vars , X_vars_incr )
162- assert_equal (X .shape [axis ], n_incr )
131+ expected_dtypes = [(np .float32 , np .float32 ),
132+ (np .float64 , np .float64 ),
133+ (np .int32 , np .float64 ),
134+ (np .int64 , np .float64 )]
135+
136+ for input_dtype , output_dtype in expected_dtypes :
137+ for X_sparse in (X_csr , X_csc ):
138+ X_sparse = X_sparse .astype (input_dtype )
139+ X_means , X_vars = mean_variance_axis (X_sparse , axis )
140+ X_means_incr , X_vars_incr , n_incr = \
141+ incr_mean_variance_axis (X_sparse , axis , last_mean ,
142+ last_var , last_n )
143+ assert_equal (X_means_incr .dtype , output_dtype )
144+ assert_equal (X_vars_incr .dtype , output_dtype )
145+ assert_array_almost_equal (X_means , X_means_incr )
146+ assert_array_almost_equal (X_vars , X_vars_incr )
147+ assert_equal (X .shape [axis ], n_incr )
163148
164149
165150def test_mean_variance_illegal_axis ():
0 commit comments