1
1
import warnings
2
- from collections import deque
2
+ from collections import defaultdict , deque
3
3
from collections .abc import Iterable as CollectionsIterable
4
4
from typing import (
5
5
TYPE_CHECKING ,
6
6
Any ,
7
+ DefaultDict ,
7
8
Deque ,
8
9
Dict ,
9
10
FrozenSet ,
@@ -249,6 +250,8 @@ def Schema(default: Any, **kwargs: Any) -> Any:
249
250
SHAPE_ITERABLE = 9
250
251
SHAPE_GENERIC = 10
251
252
SHAPE_DEQUE = 11
253
+ SHAPE_DICT = 12
254
+ SHAPE_DEFAULTDICT = 13
252
255
SHAPE_NAME_LOOKUP = {
253
256
SHAPE_LIST : 'List[{}]' ,
254
257
SHAPE_SET : 'Set[{}]' ,
@@ -257,8 +260,12 @@ def Schema(default: Any, **kwargs: Any) -> Any:
257
260
SHAPE_FROZENSET : 'FrozenSet[{}]' ,
258
261
SHAPE_ITERABLE : 'Iterable[{}]' ,
259
262
SHAPE_DEQUE : 'Deque[{}]' ,
263
+ SHAPE_DICT : 'Dict[{}]' ,
264
+ SHAPE_DEFAULTDICT : 'DefaultDict[{}]' ,
260
265
}
261
266
267
+ MAPPING_LIKE_SHAPES : Set [int ] = {SHAPE_DEFAULTDICT , SHAPE_DICT , SHAPE_MAPPING }
268
+
262
269
263
270
class ModelField (Representation ):
264
271
__slots__ = (
@@ -572,6 +579,14 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
572
579
elif issubclass (origin , Sequence ):
573
580
self .type_ = get_args (self .type_ )[0 ]
574
581
self .shape = SHAPE_SEQUENCE
582
+ elif issubclass (origin , DefaultDict ):
583
+ self .key_field = self ._create_sub_type (get_args (self .type_ )[0 ], 'key_' + self .name , for_keys = True )
584
+ self .type_ = get_args (self .type_ )[1 ]
585
+ self .shape = SHAPE_DEFAULTDICT
586
+ elif issubclass (origin , Dict ):
587
+ self .key_field = self ._create_sub_type (get_args (self .type_ )[0 ], 'key_' + self .name , for_keys = True )
588
+ self .type_ = get_args (self .type_ )[1 ]
589
+ self .shape = SHAPE_DICT
575
590
elif issubclass (origin , Mapping ):
576
591
self .key_field = self ._create_sub_type (get_args (self .type_ )[0 ], 'key_' + self .name , for_keys = True )
577
592
self .type_ = get_args (self .type_ )[1 ]
@@ -688,8 +703,8 @@ def validate(
688
703
689
704
if self .shape == SHAPE_SINGLETON :
690
705
v , errors = self ._validate_singleton (v , values , loc , cls )
691
- elif self .shape == SHAPE_MAPPING :
692
- v , errors = self ._validate_mapping (v , values , loc , cls )
706
+ elif self .shape in MAPPING_LIKE_SHAPES :
707
+ v , errors = self ._validate_mapping_like (v , values , loc , cls )
693
708
elif self .shape == SHAPE_TUPLE :
694
709
v , errors = self ._validate_tuple (v , values , loc , cls )
695
710
elif self .shape == SHAPE_ITERABLE :
@@ -806,7 +821,7 @@ def _validate_tuple(
806
821
else :
807
822
return tuple (result ), None
808
823
809
- def _validate_mapping (
824
+ def _validate_mapping_like (
810
825
self , v : Any , values : Dict [str , Any ], loc : 'LocStr' , cls : Optional ['ModelOrDc' ]
811
826
) -> 'ValidateReturn' :
812
827
try :
@@ -832,8 +847,30 @@ def _validate_mapping(
832
847
result [key_result ] = value_result
833
848
if errors :
834
849
return v , errors
835
- else :
850
+ elif self . shape == SHAPE_DICT :
836
851
return result , None
852
+ elif self .shape == SHAPE_DEFAULTDICT :
853
+ return defaultdict (self .type_ , result ), None
854
+ else :
855
+ return self ._get_mapping_value (v , result ), None
856
+
857
+ def _get_mapping_value (self , original : T , converted : Dict [Any , Any ]) -> Union [T , Dict [Any , Any ]]:
858
+ """
859
+ When type is `Mapping[KT, KV]` (or another unsupported mapping), we try to avoid
860
+ coercing to `dict` unwillingly.
861
+ """
862
+ original_cls = original .__class__
863
+
864
+ if original_cls == dict or original_cls == Dict :
865
+ return converted
866
+ elif original_cls in {defaultdict , DefaultDict }:
867
+ return defaultdict (self .type_ , converted )
868
+ else :
869
+ try :
870
+ # Counter, OrderedDict, UserDict, ...
871
+ return original_cls (converted ) # type: ignore
872
+ except TypeError :
873
+ raise RuntimeError (f'Could not convert dictionary to { original_cls .__name__ !r} ' ) from None
837
874
838
875
def _validate_singleton (
839
876
self , v : Any , values : Dict [str , Any ], loc : 'LocStr' , cls : Optional ['ModelOrDc' ]
@@ -876,7 +913,7 @@ def _type_display(self) -> PyObjectStr:
876
913
t = display_as_type (self .type_ )
877
914
878
915
# have to do this since display_as_type(self.outer_type_) is different (and wrong) on python 3.6
879
- if self .shape == SHAPE_MAPPING :
916
+ if self .shape in MAPPING_LIKE_SHAPES :
880
917
t = f'Mapping[{ display_as_type (self .key_field .type_ )} , { t } ]' # type: ignore
881
918
elif self .shape == SHAPE_TUPLE :
882
919
t = 'Tuple[{}]' .format (', ' .join (display_as_type (f .type_ ) for f in self .sub_fields )) # type: ignore
0 commit comments