Skip to content

Commit 631ee67

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-43055][CONNECT][PYTHON] Support duplicated nested field names
### What changes were proposed in this pull request? Supports duplicated nested field names when `spark.createDataFrame` or `df.collect`. ### Why are the changes needed? If there are duplicated nested field names, the following error is raised: ```py >>> from pyspark.sql.types import * >>> >>> data = [Row(Row("a", 1), Row(2, 3, "b", 4, "c")), Row(Row("x", 6), Row(7, 8, "y", 9, "z"))] >>> schema = ( ... StructType() ... .add("struct", StructType().add("x", StringType()).add("x", IntegerType())) ... .add( ... "struct", ... StructType() ... .add("a", IntegerType()) ... .add("x", IntegerType()) ... .add("x", StringType()) ... .add("y", IntegerType()) ... .add("y", StringType()), ... ) ... ) >>> df = spark.createDataFrame(data, schema=schema) Traceback (most recent call last): ... pyarrow.lib.ArrowTypeError: Expected bytes, got a 'int' object ``` ### Does this PR introduce _any_ user-facing change? The duplicated nested field names will be available. ### How was this patch tested? Added a test. Closes apache#40692 from ueshin/issues/SPARK-43055/duplicate_fields. Authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent a31ac04 commit 631ee67

File tree

5 files changed

+135
-49
lines changed

5 files changed

+135
-49
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncode
3131
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
3232
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer
3333
import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable}
34+
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
3435
import org.apache.spark.sql.types.StructType
3536
import org.apache.spark.sql.util.ArrowUtils
3637
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
@@ -60,13 +61,20 @@ private[sql] class SparkResult[T](
6061
private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean = {
6162
while (responses.hasNext) {
6263
val response = responses.next()
63-
if (response.hasArrowBatch) {
64+
if (response.hasSchema) {
65+
// The original schema should arrive before ArrowBatches.
66+
structType =
67+
DataTypeProtoConverter.toCatalystType(response.getSchema).asInstanceOf[StructType]
68+
} else if (response.hasArrowBatch) {
6469
val ipcStreamBytes = response.getArrowBatch.getData
6570
val reader = new ArrowStreamReader(ipcStreamBytes.newInput(), allocator)
6671
try {
6772
val root = reader.getVectorSchemaRoot
6873
if (batches.isEmpty) {
69-
structType = ArrowUtils.fromArrowSchema(root.getSchema)
74+
if (structType == null) {
75+
// If the schema is not available yet, fallback to the schema from Arrow.
76+
structType = ArrowUtils.fromArrowSchema(root.getSchema)
77+
}
7078
// TODO: create encoders that directly operate on arrow vectors.
7179
boundEncoder = createEncoder(structType).resolveAndBind(structType.toAttributes)
7280
}

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.connect.service
1919

20+
import java.util.concurrent.atomic.AtomicInteger
21+
2022
import scala.collection.JavaConverters._
2123

2224
import com.google.protobuf.ByteString
@@ -38,7 +40,7 @@ import org.apache.spark.sql.connect.service.SparkConnectStreamHandler.processAsA
3840
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
3941
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
4042
import org.apache.spark.sql.execution.arrow.ArrowConverters
41-
import org.apache.spark.sql.types.StructType
43+
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType, UserDefinedType}
4244
import org.apache.spark.util.{ThreadUtils, Utils}
4345

4446
class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse])
@@ -120,8 +122,38 @@ object SparkConnectStreamHandler {
120122
sessionId: String,
121123
dataframe: DataFrame,
122124
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
125+
126+
def deduplicateFieldNames(dt: DataType): DataType = dt match {
127+
case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType)
128+
case st @ StructType(fields) =>
129+
val newNames = if (st.names.toSet.size == st.names.length) {
130+
st.names
131+
} else {
132+
val genNawName = st.names.groupBy(identity).map {
133+
case (name, names) if names.length > 1 =>
134+
val i = new AtomicInteger()
135+
name -> { () => s"${name}_${i.getAndIncrement()}" }
136+
case (name, _) => name -> { () => name }
137+
}
138+
st.names.map(genNawName(_)())
139+
}
140+
val newFields =
141+
fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) =>
142+
StructField(name, deduplicateFieldNames(dataType), nullable, metadata)
143+
}
144+
StructType(newFields)
145+
case ArrayType(elementType, containsNull) =>
146+
ArrayType(deduplicateFieldNames(elementType), containsNull)
147+
case MapType(keyType, valueType, valueContainsNull) =>
148+
MapType(
149+
deduplicateFieldNames(keyType),
150+
deduplicateFieldNames(valueType),
151+
valueContainsNull)
152+
case _ => dt
153+
}
154+
123155
val spark = dataframe.sparkSession
124-
val schema = dataframe.schema
156+
val schema = deduplicateFieldNames(dataframe.schema).asInstanceOf[StructType]
125157
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
126158
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
127159
// Conservatively sets it 70% because the size is not accurate but estimated.

python/pyspark/sql/connect/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,12 +678,13 @@ def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame":
678678
req.plan.CopyFrom(plan)
679679
table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req)
680680
assert table is not None
681-
pdf = table.rename_columns([f"col_{i}" for i in range(len(table.column_names))]).to_pandas()
682-
pdf.columns = table.column_names
683681

684682
schema = schema or types.from_arrow_schema(table.schema)
685683
assert schema is not None and isinstance(schema, StructType)
686684

685+
pdf = table.to_pandas()
686+
pdf.columns = schema.fieldNames()
687+
687688
for field, pa_field in zip(schema, table.schema):
688689
if isinstance(field.dataType, TimestampType):
689690
assert pa_field.type.tz is not None

python/pyspark/sql/connect/conversion.py

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
check_dependencies(__name__)
2020

2121
import array
22+
import itertools
2223
import datetime
2324
import decimal
2425

@@ -31,20 +32,23 @@
3132
TimestampType,
3233
TimestampNTZType,
3334
MapType,
35+
StructField,
3436
StructType,
3537
ArrayType,
3638
BinaryType,
3739
NullType,
3840
DecimalType,
3941
StringType,
4042
UserDefinedType,
43+
cast,
4144
)
4245

4346
from pyspark.sql.connect.types import to_arrow_schema
4447

4548
from typing import (
4649
Any,
4750
Callable,
51+
Dict,
4852
Sequence,
4953
List,
5054
)
@@ -99,10 +103,10 @@ def _create_converter(dataType: DataType) -> Callable:
99103

100104
field_names = dataType.fieldNames()
101105

102-
field_convs = {
103-
field.name: LocalDataToArrowConversion._create_converter(field.dataType)
106+
field_convs = [
107+
LocalDataToArrowConversion._create_converter(field.dataType)
104108
for field in dataType.fields
105-
}
109+
]
106110

107111
def convert_struct(value: Any) -> Any:
108112
if value is None:
@@ -113,24 +117,15 @@ def convert_struct(value: Any) -> Any:
113117
), f"{type(value)} {value}"
114118

115119
_dict = {}
116-
if isinstance(value, dict):
117-
for k, v in value.items():
118-
assert isinstance(k, str)
119-
_dict[k] = field_convs[k](v)
120-
elif isinstance(value, Row) and hasattr(value, "__fields__"):
121-
for k, v in value.asDict(recursive=False).items():
122-
assert isinstance(k, str)
123-
_dict[k] = field_convs[k](v)
124-
elif not isinstance(value, Row) and hasattr(value, "__dict__"):
125-
for k, v in value.__dict__.items():
126-
assert isinstance(k, str)
127-
_dict[k] = field_convs[k](v)
128-
else:
129-
i = 0
130-
for v in value:
131-
field_name = field_names[i]
132-
_dict[field_name] = field_convs[field_name](v)
133-
i += 1
120+
if not isinstance(value, Row) and hasattr(value, "__dict__"):
121+
value = value.__dict__
122+
for i, field in enumerate(field_names):
123+
if isinstance(value, dict):
124+
v = value.get(field)
125+
else:
126+
v = value[i]
127+
128+
_dict[f"col_{i}"] = field_convs[i](v)
134129

135130
return _dict
136131

@@ -255,8 +250,6 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
255250

256251
assert schema is not None and isinstance(schema, StructType)
257252

258-
pa_schema = to_arrow_schema(schema)
259-
260253
column_names = schema.fieldNames()
261254

262255
column_convs = [
@@ -276,6 +269,27 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
276269

277270
pylist[i].append(column_convs[i](value))
278271

272+
def normalize(dt: DataType) -> DataType:
273+
if isinstance(dt, StructType):
274+
return StructType(
275+
[
276+
StructField(f"col_{i}", normalize(field.dataType), nullable=field.nullable)
277+
for i, field in enumerate(dt.fields)
278+
]
279+
)
280+
elif isinstance(dt, ArrayType):
281+
return ArrayType(normalize(dt.elementType), containsNull=dt.containsNull)
282+
elif isinstance(dt, MapType):
283+
return MapType(
284+
normalize(dt.keyType),
285+
normalize(dt.valueType),
286+
valueContainsNull=dt.valueContainsNull,
287+
)
288+
else:
289+
return dt
290+
291+
pa_schema = to_arrow_schema(cast(StructType, normalize(schema)))
292+
279293
return pa.Table.from_arrays(pylist, schema=pa_schema)
280294

281295

@@ -319,28 +333,42 @@ def _create_converter(dataType: DataType) -> Callable:
319333

320334
elif isinstance(dataType, StructType):
321335

322-
field_convs = {
323-
f.name: ArrowTableToRowsConversion._create_converter(f.dataType)
324-
for f in dataType.fields
325-
}
326-
need_conv = any(
327-
ArrowTableToRowsConversion._need_converter(f.dataType) for f in dataType.fields
328-
)
336+
field_names = dataType.names
337+
338+
if len(set(field_names)) == len(field_names):
339+
dedup_field_names = field_names
340+
else:
341+
gen_new_name: Dict[str, Callable[[], str]] = {}
342+
for name, group in itertools.groupby(dataType.names):
343+
if len(list(group)) > 1:
344+
345+
def _gen(_name: str) -> Callable[[], str]:
346+
_i = itertools.count()
347+
return lambda: f"{_name}_{next(_i)}"
348+
349+
else:
350+
351+
def _gen(_name: str) -> Callable[[], str]:
352+
return lambda: _name
353+
354+
gen_new_name[name] = _gen(name)
355+
dedup_field_names = [gen_new_name[name]() for name in dataType.names]
356+
357+
field_convs = [
358+
ArrowTableToRowsConversion._create_converter(f.dataType) for f in dataType.fields
359+
]
329360

330361
def convert_struct(value: Any) -> Any:
331362
if value is None:
332363
return None
333364
else:
334365
assert isinstance(value, dict)
335366

336-
if need_conv:
337-
_dict = {}
338-
for k, v in value.items():
339-
assert isinstance(k, str)
340-
_dict[k] = field_convs[k](v)
341-
return Row(**_dict)
342-
else:
343-
return Row(**value)
367+
_values = [
368+
field_convs[i](value.get(name, None))
369+
for i, name in enumerate(dedup_field_names)
370+
]
371+
return _create_row(field_names, _values)
344372

345373
return convert_struct
346374

@@ -425,13 +453,10 @@ def convert(table: "pa.Table", schema: StructType) -> List[Row]:
425453
ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields
426454
]
427455

428-
# table.to_pylist() automatically remove columns with duplicated names,
429-
# to avoid this, use columnar lists here.
430-
# TODO: support duplicated field names in the one struct. e.g. SF.struct("a", "a")
431456
columnar_data = [column.to_pylist() for column in table.columns]
432457

433458
rows: List[Row] = []
434459
for i in range(0, table.num_rows):
435-
values = [field_converters[j](columnar_data[j][i]) for j in range(0, table.num_columns)]
436-
rows.append(_create_row(fields=table.column_names, values=values))
460+
values = [field_converters[j](columnar_data[j][i]) for j in range(table.num_columns)]
461+
rows.append(_create_row(fields=schema.fieldNames(), values=values))
437462
return rows

python/pyspark/sql/tests/test_dataframe.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,6 +1700,26 @@ def test_where(self):
17001700
message_parameters={"arg_name": "condition", "arg_type": "int"},
17011701
)
17021702

1703+
def test_duplicate_field_names(self):
1704+
data = [Row(Row("a", 1), Row(2, 3, "b", 4, "c")), Row(Row("x", 6), Row(7, 8, "y", 9, "z"))]
1705+
schema = (
1706+
StructType()
1707+
.add("struct", StructType().add("x", StringType()).add("x", IntegerType()))
1708+
.add(
1709+
"struct",
1710+
StructType()
1711+
.add("a", IntegerType())
1712+
.add("x", IntegerType())
1713+
.add("x", StringType())
1714+
.add("y", IntegerType())
1715+
.add("y", StringType()),
1716+
)
1717+
)
1718+
df = self.spark.createDataFrame(data, schema=schema)
1719+
1720+
self.assertEqual(df.schema, schema)
1721+
self.assertEqual(df.collect(), data)
1722+
17031723

17041724
class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
17051725
# These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is

0 commit comments

Comments
 (0)