Skip to content

Commit f1a161c

Browse files
committed
[SPARK-44561][PYTHON] Fix AssertionError when converting UDTF output to a complex type
### What changes were proposed in this pull request? Fixes AssertionError when converting UDTF output to a complex type by ignore assertions in `_create_converter_from_pandas` to make Arrow raise an error. ### Why are the changes needed? There is an assertion in `_create_converter_from_pandas`, but it should not be applied for Python UDTF case. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added/modified the related tests. Closes apache#42310 from ueshin/issues/SPARK-44561/udtf_complex_types. Authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Takuya UESHIN <[email protected]>
1 parent 2a23c7a commit f1a161c

File tree

4 files changed

+314
-49
lines changed

4 files changed

+314
-49
lines changed

python/pyspark/sql/pandas/serializers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,10 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
571571
dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True)
572572
# TODO(SPARK-43579): cache the converter for reuse
573573
conv = _create_converter_from_pandas(
574-
dt, timezone=self._timezone, error_on_duplicated_field_names=False
574+
dt,
575+
timezone=self._timezone,
576+
error_on_duplicated_field_names=False,
577+
ignore_unexpected_complex_type_values=True,
575578
)
576579
series = conv(series)
577580

python/pyspark/sql/pandas/types.py

Lines changed: 82 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"""
2222
import datetime
2323
import itertools
24-
from typing import Any, Callable, List, Optional, Union, TYPE_CHECKING
24+
from typing import Any, Callable, Iterable, List, Optional, Union, TYPE_CHECKING
2525

2626
from pyspark.sql.types import (
2727
cast,
@@ -750,6 +750,7 @@ def _create_converter_from_pandas(
750750
*,
751751
timezone: Optional[str],
752752
error_on_duplicated_field_names: bool = True,
753+
ignore_unexpected_complex_type_values: bool = False,
753754
) -> Callable[["pd.Series"], "pd.Series"]:
754755
"""
755756
Create a converter of pandas Series to create Spark DataFrame with Arrow optimization.
@@ -763,6 +764,17 @@ def _create_converter_from_pandas(
763764
error_on_duplicated_field_names : bool, optional
764765
Whether raise an exception when there are duplicated field names.
765766
(default ``True``)
767+
ignore_unexpected_complex_type_values : bool, optional
768+
Whether ignore the case where unexpected values are given for complex types.
769+
If ``False``, each complex type expects:
770+
771+
* array type: :class:`Iterable`
772+
* map type: :class:`dict`
773+
* struct type: :class:`dict` or :class:`tuple`
774+
775+
and raise an AssertionError when the given value is not the expected type.
776+
If ``True``, just ignore and return the give value.
777+
(default ``False``)
766778
767779
Returns
768780
-------
@@ -781,28 +793,51 @@ def correct_timestamp(pser: pd.Series) -> pd.Series:
781793
def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]:
782794

783795
if isinstance(dt, ArrayType):
784-
_element_conv = _converter(dt.elementType)
785-
if _element_conv is None:
786-
return None
796+
_element_conv = _converter(dt.elementType) or (lambda x: x)
787797

788-
def convert_array(value: Any) -> Any:
789-
if value is None:
790-
return None
791-
else:
792-
return [_element_conv(v) for v in value] # type: ignore[misc]
798+
if ignore_unexpected_complex_type_values:
799+
800+
def convert_array(value: Any) -> Any:
801+
if value is None:
802+
return None
803+
elif isinstance(value, Iterable):
804+
return [_element_conv(v) for v in value]
805+
else:
806+
return value
807+
808+
else:
809+
810+
def convert_array(value: Any) -> Any:
811+
if value is None:
812+
return None
813+
else:
814+
assert isinstance(value, Iterable)
815+
return [_element_conv(v) for v in value]
793816

794817
return convert_array
795818

796819
elif isinstance(dt, MapType):
797820
_key_conv = _converter(dt.keyType) or (lambda x: x)
798821
_value_conv = _converter(dt.valueType) or (lambda x: x)
799822

800-
def convert_map(value: Any) -> Any:
801-
if value is None:
802-
return None
803-
else:
804-
assert isinstance(value, dict)
805-
return [(_key_conv(k), _value_conv(v)) for k, v in value.items()]
823+
if ignore_unexpected_complex_type_values:
824+
825+
def convert_map(value: Any) -> Any:
826+
if value is None:
827+
return None
828+
elif isinstance(value, dict):
829+
return [(_key_conv(k), _value_conv(v)) for k, v in value.items()]
830+
else:
831+
return value
832+
833+
else:
834+
835+
def convert_map(value: Any) -> Any:
836+
if value is None:
837+
return None
838+
else:
839+
assert isinstance(value, dict)
840+
return [(_key_conv(k), _value_conv(v)) for k, v in value.items()]
806841

807842
return convert_map
808843

@@ -820,17 +855,38 @@ def convert_map(value: Any) -> Any:
820855

821856
field_convs = [_converter(f.dataType) or (lambda x: x) for f in dt.fields]
822857

823-
def convert_struct(value: Any) -> Any:
824-
if value is None:
825-
return None
826-
elif isinstance(value, dict):
827-
return {
828-
dedup_field_names[i]: field_convs[i](value.get(key, None))
829-
for i, key in enumerate(field_names)
830-
}
831-
else:
832-
assert isinstance(value, tuple)
833-
return {dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value)}
858+
if ignore_unexpected_complex_type_values:
859+
860+
def convert_struct(value: Any) -> Any:
861+
if value is None:
862+
return None
863+
elif isinstance(value, dict):
864+
return {
865+
dedup_field_names[i]: field_convs[i](value.get(key, None))
866+
for i, key in enumerate(field_names)
867+
}
868+
elif isinstance(value, tuple):
869+
return {
870+
dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value)
871+
}
872+
else:
873+
return value
874+
875+
else:
876+
877+
def convert_struct(value: Any) -> Any:
878+
if value is None:
879+
return None
880+
elif isinstance(value, dict):
881+
return {
882+
dedup_field_names[i]: field_convs[i](value.get(key, None))
883+
for i, key in enumerate(field_names)
884+
}
885+
else:
886+
assert isinstance(value, tuple)
887+
return {
888+
dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value)
889+
}
834890

835891
return convert_struct
836892

python/pyspark/sql/tests/connect/test_parity_udtf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def tearDownClass(cls):
4545

4646
# TODO: use PySpark error classes instead of SparkConnectGrpcException
4747

48+
def test_struct_output_type_casting_row(self):
49+
self.check_struct_output_type_casting_row(SparkConnectGrpcException)
50+
4851
def test_udtf_with_invalid_return_type(self):
4952
@udtf(returnType="int")
5053
class TestUDTF:

0 commit comments

Comments
 (0)