Skip to content

Commit f129167

Browse files
committed
Factor out waiting code into a new Sleeper class
- Moving code related to waiting/sleeping/retrying into a new class for a more object oriented approach. - Removing any reference to wait "tokens" to avoid confusion with edit tokens. - Note: `max_retries` and `retry_timeout` are no longer available on `Site`, but can still be passed into the constructor as before.
1 parent cd71b2b commit f129167

File tree

3 files changed

+124
-47
lines changed

3 files changed

+124
-47
lines changed

mwclient/client.py

Lines changed: 16 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import mwclient.errors as errors
2525
import mwclient.listing as listing
26+
from mwclient.sleep import Sleepers
2627

2728
try:
2829
import gzip
@@ -34,15 +35,6 @@
3435
log = logging.getLogger(__name__)
3536

3637

37-
class WaitToken(object):
38-
39-
def __init__(self):
40-
self.id = '%032x' % random.getrandbits(128)
41-
42-
def __hash__(self):
43-
return hash(self.id)
44-
45-
4638
class Site(object):
4739
api_limit = 500
4840

@@ -55,9 +47,6 @@ def __init__(self, host, path='/w/', ext='.php', pool=None, retry_timeout=30,
5547
self.ext = ext
5648
self.credentials = None
5749
self.compress = compress
58-
self.retry_timeout = retry_timeout
59-
self.max_retries = max_retries
60-
self.wait_callback = wait_callback
6150
self.max_lag = text_type(max_lag)
6251
self.force_login = force_login
6352

@@ -68,8 +57,7 @@ def __init__(self, host, path='/w/', ext='.php', pool=None, retry_timeout=30,
6857
else:
6958
raise RuntimeError('Authentication is not a tuple or an instance of AuthBase')
7059

71-
# The token string => token object mapping
72-
self.wait_tokens = weakref.WeakKeyDictionary()
60+
self.sleepers = Sleepers(max_retries, retry_timeout, wait_callback)
7361

7462
# Site properties
7563
self.blocked = False # Whether current user is blocked
@@ -192,18 +180,18 @@ def api(self, action, *args, **kwargs):
192180
else:
193181
kwargs['uiprop'] = 'blockinfo|hasmsg'
194182

195-
token = self.wait_token()
183+
sleeper = self.sleepers.make()
196184

197185
while True:
198186
info = self.raw_api(action, **kwargs)
199187
if not info:
200188
info = {}
201-
if self.handle_api_result(info, token=token):
189+
if self.handle_api_result(info, sleeper=sleeper):
202190
return info
203191

204-
def handle_api_result(self, info, kwargs=None, token=None):
205-
if token is None:
206-
token = self.wait_token()
192+
def handle_api_result(self, info, kwargs=None, sleeper=None):
193+
if sleeper is None:
194+
sleeper = self.sleepers.make()
207195

208196
try:
209197
userinfo = info['query']['userinfo']
@@ -217,7 +205,7 @@ def handle_api_result(self, info, kwargs=None, token=None):
217205
self.logged_in = 'anon' not in userinfo
218206
if 'error' in info:
219207
if info['error']['code'] in (u'internal_api_error_DBConnectionError', u'internal_api_error_DBQueryError'):
220-
self.wait(token)
208+
sleeper.sleep()
221209
return False
222210
if '*' in info['error']:
223211
raise errors.APIError(info['error']['code'],
@@ -258,7 +246,7 @@ def raw_call(self, script, data, files=None, retry_on_error=True):
258246
headers = {}
259247
if self.compress and gzip:
260248
headers['Accept-Encoding'] = 'gzip'
261-
token = self.wait_token((script, data))
249+
sleeper = self.sleepers.make((script, data))
262250
while True:
263251
scheme = 'http' # Should we move to 'https' as default?
264252
host = self.host
@@ -272,7 +260,7 @@ def raw_call(self, script, data, files=None, retry_on_error=True):
272260
if stream.headers.get('x-database-lag'):
273261
wait_time = int(stream.headers.get('retry-after'))
274262
log.warn('Database lag exceeds max lag. Waiting for %d seconds', wait_time)
275-
self.wait(token, wait_time)
263+
sleeper.sleep(wait_time)
276264
elif stream.status_code == 200:
277265
return stream.text
278266
elif stream.status_code < 500 or stream.status_code > 599:
@@ -281,15 +269,15 @@ def raw_call(self, script, data, files=None, retry_on_error=True):
281269
if not retry_on_error:
282270
stream.raise_for_status()
283271
log.warn('Received %s response: %s. Retrying in a moment.', stream.status_code, stream.text)
284-
self.wait(token)
272+
sleeper.sleep()
285273

286274
except requests.exceptions.ConnectionError:
287275
# In the event of a network problem (e.g. DNS failure, refused connection, etc),
288276
# Requests will raise a ConnectionError exception.
289277
if not retry_on_error:
290278
raise
291279
log.warn('Connection error. Retrying in a moment.')
292-
self.wait(token)
280+
sleeper.sleep()
293281

294282
def raw_api(self, action, *args, **kwargs):
295283
"""Sends a call to the API."""
@@ -316,25 +304,6 @@ def raw_index(self, action, *args, **kwargs):
316304
data = self._query_string(*args, **kwargs)
317305
return self.raw_call('index', data)
318306

319-
def wait_token(self, args=None):
320-
token = WaitToken()
321-
self.wait_tokens[token] = (0, args)
322-
return token
323-
324-
def wait(self, token, min_wait=0):
325-
retry, args = self.wait_tokens[token]
326-
self.wait_tokens[token] = (retry + 1, args)
327-
if retry > self.max_retries and self.max_retries != -1:
328-
raise errors.MaximumRetriesExceeded(self, token, args)
329-
self.wait_callback(self, token, retry, args)
330-
331-
timeout = self.retry_timeout * retry
332-
if timeout < min_wait:
333-
timeout = min_wait
334-
log.debug('Sleeping for %d seconds', timeout)
335-
time.sleep(timeout)
336-
return self.wait_tokens[token]
337-
338307
def require(self, major, minor, revision=None, raise_error=True):
339308
if self.version is None:
340309
if raise_error is None:
@@ -399,7 +368,7 @@ def login(self, username=None, password=None, cookies=None, domain=None):
399368
self.conn.cookies[self.host].update(cookies)
400369

401370
if self.credentials:
402-
wait_token = self.wait_token()
371+
sleeper = self.sleepers.make()
403372
kwargs = {
404373
'lgname': self.credentials[0],
405374
'lgpassword': self.credentials[1]
@@ -413,7 +382,7 @@ def login(self, username=None, password=None, cookies=None, domain=None):
413382
elif login['login']['result'] == 'NeedToken':
414383
kwargs['lgtoken'] = login['login']['token']
415384
elif login['login']['result'] == 'Throttled':
416-
self.wait(wait_token, login['login'].get('wait', 5))
385+
sleeper.sleep(int(login['login'].get('wait', 5)))
417386
else:
418387
raise errors.LoginError(self, login['login'])
419388

@@ -541,13 +510,13 @@ def upload(self, file=None, filename=None, description='', ignore=False, file_si
541510

542511
files = {'file': file}
543512

544-
wait_token = self.wait_token()
513+
sleeper = self.sleepers.make()
545514
while True:
546515
data = self.raw_call('api', postdata, files)
547516
info = json.loads(data)
548517
if not info:
549518
info = {}
550-
if self.handle_api_result(info, kwargs=predata, token=wait_token):
519+
if self.handle_api_result(info, kwargs=predata, sleeper=sleeper):
551520
return info.get('upload', {})
552521

553522
def parse(self, text=None, title=None, page=None):

mwclient/sleep.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import random
2+
import time
3+
import logging
4+
from mwclient.errors import MaximumRetriesExceeded
5+
6+
log = logging.getLogger(__name__)
7+
8+
9+
class Sleepers(object):
10+
11+
def __init__(self, max_retries, retry_timeout, callback=lambda *x: None):
12+
self.max_retries = max_retries
13+
self.retry_timeout = retry_timeout
14+
self.callback = callback
15+
16+
def make(self, args=None):
17+
return Sleeper(args, self.max_retries, self.retry_timeout, self.callback)
18+
19+
20+
class Sleeper(object):
21+
"""
22+
For any given operation, a `Sleeper` object keeps count of the number of
23+
retries. For each retry, the sleep time increases until the max number of
24+
retries is reached and a `MaximumRetriesExceeded` is raised. The sleeper
25+
object should be discarded once the operation is successful.
26+
"""
27+
28+
def __init__(self, args, max_retries, retry_timeout, callback):
29+
self.args = args
30+
self.retries = 0
31+
self.max_retries = max_retries
32+
self.retry_timeout = retry_timeout
33+
self.callback = callback
34+
35+
def sleep(self, min_time=0):
36+
"""
37+
Sleep a minimum of `min_time` seconds.
38+
The actual sleeping time will increase with the number of retries.
39+
"""
40+
self.retries += 1
41+
if self.retries > self.max_retries:
42+
raise MaximumRetriesExceeded(self, self.args)
43+
44+
self.callback(self, self.retries, self.args)
45+
46+
timeout = self.retry_timeout * (self.retries - 1)
47+
if timeout < min_time:
48+
timeout = min_time
49+
log.debug('Sleeping for %d seconds', timeout)
50+
time.sleep(timeout)

tests/test_sleep.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# encoding=utf-8
2+
from __future__ import print_function
3+
import unittest
4+
import time
5+
import mock
6+
import pytest
7+
from mwclient.sleep import Sleepers
8+
from mwclient.sleep import Sleeper
9+
from mwclient.errors import MaximumRetriesExceeded
10+
11+
if __name__ == "__main__":
12+
print()
13+
print("Note: Running in stand-alone mode. Consult the README")
14+
print(" (section 'Contributing') for advice on running tests.")
15+
print()
16+
17+
18+
class TestSleepers(unittest.TestCase):
19+
20+
def setUp(self):
21+
self.sleep = mock.patch('time.sleep').start()
22+
self.max_retries = 10
23+
self.sleepers = Sleepers(self.max_retries, 30)
24+
25+
def tearDown(self):
26+
mock.patch.stopall()
27+
28+
def test_make(self):
29+
sleeper = self.sleepers.make()
30+
assert type(sleeper) == Sleeper
31+
assert sleeper.retries == 0
32+
33+
def test_sleep(self):
34+
sleeper = self.sleepers.make()
35+
sleeper.sleep()
36+
sleeper.sleep()
37+
self.sleep.assert_has_calls([mock.call(0), mock.call(30)])
38+
39+
def test_min_time(self):
40+
sleeper = self.sleepers.make()
41+
sleeper.sleep(5)
42+
self.sleep.assert_has_calls([mock.call(5)])
43+
44+
def test_retries_count(self):
45+
sleeper = self.sleepers.make()
46+
sleeper.sleep()
47+
sleeper.sleep()
48+
assert sleeper.retries == 2
49+
50+
def test_max_retries(self):
51+
sleeper = self.sleepers.make()
52+
for x in range(self.max_retries):
53+
sleeper.sleep()
54+
with pytest.raises(MaximumRetriesExceeded):
55+
sleeper.sleep()
56+
57+
if __name__ == '__main__':
58+
unittest.main()

0 commit comments

Comments
 (0)