-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[SPARK-52588][SQL] Approx_top_k: accumulate and estimate #51308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
yhuang-db
wants to merge
21
commits into
apache:master
Choose a base branch
from
yhuang-db:SPARK-52588
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+483
−45
Open
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
bbd8ae4
init ApproxTopKAccumulate
yhuang-db 9015bbb
init ApproxTopKCombine, combineSizeSpecified undone
yhuang-db f369de7
init ApproxTopKCombine, combineSizeSpecified undone
yhuang-db 1f319cf
init ApproxTopKEstimate
yhuang-db c09fd83
init accumulate and estimate tests
yhuang-db 612f8d8
unfinished estimate null check
yhuang-db 03e432d
fix estimate null check
yhuang-db 88cae20
estimate and accumulate invalid parameter test
yhuang-db ccfc661
remove combine for PR
yhuang-db d89c121
remove combine for PR
yhuang-db b95ff7a
separate expression suite and query suite
yhuang-db edaf18e
add expression doc
yhuang-db b3811a7
add accumulation doc
yhuang-db 7e1e519
nit doc
yhuang-db a9153aa
update expression type check test
yhuang-db ae6cc81
remove k and max type check test
yhuang-db d60702d
add upper limit test for accumulate
yhuang-db b616e7c
add invalid value tests
yhuang-db fa8d569
fix ApproxTopKAccumulate doc
yhuang-db 2f69184
Merge branch 'master' into SPARK-52588
yhuang-db 2d24c68
fix sql test
yhuang-db File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
...lyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.sql.catalyst.expressions | ||
|
||
import org.apache.datasketches.frequencies.ItemsSketch | ||
import org.apache.datasketches.memory.Memory | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry | ||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK | ||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback | ||
import org.apache.spark.sql.types._ | ||
|
||
/** | ||
* An expression that estimates the top K items from a sketch. | ||
* | ||
* The input is a sketch state that is generated by the ApproxTopKAccumulation function. | ||
* The output is an array of structs, each containing a frequent item and its estimated frequency. | ||
* The items are sorted by their estimated frequency in descending order. | ||
* | ||
* @param state The sketch state, which is a struct containing the serialized sketch data, | ||
* the original data type and the max items tracked of the sketch. | ||
* @param k The number of top items to estimate. | ||
*/ | ||
// scalastyle:off line.size.limit | ||
@ExpressionDescription( | ||
usage = """ | ||
_FUNC_(state, k) - Returns top k items with their frequency. | ||
`k` An optional INTEGER literal greater than 0. If k is not specified, it defaults to 5. | ||
""", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_(approx_top_k_accumulate(expr)) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); | ||
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] | ||
|
||
> SELECT _FUNC_(approx_top_k_accumulate(expr), 2) FROM VALUES 'a', 'b', 'c', 'c', 'c', 'c', 'd', 'd' tab(expr); | ||
[{"item":"c","count":4},{"item":"d","count":2}] | ||
""", | ||
group = "misc_funcs", | ||
since = "4.1.0") | ||
// scalastyle:on line.size.limit | ||
case class ApproxTopKEstimate(state: Expression, k: Expression) | ||
extends BinaryExpression | ||
with CodegenFallback | ||
with ImplicitCastInputTypes { | ||
|
||
def this(child: Expression, topK: Int) = this(child, Literal(topK)) | ||
|
||
def this(child: Expression) = this(child, Literal(ApproxTopK.DEFAULT_K)) | ||
|
||
private lazy val itemDataType: DataType = { | ||
// itemDataType is the type of the "ItemTypeNull" field of the output of ACCUMULATE or COMBINE | ||
state.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType | ||
} | ||
|
||
override def left: Expression = state | ||
|
||
override def right: Expression = k | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) | ||
|
||
override def checkInputDataTypes(): TypeCheckResult = { | ||
val defaultCheck = super.checkInputDataTypes() | ||
if (defaultCheck.isFailure) { | ||
defaultCheck | ||
} else if (!k.foldable) { | ||
TypeCheckFailure("K must be a constant literal") | ||
} else { | ||
TypeCheckSuccess | ||
} | ||
} | ||
|
||
override def dataType: DataType = ApproxTopK.getResultDataType(itemDataType) | ||
|
||
override def eval(input: InternalRow): Any = { | ||
// null check | ||
ApproxTopK.checkExpressionNotNull(k, "k") | ||
// eval | ||
val stateEval = left.eval(input) | ||
val kEval = right.eval(input) | ||
val dataSketchBytes = stateEval.asInstanceOf[InternalRow].getBinary(0) | ||
val maxItemsTrackedVal = stateEval.asInstanceOf[InternalRow].getInt(2) | ||
val kVal = kEval.asInstanceOf[Int] | ||
ApproxTopK.checkK(kVal) | ||
ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal, kVal) | ||
val itemsSketch = ItemsSketch.getInstance( | ||
Memory.wrap(dataSketchBytes), ApproxTopK.genSketchSerDe(itemDataType)) | ||
ApproxTopK.genEvalResult(itemsSketch, kVal, itemDataType) | ||
} | ||
|
||
override protected def withNewChildrenInternal(newState: Expression, newK: Expression) | ||
: Expression = copy(state = newState, k = newK) | ||
|
||
override def nullable: Boolean = false | ||
|
||
override def prettyName: String = | ||
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_estimate") | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} | ||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} | ||
import org.apache.spark.sql.catalyst.expressions.{ArrayOfDecimalsSerDe, Expression, ExpressionDescription, ImplicitCastInputTypes, Literal} | ||
import org.apache.spark.sql.catalyst.trees.TernaryLike | ||
import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike} | ||
import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData} | ||
import org.apache.spark.sql.errors.QueryExecutionErrors | ||
import org.apache.spark.sql.types._ | ||
|
@@ -53,8 +53,8 @@ import org.apache.spark.unsafe.types.UTF8String | |
usage = """ | ||
_FUNC_(expr, k, maxItemsTracked) - Returns top k items with their frequency. | ||
`k` An optional INTEGER literal greater than 0. If k is not specified, it defaults to 5. | ||
`maxItemsTracked` An optional INTEGER literal greater than or equal to k. If maxItemsTracked is not specified, it defaults to 10000. | ||
""", | ||
`maxItemsTracked` An optional INTEGER literal greater than or equal to k and has upper limit of 1000000. If maxItemsTracked is not specified, it defaults to 10000. | ||
""", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_(expr) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); | ||
|
@@ -173,40 +173,47 @@ case class ApproxTopK( | |
|
||
object ApproxTopK { | ||
|
||
private val DEFAULT_K: Int = 5 | ||
private val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 | ||
val DEFAULT_K: Int = 5 | ||
val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 | ||
private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000 | ||
|
||
private def checkExpressionNotNull(expr: Expression, exprName: String): Unit = { | ||
def checkExpressionNotNull(expr: Expression, exprName: String): Unit = { | ||
if (expr == null || expr.eval() == null) { | ||
throw QueryExecutionErrors.approxTopKNullArg(exprName) | ||
} | ||
} | ||
|
||
private def checkK(k: Int): Unit = { | ||
def checkK(k: Int): Unit = { | ||
if (k <= 0) { | ||
throw QueryExecutionErrors.approxTopKNonPositiveValue("k", k) | ||
} | ||
} | ||
|
||
private def checkMaxItemsTracked(maxItemsTracked: Int, k: Int): Unit = { | ||
def checkMaxItemsTracked(maxItemsTracked: Int): Unit = { | ||
if (maxItemsTracked > MAX_ITEMS_TRACKED_LIMIT) { | ||
throw QueryExecutionErrors.approxTopKMaxItemsTrackedExceedsLimit( | ||
maxItemsTracked, MAX_ITEMS_TRACKED_LIMIT) | ||
} | ||
if (maxItemsTracked <= 0) { | ||
throw QueryExecutionErrors.approxTopKNonPositiveValue("maxItemsTracked", maxItemsTracked) | ||
} | ||
} | ||
|
||
def checkMaxItemsTracked(maxItemsTracked: Int, k: Int): Unit = { | ||
checkMaxItemsTracked(maxItemsTracked) | ||
if (maxItemsTracked < k) { | ||
throw QueryExecutionErrors.approxTopKMaxItemsTrackedLessThanK(maxItemsTracked, k) | ||
} | ||
} | ||
|
||
private def getResultDataType(itemDataType: DataType): DataType = { | ||
def getResultDataType(itemDataType: DataType): DataType = { | ||
val resultEntryType = StructType( | ||
StructField("item", itemDataType, nullable = false) :: | ||
StructField("count", LongType, nullable = false) :: Nil) | ||
ArrayType(resultEntryType, containsNull = false) | ||
} | ||
|
||
private def isDataTypeSupported(itemType: DataType): Boolean = { | ||
def isDataTypeSupported(itemType: DataType): Boolean = { | ||
itemType match { | ||
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | | ||
_: LongType | _: FloatType | _: DoubleType | _: DateType | | ||
|
@@ -216,13 +223,14 @@ object ApproxTopK { | |
} | ||
} | ||
|
||
private def calMaxMapSize(maxItemsTracked: Int): Int = { | ||
def calMaxMapSize(maxItemsTracked: Int): Int = { | ||
// The maximum capacity of this internal hash map has maxMapCap = 0.75 * maxMapSize | ||
// Therefore, the maxMapSize must be at least ceil(maxItemsTracked / 0.75) | ||
// https://datasketches.apache.org/docs/Frequency/FrequentItemsOverview.html | ||
val ceilMaxMapSize = math.ceil(maxItemsTracked / 0.75).toInt | ||
// The maxMapSize must be a power of 2 and greater than ceilMaxMapSize | ||
math.pow(2, math.ceil(math.log(ceilMaxMapSize) / math.log(2))).toInt | ||
val maxMapSize = math.pow(2, math.ceil(math.log(ceilMaxMapSize) / math.log(2))).toInt | ||
maxMapSize | ||
} | ||
|
||
def createAggregationBuffer(itemExpression: Expression, maxMapSize: Int): ItemsSketch[Any] = { | ||
|
@@ -242,7 +250,7 @@ object ApproxTopK { | |
} | ||
} | ||
|
||
private def updateSketchBuffer( | ||
def updateSketchBuffer( | ||
itemExpression: Expression, | ||
buffer: ItemsSketch[Any], | ||
input: InternalRow): ItemsSketch[Any] = { | ||
|
@@ -268,7 +276,7 @@ object ApproxTopK { | |
buffer | ||
} | ||
|
||
private def genEvalResult( | ||
def genEvalResult( | ||
itemsSketch: ItemsSketch[Any], | ||
k: Int, | ||
itemDataType: DataType): GenericArrayData = { | ||
|
@@ -290,7 +298,7 @@ object ApproxTopK { | |
new GenericArrayData(result) | ||
} | ||
|
||
private def genSketchSerDe(dataType: DataType): ArrayOfItemsSerDe[Any] = { | ||
def genSketchSerDe(dataType: DataType): ArrayOfItemsSerDe[Any] = { | ||
dataType match { | ||
case _: BooleanType => new ArrayOfBooleansSerDe().asInstanceOf[ArrayOfItemsSerDe[Any]] | ||
case _: ByteType | _: ShortType | _: IntegerType | _: FloatType | _: DateType => | ||
|
@@ -305,4 +313,123 @@ object ApproxTopK { | |
new ArrayOfDecimalsSerDe(dt).asInstanceOf[ArrayOfItemsSerDe[Any]] | ||
} | ||
} | ||
|
||
def getSketchStateDataType(itemDataType: DataType): StructType = | ||
StructType( | ||
StructField("Sketch", BinaryType, nullable = false) :: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
StructField("ItemTypeNull", itemDataType) :: | ||
StructField("MaxItemsTracked", IntegerType, nullable = false) :: Nil) | ||
} | ||
|
||
/** | ||
* An aggregate function that accumulates items into a sketch, which can then be used | ||
* to combine with other sketches, via ApproxTopKCombine, | ||
* or to estimate the top K items, via ApproxTopKEstimate. | ||
* | ||
* The output of this function is a struct containing the sketch in binary format, | ||
* a null object indicating the type of items in the sketch, | ||
* and the maximum number of items tracked by the sketch. | ||
* | ||
* @param expr the child expression to accumulate items from | ||
* @param maxItemsTracked the maximum number of items to track in the sketch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also add doc for mutableAggBufferOffset and inputAggBufferOffset |
||
*/ | ||
// scalastyle:off line.size.limit | ||
@ExpressionDescription( | ||
usage = """ | ||
_FUNC_(expr, maxItemsTracked) - Accumulates items into a sketch. | ||
`maxItemsTracked` An optional positive INTEGER literal with upper limit of 1000000. If maxItemsTracked is not specified, it defaults to 10000. | ||
""", | ||
examples = """ | ||
Examples: | ||
> SELECT approx_top_k_estimate(_FUNC_(expr)) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); | ||
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] | ||
> SELECT approx_top_k_estimate(_FUNC_(expr, 100), 2) FROM VALUES 'a', 'b', 'c', 'c', 'c', 'c', 'd', 'd' AS tab(expr); | ||
[{"item":"c","count":4},{"item":"d","count":2}] | ||
""", | ||
group = "agg_funcs", | ||
since = "4.1.0") | ||
// scalastyle:on line.size.limit | ||
case class ApproxTopKAccumulate( | ||
expr: Expression, | ||
maxItemsTracked: Expression, | ||
mutableAggBufferOffset: Int = 0, | ||
inputAggBufferOffset: Int = 0) | ||
extends TypedImperativeAggregate[ItemsSketch[Any]] | ||
with ImplicitCastInputTypes | ||
with BinaryLike[Expression] { | ||
|
||
def this(child: Expression, maxItemsTracked: Expression) = this(child, maxItemsTracked, 0, 0) | ||
|
||
def this(child: Expression, maxItemsTracked: Int) = this(child, Literal(maxItemsTracked), 0, 0) | ||
|
||
def this(child: Expression) = this(child, Literal(ApproxTopK.DEFAULT_MAX_ITEMS_TRACKED), 0, 0) | ||
|
||
private lazy val itemDataType: DataType = expr.dataType | ||
|
||
private lazy val maxItemsTrackedVal: Int = { | ||
ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked") | ||
val maxItemsTrackedVal = maxItemsTracked.eval().asInstanceOf[Int] | ||
ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal) | ||
maxItemsTrackedVal | ||
} | ||
|
||
override def left: Expression = expr | ||
|
||
override def right: Expression = maxItemsTracked | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) | ||
|
||
override def checkInputDataTypes(): TypeCheckResult = { | ||
val defaultCheck = super.checkInputDataTypes() | ||
if (defaultCheck.isFailure) { | ||
defaultCheck | ||
} else if (!ApproxTopK.isDataTypeSupported(itemDataType)) { | ||
TypeCheckFailure(f"${itemDataType.typeName} columns are not supported") | ||
} else if (!maxItemsTracked.foldable) { | ||
TypeCheckFailure("Number of items tracked must be a constant literal") | ||
} else { | ||
TypeCheckSuccess | ||
} | ||
} | ||
|
||
override def dataType: DataType = ApproxTopK.getSketchStateDataType(itemDataType) | ||
|
||
override def createAggregationBuffer(): ItemsSketch[Any] = { | ||
val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) | ||
ApproxTopK.createAggregationBuffer(expr, maxMapSize) | ||
} | ||
|
||
override def update(buffer: ItemsSketch[Any], input: InternalRow): ItemsSketch[Any] = | ||
ApproxTopK.updateSketchBuffer(expr, buffer, input) | ||
|
||
override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]): ItemsSketch[Any] = | ||
buffer.merge(input) | ||
|
||
override def eval(buffer: ItemsSketch[Any]): Any = { | ||
val sketchBytes = serialize(buffer) | ||
InternalRow.apply(sketchBytes, null, maxItemsTrackedVal) | ||
} | ||
|
||
override def serialize(buffer: ItemsSketch[Any]): Array[Byte] = | ||
buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType)) | ||
|
||
override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] = | ||
ItemsSketch.getInstance(Memory.wrap(storageFormat), ApproxTopK.genSketchSerDe(itemDataType)) | ||
|
||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
||
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
||
override protected def withNewChildrenInternal( | ||
newLeft: Expression, | ||
newRight: Expression): Expression = | ||
copy(expr = newLeft, maxItemsTracked = newRight) | ||
|
||
override def nullable: Boolean = false | ||
|
||
override def prettyName: String = | ||
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate") | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we also check the
StructType
ofstate
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, let's add test for this.