2121"""
2222import datetime
2323import itertools
24- from typing import Any , Callable , List , Optional , Union , TYPE_CHECKING
24+ from typing import Any , Callable , Iterable , List , Optional , Union , TYPE_CHECKING
2525
2626from pyspark .sql .types import (
2727 cast ,
@@ -750,6 +750,7 @@ def _create_converter_from_pandas(
750750 * ,
751751 timezone : Optional [str ],
752752 error_on_duplicated_field_names : bool = True ,
753+ ignore_unexpected_complex_type_values : bool = False ,
753754) -> Callable [["pd.Series" ], "pd.Series" ]:
754755 """
755756 Create a converter of pandas Series to create Spark DataFrame with Arrow optimization.
@@ -763,6 +764,17 @@ def _create_converter_from_pandas(
763764 error_on_duplicated_field_names : bool, optional
764765 Whether raise an exception when there are duplicated field names.
765766 (default ``True``)
767+ ignore_unexpected_complex_type_values : bool, optional
768+ Whether ignore the case where unexpected values are given for complex types.
769+ If ``False``, each complex type expects:
770+
771+ * array type: :class:`Iterable`
772+ * map type: :class:`dict`
773+ * struct type: :class:`dict` or :class:`tuple`
774+
775+ and raise an AssertionError when the given value is not the expected type.
776+ If ``True``, just ignore and return the give value.
777+ (default ``False``)
766778
767779 Returns
768780 -------
@@ -781,28 +793,51 @@ def correct_timestamp(pser: pd.Series) -> pd.Series:
781793 def _converter (dt : DataType ) -> Optional [Callable [[Any ], Any ]]:
782794
783795 if isinstance (dt , ArrayType ):
784- _element_conv = _converter (dt .elementType )
785- if _element_conv is None :
786- return None
796+ _element_conv = _converter (dt .elementType ) or (lambda x : x )
787797
788- def convert_array (value : Any ) -> Any :
789- if value is None :
790- return None
791- else :
792- return [_element_conv (v ) for v in value ] # type: ignore[misc]
798+ if ignore_unexpected_complex_type_values :
799+
800+ def convert_array (value : Any ) -> Any :
801+ if value is None :
802+ return None
803+ elif isinstance (value , Iterable ):
804+ return [_element_conv (v ) for v in value ]
805+ else :
806+ return value
807+
808+ else :
809+
810+ def convert_array (value : Any ) -> Any :
811+ if value is None :
812+ return None
813+ else :
814+ assert isinstance (value , Iterable )
815+ return [_element_conv (v ) for v in value ]
793816
794817 return convert_array
795818
796819 elif isinstance (dt , MapType ):
797820 _key_conv = _converter (dt .keyType ) or (lambda x : x )
798821 _value_conv = _converter (dt .valueType ) or (lambda x : x )
799822
800- def convert_map (value : Any ) -> Any :
801- if value is None :
802- return None
803- else :
804- assert isinstance (value , dict )
805- return [(_key_conv (k ), _value_conv (v )) for k , v in value .items ()]
823+ if ignore_unexpected_complex_type_values :
824+
825+ def convert_map (value : Any ) -> Any :
826+ if value is None :
827+ return None
828+ elif isinstance (value , dict ):
829+ return [(_key_conv (k ), _value_conv (v )) for k , v in value .items ()]
830+ else :
831+ return value
832+
833+ else :
834+
835+ def convert_map (value : Any ) -> Any :
836+ if value is None :
837+ return None
838+ else :
839+ assert isinstance (value , dict )
840+ return [(_key_conv (k ), _value_conv (v )) for k , v in value .items ()]
806841
807842 return convert_map
808843
@@ -820,17 +855,38 @@ def convert_map(value: Any) -> Any:
820855
821856 field_convs = [_converter (f .dataType ) or (lambda x : x ) for f in dt .fields ]
822857
823- def convert_struct (value : Any ) -> Any :
824- if value is None :
825- return None
826- elif isinstance (value , dict ):
827- return {
828- dedup_field_names [i ]: field_convs [i ](value .get (key , None ))
829- for i , key in enumerate (field_names )
830- }
831- else :
832- assert isinstance (value , tuple )
833- return {dedup_field_names [i ]: field_convs [i ](v ) for i , v in enumerate (value )}
858+ if ignore_unexpected_complex_type_values :
859+
860+ def convert_struct (value : Any ) -> Any :
861+ if value is None :
862+ return None
863+ elif isinstance (value , dict ):
864+ return {
865+ dedup_field_names [i ]: field_convs [i ](value .get (key , None ))
866+ for i , key in enumerate (field_names )
867+ }
868+ elif isinstance (value , tuple ):
869+ return {
870+ dedup_field_names [i ]: field_convs [i ](v ) for i , v in enumerate (value )
871+ }
872+ else :
873+ return value
874+
875+ else :
876+
877+ def convert_struct (value : Any ) -> Any :
878+ if value is None :
879+ return None
880+ elif isinstance (value , dict ):
881+ return {
882+ dedup_field_names [i ]: field_convs [i ](value .get (key , None ))
883+ for i , key in enumerate (field_names )
884+ }
885+ else :
886+ assert isinstance (value , tuple )
887+ return {
888+ dedup_field_names [i ]: field_convs [i ](v ) for i , v in enumerate (value )
889+ }
834890
835891 return convert_struct
836892
0 commit comments