|
20 | 20 |
|
21 | 21 | import json |
22 | 22 |
|
23 | | -import pyarrow as pa |
24 | | - |
25 | 23 | from typing import Any, Dict, Optional |
26 | 24 |
|
27 | 25 | from pyspark.sql.types import ( |
@@ -299,147 +297,3 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: |
299 | 297 | return UserDefinedType.fromJson(json_value) |
300 | 298 | else: |
301 | 299 | raise Exception(f"Unsupported data type {schema}") |
302 | | - |
303 | | - |
304 | | -def to_arrow_type(dt: DataType) -> "pa.DataType": |
305 | | - """ |
306 | | - Convert Spark data type to pyarrow type. |
307 | | -
|
308 | | - This function refers to 'pyspark.sql.pandas.types.to_arrow_type' but relax the restriction, |
309 | | - e.g. it supports nested StructType. |
310 | | - """ |
311 | | - if type(dt) == BooleanType: |
312 | | - arrow_type = pa.bool_() |
313 | | - elif type(dt) == ByteType: |
314 | | - arrow_type = pa.int8() |
315 | | - elif type(dt) == ShortType: |
316 | | - arrow_type = pa.int16() |
317 | | - elif type(dt) == IntegerType: |
318 | | - arrow_type = pa.int32() |
319 | | - elif type(dt) == LongType: |
320 | | - arrow_type = pa.int64() |
321 | | - elif type(dt) == FloatType: |
322 | | - arrow_type = pa.float32() |
323 | | - elif type(dt) == DoubleType: |
324 | | - arrow_type = pa.float64() |
325 | | - elif type(dt) == DecimalType: |
326 | | - arrow_type = pa.decimal128(dt.precision, dt.scale) |
327 | | - elif type(dt) == StringType: |
328 | | - arrow_type = pa.string() |
329 | | - elif type(dt) == BinaryType: |
330 | | - arrow_type = pa.binary() |
331 | | - elif type(dt) == DateType: |
332 | | - arrow_type = pa.date32() |
333 | | - elif type(dt) == TimestampType: |
334 | | - # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read |
335 | | - arrow_type = pa.timestamp("us", tz="UTC") |
336 | | - elif type(dt) == TimestampNTZType: |
337 | | - arrow_type = pa.timestamp("us", tz=None) |
338 | | - elif type(dt) == DayTimeIntervalType: |
339 | | - arrow_type = pa.duration("us") |
340 | | - elif type(dt) == ArrayType: |
341 | | - field = pa.field("element", to_arrow_type(dt.elementType), nullable=dt.containsNull) |
342 | | - arrow_type = pa.list_(field) |
343 | | - elif type(dt) == MapType: |
344 | | - key_field = pa.field("key", to_arrow_type(dt.keyType), nullable=False) |
345 | | - value_field = pa.field("value", to_arrow_type(dt.valueType), nullable=dt.valueContainsNull) |
346 | | - arrow_type = pa.map_(key_field, value_field) |
347 | | - elif type(dt) == StructType: |
348 | | - fields = [ |
349 | | - pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) |
350 | | - for field in dt |
351 | | - ] |
352 | | - arrow_type = pa.struct(fields) |
353 | | - elif type(dt) == NullType: |
354 | | - arrow_type = pa.null() |
355 | | - elif isinstance(dt, UserDefinedType): |
356 | | - arrow_type = to_arrow_type(dt.sqlType()) |
357 | | - else: |
358 | | - raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) |
359 | | - return arrow_type |
360 | | - |
361 | | - |
362 | | -def to_arrow_schema(schema: StructType) -> "pa.Schema": |
363 | | - """Convert a schema from Spark to Arrow""" |
364 | | - fields = [ |
365 | | - pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) |
366 | | - for field in schema |
367 | | - ] |
368 | | - return pa.schema(fields) |
369 | | - |
370 | | - |
371 | | -def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType: |
372 | | - """Convert pyarrow type to Spark data type. |
373 | | -
|
374 | | - This function refers to 'pyspark.sql.pandas.types.from_arrow_type' but relax the restriction, |
375 | | - e.g. it supports nested StructType, Array of TimestampType. However, Arrow DictionaryType is |
376 | | - not allowed. |
377 | | - """ |
378 | | - import pyarrow.types as types |
379 | | - |
380 | | - spark_type: DataType |
381 | | - if types.is_boolean(at): |
382 | | - spark_type = BooleanType() |
383 | | - elif types.is_int8(at): |
384 | | - spark_type = ByteType() |
385 | | - elif types.is_int16(at): |
386 | | - spark_type = ShortType() |
387 | | - elif types.is_int32(at): |
388 | | - spark_type = IntegerType() |
389 | | - elif types.is_int64(at): |
390 | | - spark_type = LongType() |
391 | | - elif types.is_float32(at): |
392 | | - spark_type = FloatType() |
393 | | - elif types.is_float64(at): |
394 | | - spark_type = DoubleType() |
395 | | - elif types.is_decimal(at): |
396 | | - spark_type = DecimalType(precision=at.precision, scale=at.scale) |
397 | | - elif types.is_string(at): |
398 | | - spark_type = StringType() |
399 | | - elif types.is_binary(at): |
400 | | - spark_type = BinaryType() |
401 | | - elif types.is_date32(at): |
402 | | - spark_type = DateType() |
403 | | - elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None: |
404 | | - spark_type = TimestampNTZType() |
405 | | - elif types.is_timestamp(at): |
406 | | - spark_type = TimestampType() |
407 | | - elif types.is_duration(at): |
408 | | - spark_type = DayTimeIntervalType() |
409 | | - elif types.is_list(at): |
410 | | - spark_type = ArrayType(from_arrow_type(at.value_type, prefer_timestamp_ntz)) |
411 | | - elif types.is_map(at): |
412 | | - spark_type = MapType( |
413 | | - from_arrow_type(at.key_type, prefer_timestamp_ntz), |
414 | | - from_arrow_type(at.item_type, prefer_timestamp_ntz), |
415 | | - ) |
416 | | - elif types.is_struct(at): |
417 | | - return StructType( |
418 | | - [ |
419 | | - StructField( |
420 | | - field.name, |
421 | | - from_arrow_type(field.type, prefer_timestamp_ntz), |
422 | | - nullable=field.nullable, |
423 | | - ) |
424 | | - for field in at |
425 | | - ] |
426 | | - ) |
427 | | - elif types.is_null(at): |
428 | | - spark_type = NullType() |
429 | | - else: |
430 | | - raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) |
431 | | - return spark_type |
432 | | - |
433 | | - |
434 | | -def from_arrow_schema(arrow_schema: "pa.Schema", prefer_timestamp_ntz: bool = False) -> StructType: |
435 | | - """Convert schema from Arrow to Spark.""" |
436 | | - return StructType( |
437 | | - [ |
438 | | - StructField( |
439 | | - field.name, |
440 | | - from_arrow_type(field.type, prefer_timestamp_ntz), |
441 | | - nullable=field.nullable, |
442 | | - ) |
443 | | - for field in arrow_schema |
444 | | - ] |
445 | | - ) |
0 commit comments