Skip to content

Commit 2ace75b

Browse files
yenchenlinjnothman
authored andcommitted
COSMIT Reduce duplicated code (scikit-learn#7053)
1 parent 4f4a580 commit 2ace75b

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

sklearn/cluster/_k_means.pyx

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ from sklearn.utils.fixes import bincount
2222
ctypedef np.float64_t DOUBLE
2323
ctypedef np.int32_t INT
2424

25+
ctypedef floating (*DOT)(int N, floating *X, int incX, floating *Y,
26+
int incY)
27+
2528
cdef extern from "cblas.h":
2629
double ddot "cblas_ddot"(int N, double *X, int incX, double *Y, int incY)
2730
float sdot "cblas_sdot"(int N, float *X, int incX, float *Y, int incY)
@@ -56,41 +59,35 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X,
5659
DOUBLE inertia = 0.0
5760
DOUBLE min_dist
5861
DOUBLE dist
62+
DOT dot
5963

6064
if floating is float:
6165
center_squared_norms = np.zeros(n_clusters, dtype=np.float32)
6266
x_stride = X.strides[1] / sizeof(float)
6367
center_stride = centers.strides[1] / sizeof(float)
68+
dot = sdot
6469
else:
6570
center_squared_norms = np.zeros(n_clusters, dtype=np.float64)
6671
x_stride = X.strides[1] / sizeof(DOUBLE)
6772
center_stride = centers.strides[1] / sizeof(DOUBLE)
73+
dot = ddot
6874

6975
if n_samples == distances.shape[0]:
7076
store_distances = 1
7177

7278
for center_idx in range(n_clusters):
73-
if floating is float:
74-
center_squared_norms[center_idx] = sdot(
75-
n_features, &centers[center_idx, 0], center_stride,
76-
&centers[center_idx, 0], center_stride)
77-
else:
78-
center_squared_norms[center_idx] = ddot(
79-
n_features, &centers[center_idx, 0], center_stride,
80-
&centers[center_idx, 0], center_stride)
79+
center_squared_norms[center_idx] = dot(
80+
n_features, &centers[center_idx, 0], center_stride,
81+
&centers[center_idx, 0], center_stride)
8182

8283
for sample_idx in range(n_samples):
8384
min_dist = -1
8485
for center_idx in range(n_clusters):
8586
dist = 0.0
8687
# hardcoded: minimize euclidean distance to cluster center:
8788
# ||a - b||^2 = ||a||^2 + ||b||^2 -2 <a, b>
88-
if floating is float:
89-
dist += sdot(n_features, &X[sample_idx, 0], x_stride,
90-
&centers[center_idx, 0], center_stride)
91-
else:
92-
dist += ddot(n_features, &X[sample_idx, 0], x_stride,
93-
&centers[center_idx, 0], center_stride)
89+
dist += dot(n_features, &X[sample_idx, 0], x_stride,
90+
&centers[center_idx, 0], center_stride)
9491
dist *= -2
9592
dist += center_squared_norms[center_idx]
9693
dist += x_squared_norms[sample_idx]
@@ -132,21 +129,20 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
132129
DOUBLE inertia = 0.0
133130
DOUBLE min_dist
134131
DOUBLE dist
132+
DOT dot
135133

136134
if floating is float:
137135
center_squared_norms = np.zeros(n_clusters, dtype=np.float32)
136+
dot = sdot
138137
else:
139138
center_squared_norms = np.zeros(n_clusters, dtype=np.float64)
139+
dot = ddot
140140

141141
if n_samples == distances.shape[0]:
142142
store_distances = 1
143143

144144
for center_idx in range(n_clusters):
145-
if floating is float:
146-
center_squared_norms[center_idx] = sdot(
147-
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
148-
else:
149-
center_squared_norms[center_idx] = ddot(
145+
center_squared_norms[center_idx] = dot(
150146
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
151147

152148
for sample_idx in range(n_samples):

0 commit comments

Comments
 (0)