Skip to content

Commit bb38ebb

Browse files
viiryamengxr
authored andcommitted
[SPARK-5050][Mllib] Add unit test for sqdist
Related to apache#3643. Follow the previous suggestion to add unit test for `sqdist` in `VectorsSuite`. Author: Liang-Chi Hsieh <[email protected]> Closes apache#3869 from viirya/sqdist_test and squashes the following commits: fb743da [Liang-Chi Hsieh] Modified for comment and fix bug. 90a08f3 [Liang-Chi Hsieh] Modified for comment. 39a3ca6 [Liang-Chi Hsieh] Take care of special case. b789f42 [Liang-Chi Hsieh] More proper unit test with random sparsity pattern. c36be68 [Liang-Chi Hsieh] Add unit test for sqdist.
1 parent 4108e5f commit bb38ebb

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,16 +373,17 @@ object Vectors {
373373
var kv2 = 0
374374
val indices = v1.indices
375375
var squaredDistance = 0.0
376-
var iv1 = indices(kv1)
376+
val nnzv1 = indices.size
377377
val nnzv2 = v2.size
378+
var iv1 = if (nnzv1 > 0) indices(kv1) else -1
378379

379380
while (kv2 < nnzv2) {
380381
var score = 0.0
381382
if (kv2 != iv1) {
382383
score = v2(kv2)
383384
} else {
384385
score = v1.values(kv1) - v2(kv2)
385-
if (kv1 < indices.length - 1) {
386+
if (kv1 < nnzv1 - 1) {
386387
kv1 += 1
387388
iv1 = indices(kv1)
388389
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
package org.apache.spark.mllib.linalg
1919

20-
import breeze.linalg.{DenseMatrix => BDM}
20+
import scala.util.Random
21+
22+
import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
2123
import org.scalatest.FunSuite
2224

2325
import org.apache.spark.SparkException
@@ -175,6 +177,33 @@ class VectorsSuite extends FunSuite {
175177
assert(v.size === x.rows)
176178
}
177179

180+
test("sqdist") {
181+
val random = new Random()
182+
for (m <- 1 until 1000 by 100) {
183+
val nnz = random.nextInt(m)
184+
185+
val indices1 = random.shuffle(0 to m - 1).slice(0, nnz).sorted.toArray
186+
val values1 = Array.fill(nnz)(random.nextDouble)
187+
val sparseVector1 = Vectors.sparse(m, indices1, values1)
188+
189+
val indices2 = random.shuffle(0 to m - 1).slice(0, nnz).sorted.toArray
190+
val values2 = Array.fill(nnz)(random.nextDouble)
191+
val sparseVector2 = Vectors.sparse(m, indices2, values2)
192+
193+
val denseVector1 = Vectors.dense(sparseVector1.toArray)
194+
val denseVector2 = Vectors.dense(sparseVector2.toArray)
195+
196+
val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze)
197+
198+
// SparseVector vs. SparseVector
199+
assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
200+
// DenseVector vs. SparseVector
201+
assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
202+
// DenseVector vs. DenseVector
203+
assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8)
204+
}
205+
}
206+
178207
test("foreachActive") {
179208
val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0)
180209
val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0)))

0 commit comments

Comments
 (0)