Skip to content

Commit 3ec3559

Browse files
authored
Support Field in dataclass + 'metadata' kwarg of dataclasses.field (pydantic#2384)
* Support `Field` in `dataclass` + `'metadata'` kwarg of `dataclasses.field` Please enter the commit message for your changes. Lines starting * add `__has_field_info_default__` for minimal effect on perf * lower complexity of `_process_class`
1 parent f32832a commit 3ec3559

File tree

4 files changed

+70
-23
lines changed

4 files changed

+70
-23
lines changed

changes/2384-PrettyWood.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve field declaration for _pydantic_ `dataclass` by allowing the usage of _pydantic_ `Field` or `'metadata'` kwarg of `dataclasses.field`

docs/examples/dataclasses_default_schema.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import dataclasses
2-
from typing import List
2+
from typing import List, Optional
3+
4+
from pydantic import Field
35
from pydantic.dataclasses import dataclass
46

57

@@ -8,6 +10,11 @@ class User:
810
id: int
911
name: str = 'John Doe'
1012
friends: List[int] = dataclasses.field(default_factory=lambda: [0])
13+
age: Optional[int] = dataclasses.field(
14+
default=None,
15+
metadata=dict(title='The age of the user', description='do not lie!')
16+
)
17+
height: Optional[int] = Field(None, title='The height in cm', ge=50, le=300)
1118

1219

1320
user = User(id='42')

pydantic/dataclasses.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from .class_validators import gather_all_validators
44
from .error_wrappers import ValidationError
55
from .errors import DataclassTypeError
6-
from .fields import Required
6+
from .fields import Field, FieldInfo, Required, Undefined
77
from .main import create_model, validate_model
88
from .typing import resolve_annotations
99
from .utils import ClassAttribute
1010

1111
if TYPE_CHECKING:
1212
from .main import BaseConfig, BaseModel # noqa: F401
13-
from .typing import CallableGenerator
13+
from .typing import CallableGenerator, NoArgAnyCallable
1414

1515
DataclassT = TypeVar('DataclassT', bound='Dataclass')
1616

@@ -19,6 +19,7 @@ class Dataclass:
1919
__initialised__: bool
2020
__post_init_original__: Optional[Callable[..., None]]
2121
__processed__: Optional[ClassAttribute]
22+
__has_field_info_default__: bool # whether or not a `pydantic.Field` is used as default value
2223

2324
def __init__(self, *args: Any, **kwargs: Any) -> None:
2425
pass
@@ -80,6 +81,30 @@ def is_builtin_dataclass(_cls: Type[Any]) -> bool:
8081
return not hasattr(_cls, '__processed__') and dataclasses.is_dataclass(_cls)
8182

8283

84+
def _generate_pydantic_post_init(
85+
post_init_original: Optional[Callable[..., None]], post_init_post_parse: Optional[Callable[..., None]]
86+
) -> Callable[..., None]:
87+
def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
88+
if post_init_original is not None:
89+
post_init_original(self, *initvars)
90+
91+
if getattr(self, '__has_field_info_default__', False):
92+
# We need to remove `FieldInfo` values since they are not valid as input
93+
# It's ok to do that because they are obviously the default values!
94+
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
95+
else:
96+
input_data = self.__dict__
97+
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
98+
if validation_error:
99+
raise validation_error
100+
object.__setattr__(self, '__dict__', d)
101+
object.__setattr__(self, '__initialised__', True)
102+
if post_init_post_parse is not None:
103+
post_init_post_parse(self, *initvars)
104+
105+
return _pydantic_post_init
106+
107+
83108
def _process_class(
84109
_cls: Type[Any],
85110
init: bool,
@@ -100,16 +125,7 @@ def _process_class(
100125

101126
post_init_post_parse = getattr(_cls, '__post_init_post_parse__', None)
102127

103-
def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
104-
if post_init_original is not None:
105-
post_init_original(self, *initvars)
106-
d, _, validation_error = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__)
107-
if validation_error:
108-
raise validation_error
109-
object.__setattr__(self, '__dict__', d)
110-
object.__setattr__(self, '__initialised__', True)
111-
if post_init_post_parse is not None:
112-
post_init_post_parse(self, *initvars)
128+
_pydantic_post_init = _generate_pydantic_post_init(post_init_original, post_init_post_parse)
113129

114130
# If the class is already a dataclass, __post_init__ will not be called automatically
115131
# so no validation will be added.
@@ -144,22 +160,31 @@ def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None:
144160
)
145161
cls.__processed__ = ClassAttribute('__processed__', True)
146162

147-
fields: Dict[str, Any] = {}
163+
field_definitions: Dict[str, Any] = {}
148164
for field in dataclasses.fields(cls):
165+
default: Any = Undefined
166+
default_factory: Optional['NoArgAnyCallable'] = None
167+
field_info: FieldInfo
149168

150-
if field.default != dataclasses.MISSING:
151-
field_value = field.default
169+
if field.default is not dataclasses.MISSING:
170+
default = field.default
152171
# mypy issue 7020 and 708
153-
elif field.default_factory != dataclasses.MISSING: # type: ignore
154-
field_value = field.default_factory() # type: ignore
172+
elif field.default_factory is not dataclasses.MISSING: # type: ignore
173+
default_factory = field.default_factory # type: ignore
174+
else:
175+
default = Required
176+
177+
if isinstance(default, FieldInfo):
178+
field_info = default
179+
cls.__has_field_info_default__ = True
155180
else:
156-
field_value = Required
181+
field_info = Field(default=default, default_factory=default_factory, **field.metadata)
157182

158-
fields[field.name] = (field.type, field_value)
183+
field_definitions[field.name] = (field.type, field_info)
159184

160185
validators = gather_all_validators(cls)
161186
cls.__pydantic_model__ = create_model(
162-
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **fields
187+
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **field_definitions
163188
)
164189

165190
cls.__initialised__ = False

tests/test_dataclasses.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ class User:
429429
assert fields['id'].default is None
430430

431431
assert fields['aliases'].required is False
432-
assert fields['aliases'].default == {'John': 'Joey'}
432+
assert fields['aliases'].default_factory() == {'John': 'Joey'}
433433

434434

435435
def test_default_factory_singleton_field():
@@ -456,6 +456,10 @@ class User:
456456
name: str = 'John Doe'
457457
aliases: Dict[str, str] = dataclasses.field(default_factory=lambda: {'John': 'Joey'})
458458
signup_ts: datetime = None
459+
age: Optional[int] = dataclasses.field(
460+
default=None, metadata=dict(title='The age of the user', description='do not lie!')
461+
)
462+
height: Optional[int] = pydantic.Field(None, title='The height in cm', ge=50, le=300)
459463

460464
user = User(id=123)
461465
assert user.__pydantic_model__.schema() == {
@@ -466,11 +470,21 @@ class User:
466470
'name': {'title': 'Name', 'default': 'John Doe', 'type': 'string'},
467471
'aliases': {
468472
'title': 'Aliases',
469-
'default': {'John': 'Joey'},
470473
'type': 'object',
471474
'additionalProperties': {'type': 'string'},
472475
},
473476
'signup_ts': {'title': 'Signup Ts', 'type': 'string', 'format': 'date-time'},
477+
'age': {
478+
'title': 'The age of the user',
479+
'description': 'do not lie!',
480+
'type': 'integer',
481+
},
482+
'height': {
483+
'title': 'The height in cm',
484+
'minimum': 50,
485+
'maximum': 300,
486+
'type': 'integer',
487+
},
474488
},
475489
'required': ['id'],
476490
}

0 commit comments

Comments
 (0)