Skip to content

Commit 916b0d3

Browse files
ueshinzhengruifeng
authored andcommitted
[SPARK-43817][SPARK-43702][PYTHON] Support UserDefinedType in createDataFrame from pandas DataFrame and toPandas
### What changes were proposed in this pull request? Support `UserDefinedType` in `createDataFrame` from pandas DataFrame and `toPandas`. For the following schema and pandas DataFrame: ```py schema = ( StructType() .add("point", ExamplePointUDT()) .add("struct", StructType().add("point", ExamplePointUDT())) .add("array", ArrayType(ExamplePointUDT())) .add("map", MapType(StringType(), ExamplePointUDT())) ) data = [ Row( ExamplePoint(1.0, 2.0), Row(ExamplePoint(3.0, 4.0)), [ExamplePoint(5.0, 6.0)], dict(point=ExamplePoint(7.0, 8.0)), ) ] df = spark.createDataFrame(data, schema) pdf = pd.DataFrame.from_records(data, columns=schema.names) ``` ##### `spark.createDataFrame()` For all, return the same results: ```py >>> spark.createDataFrame(pdf, schema).show(truncate=False) +----------+------------+------------+---------------------+ |point |struct |array |map | +----------+------------+------------+---------------------+ |(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}| +----------+------------+------------+---------------------+ ``` ##### `df.toPandas()` ```py >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row') >>> df.toPandas() point struct array map 0 (1.0,2.0) ((3.0,4.0),) [(5.0,6.0)] {'point': (7.0,8.0)} ``` ### Why are the changes needed? Currently `UserDefinedType` in `spark.createDataFrame()` with pandas DataFrame and `df.toPandas()` is not supported with Arrow enabled or in Spark Connect. ##### `spark.createDataFrame()` Works without Arrow: ```py >>> spark.createDataFrame(pdf, schema).show(truncate=False) +----------+------------+------------+---------------------+ |point |struct |array |map | +----------+------------+------------+---------------------+ |(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}| +----------+------------+------------+---------------------+ ``` , whereas: - With Arrow: Works with fallback: ```py >>> spark.createDataFrame(pdf, schema).show(truncate=False) /.../python/pyspark/sql/pandas/conversion.py:351: UserWarning: createDataFrame attempted Arrow optimization because 'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by the reason below: [UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not supported in conversion to Arrow. Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true. warn(msg) +----------+------------+------------+---------------------+ |point |struct |array |map | +----------+------------+------------+---------------------+ |(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}| +----------+------------+------------+---------------------+ ``` - Spark Connect ```py >>> spark.createDataFrame(pdf, schema).show(truncate=False) Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkTypeError: [UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not supported in conversion to Arrow. ``` ##### `df.toPandas()` Works without Arrow: ```py >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row') >>> df.toPandas() point struct array map 0 (1.0,2.0) ((3.0,4.0),) [(5.0,6.0)] {'point': (7.0,8.0)} ``` , whereas: - With Arrow Works with fallback: ```py >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row') >>> df.toPandas() /.../python/pyspark/sql/pandas/conversion.py:111: UserWarning: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by the reason below: [UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not supported in conversion to Arrow. Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true. warn(msg) point struct array map 0 (1.0,2.0) ((3.0,4.0),) [(5.0,6.0)] {'point': (7.0,8.0)} ``` - Spark Connect Results with the internal type: ```py >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row') >>> df.toPandas() point struct array map 0 [1.0, 2.0] ([3.0, 4.0],) [[5.0, 6.0]] {'point': [7.0, 8.0]} ``` ### Does this PR introduce _any_ user-facing change? Users will be able to use `UserDefinedType`. ### How was this patch tested? Added the related tests. Closes apache#41333 from ueshin/issues/SPARK-43817/udt. Authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 001da5d commit 916b0d3

File tree

11 files changed

+171
-234
lines changed

11 files changed

+171
-234
lines changed

python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -25,53 +25,7 @@
2525
class UDTOpsParityTests(
2626
UDTOpsTestsMixin, PandasOnSparkTestUtils, OpsTestBase, ReusedConnectTestCase
2727
):
28-
@unittest.skip(
29-
"TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect."
30-
)
31-
def test_eq(self):
32-
super().test_eq()
33-
34-
@unittest.skip(
35-
"TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect."
36-
)
37-
def test_from_to_pandas(self):
38-
super().test_from_to_pandas()
39-
40-
@unittest.skip(
41-
"TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect."
42-
)
43-
def test_ge(self):
44-
super().test_ge()
45-
46-
@unittest.skip(
47-
"TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect."
48-
)
49-
def test_gt(self):
50-
super().test_gt()
51-
52-
@unittest.skip(
53-
"TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect."
54-
)
55-
def test_isnull(self):
56-
super().test_isnull()
57-
58-
@unittest.skip(
59-
"TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect."
60-
)
61-
def test_le(self):
62-
super().test_le()
63-
64-
@unittest.skip(
65-
"TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect."
66-
)
67-
def test_lt(self):
68-
super().test_lt()
69-
70-
@unittest.skip(
71-
"TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect."
72-
)
73-
def test_ne(self):
74-
super().test_ne()
28+
pass
7529

7630

7731
if __name__ == "__main__":

python/pyspark/sql/connect/client/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
CommonInlineUserDefinedFunction,
7676
JavaUDF,
7777
)
78-
from pyspark.sql.pandas.types import _create_converter_to_pandas
78+
from pyspark.sql.pandas.types import _create_converter_to_pandas, from_arrow_schema
7979
from pyspark.sql.types import DataType, StructType, TimestampType, _has_type
8080
from pyspark.rdd import PythonEvalType
8181
from pyspark.storagelevel import StorageLevel
@@ -717,7 +717,7 @@ def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame":
717717
table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req)
718718
assert table is not None
719719

720-
schema = schema or types.from_arrow_schema(table.schema, prefer_timestamp_ntz=True)
720+
schema = schema or from_arrow_schema(table.schema, prefer_timestamp_ntz=True)
721721
assert schema is not None and isinstance(schema, StructType)
722722

723723
# Rename columns to avoid duplicated column names.

python/pyspark/sql/connect/conversion.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@
4242
)
4343

4444
from pyspark.storagelevel import StorageLevel
45-
from pyspark.sql.connect.types import to_arrow_schema
4645
import pyspark.sql.connect.proto as pb2
47-
from pyspark.sql.pandas.types import _dedup_names, _deduplicate_field_names
46+
from pyspark.sql.pandas.types import to_arrow_schema, _dedup_names, _deduplicate_field_names
4847

4948
from typing import (
5049
Any,
@@ -246,7 +245,7 @@ def convert_string(value: Any) -> Any:
246245
elif isinstance(dataType, UserDefinedType):
247246
udt: UserDefinedType = dataType
248247

249-
conv = LocalDataToArrowConversion._create_converter(dataType.sqlType())
248+
conv = LocalDataToArrowConversion._create_converter(udt.sqlType())
250249

251250
def convert_udt(value: Any) -> Any:
252251
if value is None:
@@ -428,7 +427,7 @@ def convert_timestample_ntz(value: Any) -> Any:
428427
elif isinstance(dataType, UserDefinedType):
429428
udt: UserDefinedType = dataType
430429

431-
conv = ArrowTableToRowsConversion._create_converter(dataType.sqlType())
430+
conv = ArrowTableToRowsConversion._create_converter(udt.sqlType())
432431

433432
def convert_udt(value: Any) -> Any:
434433
if value is None:

python/pyspark/sql/connect/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
lit,
7777
expr as sql_expression,
7878
)
79-
from pyspark.sql.connect.types import from_arrow_schema
79+
from pyspark.sql.pandas.types import from_arrow_schema
8080

8181

8282
if TYPE_CHECKING:

python/pyspark/sql/connect/session.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,12 @@ def createDataFrame(
331331

332332
# Determine arrow types to coerce data when creating batches
333333
arrow_schema: Optional[pa.Schema] = None
334+
spark_types: List[Optional[DataType]]
335+
arrow_types: List[Optional[pa.DataType]]
334336
if isinstance(schema, StructType):
335-
arrow_schema = to_arrow_schema(cast(StructType, _deduplicate_field_names(schema)))
337+
deduped_schema = cast(StructType, _deduplicate_field_names(schema))
338+
spark_types = [field.dataType for field in deduped_schema.fields]
339+
arrow_schema = to_arrow_schema(deduped_schema)
336340
arrow_types = [field.type for field in arrow_schema]
337341
_cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()]
338342
elif isinstance(schema, DataType):
@@ -342,14 +346,15 @@ def createDataFrame(
342346
)
343347
else:
344348
# Any timestamps must be coerced to be compatible with Spark
345-
arrow_types = [
346-
to_arrow_type(TimestampType())
349+
spark_types = [
350+
TimestampType()
347351
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t)
348-
else to_arrow_type(DayTimeIntervalType())
352+
else DayTimeIntervalType()
349353
if is_timedelta64_dtype(t)
350354
else None
351355
for t in data.dtypes
352356
]
357+
arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types]
353358

354359
timezone, safecheck = self._client.get_configs(
355360
"spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely"
@@ -358,7 +363,14 @@ def createDataFrame(
358363
ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true")
359364

360365
_table = pa.Table.from_batches(
361-
[ser._create_batch([(c, t) for (_, c), t in zip(data.items(), arrow_types)])]
366+
[
367+
ser._create_batch(
368+
[
369+
(c, at, st)
370+
for (_, c), at, st in zip(data.items(), arrow_types, spark_types)
371+
]
372+
)
373+
]
362374
)
363375

364376
if isinstance(schema, StructType):

python/pyspark/sql/connect/types.py

Lines changed: 0 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
import json
2222

23-
import pyarrow as pa
24-
2523
from typing import Any, Dict, Optional
2624

2725
from pyspark.sql.types import (
@@ -299,147 +297,3 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType:
299297
return UserDefinedType.fromJson(json_value)
300298
else:
301299
raise Exception(f"Unsupported data type {schema}")
302-
303-
304-
def to_arrow_type(dt: DataType) -> "pa.DataType":
305-
"""
306-
Convert Spark data type to pyarrow type.
307-
308-
This function refers to 'pyspark.sql.pandas.types.to_arrow_type' but relax the restriction,
309-
e.g. it supports nested StructType.
310-
"""
311-
if type(dt) == BooleanType:
312-
arrow_type = pa.bool_()
313-
elif type(dt) == ByteType:
314-
arrow_type = pa.int8()
315-
elif type(dt) == ShortType:
316-
arrow_type = pa.int16()
317-
elif type(dt) == IntegerType:
318-
arrow_type = pa.int32()
319-
elif type(dt) == LongType:
320-
arrow_type = pa.int64()
321-
elif type(dt) == FloatType:
322-
arrow_type = pa.float32()
323-
elif type(dt) == DoubleType:
324-
arrow_type = pa.float64()
325-
elif type(dt) == DecimalType:
326-
arrow_type = pa.decimal128(dt.precision, dt.scale)
327-
elif type(dt) == StringType:
328-
arrow_type = pa.string()
329-
elif type(dt) == BinaryType:
330-
arrow_type = pa.binary()
331-
elif type(dt) == DateType:
332-
arrow_type = pa.date32()
333-
elif type(dt) == TimestampType:
334-
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
335-
arrow_type = pa.timestamp("us", tz="UTC")
336-
elif type(dt) == TimestampNTZType:
337-
arrow_type = pa.timestamp("us", tz=None)
338-
elif type(dt) == DayTimeIntervalType:
339-
arrow_type = pa.duration("us")
340-
elif type(dt) == ArrayType:
341-
field = pa.field("element", to_arrow_type(dt.elementType), nullable=dt.containsNull)
342-
arrow_type = pa.list_(field)
343-
elif type(dt) == MapType:
344-
key_field = pa.field("key", to_arrow_type(dt.keyType), nullable=False)
345-
value_field = pa.field("value", to_arrow_type(dt.valueType), nullable=dt.valueContainsNull)
346-
arrow_type = pa.map_(key_field, value_field)
347-
elif type(dt) == StructType:
348-
fields = [
349-
pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
350-
for field in dt
351-
]
352-
arrow_type = pa.struct(fields)
353-
elif type(dt) == NullType:
354-
arrow_type = pa.null()
355-
elif isinstance(dt, UserDefinedType):
356-
arrow_type = to_arrow_type(dt.sqlType())
357-
else:
358-
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
359-
return arrow_type
360-
361-
362-
def to_arrow_schema(schema: StructType) -> "pa.Schema":
363-
"""Convert a schema from Spark to Arrow"""
364-
fields = [
365-
pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
366-
for field in schema
367-
]
368-
return pa.schema(fields)
369-
370-
371-
def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType:
372-
"""Convert pyarrow type to Spark data type.
373-
374-
This function refers to 'pyspark.sql.pandas.types.from_arrow_type' but relax the restriction,
375-
e.g. it supports nested StructType, Array of TimestampType. However, Arrow DictionaryType is
376-
not allowed.
377-
"""
378-
import pyarrow.types as types
379-
380-
spark_type: DataType
381-
if types.is_boolean(at):
382-
spark_type = BooleanType()
383-
elif types.is_int8(at):
384-
spark_type = ByteType()
385-
elif types.is_int16(at):
386-
spark_type = ShortType()
387-
elif types.is_int32(at):
388-
spark_type = IntegerType()
389-
elif types.is_int64(at):
390-
spark_type = LongType()
391-
elif types.is_float32(at):
392-
spark_type = FloatType()
393-
elif types.is_float64(at):
394-
spark_type = DoubleType()
395-
elif types.is_decimal(at):
396-
spark_type = DecimalType(precision=at.precision, scale=at.scale)
397-
elif types.is_string(at):
398-
spark_type = StringType()
399-
elif types.is_binary(at):
400-
spark_type = BinaryType()
401-
elif types.is_date32(at):
402-
spark_type = DateType()
403-
elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
404-
spark_type = TimestampNTZType()
405-
elif types.is_timestamp(at):
406-
spark_type = TimestampType()
407-
elif types.is_duration(at):
408-
spark_type = DayTimeIntervalType()
409-
elif types.is_list(at):
410-
spark_type = ArrayType(from_arrow_type(at.value_type, prefer_timestamp_ntz))
411-
elif types.is_map(at):
412-
spark_type = MapType(
413-
from_arrow_type(at.key_type, prefer_timestamp_ntz),
414-
from_arrow_type(at.item_type, prefer_timestamp_ntz),
415-
)
416-
elif types.is_struct(at):
417-
return StructType(
418-
[
419-
StructField(
420-
field.name,
421-
from_arrow_type(field.type, prefer_timestamp_ntz),
422-
nullable=field.nullable,
423-
)
424-
for field in at
425-
]
426-
)
427-
elif types.is_null(at):
428-
spark_type = NullType()
429-
else:
430-
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
431-
return spark_type
432-
433-
434-
def from_arrow_schema(arrow_schema: "pa.Schema", prefer_timestamp_ntz: bool = False) -> StructType:
435-
"""Convert schema from Arrow to Spark."""
436-
return StructType(
437-
[
438-
StructField(
439-
field.name,
440-
from_arrow_type(field.type, prefer_timestamp_ntz),
441-
nullable=field.nullable,
442-
)
443-
for field in arrow_schema
444-
]
445-
)

python/pyspark/sql/pandas/conversion.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -598,30 +598,29 @@ def _create_from_pandas_with_arrow(
598598

599599
# Determine arrow types to coerce data when creating batches
600600
if isinstance(schema, StructType):
601-
arrow_types = [
602-
to_arrow_type(_deduplicate_field_names(f.dataType)) for f in schema.fields
603-
]
601+
spark_types = [_deduplicate_field_names(f.dataType) for f in schema.fields]
604602
elif isinstance(schema, DataType):
605603
raise PySparkTypeError(
606604
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
607605
message_parameters={"data_type": str(schema)},
608606
)
609607
else:
610608
# Any timestamps must be coerced to be compatible with Spark
611-
arrow_types = [
612-
to_arrow_type(TimestampType())
613-
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t)
614-
else None
609+
spark_types = [
610+
TimestampType() if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
615611
for t in pdf.dtypes
616612
]
617613

618614
# Slice the DataFrame to be batched
619615
step = self._jconf.arrowMaxRecordsPerBatch()
620616
pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step))
621617

622-
# Create list of Arrow (columns, type) for serializer dump_stream
618+
# Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_stream
623619
arrow_data = [
624-
[(c, t) for (_, c), t in zip(pdf_slice.items(), arrow_types)]
620+
[
621+
(c, to_arrow_type(t) if t is not None else None, t)
622+
for (_, c), t in zip(pdf_slice.items(), spark_types)
623+
]
625624
for pdf_slice in pdf_slices
626625
]
627626

0 commit comments

Comments
 (0)