dify
This commit is contained in:
4
dify/api/libs/broadcast_channel/redis/__init__.py
Normal file
4
dify/api/libs/broadcast_channel/redis/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .channel import BroadcastChannel
|
||||
from .sharded_channel import ShardedRedisBroadcastChannel
|
||||
|
||||
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]
|
||||
227
dify/api/libs/broadcast_channel/redis/_subscription.py
Normal file
227
dify/api/libs/broadcast_channel/redis/_subscription.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Self
|
||||
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis.client import PubSub
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisSubscriptionBase(Subscription):
|
||||
"""Base class for Redis pub/sub subscriptions with common functionality.
|
||||
|
||||
This class provides shared functionality for both regular and sharded
|
||||
Redis pub/sub subscriptions, reducing code duplication and improving
|
||||
maintainability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
|
||||
self._dropped_count = 0
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
"""Start the subscription if not already started."""
|
||||
with self._start_lock:
|
||||
if self._started:
|
||||
return
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
if self._pubsub is None:
|
||||
raise SubscriptionClosedError(
|
||||
f"The Redis {self._get_subscription_type()} subscription has been cleaned up"
|
||||
)
|
||||
|
||||
self._subscribe()
|
||||
_logger.debug("Subscribed to %s channel %s", self._get_subscription_type(), self._topic)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-{self._get_subscription_type().replace(' ', '-')}-broadcast-{self._topic}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener_thread.start()
|
||||
self._started = True
|
||||
|
||||
def _listen(self) -> None:
|
||||
"""Main listener loop for processing messages."""
|
||||
pubsub = self._pubsub
|
||||
assert pubsub is not None, "PubSub should not be None while starting listening."
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
raw_message = self._get_message()
|
||||
except Exception as e:
|
||||
# Log the exception and exit the listener thread gracefully
|
||||
# This handles Redis connection errors and other exceptions
|
||||
_logger.error(
|
||||
"Error getting message from Redis %s subscription, topic=%s: %s",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
if raw_message.get("type") != self._get_message_type():
|
||||
continue
|
||||
|
||||
channel_field = raw_message.get("channel")
|
||||
if isinstance(channel_field, bytes):
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
elif isinstance(channel_field, str):
|
||||
channel_name = channel_field
|
||||
else:
|
||||
channel_name = str(channel_field)
|
||||
|
||||
if channel_name != self._topic:
|
||||
_logger.warning(
|
||||
"Ignoring %s message from unexpected channel %s", self._get_subscription_type(), channel_name
|
||||
)
|
||||
continue
|
||||
|
||||
payload_bytes: bytes | None = raw_message.get("data")
|
||||
if not isinstance(payload_bytes, bytes):
|
||||
_logger.error(
|
||||
"Received invalid data from %s channel %s, type=%s",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
type(payload_bytes),
|
||||
)
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
|
||||
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
|
||||
try:
|
||||
self._unsubscribe()
|
||||
pubsub.close()
|
||||
_logger.debug("%s PubSub closed for topic %s", self._get_subscription_type().title(), self._topic)
|
||||
except Exception as e:
|
||||
_logger.error(
|
||||
"Error during cleanup of Redis %s subscription, topic=%s: %s",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
self._pubsub = None
|
||||
|
||||
def _enqueue_message(self, payload: bytes) -> None:
|
||||
"""Enqueue a message to the internal queue with dropping behavior."""
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
self._queue.put_nowait(payload)
|
||||
return
|
||||
except queue.Full:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
self._dropped_count += 1
|
||||
_logger.debug(
|
||||
"Dropped message from Redis %s subscription, topic=%s, total_dropped=%d",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
self._dropped_count,
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
return
|
||||
|
||||
def _message_iterator(self) -> Generator[bytes, None, None]:
|
||||
"""Iterator for consuming messages from the subscription."""
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
item = self._queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
"""Return an iterator over messages from the subscription."""
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
"""Receive the next message from the subscription."""
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
self._start_if_needed()
|
||||
|
||||
try:
|
||||
item = self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
return item
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""Context manager entry point."""
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
"""Context manager exit point."""
|
||||
self.close()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the subscription and clean up resources."""
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
|
||||
# message retrieval method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=1.0)
|
||||
self._listener_thread = None
|
||||
|
||||
# Abstract methods to be implemented by subclasses
|
||||
def _get_subscription_type(self) -> str:
|
||||
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
"""Subscribe to the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _unsubscribe(self) -> None:
|
||||
"""Unsubscribe from the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
"""Get a message from Redis using the appropriate method."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
"""Return the expected message type (e.g., 'message' or 'smessage')."""
|
||||
raise NotImplementedError
|
||||
67
dify/api/libs/broadcast_channel/redis/channel.py
Normal file
67
dify/api/libs/broadcast_channel/redis/channel.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
|
||||
class BroadcastChannel:
|
||||
"""
|
||||
Redis Pub/Sub based broadcast channel implementation (regular, non-sharded).
|
||||
|
||||
Provides "at most once" delivery semantics for messages published to channels
|
||||
using Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
|
||||
|
||||
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
def topic(self, topic: str) -> "Topic":
|
||||
return Topic(self._client, topic)
|
||||
|
||||
|
||||
class Topic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.publish(self._topic, payload)
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisSubscription(
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
|
||||
|
||||
class _RedisSubscription(RedisSubscriptionBase):
|
||||
"""Regular Redis pub/sub subscription implementation."""
|
||||
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "regular"
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.subscribe(self._topic)
|
||||
|
||||
def _unsubscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.unsubscribe(self._topic)
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "message"
|
||||
65
dify/api/libs/broadcast_channel/redis/sharded_channel.py
Normal file
65
dify/api/libs/broadcast_channel/redis/sharded_channel.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
|
||||
class ShardedRedisBroadcastChannel:
|
||||
"""
|
||||
Redis 7.0+ Sharded Pub/Sub based broadcast channel implementation.
|
||||
|
||||
Provides "at most once" delivery semantics using SPUBLISH/SSUBSCRIBE commands,
|
||||
distributing channels across Redis cluster nodes for better scalability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
def topic(self, topic: str) -> "ShardedTopic":
|
||||
return ShardedTopic(self._client, topic)
|
||||
|
||||
|
||||
class ShardedTopic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.spublish(self._topic, payload) # type: ignore[attr-defined]
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisShardedSubscription(
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
|
||||
|
||||
class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
"""Redis 7.0+ sharded pub/sub subscription implementation."""
|
||||
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "sharded"
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.ssubscribe(self._topic) # type: ignore[attr-defined]
|
||||
|
||||
def _unsubscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "smessage"
|
||||
Reference in New Issue
Block a user