Skip to content

Commit 840bcf3

Browse files
committed
Merge pull request scikit-learn#4961 from mrphilroth/issue4959
[MRG] Allowing for different cluster stds
2 parents c151f6e + ec78cc1 commit 840bcf3

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

sklearn/datasets/samples_generator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,9 @@ def make_blobs(n_samples=100, n_features=2, centers=3, cluster_std=1.0,
745745
centers = check_array(centers)
746746
n_features = centers.shape[1]
747747

748+
if isinstance(cluster_std, numbers.Real):
749+
cluster_std = np.ones(len(centers)) * cluster_std
750+
748751
X = []
749752
y = []
750753

@@ -754,8 +757,8 @@ def make_blobs(n_samples=100, n_features=2, centers=3, cluster_std=1.0,
754757
for i in range(n_samples % n_centers):
755758
n_samples_per_center[i] += 1
756759

757-
for i, n in enumerate(n_samples_per_center):
758-
X.append(centers[i] + generator.normal(scale=cluster_std,
760+
for i, (n, std) in enumerate(zip(n_samples_per_center, cluster_std)):
761+
X.append(centers[i] + generator.normal(scale=std,
759762
size=(n, n_features)))
760763
y += [i] * n
761764

sklearn/datasets/tests/test_samples_generator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,16 @@ def test_make_regression_multitarget():
208208

209209

210210
def test_make_blobs():
211-
X, y = make_blobs(n_samples=50, n_features=2,
212-
centers=[[0.0, 0.0], [1.0, 1.0], [0.0, 1.0]],
213-
random_state=0)
211+
cluster_stds = np.array([0.05, 0.2, 0.4])
212+
cluster_centers = np.array([[0.0, 0.0], [1.0, 1.0], [0.0, 1.0]])
213+
X, y = make_blobs(random_state=0, n_samples=50, n_features=2,
214+
centers=cluster_centers, cluster_std=cluster_stds)
214215

215216
assert_equal(X.shape, (50, 2), "X shape mismatch")
216217
assert_equal(y.shape, (50,), "y shape mismatch")
217218
assert_equal(np.unique(y).shape, (3,), "Unexpected number of blobs")
219+
for i, (ctr, std) in enumerate(zip(cluster_centers, cluster_stds)):
220+
assert_almost_equal((X[y == i] - ctr).std(), std, 1, "Unexpected std")
218221

219222

220223
def test_make_friedman1():

0 commit comments

Comments
 (0)