Skip to content

MultiDbClient implementation #3696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: feat/active-active
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions redis/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import asyncio
import threading
from typing import Callable

class BackgroundScheduler:
"""
Schedules background tasks execution either in separate thread or in the running event loop.
"""
def __init__(self):
self._next_timer = None

def __del__(self):
if self._next_timer:
self._next_timer.cancel()

def run_once(self, delay: float, callback: Callable, *args):
"""
Runs callable task once after certain delay in seconds.
"""
# Run loop in a separate thread to unblock main thread.
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=_start_event_loop_in_thread,
args=(loop, self._call_later, delay, callback, *args),
daemon=True
)
thread.start()

def run_recurring(
self,
interval: float,
callback: Callable,
*args
):
"""
Runs recurring callable task with given interval in seconds.
"""
# Run loop in a separate thread to unblock main thread.
loop = asyncio.new_event_loop()

thread = threading.Thread(
target=_start_event_loop_in_thread,
args=(loop, self._call_later_recurring, interval, callback, *args),
daemon=True
)
thread.start()

def _call_later(self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args):
self._next_timer = loop.call_later(delay, callback, *args)

def _call_later_recurring(
self,
loop: asyncio.AbstractEventLoop,
interval: float,
callback: Callable,
*args
):
self._call_later(
loop, interval, self._execute_recurring, loop, interval, callback, *args
)

def _execute_recurring(
self,
loop: asyncio.AbstractEventLoop,
interval: float,
callback: Callable,
*args
):
"""
Executes recurring callable task with given interval in seconds.
"""
callback(*args)

self._call_later(
loop, interval, self._execute_recurring, loop, interval, callback, *args
)


def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop, call_soon_cb: Callable, *args):
"""
Starts event loop in a thread and schedule callback as soon as event loop is ready.
Used to be able to schedule tasks using loop.call_later.

:param event_loop:
:return:
"""
asyncio.set_event_loop(event_loop)
event_loop.call_soon(call_soon_cb, event_loop, *args)
event_loop.run_forever()
4 changes: 2 additions & 2 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def _send_command_parse_response(self, conn, command_name, *args, **options):
conn.send_command(*args, **options)
return self.parse_response(conn, command_name, **options)

def _close_connection(self, conn) -> None:
def _close_connection(self, conn, error, *args) -> None:
"""
Close the connection before retrying.

Expand Down Expand Up @@ -633,7 +633,7 @@ def _execute_command(self, *args, **options):
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda _: self._close_connection(conn),
lambda error: self._close_connection(conn, error, *args),
)
finally:
if self._single_connection_client:
Expand Down
73 changes: 73 additions & 0 deletions redis/data_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import threading
from typing import List, Any, TypeVar, Generic, Union

T = TypeVar('T')

class WeightedList(Generic[T]):
"""
Thread-safe weighted list.
"""
def __init__(self):
self._items: List[tuple[Any, Union[int, float]]] = []
self._lock = threading.RLock()

def add(self, item: Any, weight: float) -> None:
"""Add item with weight, maintaining sorted order"""
with self._lock:
# Find insertion point using binary search
left, right = 0, len(self._items)
while left < right:
mid = (left + right) // 2
if self._items[mid][0] < weight:
right = mid
else:
left = mid + 1

self._items.insert(left, (weight, item))

def remove(self, item):
"""Remove first occurrence of item"""
with self._lock:
for i, (weight, stored_item) in enumerate(self._items):
if stored_item == item:
self._items.pop(i)
return weight
raise ValueError("Item not found")

def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple[Any, Union[int, float]]]:
"""Get all items within weight range"""
with self._lock:
result = []
for weight, item in self._items:
if min_weight <= weight <= max_weight:
result.append((item, weight))
return result

def get_top_n(self, n: int) -> List[tuple[Any, Union[int, float]]]:
"""Get top N the highest weighted items"""
with self._lock:
return [(item, weight) for weight, item in self._items[:n]]

def update_weight(self, item, new_weight: float):
with self._lock:
"""Update weight of an item"""
old_weight = self.remove(item)
self.add(item, new_weight)
return old_weight

def __iter__(self):
"""Iterate in descending weight order"""
with self._lock:
items_copy = self._items.copy() # Create snapshot as lock released after each 'yield'

for weight, item in items_copy:
yield item, weight

def __len__(self):
with self._lock:
return len(self._items)

def __getitem__(self, index) -> tuple[Any, Union[int, float]]:
with self._lock:
weight, item = self._items[index]
return item, weight
69 changes: 59 additions & 10 deletions redis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import threading
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict, Type

from redis.auth.token import TokenInterface
from redis.credentials import CredentialProvider, StreamingCredentialProvider
Expand Down Expand Up @@ -42,6 +42,11 @@ def dispatch(self, event: object):
async def dispatch_async(self, event: object):
pass

@abstractmethod
def register_listeners(self, mappings: Dict[Type[object], List[EventListenerInterface]]):
"""Register additional listeners."""
pass


class EventException(Exception):
"""
Expand All @@ -56,11 +61,14 @@ def __init__(self, exception: Exception, event: object):

class EventDispatcher(EventDispatcherInterface):
# TODO: Make dispatcher to accept external mappings.
def __init__(self):
def __init__(
self,
event_listeners: Optional[Dict[Type[object], List[EventListenerInterface]]] = None,
):
"""
Mapping should be extended for any new events or listeners to be added.
Dispatcher that dispatches events to listeners associated with given event.
"""
self._event_listeners_mapping = {
self._event_listeners_mapping: Dict[Type[object], List[EventListenerInterface]]= {
AfterConnectionReleasedEvent: [
ReAuthConnectionListener(),
],
Expand All @@ -77,17 +85,33 @@ def __init__(self):
],
}

self._lock = threading.Lock()
self._async_lock = asyncio.Lock()

if event_listeners:
self.register_listeners(event_listeners)

def dispatch(self, event: object):
listeners = self._event_listeners_mapping.get(type(event))
with self._lock:
listeners = self._event_listeners_mapping.get(type(event), [])

for listener in listeners:
listener.listen(event)
for listener in listeners:
listener.listen(event)

async def dispatch_async(self, event: object):
listeners = self._event_listeners_mapping.get(type(event))
with self._async_lock:
listeners = self._event_listeners_mapping.get(type(event), [])

for listener in listeners:
await listener.listen(event)
for listener in listeners:
await listener.listen(event)

def register_listeners(self, event_listeners: Dict[Type[object], List[EventListenerInterface]]):
with self._lock:
for event in event_listeners:
if event in self._event_listeners_mapping:
self._event_listeners_mapping[event] = list(set(self._event_listeners_mapping[event] + event_listeners[event]))
else:
self._event_listeners_mapping[event] = event_listeners[event]


class AfterConnectionReleasedEvent:
Expand Down Expand Up @@ -225,6 +249,31 @@ def nodes(self) -> dict:
def credential_provider(self) -> Union[CredentialProvider, None]:
return self._credential_provider

class OnCommandFailEvent:
"""
Event fired whenever a command fails during the execution.
"""
def __init__(
self,
command: tuple,
exception: Exception,
client,
):
self._command = command
self._exception = exception
self._client = client

@property
def command(self) -> tuple:
return self._command

@property
def exception(self) -> Exception:
return self._exception

@property
def client(self):
return self._client

class ReAuthConnectionListener(EventListenerInterface):
"""
Expand Down
Empty file added redis/multidb/__init__.py
Empty file.
Loading