Skip to content

Commit 2f282eb

Browse files
authored
Fix KafkaConsumer.poll() with zero timeout (#2613)
1 parent 4100319 commit 2f282eb

File tree

7 files changed

+179
-127
lines changed

7 files changed

+179
-127
lines changed

kafka/client_async.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from kafka.metrics.stats.rate import TimeUnit
2828
from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS
2929
from kafka.protocol.metadata import MetadataRequest
30-
from kafka.util import Dict, WeakMethod, ensure_valid_topic_name, timeout_ms_fn
30+
from kafka.util import Dict, Timer, WeakMethod, ensure_valid_topic_name
3131
# Although this looks unused, it actually monkey-patches socket.socketpair()
3232
# and should be left in as long as we're using socket.socketpair() in this file
3333
from kafka.vendor import socketpair # noqa: F401
@@ -645,12 +645,8 @@ def poll(self, timeout_ms=None, future=None):
645645
"""
646646
if not isinstance(timeout_ms, (int, float, type(None))):
647647
raise TypeError('Invalid type for timeout: %s' % type(timeout_ms))
648+
timer = Timer(timeout_ms)
648649

649-
begin = time.time()
650-
if timeout_ms is not None:
651-
timeout_at = begin + (timeout_ms / 1000)
652-
else:
653-
timeout_at = begin + (self.config['request_timeout_ms'] / 1000)
654650
# Loop for futures, break after first loop if None
655651
responses = []
656652
while True:
@@ -675,7 +671,7 @@ def poll(self, timeout_ms=None, future=None):
675671
if future is not None and future.is_done:
676672
timeout = 0
677673
else:
678-
user_timeout_ms = 1000 * max(0, timeout_at - time.time())
674+
user_timeout_ms = timer.timeout_ms if timeout_ms is not None else self.config['request_timeout_ms']
679675
idle_connection_timeout_ms = self._idle_expiry_manager.next_check_ms()
680676
request_timeout_ms = self._next_ifr_request_timeout_ms()
681677
log.debug("Timeouts: user %f, metadata %f, idle connection %f, request %f", user_timeout_ms, metadata_timeout_ms, idle_connection_timeout_ms, request_timeout_ms)
@@ -698,7 +694,7 @@ def poll(self, timeout_ms=None, future=None):
698694
break
699695
elif future.is_done:
700696
break
701-
elif timeout_ms is not None and time.time() >= timeout_at:
697+
elif timeout_ms is not None and timer.expired:
702698
break
703699

704700
return responses
@@ -1175,16 +1171,16 @@ def await_ready(self, node_id, timeout_ms=30000):
11751171
This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with
11761172
care.
11771173
"""
1178-
inner_timeout_ms = timeout_ms_fn(timeout_ms, None)
1174+
timer = Timer(timeout_ms)
11791175
self.poll(timeout_ms=0)
11801176
if self.is_ready(node_id):
11811177
return True
11821178

1183-
while not self.is_ready(node_id) and inner_timeout_ms() > 0:
1179+
while not self.is_ready(node_id) and not timer.expired:
11841180
if self.connection_failed(node_id):
11851181
raise Errors.KafkaConnectionError("Connection to %s failed." % (node_id,))
11861182
self.maybe_connect(node_id)
1187-
self.poll(timeout_ms=inner_timeout_ms())
1183+
self.poll(timeout_ms=timer.timeout_ms)
11881184
return self.is_ready(node_id)
11891185

11901186
def send_and_receive(self, node_id, request):

kafka/consumer/fetcher.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from kafka.record import MemoryRecords
2020
from kafka.serializer import Deserializer
2121
from kafka.structs import TopicPartition, OffsetAndMetadata, OffsetAndTimestamp
22-
from kafka.util import timeout_ms_fn
22+
from kafka.util import Timer
2323

2424
log = logging.getLogger(__name__)
2525

@@ -230,15 +230,15 @@ def _fetch_offsets_by_times(self, timestamps, timeout_ms=None):
230230
if not timestamps:
231231
return {}
232232

233-
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout fetching offsets')
233+
timer = Timer(timeout_ms, "Failed to get offsets by timestamps in %s ms" % (timeout_ms,))
234234
timestamps = copy.copy(timestamps)
235235
fetched_offsets = dict()
236236
while True:
237237
if not timestamps:
238238
return {}
239239

240240
future = self._send_list_offsets_requests(timestamps)
241-
self._client.poll(future=future, timeout_ms=inner_timeout_ms())
241+
self._client.poll(future=future, timeout_ms=timer.timeout_ms)
242242

243243
# Timeout w/o future completion
244244
if not future.is_done:
@@ -256,12 +256,17 @@ def _fetch_offsets_by_times(self, timestamps, timeout_ms=None):
256256

257257
if future.exception.invalid_metadata or self._client.cluster.need_update:
258258
refresh_future = self._client.cluster.request_update()
259-
self._client.poll(future=refresh_future, timeout_ms=inner_timeout_ms())
259+
self._client.poll(future=refresh_future, timeout_ms=timer.timeout_ms)
260260

261261
if not future.is_done:
262262
break
263263
else:
264-
time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000)
264+
if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']:
265+
time.sleep(self.config['retry_backoff_ms'] / 1000)
266+
else:
267+
time.sleep(timer.timeout_ms / 1000)
268+
269+
timer.maybe_raise()
265270

266271
raise Errors.KafkaTimeoutError(
267272
"Failed to get offsets by timestamps in %s ms" % (timeout_ms,))

kafka/consumer/group.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from kafka.metrics import MetricConfig, Metrics
1919
from kafka.protocol.list_offsets import OffsetResetStrategy
2020
from kafka.structs import OffsetAndMetadata, TopicPartition
21-
from kafka.util import timeout_ms_fn
21+
from kafka.util import Timer
2222
from kafka.version import __version__
2323

2424
log = logging.getLogger(__name__)
@@ -679,41 +679,40 @@ def poll(self, timeout_ms=0, max_records=None, update_offsets=True):
679679
assert not self._closed, 'KafkaConsumer is closed'
680680

681681
# Poll for new data until the timeout expires
682-
inner_timeout_ms = timeout_ms_fn(timeout_ms, None)
682+
timer = Timer(timeout_ms)
683683
while not self._closed:
684-
records = self._poll_once(inner_timeout_ms(), max_records, update_offsets=update_offsets)
684+
records = self._poll_once(timer, max_records, update_offsets=update_offsets)
685685
if records:
686686
return records
687-
688-
if inner_timeout_ms() <= 0:
687+
elif timer.expired:
689688
break
690-
691689
return {}
692690

693-
def _poll_once(self, timeout_ms, max_records, update_offsets=True):
691+
def _poll_once(self, timer, max_records, update_offsets=True):
694692
"""Do one round of polling. In addition to checking for new data, this does
695693
any needed heart-beating, auto-commits, and offset updates.
696694
697695
Arguments:
698-
timeout_ms (int): The maximum time in milliseconds to block.
696+
timer (Timer): The maximum time in milliseconds to block.
699697
700698
Returns:
701699
dict: Map of topic to list of records (may be empty).
702700
"""
703-
inner_timeout_ms = timeout_ms_fn(timeout_ms, None)
704-
if not self._coordinator.poll(timeout_ms=inner_timeout_ms()):
701+
if not self._coordinator.poll(timeout_ms=timer.timeout_ms):
705702
return {}
706703

707-
has_all_fetch_positions = self._update_fetch_positions(timeout_ms=inner_timeout_ms())
704+
has_all_fetch_positions = self._update_fetch_positions(timeout_ms=timer.timeout_ms)
708705

709706
# If data is available already, e.g. from a previous network client
710707
# poll() call to commit, then just return it immediately
711708
records, partial = self._fetcher.fetched_records(max_records, update_offsets=update_offsets)
709+
log.debug('Fetched records: %s, %s', records, partial)
712710
# Before returning the fetched records, we can send off the
713711
# next round of fetches and avoid block waiting for their
714712
# responses to enable pipelining while the user is handling the
715713
# fetched records.
716714
if not partial:
715+
log.debug("Sending fetches")
717716
futures = self._fetcher.send_fetches()
718717
if len(futures):
719718
self._client.poll(timeout_ms=0)
@@ -723,7 +722,7 @@ def _poll_once(self, timeout_ms, max_records, update_offsets=True):
723722

724723
# We do not want to be stuck blocking in poll if we are missing some positions
725724
# since the offset lookup may be backing off after a failure
726-
poll_timeout_ms = inner_timeout_ms(self._coordinator.time_to_next_poll() * 1000)
725+
poll_timeout_ms = min(timer.timeout_ms, self._coordinator.time_to_next_poll() * 1000)
727726
if not has_all_fetch_positions:
728727
poll_timeout_ms = min(poll_timeout_ms, self.config['retry_backoff_ms'])
729728

@@ -749,15 +748,14 @@ def position(self, partition, timeout_ms=None):
749748
raise TypeError('partition must be a TopicPartition namedtuple')
750749
assert self._subscription.is_assigned(partition), 'Partition is not assigned'
751750

752-
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout retrieving partition position')
751+
timer = Timer(timeout_ms)
753752
position = self._subscription.assignment[partition].position
754-
try:
755-
while position is None:
756-
# batch update fetch positions for any partitions without a valid position
757-
self._update_fetch_positions(timeout_ms=inner_timeout_ms())
753+
while position is None:
754+
# batch update fetch positions for any partitions without a valid position
755+
if self._update_fetch_positions(timeout_ms=timer.timeout_ms):
758756
position = self._subscription.assignment[partition].position
759-
except KafkaTimeoutError:
760-
return None
757+
elif timer.expired:
758+
return None
761759
else:
762760
return position.offset
763761

kafka/coordinator/base.py

+42-20
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from kafka.metrics.stats import Avg, Count, Max, Rate
1717
from kafka.protocol.find_coordinator import FindCoordinatorRequest
1818
from kafka.protocol.group import HeartbeatRequest, JoinGroupRequest, LeaveGroupRequest, SyncGroupRequest, DEFAULT_GENERATION_ID, UNKNOWN_MEMBER_ID
19-
from kafka.util import timeout_ms_fn
19+
from kafka.util import Timer
2020

2121
log = logging.getLogger('kafka.coordinator')
2222

@@ -256,9 +256,9 @@ def ensure_coordinator_ready(self, timeout_ms=None):
256256
timeout_ms (numeric, optional): Maximum number of milliseconds to
257257
block waiting to find coordinator. Default: None.
258258
259-
Raises: KafkaTimeoutError if timeout_ms is not None
259+
Returns: True is coordinator found before timeout_ms, else False
260260
"""
261-
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to find group coordinator')
261+
timer = Timer(timeout_ms)
262262
with self._client._lock, self._lock:
263263
while self.coordinator_unknown():
264264

@@ -272,27 +272,37 @@ def ensure_coordinator_ready(self, timeout_ms=None):
272272
else:
273273
self.coordinator_id = maybe_coordinator_id
274274
self._client.maybe_connect(self.coordinator_id)
275-
continue
275+
if timer.expired:
276+
return False
277+
else:
278+
continue
276279
else:
277280
future = self.lookup_coordinator()
278281

279-
self._client.poll(future=future, timeout_ms=inner_timeout_ms())
282+
self._client.poll(future=future, timeout_ms=timer.timeout_ms)
280283

281284
if not future.is_done:
282-
raise Errors.KafkaTimeoutError()
285+
return False
283286

284287
if future.failed():
285288
if future.retriable():
286289
if getattr(future.exception, 'invalid_metadata', False):
287290
log.debug('Requesting metadata for group coordinator request: %s', future.exception)
288291
metadata_update = self._client.cluster.request_update()
289-
self._client.poll(future=metadata_update, timeout_ms=inner_timeout_ms())
292+
self._client.poll(future=metadata_update, timeout_ms=timer.timeout_ms)
290293
if not metadata_update.is_done:
291-
raise Errors.KafkaTimeoutError()
294+
return False
292295
else:
293-
time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000)
296+
if timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']:
297+
time.sleep(self.config['retry_backoff_ms'] / 1000)
298+
else:
299+
time.sleep(timer.timeout_ms / 1000)
294300
else:
295301
raise future.exception # pylint: disable-msg=raising-bad-type
302+
if timer.expired:
303+
return False
304+
else:
305+
return True
296306

297307
def _reset_find_coordinator_future(self, result):
298308
self._find_coordinator_future = None
@@ -407,21 +417,23 @@ def ensure_active_group(self, timeout_ms=None):
407417
timeout_ms (numeric, optional): Maximum number of milliseconds to
408418
block waiting to join group. Default: None.
409419
410-
Raises: KafkaTimeoutError if timeout_ms is not None
420+
Returns: True if group initialized before timeout_ms, else False
411421
"""
412422
if self.config['api_version'] < (0, 9):
413423
raise Errors.UnsupportedVersionError('Group Coordinator APIs require 0.9+ broker')
414-
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to join consumer group')
415-
self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms())
424+
timer = Timer(timeout_ms)
425+
if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms):
426+
return False
416427
self._start_heartbeat_thread()
417-
self.join_group(timeout_ms=inner_timeout_ms())
428+
return self.join_group(timeout_ms=timer.timeout_ms)
418429

419430
def join_group(self, timeout_ms=None):
420431
if self.config['api_version'] < (0, 9):
421432
raise Errors.UnsupportedVersionError('Group Coordinator APIs require 0.9+ broker')
422-
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to join consumer group')
433+
timer = Timer(timeout_ms)
423434
while self.need_rejoin():
424-
self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms())
435+
if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms):
436+
return False
425437

426438
# call on_join_prepare if needed. We set a flag
427439
# to make sure that we do not call it a second
@@ -434,7 +446,7 @@ def join_group(self, timeout_ms=None):
434446
if not self.rejoining:
435447
self._on_join_prepare(self._generation.generation_id,
436448
self._generation.member_id,
437-
timeout_ms=inner_timeout_ms())
449+
timeout_ms=timer.timeout_ms)
438450
self.rejoining = True
439451

440452
# fence off the heartbeat thread explicitly so that it cannot
@@ -449,16 +461,19 @@ def join_group(self, timeout_ms=None):
449461
while not self.coordinator_unknown():
450462
if not self._client.in_flight_request_count(self.coordinator_id):
451463
break
452-
self._client.poll(timeout_ms=inner_timeout_ms(200))
464+
poll_timeout_ms = 200 if timer.timeout_ms is None or timer.timeout_ms > 200 else timer.timeout_ms
465+
self._client.poll(timeout_ms=poll_timeout_ms)
466+
if timer.expired:
467+
return False
453468
else:
454469
continue
455470

456471
future = self._initiate_join_group()
457-
self._client.poll(future=future, timeout_ms=inner_timeout_ms())
472+
self._client.poll(future=future, timeout_ms=timer.timeout_ms)
458473
if future.is_done:
459474
self._reset_join_group_future()
460475
else:
461-
raise Errors.KafkaTimeoutError()
476+
return False
462477

463478
if future.succeeded():
464479
self.rejoining = False
@@ -467,6 +482,7 @@ def join_group(self, timeout_ms=None):
467482
self._generation.member_id,
468483
self._generation.protocol,
469484
future.value)
485+
return True
470486
else:
471487
exception = future.exception
472488
if isinstance(exception, (Errors.UnknownMemberIdError,
@@ -476,7 +492,13 @@ def join_group(self, timeout_ms=None):
476492
continue
477493
elif not future.retriable():
478494
raise exception # pylint: disable-msg=raising-bad-type
479-
time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000)
495+
elif timer.expired:
496+
return False
497+
else:
498+
if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']:
499+
time.sleep(self.config['retry_backoff_ms'] / 1000)
500+
else:
501+
time.sleep(timer.timeout_ms / 1000)
480502

481503
def _send_join_group_request(self):
482504
"""Join the group and return the assignment for the next generation.

0 commit comments

Comments
 (0)