Skip to content

Commit 35fa5e6

Browse files
committed
[SPARK-41271][SQL] Support parameterized SQL queries by sql()
### What changes were proposed in this pull request? In the PR, I propose to extend SparkSession API and override the `sql` method by: ```scala def sql(sqlText: String, args: Map[String, String]): DataFrame ``` which accepts a map with: - keys are parameters names, - values are SQL literal values. And the first argument `sqlText` might have named parameters in the positions of constants like literal values. For example: ```scala spark.sql( sqlText = "SELECT * FROM tbl WHERE date > :startDate LIMIT :maxRows", args = Map( "startDate" -> "DATE'2022-12-01'", "maxRows" -> "100")) ``` The new `sql()` method parses the input SQL statement and provided parameter values, and replaces the named parameters by the literal values. And then it eagerly runs DDL/DML commands, but not for SELECT queries. Closes apache#38712 ### Why are the changes needed? 1. To improve user experience with Spark SQL via - Using Spark as remote service (microservice). - Write SQL code that will power reports, dashboards, charts and other data presentation solutions that need to account for criteria modifiable by users through an interface. - Build a generic integration layer based on the SQL API. The goal is to expose managed data to a wide application ecosystem with a microservice architecture. It is only natural in such a setup to ask for modular and reusable SQL code, that can be executed repeatedly with different parameter values. 2. To achieve feature parity with other systems that support named parameters: - Redshift: https://docs.aws.amazon.com/redshift/latest/mgmt/data-api.html#data-api-calling - BigQuery: https://cloud.google.com/bigquery/docs/parameterized-queries#api - MS DBSQL: https://learn.microsoft.com/en-us/azure/databricks/sql/user/queries/query-parameters ### Does this PR introduce _any_ user-facing change? No, this is an extension of the existing APIs. ### How was this patch tested? By running new tests: ``` $ build/sbt "core/testOnly *SparkThrowableSuite" $ build/sbt "test:testOnly *PlanParserSuite" $ build/sbt "test:testOnly *AnalysisSuite" $ build/sbt "test:testOnly *ParametersSuite" ``` Closes apache#38864 from MaxGekk/parameterized-sql-2. Lead-authored-by: Max Gekk <[email protected]> Co-authored-by: Maxim Gekk <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent c25434d commit 35fa5e6

File tree

13 files changed

+246
-10
lines changed

13 files changed

+246
-10
lines changed

core/src/main/resources/error/error-classes.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,11 @@
806806
}
807807
}
808808
},
809+
"INVALID_SQL_ARG" : {
810+
"message" : [
811+
"The argument <name> of `sql()` is invalid. Consider to replace it by a SQL literal statement."
812+
]
813+
},
809814
"INVALID_SQL_SYNTAX" : {
810815
"message" : [
811816
"Invalid SQL syntax: <inputString>"
@@ -1147,6 +1152,11 @@
11471152
"Unable to convert SQL type <toType> to Protobuf type <protobufType>."
11481153
]
11491154
},
1155+
"UNBOUND_SQL_PARAMETER" : {
1156+
"message" : [
1157+
"Found the unbound parameter: <name>. Please, fix `args` and provide a mapping of the parameter to a SQL literal statement."
1158+
]
1159+
},
11501160
"UNCLOSED_BRACKETED_COMMENT" : {
11511161
"message" : [
11521162
"Found an unclosed bracketed comment. Please, append */ at the end of the comment."

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,7 @@ primaryExpression
930930

931931
constant
932932
: NULL #nullLiteral
933+
| COLON identifier #parameterLiteral
933934
| interval #intervalLiteral
934935
| identifier stringLit #typeConstructor
935936
| number #numericLiteral

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
325325
errorClass = "_LEGACY_ERROR_TEMP_2413",
326326
messageParameters = Map("argName" -> e.prettyName))
327327

328+
case p: Parameter =>
329+
p.failAnalysis(
330+
errorClass = "UNBOUND_SQL_PARAMETER",
331+
messageParameters = Map("name" -> toSQLId(p.name)))
332+
328333
case _ =>
329334
})
330335

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.SparkException
21+
import org.apache.spark.sql.catalyst.analysis.AnalysisErrorAt
22+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
23+
import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, TreePattern}
24+
import org.apache.spark.sql.errors.QueryErrorsBase
25+
import org.apache.spark.sql.types.DataType
26+
27+
/**
28+
* The expression represents a named parameter that should be replaced by a literal.
29+
*
30+
* @param name The identifier of the parameter without the marker.
31+
*/
32+
case class Parameter(name: String) extends LeafExpression with Unevaluable {
33+
override lazy val resolved: Boolean = false
34+
35+
private def unboundError(methodName: String): Nothing = {
36+
throw SparkException.internalError(
37+
s"Cannot call `$methodName()` of the unbound parameter `$name`.")
38+
}
39+
override def dataType: DataType = unboundError("dataType")
40+
override def nullable: Boolean = unboundError("nullable")
41+
42+
final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER)
43+
}
44+
45+
46+
/**
47+
* Finds all named parameters in the given plan and substitutes them by literals of `args` values.
48+
*/
49+
object Parameter extends QueryErrorsBase {
50+
def bind(plan: LogicalPlan, args: Map[String, Expression]): LogicalPlan = {
51+
if (!args.isEmpty) {
52+
args.filter(!_._2.isInstanceOf[Literal]).headOption.foreach { case (name, expr) =>
53+
expr.failAnalysis(
54+
errorClass = "INVALID_SQL_ARG",
55+
messageParameters = Map("name" -> toSQLId(name)))
56+
}
57+
plan.transformAllExpressionsWithPruning(_.containsPattern(PARAMETER)) {
58+
case Parameter(name) if args.contains(name) => args(name)
59+
}
60+
} else {
61+
plan
62+
}
63+
}
64+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4837,4 +4837,11 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
48374837
override def visitTimestampdiff(ctx: TimestampdiffContext): Expression = withOrigin(ctx) {
48384838
TimestampDiff(ctx.unit.getText, expression(ctx.startTimestamp), expression(ctx.endTimestamp))
48394839
}
4840+
4841+
/**
4842+
* Create a named parameter which represents a literal with a non-bound value and unknown type.
4843+
* */
4844+
override def visitParameterLiteral(ctx: ParameterLiteralContext): Expression = withOrigin(ctx) {
4845+
Parameter(ctx.identifier().getText)
4846+
}
48404847
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ object TreePattern extends Enumeration {
7171
val NULL_LITERAL: Value = Value
7272
val SERIALIZE_FROM_OBJECT: Value = Value
7373
val OUTER_REFERENCE: Value = Value
74+
val PARAMETER: Value = Value
7475
val PIVOT: Value = Value
7576
val PLAN_EXPRESSION: Value = Value
7677
val PYTHON_UDF: Value = Value

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,4 +1295,18 @@ class AnalysisSuite extends AnalysisTest with Matchers {
12951295

12961296
assertAnalysisSuccess(finalPlan)
12971297
}
1298+
1299+
test("SPARK-41271: bind named parameters to literals") {
1300+
comparePlans(
1301+
Parameter.bind(
1302+
plan = parsePlan("SELECT * FROM a LIMIT :limitA"),
1303+
args = Map("limitA" -> Literal(10))),
1304+
parsePlan("SELECT * FROM a LIMIT 10"))
1305+
// Ignore unused arguments
1306+
comparePlans(
1307+
Parameter.bind(
1308+
plan = parsePlan("SELECT c FROM a WHERE c < :param2"),
1309+
args = Map("param1" -> Literal(10), "param2" -> Literal(20))),
1310+
parsePlan("SELECT c FROM a WHERE c < 20"))
1311+
}
12981312
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,4 +1568,30 @@ class PlanParserSuite extends AnalysisTest {
15681568
.toAggregateExpression(false, Some(GreaterThan(UnresolvedAttribute("id"), Literal(10))))
15691569
)
15701570
}
1571+
1572+
test("SPARK-41271: parsing of named parameters") {
1573+
comparePlans(
1574+
parsePlan("SELECT :param_1"),
1575+
Project(UnresolvedAlias(Parameter("param_1"), None) :: Nil, OneRowRelation()))
1576+
comparePlans(
1577+
parsePlan("SELECT abs(:1Abc)"),
1578+
Project(UnresolvedAlias(
1579+
UnresolvedFunction(
1580+
"abs" :: Nil,
1581+
Parameter("1Abc") :: Nil,
1582+
isDistinct = false), None) :: Nil,
1583+
OneRowRelation()))
1584+
comparePlans(
1585+
parsePlan("SELECT * FROM a LIMIT :limitA"),
1586+
table("a").select(star()).limit(Parameter("limitA")))
1587+
// Invalid empty name and invalid symbol in a name
1588+
checkError(
1589+
exception = parseException(s"SELECT :-"),
1590+
errorClass = "PARSE_SYNTAX_ERROR",
1591+
parameters = Map("error" -> "'-'", "hint" -> ""))
1592+
checkError(
1593+
exception = parseException(s"SELECT :"),
1594+
errorClass = "PARSE_SYNTAX_ERROR",
1595+
parameters = Map("error" -> "end of input", "hint" -> ""))
1596+
}
15711597
}

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalog.Catalog
3737
import org.apache.spark.sql.catalyst._
3838
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
3939
import org.apache.spark.sql.catalyst.encoders._
40-
import org.apache.spark.sql.catalyst.expressions.AttributeReference
40+
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Parameter}
4141
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
4242
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
4343
import org.apache.spark.sql.connector.ExternalCommandRunner
@@ -609,19 +609,49 @@ class SparkSession private(
609609
* ----------------- */
610610

611611
/**
612-
* Executes a SQL query using Spark, returning the result as a `DataFrame`.
612+
* Executes a SQL query substituting named parameters by the given arguments,
613+
* returning the result as a `DataFrame`.
613614
* This API eagerly runs DDL/DML commands, but not for SELECT queries.
614615
*
615-
* @since 2.0.0
616+
* @param sqlText A SQL statement with named parameters to execute.
617+
* @param args A map of parameter names to literal values.
618+
*
619+
* @since 3.4.0
616620
*/
617-
def sql(sqlText: String): DataFrame = withActive {
621+
@Experimental
622+
def sql(sqlText: String, args: Map[String, String]): DataFrame = withActive {
618623
val tracker = new QueryPlanningTracker
619624
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
620-
sessionState.sqlParser.parsePlan(sqlText)
625+
val parser = sessionState.sqlParser
626+
val parsedArgs = args.mapValues(parser.parseExpression).toMap
627+
Parameter.bind(parser.parsePlan(sqlText), parsedArgs)
621628
}
622629
Dataset.ofRows(self, plan, tracker)
623630
}
624631

632+
/**
633+
* Executes a SQL query substituting named parameters by the given arguments,
634+
* returning the result as a `DataFrame`.
635+
* This API eagerly runs DDL/DML commands, but not for SELECT queries.
636+
*
637+
* @param sqlText A SQL statement with named parameters to execute.
638+
* @param args A map of parameter names to literal values.
639+
*
640+
* @since 3.4.0
641+
*/
642+
@Experimental
643+
def sql(sqlText: String, args: java.util.Map[String, String]): DataFrame = {
644+
sql(sqlText, args.asScala.toMap)
645+
}
646+
647+
/**
648+
* Executes a SQL query using Spark, returning the result as a `DataFrame`.
649+
* This API eagerly runs DDL/DML commands, but not for SELECT queries.
650+
*
651+
* @since 2.0.0
652+
*/
653+
def sql(sqlText: String): DataFrame = sql(sqlText, Map.empty[String, String])
654+
625655
/**
626656
* Execute an arbitrary string command inside an external execution engine rather than Spark.
627657
* This could be useful when user wants to execute some commands out of Spark. For
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.sql.test.SharedSparkSession
21+
22+
class ParametersSuite extends QueryTest with SharedSparkSession {
23+
24+
test("bind parameters") {
25+
val sqlText =
26+
"""
27+
|SELECT id, id % :div as c0
28+
|FROM VALUES (0), (1), (2), (3), (4), (5), (6), (7), (8), (9) AS t(id)
29+
|WHERE id < :constA
30+
|""".stripMargin
31+
val args = Map("div" -> "3", "constA" -> "4L")
32+
checkAnswer(
33+
spark.sql(sqlText, args),
34+
Row(0, 0) :: Row(1, 1) :: Row(2, 2) :: Row(3, 0) :: Nil)
35+
36+
checkAnswer(
37+
spark.sql("""SELECT contains('Spark \'SQL\'', :subStr)""", Map("subStr" -> "'SQL'")),
38+
Row(true))
39+
}
40+
41+
test("non-substituted parameters") {
42+
checkError(
43+
exception = intercept[AnalysisException] {
44+
spark.sql("select :abc, :def", Map("abc" -> "1"))
45+
},
46+
errorClass = "UNBOUND_SQL_PARAMETER",
47+
parameters = Map("name" -> "`def`"),
48+
context = ExpectedContext(
49+
fragment = ":def",
50+
start = 13,
51+
stop = 16))
52+
checkError(
53+
exception = intercept[AnalysisException] {
54+
sql("select :abc").collect()
55+
},
56+
errorClass = "UNBOUND_SQL_PARAMETER",
57+
parameters = Map("name" -> "`abc`"),
58+
context = ExpectedContext(
59+
fragment = ":abc",
60+
start = 7,
61+
stop = 10))
62+
}
63+
64+
test("non-literal argument of `sql()`") {
65+
Seq("col1 + 1", "CAST('100' AS INT)", "map('a', 1, 'b', 2)", "array(1)").foreach { arg =>
66+
checkError(
67+
exception = intercept[AnalysisException] {
68+
spark.sql("SELECT :param1 FROM VALUES (1) AS t(col1)", Map("param1" -> arg))
69+
},
70+
errorClass = "INVALID_SQL_ARG",
71+
parameters = Map("name" -> "`param1`"),
72+
context = ExpectedContext(
73+
fragment = arg,
74+
start = 0,
75+
stop = arg.length - 1))
76+
}
77+
}
78+
}

0 commit comments

Comments
 (0)