This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,134 @@
"""
Broadcast channel for Pub/Sub messaging.
"""
import types
from abc import abstractmethod
from collections.abc import Iterator
from contextlib import AbstractContextManager
from typing import Protocol, Self
class Subscription(AbstractContextManager["Subscription"], Protocol):
"""A subscription to a topic that provides an iterator over received messages.
The subscription can be used as a context manager and will automatically
close when exiting the context.
Note: `Subscription` instances are not thread-safe. Each thread should create its own
subscription.
"""
@abstractmethod
def __iter__(self) -> Iterator[bytes]:
"""`__iter__` returns an iterator used to consume the message from this subscription.
If the caller did not enter the context, `__iter__` may lazily perform the setup before
yielding messages; otherwise `__enter__` handles it.”
If the subscription is closed, then the returned iterator exits without
raising any error.
"""
...
@abstractmethod
def close(self) -> None:
"""close closes the subscription, releases any resources associated with it."""
...
def __enter__(self) -> Self:
"""`__enter__` does the setup logic of the subscription (if any), and return itself."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
self.close()
return None
@abstractmethod
def receive(self, timeout: float | None = 0.1) -> bytes | None:
"""Receive the next message from the broadcast channel.
If `timeout` is specified, this method returns `None` if no message is
received within the given period. If `timeout` is `None`, the call blocks
until a message is received.
Calling receive with `timeout=None` is highly discouraged, as it is impossible to
cancel a blocking subscription.
:param timeout: timeout for receive message, in seconds.
Returns:
bytes: The received message as a byte string, or
None: If the timeout expires before a message is received.
Raises:
SubscriptionClosed: If the subscription has already been closed.
"""
...
class Producer(Protocol):
"""Producer is an interface for message publishing. It is already bound to a specific topic.
`Producer` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def publish(self, payload: bytes) -> None:
"""Publish a message to the bounded topic."""
...
class Subscriber(Protocol):
"""Subscriber is an interface for subscription creation. It is already bound to a specific topic.
`Subscriber` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def subscribe(self) -> Subscription:
pass
class Topic(Producer, Subscriber, Protocol):
"""A named channel for publishing and subscribing to messages.
Topics provide both read and write access. For restricted access,
use as_producer() for write-only view or as_subscriber() for read-only view.
`Topic` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def as_producer(self) -> Producer:
"""as_producer creates a write-only view for this topic."""
...
@abstractmethod
def as_subscriber(self) -> Subscriber:
"""as_subscriber create a read-only view for this topic."""
...
class BroadcastChannel(Protocol):
"""A broadcasting channel is a channel supporting broadcasting semantics.
Each channel is identified by a topic, different topics are isolated and do not affect each other.
There can be multiple subscriptions to a specific topic. When a publisher publishes a message to
a specific topic, all subscription should receive the published message.
There are no restriction for the persistence of messages. Once a subscription is created, it
should receive all subsequent messages published.
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def topic(self, topic: str) -> "Topic":
"""topic returns a `Topic` instance for the given topic name."""
...

View File

@@ -0,0 +1,12 @@
class BroadcastChannelError(Exception):
"""`BroadcastChannelError` is the base class for all exceptions related
to `BroadcastChannel`."""
pass
class SubscriptionClosedError(BroadcastChannelError):
"""SubscriptionClosedError means that the subscription has been closed and
methods for consuming messages should not be called."""
pass

View File

@@ -0,0 +1,4 @@
from .channel import BroadcastChannel
from .sharded_channel import ShardedRedisBroadcastChannel
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]

View File

@@ -0,0 +1,227 @@
import logging
import queue
import threading
import types
from collections.abc import Generator, Iterator
from typing import Self
from libs.broadcast_channel.channel import Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis.client import PubSub
_logger = logging.getLogger(__name__)
class RedisSubscriptionBase(Subscription):
"""Base class for Redis pub/sub subscriptions with common functionality.
This class provides shared functionality for both regular and sharded
Redis pub/sub subscriptions, reducing code duplication and improving
maintainability.
"""
def __init__(
self,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
self._dropped_count = 0
self._listener_thread: threading.Thread | None = None
self._start_lock = threading.Lock()
self._started = False
def _start_if_needed(self) -> None:
"""Start the subscription if not already started."""
with self._start_lock:
if self._started:
return
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
if self._pubsub is None:
raise SubscriptionClosedError(
f"The Redis {self._get_subscription_type()} subscription has been cleaned up"
)
self._subscribe()
_logger.debug("Subscribed to %s channel %s", self._get_subscription_type(), self._topic)
self._listener_thread = threading.Thread(
target=self._listen,
name=f"redis-{self._get_subscription_type().replace(' ', '-')}-broadcast-{self._topic}",
daemon=True,
)
self._listener_thread.start()
self._started = True
def _listen(self) -> None:
"""Main listener loop for processing messages."""
pubsub = self._pubsub
assert pubsub is not None, "PubSub should not be None while starting listening."
while not self._closed.is_set():
try:
raw_message = self._get_message()
except Exception as e:
# Log the exception and exit the listener thread gracefully
# This handles Redis connection errors and other exceptions
_logger.error(
"Error getting message from Redis %s subscription, topic=%s: %s",
self._get_subscription_type(),
self._topic,
e,
exc_info=True,
)
break
if raw_message is None:
continue
if raw_message.get("type") != self._get_message_type():
continue
channel_field = raw_message.get("channel")
if isinstance(channel_field, bytes):
channel_name = channel_field.decode("utf-8")
elif isinstance(channel_field, str):
channel_name = channel_field
else:
channel_name = str(channel_field)
if channel_name != self._topic:
_logger.warning(
"Ignoring %s message from unexpected channel %s", self._get_subscription_type(), channel_name
)
continue
payload_bytes: bytes | None = raw_message.get("data")
if not isinstance(payload_bytes, bytes):
_logger.error(
"Received invalid data from %s channel %s, type=%s",
self._get_subscription_type(),
self._topic,
type(payload_bytes),
)
continue
self._enqueue_message(payload_bytes)
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
try:
self._unsubscribe()
pubsub.close()
_logger.debug("%s PubSub closed for topic %s", self._get_subscription_type().title(), self._topic)
except Exception as e:
_logger.error(
"Error during cleanup of Redis %s subscription, topic=%s: %s",
self._get_subscription_type(),
self._topic,
e,
exc_info=True,
)
finally:
self._pubsub = None
def _enqueue_message(self, payload: bytes) -> None:
"""Enqueue a message to the internal queue with dropping behavior."""
while not self._closed.is_set():
try:
self._queue.put_nowait(payload)
return
except queue.Full:
try:
self._queue.get_nowait()
self._dropped_count += 1
_logger.debug(
"Dropped message from Redis %s subscription, topic=%s, total_dropped=%d",
self._get_subscription_type(),
self._topic,
self._dropped_count,
)
except queue.Empty:
continue
return
def _message_iterator(self) -> Generator[bytes, None, None]:
"""Iterator for consuming messages from the subscription."""
while not self._closed.is_set():
try:
item = self._queue.get(timeout=0.1)
except queue.Empty:
continue
yield item
def __iter__(self) -> Iterator[bytes]:
"""Return an iterator over messages from the subscription."""
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
self._start_if_needed()
return iter(self._message_iterator())
def receive(self, timeout: float | None = None) -> bytes | None:
"""Receive the next message from the subscription."""
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
self._start_if_needed()
try:
item = self._queue.get(timeout=timeout)
except queue.Empty:
return None
return item
def __enter__(self) -> Self:
"""Context manager entry point."""
self._start_if_needed()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
"""Context manager exit point."""
self.close()
return None
def close(self) -> None:
"""Close the subscription and clean up resources."""
if self._closed.is_set():
return
self._closed.set()
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
# message retrieval method should NOT be called concurrently.
#
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
listener = self._listener_thread
if listener is not None:
listener.join(timeout=1.0)
self._listener_thread = None
# Abstract methods to be implemented by subclasses
def _get_subscription_type(self) -> str:
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
raise NotImplementedError
def _subscribe(self) -> None:
"""Subscribe to the Redis topic using the appropriate command."""
raise NotImplementedError
def _unsubscribe(self) -> None:
"""Unsubscribe from the Redis topic using the appropriate command."""
raise NotImplementedError
def _get_message(self) -> dict | None:
"""Get a message from Redis using the appropriate method."""
raise NotImplementedError
def _get_message_type(self) -> str:
"""Return the expected message type (e.g., 'message' or 'smessage')."""
raise NotImplementedError

View File

@@ -0,0 +1,67 @@
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
from ._subscription import RedisSubscriptionBase
class BroadcastChannel:
"""
Redis Pub/Sub based broadcast channel implementation (regular, non-sharded).
Provides "at most once" delivery semantics for messages published to channels
using Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
"""
def __init__(
self,
redis_client: Redis,
):
self._client = redis_client
def topic(self, topic: str) -> "Topic":
return Topic(self._client, topic)
class Topic:
def __init__(self, redis_client: Redis, topic: str):
self._client = redis_client
self._topic = topic
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.publish(self._topic, payload)
def as_subscriber(self) -> Subscriber:
return self
def subscribe(self) -> Subscription:
return _RedisSubscription(
pubsub=self._client.pubsub(),
topic=self._topic,
)
class _RedisSubscription(RedisSubscriptionBase):
"""Regular Redis pub/sub subscription implementation."""
def _get_subscription_type(self) -> str:
return "regular"
def _subscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.subscribe(self._topic)
def _unsubscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.unsubscribe(self._topic)
def _get_message(self) -> dict | None:
assert self._pubsub is not None
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
def _get_message_type(self) -> str:
return "message"

View File

@@ -0,0 +1,65 @@
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
from ._subscription import RedisSubscriptionBase
class ShardedRedisBroadcastChannel:
"""
Redis 7.0+ Sharded Pub/Sub based broadcast channel implementation.
Provides "at most once" delivery semantics using SPUBLISH/SSUBSCRIBE commands,
distributing channels across Redis cluster nodes for better scalability.
"""
def __init__(
self,
redis_client: Redis,
):
self._client = redis_client
def topic(self, topic: str) -> "ShardedTopic":
return ShardedTopic(self._client, topic)
class ShardedTopic:
def __init__(self, redis_client: Redis, topic: str):
self._client = redis_client
self._topic = topic
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.spublish(self._topic, payload) # type: ignore[attr-defined]
def as_subscriber(self) -> Subscriber:
return self
def subscribe(self) -> Subscription:
return _RedisShardedSubscription(
pubsub=self._client.pubsub(),
topic=self._topic,
)
class _RedisShardedSubscription(RedisSubscriptionBase):
"""Redis 7.0+ sharded pub/sub subscription implementation."""
def _get_subscription_type(self) -> str:
return "sharded"
def _subscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.ssubscribe(self._topic) # type: ignore[attr-defined]
def _unsubscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined]
def _get_message(self) -> dict | None:
assert self._pubsub is not None
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
def _get_message_type(self) -> str:
return "smessage"