Skip to content

Commit 2a24c48

Browse files
lu-wang-dljkbradley
authored andcommitted
[SPARK-23975][ML] Allow Clustering to take Arrays of Double as input features
## What changes were proposed in this pull request? - Multiple possible input types is added in validateAndTransformSchema() and computeCost() while checking column type - Add if statement in transform() to support array type as featuresCol - Add the case statement in fit() while selecting columns from dataset These changes will be applied to KMeans first, then to other clustering method ## How was this patch tested? unit test is added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG <[email protected]> Closes apache#21081 from ludatabricks/SPARK-23975.
1 parent 55c4ca8 commit 2a24c48

File tree

3 files changed

+126
-7
lines changed

3 files changed

+126
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors
3333
import org.apache.spark.mllib.linalg.VectorImplicits._
3434
import org.apache.spark.rdd.RDD
3535
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
36-
import org.apache.spark.sql.functions.{col, udf}
37-
import org.apache.spark.sql.types.{IntegerType, StructType}
36+
import org.apache.spark.sql.functions.udf
37+
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}
3838
import org.apache.spark.storage.StorageLevel
3939
import org.apache.spark.util.VersionUtils.majorVersion
4040

@@ -86,13 +86,24 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
8686
@Since("1.5.0")
8787
def getInitSteps: Int = $(initSteps)
8888

89+
/**
90+
* Validates the input schema.
91+
* @param schema input schema
92+
*/
93+
private[clustering] def validateSchema(schema: StructType): Unit = {
94+
val typeCandidates = List( new VectorUDT,
95+
new ArrayType(DoubleType, false),
96+
new ArrayType(FloatType, false))
97+
98+
SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates)
99+
}
89100
/**
90101
* Validates and transforms the input schema.
91102
* @param schema input schema
92103
* @return output schema
93104
*/
94105
protected def validateAndTransformSchema(schema: StructType): StructType = {
95-
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
106+
validateSchema(schema)
96107
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
97108
}
98109
}
@@ -125,8 +136,11 @@ class KMeansModel private[ml] (
125136
@Since("2.0.0")
126137
override def transform(dataset: Dataset[_]): DataFrame = {
127138
transformSchema(dataset.schema, logging = true)
139+
128140
val predictUDF = udf((vector: Vector) => predict(vector))
129-
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
141+
142+
dataset.withColumn($(predictionCol),
143+
predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
130144
}
131145

132146
@Since("1.5.0")
@@ -146,8 +160,10 @@ class KMeansModel private[ml] (
146160
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
147161
@Since("2.0.0")
148162
def computeCost(dataset: Dataset[_]): Double = {
149-
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
150-
val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
163+
validateSchema(dataset.schema)
164+
165+
val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol))
166+
.rdd.map {
151167
case Row(point: Vector) => OldVectors.fromML(point)
152168
}
153169
parentModel.computeCost(data)
@@ -335,7 +351,9 @@ class KMeans @Since("1.5.0") (
335351
transformSchema(dataset.schema, logging = true)
336352

337353
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
338-
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
354+
val instances: RDD[OldVector] = dataset.select(
355+
DatasetUtils.columnToVector(dataset, getFeaturesCol))
356+
.rdd.map {
339357
case Row(point: Vector) => OldVectors.fromML(point)
340358
}
341359

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.util
19+
20+
import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
21+
import org.apache.spark.sql.{Column, Dataset}
22+
import org.apache.spark.sql.functions.{col, udf}
23+
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType}
24+
25+
26+
private[spark] object DatasetUtils {
27+
28+
/**
29+
* Cast a column in a Dataset to Vector type.
30+
*
31+
* The supported data types of the input column are
32+
* - Vector
33+
* - float/double type Array.
34+
*
35+
* Note: The returned column does not have Metadata.
36+
*
37+
* @param dataset input DataFrame
38+
* @param colName column name.
39+
* @return Vector column
40+
*/
41+
def columnToVector(dataset: Dataset[_], colName: String): Column = {
42+
val columnDataType = dataset.schema(colName).dataType
43+
columnDataType match {
44+
case _: VectorUDT => col(colName)
45+
case fdt: ArrayType =>
46+
val transferUDF = fdt.elementType match {
47+
case _: FloatType => udf(f = (vector: Seq[Float]) => {
48+
val inputArray = Array.fill[Double](vector.size)(0.0)
49+
vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble)
50+
Vectors.dense(inputArray)
51+
})
52+
case _: DoubleType => udf((vector: Seq[Double]) => {
53+
Vectors.dense(vector.toArray)
54+
})
55+
case other =>
56+
throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector")
57+
}
58+
transferUDF(col(colName))
59+
case other =>
60+
throw new IllegalArgumentException(s"$other column cannot be cast to Vector")
61+
}
62+
}
63+
}

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans
3030
import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
3131
import org.apache.spark.mllib.util.MLlibTestSparkContext
3232
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
33+
import org.apache.spark.sql.functions._
34+
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}
3335

3436
private[clustering] case class TestRow(features: Vector)
3537

@@ -199,6 +201,42 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
199201
assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
200202
}
201203

204+
test("KMean with Array input") {
205+
val featuresColNameD = "array_double_features"
206+
val featuresColNameF = "array_float_features"
207+
208+
val doubleUDF = udf { (features: Vector) =>
209+
val featureArray = Array.fill[Double](features.size)(0.0)
210+
features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
211+
featureArray
212+
}
213+
val floatUDF = udf { (features: Vector) =>
214+
val featureArray = Array.fill[Float](features.size)(0.0f)
215+
features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
216+
featureArray
217+
}
218+
219+
val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features")))
220+
.drop("features")
221+
val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features")))
222+
.drop("features")
223+
assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false)))
224+
assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false)))
225+
226+
val kmeansD = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1)
227+
val kmeansF = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1)
228+
val modelD = kmeansD.fit(newdatasetD)
229+
val modelF = kmeansF.fit(newdatasetF)
230+
val transformedD = modelD.transform(newdatasetD)
231+
val transformedF = modelF.transform(newdatasetF)
232+
233+
val predictDifference = transformedD.select("prediction")
234+
.except(transformedF.select("prediction"))
235+
assert(predictDifference.count() == 0)
236+
assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) )
237+
}
238+
239+
202240
test("read/write") {
203241
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
204242
assert(model.clusterCenters === model2.clusterCenters)

0 commit comments

Comments
 (0)