Skip to content

Commit 512a2f1

Browse files
Lewuathemengxr
authored andcommitted
[SPARK-6615][MLLIB] Python API for Word2Vec
This is the sub-task of SPARK-6254. Wrap missing method for `Word2Vec` and `Word2VecModel`. Author: lewuathe <[email protected]> Closes apache#5296 from Lewuathe/SPARK-6615 and squashes the following commits: f14c304 [lewuathe] Reorder tests 1d326b9 [lewuathe] Merge master e2bedfb [lewuathe] Modify test cases afb866d [lewuathe] [SPARK-6615] Python API for Word2Vec
1 parent b52c7f9 commit 512a2f1

File tree

3 files changed

+64
-7
lines changed

3 files changed

+64
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,13 +476,15 @@ private[python] class PythonMLLibAPI extends Serializable {
476476
learningRate: Double,
477477
numPartitions: Int,
478478
numIterations: Int,
479-
seed: Long): Word2VecModelWrapper = {
479+
seed: Long,
480+
minCount: Int): Word2VecModelWrapper = {
480481
val word2vec = new Word2Vec()
481482
.setVectorSize(vectorSize)
482483
.setLearningRate(learningRate)
483484
.setNumPartitions(numPartitions)
484485
.setNumIterations(numIterations)
485486
.setSeed(seed)
487+
.setMinCount(minCount)
486488
try {
487489
val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
488490
new Word2VecModelWrapper(model)
@@ -516,6 +518,10 @@ private[python] class PythonMLLibAPI extends Serializable {
516518
val words = result.map(_._1)
517519
List(words, similarity).map(_.asInstanceOf[Object]).asJava
518520
}
521+
522+
def getVectors: JMap[String, JList[Float]] = {
523+
model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
524+
}
519525
}
520526

521527
/**

python/pyspark/mllib/feature.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,12 @@ def findSynonyms(self, word, num):
337337
words, similarity = self.call("findSynonyms", word, num)
338338
return zip(words, similarity)
339339

340+
def getVectors(self):
341+
"""
342+
Returns a map of words to their vector representations.
343+
"""
344+
return self.call("getVectors")
345+
340346

341347
class Word2Vec(object):
342348
"""
@@ -379,6 +385,7 @@ def __init__(self):
379385
self.numPartitions = 1
380386
self.numIterations = 1
381387
self.seed = random.randint(0, sys.maxint)
388+
self.minCount = 5
382389

383390
def setVectorSize(self, vectorSize):
384391
"""
@@ -417,6 +424,14 @@ def setSeed(self, seed):
417424
self.seed = seed
418425
return self
419426

427+
def setMinCount(self, minCount):
428+
"""
429+
Sets minCount, the minimum number of times a token must appear
430+
to be included in the word2vec model's vocabulary (default: 5).
431+
"""
432+
self.minCount = minCount
433+
return self
434+
420435
def fit(self, data):
421436
"""
422437
Computes the vector representation of each word in vocabulary.
@@ -428,7 +443,8 @@ def fit(self, data):
428443
raise TypeError("data should be an RDD of list of string")
429444
jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
430445
float(self.learningRate), int(self.numPartitions),
431-
int(self.numIterations), long(self.seed))
446+
int(self.numIterations), long(self.seed),
447+
int(self.minCount))
432448
return Word2VecModel(jmodel)
433449

434450

python/pyspark/mllib/tests.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from pyspark.mllib.regression import LabeledPoint
4343
from pyspark.mllib.random import RandomRDDs
4444
from pyspark.mllib.stat import Statistics
45+
from pyspark.mllib.feature import Word2Vec
4546
from pyspark.mllib.feature import IDF
4647
from pyspark.serializers import PickleSerializer
4748
from pyspark.sql import SQLContext
@@ -630,6 +631,12 @@ def test_right_number_of_results(self):
630631
self.assertIsNotNone(chi[1000])
631632

632633

634+
class SerDeTest(PySparkTestCase):
635+
def test_to_java_object_rdd(self): # SPARK-6660
636+
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
637+
self.assertEqual(_to_java_object_rdd(data).count(), 10)
638+
639+
633640
class FeatureTest(PySparkTestCase):
634641
def test_idf_model(self):
635642
data = [
@@ -643,11 +650,39 @@ def test_idf_model(self):
643650
self.assertEqual(len(idf), 11)
644651

645652

646-
class SerDeTest(PySparkTestCase):
647-
def test_to_java_object_rdd(self): # SPARK-6660
648-
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
649-
self.assertEqual(_to_java_object_rdd(data).count(), 10)
650-
653+
class Word2VecTests(PySparkTestCase):
654+
def test_word2vec_setters(self):
655+
data = [
656+
["I", "have", "a", "pen"],
657+
["I", "like", "soccer", "very", "much"],
658+
["I", "live", "in", "Tokyo"]
659+
]
660+
model = Word2Vec() \
661+
.setVectorSize(2) \
662+
.setLearningRate(0.01) \
663+
.setNumPartitions(2) \
664+
.setNumIterations(10) \
665+
.setSeed(1024) \
666+
.setMinCount(3)
667+
self.assertEquals(model.vectorSize, 2)
668+
self.assertTrue(model.learningRate < 0.02)
669+
self.assertEquals(model.numPartitions, 2)
670+
self.assertEquals(model.numIterations, 10)
671+
self.assertEquals(model.seed, 1024)
672+
self.assertEquals(model.minCount, 3)
673+
674+
def test_word2vec_get_vectors(self):
675+
data = [
676+
["a", "b", "c", "d", "e", "f", "g"],
677+
["a", "b", "c", "d", "e", "f"],
678+
["a", "b", "c", "d", "e"],
679+
["a", "b", "c", "d"],
680+
["a", "b", "c"],
681+
["a", "b"],
682+
["a"]
683+
]
684+
model = Word2Vec().fit(self.sc.parallelize(data))
685+
self.assertEquals(len(model.getVectors()), 3)
651686

652687
if __name__ == "__main__":
653688
if not _have_scipy:

0 commit comments

Comments
 (0)