Skip to content

Commit 8bcfc61

Browse files
author
Ji ZHANG
committed
euclidean
1 parent 4364986 commit 8bcfc61

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

src/main/scala/recommendation/SimilarityRecommender.scala

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,19 +261,31 @@ object SimilarityRecommender extends Recommender {
261261

262262
def euclidean(mat: CoordinateMatrix): CoordinateMatrix = {
263263

264-
// transpose
265-
// normalize
266-
// cartesian
267-
// calc
268-
269-
mat.transpose.toIndexedRowMatrix.rows.map { row =>
270-
val values = row.vector.toSparse.indices.map { i =>
271-
i -> row.vector(i)
264+
val normEntries = mat.toIndexedRowMatrix.rows.flatMap { row =>
265+
val values = row.vector.toSparse.indices.map { j =>
266+
j -> row.vector(j)
272267
}
273-
Vectors.sparse(row.vector.size, normalize(values).toSeq)
268+
normalize(values).filterNot(_._2.isNaN).map { case (j, value) =>
269+
MatrixEntry(row.index, j, value)
270+
}
271+
}
272+
val normMat = new CoordinateMatrix(normEntries)
273+
274+
val pairs = normMat.toRowMatrix.rows.flatMap { v =>
275+
val indices = v.toSparse.indices
276+
indices.flatMap { i =>
277+
indices.filter(_ > i).map { j =>
278+
(i, j) -> math.pow(v(i) - v(j), 2)
279+
}
280+
}
281+
}.reduceByKey(_ + _)
282+
283+
val entries = pairs.flatMap { case ((i, j), v) =>
284+
val value = math.sqrt(v)
285+
Seq(MatrixEntry(i, j, value), MatrixEntry(j, i, value))
274286
}
275287

276-
null
288+
new CoordinateMatrix(entries)
277289
}
278290

279291
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package recommendation
2+
3+
import org.scalatest.FunSuite
4+
import SimilarityRecommender._
5+
6+
7+
class SimilarityRecommenderSuite extends FunSuite {
8+
9+
test("normalize") {
10+
val input = Seq[Double](1, 2, 3).zipWithIndex.map(_.swap)
11+
assert(normalize(input).map(_._2) == input.map(_._2).map(v => (v - 2) / math.sqrt(2.0 / 3)))
12+
}
13+
14+
}

0 commit comments

Comments
 (0)