Skip to content

Commit ca962fa

Browse files
arpanshah29ask
authored andcommitted
Add new function to handle etas and limits together (celery#4251)
* Add new function to handle etas and limits together * Adding unit test * Fixing indentation
1 parent 9ca61fa commit ca962fa

File tree

3 files changed

+42
-13
lines changed

3 files changed

+42
-13
lines changed

celery/worker/consumer/consumer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,12 @@ def _limit_task(self, request, bucket, tokens):
305305
return bucket.add(request)
306306
return self._schedule_bucket_request(request, bucket, tokens)
307307

308+
def _limit_post_eta(self, request, bucket, tokens):
309+
self.qos.decrement_eventually()
310+
if bucket.contents:
311+
return bucket.add(request)
312+
return self._schedule_bucket_request(request, bucket, tokens)
313+
308314
def start(self):
309315
blueprint = self.blueprint
310316
while blueprint.state not in STOP_CONDITIONS:

celery/worker/strategy.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def default(task, app, consumer,
8484
get_bucket = consumer.task_buckets.__getitem__
8585
handle = consumer.on_task_request
8686
limit_task = consumer._limit_task
87+
limit_post_eta = consumer._limit_post_eta
8788
body_can_be_buffer = consumer.pool.body_can_be_buffer
8889
Request = symbol_by_name(task.Request)
8990
Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
@@ -123,6 +124,8 @@ def task_message_handler(message, body, ack, reject, callbacks,
123124
expires=req.expires and req.expires.isoformat(),
124125
)
125126

127+
bucket = None
128+
eta = None
126129
if req.eta:
127130
try:
128131
if req.utc:
@@ -133,17 +136,22 @@ def task_message_handler(message, body, ack, reject, callbacks,
133136
error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
134137
req.eta, exc, req.info(safe=True), exc_info=True)
135138
req.reject(requeue=False)
136-
else:
137-
consumer.qos.increment_eventually()
138-
call_at(eta, apply_eta_task, (req,), priority=6)
139-
else:
140-
if rate_limits_enabled:
141-
bucket = get_bucket(task.name)
142-
if bucket:
143-
return limit_task(req, bucket, 1)
144-
task_reserved(req)
145-
if callbacks:
146-
[callback(req) for callback in callbacks]
147-
handle(req)
148-
139+
if rate_limits_enabled:
140+
bucket = get_bucket(task.name)
141+
142+
if eta and bucket:
143+
consumer.qos.increment_eventually()
144+
return call_at(eta, limit_post_eta, (req, bucket, 1),
145+
priority=6)
146+
if eta:
147+
consumer.qos.increment_eventually()
148+
call_at(eta, apply_eta_task, (req,), priority=6)
149+
return task_message_handler
150+
if bucket:
151+
return limit_task(req, bucket, 1)
152+
153+
task_reserved(req)
154+
if callbacks:
155+
[callback(req) for callback in callbacks]
156+
handle(req)
149157
return task_message_handler

t/unit/worker/test_strategy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ def was_rate_limited(self):
9898
assert not self.was_reserved()
9999
return self.consumer._limit_task.called
100100

101+
def was_limited_with_eta(self):
102+
assert not self.was_reserved()
103+
called = self.consumer.timer.call_at.called
104+
if called:
105+
assert self.consumer.timer.call_at.call_args[0][1] == \
106+
self.consumer._limit_post_eta
107+
return called
108+
101109
def was_scheduled(self):
102110
assert not self.was_reserved()
103111
assert not self.was_rate_limited()
@@ -186,6 +194,13 @@ def test_when_rate_limited(self):
186194
C()
187195
assert C.was_rate_limited()
188196

197+
def test_when_rate_limited_with_eta(self):
198+
task = self.add.s(2, 2).set(countdown=10)
199+
with self._context(task, rate_limits=True, limit='1/m') as C:
200+
C()
201+
assert C.was_limited_with_eta()
202+
C.consumer.qos.increment_eventually.assert_called_with()
203+
189204
def test_when_rate_limited__limits_disabled(self):
190205
task = self.add.s(2, 2)
191206
with self._context(task, rate_limits=False, limit='1/m') as C:

0 commit comments

Comments
 (0)