Skip to content

Commit 75f3bdd

Browse files
authored
Support protocol version 2.0.0 to callback handlers orchestrated by CloudFormation service (aws-cloudformation#101)
1 parent 9a9229c commit 75f3bdd

File tree

12 files changed

+75
-577
lines changed

12 files changed

+75
-577
lines changed

python/rpdk/python/codegen.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,30 @@ def __init__(self):
5050
self.package_name = None
5151
self.package_root = None
5252
self._use_docker = True
53+
self._protocol_version = "2.0.0"
5354

5455
def _init_from_project(self, project):
5556
self.namespace = tuple(s.lower() for s in project.type_info)
5657
self.package_name = "_".join(self.namespace)
5758
self._use_docker = project.settings.get("use_docker", True)
5859
self.package_root = project.root / "src"
5960

60-
def _prompt_for_use_docker(self, project):
61+
def _init_settings(self, project):
62+
LOG.debug("Writing settings")
6163
self._use_docker = input_with_validation(
6264
"Use docker for platform-independent packaging (Y/n)?\n",
6365
validate_no,
6466
"This is highly recommended unless you are experienced \n"
6567
"with cross-platform Python packaging.",
6668
)
6769
project.settings["use_docker"] = self._use_docker
70+
project.settings["protocolVersion"] = self._protocol_version
6871

6972
def init(self, project):
7073
LOG.debug("Init started")
7174

7275
self._init_from_project(project)
73-
self._prompt_for_use_docker(project)
76+
self._init_settings(project)
7477

7578
project.runtime = self.RUNTIME
7679
project.entrypoint = self.ENTRY_POINT.format(self.package_name)

src/cloudformation_cli_python_lib/callback.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

src/cloudformation_cli_python_lib/interface.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -95,30 +95,26 @@ class ProgressEvent:
9595
resourceModels: Optional[List[BaseModel]] = None
9696
nextToken: Optional[str] = None
9797

98-
def _serialize(
99-
self, to_response: bool = False, bearer_token: Optional[str] = None
100-
) -> MutableMapping[str, Any]:
98+
def _serialize(self) -> MutableMapping[str, Any]:
10199
# to match Java serialization, which drops `null` values, and the
102100
# contract tests currently expect this also
103101
ser = {k: v for k, v in self.__dict__.items() if v is not None}
102+
104103
# mutate to what's expected in the response
105-
if to_response:
106-
ser["bearerToken"] = bearer_token
107-
ser["operationStatus"] = ser.pop("status").name
108-
if self.resourceModel:
104+
105+
ser["status"] = ser.pop("status").name
106+
107+
if self.resourceModel:
108+
# pylint: disable=protected-access
109+
ser["resourceModel"] = self.resourceModel._serialize()
110+
if self.resourceModels:
111+
ser["resourceModels"] = [
109112
# pylint: disable=protected-access
110-
ser["resourceModel"] = self.resourceModel._serialize()
111-
if self.resourceModels:
112-
ser["resourceModels"] = [
113-
# pylint: disable=protected-access
114-
model._serialize()
115-
for model in self.resourceModels
116-
]
117-
del ser["callbackDelaySeconds"]
118-
if "callbackContext" in ser:
119-
del ser["callbackContext"]
120-
if self.errorCode:
121-
ser["errorCode"] = self.errorCode.name
113+
model._serialize()
114+
for model in self.resourceModels
115+
]
116+
if self.errorCode:
117+
ser["errorCode"] = self.errorCode.name
122118
return ser
123119

124120
@classmethod

src/cloudformation_cli_python_lib/resource.py

Lines changed: 19 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
import traceback
44
from datetime import datetime
55
from functools import wraps
6-
from time import sleep
76
from typing import Any, Callable, MutableMapping, Optional, Tuple, Type, Union
87

98
from .boto3_proxy import SessionProxy, _get_boto_session
10-
from .callback import report_progress
119
from .exceptions import InternalFailure, InvalidRequest, _HandlerError
1210
from .interface import (
1311
Action,
@@ -18,7 +16,6 @@
1816
)
1917
from .log_delivery import ProviderLogHandler
2018
from .metrics import MetricsPublisherProxy
21-
from .scheduler import cleanup_cloudwatch_events, reschedule_after_minutes
2219
from .utils import (
2320
BaseModel,
2421
Credentials,
@@ -32,7 +29,6 @@
3229
LOG = logging.getLogger(__name__)
3330

3431
MUTATING_ACTIONS = (Action.CREATE, Action.UPDATE, Action.DELETE)
35-
INVOCATION_TIMEOUT_MS = 60000
3632

3733
HandlerSignature = Callable[
3834
[Optional[SessionProxy], Any, MutableMapping[str, Any]], ProgressEvent
@@ -72,38 +68,6 @@ def _add_handler(f: HandlerSignature) -> HandlerSignature:
7268

7369
return _add_handler
7470

75-
@staticmethod
76-
def schedule_reinvocation(
77-
handler_request: HandlerRequest,
78-
handler_response: ProgressEvent,
79-
context: LambdaContext,
80-
session: SessionProxy,
81-
) -> bool:
82-
if handler_response.status != OperationStatus.IN_PROGRESS:
83-
return False
84-
# modify requestContext dict in-place, so that invoke count is bumped on local
85-
# reinvoke too
86-
reinvoke_context = handler_request.requestContext
87-
reinvoke_context["invocation"] = reinvoke_context.get("invocation", 0) + 1
88-
callback_delay_s = handler_response.callbackDelaySeconds
89-
remaining_ms = context.get_remaining_time_in_millis()
90-
91-
# when a handler requests a sub-minute callback delay, and if the lambda
92-
# invocation has enough runtime (with 20% buffer), we can re-run the handler
93-
# locally otherwise we re-invoke through CloudWatchEvents
94-
needed_ms_remaining = callback_delay_s * 1200 + INVOCATION_TIMEOUT_MS
95-
if callback_delay_s < 60 and remaining_ms > needed_ms_remaining:
96-
sleep(callback_delay_s)
97-
return True
98-
callback_delay_min = int(callback_delay_s / 60)
99-
reschedule_after_minutes(
100-
session,
101-
function_arn=context.invoked_function_arn,
102-
minutes_from_now=callback_delay_min,
103-
handler_request=handler_request,
104-
)
105-
return False
106-
10771
def _invoke_handler(
10872
self,
10973
session: Optional[SessionProxy],
@@ -169,26 +133,23 @@ def test_entrypoint(
169133
def _parse_request(
170134
event_data: MutableMapping[str, Any]
171135
) -> Tuple[
172-
Tuple[Optional[SessionProxy], Optional[SessionProxy], SessionProxy],
136+
Tuple[Optional[SessionProxy], Optional[SessionProxy]],
173137
Action,
174138
MutableMapping[str, Any],
175139
HandlerRequest,
176140
]:
177141
try:
178142
event = HandlerRequest.deserialize(event_data)
179-
platform_sess = _get_boto_session(event.requestData.platformCredentials)
180143
caller_sess = _get_boto_session(event.requestData.callerCredentials)
181144
provider_sess = _get_boto_session(event.requestData.providerCredentials)
182145
# credentials are used when rescheduling, so can't zero them out (for now)
183-
if platform_sess is None:
184-
raise ValueError("No platform credentials")
185146
action = Action[event.action]
186-
callback_context = event.requestContext.get("callbackContext", {})
147+
callback_context = event.callbackContext or {}
187148
except Exception as e: # pylint: disable=broad-except
188149
LOG.exception("Invalid request")
189150
raise InvalidRequest(f"{e} ({type(e).__name__})") from e
190151
return (
191-
(caller_sess, provider_sess, platform_sess),
152+
(caller_sess, provider_sess),
192153
action,
193154
callback_context,
194155
event,
@@ -224,66 +185,28 @@ def print_or_log(message: str) -> None:
224185

225186
try:
226187
sessions, action, callback, event = self._parse_request(event_data)
227-
caller_sess, provider_sess, platform_sess = sessions
188+
caller_sess, provider_sess = sessions
228189
ProviderLogHandler.setup(event, provider_sess)
229190
logs_setup = True
230191

231192
request = self._cast_resource_request(event)
232193

233194
metrics = MetricsPublisherProxy(event.awsAccountId, event.resourceType)
234-
metrics.add_metrics_publisher(platform_sess)
235195
metrics.add_metrics_publisher(provider_sess)
236-
# Acknowledge the task for first time invocation
237-
if not event.requestContext:
238-
report_progress(
239-
platform_sess,
240-
event.bearerToken,
241-
None,
242-
OperationStatus.IN_PROGRESS,
243-
OperationStatus.PENDING,
244-
None,
245-
"",
246-
)
247-
else:
248-
# If this invocation was triggered by a 're-invoke' CloudWatch Event,
249-
# clean it up
250-
cleanup_cloudwatch_events(
251-
platform_sess,
252-
event.requestContext.get("cloudWatchEventsRuleName", ""),
253-
event.requestContext.get("cloudWatchEventsTargetId", ""),
254-
)
255-
invoke = True
256-
while invoke:
257-
metrics.publish_invocation_metric(datetime.utcnow(), action)
258-
start_time = datetime.utcnow()
259-
error = None
260-
try:
261-
progress = self._invoke_handler(
262-
caller_sess, request, action, callback
263-
)
264-
except Exception as e: # pylint: disable=broad-except
265-
error = e
266-
m_secs = (datetime.utcnow() - start_time).total_seconds() * 1000.0
267-
metrics.publish_duration_metric(datetime.utcnow(), action, m_secs)
268-
if error:
269-
metrics.publish_exception_metric(datetime.utcnow(), action, error)
270-
raise error
271-
if progress.callbackContext:
272-
callback = progress.callbackContext
273-
event.requestContext["callbackContext"] = callback
274-
if event.action in MUTATING_ACTIONS:
275-
report_progress(
276-
platform_sess,
277-
event.bearerToken,
278-
progress.errorCode,
279-
progress.status,
280-
OperationStatus.IN_PROGRESS,
281-
progress.resourceModel,
282-
progress.message,
283-
)
284-
invoke = self.schedule_reinvocation(
285-
event, progress, context, platform_sess
286-
)
196+
197+
metrics.publish_invocation_metric(datetime.utcnow(), action)
198+
start_time = datetime.utcnow()
199+
error = None
200+
201+
try:
202+
progress = self._invoke_handler(caller_sess, request, action, callback)
203+
except Exception as e: # pylint: disable=broad-except
204+
error = e
205+
m_secs = (datetime.utcnow() - start_time).total_seconds() * 1000.0
206+
metrics.publish_duration_metric(datetime.utcnow(), action, m_secs)
207+
if error:
208+
metrics.publish_exception_metric(datetime.utcnow(), action, error)
209+
raise error
287210
except _HandlerError as e:
288211
print_or_log("Handler error")
289212
progress = e.to_progress_event()
@@ -296,6 +219,4 @@ def print_or_log(message: str) -> None:
296219

297220
# use the raw event_data as a last-ditch attempt to call back if the
298221
# request is invalid
299-
return progress._serialize( # pylint: disable=protected-access
300-
to_response=True, bearer_token=event_data.get("bearerToken")
301-
)
222+
return progress._serialize() # pylint: disable=protected-access

src/cloudformation_cli_python_lib/scheduler.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

src/cloudformation_cli_python_lib/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class RequestData:
5454
stackTags: Optional[Mapping[str, Any]] = None
5555
# platform credentials aren't really optional, but this is used to
5656
# zero them out to prevent e.g. accidental logging
57-
platformCredentials: Optional[Credentials] = None
5857
callerCredentials: Optional[Credentials] = None
5958
providerCredentials: Optional[Credentials] = None
6059
previousResourceProperties: Optional[Mapping[str, Any]] = None
@@ -91,6 +90,7 @@ class HandlerRequest:
9190
resourceTypeVersion: str
9291
requestData: RequestData
9392
stackId: str
93+
callbackContext: Optional[MutableMapping[str, Any]] = None
9494
nextToken: Optional[str] = None
9595
requestContext: MutableMapping[str, Any] = field(default_factory=dict)
9696

0 commit comments

Comments
 (0)