Skip to content

Commit e1b88ee

Browse files
committed
handle schemas with nested arrays/objects. recast int/float/bool from strings
1 parent b9d8f8c commit e1b88ee

File tree

11 files changed

+497
-7
lines changed

11 files changed

+497
-7
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ disable=
1616

1717
[BASIC]
1818

19-
good-names=e,ex,f,fp,i,j,k,n,_
19+
good-names=e,ex,f,fp,i,j,k,v,n,_
2020

2121
[FORMAT]
2222

python/rpdk/python/codegen.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from rpdk.core.jsonutils.resolver import ContainerType, resolve_models
1616
from rpdk.core.plugin_base import LanguagePlugin
1717

18-
from .resolver import translate_type
18+
from .resolver import contains_model, translate_type
1919

2020
LOG = logging.getLogger(__name__)
2121

@@ -44,6 +44,7 @@ def __init__(self):
4444
trim_blocks=True, lstrip_blocks=True, keep_trailing_newline=True
4545
)
4646
self.env.filters["translate_type"] = translate_type
47+
self.env.filters["contains_model"] = contains_model
4748
self.env.globals["ContainerType"] = ContainerType
4849
self.namespace = None
4950
self.package_name = None

python/rpdk/python/resolver.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,11 @@ def translate_type(resolved_type):
3333
return f"AbstractSet[{item_type}]"
3434

3535
raise ValueError(f"Unknown container type {resolved_type.container}")
36+
37+
38+
def contains_model(resolved_type):
39+
if resolved_type.container == ContainerType.LIST:
40+
return contains_model(resolved_type.type)
41+
if resolved_type.container == ContainerType.MODEL:
42+
return True
43+
return False

python/rpdk/python/templates/models.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# DO NOT modify this file by hand, changes will be overwritten
2+
import sys
23
from dataclasses import dataclass
4+
from inspect import getmembers, isclass
35
from typing import (
46
AbstractSet,
57
Any,
@@ -16,6 +18,8 @@
1618
BaseResourceHandlerRequest,
1719
BaseResourceModel,
1820
)
21+
from cloudformation_cli_python_lib.recast import recast_object
22+
from cloudformation_cli_python_lib.utils import deserialize_list
1923

2024
T = TypeVar("T")
2125

@@ -35,7 +39,7 @@ class ResourceHandlerRequest(BaseResourceHandlerRequest):
3539

3640
{% for model, properties in models.items() %}
3741
@dataclass
38-
class {{ model }}{% if model == "ResourceModel" %}(BaseResourceModel){% endif %}:
42+
class {{ model }}(BaseResourceModel):
3943
{% for name, type in properties.items() %}
4044
{{ name }}: Optional[{{ type|translate_type }}]
4145
{% endfor %}
@@ -47,12 +51,22 @@ def _deserialize(
4751
) -> Optional["_{{ model }}"]:
4852
if not json_data:
4953
return None
54+
{% if model == "ResourceModel" %}
55+
dataclasses = {n: o for n, o in getmembers(sys.modules[__name__]) if isclass(o)}
56+
recast_object(cls, json_data, dataclasses)
57+
{% endif %}
5058
return cls(
5159
{% for name, type in properties.items() %}
5260
{% if type.container == ContainerType.MODEL %}
5361
{{ name }}={{ type.type }}._deserialize(json_data.get("{{ name }}")),
5462
{% elif type.container == ContainerType.SET %}
5563
{{ name }}=set_or_none(json_data.get("{{ name }}")),
64+
{% elif type.container == ContainerType.LIST %}
65+
{% if type | contains_model %}
66+
{{name}}=deserialize_list(json_data.get("{{ name }}"), {{name}}),
67+
{% else %}
68+
{{ name }}=json_data.get("{{ name }}"),
69+
{% endif %}
5670
{% else %}
5771
{{ name }}=json_data.get("{{ name }}"),
5872
{% endif %}

src/cloudformation_cli_python_lib/interface.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,29 @@ class HandlerErrorCode(str, _AutoName):
5959

6060
class BaseResourceModel:
6161
def _serialize(self) -> Mapping[str, Any]:
62-
return self.__dict__
62+
ser = self.__dict__
63+
del_keys = []
64+
for k, v in ser.items():
65+
if isinstance(v, list):
66+
ser[k] = self._serialize_list(v)
67+
elif isinstance(v, BaseResourceModel):
68+
ser[k] = v._serialize() # pylint: disable=protected-access
69+
elif v is None:
70+
del_keys.append(k)
71+
for k in del_keys:
72+
del ser[k]
73+
return ser
74+
75+
def _serialize_list(self, v: List[Any]) -> List[Any]:
76+
ser: List[Any] = []
77+
for i in v:
78+
if isinstance(i, list):
79+
ser.append(self._serialize_list(i))
80+
elif isinstance(i, BaseResourceModel):
81+
ser.append(i._serialize()) # pylint: disable=protected-access
82+
else:
83+
ser.append(i)
84+
return ser
6385

6486
@classmethod
6587
def _deserialize(
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# ignoring mypy on the import as it catches ForwardRef as invalid, and we are using it
2+
# for introspection https://docs.python.org/3/library/typing.html#typing.ForwardRef
3+
from typing import Any, Dict, ForwardRef, List, Mapping # type: ignore
4+
5+
from .exceptions import InvalidRequest
6+
7+
8+
# CloudFormation recasts all primitive types as strings, this tries to set them back to
9+
# the types in the type hints
10+
def recast_object(
11+
cls: Any, json_data: Mapping[str, Any], classes: Dict[str, Any]
12+
) -> None:
13+
if not isinstance(json_data, dict):
14+
raise InvalidRequest(f"Can only parse dict items, not {type(json_data)}")
15+
for k, v in json_data.items():
16+
if isinstance(v, dict):
17+
child_cls = _field_to_type(cls.__dataclass_fields__[k].type, k, classes)
18+
recast_object(child_cls, v, classes)
19+
elif isinstance(v, list):
20+
json_data[k] = _recast_lists(cls, k, v, classes)
21+
elif isinstance(v, str):
22+
dest_type = _field_to_type(cls.__dataclass_fields__[k].type, k, classes)
23+
json_data[k] = _recast_primitive(dest_type, k, v)
24+
else:
25+
raise InvalidRequest(f"Unsupported type: {type(v)} for {k}")
26+
27+
28+
def _recast_lists(cls: Any, k: str, v: List[Any], classes: Dict[str, Any]) -> List[Any]:
29+
casted_list: List[Any] = []
30+
if k in cls.__dataclass_fields__:
31+
cls = _field_to_type(cls.__dataclass_fields__[k].type, k, classes)
32+
for item in v:
33+
if isinstance(item, str):
34+
casted_item: Any = _recast_primitive(cls, k, item)
35+
elif isinstance(item, list):
36+
casted_item = _recast_lists(cls, k, item, classes)
37+
elif isinstance(item, dict):
38+
recast_object(cls, item, classes)
39+
casted_item = item
40+
else:
41+
raise InvalidRequest(f"Unsupported type: {type(v)} for {k}")
42+
casted_list.append(casted_item)
43+
return casted_list
44+
45+
46+
def _recast_primitive(cls: Any, k: str, v: str) -> Any:
47+
if cls == bool:
48+
if v.lower() == "true":
49+
return True
50+
if v.lower() == "false":
51+
return False
52+
raise InvalidRequest(f'value for {k} "{v}" is not boolean')
53+
return cls(v)
54+
55+
56+
def _field_to_type(field: Any, key: str, classes: Dict[str, Any]) -> Any:
57+
if field in [int, float, str, bool]:
58+
return field
59+
# If it's a ForwardRef we need to find base type
60+
if isinstance(field, ForwardRef):
61+
# Assuming codegen added an _ as a prefix, removing it and then gettting the
62+
# class from model classes
63+
return classes[field.__forward_arg__[1:]]
64+
# Assuming this is a generic object created by typing.Union
65+
try:
66+
possible_types = field.__args__
67+
except AttributeError:
68+
raise InvalidRequest(f"Cannot process type {field.__repr__()} for field {key}")
69+
# Assuming that the union is generated from typing.Optional, so only
70+
# contains one type and None
71+
# pylint: disable=unidiomatic-typecheck
72+
fields = [t for t in possible_types if type(None) != t]
73+
if len(fields) != 1:
74+
raise InvalidRequest(f"Cannot process type {field.__repr__()} for field {key}")
75+
field = fields[0]
76+
# If it's a primitive we're done
77+
if field in [int, float, str, bool]:
78+
return field
79+
# If it's a ForwardRef we need to find base type
80+
if isinstance(field, ForwardRef):
81+
# Assuming codegen added an _ as a prefix, removing it and then gettting the
82+
# class from model classes
83+
return classes[field.__forward_arg__[1:]]
84+
# If it's not a type we don't know how to handle we bail
85+
if field._name not in ["Sequence"]: # pylint: disable=protected-access
86+
raise InvalidRequest(f"Cannot process type {field.__repr__()} for field {key}")
87+
return _field_to_type(field.__args__[0], key, classes)

src/cloudformation_cli_python_lib/utils.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,19 @@
22
import json
33
from dataclasses import dataclass, field
44
from datetime import date, datetime, time
5-
from typing import Any, Callable, Mapping, MutableMapping, Optional, Type
6-
5+
from typing import (
6+
Any,
7+
Callable,
8+
Dict,
9+
List,
10+
Mapping,
11+
MutableMapping,
12+
Optional,
13+
Type,
14+
Union,
15+
)
16+
17+
from .exceptions import InvalidRequest
718
from .interface import Action, BaseResourceHandlerRequest, BaseResourceModel
819

920

@@ -121,3 +132,20 @@ def to_modelled(
121132
class LambdaContext:
122133
get_remaining_time_in_millis: Callable[["LambdaContext"], int]
123134
invoked_function_arn: str
135+
136+
137+
def deserialize_list(
138+
json_data: Union[List[Any], Dict[str, Any]], inner_dataclass: Any
139+
) -> Optional[List[Any]]:
140+
if not json_data:
141+
return None
142+
output = []
143+
for item in json_data:
144+
if isinstance(item, list):
145+
output.append(deserialize_list(item, inner_dataclass))
146+
elif isinstance(item, dict):
147+
# pylint: disable=protected-access
148+
output.append(inner_dataclass._deserialize(item))
149+
else:
150+
raise InvalidRequest(f"cannot deserialize lists of {type(item)}")
151+
return output

tests/lib/recast_test.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# pylint: disable=protected-access
2+
from typing import Any, Optional, Union
3+
4+
import pytest
5+
from cloudformation_cli_python_lib.exceptions import InvalidRequest
6+
from cloudformation_cli_python_lib.recast import (
7+
_field_to_type,
8+
_recast_lists,
9+
_recast_primitive,
10+
recast_object,
11+
)
12+
13+
from .sample_model import ResourceModel as ComplexResourceModel, SimpleResourceModel
14+
15+
16+
def test_recast_object_simple():
17+
payload = {
18+
"AnInt": "1",
19+
"ABool": "true",
20+
"AList": [
21+
{
22+
"DeeperBool": "false",
23+
"DeeperList": ["1", "2", "3"],
24+
"DeeperDictInList": {"DeepestBool": "true", "DeepestList": ["3", "4"]},
25+
},
26+
{"DeeperDictInList": {"DeepestBool": "false", "DeepestList": ["6", "7"]}},
27+
],
28+
"ADict": {
29+
"DeepBool": "true",
30+
"DeepList": ["10", "11"],
31+
"DeepDict": {
32+
"DeeperBool": "false",
33+
"DeeperList": ["1", "2", "3"],
34+
"DeeperDict": {"DeepestBool": "true", "DeepestList": ["13", "17"]},
35+
},
36+
},
37+
"NestedList": [
38+
[{"NestedListInt": "true", "NestedListList": ["1", "2", "3"]}],
39+
[{"NestedListInt": "false", "NestedListList": ["11", "12", "13"]}],
40+
],
41+
}
42+
expected = {
43+
"AnInt": 1,
44+
"ABool": True,
45+
"AList": [
46+
{
47+
"DeeperBool": False,
48+
"DeeperList": [1, 2, 3],
49+
"DeeperDictInList": {"DeepestBool": True, "DeepestList": [3, 4]},
50+
},
51+
{"DeeperDictInList": {"DeepestBool": False, "DeepestList": [6, 7]}},
52+
],
53+
"ADict": {
54+
"DeepBool": True,
55+
"DeepList": [10, 11],
56+
"DeepDict": {
57+
"DeeperBool": False,
58+
"DeeperList": [1, 2, 3],
59+
"DeeperDict": {"DeepestBool": True, "DeepestList": [13, 17]},
60+
},
61+
},
62+
"NestedList": [
63+
[{"NestedListInt": True, "NestedListList": [1.0, 2.0, 3.0]}],
64+
[{"NestedListInt": False, "NestedListList": [11.0, 12.0, 13.0]}],
65+
],
66+
}
67+
model = ComplexResourceModel._deserialize(payload)
68+
assert expected == payload
69+
assert expected == model._serialize()
70+
71+
72+
def test_recast_object_invalid_json_type():
73+
with pytest.raises(InvalidRequest) as excinfo:
74+
recast_object(SimpleResourceModel, [], {})
75+
assert str(excinfo.value) == f"Can only parse dict items, not {type([])}"
76+
77+
78+
def test_recast_object_invalid_sub_type():
79+
k = "key"
80+
v = (1, 2)
81+
with pytest.raises(InvalidRequest) as excinfo:
82+
recast_object(SimpleResourceModel, {k: v}, {})
83+
assert str(excinfo.value) == f"Unsupported type: {type(v)} for {k}"
84+
85+
86+
def test_recast_list_invalid_sub_type():
87+
k = "key"
88+
v = (1, 2)
89+
with pytest.raises(InvalidRequest) as excinfo:
90+
_recast_lists(SimpleResourceModel, k, v, {})
91+
assert str(excinfo.value) == f"Unsupported type: {type(v)} for {k}"
92+
93+
94+
def test_recast_boolean_invalid_value():
95+
k = "key"
96+
v = "not-a-bool"
97+
with pytest.raises(InvalidRequest) as excinfo:
98+
_recast_primitive(bool, k, v)
99+
assert str(excinfo.value) == f'value for {k} "{v}" is not boolean'
100+
101+
102+
def test_field_to_type_unhandled_types():
103+
k = "key"
104+
for field in [Union[str, list], Any, Optional[Any]]:
105+
with pytest.raises(InvalidRequest) as excinfo:
106+
_field_to_type(field, k, {})
107+
assert str(excinfo.value).startswith("Cannot process type ")

0 commit comments

Comments
 (0)