Skip to content

Commit f52ba7e

Browse files
fix: make asyncio.Lock() run in background thread (GoogleCloudPlatform#252)
1 parent 7a8a130 commit f52ba7e

File tree

6 files changed

+82
-19
lines changed

6 files changed

+82
-19
lines changed

google/cloud/sql/connector/instance_connection_manager.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def _client_session(self) -> aiohttp.ClientSession:
233233
_project: str
234234
_region: str
235235

236+
_refresh_rate_limiter: AsyncRateLimiter
236237
_refresh_in_progress: asyncio.locks.Event
237238
_current: asyncio.Task # task wraps coroutine that returns InstanceMetadata
238239
_next: asyncio.Task # task wraps coroutine that returns another task
@@ -274,17 +275,18 @@ def __init__(
274275

275276
self._auth_init(credentials)
276277

277-
self._refresh_rate_limiter = AsyncRateLimiter(
278-
max_capacity=2, rate=1 / 30, loop=self._loop
279-
)
280-
281-
async def _set_instance_data() -> None:
282-
logger.debug("Updating instance data")
278+
async def _async_init() -> None:
279+
"""Initialize InstanceConnectionManager's variables that require the
280+
event loop running in background thread.
281+
"""
282+
self._refresh_rate_limiter = AsyncRateLimiter(
283+
max_capacity=2, rate=1 / 30, loop=self._loop
284+
)
283285
self._refresh_in_progress = asyncio.locks.Event()
284286
self._current = self._loop.create_task(self._get_instance_data())
285287
self._next = self._loop.create_task(self._schedule_refresh())
286288

287-
init_future = asyncio.run_coroutine_threadsafe(_set_instance_data(), self._loop)
289+
init_future = asyncio.run_coroutine_threadsafe(_async_init(), self._loop)
288290
init_future.result()
289291

290292
def __del__(self) -> None:

google/cloud/sql/connector/rate_limiter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def __init__(
4646
self.rate = rate
4747
self.max_capacity = max_capacity
4848
self._loop = loop or asyncio.get_event_loop()
49-
self._lock = asyncio.Lock()
5049
self._tokens: float = max_capacity
5150
self._last_token_update = self._loop.time()
51+
self._lock = asyncio.Lock()
5252

5353
def _update_token_count(self) -> None:
5454
"""

tests/conftest.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ def async_loop() -> Generator:
5151
Creates a loop in a background thread and returns it to use for testing.
5252
"""
5353
loop = asyncio.new_event_loop()
54-
thr = threading.Thread(target=loop.run_forever)
54+
thr = threading.Thread(target=loop.run_forever, daemon=True)
5555
thr.start()
5656
yield loop
5757
loop.stop()
58-
thr.join()
5958

6059

6160
@pytest.fixture

tests/system/test_connector_object.py

+24
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import logging
2020
import google.auth
2121
from google.cloud.sql.connector import connector
22+
import datetime
23+
import concurrent.futures
2224

2325

2426
def init_connection_engine(
@@ -80,3 +82,25 @@ def test_multiple_connectors() -> None:
8082
)
8183
except Exception as e:
8284
logging.exception("Failed to connect with multiple Connector objects!", e)
85+
86+
87+
def test_connector_in_ThreadPoolExecutor() -> None:
88+
"""Test that Connector can connect from ThreadPoolExecutor thread.
89+
This helps simulate how connector works in Cloud Run and Cloud Functions.
90+
"""
91+
92+
def get_time() -> datetime.datetime:
93+
"""Helper method for getting current time from database."""
94+
default_connector = connector.Connector()
95+
pool = init_connection_engine(default_connector)
96+
97+
# connect to database and get current time
98+
with pool.connect() as conn:
99+
current_time = conn.execute("SELECT NOW()").fetchone()
100+
return current_time[0]
101+
102+
# try running connector in ThreadPoolExecutor as Cloud Run does
103+
with concurrent.futures.ThreadPoolExecutor() as executor:
104+
future = executor.submit(get_time)
105+
return_value = future.result()
106+
assert isinstance(return_value, datetime.datetime)

tests/unit/test_instance_connection_manager.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,16 @@ def icm(
4646

4747
@pytest.fixture
4848
def test_rate_limiter(async_loop: asyncio.AbstractEventLoop) -> AsyncRateLimiter:
49-
return AsyncRateLimiter(max_capacity=1, rate=1 / 2, loop=async_loop)
49+
async def rate_limiter_in_loop(
50+
async_loop: asyncio.AbstractEventLoop,
51+
) -> AsyncRateLimiter:
52+
return AsyncRateLimiter(max_capacity=1, rate=1 / 2, loop=async_loop)
53+
54+
limiter_future = asyncio.run_coroutine_threadsafe(
55+
rate_limiter_in_loop(async_loop), async_loop
56+
)
57+
limiter = limiter_future.result()
58+
return limiter
5059

5160

5261
class MockMetadata:

tests/unit/test_rate_limiter.py

+37-8
Original file line numberDiff line numberDiff line change
@@ -22,41 +22,70 @@
2222
)
2323

2424

25+
async def rate_limiter_in_loop(
26+
max_capacity: int, rate: float, loop: asyncio.AbstractEventLoop
27+
) -> AsyncRateLimiter:
28+
"""Helper function to create AsyncRateLimiter object inside given event loop."""
29+
limiter = AsyncRateLimiter(max_capacity=max_capacity, rate=rate, loop=loop)
30+
return limiter
31+
32+
2533
@pytest.mark.asyncio
26-
async def test_rate_limiter_throttles_requests() -> None:
34+
async def test_rate_limiter_throttles_requests(
35+
async_loop: asyncio.AbstractEventLoop,
36+
) -> None:
37+
"""Test to check whether rate limiter will throttle incoming requests."""
2738
counter = 0
2839
# allow 2 requests to go through every 5 seconds
29-
limiter = AsyncRateLimiter(max_capacity=2, rate=1 / 5)
40+
limiter_future = asyncio.run_coroutine_threadsafe(
41+
rate_limiter_in_loop(max_capacity=2, rate=1 / 5, loop=async_loop), async_loop
42+
)
43+
limiter = limiter_future.result()
3044

3145
async def increment() -> None:
3246
await limiter.acquire()
3347
nonlocal counter
3448
counter += 1
3549

36-
tasks = [increment() for _ in range(10)]
50+
# create 10 tasks calling increment()
51+
tasks = [async_loop.create_task(increment()) for _ in range(10)]
3752

38-
done, pending = await asyncio.wait(tasks, timeout=11)
53+
# wait 10 seconds and check tasks
54+
done, pending = asyncio.run_coroutine_threadsafe(
55+
asyncio.wait(tasks, timeout=11), async_loop
56+
).result()
3957

58+
# verify 4 tasks completed and 6 pending due to rate limiter
4059
assert counter == 4
4160
assert len(done) == 4
4261
assert len(pending) == 6
4362

4463

4564
@pytest.mark.asyncio
46-
async def test_rate_limiter_completes_all_tasks() -> None:
65+
async def test_rate_limiter_completes_all_tasks(
66+
async_loop: asyncio.AbstractEventLoop,
67+
) -> None:
68+
"""Test to check all requests will go through rate limiter successfully."""
4769
counter = 0
4870
# allow 1 request to go through per second
49-
limiter = AsyncRateLimiter(max_capacity=1, rate=1)
71+
limiter_future = asyncio.run_coroutine_threadsafe(
72+
rate_limiter_in_loop(max_capacity=1, rate=1, loop=async_loop), async_loop
73+
)
74+
limiter = limiter_future.result()
5075

5176
async def increment() -> None:
5277
await limiter.acquire()
5378
nonlocal counter
5479
counter += 1
5580

56-
tasks = [increment() for _ in range(10)]
81+
# create 10 tasks calling increment()
82+
tasks = [async_loop.create_task(increment()) for _ in range(10)]
5783

58-
done, pending = await asyncio.wait(tasks, timeout=30)
84+
done, pending = asyncio.run_coroutine_threadsafe(
85+
asyncio.wait(tasks, timeout=30), async_loop
86+
).result()
5987

88+
# verify all tasks done and none pending
6089
assert counter == 10
6190
assert len(done) == 10
6291
assert len(pending) == 0

0 commit comments

Comments
 (0)