Skip to content

Commit 3a77b31

Browse files
committed
handle Set and Map
1 parent a2d0941 commit 3a77b31

File tree

5 files changed

+95
-28
lines changed

5 files changed

+95
-28
lines changed

python/rpdk/python/resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def translate_type(resolved_type):
3636

3737

3838
def contains_model(resolved_type):
39-
if resolved_type.container == ContainerType.LIST:
39+
if resolved_type.container in [ContainerType.LIST, ContainerType.SET]:
4040
return contains_model(resolved_type.type)
4141
if resolved_type.container == ContainerType.MODEL:
4242
return True

src/cloudformation_cli_python_lib/interface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def _serialize(self) -> Mapping[str, Any]:
6868
ser[k] = v._serialize() # pylint: disable=protected-access
6969
elif v is None:
7070
del_keys.append(k)
71+
else:
72+
ser[k] = v
7173
for k in del_keys:
7274
del ser[k]
7375
return ser

src/cloudformation_cli_python_lib/recast.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import typing
2-
from typing import Any, Dict, List, Mapping
2+
from typing import Any, Dict, List, Mapping, Set
33

44
from .exceptions import InvalidRequest
55

@@ -11,12 +11,17 @@ def recast_object(
1111
) -> None:
1212
if not isinstance(json_data, dict):
1313
raise InvalidRequest(f"Can only parse dict items, not {type(json_data)}")
14+
# if type is Any, we leave it as is
15+
if cls == typing.Any:
16+
return
1417
for k, v in json_data.items():
1518
if isinstance(v, dict):
1619
child_cls = _field_to_type(cls.__dataclass_fields__[k].type, k, classes)
1720
recast_object(child_cls, v, classes)
1821
elif isinstance(v, list):
1922
json_data[k] = _recast_lists(cls, k, v, classes)
23+
elif isinstance(v, set):
24+
json_data[k] = _recast_sets(cls, k, v, classes)
2025
elif isinstance(v, str):
2126
dest_type = _field_to_type(cls.__dataclass_fields__[k].type, k, classes)
2227
json_data[k] = _recast_primitive(dest_type, k, v)
@@ -25,24 +30,46 @@ def recast_object(
2530

2631

2732
def _recast_lists(cls: Any, k: str, v: List[Any], classes: Dict[str, Any]) -> List[Any]:
33+
# Leave as is if type is Any
34+
if cls == typing.Any:
35+
return v
2836
casted_list: List[Any] = []
29-
if k in cls.__dataclass_fields__:
37+
if "__dataclass_fields__" not in dir(cls):
38+
pass
39+
elif k in cls.__dataclass_fields__:
3040
cls = _field_to_type(cls.__dataclass_fields__[k].type, k, classes)
3141
for item in v:
32-
if isinstance(item, str):
33-
casted_item: Any = _recast_primitive(cls, k, item)
34-
elif isinstance(item, list):
35-
casted_item = _recast_lists(cls, k, item, classes)
36-
elif isinstance(item, dict):
37-
recast_object(cls, item, classes)
38-
casted_item = item
39-
else:
40-
raise InvalidRequest(f"Unsupported type: {type(v)} for {k}")
41-
casted_list.append(casted_item)
42+
casted_list.append(cast_sequence_item(cls, k, item, classes))
4243
return casted_list
4344

4445

46+
def _recast_sets(cls: Any, k: str, v: Set[Any], classes: Dict[str, Any]) -> Set[Any]:
47+
casted_set: Set[Any] = set()
48+
if "__dataclass_fields__" in dir(cls):
49+
cls = _field_to_type(cls.__dataclass_fields__[k].type, k, classes)
50+
for item in v:
51+
casted_set.add(cast_sequence_item(cls, k, item, classes))
52+
return casted_set
53+
54+
55+
def cast_sequence_item(cls: Any, k: str, item: Any, classes: Dict[str, Any]) -> Any:
56+
if isinstance(item, str):
57+
return _recast_primitive(cls, k, item)
58+
if isinstance(item, list):
59+
return _recast_lists(cls, k, item, classes)
60+
if isinstance(item, set):
61+
return _recast_sets(cls, k, item, classes)
62+
if isinstance(item, dict):
63+
recast_object(cls, item, classes)
64+
return item
65+
raise InvalidRequest(f"Unsupported type: {type(item)} for {k}")
66+
67+
4568
def _recast_primitive(cls: Any, k: str, v: str) -> Any:
69+
if cls == typing.Any:
70+
# If the type is Any, we cannot guess what the original type was, so we leave
71+
# it as a string
72+
return v
4673
if cls == bool:
4774
if v.lower() == "true":
4875
return True
@@ -53,7 +80,7 @@ def _recast_primitive(cls: Any, k: str, v: str) -> Any:
5380

5481

5582
def _field_to_type(field: Any, key: str, classes: Dict[str, Any]) -> Any:
56-
if field in [int, float, str, bool]:
83+
if field in [int, float, str, bool, typing.Any]:
5784
return field
5885
# If it's a ForwardRef we need to find base type
5986
if isinstance(field, get_forward_ref_type()):
@@ -64,26 +91,31 @@ def _field_to_type(field: Any, key: str, classes: Dict[str, Any]) -> Any:
6491
try:
6592
possible_types = field.__args__
6693
except AttributeError:
67-
raise InvalidRequest(f"Cannot process type {field.__repr__()} for field {key}")
94+
raise InvalidRequest(f"Cannot process type {field} for field {key}")
6895
# Assuming that the union is generated from typing.Optional, so only
6996
# contains one type and None
7097
# pylint: disable=unidiomatic-typecheck
71-
fields = [t for t in possible_types if type(None) != t]
98+
fields = [t for t in possible_types if type(None) != t] if possible_types else []
7299
if len(fields) != 1:
73-
raise InvalidRequest(f"Cannot process type {field.__repr__()} for field {key}")
100+
raise InvalidRequest(f"Cannot process type {field} for field {key}")
74101
field = fields[0]
75102
# If it's a primitive we're done
76-
if field in [int, float, str, bool]:
103+
if field in [int, float, str, bool, typing.Any]:
77104
return field
78105
# If it's a ForwardRef we need to find base type
79106
if isinstance(field, get_forward_ref_type()):
80107
# Assuming codegen added an _ as a prefix, removing it and then getting the
81108
# class from model classes
82109
return classes[field.__forward_arg__[1:]]
83-
# If it's not a type we don't know how to handle we bail
84-
if not str(field).startswith("typing.Sequence"):
85-
raise InvalidRequest(f"Cannot process type {field} for field {key}")
86-
return _field_to_type(field.__args__[0], key, classes)
110+
# reduce Sequence/AbstractSet to inner type
111+
if str(field).startswith("typing.Sequence") or str(field).startswith(
112+
"typing.AbstractSet"
113+
):
114+
return _field_to_type(field.__args__[0], key, classes)
115+
if str(field).startswith("typing.MutableMapping"):
116+
return _field_to_type(field.__args__[1], key, classes)
117+
# If it's a type we don't know how to handle, we bail
118+
raise InvalidRequest(f"Cannot process type {field} for field {key}")
87119

88120

89121
# pylint: disable=protected-access,no-member

tests/lib/recast_test.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# pylint: disable=protected-access
2-
from typing import Any, Optional, Union
2+
from typing import Awaitable, Generic, Optional, Union
33
from unittest.mock import patch
44

55
import pytest
@@ -15,8 +15,14 @@
1515
from .sample_model import ResourceModel as ComplexResourceModel, SimpleResourceModel
1616

1717

18-
def test_recast_object_simple():
18+
def test_recast_complex_object():
1919
payload = {
20+
"ListListAny": [[{"key": "val"}]],
21+
"ListListInt": [["1", "2", "3"]],
22+
"ListSetInt": [{"1", "2", "3"}],
23+
"ASet": {"1", "2", "3"},
24+
"AnotherSet": {"a", "b", "c"},
25+
"AFreeformDict": {"somekey": "somevalue", "someotherkey": "1"},
2026
"AnInt": "1",
2127
"ABool": "true",
2228
"AList": [
@@ -42,6 +48,12 @@ def test_recast_object_simple():
4248
],
4349
}
4450
expected = {
51+
"ListSetInt": [{1, 2, 3}],
52+
"ListListInt": [[1, 2, 3]],
53+
"ListListAny": [[{"key": "val"}]],
54+
"ASet": {"1", "2", "3"},
55+
"AnotherSet": {"a", "b", "c"},
56+
"AFreeformDict": {"somekey": "somevalue", "someotherkey": "1"},
4557
"AnInt": 1,
4658
"ABool": True,
4759
"AList": [
@@ -87,10 +99,10 @@ def test_recast_object_invalid_sub_type():
8799

88100
def test_recast_list_invalid_sub_type():
89101
k = "key"
90-
v = (1, 2)
102+
v = [(1, 2)]
91103
with pytest.raises(InvalidRequest) as excinfo:
92104
_recast_lists(SimpleResourceModel, k, v, {})
93-
assert str(excinfo.value) == f"Unsupported type: {type(v)} for {k}"
105+
assert str(excinfo.value) == f"Unsupported type: {type(v[0])} for {k}"
94106

95107

96108
def test_recast_boolean_invalid_value():
@@ -103,7 +115,7 @@ def test_recast_boolean_invalid_value():
103115

104116
def test_field_to_type_unhandled_types():
105117
k = "key"
106-
for field in [Union[str, list], Any, Optional[Any]]:
118+
for field in [Union[str, list], Generic, Optional[Awaitable]]:
107119
with pytest.raises(InvalidRequest) as excinfo:
108120
_field_to_type(field, k, {})
109121
assert str(excinfo.value).startswith("Cannot process type ")

tests/lib/sample_model.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
import sys
66
from dataclasses import dataclass
77
from inspect import getmembers, isclass
8-
from typing import Any, Mapping, Optional, Sequence, Type, TypeVar
8+
from typing import (
9+
AbstractSet,
10+
Any,
11+
Mapping,
12+
MutableMapping,
13+
Optional,
14+
Sequence,
15+
Type,
16+
TypeVar,
17+
)
918

1019
from cloudformation_cli_python_lib.interface import (
1120
BaseResourceHandlerRequest,
@@ -26,6 +35,12 @@ class ResourceHandlerRequest(BaseResourceHandlerRequest):
2635

2736
@dataclass
2837
class ResourceModel(BaseResourceModel):
38+
ListListAny: Optional[Sequence[Sequence[Any]]]
39+
ListSetInt: Optional[Sequence[AbstractSet[int]]]
40+
ListListInt: Optional[Sequence[Sequence[int]]]
41+
ASet: Optional[AbstractSet[Any]]
42+
AnotherSet: Optional[AbstractSet[str]]
43+
AFreeformDict: Optional[MutableMapping[str, Any]]
2944
AnInt: Optional[int]
3045
ABool: Optional[bool]
3146
NestedList: Optional[Sequence[Sequence["_NestedList"]]]
@@ -47,6 +62,12 @@ def _deserialize(
4762
dataclasses = {n: o for n, o in getmembers(sys.modules[__name__]) if isclass(o)}
4863
recast_object(cls, json_data, dataclasses)
4964
return cls(
65+
ListSetInt=json_data.get("ListSetInt"),
66+
ListListInt=json_data.get("ListListInt"),
67+
ListListAny=json_data.get("ListListAny"),
68+
ASet=json_data.get("ASet"),
69+
AnotherSet=json_data.get("AnotherSet"),
70+
AFreeformDict=json_data.get("AFreeformDict"),
5071
AnInt=json_data.get("AnInt"),
5172
ABool=json_data.get("ABool"),
5273
NestedList=deserialize_list(json_data.get("NestedList"), NestedList),

0 commit comments

Comments
 (0)