Skip to content

Commit 5bbcd13

Browse files
azagrebinrxin
authored andcommitted
[SPARK-6117] [SQL] add describe function to DataFrame for summary statis...
Please review my solution for SPARK-6117 Author: azagrebin <[email protected]> Closes apache#5073 from azagrebin/SPARK-6117 and squashes the following commits: f9056ac [azagrebin] [SPARK-6117] [SQL] create one aggregation and split it locally into resulting DF, colocate test data with test case ddb3950 [azagrebin] [SPARK-6117] [SQL] simplify implementation, add test for DF without numeric columns 9daf31e [azagrebin] [SPARK-6117] [SQL] add describe function to DataFrame for summary statistics
1 parent f535802 commit 5bbcd13

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
4141
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
4242
import org.apache.spark.sql.jdbc.JDBCWriteDetails
4343
import org.apache.spark.sql.json.JsonRDD
44-
import org.apache.spark.sql.types.{NumericType, StructType}
44+
import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType}
4545
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
4646
import org.apache.spark.util.Utils
4747

@@ -751,6 +751,57 @@ class DataFrame private[sql](
751751
select(colNames :_*)
752752
}
753753

754+
/**
755+
* Compute numerical statistics for given columns of this [[DataFrame]]:
756+
* count, mean (avg), stddev (standard deviation), min, max.
757+
* Each row of the resulting [[DataFrame]] contains column with statistic name
758+
* and columns with statistic results for each given column.
759+
* If no columns are given then computes for all numerical columns.
760+
*
761+
* {{{
762+
* df.describe("age", "height")
763+
*
764+
* // summary age height
765+
* // count 10.0 10.0
766+
* // mean 53.3 178.05
767+
* // stddev 11.6 15.7
768+
* // min 18.0 163.0
769+
* // max 92.0 192.0
770+
* }}}
771+
*/
772+
@scala.annotation.varargs
773+
def describe(cols: String*): DataFrame = {
774+
775+
def stddevExpr(expr: Expression) =
776+
Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))
777+
778+
val statistics = List[(String, Expression => Expression)](
779+
"count" -> Count,
780+
"mean" -> Average,
781+
"stddev" -> stddevExpr,
782+
"min" -> Min,
783+
"max" -> Max)
784+
785+
val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
786+
787+
val localAgg = if (aggCols.nonEmpty) {
788+
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
789+
aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
790+
}
791+
792+
agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
793+
.grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
794+
Row(statistic :: aggregation.toList: _*)
795+
}
796+
} else {
797+
statistics.map { case (name, _) => Row(name) }
798+
}
799+
800+
val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType)))
801+
val rowRdd = sqlContext.sparkContext.parallelize(localAgg)
802+
sqlContext.createDataFrame(rowRdd, schema)
803+
}
804+
754805
/**
755806
* Returns the first `n` rows.
756807
* @group action

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,51 @@ class DataFrameSuite extends QueryTest {
443443
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
444444
}
445445

446+
test("describe") {
447+
448+
val describeTestData = Seq(
449+
("Bob", 16, 176),
450+
("Alice", 32, 164),
451+
("David", 60, 192),
452+
("Amy", 24, 180)).toDF("name", "age", "height")
453+
454+
val describeResult = Seq(
455+
Row("count", 4, 4),
456+
Row("mean", 33.0, 178.0),
457+
Row("stddev", 16.583123951777, 10.0),
458+
Row("min", 16, 164),
459+
Row("max", 60, 192))
460+
461+
val emptyDescribeResult = Seq(
462+
Row("count", 0, 0),
463+
Row("mean", null, null),
464+
Row("stddev", null, null),
465+
Row("min", null, null),
466+
Row("max", null, null))
467+
468+
def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq
469+
470+
val describeTwoCols = describeTestData.describe("age", "height")
471+
assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height"))
472+
checkAnswer(describeTwoCols, describeResult)
473+
474+
val describeAllCols = describeTestData.describe()
475+
assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height"))
476+
checkAnswer(describeAllCols, describeResult)
477+
478+
val describeOneCol = describeTestData.describe("age")
479+
assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
480+
checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} )
481+
482+
val describeNoCol = describeTestData.select("name").describe()
483+
assert(getSchemaAsSeq(describeNoCol) === Seq("summary"))
484+
checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} )
485+
486+
val emptyDescription = describeTestData.limit(0).describe()
487+
assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height"))
488+
checkAnswer(emptyDescription, emptyDescribeResult)
489+
}
490+
446491
test("apply on query results (SPARK-5462)") {
447492
val df = testData.sqlContext.sql("select key from testData")
448493
checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)

0 commit comments

Comments
 (0)