|
22 | 22 | )
|
23 | 23 |
|
24 | 24 |
|
| 25 | +async def rate_limiter_in_loop( |
| 26 | + max_capacity: int, rate: float, loop: asyncio.AbstractEventLoop |
| 27 | +) -> AsyncRateLimiter: |
| 28 | + """Helper function to create AsyncRateLimiter object inside given event loop.""" |
| 29 | + limiter = AsyncRateLimiter(max_capacity=max_capacity, rate=rate, loop=loop) |
| 30 | + return limiter |
| 31 | + |
| 32 | + |
25 | 33 | @pytest.mark.asyncio
|
26 |
| -async def test_rate_limiter_throttles_requests() -> None: |
| 34 | +async def test_rate_limiter_throttles_requests( |
| 35 | + async_loop: asyncio.AbstractEventLoop, |
| 36 | +) -> None: |
| 37 | + """Test to check whether rate limiter will throttle incoming requests.""" |
27 | 38 | counter = 0
|
28 | 39 | # allow 2 requests to go through every 5 seconds
|
29 |
| - limiter = AsyncRateLimiter(max_capacity=2, rate=1 / 5) |
| 40 | + limiter_future = asyncio.run_coroutine_threadsafe( |
| 41 | + rate_limiter_in_loop(max_capacity=2, rate=1 / 5, loop=async_loop), async_loop |
| 42 | + ) |
| 43 | + limiter = limiter_future.result() |
30 | 44 |
|
31 | 45 | async def increment() -> None:
|
32 | 46 | await limiter.acquire()
|
33 | 47 | nonlocal counter
|
34 | 48 | counter += 1
|
35 | 49 |
|
36 |
| - tasks = [increment() for _ in range(10)] |
| 50 | + # create 10 tasks calling increment() |
| 51 | + tasks = [async_loop.create_task(increment()) for _ in range(10)] |
37 | 52 |
|
38 |
| - done, pending = await asyncio.wait(tasks, timeout=11) |
| 53 | + # wait 10 seconds and check tasks |
| 54 | + done, pending = asyncio.run_coroutine_threadsafe( |
| 55 | + asyncio.wait(tasks, timeout=11), async_loop |
| 56 | + ).result() |
39 | 57 |
|
| 58 | + # verify 4 tasks completed and 6 pending due to rate limiter |
40 | 59 | assert counter == 4
|
41 | 60 | assert len(done) == 4
|
42 | 61 | assert len(pending) == 6
|
43 | 62 |
|
44 | 63 |
|
45 | 64 | @pytest.mark.asyncio
|
46 |
| -async def test_rate_limiter_completes_all_tasks() -> None: |
| 65 | +async def test_rate_limiter_completes_all_tasks( |
| 66 | + async_loop: asyncio.AbstractEventLoop, |
| 67 | +) -> None: |
| 68 | + """Test to check all requests will go through rate limiter successfully.""" |
47 | 69 | counter = 0
|
48 | 70 | # allow 1 request to go through per second
|
49 |
| - limiter = AsyncRateLimiter(max_capacity=1, rate=1) |
| 71 | + limiter_future = asyncio.run_coroutine_threadsafe( |
| 72 | + rate_limiter_in_loop(max_capacity=1, rate=1, loop=async_loop), async_loop |
| 73 | + ) |
| 74 | + limiter = limiter_future.result() |
50 | 75 |
|
51 | 76 | async def increment() -> None:
|
52 | 77 | await limiter.acquire()
|
53 | 78 | nonlocal counter
|
54 | 79 | counter += 1
|
55 | 80 |
|
56 |
| - tasks = [increment() for _ in range(10)] |
| 81 | + # create 10 tasks calling increment() |
| 82 | + tasks = [async_loop.create_task(increment()) for _ in range(10)] |
57 | 83 |
|
58 |
| - done, pending = await asyncio.wait(tasks, timeout=30) |
| 84 | + done, pending = asyncio.run_coroutine_threadsafe( |
| 85 | + asyncio.wait(tasks, timeout=30), async_loop |
| 86 | + ).result() |
59 | 87 |
|
| 88 | + # verify all tasks done and none pending |
60 | 89 | assert counter == 10
|
61 | 90 | assert len(done) == 10
|
62 | 91 | assert len(pending) == 0
|
0 commit comments