Skip to content

Commit 1fbb38e

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-44249][SQL][PYTHON] Refactor PythonUDTFRunner to send its return type separately
### What changes were proposed in this pull request? Refactors `PythonUDTFRunner` to send its return type separately. ### Why are the changes needed? The return type of Python UDTF doesn't need to be included in the Python "command" because `PythonUDTF` knows the return type. It can send the return type separately. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Updated the related tests and existing tests. Closes apache#41792 from ueshin/issues/SPARK-44249/return_type. Authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 2cb1d3b commit 1fbb38e

File tree

8 files changed

+94
-62
lines changed

8 files changed

+94
-62
lines changed

python/pyspark/sql/udf.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,13 @@
5151

5252

5353
def _wrap_function(
54-
sc: SparkContext, func: Callable[..., Any], returnType: "DataTypeOrString"
54+
sc: SparkContext, func: Callable[..., Any], returnType: Optional[DataType] = None
5555
) -> JavaObject:
56-
command = (func, returnType)
56+
command: Any
57+
if returnType is None:
58+
command = func
59+
else:
60+
command = (func, returnType)
5761
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
5862
assert sc._jvm is not None
5963
return sc._jvm.SimplePythonFunction(

python/pyspark/sql/udtf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _create_judtf(self, func: Type) -> JavaObject:
118118
spark = SparkSession._getActiveSessionOrCreate()
119119
sc = spark.sparkContext
120120

121-
wrapped_func = _wrap_function(sc, func, self.returnType)
121+
wrapped_func = _wrap_function(sc, func)
122122
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
123123
assert sc._jvm is not None
124124
judtf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(

python/pyspark/worker.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@
6161
ApplyInPandasWithStateSerializer,
6262
)
6363
from pyspark.sql.pandas.types import to_arrow_type
64-
from pyspark.sql.types import StructType
64+
from pyspark.sql.types import StructType, _parse_datatype_json_string
6565
from pyspark.util import fail_on_stopiteration, try_simplify_traceback
6666
from pyspark import shuffle
67-
from pyspark.errors import PySparkRuntimeError, PySparkValueError
67+
from pyspark.errors import PySparkRuntimeError
6868

6969
pickleSer = CPickleSerializer()
7070
utf8_deserializer = UTF8Deserializer()
@@ -461,20 +461,11 @@ def assign_cols_by_name(runner_conf):
461461
# ensure the UDTF is valid. This function also prepares a mapper function for applying
462462
# the UDTF logic to input rows.
463463
def read_udtf(pickleSer, infile, eval_type):
464-
num_udtfs = read_int(infile)
465-
if num_udtfs != 1:
466-
raise PySparkValueError(f"Unexpected number of UDTFs. Expected 1 but got {num_udtfs}.")
467-
468-
# See `PythonUDFRunner.writeUDFs`.
464+
# See `PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
469465
num_arg = read_int(infile)
470466
arg_offsets = [read_int(infile) for _ in range(num_arg)]
471-
num_chained_funcs = read_int(infile)
472-
if num_chained_funcs != 1:
473-
raise PySparkValueError(
474-
f"Unexpected number of chained UDTFs. Expected 1 but got {num_chained_funcs}."
475-
)
476-
477-
handler, return_type = read_command(pickleSer, infile)
467+
handler = read_command(pickleSer, infile)
468+
return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile))
478469
if not isinstance(handler, type):
479470
raise PySparkRuntimeError(
480471
f"Invalid UDTF handler type. Expected a class (type 'type'), but "

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ abstract class UnevaluableGenerator extends Generator {
159159
case class PythonUDTF(
160160
name: String,
161161
func: PythonFunction,
162-
override val elementSchema: StructType,
162+
elementSchema: StructType,
163163
children: Seq[Expression],
164164
udfDeterministic: Boolean,
165165
resultId: ExprId = NamedExpression.newExprId)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20-
import java.io.File
20+
import java.io.{DataOutputStream, File}
21+
import java.net.Socket
2122

2223
import scala.collection.JavaConverters._
2324
import scala.collection.mutable.ArrayBuffer
@@ -31,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
3132
import org.apache.spark.sql.catalyst.expressions._
3233
import org.apache.spark.sql.catalyst.util.GenericArrayData
3334
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
35+
import org.apache.spark.sql.execution.metric.SQLMetric
3436
import org.apache.spark.sql.types.{DataType, StructField, StructType}
3537
import org.apache.spark.util.Utils
3638

@@ -69,21 +71,17 @@ case class BatchEvalPythonUDTFExec(
6971
queue.close()
7072
}
7173

72-
val inputs = Seq(udtf.children)
73-
7474
// flatten all the arguments
7575
val allInputs = new ArrayBuffer[Expression]
7676
val dataTypes = new ArrayBuffer[DataType]
77-
val argOffsets = inputs.map { input =>
78-
input.map { e =>
79-
if (allInputs.exists(_.semanticEquals(e))) {
80-
allInputs.indexWhere(_.semanticEquals(e))
81-
} else {
82-
allInputs += e
83-
dataTypes += e.dataType
84-
allInputs.length - 1
85-
}
86-
}.toArray
77+
val argOffsets = udtf.children.map { e =>
78+
if (allInputs.exists(_.semanticEquals(e))) {
79+
allInputs.indexWhere(_.semanticEquals(e))
80+
} else {
81+
allInputs += e
82+
dataTypes += e.dataType
83+
allInputs.length - 1
84+
}
8785
}.toArray
8886
val projection = MutableProjection.create(allInputs.toSeq, child.output)
8987
projection.initialize(context.partitionId())
@@ -101,7 +99,7 @@ case class BatchEvalPythonUDTFExec(
10199
projection(inputRow)
102100
}
103101

104-
val outputRowIterator = evaluate(udtf, argOffsets, projectedRowIter, schema, context)
102+
val outputRowIterator = evaluate(argOffsets, projectedRowIter, schema, context)
105103

106104
val pruneChildForResult: InternalRow => InternalRow =
107105
if (child.outputSet == AttributeSet(requiredChildOutput)) {
@@ -136,8 +134,7 @@ case class BatchEvalPythonUDTFExec(
136134
* an iterator of internal rows for every input row.
137135
*/
138136
private def evaluate(
139-
udtf: PythonUDTF,
140-
argOffsets: Array[Array[Int]],
137+
argOffsets: Array[Int],
141138
iter: Iterator[InternalRow],
142139
schema: StructType,
143140
context: TaskContext): Iterator[Iterator[InternalRow]] = {
@@ -147,9 +144,8 @@ case class BatchEvalPythonUDTFExec(
147144
val inputIterator = BatchEvalPythonExec.getInputIterator(iter, schema)
148145

149146
// Output iterator for results from Python.
150-
val funcs = Seq(ChainedPythonFunctions(Seq(udtf.func)))
151147
val outputIterator =
152-
new PythonUDFRunner(funcs, PythonEvalType.SQL_TABLE_UDF, argOffsets, pythonMetrics)
148+
new PythonUDTFRunner(udtf, argOffsets, pythonMetrics)
153149
.compute(inputIterator, context.partitionId(), context)
154150

155151
val unpickle = new Unpickler
@@ -173,3 +169,32 @@ case class BatchEvalPythonUDTFExec(
173169
override protected def withNewChildInternal(newChild: SparkPlan): BatchEvalPythonUDTFExec =
174170
copy(child = newChild)
175171
}
172+
173+
class PythonUDTFRunner(
174+
udtf: PythonUDTF,
175+
argOffsets: Array[Int],
176+
pythonMetrics: Map[String, SQLMetric])
177+
extends BasePythonUDFRunner(
178+
Seq(ChainedPythonFunctions(Seq(udtf.func))),
179+
PythonEvalType.SQL_TABLE_UDF, Array(argOffsets), pythonMetrics) {
180+
181+
protected override def newWriterThread(
182+
env: SparkEnv,
183+
worker: Socket,
184+
inputIterator: Iterator[Array[Byte]],
185+
partitionIndex: Int,
186+
context: TaskContext): WriterThread = {
187+
new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex, context) {
188+
189+
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
190+
dataOut.writeInt(argOffsets.length)
191+
argOffsets.foreach { offset =>
192+
dataOut.writeInt(offset)
193+
}
194+
dataOut.writeInt(udtf.func.command.length)
195+
dataOut.write(udtf.func.command.toArray)
196+
writeUTF(udtf.elementSchema.json, dataOut)
197+
}
198+
}
199+
}
200+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.internal.SQLConf
2929
/**
3030
* A helper class to run Python UDFs in Spark.
3131
*/
32-
class PythonUDFRunner(
32+
abstract class BasePythonUDFRunner(
3333
funcs: Seq[ChainedPythonFunctions],
3434
evalType: Int,
3535
argOffsets: Array[Array[Int]],
@@ -43,27 +43,22 @@ class PythonUDFRunner(
4343

4444
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
4545

46-
protected override def newWriterThread(
46+
abstract class PythonUDFWriterThread(
4747
env: SparkEnv,
4848
worker: Socket,
4949
inputIterator: Iterator[Array[Byte]],
5050
partitionIndex: Int,
51-
context: TaskContext): WriterThread = {
52-
new WriterThread(env, worker, inputIterator, partitionIndex, context) {
53-
54-
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
55-
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
56-
}
51+
context: TaskContext)
52+
extends WriterThread(env, worker, inputIterator, partitionIndex, context) {
5753

58-
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
59-
val startData = dataOut.size()
54+
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
55+
val startData = dataOut.size()
6056

61-
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
62-
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
57+
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
58+
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
6359

64-
val deltaData = dataOut.size() - startData
65-
pythonMetrics("pythonDataSent") += deltaData
66-
}
60+
val deltaData = dataOut.size() - startData
61+
pythonMetrics("pythonDataSent") += deltaData
6762
}
6863
}
6964

@@ -106,6 +101,29 @@ class PythonUDFRunner(
106101
}
107102
}
108103

104+
class PythonUDFRunner(
105+
funcs: Seq[ChainedPythonFunctions],
106+
evalType: Int,
107+
argOffsets: Array[Array[Int]],
108+
pythonMetrics: Map[String, SQLMetric])
109+
extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics) {
110+
111+
protected override def newWriterThread(
112+
env: SparkEnv,
113+
worker: Socket,
114+
inputIterator: Iterator[Array[Byte]],
115+
partitionIndex: Int,
116+
context: TaskContext): WriterThread = {
117+
new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex, context) {
118+
119+
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
120+
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
121+
}
122+
123+
}
124+
}
125+
}
126+
109127
object PythonUDFRunner {
110128

111129
def writeUDFs(

sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
196196
if (!shouldTestPythonUDFs) {
197197
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
198198
}
199-
var binaryPandasFunc: Array[Byte] = null
199+
var binaryPythonUDTF: Array[Byte] = null
200200
withTempPath { codePath =>
201201
Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8))
202202
withTempPath { path =>
@@ -208,14 +208,14 @@ object IntegratedUDFTestUtils extends SQLHelper {
208208
s"f = open('$path', 'wb');" +
209209
s"exec(open('$codePath', 'r').read());" +
210210
"f.write(CloudPickleSerializer().dumps(" +
211-
s"($funcName, returnType)))"),
211+
s"$funcName))"),
212212
None,
213213
"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
214-
binaryPandasFunc = Files.readAllBytes(path.toPath)
214+
binaryPythonUDTF = Files.readAllBytes(path.toPath)
215215
}
216216
}
217-
assert(binaryPandasFunc != null)
218-
binaryPandasFunc
217+
assert(binaryPythonUDTF != null)
218+
binaryPythonUDTF
219219
}
220220

221221
private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) {

sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession {
3030

3131
private val pythonScript: String =
3232
"""
33-
|from pyspark.sql.types import StructType, StructField, IntegerType
34-
|returnType = StructType([
35-
| StructField("a", IntegerType()),
36-
| StructField("b", IntegerType()),
37-
| StructField("c", IntegerType()),
38-
|])
3933
|class SimpleUDTF:
4034
| def eval(self, a: int, b: int):
4135
| yield a, b, a + b

0 commit comments

Comments
 (0)