Skip to content

Commit 8022008

Browse files
committed
FIX: make ward_tree work on 1D data
1 parent 99f5813 commit 8022008

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

sklearn/cluster/hierarchical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def ward_tree(X, connectivity=None, n_components=None, copy=True,
7979
is specified, elsewhere 'None' is returned.
8080
"""
8181
X = np.asarray(X)
82-
n_samples, n_features = X.shape
8382
if X.ndim == 1:
8483
X = np.reshape(X, (-1, 1))
84+
n_samples, n_features = X.shape
8585

8686
if connectivity is None:
8787
if n_clusters is not None:

sklearn/cluster/tests/test_hierarchical.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88
from scipy.cluster import hierarchy
9-
from nose.tools import assert_true, assert_raises
9+
from nose.tools import assert_true, assert_raises, assert_equal
1010

1111
from sklearn.cluster import Ward, WardAgglomeration, ward_tree
1212
from sklearn.cluster.hierarchical import _hc_cut
@@ -33,9 +33,10 @@ def test_unstructured_ward_tree():
3333
"""
3434
rnd = np.random.RandomState(0)
3535
X = rnd.randn(50, 100)
36-
children, n_nodes, n_leaves, parent = ward_tree(X.T)
37-
n_nodes = 2 * X.shape[1] - 1
38-
assert_true(len(children) + n_leaves == n_nodes)
36+
for this_X in (X, X[0]):
37+
children, n_nodes, n_leaves, parent = ward_tree(this_X.T)
38+
n_nodes = 2 * X.shape[1] - 1
39+
assert_equal(len(children) + n_leaves, n_nodes)
3940

4041

4142
def test_height_ward_tree():

0 commit comments

Comments
 (0)