Skip to content

[DRAFT][PYTHON] Improve Python UDF Arrow Serializer Performance #51225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 50 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
9dfe326
wrap_arrow_udf and serializer
asl3 Jun 19, 2025
779fea9
evaltype
asl3 Jun 19, 2025
a3cded6
nit
asl3 Jun 19, 2025
33f6f67
test
asl3 Jun 19, 2025
8440f9c
tmp
asl3 Jun 19, 2025
05d87c8
skip variant tests
asl3 Jun 19, 2025
1dd0ceb
arrow batch serializer
asl3 Jun 19, 2025
ffa562a
refactor
asl3 Jun 19, 2025
e075c0d
rename
asl3 Jun 19, 2025
93b44ca
refactor
asl3 Jun 19, 2025
b1f8965
nit
asl3 Jun 19, 2025
8ef0726
spacing
asl3 Jun 19, 2025
b871ab1
fmt
asl3 Jun 22, 2025
d9570a1
update test
asl3 Jun 22, 2025
8f40352
scalar arrow
asl3 Jun 23, 2025
fc844c2
spacing
asl3 Jun 23, 2025
8f9420c
spacing
asl3 Jun 23, 2025
915f919
comment
asl3 Jun 23, 2025
fc26618
sql scalar arrow iter udf
asl3 Jun 23, 2025
03cac02
whitespace
asl3 Jun 23, 2025
81f977a
restore
asl3 Jun 23, 2025
3e0d81f
nit
asl3 Jun 23, 2025
a384c81
nit
asl3 Jun 23, 2025
d8681dc
skip test
asl3 Jun 23, 2025
737acf0
test errors
asl3 Jun 23, 2025
6dd239d
SPARK-34545 test
asl3 Jun 24, 2025
3e99d5d
cleanup
asl3 Jun 24, 2025
502f201
remove skip tests
asl3 Jun 24, 2025
a8ab3e1
fmt
asl3 Jun 24, 2025
06469b1
fmt
asl3 Jun 24, 2025
80e34ec
refactor legacy/non-legacy tests
asl3 Jun 25, 2025
9525c5d
tmp
asl3 Jun 26, 2025
b107f0f
tmp
asl3 Jun 26, 2025
5fbde58
nits
asl3 Jun 26, 2025
6c4f3b3
fmt
asl3 Jun 26, 2025
a898ed3
spacing
asl3 Jun 26, 2025
e0898c4
comment
asl3 Jun 26, 2025
e54f7be
fmt
asl3 Jun 26, 2025
42e46db
fmt
asl3 Jun 26, 2025
039733c
test
asl3 Jun 26, 2025
d6afb31
complex and nested type
asl3 Jun 29, 2025
cf7c754
nested struct
asl3 Jun 30, 2025
d9c8411
complex variant
asl3 Jun 30, 2025
27644aa
remove print
asl3 Jun 30, 2025
d613c88
nits
asl3 Jun 30, 2025
249719b
complex input type
asl3 Jun 30, 2025
4b35bee
fix
asl3 Jul 1, 2025
46185cc
remove branch
asl3 Jul 1, 2025
3d829df
test
asl3 Jul 7, 2025
d5edd6c
fmt
asl3 Jul 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/docs/source/migration_guide/pyspark_upgrade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ Upgrading from PySpark 4.0 to 4.1

* In Spark 4.1, Arrow-optimized Python UDF supports UDT input / output instead of falling back to the regular UDF. To restore the legacy behavior, set ``spark.sql.execution.pythonUDF.arrow.legacy.fallbackOnUDT`` to ``true``.

* In Spark 4.1, unnecessary conversion to pandas instances is removed when ``spark.sql.execution.pythonUDTF.arrow.enabled`` is enabled. As a result, the type coercion changes when the produced output has a schema different from the specified schema. To restore the previous behavior, enable ``spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled``.
* In Spark 4.1, unnecessary conversion to pandas instances is removed when ``spark.sql.execution.pythonUDF.arrow.enabled`` is enabled. As a result, the type coercion changes when the produced output has a schema different from the specified schema. To restore the previous behavior, enable ``spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled``.

* In Spark 4.1, unnecessary conversion to pandas instances is removed when ``spark.sql.execution.pythonUDTF.arrow.enabled`` is enabled. As a result, the type coercion changes when the produced output has a schema different from the specified schema. To restore the previous behavior, enable ``spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled``.

Upgrading from PySpark 3.5 to 4.0
---------------------------------
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,11 @@
"Type hints for <target> should be specified; however, got <sig>."
]
},
"UDF_ARROW_TYPE_CONVERSION_ERROR": {
"message": [
"Cannot convert the output value of the input '<data>' with type '<schema>' to the specified return type of the column: '<arrow_schema>'. Please check if the data types match and try again."
]
},
"UDF_RETURN_TYPE": {
"message": [
"Return type of the user-defined function should be <expected>, but is <actual>."
Expand Down
124 changes: 124 additions & 0 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def load_stream(self, stream):
reader = pa.ipc.open_stream(stream)
for batch in reader:
yield batch

def arrow_to_pandas(self, arrow_column, idx, struct_in_pandas="dict", ndarray_as_list=False, spark_type=None):
return arrow_column.to_pylist()

def __repr__(self):
return "ArrowStreamSerializer"
Expand Down Expand Up @@ -194,6 +197,127 @@ class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer):
def load_stream(self, stream):
return ArrowStreamSerializer.load_stream(self, stream)

class ArrowBatchUDFSerializer(ArrowStreamSerializer):
"""
Serializer used by Python worker to evaluate Arrow UDFs
"""

def __init__(
self,
timezone,
safecheck,
assign_cols_by_name,
arrow_cast,
struct_in_pandas="row",
ndarray_as_list=False,
):
super(ArrowBatchUDFSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name
self._arrow_cast = arrow_cast
self._struct_in_pandas = struct_in_pandas
self._ndarray_as_list = True

def _create_array(self, arr, arrow_type, arrow_cast):
import pyarrow as pa

assert(isinstance(arr, pa.Array), arr)
assert(isinstance(arrow_type, pa.DataType), arrow_type)

try:
return arr
except pa.lib.ArrowException:
if arrow_cast:
return arr.cast(target_type=arrow_type, safe=self._safecheck)
else:
raise

def load_stream(self, stream):
import pyarrow as pa
batches = super(ArrowBatchUDFSerializer, self).load_stream(stream)
for batch in batches:
table = pa.Table.from_batches([batch])
columns = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(table.itercolumns())
]
yield columns

def dump_stream(self, iterator, stream):
"""
Override because Arrow UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
This should be sent after creating the first record batch so in case of an error, it can
be sent back to the JVM before the Arrow stream starts.
"""
import pyarrow as pa

def wrap_and_init_stream():
should_write_start_length = True
for packed in iterator:
# Flatten tuple of lists into a single list
if isinstance(packed, tuple) and all(isinstance(x, list) for x in packed):
packed = [item for sublist in packed for item in sublist]

if isinstance(packed, tuple) and len(packed) == 2 and isinstance(packed[1], pa.DataType):
# single array UDF in a projection
arrs = [self._create_array(packed[0], packed[1], self._arrow_cast)]
elif isinstance(packed, list):
# multiple array UDFs in a projection
arrs = [self._create_array(t[0], t[1], self._arrow_cast) for t in packed]
else:
arr = pa.array([packed], type=pa.int32())
arrs = [self._create_array(arr, pa.int32(), self._arrow_cast)]

batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])

# Write the first record batch with initialization.
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False
yield batch

return ArrowStreamSerializer.dump_stream(self, wrap_and_init_stream(), stream)

def __repr__(self):
return "ArrowBatchUDFSerializer"

def arrow_to_pandas(self, arrow_column, idx, ndarray_as_list=False, spark_type=None):
import pyarrow.types as types
from pyspark.sql import Row

if (
self._struct_in_pandas == "row"
and types.is_struct(arrow_column.type)
and not is_variant(arrow_column.type)
):
series = [
super(ArrowBatchUDFSerializer, self)
.arrow_to_pandas(
column,
i,
self._struct_in_pandas,
self._ndarray_as_list,
spark_type=None,
)
for i, (column, field) in enumerate(zip(arrow_column.flatten(), arrow_column.type))
]
row_cls = Row(*[field.name for field in arrow_column.type])
result = series[0].__class__([
row_cls(*vals) for vals in zip(*series)
])
else:
result = super(ArrowBatchUDFSerializer, self).arrow_to_pandas(
arrow_column,
idx,
struct_in_pandas=self._struct_in_pandas,
ndarray_as_list=ndarray_as_list,
spark_type=spark_type,
)
if spark_type is not None and hasattr(spark_type, "fromInternal"):
return result.apply(lambda v: spark_type.fromInternal(v) if v is not None else v)
return result


class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
"""
Expand Down
55 changes: 50 additions & 5 deletions python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,37 @@ def test_udf_use_arrow_and_session_conf(self):
)


@unittest.skipIf(
not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message
)
class ArrowPythonUDFLegacyTestsMixin(ArrowPythonUDFTestsMixin):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.spark.conf.set("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled", "true")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled")
finally:
super().tearDownClass()


class ArrowPythonUDFNonLegacyTestsMixin(ArrowPythonUDFTestsMixin):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.spark.conf.set("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled", "false")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled")
finally:
super().tearDownClass()


class ArrowPythonUDFTests(ArrowPythonUDFTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -258,18 +289,32 @@ def tearDownClass(cls):
super(ArrowPythonUDFTests, cls).tearDownClass()


class AsyncArrowPythonUDFTests(ArrowPythonUDFTests):
class ArrowPythonUDFLegacyTests(ArrowPythonUDFLegacyTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super(AsyncArrowPythonUDFTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.concurrency.level", "4")
super(ArrowPythonUDFLegacyTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.concurrency.level")
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
super(ArrowPythonUDFLegacyTests, cls).tearDownClass()


class ArrowPythonUDFNonLegacyTests(ArrowPythonUDFNonLegacyTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super(ArrowPythonUDFNonLegacyTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
super(AsyncArrowPythonUDFTests, cls).tearDownClass()
super(ArrowPythonUDFNonLegacyTests, cls).tearDownClass()


if __name__ == "__main__":
Expand Down
Loading