Skip to content

Commit 4ddf4f1

Browse files
ofekPrettyWoodsamuelcolvin
authored
Properly retain types of Mapping subclasses (pydantic#2325)
* Properly retain types of Mapping subclasses * Create 2325-ofek.md * update with feedback Co-Authored-By: Eric Jolibois <[email protected]> * satisfy mypy? * Update fields.py Co-Authored-By: Eric Jolibois <[email protected]> * show uncovered line numbers * fix coverage * update * address feedback * try * update Co-Authored-By: Eric Jolibois <[email protected]> * rename test * address feedback Co-authored-by: Eric Jolibois <[email protected]> Co-authored-by: Samuel Colvin <[email protected]>
1 parent aa92db5 commit 4ddf4f1

File tree

5 files changed

+115
-11
lines changed

5 files changed

+115
-11
lines changed

changes/2325-ofek.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Prevent `Mapping` subclasses from always being coerced to `dict`

pydantic/fields.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import warnings
2-
from collections import deque
2+
from collections import defaultdict, deque
33
from collections.abc import Iterable as CollectionsIterable
44
from typing import (
55
TYPE_CHECKING,
66
Any,
7+
DefaultDict,
78
Deque,
89
Dict,
910
FrozenSet,
@@ -249,6 +250,8 @@ def Schema(default: Any, **kwargs: Any) -> Any:
249250
SHAPE_ITERABLE = 9
250251
SHAPE_GENERIC = 10
251252
SHAPE_DEQUE = 11
253+
SHAPE_DICT = 12
254+
SHAPE_DEFAULTDICT = 13
252255
SHAPE_NAME_LOOKUP = {
253256
SHAPE_LIST: 'List[{}]',
254257
SHAPE_SET: 'Set[{}]',
@@ -257,8 +260,12 @@ def Schema(default: Any, **kwargs: Any) -> Any:
257260
SHAPE_FROZENSET: 'FrozenSet[{}]',
258261
SHAPE_ITERABLE: 'Iterable[{}]',
259262
SHAPE_DEQUE: 'Deque[{}]',
263+
SHAPE_DICT: 'Dict[{}]',
264+
SHAPE_DEFAULTDICT: 'DefaultDict[{}]',
260265
}
261266

267+
MAPPING_LIKE_SHAPES: Set[int] = {SHAPE_DEFAULTDICT, SHAPE_DICT, SHAPE_MAPPING}
268+
262269

263270
class ModelField(Representation):
264271
__slots__ = (
@@ -572,6 +579,14 @@ def _type_analysis(self) -> None: # noqa: C901 (ignore complexity)
572579
elif issubclass(origin, Sequence):
573580
self.type_ = get_args(self.type_)[0]
574581
self.shape = SHAPE_SEQUENCE
582+
elif issubclass(origin, DefaultDict):
583+
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
584+
self.type_ = get_args(self.type_)[1]
585+
self.shape = SHAPE_DEFAULTDICT
586+
elif issubclass(origin, Dict):
587+
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
588+
self.type_ = get_args(self.type_)[1]
589+
self.shape = SHAPE_DICT
575590
elif issubclass(origin, Mapping):
576591
self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True)
577592
self.type_ = get_args(self.type_)[1]
@@ -688,8 +703,8 @@ def validate(
688703

689704
if self.shape == SHAPE_SINGLETON:
690705
v, errors = self._validate_singleton(v, values, loc, cls)
691-
elif self.shape == SHAPE_MAPPING:
692-
v, errors = self._validate_mapping(v, values, loc, cls)
706+
elif self.shape in MAPPING_LIKE_SHAPES:
707+
v, errors = self._validate_mapping_like(v, values, loc, cls)
693708
elif self.shape == SHAPE_TUPLE:
694709
v, errors = self._validate_tuple(v, values, loc, cls)
695710
elif self.shape == SHAPE_ITERABLE:
@@ -806,7 +821,7 @@ def _validate_tuple(
806821
else:
807822
return tuple(result), None
808823

809-
def _validate_mapping(
824+
def _validate_mapping_like(
810825
self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc']
811826
) -> 'ValidateReturn':
812827
try:
@@ -832,8 +847,30 @@ def _validate_mapping(
832847
result[key_result] = value_result
833848
if errors:
834849
return v, errors
835-
else:
850+
elif self.shape == SHAPE_DICT:
836851
return result, None
852+
elif self.shape == SHAPE_DEFAULTDICT:
853+
return defaultdict(self.type_, result), None
854+
else:
855+
return self._get_mapping_value(v, result), None
856+
857+
def _get_mapping_value(self, original: T, converted: Dict[Any, Any]) -> Union[T, Dict[Any, Any]]:
858+
"""
859+
When type is `Mapping[KT, KV]` (or another unsupported mapping), we try to avoid
860+
coercing to `dict` unwillingly.
861+
"""
862+
original_cls = original.__class__
863+
864+
if original_cls == dict or original_cls == Dict:
865+
return converted
866+
elif original_cls in {defaultdict, DefaultDict}:
867+
return defaultdict(self.type_, converted)
868+
else:
869+
try:
870+
# Counter, OrderedDict, UserDict, ...
871+
return original_cls(converted) # type: ignore
872+
except TypeError:
873+
raise RuntimeError(f'Could not convert dictionary to {original_cls.__name__!r}') from None
837874

838875
def _validate_singleton(
839876
self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc']
@@ -876,7 +913,7 @@ def _type_display(self) -> PyObjectStr:
876913
t = display_as_type(self.type_)
877914

878915
# have to do this since display_as_type(self.outer_type_) is different (and wrong) on python 3.6
879-
if self.shape == SHAPE_MAPPING:
916+
if self.shape in MAPPING_LIKE_SHAPES:
880917
t = f'Mapping[{display_as_type(self.key_field.type_)}, {t}]' # type: ignore
881918
elif self.shape == SHAPE_TUPLE:
882919
t = 'Tuple[{}]'.format(', '.join(display_as_type(f.type_) for f in self.sub_fields)) # type: ignore

pydantic/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators
2929
from .error_wrappers import ErrorWrapper, ValidationError
3030
from .errors import ConfigError, DictError, ExtraError, MissingError
31-
from .fields import SHAPE_MAPPING, ModelField, ModelPrivateAttr, PrivateAttr, Undefined
31+
from .fields import MAPPING_LIKE_SHAPES, ModelField, ModelPrivateAttr, PrivateAttr, Undefined
3232
from .json import custom_pydantic_encoder, pydantic_encoder
3333
from .parse import Protocol, load_file, load_str_bytes
3434
from .schema import default_ref_template, model_schema
@@ -559,7 +559,8 @@ def json(
559559
@classmethod
560560
def _enforce_dict_if_root(cls, obj: Any) -> Any:
561561
if cls.__custom_root_type__ and (
562-
not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY}) or cls.__fields__[ROOT_KEY].shape == SHAPE_MAPPING
562+
not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY})
563+
or cls.__fields__[ROOT_KEY].shape in MAPPING_LIKE_SHAPES
563564
):
564565
return {ROOT_KEY: obj}
565566
else:

pydantic/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
from typing_extensions import Annotated, Literal
3030

3131
from .fields import (
32+
MAPPING_LIKE_SHAPES,
3233
SHAPE_FROZENSET,
3334
SHAPE_GENERIC,
3435
SHAPE_ITERABLE,
3536
SHAPE_LIST,
36-
SHAPE_MAPPING,
3737
SHAPE_SEQUENCE,
3838
SHAPE_SET,
3939
SHAPE_SINGLETON,
@@ -450,7 +450,7 @@ def field_type_schema(
450450
if field.shape in {SHAPE_SET, SHAPE_FROZENSET}:
451451
f_schema['uniqueItems'] = True
452452

453-
elif field.shape == SHAPE_MAPPING:
453+
elif field.shape in MAPPING_LIKE_SHAPES:
454454
f_schema = {'type': 'object'}
455455
key_field = cast(ModelField, field.key_field)
456456
regex = getattr(key_field.type_, 'regex', None)

tests/test_main.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
2+
from collections import defaultdict
23
from enum import Enum
3-
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Type, get_type_hints
4+
from typing import Any, Callable, ClassVar, DefaultDict, Dict, List, Mapping, Optional, Type, get_type_hints
45
from uuid import UUID, uuid4
56

67
import pytest
@@ -1611,6 +1612,70 @@ class Item(BaseModel):
16111612
assert id(image_2) == id(item.images[1])
16121613

16131614

1615+
def test_mapping_retains_type_subclass():
1616+
class CustomMap(dict):
1617+
pass
1618+
1619+
class Model(BaseModel):
1620+
x: Mapping[str, Mapping[str, int]]
1621+
1622+
m = Model(x=CustomMap(outer=CustomMap(inner=42)))
1623+
assert isinstance(m.x, CustomMap)
1624+
assert isinstance(m.x['outer'], CustomMap)
1625+
assert m.x['outer']['inner'] == 42
1626+
1627+
1628+
def test_mapping_retains_type_defaultdict():
1629+
class Model(BaseModel):
1630+
x: Mapping[str, int]
1631+
1632+
d = defaultdict(int)
1633+
d[1] = '2'
1634+
d['3']
1635+
1636+
m = Model(x=d)
1637+
assert isinstance(m.x, defaultdict)
1638+
assert m.x['1'] == 2
1639+
assert m.x['3'] == 0
1640+
1641+
1642+
def test_mapping_retains_type_fallback_error():
1643+
class CustomMap(dict):
1644+
def __init__(self, *args, **kwargs):
1645+
if args or kwargs:
1646+
raise TypeError('test')
1647+
super().__init__(*args, **kwargs)
1648+
1649+
class Model(BaseModel):
1650+
x: Mapping[str, int]
1651+
1652+
d = CustomMap()
1653+
d['one'] = 1
1654+
d['two'] = 2
1655+
1656+
with pytest.raises(RuntimeError, match="Could not convert dictionary to 'CustomMap'"):
1657+
Model(x=d)
1658+
1659+
1660+
def test_typing_coercion_dict():
1661+
class Model(BaseModel):
1662+
x: Dict[str, int]
1663+
1664+
m = Model(x={'one': 1, 'two': 2})
1665+
assert repr(m) == "Model(x={'one': 1, 'two': 2})"
1666+
1667+
1668+
def test_typing_coercion_defaultdict():
1669+
class Model(BaseModel):
1670+
x: DefaultDict[int, str]
1671+
1672+
d = defaultdict(str)
1673+
d['1']
1674+
m = Model(x=d)
1675+
m.x['a']
1676+
assert repr(m) == "Model(x=defaultdict(<class 'str'>, {1: '', 'a': ''}))"
1677+
1678+
16141679
def test_class_kwargs_config():
16151680
class Base(BaseModel, extra='forbid', alias_generator=str.upper):
16161681
a: int

0 commit comments

Comments
 (0)