Skip to content

Commit a780703

Browse files
committed
[SPARK-42702][SPARK-42623][SQL] Support parameterized query in subquery and CTE
### What changes were proposed in this pull request? This PR fixes a few issues of parameterized query: 1. replace placeholders in CTE/subqueries 2. don't replace placeholders in non-DML commands as it may store the original SQL text with placeholders and we can't resolve it later (e.g. CREATE VIEW). ### Why are the changes needed? make the parameterized query feature complete ### Does this PR introduce _any_ user-facing change? yes, bug fix ### How was this patch tested? new tests Closes apache#40333 from cloud-fan/parameter. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 9c2135f commit a780703

File tree

11 files changed

+247
-86
lines changed

11 files changed

+247
-86
lines changed

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
3232
import org.apache.spark.connect.proto.Parse.ParseFormat
3333
import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
3434
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
35-
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
35+
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
3636
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3737
import org.apache.spark.sql.catalyst.expressions._
3838
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
@@ -209,8 +209,12 @@ class SparkConnectPlanner(val session: SparkSession) {
209209
private def transformSql(sql: proto.SQL): LogicalPlan = {
210210
val args = sql.getArgsMap.asScala.toMap
211211
val parser = session.sessionState.sqlParser
212-
val parsedArgs = args.mapValues(parser.parseExpression).toMap
213-
Parameter.bind(parser.parsePlan(sql.getQuery), parsedArgs)
212+
val parsedPlan = parser.parsePlan(sql.getQuery)
213+
if (args.nonEmpty) {
214+
ParameterizedQuery(parsedPlan, args.mapValues(parser.parseExpression).toMap)
215+
} else {
216+
parsedPlan
217+
}
214218
}
215219

216220
private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = {

core/src/main/resources/error/error-classes.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,6 +1730,11 @@
17301730
"Pandas user defined aggregate function in the PIVOT clause."
17311731
]
17321732
},
1733+
"PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT" : {
1734+
"message" : [
1735+
"Parameter markers in unexpected statement: <statement>. Parameter markers must only be used in a query, or DML statement."
1736+
]
1737+
},
17331738
"PIVOT_AFTER_GROUP_BY" : {
17341739
"message" : [
17351740
"PIVOT clause following a GROUP BY clause. Consider pushing the GROUP BY into a subquery."

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
265265
// at the beginning of analysis.
266266
OptimizeUpdateFields,
267267
CTESubstitution,
268+
BindParameters,
268269
WindowsSubstitution,
269270
EliminateUnions,
270271
SubstituteUnresolvedOrdinals),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
336336
case p: Parameter =>
337337
p.failAnalysis(
338338
errorClass = "UNBOUND_SQL_PARAMETER",
339-
messageParameters = Map("name" -> toSQLId(p.name)))
339+
messageParameters = Map("name" -> p.name))
340340

341341
case _ =>
342342
})
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.SparkException
21+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, LeafExpression, Literal, SubqueryExpression, Unevaluable}
22+
import org.apache.spark.sql.catalyst.plans.logical.{Command, DeleteFromTable, InsertIntoStatement, LogicalPlan, MergeIntoTable, UnaryNode, UpdateTable}
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
25+
import org.apache.spark.sql.errors.QueryErrorsBase
26+
import org.apache.spark.sql.types.DataType
27+
28+
/**
29+
* The expression represents a named parameter that should be replaced by a literal.
30+
*
31+
* @param name The identifier of the parameter without the marker.
32+
*/
33+
case class Parameter(name: String) extends LeafExpression with Unevaluable {
34+
override lazy val resolved: Boolean = false
35+
36+
private def unboundError(methodName: String): Nothing = {
37+
throw SparkException.internalError(
38+
s"Cannot call `$methodName()` of the unbound parameter `$name`.")
39+
}
40+
override def dataType: DataType = unboundError("dataType")
41+
override def nullable: Boolean = unboundError("nullable")
42+
43+
final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER)
44+
}
45+
46+
/**
47+
* The logical plan representing a parameterized query. It will be removed during analysis after
48+
* the parameters are bind.
49+
*/
50+
case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression]) extends UnaryNode {
51+
assert(args.nonEmpty)
52+
override def output: Seq[Attribute] = Nil
53+
override lazy val resolved = false
54+
final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETERIZED_QUERY)
55+
override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
56+
copy(child = newChild)
57+
}
58+
59+
/**
60+
* Finds all named parameters in `ParameterizedQuery` and substitutes them by literals from the
61+
* user-specified arguments.
62+
*/
63+
object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
64+
override def apply(plan: LogicalPlan): LogicalPlan = {
65+
if (plan.containsPattern(PARAMETERIZED_QUERY)) {
66+
// One unresolved plan can have at most one ParameterizedQuery.
67+
val parameterizedQueries = plan.collect { case p: ParameterizedQuery => p }
68+
assert(parameterizedQueries.length == 1)
69+
}
70+
71+
plan.resolveOperatorsWithPruning(_.containsPattern(PARAMETERIZED_QUERY)) {
72+
// We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE
73+
// relations are not children of `UnresolvedWith`.
74+
case p @ ParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
75+
// Some commands may store the original SQL text, like CREATE VIEW, GENERATED COLUMN, etc.
76+
// We can't store the original SQL text with parameters, as we don't store the arguments and
77+
// are not able to resolve it after parsing it back. Since parameterized query is mostly
78+
// used to avoid SQL injection for SELECT queries, we simply forbid non-DML commands here.
79+
child match {
80+
case _: InsertIntoStatement => // OK
81+
case _: UpdateTable => // OK
82+
case _: DeleteFromTable => // OK
83+
case _: MergeIntoTable => // OK
84+
case cmd: Command =>
85+
child.failAnalysis(
86+
errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
87+
messageParameters = Map("statement" -> cmd.nodeName)
88+
)
89+
case _ => // OK
90+
}
91+
92+
args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) =>
93+
expr.failAnalysis(
94+
errorClass = "INVALID_SQL_ARG",
95+
messageParameters = Map("name" -> name))
96+
}
97+
98+
def bind(p: LogicalPlan): LogicalPlan = {
99+
p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) {
100+
case Parameter(name) if args.contains(name) =>
101+
args(name)
102+
case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan))
103+
}
104+
}
105+
val res = bind(child)
106+
res.copyTagsFrom(p)
107+
res
108+
109+
case _ => plan
110+
}
111+
}
112+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala

Lines changed: 0 additions & 64 deletions
This file was deleted.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ object TreePattern extends Enumeration {
7373
val OR: Value = Value
7474
val OUTER_REFERENCE: Value = Value
7575
val PARAMETER: Value = Value
76+
val PARAMETERIZED_QUERY: Value = Value
7677
val PIVOT: Value = Value
7778
val PLAN_EXPRESSION: Value = Value
7879
val PYTHON_UDF: Value = Value

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,17 +1346,21 @@ class AnalysisSuite extends AnalysisTest with Matchers {
13461346
}
13471347

13481348
test("SPARK-41271: bind named parameters to literals") {
1349-
comparePlans(
1350-
Parameter.bind(
1351-
plan = parsePlan("SELECT * FROM a LIMIT :limitA"),
1352-
args = Map("limitA" -> Literal(10))),
1353-
parsePlan("SELECT * FROM a LIMIT 10"))
1349+
CTERelationDef.curId.set(0)
1350+
val actual1 = ParameterizedQuery(
1351+
child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT :limitA"),
1352+
args = Map("limitA" -> Literal(10))).analyze
1353+
CTERelationDef.curId.set(0)
1354+
val expected1 = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT 10").analyze
1355+
comparePlans(actual1, expected1)
13541356
// Ignore unused arguments
1355-
comparePlans(
1356-
Parameter.bind(
1357-
plan = parsePlan("SELECT c FROM a WHERE c < :param2"),
1358-
args = Map("param1" -> Literal(10), "param2" -> Literal(20))),
1359-
parsePlan("SELECT c FROM a WHERE c < 20"))
1357+
CTERelationDef.curId.set(0)
1358+
val actual2 = ParameterizedQuery(
1359+
child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < :param2"),
1360+
args = Map("param1" -> Literal(10), "param2" -> Literal(20))).analyze
1361+
CTERelationDef.curId.set(0)
1362+
val expected2 = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < 20").analyze
1363+
comparePlans(actual2, expected2)
13601364
}
13611365

13621366
test("SPARK-41489: type of filter expression should be a bool") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser
1919

2020
import org.apache.spark.SparkThrowable
2121
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
22-
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
22+
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Parameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.{PercentileCont, PercentileDisc}
2525
import org.apache.spark.sql.catalyst.plans._

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ import org.apache.spark.rdd.RDD
3535
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
3636
import org.apache.spark.sql.catalog.Catalog
3737
import org.apache.spark.sql.catalyst._
38-
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
38+
import org.apache.spark.sql.catalyst.analysis.{ParameterizedQuery, UnresolvedRelation}
3939
import org.apache.spark.sql.catalyst.encoders._
40-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Parameter}
40+
import org.apache.spark.sql.catalyst.expressions.AttributeReference
4141
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
4242
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
4343
import org.apache.spark.sql.connector.ExternalCommandRunner
@@ -623,8 +623,12 @@ class SparkSession private(
623623
val tracker = new QueryPlanningTracker
624624
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
625625
val parser = sessionState.sqlParser
626-
val parsedArgs = args.mapValues(parser.parseExpression).toMap
627-
Parameter.bind(parser.parsePlan(sqlText), parsedArgs)
626+
val parsedPlan = parser.parsePlan(sqlText)
627+
if (args.nonEmpty) {
628+
ParameterizedQuery(parsedPlan, args.mapValues(parser.parseExpression).toMap)
629+
} else {
630+
parsedPlan
631+
}
628632
}
629633
Dataset.ofRows(self, plan, tracker)
630634
}

0 commit comments

Comments
 (0)